♻️ refactor config manager

This commit is contained in:
felinae98 2023-06-07 02:10:15 +08:00
parent 3bdbed9f45
commit 8a20897fe9

View File

@ -1,9 +1,10 @@
import asyncio
from datetime import datetime
from typing import Optional, Type, cast
from typing import Annotated, Optional, Type, cast
from nonebot import on_command
from nonebot.adapters.onebot.v11 import Bot, Event, MessageEvent
from nonebot import logger, on_command
from nonebot.adapters import Bot, Event
from nonebot.adapters.onebot.v11 import MessageEvent
from nonebot.adapters.onebot.v11.event import GroupMessageEvent, PrivateMessageEvent
from nonebot.adapters.onebot.v11.message import Message
from nonebot.adapters.onebot.v11.permission import GROUP_ADMIN, GROUP_OWNER
@ -11,7 +12,7 @@ from nonebot.adapters.onebot.v11.utils import unescape
from nonebot.internal.params import ArgStr
from nonebot.internal.rule import Rule
from nonebot.matcher import Matcher
from nonebot.params import Depends, EventPlainText, EventToMe
from nonebot.params import ArgPlainText, Depends, EventPlainText, EventToMe
from nonebot.permission import SUPERUSER
from nonebot.rule import to_me
from nonebot.typing import T_State
@ -31,12 +32,6 @@ from .types import Category, Target
from .utils import parse_text
def _gen_prompt_template(prompt: str):
if hasattr(Message, "template"):
return Message.template(prompt)
return prompt
def _configurable_to_me(to_me: bool = EventToMe()):
if plugin_config.bison_to_me:
return to_me
@ -66,12 +61,22 @@ def ensure_user_info(matcher: Type[Matcher]):
return _check_user_info
async def set_target_user_info(event: MessageEvent, state: T_State):
async def set_target_user_info(event: Event, state: T_State):
user = extract_target(event)
state["target_user_info"] = user
def gen_handle_cancel(matcher: Type[Matcher], message: str):
async def _handle_cancel(text: Annotated[str, EventPlainText()]):
if text == "取消":
await matcher.finish(message)
return Depends(_handle_cancel)
def do_add_sub(add_sub: Type[Matcher]):
handle_cancel = gen_handle_cancel(add_sub, "已中止订阅")
add_sub.handle()(ensure_user_info(add_sub))
@add_sub.handle()
@ -89,10 +94,10 @@ def do_add_sub(add_sub: Type[Matcher]):
+ "要查看全部平台请输入:“全部”\n中止订阅过程请输入:“取消”"
)
async def parse_platform(event: MessageEvent, state: T_State) -> None:
@add_sub.got("platform", Message.template("{_prompt}"), [handle_cancel])
async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None:
if not isinstance(state["platform"], Message):
return
platform = str(event.get_message()).strip()
if platform == "全部":
message = "全部平台\n" + "\n".join(
[
@ -108,10 +113,8 @@ def do_add_sub(add_sub: Type[Matcher]):
else:
await add_sub.reject("平台输入错误")
@add_sub.got(
"platform", _gen_prompt_template("{_prompt}"), [Depends(parse_platform)]
)
async def init_id(state: T_State):
@add_sub.handle()
async def prepare_get_id(matcher: Matcher, state: T_State):
cur_platform = platform_manager[state["platform"]]
if cur_platform.has_target:
state["_prompt"] = (
@ -120,36 +123,29 @@ def do_add_sub(add_sub: Type[Matcher]):
else ""
) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”"
else:
matcher.set_arg("raw_id", Message("no id"))
state["id"] = "default"
state["name"] = await check_sub_target(state["platform"], Target(""))
async def parse_id(event: MessageEvent, state: T_State):
if not isinstance(state["id"], Message):
@add_sub.got("raw_id", Message.template("{_prompt}"), [handle_cancel])
async def got_id(state: T_State, raw_id: str = ArgPlainText()):
if state.get("id"):
return
target = str(event.get_message()).strip()
try:
if target == "查询":
raise LookupError
if target == "取消":
raise KeyboardInterrupt
if raw_id == "查询":
url = "https://nonebot-bison.netlify.app/usage/#%E6%89%80%E6%94%AF%E6%8C%81%E5%B9%B3%E5%8F%B0%E7%9A%84-uid"
title = "Bison所支持的平台UID"
content = "查询相关平台的uid格式或获取方式"
image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg"
getId_share = f"[CQ:share,url={url},title={title},content={content},image={image}]" # 缩短字符串格式长度,以及方便后续修改为消息段格式
await add_sub.reject(Message(getId_share))
platform = platform_manager[state["platform"]]
target = await platform.parse_target(unescape(target))
name = await check_sub_target(state["platform"], target)
raw_id = await platform.parse_target(unescape(raw_id))
name = await check_sub_target(state["platform"], raw_id)
if not name:
raise ValueError
state["id"] = target
await add_sub.reject("id输入错误")
state["id"] = raw_id
state["name"] = name
except (LookupError):
url = "https://nonebot-bison.netlify.app/usage/#%E6%89%80%E6%94%AF%E6%8C%81%E5%B9%B3%E5%8F%B0%E7%9A%84-uid"
title = "Bison所支持的平台UID"
content = "查询相关平台的uid格式或获取方式"
image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg"
getId_share = f"[CQ:share,url={url},title={title},content={content},image={image}]" # 缩短字符串格式长度,以及方便后续修改为消息段格式
await add_sub.reject(Message(getId_share))
except (KeyboardInterrupt):
await add_sub.finish("已中止订阅")
except (ValueError):
await add_sub.reject("id输入错误")
except (Platform.ParseTargetException):
await add_sub.reject("不能从你的输入中提取出id请检查你输入的内容是否符合预期")
else:
@ -159,50 +155,51 @@ def do_add_sub(add_sub: Type[Matcher]):
)
)
@add_sub.got("id", _gen_prompt_template("{_prompt}"), [Depends(parse_id)])
async def init_cat(state: T_State):
@add_sub.handle()
async def prepare_get_categories(matcher: Matcher, state: T_State):
if not platform_manager[state["platform"]].categories:
matcher.set_arg("raw_cats", Message(""))
state["cats"] = []
return
state["_prompt"] = "请输入要订阅的类别,以空格分隔,支持的类别有:{}".format(
" ".join(list(platform_manager[state["platform"]].categories.values()))
)
async def parser_cats(event: MessageEvent, state: T_State):
if not isinstance(state["cats"], Message):
@add_sub.got("raw_cats", Message.template("{_prompt}"), [handle_cancel])
async def parser_cats(state: T_State, raw_cats: str = ArgPlainText()):
if "cats" in state.keys():
return
res = []
for cat in str(event.get_message()).strip().split():
if cat == "取消":
await add_sub.finish("已中止订阅")
elif cat not in platform_manager[state["platform"]].reverse_category:
await add_sub.reject("不支持 {}".format(cat))
res.append(platform_manager[state["platform"]].reverse_category[cat])
if platform_manager[state["platform"]].categories:
for cat in raw_cats.split():
if cat not in platform_manager[state["platform"]].reverse_category:
await add_sub.reject("不支持 {}".format(cat))
res.append(platform_manager[state["platform"]].reverse_category[cat])
state["cats"] = res
@add_sub.got("cats", _gen_prompt_template("{_prompt}"), [Depends(parser_cats)])
async def init_tag(state: T_State):
@add_sub.handle()
async def prepare_get_tags(matcher: Matcher, state: T_State):
if not platform_manager[state["platform"]].enable_tag:
matcher.set_arg("raw_tags", Message(""))
state["tags"] = []
return
state["_prompt"] = '请输入要订阅/屏蔽的标签(不含#号)\n多个标签请使用空格隔开\n订阅所有标签输入"全部标签"\n具体规则回复"详情"'
async def parser_tags(event: MessageEvent, state: T_State):
if not isinstance(state["tags"], Message):
@add_sub.got("raw_tags", Message.template("{_prompt}"), [handle_cancel])
async def parser_tags(state: T_State, raw_tags: str = ArgPlainText()):
if "tags" in state.keys():
return
if str(event.get_message()).strip() == "取消": # 一般不会有叫 取消 的tag吧
await add_sub.finish("已中止订阅")
if str(event.get_message()).strip() == "详情":
if raw_tags == "详情":
await add_sub.reject(
"订阅标签直接输入标签内容\n屏蔽标签请在标签名称前添加~号\n详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag"
)
if str(event.get_message()).strip() in ["全部标签", "全部", "全标签"]:
if raw_tags in ["全部标签", "全部", "全标签"]:
state["tags"] = []
else:
state["tags"] = str(event.get_message()).strip().split()
state["tags"] = raw_tags.split()
@add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)])
async def add_sub_process(event: Event, state: T_State):
@add_sub.handle()
async def add_sub_process(state: T_State):
user = cast(PlatformTarget, state.get("target_user_info"))
assert isinstance(user, PlatformTarget)
try:
@ -252,6 +249,8 @@ def do_query_sub(query_sub: Type[Matcher]):
def do_del_sub(del_sub: Type[Matcher]):
handle_cancel = gen_handle_cancel(del_sub, "删除中止")
del_sub.handle()(ensure_user_info(del_sub))
@del_sub.handle()
@ -293,13 +292,10 @@ def do_del_sub(del_sub: Type[Matcher]):
res += "请输入要删除的订阅的序号\n输入'取消'中止"
await MessageFactory(await parse_text(res)).send()
@del_sub.receive()
async def do_del(event: Event, state: T_State):
user_msg = str(event.get_message()).strip()
if user_msg == "取消":
await del_sub.finish("删除中止")
@del_sub.receive(parameterless=[handle_cancel])
async def do_del(state: T_State, index_str: str = EventPlainText()):
try:
index = int(user_msg)
index = int(index_str)
user_info = state["target_user_info"]
assert isinstance(user_info, PlatformTarget)
await config.del_subscribe(user_info, **state["sub_table"][index])
@ -339,6 +335,8 @@ group_manage_matcher = on_command(
"群管理", rule=to_me(), permission=SUPERUSER, priority=4, block=True
)
group_handle_cancel = gen_handle_cancel(group_manage_matcher, "已取消")
@group_manage_matcher.handle()
async def send_group_list_private(bot: Bot, event: GroupMessageEvent, state: T_State):
@ -359,42 +357,24 @@ async def send_group_list(bot: Bot, event: PrivateMessageEvent, state: T_State):
state["group_number_idx"] = group_number_idx
async def _parse_group_idx(state: T_State, event_msg: str = EventPlainText()):
if not isinstance(state["group_idx"], Message):
return
group_number_idx: Optional[dict[int, int]] = state.get("group_number_idx")
assert group_number_idx
try:
assert event_msg != "取消", "userAbort"
idx = int(event_msg)
assert idx in group_number_idx.keys(), "idxNotInList"
state["group_idx"] = idx
except AssertionError as AE:
errType = AE.args[0]
if errType == "userAbort":
await group_manage_matcher.finish("已取消")
elif errType == "idxNotInList":
await group_manage_matcher.reject("请输入正确序号")
@group_manage_matcher.got(
"group_idx", _gen_prompt_template("{_prompt}"), [Depends(_parse_group_idx)]
"group_idx", Message.template("{_prompt}"), [group_handle_cancel]
)
async def do_choose_group_number(state: T_State):
async def do_choose_group_number(state: T_State, group_idx: str = ArgPlainText()):
group_number_idx: dict[int, int] = state["group_number_idx"]
assert group_number_idx
idx = int(group_idx)
if idx not in group_number_idx.keys():
await group_manage_matcher.reject("请输入正确序号")
state["group_idx"] = idx
group_number_idx: dict[int, int] = state["group_number_idx"]
idx: int = state["group_idx"]
group_id = group_number_idx[idx]
state["target_user_info"] = TargetQQGroup(group_id=group_id)
async def _check_command(event_msg: str = EventPlainText()):
if event_msg not in {"添加订阅", "查询订阅", "删除订阅", "取消"}:
await group_manage_matcher.reject("请输入正确的命令")
return
@group_manage_matcher.got(
"command", "请输入需要使用的命令:添加订阅,查询订阅,删除订阅,取消", [Depends(_check_command)]
"command", "请输入需要使用的命令:添加订阅,查询订阅,删除订阅,取消", [group_handle_cancel]
)
async def do_dispatch_command(
bot: Bot,
@ -403,8 +383,8 @@ async def do_dispatch_command(
matcher: Matcher,
command: str = ArgStr(),
):
if command == "取消":
await group_manage_matcher.finish("已取消")
if command not in {"添加订阅", "查询订阅", "删除订阅", "取消"}:
await group_manage_matcher.reject("请输入正确的命令")
permission = await matcher.update_permission(bot, event)
new_matcher = Matcher.new(
"message",