From 68baa3b915be4df758548dc236b743f5a700dfe4 Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Wed, 1 Jun 2022 23:03:46 +0800 Subject: [PATCH] add scheduler config --- src/plugins/nonebot_bison/config/db_config.py | 146 ++++++++++++++---- .../nonebot_bison/scheduler/__init__.py | 1 - src/plugins/nonebot_bison/utils/__init__.py | 3 +- .../{scheduler => utils}/scheduler_config.py | 0 tests/config/test_scheduler_conf.py | 119 ++++++++++++++ 5 files changed, 240 insertions(+), 29 deletions(-) rename src/plugins/nonebot_bison/{scheduler => utils}/scheduler_config.py (100%) create mode 100644 tests/config/test_scheduler_conf.py diff --git a/src/plugins/nonebot_bison/config/db_config.py b/src/plugins/nonebot_bison/config/db_config.py index 8615bcf..3367734 100644 --- a/src/plugins/nonebot_bison/config/db_config.py +++ b/src/plugins/nonebot_bison/config/db_config.py @@ -1,15 +1,36 @@ +from dataclasses import dataclass +from datetime import datetime, time from typing import Optional -from nonebot_bison.types import Category, Tag, Target from nonebot_plugin_datastore.db import get_engine from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.orm import selectinload from sqlalchemy.sql.expression import delete, select from sqlalchemy.sql.functions import func -from .db_model import Subscribe as MSubscribe -from .db_model import Target as MTarget -from .db_model import User +from ..types import Category, Tag +from ..types import Target as T_Target +from .db_model import ScheduleTimeWeight, Subscribe, Target, User + + +def _get_time(): + dt = datetime.now() + cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second) + return cur_time + + +@dataclass +class TimeWeightConfig: + start_time: time + end_time: time + weight: int + + +@dataclass +class WeightConfig: + + default: int + time_config: list[TimeWeightConfig] class DBConfig: @@ -17,7 +38,7 @@ class DBConfig: self, user: int, user_type: str, - target: Target, + target: T_Target, target_name: str, platform_name: str, cats: list[Category], @@ -32,18 +53,18 @@ class DBConfig: db_user = User(uid=user, type=user_type) session.add(db_user) db_target_stmt = ( - select(MTarget) - .where(MTarget.platform_name == platform_name) - .where(MTarget.target == target) + select(Target) + .where(Target.platform_name == platform_name) + .where(Target.target == target) ) - db_target: Optional[MTarget] = await session.scalar(db_target_stmt) + db_target: Optional[Target] = await session.scalar(db_target_stmt) if not db_target: - db_target = MTarget( + db_target = Target( target=target, platform_name=platform_name, target_name=target_name ) else: db_target.target_name = target_name # type: ignore - subscribe = MSubscribe( + subscribe = Subscribe( categories=cats, tags=tags, user=db_user, @@ -52,15 +73,15 @@ class DBConfig: session.add(subscribe) await session.commit() - async def list_subscribe(self, user: int, user_type: str) -> list[MSubscribe]: + async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]: async with AsyncSession(get_engine()) as session: query_stmt = ( - select(MSubscribe) + select(Subscribe) .where(User.type == user_type, User.uid == user) .join(User) - .options(selectinload(MSubscribe.target)) # type:ignore + .options(selectinload(Subscribe.target)) # type:ignore ) - subs: list[MSubscribe] = (await session.scalars(query_stmt)).all() + subs: list[Subscribe] = (await session.scalars(query_stmt)).all() return subs async def del_subscribe( @@ -71,19 +92,19 @@ class DBConfig: select(User).where(User.uid == user, User.type == user_type) ) target_obj = await session.scalar( - select(MTarget).where( - MTarget.platform_name == platform_name, MTarget.target == target + select(Target).where( + Target.platform_name == platform_name, MTarget.target == target ) ) await session.execute( - delete(MSubscribe).where( - MSubscribe.user == user_obj, MSubscribe.target == target_obj + delete(Subscribe).where( + Subscribe.user == user_obj, MSubscribe.target == target_obj ) ) target_count = await session.scalar( select(func.count()) - .select_from(MSubscribe) - .where(MSubscribe.target == target_obj) + .select_from(Subscribe) + .where(Subscribe.target == target_obj) ) if target_count == 0: # delete empty target @@ -101,22 +122,93 @@ class DBConfig: tags: list, ): async with AsyncSession(get_engine()) as sess: - subscribe_obj: MSubscribe = await sess.scalar( - select(MSubscribe) + subscribe_obj: Subscribe = await sess.scalar( + select(Subscribe) .where( User.uid == user, User.type == user_type, - MTarget.target == target, - MTarget.platform_name == platform_name, + Target.target == target, + Target.platform_name == platform_name, ) .join(User) - .join(MTarget) - .options(selectinload(MSubscribe.target)) # type:ignore + .join(Target) + .options(selectinload(Subscribe.target)) # type:ignore ) subscribe_obj.tags = tags # type:ignore subscribe_obj.categories = cats # type:ignore subscribe_obj.target.target_name = target_name await sess.commit() + async def get_time_weight_config( + self, target: T_Target, platform_name: str + ) -> WeightConfig: + async with AsyncSession(get_engine()) as sess: + time_weight_conf: list[ScheduleTimeWeight] = await sess.scalars( + select(ScheduleTimeWeight) + .where(Target.platform_name == platform_name, Target.target == target) + .join(Target) + ) + targetObj: Target = await sess.scalar( + select(Target).where( + Target.platform_name == platform_name, Target.target == target + ) + ) + return WeightConfig( + default=targetObj.default_schedule_weight, + time_config=[ + TimeWeightConfig( + start_time=time_conf.start_time, + end_time=time_conf.end_time, + weight=time_conf.weight, + ) + for time_conf in time_weight_conf + ], + ) + + async def update_time_weight_config( + self, target: T_Target, platform_name: str, conf: WeightConfig + ): + async with AsyncSession(get_engine()) as sess: + targetObj: Target = await sess.scalar( + select(Target).where( + Target.platform_name == platform_name, Target.target == target + ) + ) + target_id = targetObj.id + targetObj.default_schedule_weight = conf.default + delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id) + for time_conf in conf.time_config: + new_conf = ScheduleTimeWeight( + start_time=time_conf.start_time, + end_time=time_conf.end_time, + weight=time_conf.weight, + target=targetObj, + ) + sess.add(new_conf) + + await sess.commit() + + async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]: + res = {} + cur_time = _get_time() + async with AsyncSession(get_engine()) as sess: + targets: list[Target] = await sess.scalars( + select(Target) + .where(Target.platform_name.in_(platform_list)) + .options(selectinload(Target.time_weight)) + ) + for target in targets: + key = f"{target.platform_name}-{target.target}" + weight = target.default_schedule_weight + for time_conf in target.time_weight: + if ( + time_conf.start_time <= cur_time + and time_conf.end_time > cur_time + ): + weight = time_conf.weight + break + res[key] = weight + return res + config = DBConfig() diff --git a/src/plugins/nonebot_bison/scheduler/__init__.py b/src/plugins/nonebot_bison/scheduler/__init__.py index 46bfa64..e69de29 100644 --- a/src/plugins/nonebot_bison/scheduler/__init__.py +++ b/src/plugins/nonebot_bison/scheduler/__init__.py @@ -1 +0,0 @@ -from .scheduler_config import SchedulerConfig diff --git a/src/plugins/nonebot_bison/utils/__init__.py b/src/plugins/nonebot_bison/utils/__init__.py index 73c8baa..56b61e7 100644 --- a/src/plugins/nonebot_bison/utils/__init__.py +++ b/src/plugins/nonebot_bison/utils/__init__.py @@ -10,8 +10,9 @@ from nonebot.plugin import require from ..plugin_config import plugin_config from .http import http_client +from .scheduler_config import SchedulerConfig -__all__ = ["http_client", "Singleton", "parse_text", "html_to_text"] +__all__ = ["http_client", "Singleton", "parse_text", "html_to_text", "SchedulerConfig"] class Singleton(type): diff --git a/src/plugins/nonebot_bison/scheduler/scheduler_config.py b/src/plugins/nonebot_bison/utils/scheduler_config.py similarity index 100% rename from src/plugins/nonebot_bison/scheduler/scheduler_config.py rename to src/plugins/nonebot_bison/utils/scheduler_config.py diff --git a/tests/config/test_scheduler_conf.py b/tests/config/test_scheduler_conf.py new file mode 100644 index 0000000..337e072 --- /dev/null +++ b/tests/config/test_scheduler_conf.py @@ -0,0 +1,119 @@ +from datetime import time + +from nonebug import App + + +async def test_create_config(app: App, db_migration): + from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config + from nonebot_bison.config.db_model import Subscribe, Target, User + from nonebot_bison.types import Target as T_Target + from nonebot_plugin_datastore.db import get_engine + + await config.add_subscribe( + user=123, + user_type="group", + target=T_Target("weibo_id"), + target_name="weibo_name", + platform_name="weibo", + cats=[], + tags=[], + ) + await config.add_subscribe( + user=123, + user_type="group", + target=T_Target("weibo_id1"), + target_name="weibo_name1", + platform_name="weibo", + cats=[], + tags=[], + ) + await config.update_time_weight_config( + target=T_Target("weibo_id"), + platform_name="weibo", + conf=WeightConfig( + default=10, + time_config=[ + TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20) + ], + ), + ) + + test_config = await config.get_time_weight_config( + target=T_Target("weibo_id"), platform_name="weibo" + ) + assert test_config.default == 10 + assert test_config.time_config == [ + TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20) + ] + test_config1 = await config.get_time_weight_config( + target=T_Target("weibo_id1"), platform_name="weibo" + ) + assert test_config1.default == 10 + assert test_config1.time_config == [] + + +async def test_get_current_weight(app: App, db_migration): + from datetime import time + + from nonebot_bison.config import db_config + from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config + from nonebot_bison.config.db_model import Subscribe, Target, User + from nonebot_bison.types import Target as T_Target + from nonebot_plugin_datastore.db import get_engine + + await config.add_subscribe( + user=123, + user_type="group", + target=T_Target("weibo_id"), + target_name="weibo_name", + platform_name="weibo", + cats=[], + tags=[], + ) + await config.add_subscribe( + user=123, + user_type="group", + target=T_Target("weibo_id1"), + target_name="weibo_name1", + platform_name="weibo", + cats=[], + tags=[], + ) + await config.add_subscribe( + user=123, + user_type="group", + target=T_Target("weibo_id1"), + target_name="weibo_name2", + platform_name="weibo2", + cats=[], + tags=[], + ) + await config.update_time_weight_config( + target=T_Target("weibo_id"), + platform_name="weibo", + conf=WeightConfig( + default=10, + time_config=[ + TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20), + TimeWeightConfig(start_time=time(4, 0), end_time=time(5, 0), weight=30), + ], + ), + ) + app.monkeypatch.setattr(db_config, "_get_time", lambda: time(1, 30)) + weight = await config.get_current_weight_val(["weibo", "weibo2"]) + assert len(weight) == 3 + assert weight["weibo-weibo_id"] == 20 + assert weight["weibo-weibo_id1"] == 10 + assert weight["weibo2-weibo_id1"] == 10 + app.monkeypatch.setattr(db_config, "_get_time", lambda: time(4, 0)) + weight = await config.get_current_weight_val(["weibo", "weibo2"]) + assert len(weight) == 3 + assert weight["weibo-weibo_id"] == 30 + assert weight["weibo-weibo_id1"] == 10 + assert weight["weibo2-weibo_id1"] == 10 + app.monkeypatch.setattr(db_config, "_get_time", lambda: time(5, 0)) + weight = await config.get_current_weight_val(["weibo", "weibo2"]) + assert len(weight) == 3 + assert weight["weibo-weibo_id"] == 10 + assert weight["weibo-weibo_id1"] == 10 + assert weight["weibo2-weibo_id1"] == 10