From 331d0f61016806dfaf7de67cda57648dd424c831 Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Sun, 5 Jun 2022 16:45:35 +0800 Subject: [PATCH] update --- src/plugins/nonebot_bison/__init__.py | 1 + src/plugins/nonebot_bison/bootstrap.py | 15 ++++ .../nonebot_bison/config/config_legacy.py | 2 - src/plugins/nonebot_bison/config/db.py | 1 - src/plugins/nonebot_bison/config/db_config.py | 56 +++++++++--- .../nonebot_bison/platform/arknights.py | 19 +++-- .../nonebot_bison/platform/bilibili.py | 14 +-- src/plugins/nonebot_bison/platform/ff14.py | 11 ++- .../nonebot_bison/platform/ncm_artist.py | 11 ++- .../nonebot_bison/platform/ncm_radio.py | 3 +- .../nonebot_bison/platform/platform.py | 12 +-- src/plugins/nonebot_bison/platform/rss.py | 11 ++- src/plugins/nonebot_bison/platform/weibo.py | 10 ++- .../nonebot_bison/scheduler/__init__.py | 1 + .../nonebot_bison/scheduler/manager.py | 43 ++++++++++ .../nonebot_bison/scheduler/scheduler.py | 85 +++++++++++++++++++ .../nonebot_bison/utils/scheduler_config.py | 5 ++ tests/config/test_scheduler_conf.py | 50 +++++++++++ 18 files changed, 298 insertions(+), 52 deletions(-) create mode 100644 src/plugins/nonebot_bison/bootstrap.py create mode 100644 src/plugins/nonebot_bison/scheduler/manager.py create mode 100644 src/plugins/nonebot_bison/scheduler/scheduler.py diff --git a/src/plugins/nonebot_bison/__init__.py b/src/plugins/nonebot_bison/__init__.py index 1f513fe..46c621b 100644 --- a/src/plugins/nonebot_bison/__init__.py +++ b/src/plugins/nonebot_bison/__init__.py @@ -2,6 +2,7 @@ from nonebot.plugin import require from . import ( admin_page, + bootstrap, config, config_manager, platform, diff --git a/src/plugins/nonebot_bison/bootstrap.py b/src/plugins/nonebot_bison/bootstrap.py new file mode 100644 index 0000000..a13e672 --- /dev/null +++ b/src/plugins/nonebot_bison/bootstrap.py @@ -0,0 +1,15 @@ +from nonebot import get_driver + +from .config.config_legacy import start_up as legacy_db_startup +from .config.db import upgrade_db +from .scheduler.manager import init_scheduler + + +@get_driver().on_startup +async def bootstrap(): + # legacy db + legacy_db_startup() + # new db + await upgrade_db() + # init scheduler + await init_scheduler() diff --git a/src/plugins/nonebot_bison/config/config_legacy.py b/src/plugins/nonebot_bison/config/config_legacy.py index d615005..456541c 100644 --- a/src/plugins/nonebot_bison/config/config_legacy.py +++ b/src/plugins/nonebot_bison/config/config_legacy.py @@ -243,6 +243,4 @@ def start_up(): config.update_send_cache() -nonebot.get_driver().on_startup(start_up) - config = Config() diff --git a/src/plugins/nonebot_bison/config/db.py b/src/plugins/nonebot_bison/config/db.py index 856b14d..718783c 100644 --- a/src/plugins/nonebot_bison/config/db.py +++ b/src/plugins/nonebot_bison/config/db.py @@ -70,7 +70,6 @@ async def data_migrate(): logger.info("migrate success") -@nonebot.get_driver().on_startup async def upgrade_db(): alembic_cfg = Config() alembic_cfg.set_main_option( diff --git a/src/plugins/nonebot_bison/config/db_config.py b/src/plugins/nonebot_bison/config/db_config.py index 3367734..72882ee 100644 --- a/src/plugins/nonebot_bison/config/db_config.py +++ b/src/plugins/nonebot_bison/config/db_config.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import datetime, time -from typing import Optional +from typing import Any, Awaitable, Callable, Optional from nonebot_plugin_datastore.db import get_engine from sqlalchemy.ext.asyncio.session import AsyncSession @@ -34,6 +34,16 @@ class WeightConfig: class DBConfig: + def __init__(self): + self.add_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None + self.delete_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None + + def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]): + self.add_target_hook = fun + + def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]): + self.delete_target_hook = fun + async def add_subscribe( self, user: int, @@ -62,6 +72,8 @@ class DBConfig: db_target = Target( target=target, platform_name=platform_name, target_name=target_name ) + if self.add_target_hook: + await self.add_target_hook(platform_name, target) else: db_target.target_name = target_name # type: ignore subscribe = Subscribe( @@ -93,12 +105,12 @@ class DBConfig: ) target_obj = await session.scalar( select(Target).where( - Target.platform_name == platform_name, MTarget.target == target + Target.platform_name == platform_name, Target.target == target ) ) await session.execute( delete(Subscribe).where( - Subscribe.user == user_obj, MSubscribe.target == target_obj + Subscribe.user == user_obj, Subscribe.target == target_obj ) ) target_count = await session.scalar( @@ -108,7 +120,9 @@ class DBConfig: ) if target_count == 0: # delete empty target - await session.delete(target_obj) + # await session.delete(target_obj) + if self.delete_target_hook: + await self.delete_target_hook(platform_name, T_Target(target)) await session.commit() async def update_subscribe( @@ -139,15 +153,27 @@ class DBConfig: subscribe_obj.target.target_name = target_name await sess.commit() + async def get_platform_target(self, platform_name: str) -> list[Target]: + async with AsyncSession(get_engine()) as sess: + subq = select(Subscribe.target_id).distinct().subquery() + query = ( + select(Target).join(subq).where(Target.platform_name == platform_name) + ) + return (await sess.scalars(query)).all() + async def get_time_weight_config( self, target: T_Target, platform_name: str ) -> WeightConfig: async with AsyncSession(get_engine()) as sess: - time_weight_conf: list[ScheduleTimeWeight] = await sess.scalars( - select(ScheduleTimeWeight) - .where(Target.platform_name == platform_name, Target.target == target) - .join(Target) - ) + time_weight_conf: list[ScheduleTimeWeight] = ( + await sess.scalars( + select(ScheduleTimeWeight) + .where( + Target.platform_name == platform_name, Target.target == target + ) + .join(Target) + ) + ).all() targetObj: Target = await sess.scalar( select(Target).where( Target.platform_name == platform_name, Target.target == target @@ -192,11 +218,13 @@ class DBConfig: res = {} cur_time = _get_time() async with AsyncSession(get_engine()) as sess: - targets: list[Target] = await sess.scalars( - select(Target) - .where(Target.platform_name.in_(platform_list)) - .options(selectinload(Target.time_weight)) - ) + targets: list[Target] = ( + await sess.scalars( + select(Target) + .where(Target.platform_name.in_(platform_list)) + .options(selectinload(Target.time_weight)) + ) + ).all() for target in targets: key = f"{target.platform_name}-{target.target}" weight = target.default_schedule_weight diff --git a/src/plugins/nonebot_bison/platform/arknights.py b/src/plugins/nonebot_bison/platform/arknights.py index 721a425..1434a9f 100644 --- a/src/plugins/nonebot_bison/platform/arknights.py +++ b/src/plugins/nonebot_bison/platform/arknights.py @@ -7,9 +7,16 @@ from nonebot.plugin import require from ..post import Post from ..types import Category, RawPost, Target from ..utils import http_client +from ..utils.scheduler_config import SchedulerConfig from .platform import CategoryNotSupport, NewMessage, StatusChange +class ArknightsSchedConf(SchedulerConfig, name="arknights"): + + schedule_type = "interval" + schedule_setting = {"seconds": 30} + + class Arknights(NewMessage): categories = {1: "游戏公告"} @@ -18,8 +25,7 @@ class Arknights(NewMessage): enable_tag = False enabled = True is_common = False - schedule_type = "interval" - schedule_kw = {"seconds": 30} + scheduler_class = "arknights" has_target = False async def get_target_name(self, _: Target) -> str: @@ -91,8 +97,7 @@ class AkVersion(StatusChange): enable_tag = False enabled = True is_common = False - schedule_type = "interval" - schedule_kw = {"seconds": 30} + scheduler_class = "arknights" has_target = False async def get_target_name(self, _: Target) -> str: @@ -147,8 +152,7 @@ class MonsterSiren(NewMessage): enable_tag = False enabled = True is_common = False - schedule_type = "interval" - schedule_kw = {"seconds": 30} + scheduler_class = "arknights" has_target = False async def get_target_name(self, _: Target) -> str: @@ -199,8 +203,7 @@ class TerraHistoricusComic(NewMessage): enable_tag = False enabled = True is_common = False - schedule_type = "interval" - schedule_kw = {"seconds": 30} + scheduler_class = "arknights" has_target = False async def get_target_name(self, _: Target) -> str: diff --git a/src/plugins/nonebot_bison/platform/bilibili.py b/src/plugins/nonebot_bison/platform/bilibili.py index 56afd42..523f4b3 100644 --- a/src/plugins/nonebot_bison/platform/bilibili.py +++ b/src/plugins/nonebot_bison/platform/bilibili.py @@ -4,10 +4,16 @@ from typing import Any, Optional from ..post import Post from ..types import Category, RawPost, Tag, Target -from ..utils import http_client +from ..utils import SchedulerConfig, http_client from .platform import CategoryNotSupport, NewMessage, StatusChange +class BilibiliSchedConf(SchedulerConfig, name="bilibili.com"): + + schedule_type = "interval" + schedule_setting = {"seconds": 10} + + class Bilibili(NewMessage): categories = { @@ -22,8 +28,7 @@ class Bilibili(NewMessage): enable_tag = True enabled = True is_common = True - schedule_type = "interval" - schedule_kw = {"seconds": 10} + scheduler_class = "bilibili.com" name = "B站" has_target = True parse_target_promot = "请输入用户主页的链接" @@ -167,8 +172,7 @@ class Bilibililive(StatusChange): enable_tag = True enabled = True is_common = True - schedule_type = "interval" - schedule_kw = {"seconds": 10} + scheduler_class = "bilibili.com" name = "Bilibili直播" has_target = True diff --git a/src/plugins/nonebot_bison/platform/ff14.py b/src/plugins/nonebot_bison/platform/ff14.py index 0cbc92e..9f67b28 100644 --- a/src/plugins/nonebot_bison/platform/ff14.py +++ b/src/plugins/nonebot_bison/platform/ff14.py @@ -2,10 +2,16 @@ from typing import Any from ..post import Post from ..types import RawPost, Target -from ..utils import http_client +from ..utils import SchedulerConfig, http_client from .platform import NewMessage +class FF14SchedConf(SchedulerConfig, name="ff14"): + + schedule_type = "interval" + schedule_setting = {"seconds": 60} + + class FF14(NewMessage): categories = {} @@ -14,8 +20,7 @@ class FF14(NewMessage): enable_tag = False enabled = True is_common = False - schedule_type = "interval" - schedule_kw = {"seconds": 60} + scheduler_class = "ff14" has_target = False async def get_target_name(self, _: Target) -> str: diff --git a/src/plugins/nonebot_bison/platform/ncm_artist.py b/src/plugins/nonebot_bison/platform/ncm_artist.py index 00d329e..c98d4eb 100644 --- a/src/plugins/nonebot_bison/platform/ncm_artist.py +++ b/src/plugins/nonebot_bison/platform/ncm_artist.py @@ -3,10 +3,16 @@ from typing import Any, Optional from ..post import Post from ..types import RawPost, Target -from ..utils import http_client +from ..utils import SchedulerConfig, http_client from .platform import NewMessage +class NcmSchedConf(SchedulerConfig, name="music.163.com"): + + schedule_type = "interval" + schedule_setting = {"minutes": 1} + + class NcmArtist(NewMessage): categories = {} @@ -14,8 +20,7 @@ class NcmArtist(NewMessage): enable_tag = False enabled = True is_common = True - schedule_type = "interval" - schedule_kw = {"minutes": 1} + scheduler_class = "music.163.com" name = "网易云-歌手" has_target = True parse_target_promot = "请输入歌手主页(包含数字ID)的链接" diff --git a/src/plugins/nonebot_bison/platform/ncm_radio.py b/src/plugins/nonebot_bison/platform/ncm_radio.py index 14b439e..38d6967 100644 --- a/src/plugins/nonebot_bison/platform/ncm_radio.py +++ b/src/plugins/nonebot_bison/platform/ncm_radio.py @@ -14,8 +14,7 @@ class NcmRadio(NewMessage): enable_tag = False enabled = True is_common = False - schedule_type = "interval" - schedule_kw = {"minutes": 10} + scheduler_class = "music.163.com" name = "网易云-电台" has_target = True parse_target_promot = "请输入主播电台主页(包含数字ID)的链接" diff --git a/src/plugins/nonebot_bison/platform/platform.py b/src/plugins/nonebot_bison/platform/platform.py index 22c1200..23e0d35 100644 --- a/src/plugins/nonebot_bison/platform/platform.py +++ b/src/plugins/nonebot_bison/platform/platform.py @@ -39,8 +39,7 @@ class RegistryABCMeta(RegistryMeta, ABC): class Platform(metaclass=RegistryABCMeta, base=True): - schedule_type: Literal["date", "interval", "cron"] - schedule_kw: dict + scheduler_class: str is_common: bool enabled: bool name: str @@ -332,11 +331,11 @@ class NoTargetGroup(Platform, abstract=True): def __init__(self, platform_list: list[Platform]): self.platform_list = platform_list + self.platform_name = platform_list[0].platform_name name = self.DUMMY_STR self.categories = {} categories_keys = set() - self.schedule_type = platform_list[0].schedule_type - self.schedule_kw = platform_list[0].schedule_kw + self.scheduler_class = platform_list[0].scheduler_class for platform in platform_list: if platform.has_target: raise RuntimeError( @@ -355,10 +354,7 @@ class NoTargetGroup(Platform, abstract=True): ) categories_keys |= platform_category_key_set self.categories.update(platform.categories) - if ( - platform.schedule_kw != self.schedule_kw - or platform.schedule_type != self.schedule_type - ): + if platform.scheduler_class != self.scheduler_class: raise RuntimeError( "Platform scheduler for {} not fit".format(self.platform_name) ) diff --git a/src/plugins/nonebot_bison/platform/rss.py b/src/plugins/nonebot_bison/platform/rss.py index ed09e8a..b5e7cc0 100644 --- a/src/plugins/nonebot_bison/platform/rss.py +++ b/src/plugins/nonebot_bison/platform/rss.py @@ -6,10 +6,16 @@ from bs4 import BeautifulSoup as bs from ..post import Post from ..types import RawPost, Target -from ..utils import http_client +from ..utils import SchedulerConfig, http_client from .platform import NewMessage +class RssSchedConf(SchedulerConfig, name="rss"): + + schedule_type = "interval" + schedule_setting = {"seconds": 30} + + class Rss(NewMessage): categories = {} @@ -18,8 +24,7 @@ class Rss(NewMessage): name = "Rss" enabled = True is_common = True - schedule_type = "interval" - schedule_kw = {"seconds": 30} + scheduler_class = "rss" has_target = True async def get_target_name(self, target: Target) -> Optional[str]: diff --git a/src/plugins/nonebot_bison/platform/weibo.py b/src/plugins/nonebot_bison/platform/weibo.py index f7973a3..0749f28 100644 --- a/src/plugins/nonebot_bison/platform/weibo.py +++ b/src/plugins/nonebot_bison/platform/weibo.py @@ -8,10 +8,15 @@ from nonebot.log import logger from ..post import Post from ..types import * -from ..utils import http_client +from ..utils import SchedulerConfig, http_client from .platform import NewMessage +class WeiboSchedConf(SchedulerConfig, name="weibo.com"): + schedule_type = "interval" + schedule_setting = {"seconds": 3} + + class Weibo(NewMessage): categories = { @@ -25,8 +30,7 @@ class Weibo(NewMessage): name = "新浪微博" enabled = True is_common = True - schedule_type = "interval" - schedule_kw = {"seconds": 3} + scheduler_class = "weibo.com" has_target = True parse_target_promot = "请输入用户主页(包含数字UID)的链接" diff --git a/src/plugins/nonebot_bison/scheduler/__init__.py b/src/plugins/nonebot_bison/scheduler/__init__.py index e69de29..4fe6284 100644 --- a/src/plugins/nonebot_bison/scheduler/__init__.py +++ b/src/plugins/nonebot_bison/scheduler/__init__.py @@ -0,0 +1 @@ +from .manager import * diff --git a/src/plugins/nonebot_bison/scheduler/manager.py b/src/plugins/nonebot_bison/scheduler/manager.py new file mode 100644 index 0000000..7f332c4 --- /dev/null +++ b/src/plugins/nonebot_bison/scheduler/manager.py @@ -0,0 +1,43 @@ +from nonebot.log import logger + +from ..config import config +from ..config.db_model import Target +from ..platform import platform_manager +from ..types import Target as T_Target +from ..utils import SchedulerConfig +from .scheduler import Scheduler + +scheduler_dict: dict[str, Scheduler] = {} +_schedule_class_dict: dict[str, list[Target]] = {} + + +async def init_scheduler(): + for platform in platform_manager.values(): + scheduler_class = platform.scheduler_class + platform_name = platform.platform_name + targets = await config.get_platform_target(platform_name) + if scheduler_class not in _schedule_class_dict: + _schedule_class_dict[scheduler_class] = targets + else: + _schedule_class_dict[scheduler_class].extend(targets) + for scheduler_class, target_list in _schedule_class_dict.items(): + schedulable_args = [] + for target in target_list: + schedulable_args.append((target.platform_name, T_Target(target.target))) + scheduler_dict[scheduler_class] = Scheduler(scheduler_class, schedulable_args) + + +async def handle_insert_new_target(platform_name: str, target: T_Target): + platform = platform_manager[platform_name] + scheduler_obj = scheduler_dict[platform.scheduler_class] + scheduler_obj.insert_new_schedulable(platform_name, target) + + +async def handle_delete_target(platform_name: str, target: T_Target): + platform = platform_manager[platform_name] + scheduler_obj = scheduler_dict[platform.scheduler_class] + scheduler_obj.delete_schedulable(platform_name, target) + + +config.register_add_target_hook(handle_delete_target) +config.register_delete_target_hook(handle_delete_target) diff --git a/src/plugins/nonebot_bison/scheduler/scheduler.py b/src/plugins/nonebot_bison/scheduler/scheduler.py new file mode 100644 index 0000000..60df150 --- /dev/null +++ b/src/plugins/nonebot_bison/scheduler/scheduler.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass +from typing import Optional + +from nonebot.log import logger + +from ..config import config +from ..platform.platform import Platform +from ..types import Target +from ..utils import SchedulerConfig + + +@dataclass +class Schedulable: + platform_name: str + target: Target + current_weight: int + + +class Scheduler: + + schedulable_list: list[Schedulable] + + def __init__(self, name: str, schedulables: list[tuple[str, Target]]): + conf = SchedulerConfig.registry.get(name) + if not conf: + logger.error(f"scheduler config [{name}] not found, exiting") + raise RuntimeError(f"{name} not found") + self.scheduler_config = conf + self.schedulable_list = [] + platform_name_set = set() + for platform_name, target in schedulables: + self.schedulable_list.append( + Schedulable( + platform_name=platform_name, target=target, current_weight=0 + ) + ) + platform_name_set.add(platform_name) + self.platform_name_list = list(platform_name_set) + self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值 + + async def schedule(self) -> Optional[Schedulable]: + if not self.schedulable_list: + return None + cur_weight = await config.get_current_weight_val(self.platform_name_list) + weight_sum = self.pre_weight_val + self.pre_weight_val = 0 + cur_max_schedulable = None + for schedulable in self.schedulable_list: + schedulable.current_weight += cur_weight[ + f"{schedulable.platform_name}-{schedulable.target}" + ] + weight_sum += cur_weight[ + f"{schedulable.platform_name}-{schedulable.target}" + ] + if ( + not cur_max_schedulable + or cur_max_schedulable.current_weight < schedulable.current_weight + ): + cur_max_schedulable = schedulable + assert cur_max_schedulable + cur_max_schedulable.current_weight -= weight_sum + return cur_max_schedulable + + def insert_new_schedulable(self, platform_name: str, target: Target): + self.pre_weight_val += 1000 + self.schedulable_list.append(Schedulable(platform_name, target, 1000)) + logger.info( + f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})" + ) + + def delete_schedulable(self, platform_name, target: Target): + if not self.schedulable_list: + return + to_find_idx = None + for idx, schedulable in enumerate(self.schedulable_list): + if ( + schedulable.platform_name == platform_name + and schedulable.target == target + ): + to_find_idx = idx + break + if to_find_idx is not None: + deleted_schdulable = self.schedulable_list.pop(to_find_idx) + self.pre_weight_val -= deleted_schdulable.current_weight + return diff --git a/src/plugins/nonebot_bison/utils/scheduler_config.py b/src/plugins/nonebot_bison/utils/scheduler_config.py index 8af28e4..ce01c0b 100644 --- a/src/plugins/nonebot_bison/utils/scheduler_config.py +++ b/src/plugins/nonebot_bison/utils/scheduler_config.py @@ -6,7 +6,12 @@ class SchedulerConfig: schedule_type: Literal["date", "interval", "cron"] schedule_setting: dict registry: dict[str, Type["SchedulerConfig"]] = {} + name: str def __init_subclass__(cls, *, name, **kwargs): super().__init_subclass__(**kwargs) cls.registry[name] = cls + cls.name = name + + def __str__(self): + return f"[{self.name}]-{self.name}-{self.schedule_setting}" diff --git a/tests/config/test_scheduler_conf.py b/tests/config/test_scheduler_conf.py index 337e072..619e9ee 100644 --- a/tests/config/test_scheduler_conf.py +++ b/tests/config/test_scheduler_conf.py @@ -117,3 +117,53 @@ async def test_get_current_weight(app: App, db_migration): assert weight["weibo-weibo_id"] == 10 assert weight["weibo-weibo_id1"] == 10 assert weight["weibo2-weibo_id1"] == 10 + + +async def test_get_platform_target(app: App, db_migration): + from nonebot_bison.config import db_config + from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config + from nonebot_bison.config.db_model import Subscribe, Target, User + from nonebot_bison.types import Target as T_Target + from nonebot_plugin_datastore.db import get_engine + from sqlalchemy.ext.asyncio.session import AsyncSession + from sqlalchemy.sql.expression import select + + await config.add_subscribe( + user=123, + user_type="group", + target=T_Target("weibo_id"), + target_name="weibo_name", + platform_name="weibo", + cats=[], + tags=[], + ) + await config.add_subscribe( + user=123, + user_type="group", + target=T_Target("weibo_id1"), + target_name="weibo_name1", + platform_name="weibo", + cats=[], + tags=[], + ) + await config.add_subscribe( + user=245, + user_type="group", + target=T_Target("weibo_id1"), + target_name="weibo_name1", + platform_name="weibo", + cats=[], + tags=[], + ) + res = await config.get_platform_target("weibo") + assert len(res) == 2 + await config.del_subscribe(123, "group", T_Target("weibo_id1"), "weibo") + res = await config.get_platform_target("weibo") + assert len(res) == 2 + await config.del_subscribe(123, "group", T_Target("weibo_id"), "weibo") + res = await config.get_platform_target("weibo") + assert len(res) == 1 + + async with AsyncSession(get_engine()) as sess: + res = await sess.scalars(select(Target).where(Target.platform_name == "weibo")) + assert len(res.all()) == 2