import asyncio from collections import defaultdict from datetime import time, datetime from collections.abc import Callable, Sequence, Awaitable from sqlalchemy.orm import selectinload from sqlalchemy.exc import IntegrityError from sqlalchemy import func, delete, select from nonebot_plugin_saa import PlatformTarget from nonebot_plugin_datastore import create_session from ..types import Tag from ..types import Target as T_Target from .utils import NoSuchTargetException from .db_model import User, Target, Subscribe, ScheduleTimeWeight from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp def _get_time(): dt = datetime.now() cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second) return cur_time class SubscribeDupException(Exception): ... class DBConfig: def __init__(self): self.add_target_hook: list[Callable[[str, T_Target], Awaitable]] = [] self.delete_target_hook: list[Callable[[str, T_Target], Awaitable]] = [] def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]): self.add_target_hook.append(fun) def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]): self.delete_target_hook.append(fun) async def add_subscribe( self, user: PlatformTarget, target: T_Target, target_name: str, platform_name: str, cats: list[Category], tags: list[Tag], ): async with create_session() as session: db_user_stmt = select(User).where(User.user_target == user.dict()) db_user: User | None = await session.scalar(db_user_stmt) if not db_user: db_user = User(user_target=user.dict()) session.add(db_user) db_target_stmt = select(Target).where(Target.platform_name == platform_name).where(Target.target == target) db_target: Target | None = await session.scalar(db_target_stmt) if not db_target: db_target = Target(target=target, platform_name=platform_name, target_name=target_name) await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook]) else: db_target.target_name = target_name subscribe = Subscribe( categories=cats, tags=tags, user=db_user, target=db_target, ) session.add(subscribe) try: await session.commit() except IntegrityError as e: if len(e.args) > 0 and "UNIQUE constraint failed" in e.args[0]: raise SubscribeDupException() raise e async def list_subscribe(self, user: PlatformTarget) -> Sequence[Subscribe]: async with create_session() as session: query_stmt = ( select(Subscribe) .where(User.user_target == user.dict()) .join(User) .options(selectinload(Subscribe.target)) ) subs = (await session.scalars(query_stmt)).all() return subs async def list_subs_with_all_info(self) -> Sequence[Subscribe]: """获取数据库中带有user、target信息的subscribe数据""" async with create_session() as session: query_stmt = ( select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user)) ) subs = (await session.scalars(query_stmt)).all() return subs async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str): async with create_session() as session: user_obj = await session.scalar(select(User).where(User.user_target == user.dict())) target_obj = await session.scalar( select(Target).where(Target.platform_name == platform_name, Target.target == target) ) await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj)) target_count = await session.scalar( select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj) ) if target_count == 0: # delete empty target await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook]) await session.commit() async def update_subscribe( self, user: PlatformTarget, target: str, target_name: str, platform_name: str, cats: list, tags: list, ): async with create_session() as sess: subscribe_obj: Subscribe = await sess.scalar( select(Subscribe) .where( User.user_target == user.dict(), Target.target == target, Target.platform_name == platform_name, ) .join(User) .join(Target) .options(selectinload(Subscribe.target)) # type:ignore ) subscribe_obj.tags = tags # type:ignore subscribe_obj.categories = cats # type:ignore subscribe_obj.target.target_name = target_name await sess.commit() async def get_platform_target(self, platform_name: str) -> Sequence[Target]: async with create_session() as sess: subq = select(Subscribe.target_id).distinct().subquery() query = select(Target).join(subq).where(Target.platform_name == platform_name) return (await sess.scalars(query)).all() async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig: async with create_session() as sess: time_weight_conf = ( await sess.scalars( select(ScheduleTimeWeight) .where(Target.platform_name == platform_name, Target.target == target) .join(Target) ) ).all() targetObj = await sess.scalar( select(Target).where(Target.platform_name == platform_name, Target.target == target) ) assert targetObj return WeightConfig( default=targetObj.default_schedule_weight, time_config=[ TimeWeightConfig( start_time=time_conf.start_time, end_time=time_conf.end_time, weight=time_conf.weight, ) for time_conf in time_weight_conf ], ) async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig): async with create_session() as sess: targetObj = await sess.scalar( select(Target).where(Target.platform_name == platform_name, Target.target == target) ) if not targetObj: raise NoSuchTargetException() target_id = targetObj.id targetObj.default_schedule_weight = conf.default delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id) await sess.execute(delete_statement) for time_conf in conf.time_config: new_conf = ScheduleTimeWeight( start_time=time_conf.start_time, end_time=time_conf.end_time, weight=time_conf.weight, target=targetObj, ) sess.add(new_conf) await sess.commit() async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]: res = {} cur_time = _get_time() async with create_session() as sess: targets = ( await sess.scalars( select(Target) .where(Target.platform_name.in_(platform_list)) .options(selectinload(Target.time_weight)) ) ).all() for target in targets: key = f"{target.platform_name}-{target.target}" weight = target.default_schedule_weight for time_conf in target.time_weight: if time_conf.start_time <= cur_time and time_conf.end_time > cur_time: weight = time_conf.weight break res[key] = weight return res async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]: async with create_session() as sess: query = ( select(Subscribe) .join(Target) .where(Target.platform_name == platform_name, Target.target == target) .options(selectinload(Subscribe.user)) ) subsribes = (await sess.scalars(query)).all() return [ UserSubInfo( PlatformTarget.deserialize(subscribe.user.user_target), subscribe.categories, subscribe.tags, ) for subscribe in subsribes ] async def get_all_weight_config( self, ) -> dict[str, dict[str, PlatformWeightConfigResp]]: res: dict[str, dict[str, PlatformWeightConfigResp]] = defaultdict(dict) async with create_session() as sess: query = select(Target) targets = (await sess.scalars(query)).all() query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target)) time_weights = (await sess.scalars(query)).all() for target in targets: platform_name = target.platform_name if platform_name not in res.keys(): res[platform_name][target.target] = PlatformWeightConfigResp( target=T_Target(target.target), target_name=target.target_name, platform_name=platform_name, weight=WeightConfig(default=target.default_schedule_weight, time_config=[]), ) for time_weight_config in time_weights: platform_name = time_weight_config.target.platform_name target = time_weight_config.target.target res[platform_name][target].weight.time_config.append( TimeWeightConfig( start_time=time_weight_config.start_time, end_time=time_weight_config.end_time, weight=time_weight_config.weight, ) ) return res config = DBConfig()