From 5010ca1ac5230667492d5c82b83021c0af2c76e2 Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Mon, 8 May 2023 21:09:28 +0800 Subject: [PATCH] :construction: sunset get_bot --- nonebot_bison/__init__.py | 5 ++ nonebot_bison/admin_page/api.py | 6 +-- nonebot_bison/scheduler/scheduler.py | 11 ++-- nonebot_bison/send.py | 31 ++++++----- nonebot_bison/utils/get_bot.py | 70 +------------------------ tests/conftest.py | 1 + tests/test_get_bot.py | 78 ---------------------------- 7 files changed, 30 insertions(+), 172 deletions(-) diff --git a/nonebot_bison/__init__.py b/nonebot_bison/__init__.py index 69b58d3..a339e5f 100644 --- a/nonebot_bison/__init__.py +++ b/nonebot_bison/__init__.py @@ -2,6 +2,9 @@ from nonebot.plugin import PluginMetadata, require require("nonebot_plugin_apscheduler") require("nonebot_plugin_datastore") +require("nonebot_plugin_saa") + +import nonebot_plugin_saa from . import ( admin_page, @@ -18,6 +21,8 @@ from . import ( from .plugin_config import plugin_config __help__version__ = "0.7.3" +nonebot_plugin_saa.enable_auto_select_bot() + __help__plugin__name__ = "nonebot_bison" __usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档,或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅" diff --git a/nonebot_bison/admin_page/api.py b/nonebot_bison/admin_page/api.py index 9f2b00d..5a4fb81 100644 --- a/nonebot_bison/admin_page/api.py +++ b/nonebot_bison/admin_page/api.py @@ -5,6 +5,7 @@ from fastapi.param_functions import Depends from fastapi.routing import APIRouter from fastapi.security.oauth2 import OAuth2PasswordBearer from nonebot_plugin_saa import TargetQQGroup +from nonebot_plugin_saa.utils.auto_select_bot import get_bot from ..apis import check_sub_target from ..config import ( @@ -17,7 +18,7 @@ from ..config.db_config import SubscribeDupException from ..platform import platform_manager from ..types import Target as T_Target from ..types import WeightConfig -from ..utils.get_bot import get_bot, get_groups +from ..utils.get_bot import get_groups from .jwt import load_jwt, pack_jwt from .token_manager import token_manager from .types import ( @@ -185,8 +186,7 @@ async def del_group_sub(groupNumber: int, platformName: str, target: str): async def update_group_sub(groupNumber: int, req: AddSubscribeReq): try: await config.update_subscribe( - int(groupNumber), - "group", + TargetQQGroup(group_id=groupNumber), req.target, req.targetName, req.platformName, diff --git a/nonebot_bison/scheduler/scheduler.py b/nonebot_bison/scheduler/scheduler.py index 1a7bae4..4f26b69 100644 --- a/nonebot_bison/scheduler/scheduler.py +++ b/nonebot_bison/scheduler/scheduler.py @@ -1,16 +1,15 @@ from dataclasses import dataclass from typing import Optional, Type -from nonebot.adapters.onebot.v11.bot import Bot from nonebot.log import logger from nonebot_plugin_apscheduler import scheduler +from nonebot_plugin_saa.utils.exceptions import NoBotFound from ..config import config from ..platform import platform_manager from ..send import send_msgs from ..types import Target from ..utils import ProcessContext, SchedulerConfig -from ..utils.get_bot import get_bot @dataclass @@ -107,17 +106,15 @@ class Scheduler: return for user, send_list in to_send: - bot = get_bot(user) for send_post in send_list: logger.info("send to {}: {}".format(user, send_post)) - if not bot: - logger.warning("no bot connected") - else: + try: await send_msgs( - bot, user, await send_post.generate_messages(), ) + except NoBotFound: + logger.warning("no bot connected") def insert_new_schedulable(self, platform_name: str, target: Target): self.pre_weight_val += 1000 diff --git a/nonebot_bison/send.py b/nonebot_bison/send.py index 14ba366..540eba5 100644 --- a/nonebot_bison/send.py +++ b/nonebot_bison/send.py @@ -2,24 +2,23 @@ import asyncio from collections import deque from typing import Deque -from nonebot.adapters import Bot from nonebot.adapters.onebot.v11.exception import ActionFailed from nonebot.log import logger from nonebot_plugin_saa import AggregatedMessageFactory, MessageFactory, PlatformTarget +from nonebot_plugin_saa.utils.auto_select_bot import refresh_bots from .plugin_config import plugin_config -from .utils.get_bot import refresh_bots Sendable = MessageFactory | AggregatedMessageFactory -QUEUE: Deque[tuple[Bot, PlatformTarget, Sendable, int]] = deque() +QUEUE: Deque[tuple[PlatformTarget, Sendable, int]] = deque() MESSGE_SEND_INTERVAL = 1.5 -async def _do_send(bot: "Bot", send_target: PlatformTarget, msg: Sendable): +async def _do_send(send_target: PlatformTarget, msg: Sendable): try: - await msg.send_to(send_target, bot) + await msg.send_to(send_target) except ActionFailed: # TODO: catch exception of other adapters await refresh_bots() logger.warning("send msg failed, refresh bots") @@ -34,14 +33,14 @@ async def do_send_msgs(): # the length of queue will be 0. # At that time, adding items to queue will trigger a new execution of this func, which is not expected. # So, read from queue first then pop from it - bot, send_target, msg_factory, retry_time = QUEUE[0] + send_target, msg_factory, retry_time = QUEUE[0] try: - await _do_send(bot, send_target, msg_factory) + await _do_send(send_target, msg_factory) except Exception as e: await asyncio.sleep(MESSGE_SEND_INTERVAL) QUEUE.popleft() if retry_time > 0: - QUEUE.appendleft((bot, send_target, msg_factory, retry_time - 1)) + QUEUE.appendleft((send_target, msg_factory, retry_time - 1)) else: msg_str = str(msg_factory) if len(msg_str) > 50: @@ -56,27 +55,27 @@ async def do_send_msgs(): return -async def _send_msgs_dispatch(bot: Bot, send_target: PlatformTarget, msg: Sendable): +async def _send_msgs_dispatch(send_target: PlatformTarget, msg: Sendable): if plugin_config.bison_use_queue: - QUEUE.append((bot, send_target, msg, plugin_config.bison_resend_times)) + QUEUE.append((send_target, msg, plugin_config.bison_resend_times)) # len(QUEUE) before append was 0 if len(QUEUE) == 1: asyncio.create_task(do_send_msgs()) else: - await _do_send(bot, send_target, msg) + await _do_send(send_target, msg) -async def send_msgs(bot: Bot, send_target: PlatformTarget, msgs: list[MessageFactory]): +async def send_msgs(send_target: PlatformTarget, msgs: list[MessageFactory]): if not plugin_config.bison_use_pic_merge: for msg in msgs: - await _send_msgs_dispatch(bot, send_target, msg) + await _send_msgs_dispatch(send_target, msg) return msgs = msgs.copy() if plugin_config.bison_use_pic_merge == 1: - await _send_msgs_dispatch(bot, send_target, msgs.pop(0)) + await _send_msgs_dispatch(send_target, msgs.pop(0)) if msgs: if len(msgs) == 1: # 只有一条消息序列就不合并转发 - await _send_msgs_dispatch(bot, send_target, msgs.pop(0)) + await _send_msgs_dispatch(send_target, msgs.pop(0)) else: forward_message = AggregatedMessageFactory(list(msgs)) - await _send_msgs_dispatch(bot, send_target, forward_message) + await _send_msgs_dispatch(send_target, forward_message) diff --git a/nonebot_bison/utils/get_bot.py b/nonebot_bison/utils/get_bot.py index 926dcd3..9065e0e 100644 --- a/nonebot_bison/utils/get_bot.py +++ b/nonebot_bison/utils/get_bot.py @@ -1,18 +1,11 @@ """ 提供获取 Bot 的方法 """ -import random from collections import defaultdict -from typing import Any, Optional +from typing import Any import nonebot -from nonebot import get_driver, on_notice from nonebot.adapters import Bot from nonebot.adapters.onebot.v11 import Bot as Ob11Bot -from nonebot.adapters.onebot.v11 import ( - FriendAddNoticeEvent, - GroupDecreaseNoticeEvent, - GroupIncreaseNoticeEvent, -) -from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, TargetQQPrivate +from nonebot_plugin_saa import PlatformTarget GROUP: dict[int, list[Bot]] = {} USER: dict[int, list[Bot]] = {} @@ -29,65 +22,6 @@ def get_bots() -> list[Bot]: return bots -async def _refresh_ob11(bot: Ob11Bot): - # 获取群列表 - groups = await bot.get_group_list() - for group in groups: - group_id = group["group_id"] - target = TargetQQGroup(group_id=group_id) - BOT_CACHE[target].append(bot) - - # 获取好友列表 - users = await bot.get_friend_list() - for user in users: - user_id = user["user_id"] - target = TargetQQPrivate(user_id=user_id) - BOT_CACHE[target].append(bot) - - -async def refresh_bots(): - """刷新缓存的 Bot 数据""" - BOT_CACHE.clear() - for bot in get_bots(): - match bot: - case Ob11Bot(): - await _refresh_ob11(bot) - - -driver = get_driver() - - -@driver.on_bot_connect -@driver.on_bot_disconnect -async def _(bot: Bot): - await refresh_bots() - - -change_notice = on_notice(priority=1) - - -@change_notice.handle() -async def _(bot: Bot, event: FriendAddNoticeEvent): - await refresh_bots() - - -# 01-06 16:56:51 [SUCCESS] nonebot | OneBot V11 **** | [notice.group_increase.approve]: {'time': 1672995411, 'self_id': ****, 'post_type': 'notice', 'notice_type': 'group_increase', 'sub_type': 'approve', 'user_id': ****, 'group_id': ****, 'operator_id': 0} -# 01-06 16:58:09 [SUCCESS] nonebot | OneBot V11 **** | [notice.group_decrease.kick_me]: {'time': 1672995489, 'self_id': ****, 'post_type': 'notice', 'notice_type': 'group_decrease', 'sub_type': 'kick_me', 'user_id': ****, 'group_id': ****, 'operator_id': ****} -@change_notice.handle() -async def _(bot: Bot, event: GroupDecreaseNoticeEvent | GroupIncreaseNoticeEvent): - if bot.self_id == event.user_id: - await refresh_bots() - - -def get_bot(user: PlatformTarget) -> Optional[Bot]: - """获取 Bot""" - bots = BOT_CACHE.get(user) - if not bots: - return - - return random.choice(bots) - - async def get_groups() -> list[dict[str, Any]]: """获取所有群号""" # TODO diff --git a/tests/conftest.py b/tests/conftest.py index 86dd982..c2bdd82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,7 @@ def load_adapters(nonebug_init: None): async def app(tmp_path: Path, request: pytest.FixtureRequest, mocker: MockerFixture): sys.path.append(str(Path(__file__).parent.parent / "src" / "plugins")) + nonebot.require("nonebot_plugin_saa") nonebot.require("nonebot_bison") from nonebot_plugin_datastore.config import plugin_config as datastore_config from nonebot_plugin_datastore.db import create_session, init_db diff --git a/tests/test_get_bot.py b/tests/test_get_bot.py index 3c02a0a..17e7405 100644 --- a/tests/test_get_bot.py +++ b/tests/test_get_bot.py @@ -22,81 +22,3 @@ async def test_get_bots(app: App) -> None: assert botv11 in bot assert botv12 not in bot - - -@pytest.mark.asyncio -@pytest.mark.parametrize("app", [{"refresh_bot": True}], indirect=True) -async def test_refresh_bots(app: App) -> None: - from nonebot import get_driver - from nonebot.adapters.onebot.v11 import Bot as BotV11 - from nonebot.adapters.onebot.v12 import Bot as BotV12 - from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate - - from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots - - async with app.test_api() as ctx: - botv11 = ctx.create_bot(base=BotV11, self_id="v11") - botv12 = ctx.create_bot(base=BotV12, self_id="v12", platform="qq", impl="walle") - - driver = get_driver() - driver._bots = {botv11.self_id: botv11, botv12.self_id: botv12} - - ctx.should_call_api("get_group_list", {}, [{"group_id": 1}]) - ctx.should_call_api("get_friend_list", {}, [{"user_id": 2}]) - - assert get_bot(TargetQQGroup(group_id=1)) is None - assert get_bot(TargetQQPrivate(user_id=2)) is None - - await refresh_bots() - - assert get_bot(TargetQQGroup(group_id=1)) == botv11 - assert get_bot(TargetQQPrivate(user_id=2)) == botv11 - - # 测试获取群列表 - ctx.should_call_api("get_group_list", {}, [{"group_id": 3}]) - - groups = await get_groups() - - assert groups == [{"group_id": 3}] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("app", [{"refresh_bot": True}], indirect=True) -async def test_get_bot_two_bots(app: App) -> None: - from nonebot import get_driver - from nonebot.adapters.onebot.v11 import Bot as BotV11 - from nonebot.adapters.onebot.v12 import Bot as BotV12 - from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate - - from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots - - async with app.test_api() as ctx: - bot1 = ctx.create_bot(base=BotV11, self_id="1") - bot2 = ctx.create_bot(base=BotV11, self_id="2") - botv12 = ctx.create_bot(base=BotV12, self_id="v12", platform="qq", impl="walle") - - driver = get_driver() - driver._bots = {bot1.self_id: bot1, bot2.self_id: bot2, botv12.self_id: botv12} - - ctx.should_call_api("get_group_list", {}, [{"group_id": 1}, {"group_id": 2}]) - ctx.should_call_api("get_friend_list", {}, [{"user_id": 1}, {"user_id": 2}]) - ctx.should_call_api("get_group_list", {}, [{"group_id": 2}, {"group_id": 3}]) - ctx.should_call_api("get_friend_list", {}, [{"user_id": 2}, {"user_id": 3}]) - - await refresh_bots() - - assert get_bot(TargetQQGroup(group_id=0)) is None - assert get_bot(TargetQQGroup(group_id=1)) == bot1 - assert get_bot(TargetQQGroup(group_id=2)) in (bot1, bot2) - assert get_bot(TargetQQGroup(group_id=3)) == bot2 - assert get_bot(TargetQQPrivate(user_id=0)) is None - assert get_bot(TargetQQPrivate(user_id=1)) == bot1 - assert get_bot(TargetQQPrivate(user_id=2)) in (bot1, bot2) - assert get_bot(TargetQQPrivate(user_id=3)) == bot2 - - ctx.should_call_api("get_group_list", {}, [{"group_id": 1}, {"group_id": 2}]) - ctx.should_call_api("get_group_list", {}, [{"group_id": 2}, {"group_id": 3}]) - - groups = await get_groups() - - assert groups == [{"group_id": 1}, {"group_id": 2}, {"group_id": 3}]