This commit is contained in:
felinae98 2022-06-05 16:45:35 +08:00
parent 7b4c79acd3
commit 331d0f6101
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
18 changed files with 298 additions and 52 deletions

View File

@ -2,6 +2,7 @@ from nonebot.plugin import require
from . import ( from . import (
admin_page, admin_page,
bootstrap,
config, config,
config_manager, config_manager,
platform, platform,

View File

@ -0,0 +1,15 @@
from nonebot import get_driver
from .config.config_legacy import start_up as legacy_db_startup
from .config.db import upgrade_db
from .scheduler.manager import init_scheduler
@get_driver().on_startup
async def bootstrap():
# legacy db
legacy_db_startup()
# new db
await upgrade_db()
# init scheduler
await init_scheduler()

View File

@ -243,6 +243,4 @@ def start_up():
config.update_send_cache() config.update_send_cache()
nonebot.get_driver().on_startup(start_up)
config = Config() config = Config()

View File

@ -70,7 +70,6 @@ async def data_migrate():
logger.info("migrate success") logger.info("migrate success")
@nonebot.get_driver().on_startup
async def upgrade_db(): async def upgrade_db():
alembic_cfg = Config() alembic_cfg = Config()
alembic_cfg.set_main_option( alembic_cfg.set_main_option(

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, time from datetime import datetime, time
from typing import Optional from typing import Any, Awaitable, Callable, Optional
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
@ -34,6 +34,16 @@ class WeightConfig:
class DBConfig: class DBConfig:
def __init__(self):
self.add_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None
self.delete_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None
def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.add_target_hook = fun
def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.delete_target_hook = fun
async def add_subscribe( async def add_subscribe(
self, self,
user: int, user: int,
@ -62,6 +72,8 @@ class DBConfig:
db_target = Target( db_target = Target(
target=target, platform_name=platform_name, target_name=target_name target=target, platform_name=platform_name, target_name=target_name
) )
if self.add_target_hook:
await self.add_target_hook(platform_name, target)
else: else:
db_target.target_name = target_name # type: ignore db_target.target_name = target_name # type: ignore
subscribe = Subscribe( subscribe = Subscribe(
@ -93,12 +105,12 @@ class DBConfig:
) )
target_obj = await session.scalar( target_obj = await session.scalar(
select(Target).where( select(Target).where(
Target.platform_name == platform_name, MTarget.target == target Target.platform_name == platform_name, Target.target == target
) )
) )
await session.execute( await session.execute(
delete(Subscribe).where( delete(Subscribe).where(
Subscribe.user == user_obj, MSubscribe.target == target_obj Subscribe.user == user_obj, Subscribe.target == target_obj
) )
) )
target_count = await session.scalar( target_count = await session.scalar(
@ -108,7 +120,9 @@ class DBConfig:
) )
if target_count == 0: if target_count == 0:
# delete empty target # delete empty target
await session.delete(target_obj) # await session.delete(target_obj)
if self.delete_target_hook:
await self.delete_target_hook(platform_name, T_Target(target))
await session.commit() await session.commit()
async def update_subscribe( async def update_subscribe(
@ -139,15 +153,27 @@ class DBConfig:
subscribe_obj.target.target_name = target_name subscribe_obj.target.target_name = target_name
await sess.commit() await sess.commit()
async def get_platform_target(self, platform_name: str) -> list[Target]:
async with AsyncSession(get_engine()) as sess:
subq = select(Subscribe.target_id).distinct().subquery()
query = (
select(Target).join(subq).where(Target.platform_name == platform_name)
)
return (await sess.scalars(query)).all()
async def get_time_weight_config( async def get_time_weight_config(
self, target: T_Target, platform_name: str self, target: T_Target, platform_name: str
) -> WeightConfig: ) -> WeightConfig:
async with AsyncSession(get_engine()) as sess: async with AsyncSession(get_engine()) as sess:
time_weight_conf: list[ScheduleTimeWeight] = await sess.scalars( time_weight_conf: list[ScheduleTimeWeight] = (
await sess.scalars(
select(ScheduleTimeWeight) select(ScheduleTimeWeight)
.where(Target.platform_name == platform_name, Target.target == target) .where(
Target.platform_name == platform_name, Target.target == target
)
.join(Target) .join(Target)
) )
).all()
targetObj: Target = await sess.scalar( targetObj: Target = await sess.scalar(
select(Target).where( select(Target).where(
Target.platform_name == platform_name, Target.target == target Target.platform_name == platform_name, Target.target == target
@ -192,11 +218,13 @@ class DBConfig:
res = {} res = {}
cur_time = _get_time() cur_time = _get_time()
async with AsyncSession(get_engine()) as sess: async with AsyncSession(get_engine()) as sess:
targets: list[Target] = await sess.scalars( targets: list[Target] = (
await sess.scalars(
select(Target) select(Target)
.where(Target.platform_name.in_(platform_list)) .where(Target.platform_name.in_(platform_list))
.options(selectinload(Target.time_weight)) .options(selectinload(Target.time_weight))
) )
).all()
for target in targets: for target in targets:
key = f"{target.platform_name}-{target.target}" key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight weight = target.default_schedule_weight

View File

@ -7,9 +7,16 @@ from nonebot.plugin import require
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Target from ..types import Category, RawPost, Target
from ..utils import http_client from ..utils import http_client
from ..utils.scheduler_config import SchedulerConfig
from .platform import CategoryNotSupport, NewMessage, StatusChange from .platform import CategoryNotSupport, NewMessage, StatusChange
class ArknightsSchedConf(SchedulerConfig, name="arknights"):
schedule_type = "interval"
schedule_setting = {"seconds": 30}
class Arknights(NewMessage): class Arknights(NewMessage):
categories = {1: "游戏公告"} categories = {1: "游戏公告"}
@ -18,8 +25,7 @@ class Arknights(NewMessage):
enable_tag = False enable_tag = False
enabled = True enabled = True
is_common = False is_common = False
schedule_type = "interval" scheduler_class = "arknights"
schedule_kw = {"seconds": 30}
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: async def get_target_name(self, _: Target) -> str:
@ -91,8 +97,7 @@ class AkVersion(StatusChange):
enable_tag = False enable_tag = False
enabled = True enabled = True
is_common = False is_common = False
schedule_type = "interval" scheduler_class = "arknights"
schedule_kw = {"seconds": 30}
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: async def get_target_name(self, _: Target) -> str:
@ -147,8 +152,7 @@ class MonsterSiren(NewMessage):
enable_tag = False enable_tag = False
enabled = True enabled = True
is_common = False is_common = False
schedule_type = "interval" scheduler_class = "arknights"
schedule_kw = {"seconds": 30}
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: async def get_target_name(self, _: Target) -> str:
@ -199,8 +203,7 @@ class TerraHistoricusComic(NewMessage):
enable_tag = False enable_tag = False
enabled = True enabled = True
is_common = False is_common = False
schedule_type = "interval" scheduler_class = "arknights"
schedule_kw = {"seconds": 30}
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: async def get_target_name(self, _: Target) -> str:

View File

@ -4,10 +4,16 @@ from typing import Any, Optional
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Tag, Target from ..types import Category, RawPost, Tag, Target
from ..utils import http_client from ..utils import SchedulerConfig, http_client
from .platform import CategoryNotSupport, NewMessage, StatusChange from .platform import CategoryNotSupport, NewMessage, StatusChange
class BilibiliSchedConf(SchedulerConfig, name="bilibili.com"):
schedule_type = "interval"
schedule_setting = {"seconds": 10}
class Bilibili(NewMessage): class Bilibili(NewMessage):
categories = { categories = {
@ -22,8 +28,7 @@ class Bilibili(NewMessage):
enable_tag = True enable_tag = True
enabled = True enabled = True
is_common = True is_common = True
schedule_type = "interval" scheduler_class = "bilibili.com"
schedule_kw = {"seconds": 10}
name = "B站" name = "B站"
has_target = True has_target = True
parse_target_promot = "请输入用户主页的链接" parse_target_promot = "请输入用户主页的链接"
@ -167,8 +172,7 @@ class Bilibililive(StatusChange):
enable_tag = True enable_tag = True
enabled = True enabled = True
is_common = True is_common = True
schedule_type = "interval" scheduler_class = "bilibili.com"
schedule_kw = {"seconds": 10}
name = "Bilibili直播" name = "Bilibili直播"
has_target = True has_target = True

View File

@ -2,10 +2,16 @@ from typing import Any
from ..post import Post from ..post import Post
from ..types import RawPost, Target from ..types import RawPost, Target
from ..utils import http_client from ..utils import SchedulerConfig, http_client
from .platform import NewMessage from .platform import NewMessage
class FF14SchedConf(SchedulerConfig, name="ff14"):
schedule_type = "interval"
schedule_setting = {"seconds": 60}
class FF14(NewMessage): class FF14(NewMessage):
categories = {} categories = {}
@ -14,8 +20,7 @@ class FF14(NewMessage):
enable_tag = False enable_tag = False
enabled = True enabled = True
is_common = False is_common = False
schedule_type = "interval" scheduler_class = "ff14"
schedule_kw = {"seconds": 60}
has_target = False has_target = False
async def get_target_name(self, _: Target) -> str: async def get_target_name(self, _: Target) -> str:

View File

@ -3,10 +3,16 @@ from typing import Any, Optional
from ..post import Post from ..post import Post
from ..types import RawPost, Target from ..types import RawPost, Target
from ..utils import http_client from ..utils import SchedulerConfig, http_client
from .platform import NewMessage from .platform import NewMessage
class NcmSchedConf(SchedulerConfig, name="music.163.com"):
schedule_type = "interval"
schedule_setting = {"minutes": 1}
class NcmArtist(NewMessage): class NcmArtist(NewMessage):
categories = {} categories = {}
@ -14,8 +20,7 @@ class NcmArtist(NewMessage):
enable_tag = False enable_tag = False
enabled = True enabled = True
is_common = True is_common = True
schedule_type = "interval" scheduler_class = "music.163.com"
schedule_kw = {"minutes": 1}
name = "网易云-歌手" name = "网易云-歌手"
has_target = True has_target = True
parse_target_promot = "请输入歌手主页包含数字ID的链接" parse_target_promot = "请输入歌手主页包含数字ID的链接"

View File

@ -14,8 +14,7 @@ class NcmRadio(NewMessage):
enable_tag = False enable_tag = False
enabled = True enabled = True
is_common = False is_common = False
schedule_type = "interval" scheduler_class = "music.163.com"
schedule_kw = {"minutes": 10}
name = "网易云-电台" name = "网易云-电台"
has_target = True has_target = True
parse_target_promot = "请输入主播电台主页包含数字ID的链接" parse_target_promot = "请输入主播电台主页包含数字ID的链接"

View File

@ -39,8 +39,7 @@ class RegistryABCMeta(RegistryMeta, ABC):
class Platform(metaclass=RegistryABCMeta, base=True): class Platform(metaclass=RegistryABCMeta, base=True):
schedule_type: Literal["date", "interval", "cron"] scheduler_class: str
schedule_kw: dict
is_common: bool is_common: bool
enabled: bool enabled: bool
name: str name: str
@ -332,11 +331,11 @@ class NoTargetGroup(Platform, abstract=True):
def __init__(self, platform_list: list[Platform]): def __init__(self, platform_list: list[Platform]):
self.platform_list = platform_list self.platform_list = platform_list
self.platform_name = platform_list[0].platform_name
name = self.DUMMY_STR name = self.DUMMY_STR
self.categories = {} self.categories = {}
categories_keys = set() categories_keys = set()
self.schedule_type = platform_list[0].schedule_type self.scheduler_class = platform_list[0].scheduler_class
self.schedule_kw = platform_list[0].schedule_kw
for platform in platform_list: for platform in platform_list:
if platform.has_target: if platform.has_target:
raise RuntimeError( raise RuntimeError(
@ -355,10 +354,7 @@ class NoTargetGroup(Platform, abstract=True):
) )
categories_keys |= platform_category_key_set categories_keys |= platform_category_key_set
self.categories.update(platform.categories) self.categories.update(platform.categories)
if ( if platform.scheduler_class != self.scheduler_class:
platform.schedule_kw != self.schedule_kw
or platform.schedule_type != self.schedule_type
):
raise RuntimeError( raise RuntimeError(
"Platform scheduler for {} not fit".format(self.platform_name) "Platform scheduler for {} not fit".format(self.platform_name)
) )

View File

@ -6,10 +6,16 @@ from bs4 import BeautifulSoup as bs
from ..post import Post from ..post import Post
from ..types import RawPost, Target from ..types import RawPost, Target
from ..utils import http_client from ..utils import SchedulerConfig, http_client
from .platform import NewMessage from .platform import NewMessage
class RssSchedConf(SchedulerConfig, name="rss"):
schedule_type = "interval"
schedule_setting = {"seconds": 30}
class Rss(NewMessage): class Rss(NewMessage):
categories = {} categories = {}
@ -18,8 +24,7 @@ class Rss(NewMessage):
name = "Rss" name = "Rss"
enabled = True enabled = True
is_common = True is_common = True
schedule_type = "interval" scheduler_class = "rss"
schedule_kw = {"seconds": 30}
has_target = True has_target = True
async def get_target_name(self, target: Target) -> Optional[str]: async def get_target_name(self, target: Target) -> Optional[str]:

View File

@ -8,10 +8,15 @@ from nonebot.log import logger
from ..post import Post from ..post import Post
from ..types import * from ..types import *
from ..utils import http_client from ..utils import SchedulerConfig, http_client
from .platform import NewMessage from .platform import NewMessage
class WeiboSchedConf(SchedulerConfig, name="weibo.com"):
schedule_type = "interval"
schedule_setting = {"seconds": 3}
class Weibo(NewMessage): class Weibo(NewMessage):
categories = { categories = {
@ -25,8 +30,7 @@ class Weibo(NewMessage):
name = "新浪微博" name = "新浪微博"
enabled = True enabled = True
is_common = True is_common = True
schedule_type = "interval" scheduler_class = "weibo.com"
schedule_kw = {"seconds": 3}
has_target = True has_target = True
parse_target_promot = "请输入用户主页包含数字UID的链接" parse_target_promot = "请输入用户主页包含数字UID的链接"

View File

@ -0,0 +1 @@
from .manager import *

View File

@ -0,0 +1,43 @@
from nonebot.log import logger
from ..config import config
from ..config.db_model import Target
from ..platform import platform_manager
from ..types import Target as T_Target
from ..utils import SchedulerConfig
from .scheduler import Scheduler
scheduler_dict: dict[str, Scheduler] = {}
_schedule_class_dict: dict[str, list[Target]] = {}
async def init_scheduler():
for platform in platform_manager.values():
scheduler_class = platform.scheduler_class
platform_name = platform.platform_name
targets = await config.get_platform_target(platform_name)
if scheduler_class not in _schedule_class_dict:
_schedule_class_dict[scheduler_class] = targets
else:
_schedule_class_dict[scheduler_class].extend(targets)
for scheduler_class, target_list in _schedule_class_dict.items():
schedulable_args = []
for target in target_list:
schedulable_args.append((target.platform_name, T_Target(target.target)))
scheduler_dict[scheduler_class] = Scheduler(scheduler_class, schedulable_args)
async def handle_insert_new_target(platform_name: str, target: T_Target):
platform = platform_manager[platform_name]
scheduler_obj = scheduler_dict[platform.scheduler_class]
scheduler_obj.insert_new_schedulable(platform_name, target)
async def handle_delete_target(platform_name: str, target: T_Target):
platform = platform_manager[platform_name]
scheduler_obj = scheduler_dict[platform.scheduler_class]
scheduler_obj.delete_schedulable(platform_name, target)
config.register_add_target_hook(handle_delete_target)
config.register_delete_target_hook(handle_delete_target)

View File

@ -0,0 +1,85 @@
from dataclasses import dataclass
from typing import Optional
from nonebot.log import logger
from ..config import config
from ..platform.platform import Platform
from ..types import Target
from ..utils import SchedulerConfig
@dataclass
class Schedulable:
platform_name: str
target: Target
current_weight: int
class Scheduler:
schedulable_list: list[Schedulable]
def __init__(self, name: str, schedulables: list[tuple[str, Target]]):
conf = SchedulerConfig.registry.get(name)
if not conf:
logger.error(f"scheduler config [{name}] not found, exiting")
raise RuntimeError(f"{name} not found")
self.scheduler_config = conf
self.schedulable_list = []
platform_name_set = set()
for platform_name, target in schedulables:
self.schedulable_list.append(
Schedulable(
platform_name=platform_name, target=target, current_weight=0
)
)
platform_name_set.add(platform_name)
self.platform_name_list = list(platform_name_set)
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
async def schedule(self) -> Optional[Schedulable]:
if not self.schedulable_list:
return None
cur_weight = await config.get_current_weight_val(self.platform_name_list)
weight_sum = self.pre_weight_val
self.pre_weight_val = 0
cur_max_schedulable = None
for schedulable in self.schedulable_list:
schedulable.current_weight += cur_weight[
f"{schedulable.platform_name}-{schedulable.target}"
]
weight_sum += cur_weight[
f"{schedulable.platform_name}-{schedulable.target}"
]
if (
not cur_max_schedulable
or cur_max_schedulable.current_weight < schedulable.current_weight
):
cur_max_schedulable = schedulable
assert cur_max_schedulable
cur_max_schedulable.current_weight -= weight_sum
return cur_max_schedulable
def insert_new_schedulable(self, platform_name: str, target: Target):
self.pre_weight_val += 1000
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
logger.info(
f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})"
)
def delete_schedulable(self, platform_name, target: Target):
if not self.schedulable_list:
return
to_find_idx = None
for idx, schedulable in enumerate(self.schedulable_list):
if (
schedulable.platform_name == platform_name
and schedulable.target == target
):
to_find_idx = idx
break
if to_find_idx is not None:
deleted_schdulable = self.schedulable_list.pop(to_find_idx)
self.pre_weight_val -= deleted_schdulable.current_weight
return

View File

@ -6,7 +6,12 @@ class SchedulerConfig:
schedule_type: Literal["date", "interval", "cron"] schedule_type: Literal["date", "interval", "cron"]
schedule_setting: dict schedule_setting: dict
registry: dict[str, Type["SchedulerConfig"]] = {} registry: dict[str, Type["SchedulerConfig"]] = {}
name: str
def __init_subclass__(cls, *, name, **kwargs): def __init_subclass__(cls, *, name, **kwargs):
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
cls.registry[name] = cls cls.registry[name] = cls
cls.name = name
def __str__(self):
return f"[{self.name}]-{self.name}-{self.schedule_setting}"

View File

@ -117,3 +117,53 @@ async def test_get_current_weight(app: App, db_migration):
assert weight["weibo-weibo_id"] == 10 assert weight["weibo-weibo_id"] == 10
assert weight["weibo-weibo_id1"] == 10 assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10 assert weight["weibo2-weibo_id1"] == 10
async def test_get_platform_target(app: App, db_migration):
from nonebot_bison.config import db_config
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=245,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
cats=[],
tags=[],
)
res = await config.get_platform_target("weibo")
assert len(res) == 2
await config.del_subscribe(123, "group", T_Target("weibo_id1"), "weibo")
res = await config.get_platform_target("weibo")
assert len(res) == 2
await config.del_subscribe(123, "group", T_Target("weibo_id"), "weibo")
res = await config.get_platform_target("weibo")
assert len(res) == 1
async with AsyncSession(get_engine()) as sess:
res = await sess.scalars(select(Target).where(Target.platform_name == "weibo"))
assert len(res.all()) == 2