mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-05 19:36:43 +08:00
🚧 sunset get_bot
This commit is contained in:
parent
39c045c63f
commit
5010ca1ac5
@ -2,6 +2,9 @@ from nonebot.plugin import PluginMetadata, require
|
|||||||
|
|
||||||
require("nonebot_plugin_apscheduler")
|
require("nonebot_plugin_apscheduler")
|
||||||
require("nonebot_plugin_datastore")
|
require("nonebot_plugin_datastore")
|
||||||
|
require("nonebot_plugin_saa")
|
||||||
|
|
||||||
|
import nonebot_plugin_saa
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
admin_page,
|
admin_page,
|
||||||
@ -18,6 +21,8 @@ from . import (
|
|||||||
from .plugin_config import plugin_config
|
from .plugin_config import plugin_config
|
||||||
|
|
||||||
__help__version__ = "0.7.3"
|
__help__version__ = "0.7.3"
|
||||||
|
nonebot_plugin_saa.enable_auto_select_bot()
|
||||||
|
|
||||||
__help__plugin__name__ = "nonebot_bison"
|
__help__plugin__name__ = "nonebot_bison"
|
||||||
__usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档,或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅"
|
__usage__ = f"本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档,或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号,发送“查询订阅”或“删除订阅”管理订阅"
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from fastapi.param_functions import Depends
|
|||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.security.oauth2 import OAuth2PasswordBearer
|
from fastapi.security.oauth2 import OAuth2PasswordBearer
|
||||||
from nonebot_plugin_saa import TargetQQGroup
|
from nonebot_plugin_saa import TargetQQGroup
|
||||||
|
from nonebot_plugin_saa.utils.auto_select_bot import get_bot
|
||||||
|
|
||||||
from ..apis import check_sub_target
|
from ..apis import check_sub_target
|
||||||
from ..config import (
|
from ..config import (
|
||||||
@ -17,7 +18,7 @@ from ..config.db_config import SubscribeDupException
|
|||||||
from ..platform import platform_manager
|
from ..platform import platform_manager
|
||||||
from ..types import Target as T_Target
|
from ..types import Target as T_Target
|
||||||
from ..types import WeightConfig
|
from ..types import WeightConfig
|
||||||
from ..utils.get_bot import get_bot, get_groups
|
from ..utils.get_bot import get_groups
|
||||||
from .jwt import load_jwt, pack_jwt
|
from .jwt import load_jwt, pack_jwt
|
||||||
from .token_manager import token_manager
|
from .token_manager import token_manager
|
||||||
from .types import (
|
from .types import (
|
||||||
@ -185,8 +186,7 @@ async def del_group_sub(groupNumber: int, platformName: str, target: str):
|
|||||||
async def update_group_sub(groupNumber: int, req: AddSubscribeReq):
|
async def update_group_sub(groupNumber: int, req: AddSubscribeReq):
|
||||||
try:
|
try:
|
||||||
await config.update_subscribe(
|
await config.update_subscribe(
|
||||||
int(groupNumber),
|
TargetQQGroup(group_id=groupNumber),
|
||||||
"group",
|
|
||||||
req.target,
|
req.target,
|
||||||
req.targetName,
|
req.targetName,
|
||||||
req.platformName,
|
req.platformName,
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from nonebot.adapters.onebot.v11.bot import Bot
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot_plugin_apscheduler import scheduler
|
from nonebot_plugin_apscheduler import scheduler
|
||||||
|
from nonebot_plugin_saa.utils.exceptions import NoBotFound
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..platform import platform_manager
|
from ..platform import platform_manager
|
||||||
from ..send import send_msgs
|
from ..send import send_msgs
|
||||||
from ..types import Target
|
from ..types import Target
|
||||||
from ..utils import ProcessContext, SchedulerConfig
|
from ..utils import ProcessContext, SchedulerConfig
|
||||||
from ..utils.get_bot import get_bot
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -107,17 +106,15 @@ class Scheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for user, send_list in to_send:
|
for user, send_list in to_send:
|
||||||
bot = get_bot(user)
|
|
||||||
for send_post in send_list:
|
for send_post in send_list:
|
||||||
logger.info("send to {}: {}".format(user, send_post))
|
logger.info("send to {}: {}".format(user, send_post))
|
||||||
if not bot:
|
try:
|
||||||
logger.warning("no bot connected")
|
|
||||||
else:
|
|
||||||
await send_msgs(
|
await send_msgs(
|
||||||
bot,
|
|
||||||
user,
|
user,
|
||||||
await send_post.generate_messages(),
|
await send_post.generate_messages(),
|
||||||
)
|
)
|
||||||
|
except NoBotFound:
|
||||||
|
logger.warning("no bot connected")
|
||||||
|
|
||||||
def insert_new_schedulable(self, platform_name: str, target: Target):
|
def insert_new_schedulable(self, platform_name: str, target: Target):
|
||||||
self.pre_weight_val += 1000
|
self.pre_weight_val += 1000
|
||||||
|
@ -2,24 +2,23 @@ import asyncio
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Deque
|
from typing import Deque
|
||||||
|
|
||||||
from nonebot.adapters import Bot
|
|
||||||
from nonebot.adapters.onebot.v11.exception import ActionFailed
|
from nonebot.adapters.onebot.v11.exception import ActionFailed
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot_plugin_saa import AggregatedMessageFactory, MessageFactory, PlatformTarget
|
from nonebot_plugin_saa import AggregatedMessageFactory, MessageFactory, PlatformTarget
|
||||||
|
from nonebot_plugin_saa.utils.auto_select_bot import refresh_bots
|
||||||
|
|
||||||
from .plugin_config import plugin_config
|
from .plugin_config import plugin_config
|
||||||
from .utils.get_bot import refresh_bots
|
|
||||||
|
|
||||||
Sendable = MessageFactory | AggregatedMessageFactory
|
Sendable = MessageFactory | AggregatedMessageFactory
|
||||||
|
|
||||||
QUEUE: Deque[tuple[Bot, PlatformTarget, Sendable, int]] = deque()
|
QUEUE: Deque[tuple[PlatformTarget, Sendable, int]] = deque()
|
||||||
|
|
||||||
MESSGE_SEND_INTERVAL = 1.5
|
MESSGE_SEND_INTERVAL = 1.5
|
||||||
|
|
||||||
|
|
||||||
async def _do_send(bot: "Bot", send_target: PlatformTarget, msg: Sendable):
|
async def _do_send(send_target: PlatformTarget, msg: Sendable):
|
||||||
try:
|
try:
|
||||||
await msg.send_to(send_target, bot)
|
await msg.send_to(send_target)
|
||||||
except ActionFailed: # TODO: catch exception of other adapters
|
except ActionFailed: # TODO: catch exception of other adapters
|
||||||
await refresh_bots()
|
await refresh_bots()
|
||||||
logger.warning("send msg failed, refresh bots")
|
logger.warning("send msg failed, refresh bots")
|
||||||
@ -34,14 +33,14 @@ async def do_send_msgs():
|
|||||||
# the length of queue will be 0.
|
# the length of queue will be 0.
|
||||||
# At that time, adding items to queue will trigger a new execution of this func, which is not expected.
|
# At that time, adding items to queue will trigger a new execution of this func, which is not expected.
|
||||||
# So, read from queue first then pop from it
|
# So, read from queue first then pop from it
|
||||||
bot, send_target, msg_factory, retry_time = QUEUE[0]
|
send_target, msg_factory, retry_time = QUEUE[0]
|
||||||
try:
|
try:
|
||||||
await _do_send(bot, send_target, msg_factory)
|
await _do_send(send_target, msg_factory)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await asyncio.sleep(MESSGE_SEND_INTERVAL)
|
await asyncio.sleep(MESSGE_SEND_INTERVAL)
|
||||||
QUEUE.popleft()
|
QUEUE.popleft()
|
||||||
if retry_time > 0:
|
if retry_time > 0:
|
||||||
QUEUE.appendleft((bot, send_target, msg_factory, retry_time - 1))
|
QUEUE.appendleft((send_target, msg_factory, retry_time - 1))
|
||||||
else:
|
else:
|
||||||
msg_str = str(msg_factory)
|
msg_str = str(msg_factory)
|
||||||
if len(msg_str) > 50:
|
if len(msg_str) > 50:
|
||||||
@ -56,27 +55,27 @@ async def do_send_msgs():
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
async def _send_msgs_dispatch(bot: Bot, send_target: PlatformTarget, msg: Sendable):
|
async def _send_msgs_dispatch(send_target: PlatformTarget, msg: Sendable):
|
||||||
if plugin_config.bison_use_queue:
|
if plugin_config.bison_use_queue:
|
||||||
QUEUE.append((bot, send_target, msg, plugin_config.bison_resend_times))
|
QUEUE.append((send_target, msg, plugin_config.bison_resend_times))
|
||||||
# len(QUEUE) before append was 0
|
# len(QUEUE) before append was 0
|
||||||
if len(QUEUE) == 1:
|
if len(QUEUE) == 1:
|
||||||
asyncio.create_task(do_send_msgs())
|
asyncio.create_task(do_send_msgs())
|
||||||
else:
|
else:
|
||||||
await _do_send(bot, send_target, msg)
|
await _do_send(send_target, msg)
|
||||||
|
|
||||||
|
|
||||||
async def send_msgs(bot: Bot, send_target: PlatformTarget, msgs: list[MessageFactory]):
|
async def send_msgs(send_target: PlatformTarget, msgs: list[MessageFactory]):
|
||||||
if not plugin_config.bison_use_pic_merge:
|
if not plugin_config.bison_use_pic_merge:
|
||||||
for msg in msgs:
|
for msg in msgs:
|
||||||
await _send_msgs_dispatch(bot, send_target, msg)
|
await _send_msgs_dispatch(send_target, msg)
|
||||||
return
|
return
|
||||||
msgs = msgs.copy()
|
msgs = msgs.copy()
|
||||||
if plugin_config.bison_use_pic_merge == 1:
|
if plugin_config.bison_use_pic_merge == 1:
|
||||||
await _send_msgs_dispatch(bot, send_target, msgs.pop(0))
|
await _send_msgs_dispatch(send_target, msgs.pop(0))
|
||||||
if msgs:
|
if msgs:
|
||||||
if len(msgs) == 1: # 只有一条消息序列就不合并转发
|
if len(msgs) == 1: # 只有一条消息序列就不合并转发
|
||||||
await _send_msgs_dispatch(bot, send_target, msgs.pop(0))
|
await _send_msgs_dispatch(send_target, msgs.pop(0))
|
||||||
else:
|
else:
|
||||||
forward_message = AggregatedMessageFactory(list(msgs))
|
forward_message = AggregatedMessageFactory(list(msgs))
|
||||||
await _send_msgs_dispatch(bot, send_target, forward_message)
|
await _send_msgs_dispatch(send_target, forward_message)
|
||||||
|
@ -1,18 +1,11 @@
|
|||||||
""" 提供获取 Bot 的方法 """
|
""" 提供获取 Bot 的方法 """
|
||||||
import random
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import nonebot
|
import nonebot
|
||||||
from nonebot import get_driver, on_notice
|
|
||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
from nonebot.adapters.onebot.v11 import Bot as Ob11Bot
|
from nonebot.adapters.onebot.v11 import Bot as Ob11Bot
|
||||||
from nonebot.adapters.onebot.v11 import (
|
from nonebot_plugin_saa import PlatformTarget
|
||||||
FriendAddNoticeEvent,
|
|
||||||
GroupDecreaseNoticeEvent,
|
|
||||||
GroupIncreaseNoticeEvent,
|
|
||||||
)
|
|
||||||
from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, TargetQQPrivate
|
|
||||||
|
|
||||||
GROUP: dict[int, list[Bot]] = {}
|
GROUP: dict[int, list[Bot]] = {}
|
||||||
USER: dict[int, list[Bot]] = {}
|
USER: dict[int, list[Bot]] = {}
|
||||||
@ -29,65 +22,6 @@ def get_bots() -> list[Bot]:
|
|||||||
return bots
|
return bots
|
||||||
|
|
||||||
|
|
||||||
async def _refresh_ob11(bot: Ob11Bot):
|
|
||||||
# 获取群列表
|
|
||||||
groups = await bot.get_group_list()
|
|
||||||
for group in groups:
|
|
||||||
group_id = group["group_id"]
|
|
||||||
target = TargetQQGroup(group_id=group_id)
|
|
||||||
BOT_CACHE[target].append(bot)
|
|
||||||
|
|
||||||
# 获取好友列表
|
|
||||||
users = await bot.get_friend_list()
|
|
||||||
for user in users:
|
|
||||||
user_id = user["user_id"]
|
|
||||||
target = TargetQQPrivate(user_id=user_id)
|
|
||||||
BOT_CACHE[target].append(bot)
|
|
||||||
|
|
||||||
|
|
||||||
async def refresh_bots():
|
|
||||||
"""刷新缓存的 Bot 数据"""
|
|
||||||
BOT_CACHE.clear()
|
|
||||||
for bot in get_bots():
|
|
||||||
match bot:
|
|
||||||
case Ob11Bot():
|
|
||||||
await _refresh_ob11(bot)
|
|
||||||
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
|
|
||||||
|
|
||||||
@driver.on_bot_connect
|
|
||||||
@driver.on_bot_disconnect
|
|
||||||
async def _(bot: Bot):
|
|
||||||
await refresh_bots()
|
|
||||||
|
|
||||||
|
|
||||||
change_notice = on_notice(priority=1)
|
|
||||||
|
|
||||||
|
|
||||||
@change_notice.handle()
|
|
||||||
async def _(bot: Bot, event: FriendAddNoticeEvent):
|
|
||||||
await refresh_bots()
|
|
||||||
|
|
||||||
|
|
||||||
# 01-06 16:56:51 [SUCCESS] nonebot | OneBot V11 **** | [notice.group_increase.approve]: {'time': 1672995411, 'self_id': ****, 'post_type': 'notice', 'notice_type': 'group_increase', 'sub_type': 'approve', 'user_id': ****, 'group_id': ****, 'operator_id': 0}
|
|
||||||
# 01-06 16:58:09 [SUCCESS] nonebot | OneBot V11 **** | [notice.group_decrease.kick_me]: {'time': 1672995489, 'self_id': ****, 'post_type': 'notice', 'notice_type': 'group_decrease', 'sub_type': 'kick_me', 'user_id': ****, 'group_id': ****, 'operator_id': ****}
|
|
||||||
@change_notice.handle()
|
|
||||||
async def _(bot: Bot, event: GroupDecreaseNoticeEvent | GroupIncreaseNoticeEvent):
|
|
||||||
if bot.self_id == event.user_id:
|
|
||||||
await refresh_bots()
|
|
||||||
|
|
||||||
|
|
||||||
def get_bot(user: PlatformTarget) -> Optional[Bot]:
|
|
||||||
"""获取 Bot"""
|
|
||||||
bots = BOT_CACHE.get(user)
|
|
||||||
if not bots:
|
|
||||||
return
|
|
||||||
|
|
||||||
return random.choice(bots)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_groups() -> list[dict[str, Any]]:
|
async def get_groups() -> list[dict[str, Any]]:
|
||||||
"""获取所有群号"""
|
"""获取所有群号"""
|
||||||
# TODO
|
# TODO
|
||||||
|
@ -30,6 +30,7 @@ def load_adapters(nonebug_init: None):
|
|||||||
async def app(tmp_path: Path, request: pytest.FixtureRequest, mocker: MockerFixture):
|
async def app(tmp_path: Path, request: pytest.FixtureRequest, mocker: MockerFixture):
|
||||||
sys.path.append(str(Path(__file__).parent.parent / "src" / "plugins"))
|
sys.path.append(str(Path(__file__).parent.parent / "src" / "plugins"))
|
||||||
|
|
||||||
|
nonebot.require("nonebot_plugin_saa")
|
||||||
nonebot.require("nonebot_bison")
|
nonebot.require("nonebot_bison")
|
||||||
from nonebot_plugin_datastore.config import plugin_config as datastore_config
|
from nonebot_plugin_datastore.config import plugin_config as datastore_config
|
||||||
from nonebot_plugin_datastore.db import create_session, init_db
|
from nonebot_plugin_datastore.db import create_session, init_db
|
||||||
|
@ -22,81 +22,3 @@ async def test_get_bots(app: App) -> None:
|
|||||||
|
|
||||||
assert botv11 in bot
|
assert botv11 in bot
|
||||||
assert botv12 not in bot
|
assert botv12 not in bot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("app", [{"refresh_bot": True}], indirect=True)
|
|
||||||
async def test_refresh_bots(app: App) -> None:
|
|
||||||
from nonebot import get_driver
|
|
||||||
from nonebot.adapters.onebot.v11 import Bot as BotV11
|
|
||||||
from nonebot.adapters.onebot.v12 import Bot as BotV12
|
|
||||||
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
|
|
||||||
|
|
||||||
from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots
|
|
||||||
|
|
||||||
async with app.test_api() as ctx:
|
|
||||||
botv11 = ctx.create_bot(base=BotV11, self_id="v11")
|
|
||||||
botv12 = ctx.create_bot(base=BotV12, self_id="v12", platform="qq", impl="walle")
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
driver._bots = {botv11.self_id: botv11, botv12.self_id: botv12}
|
|
||||||
|
|
||||||
ctx.should_call_api("get_group_list", {}, [{"group_id": 1}])
|
|
||||||
ctx.should_call_api("get_friend_list", {}, [{"user_id": 2}])
|
|
||||||
|
|
||||||
assert get_bot(TargetQQGroup(group_id=1)) is None
|
|
||||||
assert get_bot(TargetQQPrivate(user_id=2)) is None
|
|
||||||
|
|
||||||
await refresh_bots()
|
|
||||||
|
|
||||||
assert get_bot(TargetQQGroup(group_id=1)) == botv11
|
|
||||||
assert get_bot(TargetQQPrivate(user_id=2)) == botv11
|
|
||||||
|
|
||||||
# 测试获取群列表
|
|
||||||
ctx.should_call_api("get_group_list", {}, [{"group_id": 3}])
|
|
||||||
|
|
||||||
groups = await get_groups()
|
|
||||||
|
|
||||||
assert groups == [{"group_id": 3}]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("app", [{"refresh_bot": True}], indirect=True)
|
|
||||||
async def test_get_bot_two_bots(app: App) -> None:
|
|
||||||
from nonebot import get_driver
|
|
||||||
from nonebot.adapters.onebot.v11 import Bot as BotV11
|
|
||||||
from nonebot.adapters.onebot.v12 import Bot as BotV12
|
|
||||||
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
|
|
||||||
|
|
||||||
from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots
|
|
||||||
|
|
||||||
async with app.test_api() as ctx:
|
|
||||||
bot1 = ctx.create_bot(base=BotV11, self_id="1")
|
|
||||||
bot2 = ctx.create_bot(base=BotV11, self_id="2")
|
|
||||||
botv12 = ctx.create_bot(base=BotV12, self_id="v12", platform="qq", impl="walle")
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
driver._bots = {bot1.self_id: bot1, bot2.self_id: bot2, botv12.self_id: botv12}
|
|
||||||
|
|
||||||
ctx.should_call_api("get_group_list", {}, [{"group_id": 1}, {"group_id": 2}])
|
|
||||||
ctx.should_call_api("get_friend_list", {}, [{"user_id": 1}, {"user_id": 2}])
|
|
||||||
ctx.should_call_api("get_group_list", {}, [{"group_id": 2}, {"group_id": 3}])
|
|
||||||
ctx.should_call_api("get_friend_list", {}, [{"user_id": 2}, {"user_id": 3}])
|
|
||||||
|
|
||||||
await refresh_bots()
|
|
||||||
|
|
||||||
assert get_bot(TargetQQGroup(group_id=0)) is None
|
|
||||||
assert get_bot(TargetQQGroup(group_id=1)) == bot1
|
|
||||||
assert get_bot(TargetQQGroup(group_id=2)) in (bot1, bot2)
|
|
||||||
assert get_bot(TargetQQGroup(group_id=3)) == bot2
|
|
||||||
assert get_bot(TargetQQPrivate(user_id=0)) is None
|
|
||||||
assert get_bot(TargetQQPrivate(user_id=1)) == bot1
|
|
||||||
assert get_bot(TargetQQPrivate(user_id=2)) in (bot1, bot2)
|
|
||||||
assert get_bot(TargetQQPrivate(user_id=3)) == bot2
|
|
||||||
|
|
||||||
ctx.should_call_api("get_group_list", {}, [{"group_id": 1}, {"group_id": 2}])
|
|
||||||
ctx.should_call_api("get_group_list", {}, [{"group_id": 2}, {"group_id": 3}])
|
|
||||||
|
|
||||||
groups = await get_groups()
|
|
||||||
|
|
||||||
assert groups == [{"group_id": 1}, {"group_id": 2}, {"group_id": 3}]
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user