🎨 use Arg to replace assert

This commit is contained in:
felinae98 2023-06-11 15:14:38 +08:00
parent 4846d32e2e
commit 9f1730093c
3 changed files with 42 additions and 47 deletions

View File

@ -37,8 +37,6 @@ def do_add_sub(add_sub: Type[Matcher]):
@add_sub.got("platform", MessageTemplate("{_prompt}"), [handle_cancel]) @add_sub.got("platform", MessageTemplate("{_prompt}"), [handle_cancel])
async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None: async def parse_platform(state: T_State, platform: str = ArgPlainText()) -> None:
if not isinstance(state["platform"], Message):
return
if platform == "全部": if platform == "全部":
message = "全部平台\n" + "\n".join( message = "全部平台\n" + "\n".join(
[ [
@ -152,9 +150,9 @@ def do_add_sub(add_sub: Type[Matcher]):
state["tags"] = raw_tags_text.split() state["tags"] = raw_tags_text.split()
@add_sub.handle() @add_sub.handle()
async def add_sub_process(state: T_State): async def add_sub_process(
user = cast(PlatformTarget, state.get("target_user_info")) state: T_State, user: PlatformTarget = Arg("target_user_info")
assert isinstance(user, PlatformTarget) ):
try: try:
await config.add_subscribe( await config.add_subscribe(
user=user, user=user,

View File

@ -1,7 +1,7 @@
from typing import Type from typing import Type
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.params import EventPlainText from nonebot.params import Arg, EventPlainText
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot_plugin_saa import MessageFactory, PlatformTarget from nonebot_plugin_saa import MessageFactory, PlatformTarget
@ -18,15 +18,12 @@ def do_del_sub(del_sub: Type[Matcher]):
del_sub.handle()(ensure_user_info(del_sub)) del_sub.handle()(ensure_user_info(del_sub))
@del_sub.handle() @del_sub.handle()
async def send_list(state: T_State): async def send_list(
user_info = state["target_user_info"] state: T_State, user_info: PlatformTarget = Arg("target_user_info")
assert isinstance(user_info, PlatformTarget) ):
try:
sub_list = await config.list_subscribe(user_info) sub_list = await config.list_subscribe(user_info)
assert sub_list if not sub_list:
except AssertionError:
await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅") await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
else:
res = "订阅的帐号为:\n" res = "订阅的帐号为:\n"
state["sub_table"] = {} state["sub_table"] = {}
for index, sub in enumerate(sub_list, 1): for index, sub in enumerate(sub_list, 1):
@ -57,11 +54,13 @@ def do_del_sub(del_sub: Type[Matcher]):
await MessageFactory(await parse_text(res)).send() await MessageFactory(await parse_text(res)).send()
@del_sub.receive(parameterless=[handle_cancel]) @del_sub.receive(parameterless=[handle_cancel])
async def do_del(state: T_State, index_str: str = EventPlainText()): async def do_del(
state: T_State,
index_str: str = EventPlainText(),
user_info: PlatformTarget = Arg("target_user_info"),
):
try: try:
index = int(index_str) index = int(index_str)
user_info = state["target_user_info"]
assert isinstance(user_info, PlatformTarget)
await config.del_subscribe(user_info, **state["sub_table"][index]) await config.del_subscribe(user_info, **state["sub_table"][index])
except Exception as e: except Exception as e:
await del_sub.reject("删除错误") await del_sub.reject("删除错误")

View File

@ -1,7 +1,7 @@
from typing import Type from typing import Type
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.typing import T_State from nonebot.params import Arg
from nonebot_plugin_saa import MessageFactory, PlatformTarget from nonebot_plugin_saa import MessageFactory, PlatformTarget
from ..config import config from ..config import config
@ -15,9 +15,7 @@ def do_query_sub(query_sub: Type[Matcher]):
query_sub.handle()(ensure_user_info(query_sub)) query_sub.handle()(ensure_user_info(query_sub))
@query_sub.handle() @query_sub.handle()
async def _(state: T_State): async def _(user_info: PlatformTarget = Arg("target_user_info")):
user_info = state["target_user_info"]
assert isinstance(user_info, PlatformTarget)
sub_list = await config.list_subscribe(user_info) sub_list = await config.list_subscribe(user_info)
res = "订阅的帐号为:\n" res = "订阅的帐号为:\n"
for sub in sub_list: for sub in sub_list: