finish new scheduler

This commit is contained in:
felinae98 2022-06-06 00:22:18 +08:00
parent 331d0f6101
commit 9fa97704b0
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
10 changed files with 192 additions and 18 deletions

View File

@ -2,6 +2,7 @@ from nonebot import get_driver
from .config.config_legacy import start_up as legacy_db_startup from .config.config_legacy import start_up as legacy_db_startup
from .config.db import upgrade_db from .config.db import upgrade_db
from .scheduler.aps import start_scheduler
from .scheduler.manager import init_scheduler from .scheduler.manager import init_scheduler
@ -13,3 +14,5 @@ async def bootstrap():
await upgrade_db() await upgrade_db()
# init scheduler # init scheduler
await init_scheduler() await init_scheduler()
# start scheduler
start_scheduler()

View File

@ -32,6 +32,22 @@ def get_config_path() -> str:
return new_path return new_path
def drop():
if plugin_config.bison_config_path:
data_dir = plugin_config.bison_config_path
else:
working_dir = os.getcwd()
data_dir = path.join(working_dir, "data")
old_path = path.join(data_dir, "bison.json")
new_path = path.join(data_dir, "bison-legacy.json")
if os.path.exists(old_path):
config.db.close()
config.available = False
os.rename(old_path, new_path)
return True
return False
class SubscribeContent(TypedDict): class SubscribeContent(TypedDict):
target: str target: str
target_type: str target_type: str
@ -223,6 +239,8 @@ class Config(metaclass=Singleton):
def start_up(): def start_up():
config = Config() config = Config()
if not config.available:
return
if not (search_res := config.kv_config.search(Query().name == "version")): if not (search_res := config.kv_config.search(Query().name == "version")):
config.kv_config.insert({"name": "version", "value": config.migrate_version}) config.kv_config.insert({"name": "version", "value": config.migrate_version})
elif search_res[0].get("value") < config.migrate_version: elif search_res[0].get("value") < config.migrate_version:

View File

@ -10,7 +10,7 @@ from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.engine.base import Connection from sqlalchemy.engine.base import Connection
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from .config_legacy import ConfigContent, config from .config_legacy import ConfigContent, config, drop
from .db_model import Base, Subscribe, Target, User from .db_model import Base, Subscribe, Target, User
DATA = PluginData("bison") DATA = PluginData("bison")
@ -67,6 +67,7 @@ async def data_migrate():
+ subscribe_to_create + subscribe_to_create
) )
await sess.commit() await sess.commit()
drop()
logger.info("migrate success") logger.info("migrate success")

View File

@ -10,6 +10,8 @@ from sqlalchemy.sql.functions import func
from ..types import Category, Tag from ..types import Category, Tag
from ..types import Target as T_Target from ..types import Target as T_Target
from ..types import User as T_User
from ..types import UserSubInfo
from .db_model import ScheduleTimeWeight, Subscribe, Target, User from .db_model import ScheduleTimeWeight, Subscribe, Target, User
@ -238,5 +240,27 @@ class DBConfig:
res[key] = weight res[key] = weight
return res return res
async def get_platform_target_subscribers(
self, platform_name: str, target: T_Target
) -> list[UserSubInfo]:
async with AsyncSession(get_engine()) as sess:
query = (
select(Subscribe)
.join(Target)
.where(Target.platform_name == platform_name, Target.target == target)
.options(selectinload(Subscribe.user))
)
subsribes: list[Subscribe] = (await sess.scalars(query)).all()
return list(
map(
lambda subscribe: UserSubInfo(
T_User(subscribe.user.uid, subscribe.user.type),
subscribe.categories,
subscribe.tags,
),
subsribes,
)
)
config = DBConfig() config = DBConfig()

View File

@ -135,9 +135,7 @@ class Platform(metaclass=RegistryABCMeta, base=True):
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo] self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]: ) -> list[tuple[User, list[Post]]]:
res: list[tuple[User, list[Post]]] = [] res: list[tuple[User, list[Post]]] = []
for user, category_getter, tag_getter in users: for user, cats, required_tags in users:
required_tags = tag_getter(target) if self.enable_tag else []
cats = category_getter(target)
user_raw_post = await self.filter_user_custom( user_raw_post = await self.filter_user_custom(
new_posts, cats, required_tags new_posts, cats, required_tags
) )

View File

@ -0,0 +1,31 @@
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from nonebot.log import LoguruHandler
from ..plugin_config import plugin_config
from ..send import do_send_msgs
aps = AsyncIOScheduler(timezone="Asia/Shanghai")
class CustomLogHandler(LoguruHandler):
def filter(self, record: logging.LogRecord):
return record.msg != (
'Execution of job "%s" '
"skipped: maximum number of running instances reached (%d)"
)
if plugin_config.bison_use_queue:
aps.add_job(do_send_msgs, "interval", seconds=0.3, coalesce=True)
aps_logger = logging.getLogger("apscheduler")
aps_logger.setLevel(30)
aps_logger.handlers.clear()
aps_logger.addHandler(CustomLogHandler())
def start_scheduler():
aps.configure({"apscheduler.timezone": "Asia/Shanghai"})
aps.start()

View File

@ -1,12 +1,17 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
import nonebot
from nonebot.adapters.onebot.v11.bot import Bot
from nonebot.log import logger from nonebot.log import logger
from ..config import config from ..config import config
from ..platform import platform_manager
from ..platform.platform import Platform from ..platform.platform import Platform
from ..send import send_msgs
from ..types import Target from ..types import Target
from ..utils import SchedulerConfig from ..utils import SchedulerConfig
from .aps import aps
@dataclass @dataclass
@ -22,6 +27,7 @@ class Scheduler:
def __init__(self, name: str, schedulables: list[tuple[str, Target]]): def __init__(self, name: str, schedulables: list[tuple[str, Target]]):
conf = SchedulerConfig.registry.get(name) conf = SchedulerConfig.registry.get(name)
self.name = name
if not conf: if not conf:
logger.error(f"scheduler config [{name}] not found, exiting") logger.error(f"scheduler config [{name}] not found, exiting")
raise RuntimeError(f"{name} not found") raise RuntimeError(f"{name} not found")
@ -37,8 +43,16 @@ class Scheduler:
platform_name_set.add(platform_name) platform_name_set.add(platform_name)
self.platform_name_list = list(platform_name_set) self.platform_name_list = list(platform_name_set)
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值 self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
logger.info(
f"register scheduler for {name} with {self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
)
aps.add_job(
self.exec_fetch,
self.scheduler_config.schedule_type,
**self.scheduler_config.schedule_setting,
)
async def schedule(self) -> Optional[Schedulable]: async def get_next_schedulable(self) -> Optional[Schedulable]:
if not self.schedulable_list: if not self.schedulable_list:
return None return None
cur_weight = await config.get_current_weight_val(self.platform_name_list) cur_weight = await config.get_current_weight_val(self.platform_name_list)
@ -61,6 +75,35 @@ class Scheduler:
cur_max_schedulable.current_weight -= weight_sum cur_max_schedulable.current_weight -= weight_sum
return cur_max_schedulable return cur_max_schedulable
async def exec_fetch(self):
if not (schedulable := await self.get_next_schedulable()):
return
logger.debug(
f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}"
)
send_userinfo_list = await config.get_platform_target_subscribers(
schedulable.platform_name, schedulable.target
)
to_send = await platform_manager[schedulable.platform_name].do_fetch_new_post(
schedulable.target, send_userinfo_list
)
if not to_send:
return
bot = nonebot.get_bot()
assert isinstance(bot, Bot)
for user, send_list in to_send:
for send_post in send_list:
logger.info("send to {}: {}".format(user, send_post))
if not bot:
logger.warning("no bot connected")
else:
await send_msgs(
bot,
user.user,
user.user_type,
await send_post.generate_messages(),
)
def insert_new_schedulable(self, platform_name: str, target: Target): def insert_new_schedulable(self, platform_name: str, target: Target):
self.pre_weight_val += 1000 self.pre_weight_val += 1000
self.schedulable_list.append(Schedulable(platform_name, target, 1000)) self.schedulable_list.append(Schedulable(platform_name, target, 1000))

View File

@ -22,5 +22,5 @@ class PlatformTarget:
class UserSubInfo(NamedTuple): class UserSubInfo(NamedTuple):
user: User user: User
category_getter: Callable[[Target], list[Category]] categories: list[Category]
tag_getter: Callable[[Target], list[Tag]] tags: list[Tag]

View File

@ -3,7 +3,7 @@ from datetime import time
from nonebug import App from nonebug import App
async def test_create_config(app: App, db_migration): async def test_create_config(app: App, init_scheduler):
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
@ -52,7 +52,7 @@ async def test_create_config(app: App, db_migration):
assert test_config1.time_config == [] assert test_config1.time_config == []
async def test_get_current_weight(app: App, db_migration): async def test_get_current_weight(app: App, init_scheduler):
from datetime import time from datetime import time
from nonebot_bison.config import db_config from nonebot_bison.config import db_config
@ -84,7 +84,7 @@ async def test_get_current_weight(app: App, db_migration):
user_type="group", user_type="group",
target=T_Target("weibo_id1"), target=T_Target("weibo_id1"),
target_name="weibo_name2", target_name="weibo_name2",
platform_name="weibo2", platform_name="bilibili",
cats=[], cats=[],
tags=[], tags=[],
) )
@ -100,26 +100,26 @@ async def test_get_current_weight(app: App, db_migration):
), ),
) )
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(1, 30)) app.monkeypatch.setattr(db_config, "_get_time", lambda: time(1, 30))
weight = await config.get_current_weight_val(["weibo", "weibo2"]) weight = await config.get_current_weight_val(["weibo", "bilibili"])
assert len(weight) == 3 assert len(weight) == 3
assert weight["weibo-weibo_id"] == 20 assert weight["weibo-weibo_id"] == 20
assert weight["weibo-weibo_id1"] == 10 assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10 assert weight["bilibili-weibo_id1"] == 10
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(4, 0)) app.monkeypatch.setattr(db_config, "_get_time", lambda: time(4, 0))
weight = await config.get_current_weight_val(["weibo", "weibo2"]) weight = await config.get_current_weight_val(["weibo", "bilibili"])
assert len(weight) == 3 assert len(weight) == 3
assert weight["weibo-weibo_id"] == 30 assert weight["weibo-weibo_id"] == 30
assert weight["weibo-weibo_id1"] == 10 assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10 assert weight["bilibili-weibo_id1"] == 10
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(5, 0)) app.monkeypatch.setattr(db_config, "_get_time", lambda: time(5, 0))
weight = await config.get_current_weight_val(["weibo", "weibo2"]) weight = await config.get_current_weight_val(["weibo", "bilibili"])
assert len(weight) == 3 assert len(weight) == 3
assert weight["weibo-weibo_id"] == 10 assert weight["weibo-weibo_id"] == 10
assert weight["weibo-weibo_id1"] == 10 assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10 assert weight["bilibili-weibo_id1"] == 10
async def test_get_platform_target(app: App, db_migration): async def test_get_platform_target(app: App, init_scheduler):
from nonebot_bison.config import db_config from nonebot_bison.config import db_config
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User from nonebot_bison.config.db_model import Subscribe, Target, User
@ -167,3 +167,52 @@ async def test_get_platform_target(app: App, db_migration):
async with AsyncSession(get_engine()) as sess: async with AsyncSession(get_engine()) as sess:
res = await sess.scalars(select(Target).where(Target.platform_name == "weibo")) res = await sess.scalars(select(Target).where(Target.platform_name == "weibo"))
assert len(res.all()) == 2 assert len(res.all()) == 2
async def test_get_platform_target_subscribers(app: App, init_scheduler):
from nonebot_bison.config import db_config
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target
from nonebot_bison.types import User as T_User
from nonebot_bison.types import UserSubInfo
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[1],
tags=["tag1"],
)
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
cats=[2],
tags=["tag2"],
)
await config.add_subscribe(
user=245,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
cats=[3],
tags=["tag3"],
)
res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id"))
assert len(res) == 1
assert res[0] == UserSubInfo(T_User(123, "group"), [1], ["tag1"])
res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id1"))
assert len(res) == 2
assert UserSubInfo(T_User(123, "group"), [2], ["tag2"]) in res
assert UserSubInfo(T_User(245, "group"), [3], ["tag3"]) in res

View File

@ -30,7 +30,7 @@ def dummy_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo from nonebot_bison.types import User, UserSubInfo
user = User(123, "group") user = User(123, "group")
return UserSubInfo(user=user, category_getter=lambda _: [], tag_getter=lambda _: []) return UserSubInfo(user=user, categories=[], tags=[])
@pytest.fixture @pytest.fixture
@ -48,6 +48,13 @@ async def db_migration(app: App):
await sess.close() await sess.close()
@pytest.fixture
async def init_scheduler(db_migration):
from nonebot_bison.scheduler.manager import init_scheduler
await init_scheduler()
@pytest.fixture @pytest.fixture
async def use_legacy_config(app: App): async def use_legacy_config(app: App):
import aiofiles import aiofiles