mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-08 04:43:00 +08:00
update
This commit is contained in:
parent
7b4c79acd3
commit
331d0f6101
@ -2,6 +2,7 @@ from nonebot.plugin import require
|
|||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
admin_page,
|
admin_page,
|
||||||
|
bootstrap,
|
||||||
config,
|
config,
|
||||||
config_manager,
|
config_manager,
|
||||||
platform,
|
platform,
|
||||||
|
15
src/plugins/nonebot_bison/bootstrap.py
Normal file
15
src/plugins/nonebot_bison/bootstrap.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
from .config.config_legacy import start_up as legacy_db_startup
|
||||||
|
from .config.db import upgrade_db
|
||||||
|
from .scheduler.manager import init_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@get_driver().on_startup
|
||||||
|
async def bootstrap():
|
||||||
|
# legacy db
|
||||||
|
legacy_db_startup()
|
||||||
|
# new db
|
||||||
|
await upgrade_db()
|
||||||
|
# init scheduler
|
||||||
|
await init_scheduler()
|
@ -243,6 +243,4 @@ def start_up():
|
|||||||
config.update_send_cache()
|
config.update_send_cache()
|
||||||
|
|
||||||
|
|
||||||
nonebot.get_driver().on_startup(start_up)
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
|
@ -70,7 +70,6 @@ async def data_migrate():
|
|||||||
logger.info("migrate success")
|
logger.info("migrate success")
|
||||||
|
|
||||||
|
|
||||||
@nonebot.get_driver().on_startup
|
|
||||||
async def upgrade_db():
|
async def upgrade_db():
|
||||||
alembic_cfg = Config()
|
alembic_cfg = Config()
|
||||||
alembic_cfg.set_main_option(
|
alembic_cfg.set_main_option(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, time
|
from datetime import datetime, time
|
||||||
from typing import Optional
|
from typing import Any, Awaitable, Callable, Optional
|
||||||
|
|
||||||
from nonebot_plugin_datastore.db import get_engine
|
from nonebot_plugin_datastore.db import get_engine
|
||||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||||
@ -34,6 +34,16 @@ class WeightConfig:
|
|||||||
|
|
||||||
|
|
||||||
class DBConfig:
|
class DBConfig:
|
||||||
|
def __init__(self):
|
||||||
|
self.add_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None
|
||||||
|
self.delete_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None
|
||||||
|
|
||||||
|
def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
|
||||||
|
self.add_target_hook = fun
|
||||||
|
|
||||||
|
def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
|
||||||
|
self.delete_target_hook = fun
|
||||||
|
|
||||||
async def add_subscribe(
|
async def add_subscribe(
|
||||||
self,
|
self,
|
||||||
user: int,
|
user: int,
|
||||||
@ -62,6 +72,8 @@ class DBConfig:
|
|||||||
db_target = Target(
|
db_target = Target(
|
||||||
target=target, platform_name=platform_name, target_name=target_name
|
target=target, platform_name=platform_name, target_name=target_name
|
||||||
)
|
)
|
||||||
|
if self.add_target_hook:
|
||||||
|
await self.add_target_hook(platform_name, target)
|
||||||
else:
|
else:
|
||||||
db_target.target_name = target_name # type: ignore
|
db_target.target_name = target_name # type: ignore
|
||||||
subscribe = Subscribe(
|
subscribe = Subscribe(
|
||||||
@ -93,12 +105,12 @@ class DBConfig:
|
|||||||
)
|
)
|
||||||
target_obj = await session.scalar(
|
target_obj = await session.scalar(
|
||||||
select(Target).where(
|
select(Target).where(
|
||||||
Target.platform_name == platform_name, MTarget.target == target
|
Target.platform_name == platform_name, Target.target == target
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await session.execute(
|
await session.execute(
|
||||||
delete(Subscribe).where(
|
delete(Subscribe).where(
|
||||||
Subscribe.user == user_obj, MSubscribe.target == target_obj
|
Subscribe.user == user_obj, Subscribe.target == target_obj
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
target_count = await session.scalar(
|
target_count = await session.scalar(
|
||||||
@ -108,7 +120,9 @@ class DBConfig:
|
|||||||
)
|
)
|
||||||
if target_count == 0:
|
if target_count == 0:
|
||||||
# delete empty target
|
# delete empty target
|
||||||
await session.delete(target_obj)
|
# await session.delete(target_obj)
|
||||||
|
if self.delete_target_hook:
|
||||||
|
await self.delete_target_hook(platform_name, T_Target(target))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def update_subscribe(
|
async def update_subscribe(
|
||||||
@ -139,15 +153,27 @@ class DBConfig:
|
|||||||
subscribe_obj.target.target_name = target_name
|
subscribe_obj.target.target_name = target_name
|
||||||
await sess.commit()
|
await sess.commit()
|
||||||
|
|
||||||
|
async def get_platform_target(self, platform_name: str) -> list[Target]:
|
||||||
|
async with AsyncSession(get_engine()) as sess:
|
||||||
|
subq = select(Subscribe.target_id).distinct().subquery()
|
||||||
|
query = (
|
||||||
|
select(Target).join(subq).where(Target.platform_name == platform_name)
|
||||||
|
)
|
||||||
|
return (await sess.scalars(query)).all()
|
||||||
|
|
||||||
async def get_time_weight_config(
|
async def get_time_weight_config(
|
||||||
self, target: T_Target, platform_name: str
|
self, target: T_Target, platform_name: str
|
||||||
) -> WeightConfig:
|
) -> WeightConfig:
|
||||||
async with AsyncSession(get_engine()) as sess:
|
async with AsyncSession(get_engine()) as sess:
|
||||||
time_weight_conf: list[ScheduleTimeWeight] = await sess.scalars(
|
time_weight_conf: list[ScheduleTimeWeight] = (
|
||||||
|
await sess.scalars(
|
||||||
select(ScheduleTimeWeight)
|
select(ScheduleTimeWeight)
|
||||||
.where(Target.platform_name == platform_name, Target.target == target)
|
.where(
|
||||||
|
Target.platform_name == platform_name, Target.target == target
|
||||||
|
)
|
||||||
.join(Target)
|
.join(Target)
|
||||||
)
|
)
|
||||||
|
).all()
|
||||||
targetObj: Target = await sess.scalar(
|
targetObj: Target = await sess.scalar(
|
||||||
select(Target).where(
|
select(Target).where(
|
||||||
Target.platform_name == platform_name, Target.target == target
|
Target.platform_name == platform_name, Target.target == target
|
||||||
@ -192,11 +218,13 @@ class DBConfig:
|
|||||||
res = {}
|
res = {}
|
||||||
cur_time = _get_time()
|
cur_time = _get_time()
|
||||||
async with AsyncSession(get_engine()) as sess:
|
async with AsyncSession(get_engine()) as sess:
|
||||||
targets: list[Target] = await sess.scalars(
|
targets: list[Target] = (
|
||||||
|
await sess.scalars(
|
||||||
select(Target)
|
select(Target)
|
||||||
.where(Target.platform_name.in_(platform_list))
|
.where(Target.platform_name.in_(platform_list))
|
||||||
.options(selectinload(Target.time_weight))
|
.options(selectinload(Target.time_weight))
|
||||||
)
|
)
|
||||||
|
).all()
|
||||||
for target in targets:
|
for target in targets:
|
||||||
key = f"{target.platform_name}-{target.target}"
|
key = f"{target.platform_name}-{target.target}"
|
||||||
weight = target.default_schedule_weight
|
weight = target.default_schedule_weight
|
||||||
|
@ -7,9 +7,16 @@ from nonebot.plugin import require
|
|||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import Category, RawPost, Target
|
from ..types import Category, RawPost, Target
|
||||||
from ..utils import http_client
|
from ..utils import http_client
|
||||||
|
from ..utils.scheduler_config import SchedulerConfig
|
||||||
from .platform import CategoryNotSupport, NewMessage, StatusChange
|
from .platform import CategoryNotSupport, NewMessage, StatusChange
|
||||||
|
|
||||||
|
|
||||||
|
class ArknightsSchedConf(SchedulerConfig, name="arknights"):
|
||||||
|
|
||||||
|
schedule_type = "interval"
|
||||||
|
schedule_setting = {"seconds": 30}
|
||||||
|
|
||||||
|
|
||||||
class Arknights(NewMessage):
|
class Arknights(NewMessage):
|
||||||
|
|
||||||
categories = {1: "游戏公告"}
|
categories = {1: "游戏公告"}
|
||||||
@ -18,8 +25,7 @@ class Arknights(NewMessage):
|
|||||||
enable_tag = False
|
enable_tag = False
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = False
|
is_common = False
|
||||||
schedule_type = "interval"
|
scheduler_class = "arknights"
|
||||||
schedule_kw = {"seconds": 30}
|
|
||||||
has_target = False
|
has_target = False
|
||||||
|
|
||||||
async def get_target_name(self, _: Target) -> str:
|
async def get_target_name(self, _: Target) -> str:
|
||||||
@ -91,8 +97,7 @@ class AkVersion(StatusChange):
|
|||||||
enable_tag = False
|
enable_tag = False
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = False
|
is_common = False
|
||||||
schedule_type = "interval"
|
scheduler_class = "arknights"
|
||||||
schedule_kw = {"seconds": 30}
|
|
||||||
has_target = False
|
has_target = False
|
||||||
|
|
||||||
async def get_target_name(self, _: Target) -> str:
|
async def get_target_name(self, _: Target) -> str:
|
||||||
@ -147,8 +152,7 @@ class MonsterSiren(NewMessage):
|
|||||||
enable_tag = False
|
enable_tag = False
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = False
|
is_common = False
|
||||||
schedule_type = "interval"
|
scheduler_class = "arknights"
|
||||||
schedule_kw = {"seconds": 30}
|
|
||||||
has_target = False
|
has_target = False
|
||||||
|
|
||||||
async def get_target_name(self, _: Target) -> str:
|
async def get_target_name(self, _: Target) -> str:
|
||||||
@ -199,8 +203,7 @@ class TerraHistoricusComic(NewMessage):
|
|||||||
enable_tag = False
|
enable_tag = False
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = False
|
is_common = False
|
||||||
schedule_type = "interval"
|
scheduler_class = "arknights"
|
||||||
schedule_kw = {"seconds": 30}
|
|
||||||
has_target = False
|
has_target = False
|
||||||
|
|
||||||
async def get_target_name(self, _: Target) -> str:
|
async def get_target_name(self, _: Target) -> str:
|
||||||
|
@ -4,10 +4,16 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import Category, RawPost, Tag, Target
|
from ..types import Category, RawPost, Tag, Target
|
||||||
from ..utils import http_client
|
from ..utils import SchedulerConfig, http_client
|
||||||
from .platform import CategoryNotSupport, NewMessage, StatusChange
|
from .platform import CategoryNotSupport, NewMessage, StatusChange
|
||||||
|
|
||||||
|
|
||||||
|
class BilibiliSchedConf(SchedulerConfig, name="bilibili.com"):
|
||||||
|
|
||||||
|
schedule_type = "interval"
|
||||||
|
schedule_setting = {"seconds": 10}
|
||||||
|
|
||||||
|
|
||||||
class Bilibili(NewMessage):
|
class Bilibili(NewMessage):
|
||||||
|
|
||||||
categories = {
|
categories = {
|
||||||
@ -22,8 +28,7 @@ class Bilibili(NewMessage):
|
|||||||
enable_tag = True
|
enable_tag = True
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = True
|
is_common = True
|
||||||
schedule_type = "interval"
|
scheduler_class = "bilibili.com"
|
||||||
schedule_kw = {"seconds": 10}
|
|
||||||
name = "B站"
|
name = "B站"
|
||||||
has_target = True
|
has_target = True
|
||||||
parse_target_promot = "请输入用户主页的链接"
|
parse_target_promot = "请输入用户主页的链接"
|
||||||
@ -167,8 +172,7 @@ class Bilibililive(StatusChange):
|
|||||||
enable_tag = True
|
enable_tag = True
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = True
|
is_common = True
|
||||||
schedule_type = "interval"
|
scheduler_class = "bilibili.com"
|
||||||
schedule_kw = {"seconds": 10}
|
|
||||||
name = "Bilibili直播"
|
name = "Bilibili直播"
|
||||||
has_target = True
|
has_target = True
|
||||||
|
|
||||||
|
@ -2,10 +2,16 @@ from typing import Any
|
|||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import RawPost, Target
|
from ..types import RawPost, Target
|
||||||
from ..utils import http_client
|
from ..utils import SchedulerConfig, http_client
|
||||||
from .platform import NewMessage
|
from .platform import NewMessage
|
||||||
|
|
||||||
|
|
||||||
|
class FF14SchedConf(SchedulerConfig, name="ff14"):
|
||||||
|
|
||||||
|
schedule_type = "interval"
|
||||||
|
schedule_setting = {"seconds": 60}
|
||||||
|
|
||||||
|
|
||||||
class FF14(NewMessage):
|
class FF14(NewMessage):
|
||||||
|
|
||||||
categories = {}
|
categories = {}
|
||||||
@ -14,8 +20,7 @@ class FF14(NewMessage):
|
|||||||
enable_tag = False
|
enable_tag = False
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = False
|
is_common = False
|
||||||
schedule_type = "interval"
|
scheduler_class = "ff14"
|
||||||
schedule_kw = {"seconds": 60}
|
|
||||||
has_target = False
|
has_target = False
|
||||||
|
|
||||||
async def get_target_name(self, _: Target) -> str:
|
async def get_target_name(self, _: Target) -> str:
|
||||||
|
@ -3,10 +3,16 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import RawPost, Target
|
from ..types import RawPost, Target
|
||||||
from ..utils import http_client
|
from ..utils import SchedulerConfig, http_client
|
||||||
from .platform import NewMessage
|
from .platform import NewMessage
|
||||||
|
|
||||||
|
|
||||||
|
class NcmSchedConf(SchedulerConfig, name="music.163.com"):
|
||||||
|
|
||||||
|
schedule_type = "interval"
|
||||||
|
schedule_setting = {"minutes": 1}
|
||||||
|
|
||||||
|
|
||||||
class NcmArtist(NewMessage):
|
class NcmArtist(NewMessage):
|
||||||
|
|
||||||
categories = {}
|
categories = {}
|
||||||
@ -14,8 +20,7 @@ class NcmArtist(NewMessage):
|
|||||||
enable_tag = False
|
enable_tag = False
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = True
|
is_common = True
|
||||||
schedule_type = "interval"
|
scheduler_class = "music.163.com"
|
||||||
schedule_kw = {"minutes": 1}
|
|
||||||
name = "网易云-歌手"
|
name = "网易云-歌手"
|
||||||
has_target = True
|
has_target = True
|
||||||
parse_target_promot = "请输入歌手主页(包含数字ID)的链接"
|
parse_target_promot = "请输入歌手主页(包含数字ID)的链接"
|
||||||
|
@ -14,8 +14,7 @@ class NcmRadio(NewMessage):
|
|||||||
enable_tag = False
|
enable_tag = False
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = False
|
is_common = False
|
||||||
schedule_type = "interval"
|
scheduler_class = "music.163.com"
|
||||||
schedule_kw = {"minutes": 10}
|
|
||||||
name = "网易云-电台"
|
name = "网易云-电台"
|
||||||
has_target = True
|
has_target = True
|
||||||
parse_target_promot = "请输入主播电台主页(包含数字ID)的链接"
|
parse_target_promot = "请输入主播电台主页(包含数字ID)的链接"
|
||||||
|
@ -39,8 +39,7 @@ class RegistryABCMeta(RegistryMeta, ABC):
|
|||||||
|
|
||||||
class Platform(metaclass=RegistryABCMeta, base=True):
|
class Platform(metaclass=RegistryABCMeta, base=True):
|
||||||
|
|
||||||
schedule_type: Literal["date", "interval", "cron"]
|
scheduler_class: str
|
||||||
schedule_kw: dict
|
|
||||||
is_common: bool
|
is_common: bool
|
||||||
enabled: bool
|
enabled: bool
|
||||||
name: str
|
name: str
|
||||||
@ -332,11 +331,11 @@ class NoTargetGroup(Platform, abstract=True):
|
|||||||
|
|
||||||
def __init__(self, platform_list: list[Platform]):
|
def __init__(self, platform_list: list[Platform]):
|
||||||
self.platform_list = platform_list
|
self.platform_list = platform_list
|
||||||
|
self.platform_name = platform_list[0].platform_name
|
||||||
name = self.DUMMY_STR
|
name = self.DUMMY_STR
|
||||||
self.categories = {}
|
self.categories = {}
|
||||||
categories_keys = set()
|
categories_keys = set()
|
||||||
self.schedule_type = platform_list[0].schedule_type
|
self.scheduler_class = platform_list[0].scheduler_class
|
||||||
self.schedule_kw = platform_list[0].schedule_kw
|
|
||||||
for platform in platform_list:
|
for platform in platform_list:
|
||||||
if platform.has_target:
|
if platform.has_target:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -355,10 +354,7 @@ class NoTargetGroup(Platform, abstract=True):
|
|||||||
)
|
)
|
||||||
categories_keys |= platform_category_key_set
|
categories_keys |= platform_category_key_set
|
||||||
self.categories.update(platform.categories)
|
self.categories.update(platform.categories)
|
||||||
if (
|
if platform.scheduler_class != self.scheduler_class:
|
||||||
platform.schedule_kw != self.schedule_kw
|
|
||||||
or platform.schedule_type != self.schedule_type
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Platform scheduler for {} not fit".format(self.platform_name)
|
"Platform scheduler for {} not fit".format(self.platform_name)
|
||||||
)
|
)
|
||||||
|
@ -6,10 +6,16 @@ from bs4 import BeautifulSoup as bs
|
|||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import RawPost, Target
|
from ..types import RawPost, Target
|
||||||
from ..utils import http_client
|
from ..utils import SchedulerConfig, http_client
|
||||||
from .platform import NewMessage
|
from .platform import NewMessage
|
||||||
|
|
||||||
|
|
||||||
|
class RssSchedConf(SchedulerConfig, name="rss"):
|
||||||
|
|
||||||
|
schedule_type = "interval"
|
||||||
|
schedule_setting = {"seconds": 30}
|
||||||
|
|
||||||
|
|
||||||
class Rss(NewMessage):
|
class Rss(NewMessage):
|
||||||
|
|
||||||
categories = {}
|
categories = {}
|
||||||
@ -18,8 +24,7 @@ class Rss(NewMessage):
|
|||||||
name = "Rss"
|
name = "Rss"
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = True
|
is_common = True
|
||||||
schedule_type = "interval"
|
scheduler_class = "rss"
|
||||||
schedule_kw = {"seconds": 30}
|
|
||||||
has_target = True
|
has_target = True
|
||||||
|
|
||||||
async def get_target_name(self, target: Target) -> Optional[str]:
|
async def get_target_name(self, target: Target) -> Optional[str]:
|
||||||
|
@ -8,10 +8,15 @@ from nonebot.log import logger
|
|||||||
|
|
||||||
from ..post import Post
|
from ..post import Post
|
||||||
from ..types import *
|
from ..types import *
|
||||||
from ..utils import http_client
|
from ..utils import SchedulerConfig, http_client
|
||||||
from .platform import NewMessage
|
from .platform import NewMessage
|
||||||
|
|
||||||
|
|
||||||
|
class WeiboSchedConf(SchedulerConfig, name="weibo.com"):
|
||||||
|
schedule_type = "interval"
|
||||||
|
schedule_setting = {"seconds": 3}
|
||||||
|
|
||||||
|
|
||||||
class Weibo(NewMessage):
|
class Weibo(NewMessage):
|
||||||
|
|
||||||
categories = {
|
categories = {
|
||||||
@ -25,8 +30,7 @@ class Weibo(NewMessage):
|
|||||||
name = "新浪微博"
|
name = "新浪微博"
|
||||||
enabled = True
|
enabled = True
|
||||||
is_common = True
|
is_common = True
|
||||||
schedule_type = "interval"
|
scheduler_class = "weibo.com"
|
||||||
schedule_kw = {"seconds": 3}
|
|
||||||
has_target = True
|
has_target = True
|
||||||
parse_target_promot = "请输入用户主页(包含数字UID)的链接"
|
parse_target_promot = "请输入用户主页(包含数字UID)的链接"
|
||||||
|
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
from .manager import *
|
43
src/plugins/nonebot_bison/scheduler/manager.py
Normal file
43
src/plugins/nonebot_bison/scheduler/manager.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from nonebot.log import logger
|
||||||
|
|
||||||
|
from ..config import config
|
||||||
|
from ..config.db_model import Target
|
||||||
|
from ..platform import platform_manager
|
||||||
|
from ..types import Target as T_Target
|
||||||
|
from ..utils import SchedulerConfig
|
||||||
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
|
scheduler_dict: dict[str, Scheduler] = {}
|
||||||
|
_schedule_class_dict: dict[str, list[Target]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def init_scheduler():
|
||||||
|
for platform in platform_manager.values():
|
||||||
|
scheduler_class = platform.scheduler_class
|
||||||
|
platform_name = platform.platform_name
|
||||||
|
targets = await config.get_platform_target(platform_name)
|
||||||
|
if scheduler_class not in _schedule_class_dict:
|
||||||
|
_schedule_class_dict[scheduler_class] = targets
|
||||||
|
else:
|
||||||
|
_schedule_class_dict[scheduler_class].extend(targets)
|
||||||
|
for scheduler_class, target_list in _schedule_class_dict.items():
|
||||||
|
schedulable_args = []
|
||||||
|
for target in target_list:
|
||||||
|
schedulable_args.append((target.platform_name, T_Target(target.target)))
|
||||||
|
scheduler_dict[scheduler_class] = Scheduler(scheduler_class, schedulable_args)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_insert_new_target(platform_name: str, target: T_Target):
|
||||||
|
platform = platform_manager[platform_name]
|
||||||
|
scheduler_obj = scheduler_dict[platform.scheduler_class]
|
||||||
|
scheduler_obj.insert_new_schedulable(platform_name, target)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_delete_target(platform_name: str, target: T_Target):
|
||||||
|
platform = platform_manager[platform_name]
|
||||||
|
scheduler_obj = scheduler_dict[platform.scheduler_class]
|
||||||
|
scheduler_obj.delete_schedulable(platform_name, target)
|
||||||
|
|
||||||
|
|
||||||
|
config.register_add_target_hook(handle_delete_target)
|
||||||
|
config.register_delete_target_hook(handle_delete_target)
|
85
src/plugins/nonebot_bison/scheduler/scheduler.py
Normal file
85
src/plugins/nonebot_bison/scheduler/scheduler.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from nonebot.log import logger
|
||||||
|
|
||||||
|
from ..config import config
|
||||||
|
from ..platform.platform import Platform
|
||||||
|
from ..types import Target
|
||||||
|
from ..utils import SchedulerConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Schedulable:
|
||||||
|
platform_name: str
|
||||||
|
target: Target
|
||||||
|
current_weight: int
|
||||||
|
|
||||||
|
|
||||||
|
class Scheduler:
|
||||||
|
|
||||||
|
schedulable_list: list[Schedulable]
|
||||||
|
|
||||||
|
def __init__(self, name: str, schedulables: list[tuple[str, Target]]):
|
||||||
|
conf = SchedulerConfig.registry.get(name)
|
||||||
|
if not conf:
|
||||||
|
logger.error(f"scheduler config [{name}] not found, exiting")
|
||||||
|
raise RuntimeError(f"{name} not found")
|
||||||
|
self.scheduler_config = conf
|
||||||
|
self.schedulable_list = []
|
||||||
|
platform_name_set = set()
|
||||||
|
for platform_name, target in schedulables:
|
||||||
|
self.schedulable_list.append(
|
||||||
|
Schedulable(
|
||||||
|
platform_name=platform_name, target=target, current_weight=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
platform_name_set.add(platform_name)
|
||||||
|
self.platform_name_list = list(platform_name_set)
|
||||||
|
self.pre_weight_val = 0 # 轮调度中“本轮”增加权重和的初值
|
||||||
|
|
||||||
|
async def schedule(self) -> Optional[Schedulable]:
|
||||||
|
if not self.schedulable_list:
|
||||||
|
return None
|
||||||
|
cur_weight = await config.get_current_weight_val(self.platform_name_list)
|
||||||
|
weight_sum = self.pre_weight_val
|
||||||
|
self.pre_weight_val = 0
|
||||||
|
cur_max_schedulable = None
|
||||||
|
for schedulable in self.schedulable_list:
|
||||||
|
schedulable.current_weight += cur_weight[
|
||||||
|
f"{schedulable.platform_name}-{schedulable.target}"
|
||||||
|
]
|
||||||
|
weight_sum += cur_weight[
|
||||||
|
f"{schedulable.platform_name}-{schedulable.target}"
|
||||||
|
]
|
||||||
|
if (
|
||||||
|
not cur_max_schedulable
|
||||||
|
or cur_max_schedulable.current_weight < schedulable.current_weight
|
||||||
|
):
|
||||||
|
cur_max_schedulable = schedulable
|
||||||
|
assert cur_max_schedulable
|
||||||
|
cur_max_schedulable.current_weight -= weight_sum
|
||||||
|
return cur_max_schedulable
|
||||||
|
|
||||||
|
def insert_new_schedulable(self, platform_name: str, target: Target):
|
||||||
|
self.pre_weight_val += 1000
|
||||||
|
self.schedulable_list.append(Schedulable(platform_name, target, 1000))
|
||||||
|
logger.info(
|
||||||
|
f"insert [{platform_name}]{target} to Schduler({self.scheduler_config.name})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_schedulable(self, platform_name, target: Target):
|
||||||
|
if not self.schedulable_list:
|
||||||
|
return
|
||||||
|
to_find_idx = None
|
||||||
|
for idx, schedulable in enumerate(self.schedulable_list):
|
||||||
|
if (
|
||||||
|
schedulable.platform_name == platform_name
|
||||||
|
and schedulable.target == target
|
||||||
|
):
|
||||||
|
to_find_idx = idx
|
||||||
|
break
|
||||||
|
if to_find_idx is not None:
|
||||||
|
deleted_schdulable = self.schedulable_list.pop(to_find_idx)
|
||||||
|
self.pre_weight_val -= deleted_schdulable.current_weight
|
||||||
|
return
|
@ -6,7 +6,12 @@ class SchedulerConfig:
|
|||||||
schedule_type: Literal["date", "interval", "cron"]
|
schedule_type: Literal["date", "interval", "cron"]
|
||||||
schedule_setting: dict
|
schedule_setting: dict
|
||||||
registry: dict[str, Type["SchedulerConfig"]] = {}
|
registry: dict[str, Type["SchedulerConfig"]] = {}
|
||||||
|
name: str
|
||||||
|
|
||||||
def __init_subclass__(cls, *, name, **kwargs):
|
def __init_subclass__(cls, *, name, **kwargs):
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
cls.registry[name] = cls
|
cls.registry[name] = cls
|
||||||
|
cls.name = name
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"[{self.name}]-{self.name}-{self.schedule_setting}"
|
||||||
|
@ -117,3 +117,53 @@ async def test_get_current_weight(app: App, db_migration):
|
|||||||
assert weight["weibo-weibo_id"] == 10
|
assert weight["weibo-weibo_id"] == 10
|
||||||
assert weight["weibo-weibo_id1"] == 10
|
assert weight["weibo-weibo_id1"] == 10
|
||||||
assert weight["weibo2-weibo_id1"] == 10
|
assert weight["weibo2-weibo_id1"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_platform_target(app: App, db_migration):
|
||||||
|
from nonebot_bison.config import db_config
|
||||||
|
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
|
||||||
|
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||||
|
from nonebot_bison.types import Target as T_Target
|
||||||
|
from nonebot_plugin_datastore.db import get_engine
|
||||||
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||||
|
from sqlalchemy.sql.expression import select
|
||||||
|
|
||||||
|
await config.add_subscribe(
|
||||||
|
user=123,
|
||||||
|
user_type="group",
|
||||||
|
target=T_Target("weibo_id"),
|
||||||
|
target_name="weibo_name",
|
||||||
|
platform_name="weibo",
|
||||||
|
cats=[],
|
||||||
|
tags=[],
|
||||||
|
)
|
||||||
|
await config.add_subscribe(
|
||||||
|
user=123,
|
||||||
|
user_type="group",
|
||||||
|
target=T_Target("weibo_id1"),
|
||||||
|
target_name="weibo_name1",
|
||||||
|
platform_name="weibo",
|
||||||
|
cats=[],
|
||||||
|
tags=[],
|
||||||
|
)
|
||||||
|
await config.add_subscribe(
|
||||||
|
user=245,
|
||||||
|
user_type="group",
|
||||||
|
target=T_Target("weibo_id1"),
|
||||||
|
target_name="weibo_name1",
|
||||||
|
platform_name="weibo",
|
||||||
|
cats=[],
|
||||||
|
tags=[],
|
||||||
|
)
|
||||||
|
res = await config.get_platform_target("weibo")
|
||||||
|
assert len(res) == 2
|
||||||
|
await config.del_subscribe(123, "group", T_Target("weibo_id1"), "weibo")
|
||||||
|
res = await config.get_platform_target("weibo")
|
||||||
|
assert len(res) == 2
|
||||||
|
await config.del_subscribe(123, "group", T_Target("weibo_id"), "weibo")
|
||||||
|
res = await config.get_platform_target("weibo")
|
||||||
|
assert len(res) == 1
|
||||||
|
|
||||||
|
async with AsyncSession(get_engine()) as sess:
|
||||||
|
res = await sess.scalars(select(Target).where(Target.platform_name == "weibo"))
|
||||||
|
assert len(res.all()) == 2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user