🚧 sunset get_bot

This commit is contained in:
felinae98 2023-05-08 21:09:28 +08:00
parent 39c045c63f
commit 5010ca1ac5
7 changed files with 30 additions and 172 deletions

View File

@ -2,6 +2,9 @@ from nonebot.plugin import PluginMetadata, require
require("nonebot_plugin_apscheduler") require("nonebot_plugin_apscheduler")
require("nonebot_plugin_datastore") require("nonebot_plugin_datastore")
require("nonebot_plugin_saa")
import nonebot_plugin_saa
from . import ( from . import (
admin_page, admin_page,
@ -18,6 +21,8 @@ from . import (
from .plugin_config import plugin_config from .plugin_config import plugin_config
__help__version__ = "0.7.3" __help__version__ = "0.7.3"
nonebot_plugin_saa.enable_auto_select_bot()
__help__plugin__name__ = "nonebot_bison" __help__plugin__name__ = "nonebot_bison"
__usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅详情请查看本bot文档或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅" __usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅详情请查看本bot文档或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅"

View File

@ -5,6 +5,7 @@ from fastapi.param_functions import Depends
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.security.oauth2 import OAuth2PasswordBearer from fastapi.security.oauth2 import OAuth2PasswordBearer
from nonebot_plugin_saa import TargetQQGroup from nonebot_plugin_saa import TargetQQGroup
from nonebot_plugin_saa.utils.auto_select_bot import get_bot
from ..apis import check_sub_target from ..apis import check_sub_target
from ..config import ( from ..config import (
@ -17,7 +18,7 @@ from ..config.db_config import SubscribeDupException
from ..platform import platform_manager from ..platform import platform_manager
from ..types import Target as T_Target from ..types import Target as T_Target
from ..types import WeightConfig from ..types import WeightConfig
from ..utils.get_bot import get_bot, get_groups from ..utils.get_bot import get_groups
from .jwt import load_jwt, pack_jwt from .jwt import load_jwt, pack_jwt
from .token_manager import token_manager from .token_manager import token_manager
from .types import ( from .types import (
@ -185,8 +186,7 @@ async def del_group_sub(groupNumber: int, platformName: str, target: str):
async def update_group_sub(groupNumber: int, req: AddSubscribeReq): async def update_group_sub(groupNumber: int, req: AddSubscribeReq):
try: try:
await config.update_subscribe( await config.update_subscribe(
int(groupNumber), TargetQQGroup(group_id=groupNumber),
"group",
req.target, req.target,
req.targetName, req.targetName,
req.platformName, req.platformName,

View File

@ -1,16 +1,15 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Type from typing import Optional, Type
from nonebot.adapters.onebot.v11.bot import Bot
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_apscheduler import scheduler from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_saa.utils.exceptions import NoBotFound
from ..config import config from ..config import config
from ..platform import platform_manager from ..platform import platform_manager
from ..send import send_msgs from ..send import send_msgs
from ..types import Target from ..types import Target
from ..utils import ProcessContext, SchedulerConfig from ..utils import ProcessContext, SchedulerConfig
from ..utils.get_bot import get_bot
@dataclass @dataclass
@ -107,17 +106,15 @@ class Scheduler:
return return
for user, send_list in to_send: for user, send_list in to_send:
bot = get_bot(user)
for send_post in send_list: for send_post in send_list:
logger.info("send to {}: {}".format(user, send_post)) logger.info("send to {}: {}".format(user, send_post))
if not bot: try:
logger.warning("no bot connected")
else:
await send_msgs( await send_msgs(
bot,
user, user,
await send_post.generate_messages(), await send_post.generate_messages(),
) )
except NoBotFound:
logger.warning("no bot connected")
def insert_new_schedulable(self, platform_name: str, target: Target): def insert_new_schedulable(self, platform_name: str, target: Target):
self.pre_weight_val += 1000 self.pre_weight_val += 1000

View File

@ -2,24 +2,23 @@ import asyncio
from collections import deque from collections import deque
from typing import Deque from typing import Deque
from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11.exception import ActionFailed from nonebot.adapters.onebot.v11.exception import ActionFailed
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_saa import AggregatedMessageFactory, MessageFactory, PlatformTarget from nonebot_plugin_saa import AggregatedMessageFactory, MessageFactory, PlatformTarget
from nonebot_plugin_saa.utils.auto_select_bot import refresh_bots
from .plugin_config import plugin_config from .plugin_config import plugin_config
from .utils.get_bot import refresh_bots
Sendable = MessageFactory | AggregatedMessageFactory Sendable = MessageFactory | AggregatedMessageFactory
QUEUE: Deque[tuple[Bot, PlatformTarget, Sendable, int]] = deque() QUEUE: Deque[tuple[PlatformTarget, Sendable, int]] = deque()
MESSGE_SEND_INTERVAL = 1.5 MESSGE_SEND_INTERVAL = 1.5
async def _do_send(bot: "Bot", send_target: PlatformTarget, msg: Sendable): async def _do_send(send_target: PlatformTarget, msg: Sendable):
try: try:
await msg.send_to(send_target, bot) await msg.send_to(send_target)
except ActionFailed: # TODO: catch exception of other adapters except ActionFailed: # TODO: catch exception of other adapters
await refresh_bots() await refresh_bots()
logger.warning("send msg failed, refresh bots") logger.warning("send msg failed, refresh bots")
@ -34,14 +33,14 @@ async def do_send_msgs():
# the length of queue will be 0. # the length of queue will be 0.
# At that time, adding items to queue will trigger a new execution of this func, which is not expected. # At that time, adding items to queue will trigger a new execution of this func, which is not expected.
# So, read from queue first then pop from it # So, read from queue first then pop from it
bot, send_target, msg_factory, retry_time = QUEUE[0] send_target, msg_factory, retry_time = QUEUE[0]
try: try:
await _do_send(bot, send_target, msg_factory) await _do_send(send_target, msg_factory)
except Exception as e: except Exception as e:
await asyncio.sleep(MESSGE_SEND_INTERVAL) await asyncio.sleep(MESSGE_SEND_INTERVAL)
QUEUE.popleft() QUEUE.popleft()
if retry_time > 0: if retry_time > 0:
QUEUE.appendleft((bot, send_target, msg_factory, retry_time - 1)) QUEUE.appendleft((send_target, msg_factory, retry_time - 1))
else: else:
msg_str = str(msg_factory) msg_str = str(msg_factory)
if len(msg_str) > 50: if len(msg_str) > 50:
@ -56,27 +55,27 @@ async def do_send_msgs():
return return
async def _send_msgs_dispatch(bot: Bot, send_target: PlatformTarget, msg: Sendable): async def _send_msgs_dispatch(send_target: PlatformTarget, msg: Sendable):
if plugin_config.bison_use_queue: if plugin_config.bison_use_queue:
QUEUE.append((bot, send_target, msg, plugin_config.bison_resend_times)) QUEUE.append((send_target, msg, plugin_config.bison_resend_times))
# len(QUEUE) before append was 0 # len(QUEUE) before append was 0
if len(QUEUE) == 1: if len(QUEUE) == 1:
asyncio.create_task(do_send_msgs()) asyncio.create_task(do_send_msgs())
else: else:
await _do_send(bot, send_target, msg) await _do_send(send_target, msg)
async def send_msgs(bot: Bot, send_target: PlatformTarget, msgs: list[MessageFactory]): async def send_msgs(send_target: PlatformTarget, msgs: list[MessageFactory]):
if not plugin_config.bison_use_pic_merge: if not plugin_config.bison_use_pic_merge:
for msg in msgs: for msg in msgs:
await _send_msgs_dispatch(bot, send_target, msg) await _send_msgs_dispatch(send_target, msg)
return return
msgs = msgs.copy() msgs = msgs.copy()
if plugin_config.bison_use_pic_merge == 1: if plugin_config.bison_use_pic_merge == 1:
await _send_msgs_dispatch(bot, send_target, msgs.pop(0)) await _send_msgs_dispatch(send_target, msgs.pop(0))
if msgs: if msgs:
if len(msgs) == 1: # 只有一条消息序列就不合并转发 if len(msgs) == 1: # 只有一条消息序列就不合并转发
await _send_msgs_dispatch(bot, send_target, msgs.pop(0)) await _send_msgs_dispatch(send_target, msgs.pop(0))
else: else:
forward_message = AggregatedMessageFactory(list(msgs)) forward_message = AggregatedMessageFactory(list(msgs))
await _send_msgs_dispatch(bot, send_target, forward_message) await _send_msgs_dispatch(send_target, forward_message)

View File

@ -1,18 +1,11 @@
""" 提供获取 Bot 的方法 """ """ 提供获取 Bot 的方法 """
import random
from collections import defaultdict from collections import defaultdict
from typing import Any, Optional from typing import Any
import nonebot import nonebot
from nonebot import get_driver, on_notice
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import Bot as Ob11Bot from nonebot.adapters.onebot.v11 import Bot as Ob11Bot
from nonebot.adapters.onebot.v11 import ( from nonebot_plugin_saa import PlatformTarget
FriendAddNoticeEvent,
GroupDecreaseNoticeEvent,
GroupIncreaseNoticeEvent,
)
from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, TargetQQPrivate
GROUP: dict[int, list[Bot]] = {} GROUP: dict[int, list[Bot]] = {}
USER: dict[int, list[Bot]] = {} USER: dict[int, list[Bot]] = {}
@ -29,65 +22,6 @@ def get_bots() -> list[Bot]:
return bots return bots
async def _refresh_ob11(bot: Ob11Bot):
# 获取群列表
groups = await bot.get_group_list()
for group in groups:
group_id = group["group_id"]
target = TargetQQGroup(group_id=group_id)
BOT_CACHE[target].append(bot)
# 获取好友列表
users = await bot.get_friend_list()
for user in users:
user_id = user["user_id"]
target = TargetQQPrivate(user_id=user_id)
BOT_CACHE[target].append(bot)
async def refresh_bots():
"""刷新缓存的 Bot 数据"""
BOT_CACHE.clear()
for bot in get_bots():
match bot:
case Ob11Bot():
await _refresh_ob11(bot)
driver = get_driver()
@driver.on_bot_connect
@driver.on_bot_disconnect
async def _(bot: Bot):
await refresh_bots()
change_notice = on_notice(priority=1)
@change_notice.handle()
async def _(bot: Bot, event: FriendAddNoticeEvent):
await refresh_bots()
# 01-06 16:56:51 [SUCCESS] nonebot | OneBot V11 **** | [notice.group_increase.approve]: {'time': 1672995411, 'self_id': ****, 'post_type': 'notice', 'notice_type': 'group_increase', 'sub_type': 'approve', 'user_id': ****, 'group_id': ****, 'operator_id': 0}
# 01-06 16:58:09 [SUCCESS] nonebot | OneBot V11 **** | [notice.group_decrease.kick_me]: {'time': 1672995489, 'self_id': ****, 'post_type': 'notice', 'notice_type': 'group_decrease', 'sub_type': 'kick_me', 'user_id': ****, 'group_id': ****, 'operator_id': ****}
@change_notice.handle()
async def _(bot: Bot, event: GroupDecreaseNoticeEvent | GroupIncreaseNoticeEvent):
if bot.self_id == event.user_id:
await refresh_bots()
def get_bot(user: PlatformTarget) -> Optional[Bot]:
"""获取 Bot"""
bots = BOT_CACHE.get(user)
if not bots:
return
return random.choice(bots)
async def get_groups() -> list[dict[str, Any]]: async def get_groups() -> list[dict[str, Any]]:
"""获取所有群号""" """获取所有群号"""
# TODO # TODO

View File

@ -30,6 +30,7 @@ def load_adapters(nonebug_init: None):
async def app(tmp_path: Path, request: pytest.FixtureRequest, mocker: MockerFixture): async def app(tmp_path: Path, request: pytest.FixtureRequest, mocker: MockerFixture):
sys.path.append(str(Path(__file__).parent.parent / "src" / "plugins")) sys.path.append(str(Path(__file__).parent.parent / "src" / "plugins"))
nonebot.require("nonebot_plugin_saa")
nonebot.require("nonebot_bison") nonebot.require("nonebot_bison")
from nonebot_plugin_datastore.config import plugin_config as datastore_config from nonebot_plugin_datastore.config import plugin_config as datastore_config
from nonebot_plugin_datastore.db import create_session, init_db from nonebot_plugin_datastore.db import create_session, init_db

View File

@ -22,81 +22,3 @@ async def test_get_bots(app: App) -> None:
assert botv11 in bot assert botv11 in bot
assert botv12 not in bot assert botv12 not in bot
@pytest.mark.asyncio
@pytest.mark.parametrize("app", [{"refresh_bot": True}], indirect=True)
async def test_refresh_bots(app: App) -> None:
from nonebot import get_driver
from nonebot.adapters.onebot.v11 import Bot as BotV11
from nonebot.adapters.onebot.v12 import Bot as BotV12
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots
async with app.test_api() as ctx:
botv11 = ctx.create_bot(base=BotV11, self_id="v11")
botv12 = ctx.create_bot(base=BotV12, self_id="v12", platform="qq", impl="walle")
driver = get_driver()
driver._bots = {botv11.self_id: botv11, botv12.self_id: botv12}
ctx.should_call_api("get_group_list", {}, [{"group_id": 1}])
ctx.should_call_api("get_friend_list", {}, [{"user_id": 2}])
assert get_bot(TargetQQGroup(group_id=1)) is None
assert get_bot(TargetQQPrivate(user_id=2)) is None
await refresh_bots()
assert get_bot(TargetQQGroup(group_id=1)) == botv11
assert get_bot(TargetQQPrivate(user_id=2)) == botv11
# 测试获取群列表
ctx.should_call_api("get_group_list", {}, [{"group_id": 3}])
groups = await get_groups()
assert groups == [{"group_id": 3}]
@pytest.mark.asyncio
@pytest.mark.parametrize("app", [{"refresh_bot": True}], indirect=True)
async def test_get_bot_two_bots(app: App) -> None:
from nonebot import get_driver
from nonebot.adapters.onebot.v11 import Bot as BotV11
from nonebot.adapters.onebot.v12 import Bot as BotV12
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots
async with app.test_api() as ctx:
bot1 = ctx.create_bot(base=BotV11, self_id="1")
bot2 = ctx.create_bot(base=BotV11, self_id="2")
botv12 = ctx.create_bot(base=BotV12, self_id="v12", platform="qq", impl="walle")
driver = get_driver()
driver._bots = {bot1.self_id: bot1, bot2.self_id: bot2, botv12.self_id: botv12}
ctx.should_call_api("get_group_list", {}, [{"group_id": 1}, {"group_id": 2}])
ctx.should_call_api("get_friend_list", {}, [{"user_id": 1}, {"user_id": 2}])
ctx.should_call_api("get_group_list", {}, [{"group_id": 2}, {"group_id": 3}])
ctx.should_call_api("get_friend_list", {}, [{"user_id": 2}, {"user_id": 3}])
await refresh_bots()
assert get_bot(TargetQQGroup(group_id=0)) is None
assert get_bot(TargetQQGroup(group_id=1)) == bot1
assert get_bot(TargetQQGroup(group_id=2)) in (bot1, bot2)
assert get_bot(TargetQQGroup(group_id=3)) == bot2
assert get_bot(TargetQQPrivate(user_id=0)) is None
assert get_bot(TargetQQPrivate(user_id=1)) == bot1
assert get_bot(TargetQQPrivate(user_id=2)) in (bot1, bot2)
assert get_bot(TargetQQPrivate(user_id=3)) == bot2
ctx.should_call_api("get_group_list", {}, [{"group_id": 1}, {"group_id": 2}])
ctx.should_call_api("get_group_list", {}, [{"group_id": 2}, {"group_id": 3}])
groups = await get_groups()
assert groups == [{"group_id": 1}, {"group_id": 2}, {"group_id": 3}]