From 8a20897fe9a8eae917ba63f393e0486baa7fbb33 Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Wed, 7 Jun 2023 02:10:15 +0800 Subject: [PATCH] :recycle: refactor config manager --- nonebot_bison/config_manager.py | 170 ++++++++++++++------------------ 1 file changed, 75 insertions(+), 95 deletions(-) diff --git a/nonebot_bison/config_manager.py b/nonebot_bison/config_manager.py index b80aae3..cff5217 100644 --- a/nonebot_bison/config_manager.py +++ b/nonebot_bison/config_manager.py @@ -1,9 +1,10 @@ import asyncio from datetime import datetime -from typing import Optional, Type, cast +from typing import Annotated, Optional, Type, cast -from nonebot import on_command -from nonebot.adapters.onebot.v11 import Bot, Event, MessageEvent +from nonebot import logger, on_command +from nonebot.adapters import Bot, Event +from nonebot.adapters.onebot.v11 import MessageEvent from nonebot.adapters.onebot.v11.event import GroupMessageEvent, PrivateMessageEvent from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.permission import GROUP_ADMIN, GROUP_OWNER @@ -11,7 +12,7 @@ from nonebot.adapters.onebot.v11.utils import unescape from nonebot.internal.params import ArgStr from nonebot.internal.rule import Rule from nonebot.matcher import Matcher -from nonebot.params import Depends, EventPlainText, EventToMe +from nonebot.params import ArgPlainText, Depends, EventPlainText, EventToMe from nonebot.permission import SUPERUSER from nonebot.rule import to_me from nonebot.typing import T_State @@ -31,12 +32,6 @@ from .types import Category, Target from .utils import parse_text -def _gen_prompt_template(prompt: str): - if hasattr(Message, "template"): - return Message.template(prompt) - return prompt - - def _configurable_to_me(to_me: bool = EventToMe()): if plugin_config.bison_to_me: return to_me @@ -66,12 +61,22 @@ def ensure_user_info(matcher: Type[Matcher]): return _check_user_info -async def set_target_user_info(event: MessageEvent, state: T_State): +async def set_target_user_info(event: Event, state: T_State): user = extract_target(event) state["target_user_info"] = user +def gen_handle_cancel(matcher: Type[Matcher], message: str): + async def _handle_cancel(text: Annotated[str, EventPlainText()]): + if text == "取消": + await matcher.finish(message) + + return Depends(_handle_cancel) + + def do_add_sub(add_sub: Type[Matcher]): + handle_cancel = gen_handle_cancel(add_sub, "已中止订阅") + add_sub.handle()(ensure_user_info(add_sub)) @add_sub.handle() @@ -89,10 +94,10 @@ def do_add_sub(add_sub: Type[Matcher]): + "要查看全部平台请输入:“全部”\n中止订阅过程请输入:“取消”" ) - async def parse_platform(event: MessageEvent, state: T_State) -> None: + @add_sub.got("platform", Message.template("{_prompt}"), [handle_cancel]) + async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None: if not isinstance(state["platform"], Message): return - platform = str(event.get_message()).strip() if platform == "全部": message = "全部平台\n" + "\n".join( [ @@ -108,10 +113,8 @@ def do_add_sub(add_sub: Type[Matcher]): else: await add_sub.reject("平台输入错误") - @add_sub.got( - "platform", _gen_prompt_template("{_prompt}"), [Depends(parse_platform)] - ) - async def init_id(state: T_State): + @add_sub.handle() + async def prepare_get_id(matcher: Matcher, state: T_State): cur_platform = platform_manager[state["platform"]] if cur_platform.has_target: state["_prompt"] = ( @@ -120,36 +123,29 @@ def do_add_sub(add_sub: Type[Matcher]): else "" ) + "请输入订阅用户的id\n查询id获取方法请回复:“查询”" else: + matcher.set_arg("raw_id", Message("no id")) state["id"] = "default" state["name"] = await check_sub_target(state["platform"], Target("")) - async def parse_id(event: MessageEvent, state: T_State): - if not isinstance(state["id"], Message): + @add_sub.got("raw_id", Message.template("{_prompt}"), [handle_cancel]) + async def got_id(state: T_State, raw_id: str = ArgPlainText()): + if state.get("id"): return - target = str(event.get_message()).strip() try: - if target == "查询": - raise LookupError - if target == "取消": - raise KeyboardInterrupt + if raw_id == "查询": + url = "https://nonebot-bison.netlify.app/usage/#%E6%89%80%E6%94%AF%E6%8C%81%E5%B9%B3%E5%8F%B0%E7%9A%84-uid" + title = "Bison所支持的平台UID" + content = "查询相关平台的uid格式或获取方式" + image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg" + getId_share = f"[CQ:share,url={url},title={title},content={content},image={image}]" # 缩短字符串格式长度,以及方便后续修改为消息段格式 + await add_sub.reject(Message(getId_share)) platform = platform_manager[state["platform"]] - target = await platform.parse_target(unescape(target)) - name = await check_sub_target(state["platform"], target) + raw_id = await platform.parse_target(unescape(raw_id)) + name = await check_sub_target(state["platform"], raw_id) if not name: - raise ValueError - state["id"] = target + await add_sub.reject("id输入错误") + state["id"] = raw_id state["name"] = name - except (LookupError): - url = "https://nonebot-bison.netlify.app/usage/#%E6%89%80%E6%94%AF%E6%8C%81%E5%B9%B3%E5%8F%B0%E7%9A%84-uid" - title = "Bison所支持的平台UID" - content = "查询相关平台的uid格式或获取方式" - image = "https://s3.bmp.ovh/imgs/2022/03/ab3cc45d83bd3dd3.jpg" - getId_share = f"[CQ:share,url={url},title={title},content={content},image={image}]" # 缩短字符串格式长度,以及方便后续修改为消息段格式 - await add_sub.reject(Message(getId_share)) - except (KeyboardInterrupt): - await add_sub.finish("已中止订阅") - except (ValueError): - await add_sub.reject("id输入错误") except (Platform.ParseTargetException): await add_sub.reject("不能从你的输入中提取出id,请检查你输入的内容是否符合预期") else: @@ -159,50 +155,51 @@ def do_add_sub(add_sub: Type[Matcher]): ) ) - @add_sub.got("id", _gen_prompt_template("{_prompt}"), [Depends(parse_id)]) - async def init_cat(state: T_State): + @add_sub.handle() + async def prepare_get_categories(matcher: Matcher, state: T_State): if not platform_manager[state["platform"]].categories: + matcher.set_arg("raw_cats", Message("")) state["cats"] = [] return state["_prompt"] = "请输入要订阅的类别,以空格分隔,支持的类别有:{}".format( " ".join(list(platform_manager[state["platform"]].categories.values())) ) - async def parser_cats(event: MessageEvent, state: T_State): - if not isinstance(state["cats"], Message): + @add_sub.got("raw_cats", Message.template("{_prompt}"), [handle_cancel]) + async def parser_cats(state: T_State, raw_cats: str = ArgPlainText()): + if "cats" in state.keys(): return res = [] - for cat in str(event.get_message()).strip().split(): - if cat == "取消": - await add_sub.finish("已中止订阅") - elif cat not in platform_manager[state["platform"]].reverse_category: - await add_sub.reject("不支持 {}".format(cat)) - res.append(platform_manager[state["platform"]].reverse_category[cat]) + if platform_manager[state["platform"]].categories: + for cat in raw_cats.split(): + if cat not in platform_manager[state["platform"]].reverse_category: + await add_sub.reject("不支持 {}".format(cat)) + res.append(platform_manager[state["platform"]].reverse_category[cat]) state["cats"] = res - @add_sub.got("cats", _gen_prompt_template("{_prompt}"), [Depends(parser_cats)]) - async def init_tag(state: T_State): + @add_sub.handle() + async def prepare_get_tags(matcher: Matcher, state: T_State): if not platform_manager[state["platform"]].enable_tag: + matcher.set_arg("raw_tags", Message("")) state["tags"] = [] return state["_prompt"] = '请输入要订阅/屏蔽的标签(不含#号)\n多个标签请使用空格隔开\n订阅所有标签输入"全部标签"\n具体规则回复"详情"' - async def parser_tags(event: MessageEvent, state: T_State): - if not isinstance(state["tags"], Message): + @add_sub.got("raw_tags", Message.template("{_prompt}"), [handle_cancel]) + async def parser_tags(state: T_State, raw_tags: str = ArgPlainText()): + if "tags" in state.keys(): return - if str(event.get_message()).strip() == "取消": # 一般不会有叫 取消 的tag吧 - await add_sub.finish("已中止订阅") - if str(event.get_message()).strip() == "详情": + if raw_tags == "详情": await add_sub.reject( "订阅标签直接输入标签内容\n屏蔽标签请在标签名称前添加~号\n详见https://nonebot-bison.netlify.app/usage/#%E5%B9%B3%E5%8F%B0%E8%AE%A2%E9%98%85%E6%A0%87%E7%AD%BE-tag" ) - if str(event.get_message()).strip() in ["全部标签", "全部", "全标签"]: + if raw_tags in ["全部标签", "全部", "全标签"]: state["tags"] = [] else: - state["tags"] = str(event.get_message()).strip().split() + state["tags"] = raw_tags.split() - @add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)]) - async def add_sub_process(event: Event, state: T_State): + @add_sub.handle() + async def add_sub_process(state: T_State): user = cast(PlatformTarget, state.get("target_user_info")) assert isinstance(user, PlatformTarget) try: @@ -252,6 +249,8 @@ def do_query_sub(query_sub: Type[Matcher]): def do_del_sub(del_sub: Type[Matcher]): + handle_cancel = gen_handle_cancel(del_sub, "删除中止") + del_sub.handle()(ensure_user_info(del_sub)) @del_sub.handle() @@ -293,13 +292,10 @@ def do_del_sub(del_sub: Type[Matcher]): res += "请输入要删除的订阅的序号\n输入'取消'中止" await MessageFactory(await parse_text(res)).send() - @del_sub.receive() - async def do_del(event: Event, state: T_State): - user_msg = str(event.get_message()).strip() - if user_msg == "取消": - await del_sub.finish("删除中止") + @del_sub.receive(parameterless=[handle_cancel]) + async def do_del(state: T_State, index_str: str = EventPlainText()): try: - index = int(user_msg) + index = int(index_str) user_info = state["target_user_info"] assert isinstance(user_info, PlatformTarget) await config.del_subscribe(user_info, **state["sub_table"][index]) @@ -339,6 +335,8 @@ group_manage_matcher = on_command( "群管理", rule=to_me(), permission=SUPERUSER, priority=4, block=True ) +group_handle_cancel = gen_handle_cancel(group_manage_matcher, "已取消") + @group_manage_matcher.handle() async def send_group_list_private(bot: Bot, event: GroupMessageEvent, state: T_State): @@ -359,42 +357,24 @@ async def send_group_list(bot: Bot, event: PrivateMessageEvent, state: T_State): state["group_number_idx"] = group_number_idx -async def _parse_group_idx(state: T_State, event_msg: str = EventPlainText()): - if not isinstance(state["group_idx"], Message): - return - group_number_idx: Optional[dict[int, int]] = state.get("group_number_idx") - assert group_number_idx - try: - assert event_msg != "取消", "userAbort" - idx = int(event_msg) - assert idx in group_number_idx.keys(), "idxNotInList" - state["group_idx"] = idx - except AssertionError as AE: - errType = AE.args[0] - if errType == "userAbort": - await group_manage_matcher.finish("已取消") - elif errType == "idxNotInList": - await group_manage_matcher.reject("请输入正确序号") - - @group_manage_matcher.got( - "group_idx", _gen_prompt_template("{_prompt}"), [Depends(_parse_group_idx)] + "group_idx", Message.template("{_prompt}"), [group_handle_cancel] ) -async def do_choose_group_number(state: T_State): +async def do_choose_group_number(state: T_State, group_idx: str = ArgPlainText()): + group_number_idx: dict[int, int] = state["group_number_idx"] + assert group_number_idx + idx = int(group_idx) + if idx not in group_number_idx.keys(): + await group_manage_matcher.reject("请输入正确序号") + state["group_idx"] = idx group_number_idx: dict[int, int] = state["group_number_idx"] idx: int = state["group_idx"] group_id = group_number_idx[idx] state["target_user_info"] = TargetQQGroup(group_id=group_id) -async def _check_command(event_msg: str = EventPlainText()): - if event_msg not in {"添加订阅", "查询订阅", "删除订阅", "取消"}: - await group_manage_matcher.reject("请输入正确的命令") - return - - @group_manage_matcher.got( - "command", "请输入需要使用的命令:添加订阅,查询订阅,删除订阅,取消", [Depends(_check_command)] + "command", "请输入需要使用的命令:添加订阅,查询订阅,删除订阅,取消", [group_handle_cancel] ) async def do_dispatch_command( bot: Bot, @@ -403,8 +383,8 @@ async def do_dispatch_command( matcher: Matcher, command: str = ArgStr(), ): - if command == "取消": - await group_manage_matcher.finish("已取消") + if command not in {"添加订阅", "查询订阅", "删除订阅", "取消"}: + await group_manage_matcher.reject("请输入正确的命令") permission = await matcher.update_permission(bot, event) new_matcher = Matcher.new( "message",