From cc44e4588739e68c1edaea5fcd9a2ee17680160d Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Wed, 2 Mar 2022 11:33:14 +0800 Subject: [PATCH] update group admin --- poetry.lock | 51 ++--- pyproject.toml | 7 +- src/plugins/nonebot_bison/config_manager.py | 218 ++++++++++++++++---- src/plugins/nonebot_bison/types.py | 2 +- tests/conftest.py | 25 ++- tests/test_config_manager.py | 14 ++ tests/test_config_manager_admin.py | 45 ++++ tests/utils.py | 1 + 8 files changed, 289 insertions(+), 74 deletions(-) create mode 100644 tests/test_config_manager_admin.py diff --git a/poetry.lock b/poetry.lock index beee8da..69aae51 100644 --- a/poetry.lock +++ b/poetry.lock @@ -735,15 +735,21 @@ version = "0.2.1" description = "nonebot2 test framework" category = "dev" optional = false -python-versions = ">=3.7.3,<4.0.0" +python-versions = "^3.7.3" +develop = false [package.dependencies] -asgiref = ">=3.4.0,<4.0.0" -async-asgi-testclient = ">=1.4.8,<2.0.0" -nonebot2 = ">=2.0.0-beta.1,<3.0.0" -pytest = ">=6.2.5,<7.0.0" -pytest-asyncio = ">=0.16.0,<0.17.0" -typing-extensions = ">=4.0.0,<5.0.0" +asgiref = "^3.4.0" +async-asgi-testclient = "^1.4.8" +nonebot2 = "^2.0.0-beta.1" +pytest = "^7.0.0" +typing-extensions = "^4.0.0" + +[package.source] +type = "git" +url = "https://github.com/nonebot/nonebug.git" +reference = "40fcd4f" +resolved_reference = "40fcd4f3eff8f4b2118e95938fabc3d77ff6819c" [[package]] name = "packaging" @@ -988,7 +994,7 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pytest" -version = "6.2.5" +version = "7.0.1" description = "pytest: simple powerful testing with Python" category = "dev" optional = false @@ -1002,24 +1008,24 @@ iniconfig = "*" packaging = "*" pluggy = ">=0.12,<2.0" py = ">=1.8.2" -toml = "*" +tomli = ">=1.0.0" [package.extras] -testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] +testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] [[package]] name = "pytest-asyncio" -version = "0.16.0" -description = "Pytest support for asyncio." +version = "0.18.1" +description = "Pytest support for asyncio" category = "dev" optional = false -python-versions = ">= 3.6" +python-versions = ">=3.7" [package.dependencies] -pytest = ">=5.4.0" +pytest = ">=6.1.0" [package.extras] -testing = ["coverage", "hypothesis (>=5.7.1)"] +testing = ["coverage (==6.2)", "hypothesis (>=5.7.1)", "flaky (>=3.5.0)", "mypy (==0.931)"] [[package]] name = "pytest-cov" @@ -1455,7 +1461,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "7dbb53c4a2386da480678e34620664651cf61dd754dd3dae7b6e8e8a56c04b9d" +content-hash = "4a0d093b99ca03d9e9effe113ccd85daf9242c0a95f2336d2af8fcc7e764b6ce" [metadata.files] aiofiles = [ @@ -1920,10 +1926,7 @@ nonebot2 = [ {file = "nonebot2-2.0.0b2-py3-none-any.whl", hash = "sha256:8166490311b607f8fbf5e31934b005e29f6d39ff222a6771ec36c9456ec337ec"}, {file = "nonebot2-2.0.0b2.tar.gz", hash = "sha256:2950f27a62f2a98b2abf3128c19d898a24c2867e70fb5c6af231eadf558b18a8"}, ] -nonebug = [ - {file = "nonebug-0.2.1-py3-none-any.whl", hash = "sha256:f4d59effd50e400ee866df57902e4d749227a76857be26a0607fc2a5f6a05f7c"}, - {file = "nonebug-0.2.1.tar.gz", hash = "sha256:2f363bd5d65081c802b7b19a72b07ada1ad8e61968cf313176f38a5cf97e84e2"}, -] +nonebug = [] packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, @@ -2093,12 +2096,12 @@ pyparsing = [ {file = "pyparsing-3.0.7.tar.gz", hash = "sha256:18ee9022775d270c55187733956460083db60b37d0d0fb357445f3094eed3eea"}, ] pytest = [ - {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, - {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, + {file = "pytest-7.0.1-py3-none-any.whl", hash = "sha256:9ce3ff477af913ecf6321fe337b93a2c0dcf2a0a1439c43f5452112c1e4280db"}, + {file = "pytest-7.0.1.tar.gz", hash = "sha256:e30905a0c131d3d94b89624a1cc5afec3e0ba2fbdb151867d8e0ebd49850f171"}, ] pytest-asyncio = [ - {file = "pytest-asyncio-0.16.0.tar.gz", hash = "sha256:7496c5977ce88c34379df64a66459fe395cd05543f0a2f837016e7144391fcfb"}, - {file = "pytest_asyncio-0.16.0-py3-none-any.whl", hash = "sha256:5f2a21273c47b331ae6aa5b36087047b4899e40f03f18397c0e65fa5cca54e9b"}, + {file = "pytest-asyncio-0.18.1.tar.gz", hash = "sha256:c43fcdfea2335dd82ffe0f2774e40285ddfea78a8e81e56118d47b6a90fbb09e"}, + {file = "pytest_asyncio-0.18.1-py3-none-any.whl", hash = "sha256:c9ec48e8bbf5cc62755e18c4d8bc6907843ec9c5f4ac8f61464093baeba24a7e"}, ] pytest-cov = [ {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, diff --git a/pyproject.toml b/pyproject.toml index cb70049..1a2b388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,11 +40,11 @@ nonebot-plugin-htmlrender = "^0.0.4" [tool.poetry.dev-dependencies] ipdb = "^0.13.4" -pytest = "^6.2.4" -pytest-asyncio = "^0.16" +pytest = "^7.0.1" +pytest-asyncio = "^0.18.1" respx = "^0.19.0" pytest-cov = "^3.0.0" -nonebug = "^0.2.0" +nonebug = {git = "https://github.com/nonebot/nonebug.git", rev = "40fcd4f"} black = "^22.1.0" isort = "^5.10.1" pre-commit = "^2.17.0" @@ -59,6 +59,7 @@ markers = [ "compare: compare fetching result with rsshub", "render: render img by chrome" ] +asyncio_mode = "auto" [tool.black] line-length = 88 diff --git a/src/plugins/nonebot_bison/config_manager.py b/src/plugins/nonebot_bison/config_manager.py index 7836121..ee5ebf9 100644 --- a/src/plugins/nonebot_bison/config_manager.py +++ b/src/plugins/nonebot_bison/config_manager.py @@ -1,13 +1,18 @@ -from typing import Type +import asyncio +from asyncio.tasks import Task +from datetime import datetime +from typing import Optional, Type from nonebot import on_command -from nonebot.adapters import Event as AbstractEvent -from nonebot.adapters.onebot.v11 import Bot, Event +from nonebot.adapters.onebot.v11 import Bot, Event, MessageEvent +from nonebot.adapters.onebot.v11.event import GroupMessageEvent, PrivateMessageEvent from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.permission import GROUP_ADMIN, GROUP_OWNER +from nonebot.internal.params import ArgStr from nonebot.internal.rule import Rule +from nonebot.log import logger from nonebot.matcher import Matcher -from nonebot.params import Depends, EventToMe +from nonebot.params import Depends, EventMessage, EventPlainText, EventToMe from nonebot.permission import SUPERUSER from nonebot.rule import to_me from nonebot.typing import T_State @@ -15,7 +20,7 @@ from nonebot.typing import T_State from .config import Config from .platform import check_sub_target, platform_manager from .plugin_config import plugin_config -from .types import Category, Target +from .types import Category, Target, User from .utils import parse_text @@ -44,7 +49,28 @@ common_platform = [ ] +def ensure_user_info(matcher: Type[Matcher]): + async def _check_user_info(state: T_State): + if not state.get("target_user_info"): + await matcher.finish( + "No target_user_info set, this shouldn't happen, please issue" + ) + + return _check_user_info + + +async def set_target_user_info(event: MessageEvent, state: T_State): + if isinstance(event, GroupMessageEvent): + user = User(event.group_id, "group") + state["target_user_info"] = user + elif isinstance(event, PrivateMessageEvent): + user = User(event.user_id, "private") + state["target_user_info"] = user + + def do_add_sub(add_sub: Type[Matcher]): + add_sub.handle()(ensure_user_info(add_sub)) + @add_sub.handle() async def init_promote(state: T_State): state["_prompt"] = ( @@ -60,7 +86,7 @@ def do_add_sub(add_sub: Type[Matcher]): + "要查看全部平台请输入:“全部”" ) - async def parse_platform(event: AbstractEvent, state: T_State) -> None: + async def parse_platform(event: MessageEvent, state: T_State) -> None: if not isinstance(state["platform"], Message): return platform = str(event.get_message()).strip() @@ -91,7 +117,7 @@ def do_add_sub(add_sub: Type[Matcher]): Target("") ) - async def parse_id(event: AbstractEvent, state: T_State): + async def parse_id(event: MessageEvent, state: T_State): if not isinstance(state["id"], Message): return target = str(event.get_message()).strip() @@ -113,7 +139,7 @@ def do_add_sub(add_sub: Type[Matcher]): " ".join(list(platform_manager[state["platform"]].categories.values())) ) - async def parser_cats(event: AbstractEvent, state: T_State): + async def parser_cats(event: MessageEvent, state: T_State): if not isinstance(state["cats"], Message): return res = [] @@ -130,7 +156,7 @@ def do_add_sub(add_sub: Type[Matcher]): return state["_prompt"] = '请输入要订阅的tag,订阅所有tag输入"全部标签"' - async def parser_tags(event: AbstractEvent, state: T_State): + async def parser_tags(event: MessageEvent, state: T_State): if not isinstance(state["tags"], Message): return if str(event.get_message()).strip() == "全部标签": @@ -141,9 +167,13 @@ def do_add_sub(add_sub: Type[Matcher]): @add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)]) async def add_sub_process(event: Event, state: T_State): config = Config() + user = state.get("target_user_info") + assert isinstance(user, User) config.add_subscribe( - state.get("_user_id") or event.group_id, - user_type="group", + # state.get("_user_id") or event.group_id, + # user_type="group", + user=user.user, + user_type=user.user_type, target=state["id"], target_name=state["name"], target_type=state["platform"], @@ -154,11 +184,17 @@ def do_add_sub(add_sub: Type[Matcher]): def do_query_sub(query_sub: Type[Matcher]): + query_sub.handle()(ensure_user_info(query_sub)) + @query_sub.handle() - async def _(event: Event, state: T_State): + async def _(state: T_State): config: Config = Config() + user_info = state["target_user_info"] + assert isinstance(user_info, User) sub_list = config.list_subscribe( - state.get("_user_id") or event.group_id, "group" + # state.get("_user_id") or event.group_id, "group" + user_info.user, + user_info.user_type, ) res = "订阅的帐号为:\n" for sub in sub_list: @@ -179,11 +215,17 @@ def do_query_sub(query_sub: Type[Matcher]): def do_del_sub(del_sub: Type[Matcher]): + del_sub.handle()(ensure_user_info(del_sub)) + @del_sub.handle() async def send_list(bot: Bot, event: Event, state: T_State): config: Config = Config() + user_info = state["target_user_info"] + assert isinstance(user_info, User) sub_list = config.list_subscribe( - state.get("_user_id") or event.group_id, "group" + # state.get("_user_id") or event.group_id, "group" + user_info.user, + user_info.user_type, ) res = "订阅的帐号为:\n" state["sub_table"] = {} @@ -213,9 +255,13 @@ def do_del_sub(del_sub: Type[Matcher]): try: index = int(str(event.get_message()).strip()) config = Config() + user_info = state["target_user_info"] + assert isinstance(user_info, User) config.del_subscribe( - state.get("_user_id") or event.group_id, - "group", + # state.get("_user_id") or event.group_id, + # "group", + user_info.user, + user_info.user_type, **state["sub_table"][index], ) except Exception as e: @@ -224,41 +270,19 @@ def do_del_sub(del_sub: Type[Matcher]): await del_sub.finish("删除成功") -async def parse_group_number(event: AbstractEvent, state: T_State): - if not isinstance(state["_user_id"], Message): - return - state["_user_id"] = int(str(event.get_message())) - - add_sub_matcher = on_command( "添加订阅", rule=configurable_to_me, permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5, ) +add_sub_matcher.handle()(set_target_user_info) do_add_sub(add_sub_matcher) -manage_add_sub_matcher = on_command("管理-添加订阅", permission=SUPERUSER, priority=5) - - -@manage_add_sub_matcher.got("_user_id", "群号", [Depends(parse_group_number)]) -async def add_sub_handle(): - pass - - -do_add_sub(manage_add_sub_matcher) query_sub_matcher = on_command("查询订阅", rule=configurable_to_me, priority=5) +query_sub_matcher.handle()(set_target_user_info) do_query_sub(query_sub_matcher) -manage_query_sub_matcher = on_command("管理-查询订阅", permission=SUPERUSER, priority=5) - - -@manage_query_sub_matcher.got("_user_id", "群号", [Depends(parse_group_number)]) -async def query_sub_handle(): - pass - - -do_query_sub(manage_query_sub_matcher) del_sub_matcher = on_command( @@ -267,13 +291,117 @@ del_sub_matcher = on_command( permission=GROUP_ADMIN | GROUP_OWNER | SUPERUSER, priority=5, ) +del_sub_matcher.handle()(set_target_user_info) do_del_sub(del_sub_matcher) -manage_del_sub_matcher = on_command("管理-删除订阅", permission=SUPERUSER, priority=5) + +group_manage_matcher = on_command("群管理") -@manage_del_sub_matcher.got("_user_id", "群号", [Depends(parse_group_number)]) -async def del_sub_handle(): - pass +@group_manage_matcher.handle() +async def send_group_list(bot: Bot, state: T_State): + groups = await bot.call_api("get_group_list") + res_text = "请选择需要管理的群:\n" + group_number_idx = {} + for idx, group in enumerate(groups, 1): + group_number_idx[idx] = group["group_id"] + res_text += f'{idx}. {group["group_id"]} - {group["group_name"]}\n' + res_text += "请输入左侧序号" + # await group_manage_matcher.send(res_text) + state["_prompt"] = res_text + state["group_number_idx"] = group_number_idx -do_del_sub(manage_del_sub_matcher) +async def _parse_group_idx(state: T_State, event_msg: str = EventPlainText()): + if not isinstance(state["group_idx"], Message): + return + group_number_idx: Optional[dict[int, int]] = state.get("group_number_idx") + assert group_number_idx + try: + idx = int(event_msg) + assert idx in group_number_idx.keys() + state["group_idx"] = idx + except: + await group_manage_matcher.reject("请输入正确序号") + + +@group_manage_matcher.got( + "group_idx", _gen_prompt_template("{_prompt}"), [Depends(_parse_group_idx)] +) +async def do_choose_group_number(state: T_State): + group_number_idx: dict[int, int] = state["group_number_idx"] + idx: int = state["group_idx"] + group_id = group_number_idx[idx] + state["target_user_info"] = User(user=group_id, user_type="group") + + +async def _check_command(event_msg: str = EventPlainText()): + if event_msg not in {"添加订阅", "查询订阅", "删除订阅"}: + await group_manage_matcher.reject("请输入正确的命令") + return + + +@group_manage_matcher.got( + "command", "请输入需要使用的命令:添加订阅,查询订阅,删除订阅", [Depends(_check_command)] +) +async def do_dispatch_command( + bot: Bot, + event: MessageEvent, + state: T_State, + matcher: Matcher, + command: str = ArgStr(), +): + permission = await matcher.update_permission(bot, event) + new_matcher = Matcher.new( + "message", + Rule(), + permission, + None, + True, + priority=0, + block=True, + plugin=matcher.plugin, + module=matcher.module, + expire_time=datetime.now() + bot.config.session_expire_timeout, + default_state=matcher.state, + default_type_updater=matcher.__class__._default_type_updater, + default_permission_updater=matcher.__class__._default_permission_updater, + ) + if command == "查询订阅": + do_query_sub(new_matcher) + elif command == "添加订阅": + do_add_sub(new_matcher) + else: + do_del_sub(new_matcher) + new_matcher_ins = new_matcher() + asyncio.create_task(new_matcher_ins.run(bot, event, state)) + + +test_matcher = on_command("testtt") + + +@test_matcher.handle() +async def _handler(bot: Bot, event: Event, matcher: Matcher, state: T_State): + permission = await matcher.update_permission(bot, event) + new_matcher = Matcher.new( + "message", + Rule(), + permission, + None, + True, + priority=0, + block=True, + plugin=matcher.plugin, + module=matcher.module, + expire_time=datetime.now() + bot.config.session_expire_timeout, + default_state=matcher.state, + default_type_updater=matcher.__class__._default_type_updater, + default_permission_updater=matcher.__class__._default_permission_updater, + ) + + async def h(): + logger.warning("yes") + await new_matcher.send("666") + + new_matcher.handle()(h) + new_matcher_ins = new_matcher() + await new_matcher_ins.run(bot, event, state) diff --git a/src/plugins/nonebot_bison/types.py b/src/plugins/nonebot_bison/types.py index f447f38..734d877 100644 --- a/src/plugins/nonebot_bison/types.py +++ b/src/plugins/nonebot_bison/types.py @@ -9,7 +9,7 @@ Tag = str @dataclass(eq=True, frozen=True) class User: - user: str + user: int user_type: Literal["group", "private"] diff --git a/tests/conftest.py b/tests/conftest.py index 0fb5aca..18a0691 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio import typing from pathlib import Path @@ -13,6 +14,9 @@ async def app(nonebug_init: None, tmp_path: Path, monkeypatch: pytest.MonkeyPatc config = nonebot.get_driver().config config.bison_config_path = str(tmp_path) config.command_start = {""} + config.superusers = {"10001"} + config.log_level = "TRACE" + config.bison_filter_log = False return App(monkeypatch) @@ -20,5 +24,24 @@ async def app(nonebug_init: None, tmp_path: Path, monkeypatch: pytest.MonkeyPatc def dummy_user_subinfo(app: App): from nonebot_bison.types import User, UserSubInfo - user = User("123", "group") + user = User(123, "group") return UserSubInfo(user=user, category_getter=lambda _: [], tag_getter=lambda _: []) + + +@pytest.fixture +def task_watchdog(request): + def cancel_test_on_exception(task: asyncio.Task): + def maybe_cancel_clbk(t: asyncio.Task): + exception = t.exception() + if exception is None: + return + + for task in asyncio.all_tasks(): + coro = task.get_coro() + if coro.__qualname__ == request.function.__qualname__: + task.cancel() + return + + task.add_done_callback(maybe_cancel_clbk) + + return cancel_test_on_exception diff --git a/tests/test_config_manager.py b/tests/test_config_manager.py index 160563d..771170a 100644 --- a/tests/test_config_manager.py +++ b/tests/test_config_manager.py @@ -453,3 +453,17 @@ async def test_del_sub(app: App): ctx.should_finished() subs = config.list_subscribe(10000, "group") assert len(subs) == 0 + + +async def test_test(app: App): + from nonebot.adapters.onebot.v11.bot import Bot + from nonebot.adapters.onebot.v11.message import Message + from nonebot_bison.config_manager import test_matcher + + async with app.test_matcher(test_matcher) as ctx: + bot = ctx.create_bot(base=Bot) + event = fake_group_message_event(message=Message("testtt")) + ctx.receive_event(bot, event) + ctx.should_pass_permission() + ctx.should_pass_rule() + ctx.should_call_send(event, "666", True) diff --git a/tests/test_config_manager_admin.py b/tests/test_config_manager_admin.py new file mode 100644 index 0000000..fc6581d --- /dev/null +++ b/tests/test_config_manager_admin.py @@ -0,0 +1,45 @@ +from nonebug import App + +from .utils import fake_admin_user, fake_private_message_event, fake_superuser + + +async def test_query(app: App): + from nonebot.adapters.onebot.v11.bot import Bot + from nonebot.adapters.onebot.v11.message import Message + from nonebot_bison.config_manager import group_manage_matcher + + async with app.test_matcher(group_manage_matcher) as ctx: + bot = ctx.create_bot(base=Bot) + event = fake_private_message_event( + message=Message("群管理"), sender=fake_superuser + ) + ctx.receive_event(bot, event) + ctx.should_pass_rule() + ctx.should_pass_permission() + ctx.should_call_api( + "get_group_list", {}, [{"group_id": 101, "group_name": "test group"}] + ) + ctx.should_call_send( + event, Message("请选择需要管理的群:\n1. 101 - test group\n请输入左侧序号"), True + ) + event_1_err = fake_private_message_event( + message=Message("0"), sender=fake_superuser + ) + ctx.receive_event(bot, event_1_err) + ctx.should_rejected() + ctx.should_call_send(event_1_err, "请输入正确序号", True) + event_1_ok = fake_private_message_event( + message=Message("1"), sender=fake_superuser + ) + ctx.receive_event(bot, event_1_ok) + ctx.should_call_send(event_1_ok, "请输入需要使用的命令:添加订阅,查询订阅,删除订阅", True) + event_2_err = fake_private_message_event( + message=Message("222"), sender=fake_superuser + ) + ctx.receive_event(bot, event_2_err) + ctx.should_rejected() + ctx.should_call_send(event_2_err, "请输入正确的命令", True) + event_2_ok = fake_private_message_event( + message=Message("查询订阅"), sender=fake_superuser + ) + ctx.receive_event(bot, event_2_ok) diff --git a/tests/utils.py b/tests/utils.py index 1212a20..a31b67e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -68,3 +68,4 @@ def fake_private_message_event(**field) -> "PrivateMessageEvent": from nonebot.adapters.onebot.v11.event import Sender fake_admin_user = Sender(nickname="test", role="admin") +fake_superuser = Sender(user_id=10001, nickname="superuser")