This commit is contained in:
felinae98 2022-06-05 16:45:35 +08:00
parent 7b4c79acd3
commit 331d0f6101
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
18 changed files with 298 additions and 52 deletions

View File

@ -2,6 +2,7 @@ from nonebot.plugin import require
from . import (
admin_page,
bootstrap,
config,
config_manager,
platform,

View File

@ -0,0 +1,15 @@
from nonebot import get_driver
from .config.config_legacy import start_up as legacy_db_startup
from .config.db import upgrade_db
from .scheduler.manager import init_scheduler
@get_driver().on_startup
async def bootstrap():
# legacy db
legacy_db_startup()
# new db
await upgrade_db()
# init scheduler
await init_scheduler()

View File

@ -243,6 +243,4 @@ def start_up():
config.update_send_cache()
nonebot.get_driver().on_startup(start_up)
config = Config()

View File

@ -70,7 +70,6 @@ async def data_migrate():
logger.info("migrate success")
@nonebot.get_driver().on_startup
async def upgrade_db():
alembic_cfg = Config()
alembic_cfg.set_main_option(

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from datetime import datetime, time
from typing import Optional
from typing import Any, Awaitable, Callable, Optional
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
@ -34,6 +34,16 @@ class WeightConfig:
class DBConfig:
def __init__(self):
self.add_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None
self.delete_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None
def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.add_target_hook = fun
def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.delete_target_hook = fun
async def add_subscribe(
self,
user: int,
@ -62,6 +72,8 @@ class DBConfig:
db_target = Target(
target=target, platform_name=platform_name, target_name=target_name
)
if self.add_target_hook:
await self.add_target_hook(platform_name, target)
else:
db_target.target_name = target_name # type: ignore
subscribe = Subscribe(
@ -93,12 +105,12 @@ class DBConfig:
)
target_obj = await session.scalar(
select(Target).where(
Target.platform_name == platform_name, MTarget.target == target
Target.platform_name == platform_name, Target.target == target
)
)
await session.execute(
delete(Subscribe).where(
Subscribe.user == user_obj, MSubscribe.target == target_obj
Subscribe.user == user_obj, Subscribe.target == target_obj
)
)
target_count = await session.scalar(
@ -108,7 +120,9 @@ class DBConfig:
)
if target_count == 0:
# delete empty target
await session.delete(target_obj)
# await session.delete(target_obj)
if self.delete_target_hook:
await self.delete_target_hook(platform_name, T_Target(target))
await session.commit()
async def update_subscribe(
@ -139,15 +153,27 @@ class DBConfig:
subscribe_obj.target.target_name = target_name
await sess.commit()
async def get_platform_target(self, platform_name: str) -> list[Target]:
async with AsyncSession(get_engine()) as sess:
subq = select(Subscribe.target_id).distinct().subquery()
query = (
select(Target).join(subq).where(Target.platform_name == platform_name)
)
return (await sess.scalars(query)).all()
async def get_time_weight_config(
self, target: T_Target, platform_name: str
) -> WeightConfig:
async with AsyncSession(get_engine()) as sess:
time_weight_conf: list[ScheduleTimeWeight] = await sess.scalars(
select(ScheduleTimeWeight)
.where(Target.platform_name == platform_name, Target.target == target)
.join(Target)
)
time_weight_conf: list[ScheduleTimeWeight] = (
await sess.scalars(
select(ScheduleTimeWeight)
.where(
Target.platform_name == platform_name, Target.target == target
)
.join(Target)
)
).all()
targetObj: Target = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
@ -192,11 +218,13 @@ class DBConfig:
res = {}
cur_time = _get_time()
async with AsyncSession(get_engine()) as sess:
targets: list[Target] = await sess.scalars(
select(Target)
.where(Target.platform_name.in_(platform_list))
.options(selectinload(Target.time_weight))
)
targets: list[Target] = (
await sess.scalars(
select(Target)
.where(Target.platform_name.in_(platform_list))
.options(selectinload(Target.time_weight))
)
).all()
for target in targets:
key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight

View File

@ -7,9 +7,16 @@ from nonebot.plugin import require
from ..post import Post
from ..types import Category, RawPost, Target
from ..utils import http_client
from ..utils.scheduler_config import SchedulerConfig
from .platform import CategoryNotSupport, NewMessage, StatusChange
class ArknightsSchedConf(SchedulerConfig, name="arknights"):
schedule_type = "interval"
schedule_setting = {"seconds": 30}
class Arknights(NewMessage):
categories = {1: "游戏公告"}
@ -18,8 +25,7 @@ class Arknights(NewMessage):
enable_tag = False
enabled = True
is_common = False
schedule_type = "interval"
schedule_kw = {"seconds": 30}
scheduler_class = "arknights"
has_target = False
async def get_target_name(self, _: Target) -> str:
@ -91,8 +97,7 @@ class AkVersion(StatusChange):
enable_tag = False
enabled = True
is_common = False
schedule_type = "interval"
schedule_kw = {"seconds": 30}
scheduler_class = "arknights"
has_target = False
async def get_target_name(self, _: Target) -> str:
@ -147,8 +152,7 @@ class MonsterSiren(NewMessage):
enable_tag = False
enabled = True
is_common = False
schedule_type = "interval"
schedule_kw = {"seconds": 30}
scheduler_class = "arknights"
has_target = False
async def get_target_name(self, _: Target) -> str:
@ -199,8 +203,7 @@ class TerraHistoricusComic(NewMessage):
enable_tag = False
enabled = True
is_common = False
schedule_type = "interval"
schedule_kw = {"seconds": 30}
scheduler_class = "arknights"
has_target = False
async def get_target_name(self, _: Target) -> str:

View File

@ -4,10 +4,16 @@ from typing import Any, Optional
from ..post import Post
from ..types import Category, RawPost, Tag, Target
from ..utils import http_client
from ..utils import SchedulerConfig, http_client
from .platform import CategoryNotSupport, NewMessage, StatusChange
class BilibiliSchedConf(SchedulerConfig, name="bilibili.com"):
schedule_type = "interval"
schedule_setting = {"seconds": 10}
class Bilibili(NewMessage):
categories = {
@ -22,8 +28,7 @@ class Bilibili(NewMessage):
enable_tag = True
enabled = True
is_common = True
schedule_type = "interval"
schedule_kw = {"seconds": 10}
scheduler_class = "bilibili.com"
name = "B站"
has_target = True
parse_target_promot = "请输入用户主页的链接"
@ -167,8 +172,7 @@ class Bilibililive(StatusChange):
enable_tag = True
enabled = True
is_common = True
schedule_type = "interval"
schedule_kw = {"seconds": 10}
scheduler_class = "bilibili.com"
name = "Bilibili直播"
has_target = True

View File

@ -2,10 +2,16 @@ from typing import Any
from ..post import Post
from ..types import RawPost, Target
from ..utils import http_client
from ..utils import SchedulerConfig, http_client
from .platform import NewMessage
class FF14SchedConf(SchedulerConfig, name="ff14"):
schedule_type = "interval"
schedule_setting = {"seconds": 60}
class FF14(NewMessage):
categories = {}
@ -14,8 +20,7 @@ class FF14(NewMessage):
enable_tag = False
enabled = True
is_common = False
schedule_type = "interval"
schedule_kw = {"seconds": 60}
scheduler_class = "ff14"
has_target = False
async def get_target_name(self, _: Target) -> str:

View File

@ -3,10 +3,16 @@ from typing import Any, Optional
from ..post import Post
from ..types import RawPost, Target
from ..utils import http_client
from ..utils import SchedulerConfig, http_client
from .platform import NewMessage
class NcmSchedConf(SchedulerConfig, name="music.163.com"):
schedule_type = "interval"
schedule_setting = {"minutes": 1}
class NcmArtist(NewMessage):
categories = {}
@ -14,8 +20,7 @@ class NcmArtist(NewMessage):
enable_tag = False
enabled = True
is_common = True
schedule_type = "interval"
schedule_kw = {"minutes": 1}
scheduler_class = "music.163.com"
name = "网易云-歌手"
has_target = True
parse_target_promot = "请输入歌手主页包含数字ID的链接"

View File

@ -14,8 +14,7 @@ class NcmRadio(NewMessage):
enable_tag = False
enabled = True
is_common = False
schedule_type = "interval"
schedule_kw = {"minutes": 10}
scheduler_class = "music.163.com"
name = "网易云-电台"
has_target = True
parse_target_promot = "请输入主播电台主页包含数字ID的链接"

View File

@ -39,8 +39,7 @@ class RegistryABCMeta(RegistryMeta, ABC):
class Platform(metaclass=RegistryABCMeta, base=True):
schedule_type: Literal["date", "interval", "cron"]
schedule_kw: dict
scheduler_class: str
is_common: bool
enabled: bool
name: str
@ -332,11 +331,11 @@ class NoTargetGroup(Platform, abstract=True):
def __init__(self, platform_list: list[Platform]):
self.platform_list = platform_list
self.platform_name = platform_list[0].platform_name
name = self.DUMMY_STR
self.categories = {}
categories_keys = set()
self.schedule_type = platform_list[0].schedule_type
self.schedule_kw = platform_list[0].schedule_kw
self.scheduler_class = platform_list[0].scheduler_class
for platform in platform_list:
if platform.has_target:
raise RuntimeError(
@ -355,10 +354,7 @@ class NoTargetGroup(Platform, abstract=True):
)
categories_keys |= platform_category_key_set
self.categories.update(platform.categories)
if (
platform.schedule_kw != self.schedule_kw
or platform.schedule_type != self.schedule_type
):
if platform.scheduler_class != self.scheduler_class:
raise RuntimeError(
"Platform scheduler for {} not fit".format(self.platform_name)
)

View File

@ -6,10 +6,16 @@ from bs4 import BeautifulSoup as bs
from ..post import Post
from ..types import RawPost, Target
from ..utils import http_client
from ..utils import SchedulerConfig, http_client
from .platform import NewMessage
class RssSchedConf(SchedulerConfig, name="rss"):
schedule_type = "interval"
schedule_setting = {"seconds": 30}
class Rss(NewMessage):
categories = {}
@ -18,8 +24,7 @@ class Rss(NewMessage):
name = "Rss"
enabled = True
is_common = True
schedule_type = "interval"
schedule_kw = {"seconds": 30}
scheduler_class = "rss"
has_target = True
async def get_target_name(self, target: Target) -> Optional[str]:

View File

@ -8,10 +8,15 @@ from nonebot.log import logger
from ..post import Post
from ..types import *
from ..utils import http_client
from ..utils import SchedulerConfig, http_client
from .platform import NewMessage
class WeiboSchedConf(SchedulerConfig, name="weibo.com"):
schedule_type = "interval"
schedule_setting = {"seconds": 3}
class Weibo(NewMessage):
categories = {
@ -25,8 +30,7 @@ class Weibo(NewMessage):
name = "新浪微博"
enabled = True
is_common = True
schedule_type = "interval"
schedule_kw = {"seconds": 3}
scheduler_class = "weibo.com"
has_target = True
parse_target_promot = "请输入用户主页包含数字UID的链接"

View File

@ -0,0 +1 @@
from .manager import *

View File

@ -0,0 +1,43 @@
from nonebot.log import logger
from ..config import config
from ..config.db_model import Target
from ..platform import platform_manager
from ..types import Target as T_Target
from ..utils import SchedulerConfig
from .scheduler import Scheduler
scheduler_dict: dict[str, Scheduler] = {}
_schedule_class_dict: dict[str, list[Target]] = {}
async def init_scheduler():
for platform in platform_manager.values():
scheduler_class = platform.scheduler_class
platform_name = platform.platform_name
targets = await config.get_platform_target(platform_name)
if scheduler_class not in _schedule_class_dict:
_schedule_class_dict[scheduler_class] = targets
else:
_schedule_class_dict[scheduler_class].extend(targets)
for scheduler_class, target_list in _schedule_class_dict.items():
schedulable_args = []
for target in target_list:
schedulable_args.append((target.platform_name, T_Target(target.target)))
scheduler_dict[scheduler_class] = Scheduler(scheduler_class, schedulable_args)
async def handle_insert_new_target(platform_name: str, target: T_Target):
platform = platform_manager[platform_name]
scheduler_obj = scheduler_dict[platform.scheduler_class]
scheduler_obj.insert_new_schedulable(platform_name, target)
async def handle_delete_target(platform_name: str, target: T_Target):
platform = platform_manager[platform_name]
scheduler_obj = scheduler_dict[platform.scheduler_class]
scheduler_obj.delete_schedulable(platform_name, target)
config.register_add_target_hook(handle_delete_target)
config.register_delete_target_hook(handle_delete_target)

View File

@ -0,0 +1,85 @@
from dataclasses import dataclass
from typing import Optional
from nonebot.log import logger
from ..config import config
from ..platform.platform import Platform
from ..types import Target
from ..utils import SchedulerConfig
@dataclass
class Schedulable:
platform_name: str
target: Target
current_weight: int
class Scheduler:
schedulable_list: list[Schedulable]
def __init__(self, name: str, schedulables: list[tuple[str, Target]]):
conf = SchedulerConfig.registry.get(name)
if not conf:
logger.error(f"scheduler config [{name}] not found, exiting")
raise RuntimeError(f"{name} not found")
self.scheduler_config = conf
self.schedulable_list = []
platform_name_set = set()
for platform_name, target in schedulables:
self.schedulable_list.append(
Schedulable(
platform_name=platform_name, target=target, current_weight=0
)
)
platform_name_set.add(platform_name)
self.platform_name_list = list(platform_name_set)
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
async def schedule(self) -> Optional[Schedulable]:
if not self.schedulable_list:
return None
cur_weight = await config.get_current_weight_val(self.platform_name_list)
weight_sum = self.pre_weight_val
self.pre_weight_val = 0
cur_max_schedulable = None
for schedulable in self.schedulable_list:
schedulable.current_weight += cur_weight[
f"{schedulable.platform_name}-{schedulable.target}"
]
weight_sum += cur_weight[
f"{schedulable.platform_name}-{schedulable.target}"
]
if (
not cur_max_schedulable
or cur_max_schedulable.current_weight < schedulable.current_weight
):
cur_max_schedulable = schedulable
assert cur_max_schedulable
cur_max_schedulable.current_weight -= weight_sum
return cur_max_schedulable
def insert_new_schedulable(self, platform_name: str, target: Target):
self.pre_weight_val += 1000
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
logger.info(
f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})"
)
def delete_schedulable(self, platform_name, target: Target):
if not self.schedulable_list:
return
to_find_idx = None
for idx, schedulable in enumerate(self.schedulable_list):
if (
schedulable.platform_name == platform_name
and schedulable.target == target
):
to_find_idx = idx
break
if to_find_idx is not None:
deleted_schdulable = self.schedulable_list.pop(to_find_idx)
self.pre_weight_val -= deleted_schdulable.current_weight
return

View File

@ -6,7 +6,12 @@ class SchedulerConfig:
schedule_type: Literal["date", "interval", "cron"]
schedule_setting: dict
registry: dict[str, Type["SchedulerConfig"]] = {}
name: str
def __init_subclass__(cls, *, name, **kwargs):
super().__init_subclass__(**kwargs)
cls.registry[name] = cls
cls.name = name
def __str__(self):
return f"[{self.name}]-{self.name}-{self.schedule_setting}"

View File

@ -117,3 +117,53 @@ async def test_get_current_weight(app: App, db_migration):
assert weight["weibo-weibo_id"] == 10
assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10
async def test_get_platform_target(app: App, db_migration):
from nonebot_bison.config import db_config
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=245,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
cats=[],
tags=[],
)
res = await config.get_platform_target("weibo")
assert len(res) == 2
await config.del_subscribe(123, "group", T_Target("weibo_id1"), "weibo")
res = await config.get_platform_target("weibo")
assert len(res) == 2
await config.del_subscribe(123, "group", T_Target("weibo_id"), "weibo")
res = await config.get_platform_target("weibo")
assert len(res) == 1
async with AsyncSession(get_engine()) as sess:
res = await sess.scalars(select(Target).where(Target.platform_name == "weibo"))
assert len(res.all()) == 2