Merge branch 'dev' into feat/db

This commit is contained in:
felinae98
2022-03-22 20:07:23 +08:00
24 changed files with 1123 additions and 261 deletions
+45 -43
View File
@@ -13,9 +13,10 @@ from nonebot.log import logger
from nonebot.matcher import Matcher
from nonebot.params import Depends, EventPlainText, EventToMe
from nonebot.permission import SUPERUSER
from nonebot.rule import to_me
from nonebot.typing import T_State
from .config import config
from .config import Config
from .platform import check_sub_target, platform_manager
from .plugin_config import plugin_config
from .types import Category, Target, User
@@ -81,7 +82,7 @@ def do_add_sub(add_sub: Type[Matcher]):
for platform_name in common_platform
]
)
+ "要查看全部平台请输入:“全部”"
+ "要查看全部平台请输入:“全部”\n中止订阅过程请输入:“取消”"
)
async def parse_platform(event: MessageEvent, state: T_State) -> None:
@@ -96,6 +97,8 @@ def do_add_sub(add_sub: Type[Matcher]):
]
)
await add_sub.reject(message)
elif platform == "取消":
await add_sub.finish("已中止订阅")
elif platform in platform_manager:
state["platform"] = platform
else:
@@ -106,9 +109,7 @@ def do_add_sub(add_sub: Type[Matcher]):
)
async def init_id(state: T_State):
if platform_manager[state["platform"]].has_target:
state[
"_prompt"
] = "请输入订阅用户的id,详情查阅https://nonebot-bison.vercel.app/usage/#%E6%89%80%E6%94%AF%E6%8C%81%E5%B9%B3%E5%8F%B0%E7%9A%84uid"
state["_prompt"] = "请输入订阅用户的id:\n查询id获取方法请回复:“查询”"
else:
state["id"] = "default"
state["name"] = await platform_manager[state["platform"]].get_target_name(
@@ -120,13 +121,32 @@ def do_add_sub(add_sub: Type[Matcher]):
return
target = str(event.get_message()).strip()
try:
if target == "查询":
raise LookupError
if target == "取消":
raise KeyboardInterrupt
name = await check_sub_target(state["platform"], target)
if not name:
raise ValueError
state["id"] = target
state["name"] = name
except:
except (LookupError):
url = "https://nonebot-bison.vercel.app/usage/#%E6%89%80%E6%94%AF%E6%8C%81%E5%B9%B3%E5%8F%B0%E7%9A%84-uid"
title = "Bison所支持的平台UID"
content = "查询相关平台的uid格式或获取方式"
image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg"
getId_share = f"[CQ:share,url={url},title={title},content={content},image={image}]" # 缩短字符串格式长度,以及方便后续修改为消息段格式
await add_sub.reject(Message(getId_share))
except (KeyboardInterrupt):
await add_sub.finish("已中止订阅")
except (ValueError):
await add_sub.reject("id输入错误")
else:
await add_sub.send(
"即将订阅的用户为:{} {} {}\n如有错误请输入“取消”重新订阅".format(
state["platform"], state["name"], state["id"]
)
)
@add_sub.got("id", _gen_prompt_template("{_prompt}"), [Depends(parse_id)])
async def init_cat(state: T_State):
@@ -142,7 +162,9 @@ def do_add_sub(add_sub: Type[Matcher]):
return
res = []
for cat in str(event.get_message()).strip().split():
if cat not in platform_manager[state["platform"]].reverse_category:
if cat == "取消":
await add_sub.finish("已中止订阅")
elif cat not in platform_manager[state["platform"]].reverse_category:
await add_sub.reject("不支持 {}".format(cat))
res.append(platform_manager[state["platform"]].reverse_category[cat])
state["cats"] = res
@@ -157,6 +179,8 @@ def do_add_sub(add_sub: Type[Matcher]):
async def parser_tags(event: MessageEvent, state: T_State):
if not isinstance(state["tags"], Message):
return
if str(event.get_message()).strip() == "取消": # 一般不会有叫 取消 的tag吧
await add_sub.finish("已中止订阅")
if str(event.get_message()).strip() == "全部标签":
state["tags"] = []
else:
@@ -164,6 +188,7 @@ def do_add_sub(add_sub: Type[Matcher]):
@add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)])
async def add_sub_process(event: Event, state: T_State):
config = Config()
user = state.get("target_user_info")
assert isinstance(user, User)
config.add_subscribe(
@@ -185,6 +210,7 @@ def do_query_sub(query_sub: Type[Matcher]):
@query_sub.handle()
async def _(state: T_State):
config: Config = Config()
user_info = state["target_user_info"]
assert isinstance(user_info, User)
sub_list = config.list_subscribe(
@@ -215,6 +241,7 @@ def do_del_sub(del_sub: Type[Matcher]):
@del_sub.handle()
async def send_list(bot: Bot, event: Event, state: T_State):
config: Config = Config()
user_info = state["target_user_info"]
assert isinstance(user_info, User)
sub_list = config.list_subscribe(
@@ -249,6 +276,7 @@ def do_del_sub(del_sub: Type[Matcher]):
async def do_del(event: Event, state: T_State):
try:
index = int(str(event.get_message()).strip())
config = Config()
user_info = state["target_user_info"]
assert isinstance(user_info, User)
config.del_subscribe(
@@ -288,11 +316,16 @@ del_sub_matcher = on_command(
del_sub_matcher.handle()(set_target_user_info)
do_del_sub(del_sub_matcher)
group_manage_matcher = on_command("群管理")
group_manage_matcher = on_command("群管理", rule=to_me(), permission=SUPERUSER, priority=4)
@group_manage_matcher.handle()
async def send_group_list(bot: Bot, state: T_State):
async def send_group_list_private(bot: Bot, event: GroupMessageEvent, state: T_State):
await group_manage_matcher.finish(Message("该功能只支持私聊使用,请私聊Bot"))
@group_manage_matcher.handle()
async def send_group_list(bot: Bot, event: PrivateMessageEvent, state: T_State):
groups = await bot.call_api("get_group_list")
res_text = "请选择需要管理的群:\n"
group_number_idx = {}
@@ -349,13 +382,13 @@ async def do_dispatch_command(
"message",
Rule(),
permission,
None,
True,
handlers=None,
temp=True,
priority=0,
block=True,
plugin=matcher.plugin,
module=matcher.module,
expire_time=datetime.now() + bot.config.session_expire_timeout,
expire_time=datetime.now(),
default_state=matcher.state,
default_type_updater=matcher.__class__._default_type_updater,
default_permission_updater=matcher.__class__._default_permission_updater,
@@ -368,34 +401,3 @@ async def do_dispatch_command(
do_del_sub(new_matcher)
new_matcher_ins = new_matcher()
asyncio.create_task(new_matcher_ins.run(bot, event, state))
test_matcher = on_command("testtt")
@test_matcher.handle()
async def _handler(bot: Bot, event: Event, matcher: Matcher, state: T_State):
permission = await matcher.update_permission(bot, event)
new_matcher = Matcher.new(
"message",
Rule(),
permission,
None,
True,
priority=0,
block=True,
plugin=matcher.plugin,
module=matcher.module,
expire_time=datetime.now() + bot.config.session_expire_timeout,
default_state=matcher.state,
default_type_updater=matcher.__class__._default_type_updater,
default_permission_updater=matcher.__class__._default_permission_updater,
)
async def h():
logger.warning("yes")
await new_matcher.send("666")
new_matcher.handle()(h)
new_matcher_ins = new_matcher()
await new_matcher_ins.run(bot, event, state)
@@ -12,6 +12,8 @@ class PlugConfig(BaseSettings):
bison_filter_log: bool = False
bison_to_me: bool = True
bison_skip_browser_check: bool = False
bison_use_pic_merge: int = 0 # 多图片时启用图片合并转发(仅限群),当bison_use_queue为False时该配置不会生效
# 0:不启用;1:首条消息单独发送,剩余照片合并转发;2以及以上:所有消息全部合并转发
bison_resend_times: int = 0
class Config:
+13 -11
View File
@@ -24,7 +24,7 @@ class Post:
pics: list[Union[str, bytes]] = field(default_factory=list)
extra_msg: list[Message] = field(default_factory=list)
_message: Optional[list] = None
_message: Optional[list[Message]] = None
def _use_pic(self):
if not self.override_use_pic is None:
@@ -107,10 +107,10 @@ class Post:
self.pics = self.pics[matrix[0] * matrix[1] :]
self.pics.insert(0, target_io.getvalue())
async def generate_messages(self):
async def generate_messages(self) -> list[Message]:
if self._message is None:
await self._pic_merge()
msgs = []
msg_segments: list[MessageSegment] = []
text = ""
if self.text:
if self._use_pic():
@@ -123,22 +123,24 @@ class Post:
if self.target_name:
text += " {}".format(self.target_name)
if self._use_pic():
msgs.append(await parse_text(text))
msg_segments.append(await parse_text(text))
if not self.target_type == "rss" and self.url:
msgs.append(MessageSegment.text(self.url))
msg_segments.append(MessageSegment.text(self.url))
else:
if self.url:
text += " \n详情: {}".format(self.url)
msgs.append(MessageSegment.text(text))
msg_segments.append(MessageSegment.text(text))
for pic in self.pics:
# if isinstance(pic, bytes):
# pic = 'base64://' + base64.b64encode(pic).decode()
# msgs.append(Message("[CQ:image,file={url}]".format(url=pic)))
msgs.append(MessageSegment.image(pic))
msg_segments.append(MessageSegment.image(pic))
if self.compress:
msgs = [reduce(lambda x, y: x.append(y), msgs, Message())]
msgs = [reduce(lambda x, y: x.append(y), msg_segments, Message())]
else:
msgs = list(
map(lambda msg_segment: Message([msg_segment]), msg_segments)
)
msgs.extend(self.extra_msg)
self._message = msgs
assert len(self._message) > 0, f"message list empty, {self}"
return self._message
def __str__(self):
+69 -9
View File
@@ -1,23 +1,36 @@
import time
from typing import Literal, Union
from nonebot.adapters import Message, MessageSegment
from nonebot.adapters.onebot.v11.bot import Bot
from nonebot.adapters.onebot.v11.message import Message, MessageSegment
from nonebot.log import logger
from .plugin_config import plugin_config
QUEUE = []
QUEUE: list[
tuple[
Bot,
int,
Literal["private", "group", "group-forward"],
Union[str, Message],
int,
]
] = []
LAST_SEND_TIME = time.time()
async def _do_send(
bot: "Bot", user: str, user_type: str, msg: Union[str, Message, MessageSegment]
bot: "Bot",
user: int,
user_type: Literal["group", "private", "group-forward"],
msg: Union[str, Message],
):
if user_type == "group":
await bot.call_api("send_group_msg", group_id=user, message=msg)
await bot.send_group_msg(group_id=user, message=msg)
elif user_type == "private":
await bot.call_api("send_private_msg", user_id=user, message=msg)
await bot.send_private_msg(user_id=user, message=msg)
elif user_type == "group-forward":
await bot.send_group_forward_msg(group_id=user, messages=msg)
async def do_send_msgs():
@@ -39,10 +52,57 @@ async def do_send_msgs():
LAST_SEND_TIME = time.time()
async def send_msgs(bot: Bot, user, user_type: Literal["private", "group"], msgs: list):
async def _send_msgs_dispatch(
bot: Bot,
user,
user_type: Literal["private", "group", "group-forward"],
msg: Union[str, Message],
):
if plugin_config.bison_use_queue:
for msg in msgs:
QUEUE.append((bot, user, user_type, msg, plugin_config.bison_resend_times))
QUEUE.append((bot, user, user_type, msg, plugin_config.bison_resend_times))
else:
await _do_send(bot, user, user_type, msg)
async def send_msgs(
bot: Bot, user, user_type: Literal["private", "group"], msgs: list[Message]
):
if not plugin_config.bison_use_pic_merge or user_type == "private":
for msg in msgs:
await _do_send(bot, user, user_type, msg)
await _send_msgs_dispatch(bot, user, user_type, msg)
return
msgs = msgs.copy()
if plugin_config.bison_use_pic_merge == 1:
await _send_msgs_dispatch(bot, user, "group", msgs.pop(0))
if msgs:
if len(msgs) == 1: # 只有一条消息序列就不合并转发
await _send_msgs_dispatch(bot, user, "group", msgs.pop(0))
else:
group_bot_info = await bot.get_group_member_info(
group_id=user, user_id=int(bot.self_id), no_cache=True
) # 调用api获取群内bot的相关参数
# forward_msg = Message(
# [
# MessageSegment.node_custom(
# group_bot_info["user_id"],
# nickname=group_bot_info["card"] or group_bot_info["nickname"],
# content=msg,
# )
# for msg in msgs
# ]
# )
# FIXME: Because of https://github.com/nonebot/adapter-onebot/issues/9
forward_msg = [
{
"type": "node",
"data": {
"name": group_bot_info["card"] or group_bot_info["nickname"],
"uin": group_bot_info["user_id"],
"content": msg,
},
}
for msg in msgs
]
await _send_msgs_dispatch(bot, user, "group-forward", forward_msg)