mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-08 04:43:00 +08:00
finish new scheduler
This commit is contained in:
parent
331d0f6101
commit
9fa97704b0
@ -2,6 +2,7 @@ from nonebot import get_driver
|
|||||||
|
|
||||||
from .config.config_legacy import start_up as legacy_db_startup
|
from .config.config_legacy import start_up as legacy_db_startup
|
||||||
from .config.db import upgrade_db
|
from .config.db import upgrade_db
|
||||||
|
from .scheduler.aps import start_scheduler
|
||||||
from .scheduler.manager import init_scheduler
|
from .scheduler.manager import init_scheduler
|
||||||
|
|
||||||
|
|
||||||
@ -13,3 +14,5 @@ async def bootstrap():
|
|||||||
await upgrade_db()
|
await upgrade_db()
|
||||||
# init scheduler
|
# init scheduler
|
||||||
await init_scheduler()
|
await init_scheduler()
|
||||||
|
# start scheduler
|
||||||
|
start_scheduler()
|
||||||
|
@ -32,6 +32,22 @@ def get_config_path() -> str:
|
|||||||
return new_path
|
return new_path
|
||||||
|
|
||||||
|
|
||||||
|
def drop():
|
||||||
|
if plugin_config.bison_config_path:
|
||||||
|
data_dir = plugin_config.bison_config_path
|
||||||
|
else:
|
||||||
|
working_dir = os.getcwd()
|
||||||
|
data_dir = path.join(working_dir, "data")
|
||||||
|
old_path = path.join(data_dir, "bison.json")
|
||||||
|
new_path = path.join(data_dir, "bison-legacy.json")
|
||||||
|
if os.path.exists(old_path):
|
||||||
|
config.db.close()
|
||||||
|
config.available = False
|
||||||
|
os.rename(old_path, new_path)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class SubscribeContent(TypedDict):
|
class SubscribeContent(TypedDict):
|
||||||
target: str
|
target: str
|
||||||
target_type: str
|
target_type: str
|
||||||
@ -223,6 +239,8 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
def start_up():
|
def start_up():
|
||||||
config = Config()
|
config = Config()
|
||||||
|
if not config.available:
|
||||||
|
return
|
||||||
if not (search_res := config.kv_config.search(Query().name == "version")):
|
if not (search_res := config.kv_config.search(Query().name == "version")):
|
||||||
config.kv_config.insert({"name": "version", "value": config.migrate_version})
|
config.kv_config.insert({"name": "version", "value": config.migrate_version})
|
||||||
elif search_res[0].get("value") < config.migrate_version:
|
elif search_res[0].get("value") < config.migrate_version:
|
||||||
|
@ -10,7 +10,7 @@ from nonebot_plugin_datastore.db import get_engine
|
|||||||
from sqlalchemy.engine.base import Connection
|
from sqlalchemy.engine.base import Connection
|
||||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from .config_legacy import ConfigContent, config
|
from .config_legacy import ConfigContent, config, drop
|
||||||
from .db_model import Base, Subscribe, Target, User
|
from .db_model import Base, Subscribe, Target, User
|
||||||
|
|
||||||
DATA = PluginData("bison")
|
DATA = PluginData("bison")
|
||||||
@ -67,6 +67,7 @@ async def data_migrate():
|
|||||||
+ subscribe_to_create
|
+ subscribe_to_create
|
||||||
)
|
)
|
||||||
await sess.commit()
|
await sess.commit()
|
||||||
|
drop()
|
||||||
logger.info("migrate success")
|
logger.info("migrate success")
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +10,8 @@ from sqlalchemy.sql.functions import func
|
|||||||
|
|
||||||
from ..types import Category, Tag
|
from ..types import Category, Tag
|
||||||
from ..types import Target as T_Target
|
from ..types import Target as T_Target
|
||||||
|
from ..types import User as T_User
|
||||||
|
from ..types import UserSubInfo
|
||||||
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
|
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
|
||||||
|
|
||||||
|
|
||||||
@ -238,5 +240,27 @@ class DBConfig:
|
|||||||
res[key] = weight
|
res[key] = weight
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
async def get_platform_target_subscribers(
|
||||||
|
self, platform_name: str, target: T_Target
|
||||||
|
) -> list[UserSubInfo]:
|
||||||
|
async with AsyncSession(get_engine()) as sess:
|
||||||
|
query = (
|
||||||
|
select(Subscribe)
|
||||||
|
.join(Target)
|
||||||
|
.where(Target.platform_name == platform_name, Target.target == target)
|
||||||
|
.options(selectinload(Subscribe.user))
|
||||||
|
)
|
||||||
|
subsribes: list[Subscribe] = (await sess.scalars(query)).all()
|
||||||
|
return list(
|
||||||
|
map(
|
||||||
|
lambda subscribe: UserSubInfo(
|
||||||
|
T_User(subscribe.user.uid, subscribe.user.type),
|
||||||
|
subscribe.categories,
|
||||||
|
subscribe.tags,
|
||||||
|
),
|
||||||
|
subsribes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
config = DBConfig()
|
config = DBConfig()
|
||||||
|
@ -135,9 +135,7 @@ class Platform(metaclass=RegistryABCMeta, base=True):
|
|||||||
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
|
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
|
||||||
) -> list[tuple[User, list[Post]]]:
|
) -> list[tuple[User, list[Post]]]:
|
||||||
res: list[tuple[User, list[Post]]] = []
|
res: list[tuple[User, list[Post]]] = []
|
||||||
for user, category_getter, tag_getter in users:
|
for user, cats, required_tags in users:
|
||||||
required_tags = tag_getter(target) if self.enable_tag else []
|
|
||||||
cats = category_getter(target)
|
|
||||||
user_raw_post = await self.filter_user_custom(
|
user_raw_post = await self.filter_user_custom(
|
||||||
new_posts, cats, required_tags
|
new_posts, cats, required_tags
|
||||||
)
|
)
|
||||||
|
31
src/plugins/nonebot_bison/scheduler/aps.py
Normal file
31
src/plugins/nonebot_bison/scheduler/aps.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
from nonebot.log import LoguruHandler
|
||||||
|
|
||||||
|
from ..plugin_config import plugin_config
|
||||||
|
from ..send import do_send_msgs
|
||||||
|
|
||||||
|
aps = AsyncIOScheduler(timezone="Asia/Shanghai")
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLogHandler(LoguruHandler):
|
||||||
|
def filter(self, record: logging.LogRecord):
|
||||||
|
return record.msg != (
|
||||||
|
'Execution of job "%s" '
|
||||||
|
"skipped: maximum number of running instances reached (%d)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if plugin_config.bison_use_queue:
|
||||||
|
aps.add_job(do_send_msgs, "interval", seconds=0.3, coalesce=True)
|
||||||
|
|
||||||
|
aps_logger = logging.getLogger("apscheduler")
|
||||||
|
aps_logger.setLevel(30)
|
||||||
|
aps_logger.handlers.clear()
|
||||||
|
aps_logger.addHandler(CustomLogHandler())
|
||||||
|
|
||||||
|
|
||||||
|
def start_scheduler():
|
||||||
|
aps.configure({"apscheduler.timezone": "Asia/Shanghai"})
|
||||||
|
aps.start()
|
@ -1,12 +1,17 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import nonebot
|
||||||
|
from nonebot.adapters.onebot.v11.bot import Bot
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
from ..platform import platform_manager
|
||||||
from ..platform.platform import Platform
|
from ..platform.platform import Platform
|
||||||
|
from ..send import send_msgs
|
||||||
from ..types import Target
|
from ..types import Target
|
||||||
from ..utils import SchedulerConfig
|
from ..utils import SchedulerConfig
|
||||||
|
from .aps import aps
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -22,6 +27,7 @@ class Scheduler:
|
|||||||
|
|
||||||
def __init__(self, name: str, schedulables: list[tuple[str, Target]]):
|
def __init__(self, name: str, schedulables: list[tuple[str, Target]]):
|
||||||
conf = SchedulerConfig.registry.get(name)
|
conf = SchedulerConfig.registry.get(name)
|
||||||
|
self.name = name
|
||||||
if not conf:
|
if not conf:
|
||||||
logger.error(f"scheduler config [{name}] not found, exiting")
|
logger.error(f"scheduler config [{name}] not found, exiting")
|
||||||
raise RuntimeError(f"{name} not found")
|
raise RuntimeError(f"{name} not found")
|
||||||
@ -37,8 +43,16 @@ class Scheduler:
|
|||||||
platform_name_set.add(platform_name)
|
platform_name_set.add(platform_name)
|
||||||
self.platform_name_list = list(platform_name_set)
|
self.platform_name_list = list(platform_name_set)
|
||||||
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
|
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
|
||||||
|
logger.info(
|
||||||
|
f"register scheduler for {name} with {self.scheduler_config.schedule_type} {self.scheduler_config.schedule_setting}"
|
||||||
|
)
|
||||||
|
aps.add_job(
|
||||||
|
self.exec_fetch,
|
||||||
|
self.scheduler_config.schedule_type,
|
||||||
|
**self.scheduler_config.schedule_setting,
|
||||||
|
)
|
||||||
|
|
||||||
async def schedule(self) -> Optional[Schedulable]:
|
async def get_next_schedulable(self) -> Optional[Schedulable]:
|
||||||
if not self.schedulable_list:
|
if not self.schedulable_list:
|
||||||
return None
|
return None
|
||||||
cur_weight = await config.get_current_weight_val(self.platform_name_list)
|
cur_weight = await config.get_current_weight_val(self.platform_name_list)
|
||||||
@ -61,6 +75,35 @@ class Scheduler:
|
|||||||
cur_max_schedulable.current_weight -= weight_sum
|
cur_max_schedulable.current_weight -= weight_sum
|
||||||
return cur_max_schedulable
|
return cur_max_schedulable
|
||||||
|
|
||||||
|
async def exec_fetch(self):
|
||||||
|
if not (schedulable := await self.get_next_schedulable()):
|
||||||
|
return
|
||||||
|
logger.debug(
|
||||||
|
f"scheduler {self.name} fetching next target: [{schedulable.platform_name}]{schedulable.target}"
|
||||||
|
)
|
||||||
|
send_userinfo_list = await config.get_platform_target_subscribers(
|
||||||
|
schedulable.platform_name, schedulable.target
|
||||||
|
)
|
||||||
|
to_send = await platform_manager[schedulable.platform_name].do_fetch_new_post(
|
||||||
|
schedulable.target, send_userinfo_list
|
||||||
|
)
|
||||||
|
if not to_send:
|
||||||
|
return
|
||||||
|
bot = nonebot.get_bot()
|
||||||
|
assert isinstance(bot, Bot)
|
||||||
|
for user, send_list in to_send:
|
||||||
|
for send_post in send_list:
|
||||||
|
logger.info("send to {}: {}".format(user, send_post))
|
||||||
|
if not bot:
|
||||||
|
logger.warning("no bot connected")
|
||||||
|
else:
|
||||||
|
await send_msgs(
|
||||||
|
bot,
|
||||||
|
user.user,
|
||||||
|
user.user_type,
|
||||||
|
await send_post.generate_messages(),
|
||||||
|
)
|
||||||
|
|
||||||
def insert_new_schedulable(self, platform_name: str, target: Target):
|
def insert_new_schedulable(self, platform_name: str, target: Target):
|
||||||
self.pre_weight_val += 1000
|
self.pre_weight_val += 1000
|
||||||
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
|
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
|
||||||
|
@ -22,5 +22,5 @@ class PlatformTarget:
|
|||||||
|
|
||||||
class UserSubInfo(NamedTuple):
|
class UserSubInfo(NamedTuple):
|
||||||
user: User
|
user: User
|
||||||
category_getter: Callable[[Target], list[Category]]
|
categories: list[Category]
|
||||||
tag_getter: Callable[[Target], list[Tag]]
|
tags: list[Tag]
|
||||||
|
@ -3,7 +3,7 @@ from datetime import time
|
|||||||
from nonebug import App
|
from nonebug import App
|
||||||
|
|
||||||
|
|
||||||
async def test_create_config(app: App, db_migration):
|
async def test_create_config(app: App, init_scheduler):
|
||||||
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
|
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
|
||||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||||
from nonebot_bison.types import Target as T_Target
|
from nonebot_bison.types import Target as T_Target
|
||||||
@ -52,7 +52,7 @@ async def test_create_config(app: App, db_migration):
|
|||||||
assert test_config1.time_config == []
|
assert test_config1.time_config == []
|
||||||
|
|
||||||
|
|
||||||
async def test_get_current_weight(app: App, db_migration):
|
async def test_get_current_weight(app: App, init_scheduler):
|
||||||
from datetime import time
|
from datetime import time
|
||||||
|
|
||||||
from nonebot_bison.config import db_config
|
from nonebot_bison.config import db_config
|
||||||
@ -84,7 +84,7 @@ async def test_get_current_weight(app: App, db_migration):
|
|||||||
user_type="group",
|
user_type="group",
|
||||||
target=T_Target("weibo_id1"),
|
target=T_Target("weibo_id1"),
|
||||||
target_name="weibo_name2",
|
target_name="weibo_name2",
|
||||||
platform_name="weibo2",
|
platform_name="bilibili",
|
||||||
cats=[],
|
cats=[],
|
||||||
tags=[],
|
tags=[],
|
||||||
)
|
)
|
||||||
@ -100,26 +100,26 @@ async def test_get_current_weight(app: App, db_migration):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(1, 30))
|
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(1, 30))
|
||||||
weight = await config.get_current_weight_val(["weibo", "weibo2"])
|
weight = await config.get_current_weight_val(["weibo", "bilibili"])
|
||||||
assert len(weight) == 3
|
assert len(weight) == 3
|
||||||
assert weight["weibo-weibo_id"] == 20
|
assert weight["weibo-weibo_id"] == 20
|
||||||
assert weight["weibo-weibo_id1"] == 10
|
assert weight["weibo-weibo_id1"] == 10
|
||||||
assert weight["weibo2-weibo_id1"] == 10
|
assert weight["bilibili-weibo_id1"] == 10
|
||||||
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(4, 0))
|
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(4, 0))
|
||||||
weight = await config.get_current_weight_val(["weibo", "weibo2"])
|
weight = await config.get_current_weight_val(["weibo", "bilibili"])
|
||||||
assert len(weight) == 3
|
assert len(weight) == 3
|
||||||
assert weight["weibo-weibo_id"] == 30
|
assert weight["weibo-weibo_id"] == 30
|
||||||
assert weight["weibo-weibo_id1"] == 10
|
assert weight["weibo-weibo_id1"] == 10
|
||||||
assert weight["weibo2-weibo_id1"] == 10
|
assert weight["bilibili-weibo_id1"] == 10
|
||||||
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(5, 0))
|
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(5, 0))
|
||||||
weight = await config.get_current_weight_val(["weibo", "weibo2"])
|
weight = await config.get_current_weight_val(["weibo", "bilibili"])
|
||||||
assert len(weight) == 3
|
assert len(weight) == 3
|
||||||
assert weight["weibo-weibo_id"] == 10
|
assert weight["weibo-weibo_id"] == 10
|
||||||
assert weight["weibo-weibo_id1"] == 10
|
assert weight["weibo-weibo_id1"] == 10
|
||||||
assert weight["weibo2-weibo_id1"] == 10
|
assert weight["bilibili-weibo_id1"] == 10
|
||||||
|
|
||||||
|
|
||||||
async def test_get_platform_target(app: App, db_migration):
|
async def test_get_platform_target(app: App, init_scheduler):
|
||||||
from nonebot_bison.config import db_config
|
from nonebot_bison.config import db_config
|
||||||
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
|
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
|
||||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||||
@ -167,3 +167,52 @@ async def test_get_platform_target(app: App, db_migration):
|
|||||||
async with AsyncSession(get_engine()) as sess:
|
async with AsyncSession(get_engine()) as sess:
|
||||||
res = await sess.scalars(select(Target).where(Target.platform_name == "weibo"))
|
res = await sess.scalars(select(Target).where(Target.platform_name == "weibo"))
|
||||||
assert len(res.all()) == 2
|
assert len(res.all()) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_platform_target_subscribers(app: App, init_scheduler):
|
||||||
|
from nonebot_bison.config import db_config
|
||||||
|
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
|
||||||
|
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||||
|
from nonebot_bison.types import Target as T_Target
|
||||||
|
from nonebot_bison.types import User as T_User
|
||||||
|
from nonebot_bison.types import UserSubInfo
|
||||||
|
from nonebot_plugin_datastore.db import get_engine
|
||||||
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||||
|
from sqlalchemy.sql.expression import select
|
||||||
|
|
||||||
|
await config.add_subscribe(
|
||||||
|
user=123,
|
||||||
|
user_type="group",
|
||||||
|
target=T_Target("weibo_id"),
|
||||||
|
target_name="weibo_name",
|
||||||
|
platform_name="weibo",
|
||||||
|
cats=[1],
|
||||||
|
tags=["tag1"],
|
||||||
|
)
|
||||||
|
await config.add_subscribe(
|
||||||
|
user=123,
|
||||||
|
user_type="group",
|
||||||
|
target=T_Target("weibo_id1"),
|
||||||
|
target_name="weibo_name1",
|
||||||
|
platform_name="weibo",
|
||||||
|
cats=[2],
|
||||||
|
tags=["tag2"],
|
||||||
|
)
|
||||||
|
await config.add_subscribe(
|
||||||
|
user=245,
|
||||||
|
user_type="group",
|
||||||
|
target=T_Target("weibo_id1"),
|
||||||
|
target_name="weibo_name1",
|
||||||
|
platform_name="weibo",
|
||||||
|
cats=[3],
|
||||||
|
tags=["tag3"],
|
||||||
|
)
|
||||||
|
|
||||||
|
res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id"))
|
||||||
|
assert len(res) == 1
|
||||||
|
assert res[0] == UserSubInfo(T_User(123, "group"), [1], ["tag1"])
|
||||||
|
|
||||||
|
res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id1"))
|
||||||
|
assert len(res) == 2
|
||||||
|
assert UserSubInfo(T_User(123, "group"), [2], ["tag2"]) in res
|
||||||
|
assert UserSubInfo(T_User(245, "group"), [3], ["tag3"]) in res
|
||||||
|
@ -30,7 +30,7 @@ def dummy_user_subinfo(app: App):
|
|||||||
from nonebot_bison.types import User, UserSubInfo
|
from nonebot_bison.types import User, UserSubInfo
|
||||||
|
|
||||||
user = User(123, "group")
|
user = User(123, "group")
|
||||||
return UserSubInfo(user=user, category_getter=lambda _: [], tag_getter=lambda _: [])
|
return UserSubInfo(user=user, categories=[], tags=[])
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -48,6 +48,13 @@ async def db_migration(app: App):
|
|||||||
await sess.close()
|
await sess.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def init_scheduler(db_migration):
|
||||||
|
from nonebot_bison.scheduler.manager import init_scheduler
|
||||||
|
|
||||||
|
await init_scheduler()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def use_legacy_config(app: App):
|
async def use_legacy_config(app: App):
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
Loading…
x
Reference in New Issue
Block a user