diff --git a/src/plugins/nonebot_bison/bootstrap.py b/src/plugins/nonebot_bison/bootstrap.py index a13e672..96b2a23 100644 --- a/src/plugins/nonebot_bison/bootstrap.py +++ b/src/plugins/nonebot_bison/bootstrap.py @@ -2,6 +2,7 @@ from nonebot import get_driver from .config.config_legacy import start_up as legacy_db_startup from .config.db import upgrade_db +from .scheduler.aps import start_scheduler from .scheduler.manager import init_scheduler @@ -13,3 +14,5 @@ async def bootstrap(): await upgrade_db() # init scheduler await init_scheduler() + # start scheduler + start_scheduler() diff --git a/src/plugins/nonebot_bison/config/config_legacy.py b/src/plugins/nonebot_bison/config/config_legacy.py index 456541c..1c1a06c 100644 --- a/src/plugins/nonebot_bison/config/config_legacy.py +++ b/src/plugins/nonebot_bison/config/config_legacy.py @@ -32,6 +32,22 @@ def get_config_path() -> str: return new_path +def drop(): + if plugin_config.bison_config_path: + data_dir = plugin_config.bison_config_path + else: + working_dir = os.getcwd() + data_dir = path.join(working_dir, "data") + old_path = path.join(data_dir, "bison.json") + new_path = path.join(data_dir, "bison-legacy.json") + if os.path.exists(old_path): + config.db.close() + config.available = False + os.rename(old_path, new_path) + return True + return False + + class SubscribeContent(TypedDict): target: str target_type: str @@ -223,6 +239,8 @@ class Config(metaclass=Singleton): def start_up(): config = Config() + if not config.available: + return if not (search_res := config.kv_config.search(Query().name == "version")): config.kv_config.insert({"name": "version", "value": config.migrate_version}) elif search_res[0].get("value") < config.migrate_version: diff --git a/src/plugins/nonebot_bison/config/db.py b/src/plugins/nonebot_bison/config/db.py index 718783c..8e151c6 100644 --- a/src/plugins/nonebot_bison/config/db.py +++ b/src/plugins/nonebot_bison/config/db.py @@ -10,7 +10,7 @@ from nonebot_plugin_datastore.db import get_engine from sqlalchemy.engine.base import Connection from sqlalchemy.ext.asyncio.session import AsyncSession -from .config_legacy import ConfigContent, config +from .config_legacy import ConfigContent, config, drop from .db_model import Base, Subscribe, Target, User DATA = PluginData("bison") @@ -67,6 +67,7 @@ async def data_migrate(): + subscribe_to_create ) await sess.commit() + drop() logger.info("migrate success") diff --git a/src/plugins/nonebot_bison/config/db_config.py b/src/plugins/nonebot_bison/config/db_config.py index 72882ee..1548553 100644 --- a/src/plugins/nonebot_bison/config/db_config.py +++ b/src/plugins/nonebot_bison/config/db_config.py @@ -10,6 +10,8 @@ from sqlalchemy.sql.functions import func from ..types import Category, Tag from ..types import Target as T_Target +from ..types import User as T_User +from ..types import UserSubInfo from .db_model import ScheduleTimeWeight, Subscribe, Target, User @@ -238,5 +240,27 @@ class DBConfig: res[key] = weight return res + async def get_platform_target_subscribers( + self, platform_name: str, target: T_Target + ) -> list[UserSubInfo]: + async with AsyncSession(get_engine()) as sess: + query = ( + select(Subscribe) + .join(Target) + .where(Target.platform_name == platform_name, Target.target == target) + .options(selectinload(Subscribe.user)) + ) + subsribes: list[Subscribe] = (await sess.scalars(query)).all() + return list( + map( + lambda subscribe: UserSubInfo( + T_User(subscribe.user.uid, subscribe.user.type), + subscribe.categories, + subscribe.tags, + ), + subsribes, + ) + ) + config = DBConfig() diff --git a/src/plugins/nonebot_bison/platform/platform.py b/src/plugins/nonebot_bison/platform/platform.py index 23e0d35..4b19df5 100644 --- a/src/plugins/nonebot_bison/platform/platform.py +++ b/src/plugins/nonebot_bison/platform/platform.py @@ -135,9 +135,7 @@ class Platform(metaclass=RegistryABCMeta, base=True): self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo] ) -> list[tuple[User, list[Post]]]: res: list[tuple[User, list[Post]]] = [] - for user, category_getter, tag_getter in users: - required_tags = tag_getter(target) if self.enable_tag else [] - cats = category_getter(target) + for user, cats, required_tags in users: user_raw_post = await self.filter_user_custom( new_posts, cats, required_tags ) diff --git a/src/plugins/nonebot_bison/scheduler/aps.py b/src/plugins/nonebot_bison/scheduler/aps.py new file mode 100644 index 0000000..fea4ff1 --- /dev/null +++ b/src/plugins/nonebot_bison/scheduler/aps.py @@ -0,0 +1,31 @@ +import logging + +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from nonebot.log import LoguruHandler + +from ..plugin_config import plugin_config +from ..send import do_send_msgs + +aps = AsyncIOScheduler(timezone="Asia/Shanghai") + + +class CustomLogHandler(LoguruHandler): + def filter(self, record: logging.LogRecord): + return record.msg != ( + 'Execution of job "%s" ' + "skipped: maximum number of running instances reached (%d)" + ) + + +if plugin_config.bison_use_queue: + aps.add_job(do_send_msgs, "interval", seconds=0.3, coalesce=True) + + aps_logger = logging.getLogger("apscheduler") + aps_logger.setLevel(30) + aps_logger.handlers.clear() + aps_logger.addHandler(CustomLogHandler()) + + +def start_scheduler(): + aps.configure({"apscheduler.timezone": "Asia/Shanghai"}) + aps.start() diff --git a/src/plugins/nonebot_bison/scheduler/scheduler.py b/src/plugins/nonebot_bison/scheduler/scheduler.py index 60df150..39c703c 100644 --- a/src/plugins/nonebot_bison/scheduler/scheduler.py +++ b/src/plugins/nonebot_bison/scheduler/scheduler.py @@ -1,12 +1,17 @@ from dataclasses import dataclass from typing import Optional +import nonebot +from nonebot.adapters.onebot.v11.bot import Bot from nonebot.log import logger from ..config import config +from ..platform import platform_manager from ..platform.platform import Platform +from ..send import send_msgs from ..types import Target from ..utils import SchedulerConfig +from .aps import aps @dataclass @@ -22,6 +27,7 @@ class Scheduler: def __init__(self, name: str, schedulables: list[tuple[str, Target]]): conf = SchedulerConfig.registry.get(name) + self.name = name if not conf: logger.error(f"scheduler config [{name}] not found, exiting") raise RuntimeError(f"{name} not found") @@ -37,8 +43,16 @@ class Scheduler: platform_name_set.add(platform_name) self.platform_name_list = list(platform_name_set) self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值 + logger.info( + f"register scheduler for {name} with {self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}" + ) + aps.add_job( + self.exec_fetch, + self.scheduler_config.schedule_type, + **self.scheduler_config.schedule_setting, + ) - async def schedule(self) -> Optional[Schedulable]: + async def get_next_schedulable(self) -> Optional[Schedulable]: if not self.schedulable_list: return None cur_weight = await config.get_current_weight_val(self.platform_name_list) @@ -61,6 +75,35 @@ class Scheduler: cur_max_schedulable.current_weight -= weight_sum return cur_max_schedulable + async def exec_fetch(self): + if not (schedulable := await self.get_next_schedulable()): + return + logger.debug( + f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}" + ) + send_userinfo_list = await config.get_platform_target_subscribers( + schedulable.platform_name, schedulable.target + ) + to_send = await platform_manager[schedulable.platform_name].do_fetch_new_post( + schedulable.target, send_userinfo_list + ) + if not to_send: + return + bot = nonebot.get_bot() + assert isinstance(bot, Bot) + for user, send_list in to_send: + for send_post in send_list: + logger.info("send to {}: {}".format(user, send_post)) + if not bot: + logger.warning("no bot connected") + else: + await send_msgs( + bot, + user.user, + user.user_type, + await send_post.generate_messages(), + ) + def insert_new_schedulable(self, platform_name: str, target: Target): self.pre_weight_val += 1000 self.schedulable_list.append(Schedulable(platform_name, target, 1000)) diff --git a/src/plugins/nonebot_bison/types.py b/src/plugins/nonebot_bison/types.py index 954e90e..d2f2b9a 100644 --- a/src/plugins/nonebot_bison/types.py +++ b/src/plugins/nonebot_bison/types.py @@ -22,5 +22,5 @@ class PlatformTarget: class UserSubInfo(NamedTuple): user: User - category_getter: Callable[[Target], list[Category]] - tag_getter: Callable[[Target], list[Tag]] + categories: list[Category] + tags: list[Tag] diff --git a/tests/config/test_scheduler_conf.py b/tests/config/test_scheduler_conf.py index 619e9ee..2dabf53 100644 --- a/tests/config/test_scheduler_conf.py +++ b/tests/config/test_scheduler_conf.py @@ -3,7 +3,7 @@ from datetime import time from nonebug import App -async def test_create_config(app: App, db_migration): +async def test_create_config(app: App, init_scheduler): from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config from nonebot_bison.config.db_model import Subscribe, Target, User from nonebot_bison.types import Target as T_Target @@ -52,7 +52,7 @@ async def test_create_config(app: App, db_migration): assert test_config1.time_config == [] -async def test_get_current_weight(app: App, db_migration): +async def test_get_current_weight(app: App, init_scheduler): from datetime import time from nonebot_bison.config import db_config @@ -84,7 +84,7 @@ async def test_get_current_weight(app: App, db_migration): user_type="group", target=T_Target("weibo_id1"), target_name="weibo_name2", - platform_name="weibo2", + platform_name="bilibili", cats=[], tags=[], ) @@ -100,26 +100,26 @@ async def test_get_current_weight(app: App, db_migration): ), ) app.monkeypatch.setattr(db_config, "_get_time", lambda: time(1, 30)) - weight = await config.get_current_weight_val(["weibo", "weibo2"]) + weight = await config.get_current_weight_val(["weibo", "bilibili"]) assert len(weight) == 3 assert weight["weibo-weibo_id"] == 20 assert weight["weibo-weibo_id1"] == 10 - assert weight["weibo2-weibo_id1"] == 10 + assert weight["bilibili-weibo_id1"] == 10 app.monkeypatch.setattr(db_config, "_get_time", lambda: time(4, 0)) - weight = await config.get_current_weight_val(["weibo", "weibo2"]) + weight = await config.get_current_weight_val(["weibo", "bilibili"]) assert len(weight) == 3 assert weight["weibo-weibo_id"] == 30 assert weight["weibo-weibo_id1"] == 10 - assert weight["weibo2-weibo_id1"] == 10 + assert weight["bilibili-weibo_id1"] == 10 app.monkeypatch.setattr(db_config, "_get_time", lambda: time(5, 0)) - weight = await config.get_current_weight_val(["weibo", "weibo2"]) + weight = await config.get_current_weight_val(["weibo", "bilibili"]) assert len(weight) == 3 assert weight["weibo-weibo_id"] == 10 assert weight["weibo-weibo_id1"] == 10 - assert weight["weibo2-weibo_id1"] == 10 + assert weight["bilibili-weibo_id1"] == 10 -async def test_get_platform_target(app: App, db_migration): +async def test_get_platform_target(app: App, init_scheduler): from nonebot_bison.config import db_config from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config from nonebot_bison.config.db_model import Subscribe, Target, User @@ -167,3 +167,52 @@ async def test_get_platform_target(app: App, db_migration): async with AsyncSession(get_engine()) as sess: res = await sess.scalars(select(Target).where(Target.platform_name == "weibo")) assert len(res.all()) == 2 + + +async def test_get_platform_target_subscribers(app: App, init_scheduler): + from nonebot_bison.config import db_config + from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config + from nonebot_bison.config.db_model import Subscribe, Target, User + from nonebot_bison.types import Target as T_Target + from nonebot_bison.types import User as T_User + from nonebot_bison.types import UserSubInfo + from nonebot_plugin_datastore.db import get_engine + from sqlalchemy.ext.asyncio.session import AsyncSession + from sqlalchemy.sql.expression import select + + await config.add_subscribe( + user=123, + user_type="group", + target=T_Target("weibo_id"), + target_name="weibo_name", + platform_name="weibo", + cats=[1], + tags=["tag1"], + ) + await config.add_subscribe( + user=123, + user_type="group", + target=T_Target("weibo_id1"), + target_name="weibo_name1", + platform_name="weibo", + cats=[2], + tags=["tag2"], + ) + await config.add_subscribe( + user=245, + user_type="group", + target=T_Target("weibo_id1"), + target_name="weibo_name1", + platform_name="weibo", + cats=[3], + tags=["tag3"], + ) + + res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id")) + assert len(res) == 1 + assert res[0] == UserSubInfo(T_User(123, "group"), [1], ["tag1"]) + + res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id1")) + assert len(res) == 2 + assert UserSubInfo(T_User(123, "group"), [2], ["tag2"]) in res + assert UserSubInfo(T_User(245, "group"), [3], ["tag3"]) in res diff --git a/tests/conftest.py b/tests/conftest.py index 25b0794..c76550b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ def dummy_user_subinfo(app: App): from nonebot_bison.types import User, UserSubInfo user = User(123, "group") - return UserSubInfo(user=user, category_getter=lambda _: [], tag_getter=lambda _: []) + return UserSubInfo(user=user, categories=[], tags=[]) @pytest.fixture @@ -48,6 +48,13 @@ async def db_migration(app: App): await sess.close() +@pytest.fixture +async def init_scheduler(db_migration): + from nonebot_bison.scheduler.manager import init_scheduler + + await init_scheduler() + + @pytest.fixture async def use_legacy_config(app: App): import aiofiles