Merge remote-tracking branch 'upstream/feat/parse-target'

This commit is contained in:
Azide
2022-03-20 19:30:31 +08:00
12 changed files with 181 additions and 62 deletions
+15 -39
View File
@@ -1,5 +1,4 @@
import asyncio
from asyncio.tasks import Task
from datetime import datetime
from typing import Optional, Type
@@ -12,10 +11,11 @@ from nonebot.internal.params import ArgStr
from nonebot.internal.rule import Rule
from nonebot.log import logger
from nonebot.matcher import Matcher
from nonebot.params import Depends, EventMessage, EventPlainText, EventToMe
from nonebot.params import Depends, EventPlainText, EventToMe
from nonebot.permission import SUPERUSER
from nonebot.rule import to_me
from nonebot.typing import T_State
from nonebot_bison.platform.platform import Platform
from .config import Config
from .platform import check_sub_target, platform_manager
@@ -109,8 +109,11 @@ def do_add_sub(add_sub: Type[Matcher]):
"platform", _gen_prompt_template("{_prompt}"), [Depends(parse_platform)]
)
async def init_id(state: T_State):
if platform_manager[state["platform"]].has_target:
state["_prompt"] = "请输入订阅用户的id:\n查询id获取方法请回复:“查询”"
cur_platform = platform_manager[state["platform"]]
if cur_platform.has_target:
state["_prompt"] = (
cur_platform.parse_target_promot or "请输入订阅用户的id:\n查询id获取方法请回复:“查询”"
)
else:
state["id"] = "default"
state["name"] = await platform_manager[state["platform"]].get_target_name(
@@ -126,6 +129,8 @@ def do_add_sub(add_sub: Type[Matcher]):
raise LookupError
if target == "取消":
raise KeyboardInterrupt
platform = platform_manager[state["platform"]]
target = await platform.parse_target(target)
name = await check_sub_target(state["platform"], target)
if not name:
raise ValueError
@@ -142,6 +147,8 @@ def do_add_sub(add_sub: Type[Matcher]):
await add_sub.finish("已中止订阅")
except (ValueError):
await add_sub.reject("id输入错误")
except (Platform.ParseTargetException):
await add_sub.reject("不能从你的输入中提取出id,请检查你输入的内容是否符合预期")
else:
await add_sub.send(
"即将订阅的用户为:{} {} {}\n如有错误请输入“取消”重新订阅".format(
@@ -321,7 +328,7 @@ group_manage_matcher = on_command("群管理", rule=to_me(), permission=SUPERUSE
@group_manage_matcher.handle()
async def send_group_list(bot: Bot, event: GroupMessageEvent, state: T_State):
async def send_group_list_private(bot: Bot, event: GroupMessageEvent, state: T_State):
await group_manage_matcher.finish(Message("该功能只支持私聊使用,请私聊Bot"))
@@ -383,13 +390,13 @@ async def do_dispatch_command(
"message",
Rule(),
permission,
None,
True,
handlers=None,
temp=True,
priority=0,
block=True,
plugin=matcher.plugin,
module=matcher.module,
expire_time=datetime.now() + bot.config.session_expire_timeout,
expire_time=datetime.now(),
default_state=matcher.state,
default_type_updater=matcher.__class__._default_type_updater,
default_permission_updater=matcher.__class__._default_permission_updater,
@@ -402,34 +409,3 @@ async def do_dispatch_command(
do_del_sub(new_matcher)
new_matcher_ins = new_matcher()
asyncio.create_task(new_matcher_ins.run(bot, event, state))
test_matcher = on_command("testtt")
@test_matcher.handle()
async def _handler(bot: Bot, event: Event, matcher: Matcher, state: T_State):
permission = await matcher.update_permission(bot, event)
new_matcher = Matcher.new(
"message",
Rule(),
permission,
None,
True,
priority=0,
block=True,
plugin=matcher.plugin,
module=matcher.module,
expire_time=datetime.now() + bot.config.session_expire_timeout,
default_state=matcher.state,
default_type_updater=matcher.__class__._default_type_updater,
default_permission_updater=matcher.__class__._default_permission_updater,
)
async def h():
logger.warning("yes")
await new_matcher.send("666")
new_matcher.handle()(h)
new_matcher_ins = new_matcher()
await new_matcher_ins.run(bot, event, state)
+11 -1
View File
@@ -1,11 +1,12 @@
import json
import re
from typing import Any, Optional
import httpx
from ..post import Post
from ..types import Category, RawPost, Tag, Target
from .platform import CategoryNotSupport, NewMessage
from .platform import CategoryNotSupport, NewMessage, Platform
class Bilibili(NewMessage):
@@ -26,6 +27,7 @@ class Bilibili(NewMessage):
schedule_kw = {"seconds": 10}
name = "B站"
has_target = True
parse_target_promot = "请输入用户主页的链接"
async def get_target_name(self, target: Target) -> Optional[str]:
async with httpx.AsyncClient() as client:
@@ -37,6 +39,14 @@ class Bilibili(NewMessage):
return None
return res_data["data"]["name"]
async def parse_target(self, target_text: str) -> Target:
if re.match(r"\d+", target_text):
return Target(target_text)
elif match := re.match(r"(?:https://)?space.bilibili.com/(\d+)", target_text):
return Target(match.group(1))
else:
raise Platform.ParseTargetException()
async def get_sub_list(self, target: Target) -> list[RawPost]:
async with httpx.AsyncClient() as client:
params = {"host_uid": target, "offset": 0, "need_top": 0}
@@ -47,6 +47,7 @@ class Platform(metaclass=RegistryABCMeta, base=True):
enable_tag: bool
store: dict[Target, Any]
platform_name: str
parse_target_promot: Optional[str] = None
@abstractmethod
async def get_target_name(self, target: Target) -> Optional[str]:
@@ -73,6 +74,12 @@ class Platform(metaclass=RegistryABCMeta, base=True):
self.reverse_category[val] = key
self.store = dict()
class ParseTargetException(Exception):
pass
async def parse_target(self, target_string: str) -> Target:
return Target(target_string)
@abstractmethod
def get_tags(self, raw_post: RawPost) -> Optional[Collection[Tag]]:
"Return Tag list of given RawPost"
+1
View File
@@ -139,6 +139,7 @@ class Post:
msgs = [reduce(lambda x, y: x.append(y), msgs, Message())]
msgs.extend(self.extra_msg)
self._message = msgs
assert len(self._message) > 0, f"message list empty, {self}"
return self._message
def __str__(self):