From b47995a09c1658c87be623dacd0d72194addca7c Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Thu, 9 Jun 2022 00:42:14 +0800 Subject: [PATCH] fix bug, add test case --- .../nonebot_bison/scheduler/manager.py | 14 +- .../nonebot_bison/scheduler/scheduler.py | 11 +- tests/scheduler/test_scheduler.py | 147 ++++++++++++++++++ 3 files changed, 165 insertions(+), 7 deletions(-) create mode 100644 tests/scheduler/test_scheduler.py diff --git a/src/plugins/nonebot_bison/scheduler/manager.py b/src/plugins/nonebot_bison/scheduler/manager.py index 7f332c4..f3acad6 100644 --- a/src/plugins/nonebot_bison/scheduler/manager.py +++ b/src/plugins/nonebot_bison/scheduler/manager.py @@ -8,10 +8,11 @@ from ..utils import SchedulerConfig from .scheduler import Scheduler scheduler_dict: dict[str, Scheduler] = {} -_schedule_class_dict: dict[str, list[Target]] = {} async def init_scheduler(): + _schedule_class_dict: dict[str, list[Target]] = {} + _schedule_class_platform_dict: dict[str, list[str]] = {} for platform in platform_manager.values(): scheduler_class = platform.scheduler_class platform_name = platform.platform_name @@ -20,11 +21,18 @@ async def init_scheduler(): _schedule_class_dict[scheduler_class] = targets else: _schedule_class_dict[scheduler_class].extend(targets) + if scheduler_class not in _schedule_class_platform_dict: + _schedule_class_platform_dict[scheduler_class] = [platform_name] + else: + _schedule_class_platform_dict[scheduler_class].append(platform_name) for scheduler_class, target_list in _schedule_class_dict.items(): schedulable_args = [] for target in target_list: schedulable_args.append((target.platform_name, T_Target(target.target))) - scheduler_dict[scheduler_class] = Scheduler(scheduler_class, schedulable_args) + platform_name_list = _schedule_class_platform_dict[scheduler_class] + scheduler_dict[scheduler_class] = Scheduler( + scheduler_class, schedulable_args, platform_name_list + ) async def handle_insert_new_target(platform_name: str, target: T_Target): @@ -39,5 +47,5 @@ async def handle_delete_target(platform_name: str, target: T_Target): scheduler_obj.delete_schedulable(platform_name, target) -config.register_add_target_hook(handle_delete_target) +config.register_add_target_hook(handle_insert_new_target) config.register_delete_target_hook(handle_delete_target) diff --git a/src/plugins/nonebot_bison/scheduler/scheduler.py b/src/plugins/nonebot_bison/scheduler/scheduler.py index 39c703c..1cc15e5 100644 --- a/src/plugins/nonebot_bison/scheduler/scheduler.py +++ b/src/plugins/nonebot_bison/scheduler/scheduler.py @@ -25,7 +25,12 @@ class Scheduler: schedulable_list: list[Schedulable] - def __init__(self, name: str, schedulables: list[tuple[str, Target]]): + def __init__( + self, + name: str, + schedulables: list[tuple[str, Target]], + platform_name_list: list[str], + ): conf = SchedulerConfig.registry.get(name) self.name = name if not conf: @@ -33,15 +38,13 @@ class Scheduler: raise RuntimeError(f"{name} not found") self.scheduler_config = conf self.schedulable_list = [] - platform_name_set = set() for platform_name, target in schedulables: self.schedulable_list.append( Schedulable( platform_name=platform_name, target=target, current_weight=0 ) ) - platform_name_set.add(platform_name) - self.platform_name_list = list(platform_name_set) + self.platform_name_list = platform_name_list self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值 logger.info( f"register scheduler for {name} with {self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}" diff --git a/tests/scheduler/test_scheduler.py b/tests/scheduler/test_scheduler.py new file mode 100644 index 0000000..fa4294f --- /dev/null +++ b/tests/scheduler/test_scheduler.py @@ -0,0 +1,147 @@ +import typing +from datetime import time + +import pytest +from nonebug import App + +if typing.TYPE_CHECKING: + from nonebot_bison.scheduler.scheduler import Schedulable + + +async def get_schedule_times(scheduler_class: str, time: int) -> dict[str, int]: + from nonebot_bison.scheduler import scheduler_dict + + scheduler = scheduler_dict[scheduler_class] + res = {} + for _ in range(time): + schedulable = await scheduler.get_next_schedulable() + assert schedulable + key = f"{schedulable.platform_name}-{schedulable.target}" + res[key] = res.get(key, 0) + 1 + return res + + +async def test_scheduler_without_time(init_scheduler): + from nonebot_bison.config import config + from nonebot_bison.config.db_config import WeightConfig + from nonebot_bison.scheduler.manager import init_scheduler, scheduler_dict + from nonebot_bison.types import Target as T_Target + + await config.add_subscribe( + 123, "group", T_Target("t1"), "target1", "bilibili", [], [] + ) + await config.add_subscribe( + 123, "group", T_Target("t2"), "target1", "bilibili", [], [] + ) + await config.add_subscribe( + 123, "group", T_Target("t2"), "target1", "bilibili-live", [], [] + ) + + await config.update_time_weight_config( + T_Target("t2"), "bilibili", WeightConfig(20, []) + ) + await config.update_time_weight_config( + T_Target("t2"), "bilibili-live", WeightConfig(30, []) + ) + + await init_scheduler() + + static_res = await get_schedule_times("bilibili.com", 6) + assert static_res["bilibili-t1"] == 1 + assert static_res["bilibili-t2"] == 2 + assert static_res["bilibili-live-t2"] == 3 + + static_res = await get_schedule_times("bilibili.com", 6) + assert static_res["bilibili-t1"] == 1 + assert static_res["bilibili-t2"] == 2 + assert static_res["bilibili-live-t2"] == 3 + + +async def test_scheduler_with_time(app: App, init_scheduler): + from nonebot_bison.config import config, db_config + from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig + from nonebot_bison.scheduler.manager import init_scheduler, scheduler_dict + from nonebot_bison.types import Target as T_Target + + await config.add_subscribe( + 123, "group", T_Target("t1"), "target1", "bilibili", [], [] + ) + await config.add_subscribe( + 123, "group", T_Target("t2"), "target1", "bilibili", [], [] + ) + await config.add_subscribe( + 123, "group", T_Target("t2"), "target1", "bilibili-live", [], [] + ) + + await config.update_time_weight_config( + T_Target("t2"), + "bilibili", + WeightConfig(20, [TimeWeightConfig(time(10), time(11), 1000)]), + ) + await config.update_time_weight_config( + T_Target("t2"), "bilibili-live", WeightConfig(30, []) + ) + + await init_scheduler() + + app.monkeypatch.setattr(db_config, "_get_time", lambda: time(1, 30)) + static_res = await get_schedule_times("bilibili.com", 6) + assert static_res["bilibili-t1"] == 1 + assert static_res["bilibili-t2"] == 2 + assert static_res["bilibili-live-t2"] == 3 + + static_res = await get_schedule_times("bilibili.com", 6) + assert static_res["bilibili-t1"] == 1 + assert static_res["bilibili-t2"] == 2 + assert static_res["bilibili-live-t2"] == 3 + + app.monkeypatch.setattr(db_config, "_get_time", lambda: time(10, 30)) + + static_res = await get_schedule_times("bilibili.com", 6) + assert static_res["bilibili-t2"] == 6 + + +async def test_scheduler_add_new(init_scheduler): + from nonebot_bison.config import config + from nonebot_bison.config.db_config import WeightConfig + from nonebot_bison.scheduler.manager import init_scheduler, scheduler_dict + from nonebot_bison.types import Target as T_Target + + await config.add_subscribe( + 123, "group", T_Target("t1"), "target1", "bilibili", [], [] + ) + + await init_scheduler() + + await config.add_subscribe( + 2345, "group", T_Target("t1"), "target1", "bilibili", [], [] + ) + await config.add_subscribe( + 123, "group", T_Target("t2"), "target2", "bilibili", [], [] + ) + stat_res = await get_schedule_times("bilibili.com", 1) + assert stat_res["bilibili-t2"] == 1 + + +async def test_schedule_delete(init_scheduler): + from nonebot_bison.config import config + from nonebot_bison.config.db_config import WeightConfig + from nonebot_bison.scheduler.manager import init_scheduler, scheduler_dict + from nonebot_bison.types import Target as T_Target + + await config.add_subscribe( + 123, "group", T_Target("t1"), "target1", "bilibili", [], [] + ) + await config.add_subscribe( + 123, "group", T_Target("t2"), "target1", "bilibili", [], [] + ) + + await init_scheduler() + + stat_res = await get_schedule_times("bilibili.com", 2) + assert stat_res["bilibili-t2"] == 1 + assert stat_res["bilibili-t1"] == 1 + + await config.del_subscribe(123, "group", T_Target("t1"), "bilibili") + stat_res = await get_schedule_times("bilibili.com", 2) + assert stat_res["bilibili-t2"] == 2