import asyncio from collections import defaultdict from datetime import datetime, time from typing import Awaitable, Callable, Optional, Sequence from nonebot_plugin_datastore import create_session from sqlalchemy import delete, func, select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import selectinload from ..types import Category, PlatformWeightConfigResp, Tag from ..types import Target as T_Target from ..types import TimeWeightConfig from ..types import User as T_User from ..types import UserSubInfo, WeightConfig from .db_model import ScheduleTimeWeight, Subscribe, Target, User from .utils import NoSuchTargetException def _get_time(): dt = datetime.now() cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second) return cur_time class SubscribeDupException(Exception): ... class DBConfig: def __init__(self): self.add_target_hook: list[Callable[[str, T_Target], Awaitable]] = [] self.delete_target_hook: list[Callable[[str, T_Target], Awaitable]] = [] def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]): self.add_target_hook.append(fun) def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]): self.delete_target_hook.append(fun) async def add_subscribe( self, user: int, user_type: str, target: T_Target, target_name: str, platform_name: str, cats: list[Category], tags: list[Tag], ): async with create_session() as session: db_user_stmt = ( select(User).where(User.uid == user).where(User.type == user_type) ) db_user: Optional[User] = await session.scalar(db_user_stmt) if not db_user: db_user = User(uid=user, type=user_type) session.add(db_user) db_target_stmt = ( select(Target) .where(Target.platform_name == platform_name) .where(Target.target == target) ) db_target: Optional[Target] = await session.scalar(db_target_stmt) if not db_target: db_target = Target( target=target, platform_name=platform_name, target_name=target_name ) await asyncio.gather( *[hook(platform_name, target) for hook in self.add_target_hook] ) else: db_target.target_name = target_name subscribe = Subscribe( categories=cats, tags=tags, user=db_user, target=db_target, ) session.add(subscribe) try: await session.commit() except IntegrityError as e: if len(e.args) > 0 and "UNIQUE constraint failed" in e.args[0]: raise SubscribeDupException() raise e async def list_subscribe(self, user: int, user_type: str) -> Sequence[Subscribe]: async with create_session() as session: query_stmt = ( select(Subscribe) .where(User.type == user_type, User.uid == user) .join(User) .options(selectinload(Subscribe.target)) ) subs = (await session.scalars(query_stmt)).all() return subs async def list_subs_with_all_info(self) -> Sequence[Subscribe]: """获取数据库中带有user、target信息的subscribe数据""" async with create_session() as session: query_stmt = ( select(Subscribe) .join(User) .options(selectinload(Subscribe.target), selectinload(Subscribe.user)) ) subs = (await session.scalars(query_stmt)).all() return subs async def del_subscribe( self, user: int, user_type: str, target: str, platform_name: str ): async with create_session() as session: user_obj = await session.scalar( select(User).where(User.uid == user, User.type == user_type) ) target_obj = await session.scalar( select(Target).where( Target.platform_name == platform_name, Target.target == target ) ) await session.execute( delete(Subscribe).where( Subscribe.user == user_obj, Subscribe.target == target_obj ) ) target_count = await session.scalar( select(func.count()) .select_from(Subscribe) .where(Subscribe.target == target_obj) ) if target_count == 0: # delete empty target await asyncio.gather( *[ hook(platform_name, T_Target(target)) for hook in self.delete_target_hook ] ) await session.commit() async def update_subscribe( self, user: int, user_type: str, target: str, target_name: str, platform_name: str, cats: list, tags: list, ): async with create_session() as sess: subscribe_obj: Subscribe = await sess.scalar( select(Subscribe) .where( User.uid == user, User.type == user_type, Target.target == target, Target.platform_name == platform_name, ) .join(User) .join(Target) .options(selectinload(Subscribe.target)) # type:ignore ) subscribe_obj.tags = tags # type:ignore subscribe_obj.categories = cats # type:ignore subscribe_obj.target.target_name = target_name await sess.commit() async def get_platform_target(self, platform_name: str) -> Sequence[Target]: async with create_session() as sess: subq = select(Subscribe.target_id).distinct().subquery() query = ( select(Target).join(subq).where(Target.platform_name == platform_name) ) return (await sess.scalars(query)).all() async def get_time_weight_config( self, target: T_Target, platform_name: str ) -> WeightConfig: async with create_session() as sess: time_weight_conf = ( await sess.scalars( select(ScheduleTimeWeight) .where( Target.platform_name == platform_name, Target.target == target ) .join(Target) ) ).all() targetObj = await sess.scalar( select(Target).where( Target.platform_name == platform_name, Target.target == target ) ) return WeightConfig( default=targetObj.default_schedule_weight, time_config=[ TimeWeightConfig( start_time=time_conf.start_time, end_time=time_conf.end_time, weight=time_conf.weight, ) for time_conf in time_weight_conf ], ) async def update_time_weight_config( self, target: T_Target, platform_name: str, conf: WeightConfig ): async with create_session() as sess: targetObj = await sess.scalar( select(Target).where( Target.platform_name == platform_name, Target.target == target ) ) if not targetObj: raise NoSuchTargetException() target_id = targetObj.id targetObj.default_schedule_weight = conf.default delete_statement = delete(ScheduleTimeWeight).where( ScheduleTimeWeight.target_id == target_id ) await sess.execute(delete_statement) for time_conf in conf.time_config: new_conf = ScheduleTimeWeight( start_time=time_conf.start_time, end_time=time_conf.end_time, weight=time_conf.weight, target=targetObj, ) sess.add(new_conf) await sess.commit() async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]: res = {} cur_time = _get_time() async with create_session() as sess: targets = ( await sess.scalars( select(Target) .where(Target.platform_name.in_(platform_list)) .options(selectinload(Target.time_weight)) ) ).all() for target in targets: key = f"{target.platform_name}-{target.target}" weight = target.default_schedule_weight for time_conf in target.time_weight: if ( time_conf.start_time <= cur_time and time_conf.end_time > cur_time ): weight = time_conf.weight break res[key] = weight return res async def get_platform_target_subscribers( self, platform_name: str, target: T_Target ) -> list[UserSubInfo]: async with create_session() as sess: query = ( select(Subscribe) .join(Target) .where(Target.platform_name == platform_name, Target.target == target) .options(selectinload(Subscribe.user)) ) subsribes = (await sess.scalars(query)).all() return list( map( lambda subscribe: UserSubInfo( T_User(subscribe.user.uid, subscribe.user.type), subscribe.categories, subscribe.tags, ), subsribes, ) ) async def get_all_weight_config( self, ) -> dict[str, dict[str, PlatformWeightConfigResp]]: res: dict[str, dict[str, PlatformWeightConfigResp]] = defaultdict(dict) async with create_session() as sess: query = select(Target) targets = (await sess.scalars(query)).all() query = select(ScheduleTimeWeight).options( selectinload(ScheduleTimeWeight.target) ) time_weights = (await sess.scalars(query)).all() for target in targets: platform_name = target.platform_name if platform_name not in res.keys(): res[platform_name][target.target] = PlatformWeightConfigResp( target=T_Target(target.target), target_name=target.target_name, platform_name=platform_name, weight=WeightConfig( default=target.default_schedule_weight, time_config=[] ), ) for time_weight_config in time_weights: platform_name = time_weight_config.target.platform_name target = time_weight_config.target.target res[platform_name][target].weight.time_config.append( TimeWeightConfig( start_time=time_weight_config.start_time, end_time=time_weight_config.end_time, weight=time_weight_config.weight, ) ) return res config = DBConfig()