diff --git a/nonebot_bison/admin_page/api.py b/nonebot_bison/admin_page/api.py index 8b3f554..9f2b00d 100644 --- a/nonebot_bison/admin_page/api.py +++ b/nonebot_bison/admin_page/api.py @@ -4,6 +4,7 @@ from fastapi.exceptions import HTTPException from fastapi.param_functions import Depends from fastapi.routing import APIRouter from fastapi.security.oauth2 import OAuth2PasswordBearer +from nonebot_plugin_saa import TargetQQGroup from ..apis import check_sub_target from ..config import ( @@ -15,7 +16,7 @@ from ..config import ( from ..config.db_config import SubscribeDupException from ..platform import platform_manager from ..types import Target as T_Target -from ..types import User, WeightConfig +from ..types import WeightConfig from ..utils.get_bot import get_bot, get_groups from .jwt import load_jwt, pack_jwt from .token_manager import token_manager @@ -75,7 +76,7 @@ async def get_admin_groups(qq: int): res = [] for group in await get_groups(): group_id = group["group_id"] - bot = get_bot(User(group_id, "group")) + bot = get_bot(TargetQQGroup(group_id=group_id)) if not bot: continue users = await bot.get_group_member_list(group_id=group_id) @@ -131,7 +132,7 @@ async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp: res: SubscribeResp = {} for group in groups: group_id = group["id"] - raw_subs = await config.list_subscribe(group_id, "group") + raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id)) subs = list( map( lambda sub: SubscribeConfig( @@ -157,8 +158,7 @@ async def get_target_name(platformName: str, target: str): async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp: try: await config.add_subscribe( - int(groupNumber), - "group", + TargetQQGroup(group_id=groupNumber), T_Target(req.target), req.targetName, req.platformName, @@ -173,7 +173,9 @@ async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp: @router.delete("/subs", dependencies=[Depends(check_group_permission)]) async def del_group_sub(groupNumber: int, platformName: str, target: str): try: - await config.del_subscribe(int(groupNumber), "group", target, platformName) + await config.del_subscribe( + TargetQQGroup(group_id=groupNumber), target, platformName + ) except (NoSuchUserException, NoSuchSubscribeException): raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe") return StatusResp(ok=True, msg="") diff --git a/nonebot_bison/config/db_config.py b/nonebot_bison/config/db_config.py index 3e9c4df..38ef9af 100644 --- a/nonebot_bison/config/db_config.py +++ b/nonebot_bison/config/db_config.py @@ -4,15 +4,14 @@ from datetime import datetime, time from typing import Awaitable, Callable, Optional, Sequence from nonebot_plugin_datastore import create_session +from nonebot_plugin_saa import PlatformTarget from sqlalchemy import delete, func, select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import selectinload from ..types import Category, PlatformWeightConfigResp, Tag from ..types import Target as T_Target -from ..types import TimeWeightConfig -from ..types import User as T_User -from ..types import UserSubInfo, WeightConfig +from ..types import TimeWeightConfig, UserSubInfo, WeightConfig from .db_model import ScheduleTimeWeight, Subscribe, Target, User from .utils import NoSuchTargetException @@ -40,8 +39,7 @@ class DBConfig: async def add_subscribe( self, - user: int, - user_type: str, + user: PlatformTarget, target: T_Target, target_name: str, platform_name: str, @@ -49,12 +47,10 @@ class DBConfig: tags: list[Tag], ): async with create_session() as session: - db_user_stmt = ( - select(User).where(User.uid == user).where(User.type == user_type) - ) + db_user_stmt = select(User).where(User.user_target == user.dict()) db_user: Optional[User] = await session.scalar(db_user_stmt) if not db_user: - db_user = User(uid=user, type=user_type) + db_user = User(user_target=user.dict()) session.add(db_user) db_target_stmt = ( select(Target) @@ -85,11 +81,11 @@ class DBConfig: raise SubscribeDupException() raise e - async def list_subscribe(self, user: int, user_type: str) -> Sequence[Subscribe]: + async def list_subscribe(self, user: PlatformTarget) -> Sequence[Subscribe]: async with create_session() as session: query_stmt = ( select(Subscribe) - .where(User.type == user_type, User.uid == user) + .where(User.user_target == user.dict()) .join(User) .options(selectinload(Subscribe.target)) ) @@ -109,11 +105,11 @@ class DBConfig: return subs async def del_subscribe( - self, user: int, user_type: str, target: str, platform_name: str + self, user: PlatformTarget, target: str, platform_name: str ): async with create_session() as session: user_obj = await session.scalar( - select(User).where(User.uid == user, User.type == user_type) + select(User).where(User.user_target == user.dict()) ) target_obj = await session.scalar( select(Target).where( @@ -142,8 +138,7 @@ class DBConfig: async def update_subscribe( self, - user: int, - user_type: str, + user: PlatformTarget, target: str, target_name: str, platform_name: str, @@ -154,8 +149,7 @@ class DBConfig: subscribe_obj: Subscribe = await sess.scalar( select(Subscribe) .where( - User.uid == user, - User.type == user_type, + User.user_target == user.dict(), Target.target == target, Target.platform_name == platform_name, ) @@ -272,7 +266,7 @@ class DBConfig: return list( map( lambda subscribe: UserSubInfo( - T_User(subscribe.user.uid, subscribe.user.type), + PlatformTarget.deserialize(subscribe.user.user_target), subscribe.categories, subscribe.tags, ), diff --git a/nonebot_bison/config/db_migration.py b/nonebot_bison/config/db_migration.py index 08af455..e20b1dc 100644 --- a/nonebot_bison/config/db_migration.py +++ b/nonebot_bison/config/db_migration.py @@ -1,5 +1,6 @@ from nonebot.log import logger from nonebot_plugin_datastore.db import get_engine +from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate from sqlalchemy.ext.asyncio.session import AsyncSession from .config_legacy import Config, ConfigContent, drop @@ -21,7 +22,11 @@ async def data_migrate(): subscribe_to_create = [] platform_target_map: dict[str, tuple[Target, str, int]] = {} for user in all_subs: - db_user = User(uid=user["user"], type=user["user_type"]) + if user["user_type"] == "group": + user_target = TargetQQGroup(group_id=user["user"]) + else: + user_target = TargetQQPrivate(user_id=user["user"]) + db_user = User(user_target=user_target.dict()) user_to_create.append(db_user) user_sub_set = set() for sub in user["subs"]: diff --git a/nonebot_bison/config/db_model.py b/nonebot_bison/config/db_model.py index bf45ff7..e1feade 100644 --- a/nonebot_bison/config/db_model.py +++ b/nonebot_bison/config/db_model.py @@ -14,15 +14,15 @@ get_plugin_data().set_migration_dir(Path(__file__).parent / "migrations") class User(Model): - __table_args__ = (UniqueConstraint("type", "uid", name="unique-user-constraint"),) - id: Mapped[int] = mapped_column(primary_key=True) - type: Mapped[str] = mapped_column(String(20)) - uid: Mapped[int] - user_target: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) + user_target: Mapped[dict] = mapped_column(JSON) subscribes: Mapped[list["Subscribe"]] = relationship(back_populates="user") + @property + def saa_target(self) -> PlatformTarget: + return PlatformTarget.deserialize(self.user_target) + class Target(Model): __table_args__ = ( diff --git a/nonebot_bison/config/migrations/632b8086bc2b_add_user_target.py b/nonebot_bison/config/migrations/632b8086bc2b_add_user_target.py index 2e661ac..9a867ac 100644 --- a/nonebot_bison/config/migrations/632b8086bc2b_add_user_target.py +++ b/nonebot_bison/config/migrations/632b8086bc2b_add_user_target.py @@ -18,6 +18,7 @@ depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.drop_constraint("unique-user-constraint", type_="unique") batch_op.add_column(sa.Column("user_target", sa.JSON(), nullable=True)) # ### end Alembic commands ### @@ -27,5 +28,6 @@ def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: batch_op.drop_column("user_target") + batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"]) # ### end Alembic commands ### diff --git a/nonebot_bison/config/migrations/67c38b3f39c2_make_user_target_not_nullable.py b/nonebot_bison/config/migrations/67c38b3f39c2_make_user_target_not_nullable.py new file mode 100644 index 0000000..a08e20a --- /dev/null +++ b/nonebot_bison/config/migrations/67c38b3f39c2_make_user_target_not_nullable.py @@ -0,0 +1,34 @@ +"""make user_target not nullable + +Revision ID: 67c38b3f39c2 +Revises: a5466912fad0 +Create Date: 2023-03-20 11:08:42.883556 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import sqlite + +# revision identifiers, used by Alembic. +revision = "67c38b3f39c2" +down_revision = "a5466912fad0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.alter_column( + "user_target", existing_type=sqlite.JSON(), nullable=False + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.alter_column("user_target", existing_type=sqlite.JSON(), nullable=True) + + # ### end Alembic commands ### diff --git a/nonebot_bison/config/migrations/8d3863e9d74b_remove_uid_and_type.py b/nonebot_bison/config/migrations/8d3863e9d74b_remove_uid_and_type.py new file mode 100644 index 0000000..18890dd --- /dev/null +++ b/nonebot_bison/config/migrations/8d3863e9d74b_remove_uid_and_type.py @@ -0,0 +1,33 @@ +"""remove uid and type + +Revision ID: 8d3863e9d74b +Revises: 67c38b3f39c2 +Create Date: 2023-03-20 15:38:20.220599 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8d3863e9d74b" +down_revision = "67c38b3f39c2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.drop_column("uid") + batch_op.drop_column("type") + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.add_column(sa.Column("type", sa.VARCHAR(length=20), nullable=False)) + batch_op.add_column(sa.Column("uid", sa.INTEGER(), nullable=False)) + + # ### end Alembic commands ### diff --git a/nonebot_bison/config_manager.py b/nonebot_bison/config_manager.py index 5a1a546..87ff843 100644 --- a/nonebot_bison/config_manager.py +++ b/nonebot_bison/config_manager.py @@ -15,13 +15,14 @@ from nonebot.params import Depends, EventPlainText, EventToMe from nonebot.permission import SUPERUSER from nonebot.rule import to_me from nonebot.typing import T_State +from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, extract_target from .apis import check_sub_target from .config import config from .config.db_config import SubscribeDupException from .platform import Platform, platform_manager from .plugin_config import plugin_config -from .types import Category, Target, User +from .types import Category, Target from .utils import parse_text @@ -61,12 +62,8 @@ def ensure_user_info(matcher: Type[Matcher]): async def set_target_user_info(event: MessageEvent, state: T_State): - if isinstance(event, GroupMessageEvent): - user = User(event.group_id, "group") - state["target_user_info"] = user - elif isinstance(event, PrivateMessageEvent): - user = User(event.user_id, "private") - state["target_user_info"] = user + user = extract_target(event) + state["target_user_info"] = user def do_add_sub(add_sub: Type[Matcher]): @@ -201,14 +198,11 @@ def do_add_sub(add_sub: Type[Matcher]): @add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)]) async def add_sub_process(event: Event, state: T_State): - user = cast(User, state.get("target_user_info")) - assert isinstance(user, User) + user = cast(PlatformTarget, state.get("target_user_info")) + assert isinstance(user, PlatformTarget) try: await config.add_subscribe( - # state.get("_user_id") or event.group_id, - # user_type="group", - user=user.user, - user_type=user.user_type, + user=user, target=state["id"], target_name=state["name"], platform_name=state["platform"], @@ -228,12 +222,8 @@ def do_query_sub(query_sub: Type[Matcher]): @query_sub.handle() async def _(state: T_State): user_info = state["target_user_info"] - assert isinstance(user_info, User) - sub_list = await config.list_subscribe( - # state.get("_user_id") or event.group_id, "group" - user_info.user, - user_info.user_type, - ) + assert isinstance(user_info, PlatformTarget) + sub_list = await config.list_subscribe(user_info) res = "订阅的帐号为:\n" for sub in sub_list: res += "{} {} {}".format( @@ -261,13 +251,9 @@ def do_del_sub(del_sub: Type[Matcher]): @del_sub.handle() async def send_list(bot: Bot, event: Event, state: T_State): user_info = state["target_user_info"] - assert isinstance(user_info, User) + assert isinstance(user_info, PlatformTarget) try: - sub_list = await config.list_subscribe( - # state.get("_user_id") or event.group_id, "group" - user_info.user, - user_info.user_type, - ) + sub_list = await config.list_subscribe(user_info) assert sub_list except AssertionError: await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅") @@ -309,14 +295,8 @@ def do_del_sub(del_sub: Type[Matcher]): try: index = int(user_msg) user_info = state["target_user_info"] - assert isinstance(user_info, User) - await config.del_subscribe( - # state.get("_user_id") or event.group_id, - # "group", - user_info.user, - user_info.user_type, - **state["sub_table"][index], - ) + assert isinstance(user_info, PlatformTarget) + await config.del_subscribe(user_info, **state["sub_table"][index]) except Exception as e: await del_sub.reject("删除错误") else: @@ -398,7 +378,7 @@ async def do_choose_group_number(state: T_State): group_number_idx: dict[int, int] = state["group_number_idx"] idx: int = state["group_idx"] group_id = group_number_idx[idx] - state["target_user_info"] = User(user=group_id, user_type="group") + state["target_user_info"] = TargetQQGroup(group_id=group_id) async def _check_command(event_msg: str = EventPlainText()): diff --git a/nonebot_bison/platform/platform.py b/nonebot_bison/platform/platform.py index c05daaa..283bf44 100644 --- a/nonebot_bison/platform/platform.py +++ b/nonebot_bison/platform/platform.py @@ -10,10 +10,11 @@ from typing import Any, Collection, Optional, Type import httpx from httpx import AsyncClient from nonebot.log import logger +from nonebot_plugin_saa import PlatformTarget from ..plugin_config import plugin_config from ..post import Post -from ..types import Category, RawPost, Tag, Target, User, UserSubInfo +from ..types import Category, RawPost, Tag, Target, UserSubInfo from ..utils import ProcessContext, SchedulerConfig @@ -84,12 +85,12 @@ class Platform(metaclass=PlatformABCMeta, base=True): @abstractmethod async def fetch_new_post( self, target: Target, users: list[UserSubInfo] - ) -> list[tuple[User, list[Post]]]: + ) -> list[tuple[PlatformTarget, list[Post]]]: ... async def do_fetch_new_post( self, target: Target, users: list[UserSubInfo] - ) -> list[tuple[User, list[Post]]]: + ) -> list[tuple[PlatformTarget, list[Post]]]: try: return await self.fetch_new_post(target, users) except httpx.RequestError as err: @@ -197,8 +198,8 @@ class Platform(metaclass=PlatformABCMeta, base=True): async def dispatch_user_post( self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo] - ) -> list[tuple[User, list[Post]]]: - res: list[tuple[User, list[Post]]] = [] + ) -> list[tuple[PlatformTarget, list[Post]]]: + res: list[tuple[PlatformTarget, list[Post]]] = [] for user, cats, required_tags in users: user_raw_post = await self.filter_user_custom( new_posts, cats, required_tags @@ -314,7 +315,7 @@ class NewMessage(MessageProcess, abstract=True): async def fetch_new_post( self, target: Target, users: list[UserSubInfo] - ) -> list[tuple[User, list[Post]]]: + ) -> list[tuple[PlatformTarget, list[Post]]]: post_list = await self.get_sub_list(target) new_posts = await self.filter_common_with_diff(target, post_list) if not new_posts: @@ -353,7 +354,7 @@ class StatusChange(Platform, abstract=True): async def fetch_new_post( self, target: Target, users: list[UserSubInfo] - ) -> list[tuple[User, list[Post]]]: + ) -> list[tuple[PlatformTarget, list[Post]]]: try: new_status = await self.get_status(target) except self.FetchError as err: @@ -381,7 +382,7 @@ class SimplePost(MessageProcess, abstract=True): async def fetch_new_post( self, target: Target, users: list[UserSubInfo] - ) -> list[tuple[User, list[Post]]]: + ) -> list[tuple[PlatformTarget, list[Post]]]: new_posts = await self.get_sub_list(target) if not new_posts: return [] diff --git a/nonebot_bison/types.py b/nonebot_bison/types.py index 838ce58..b50bdf0 100644 --- a/nonebot_bison/types.py +++ b/nonebot_bison/types.py @@ -3,6 +3,7 @@ from datetime import time from typing import Any, Literal, NamedTuple, NewType from httpx import URL +from nonebot_plugin_saa import PlatformTarget as SendTarget from pydantic import BaseModel RawPost = Any @@ -25,7 +26,7 @@ class PlatformTarget: class UserSubInfo(NamedTuple): - user: User + user: SendTarget categories: list[Category] tags: list[Tag] diff --git a/nonebot_bison/utils/get_bot.py b/nonebot_bison/utils/get_bot.py index e48e7d3..926dcd3 100644 --- a/nonebot_bison/utils/get_bot.py +++ b/nonebot_bison/utils/get_bot.py @@ -1,53 +1,57 @@ """ 提供获取 Bot 的方法 """ import random +from collections import defaultdict from typing import Any, Optional import nonebot from nonebot import get_driver, on_notice +from nonebot.adapters import Bot +from nonebot.adapters.onebot.v11 import Bot as Ob11Bot from nonebot.adapters.onebot.v11 import ( - Bot, FriendAddNoticeEvent, GroupDecreaseNoticeEvent, GroupIncreaseNoticeEvent, ) - -from ..types import User +from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, TargetQQPrivate GROUP: dict[int, list[Bot]] = {} USER: dict[int, list[Bot]] = {} +BOT_CACHE: dict[PlatformTarget, list[Bot]] = defaultdict(list) def get_bots() -> list[Bot]: """获取所有 OneBot 11 Bot""" + # TODO: support ob12 bots = [] for bot in nonebot.get_bots().values(): - if isinstance(bot, Bot): + if isinstance(bot, Ob11Bot): bots.append(bot) return bots +async def _refresh_ob11(bot: Ob11Bot): + # 获取群列表 + groups = await bot.get_group_list() + for group in groups: + group_id = group["group_id"] + target = TargetQQGroup(group_id=group_id) + BOT_CACHE[target].append(bot) + + # 获取好友列表 + users = await bot.get_friend_list() + for user in users: + user_id = user["user_id"] + target = TargetQQPrivate(user_id=user_id) + BOT_CACHE[target].append(bot) + + async def refresh_bots(): """刷新缓存的 Bot 数据""" - GROUP.clear() - USER.clear() + BOT_CACHE.clear() for bot in get_bots(): - # 获取群列表 - groups = await bot.get_group_list() - for group in groups: - group_id = group["group_id"] - if group_id not in GROUP: - GROUP[group_id] = [bot] - else: - GROUP[group_id].append(bot) - - # 获取好友列表 - users = await bot.get_friend_list() - for user in users: - user_id = user["user_id"] - if user_id not in USER: - USER[user_id] = [bot] - else: - USER[user_id].append(bot) + match bot: + case Ob11Bot(): + await _refresh_ob11(bot) driver = get_driver() @@ -75,15 +79,9 @@ async def _(bot: Bot, event: GroupDecreaseNoticeEvent | GroupIncreaseNoticeEvent await refresh_bots() -def get_bot(user: User) -> Optional[Bot]: +def get_bot(user: PlatformTarget) -> Optional[Bot]: """获取 Bot""" - bots = [] - if user.user_type == "group": - bots = GROUP.get(user.user, []) - - if user.user_type == "private": - bots = USER.get(user.user, []) - + bots = BOT_CACHE.get(user) if not bots: return @@ -92,6 +90,7 @@ def get_bot(user: User) -> Optional[Bot]: async def get_groups() -> list[dict[str, Any]]: """获取所有群号""" + # TODO all_groups: dict[int, dict[str, Any]] = {} for bot in get_bots(): groups = await bot.get_group_list() diff --git a/poetry.lock b/poetry.lock index d53d873..c820acd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1470,16 +1470,20 @@ name = "nonebot-plugin-send-anything-anywhere" version = "0.2.4" description = "An adaptor for nonebot2 adaptors" optional = false -python-versions = ">=3.8,<4.0" -files = [ - {file = "nonebot_plugin_send_anything_anywhere-0.2.4-py3-none-any.whl", hash = "sha256:97c1c1565479c1750c21ce471545ea293a1f26d606cbe5ae071dab0047200408"}, - {file = "nonebot_plugin_send_anything_anywhere-0.2.4.tar.gz", hash = "sha256:71217c6bd7f84d6f3d266914562c60dadf9b28e66801c3996d6d7c36bafa7fca"}, -] +python-versions = "^3.8" +files = [] +develop = false [package.dependencies] -nonebot2 = ">=2.0.0rc1,<3.0.0" -pydantic = ">=1.10.5,<2.0.0" -strenum = ">=0.4.8,<0.5.0" +nonebot2 = "^2.0.0rc1" +pydantic = "^1.10.5" +strenum = "^0.4.8" + +[package.source] +type = "git" +url = "https://github.com/felinae98/nonebot-plugin-send-anything-anywhere.git" +reference = "main" +resolved_reference = "7f8a57afc72b5b6a7f909935f1a87411bf597173" [[package]] name = "nonebot2" @@ -2869,4 +2873,4 @@ yaml = [] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0.0" -content-hash = "a8af95b0b5285f14d48ba11d7237cf636ca2102e7374d07d6b808eb5fdba8a76" +content-hash = "efba4feca911691e91af2b93cb810268f6e35a6e985811587e6b00999c2bd263" diff --git a/pyproject.toml b/pyproject.toml index a6a7d4f..69661f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ nonebot-adapter-onebot = "^2.0.0-beta.1" nonebot-plugin-htmlrender = ">=0.2.0" nonebot-plugin-datastore = "^0.6.2" nonebot-plugin-apscheduler = "^0.2.0" -nonebot-plugin-send-anything-anywhere = "^0.2.1" +nonebot-plugin-send-anything-anywhere = {git = "https://github.com/felinae98/nonebot-plugin-send-anything-anywhere.git", rev = "main"} [tool.poetry.group.dev.dependencies] ipdb = "^0.13.4" diff --git a/tests/config/test_config_operation.py b/tests/config/test_config_operation.py index 6185e96..6b25960 100644 --- a/tests/config/test_config_operation.py +++ b/tests/config/test_config_operation.py @@ -4,6 +4,7 @@ from nonebug.app import App async def test_add_subscribe(app: App, init_scheduler): from nonebot_plugin_datastore.db import get_engine + from nonebot_plugin_saa import TargetQQGroup from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.sql.expression import select @@ -12,8 +13,7 @@ async def test_add_subscribe(app: App, init_scheduler): from nonebot_bison.types import Target as TTarget await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=TTarget("weibo_id"), target_name="weibo_name", platform_name="weibo", @@ -21,15 +21,14 @@ async def test_add_subscribe(app: App, init_scheduler): tags=[], ) await config.add_subscribe( - user=234, - user_type="group", + TargetQQGroup(group_id=234), target=TTarget("weibo_id"), target_name="weibo_name", platform_name="weibo", cats=[], tags=[], ) - confs = await config.list_subscribe(123, "group") + confs = await config.list_subscribe(TargetQQGroup(group_id=123)) assert len(confs) == 1 conf: Subscribe = confs[0] async with AsyncSession(get_engine()) as sess: @@ -39,22 +38,23 @@ async def test_add_subscribe(app: App, init_scheduler): related_target_obj = await sess.scalar( select(Target).where(Target.id == conf.target_id) ) - assert related_user_obj.uid == 123 + assert related_user_obj + assert related_target_obj + assert related_user_obj.user_target["group_id"] == 123 assert related_target_obj.target_name == "weibo_name" assert related_target_obj.target == "weibo_id" assert conf.target.target == "weibo_id" assert conf.categories == [] await config.update_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=TTarget("weibo_id"), platform_name="weibo", target_name="weibo_name2", cats=[1], tags=["tag"], ) - confs = await config.list_subscribe(123, "group") + confs = await config.list_subscribe(TargetQQGroup(group_id=123)) assert len(confs) == 1 conf: Subscribe = confs[0] async with AsyncSession(get_engine()) as sess: @@ -64,7 +64,9 @@ async def test_add_subscribe(app: App, init_scheduler): related_target_obj = await sess.scalar( select(Target).where(Target.id == conf.target_id) ) - assert related_user_obj.uid == 123 + assert related_user_obj + assert related_target_obj + assert related_user_obj.user_target["group_id"] == 123 assert related_target_obj.target_name == "weibo_name2" assert related_target_obj.target == "weibo_id" assert conf.target.target == "weibo_id" @@ -73,12 +75,13 @@ async def test_add_subscribe(app: App, init_scheduler): async def test_add_dup_sub(init_scheduler): + from nonebot_plugin_saa import TargetQQGroup + from nonebot_bison.config.db_config import SubscribeDupException, config from nonebot_bison.types import Target as TTarget await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=TTarget("weibo_id"), target_name="weibo_name", platform_name="weibo", @@ -88,8 +91,7 @@ async def test_add_dup_sub(init_scheduler): with pytest.raises(SubscribeDupException): await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=TTarget("weibo_id"), target_name="weibo_name", platform_name="weibo", @@ -100,6 +102,7 @@ async def test_add_dup_sub(init_scheduler): async def test_del_subsribe(init_scheduler): from nonebot_plugin_datastore.db import get_engine + from nonebot_plugin_saa import TargetQQGroup from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.sql.expression import select from sqlalchemy.sql.functions import func @@ -109,8 +112,7 @@ async def test_del_subsribe(init_scheduler): from nonebot_bison.types import Target as TTarget await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=TTarget("weibo_id"), target_name="weibo_name", platform_name="weibo", @@ -118,8 +120,7 @@ async def test_del_subsribe(init_scheduler): tags=[], ) await config.del_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=TTarget("weibo_id"), platform_name="weibo", ) @@ -128,8 +129,7 @@ async def test_del_subsribe(init_scheduler): assert (await sess.scalar(select(func.count()).select_from(Target))) == 1 await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=TTarget("weibo_id"), target_name="weibo_name", platform_name="weibo", @@ -138,8 +138,7 @@ async def test_del_subsribe(init_scheduler): ) await config.add_subscribe( - user=124, - user_type="group", + TargetQQGroup(group_id=124), target=TTarget("weibo_id"), target_name="weibo_name_new", platform_name="weibo", @@ -148,8 +147,7 @@ async def test_del_subsribe(init_scheduler): ) await config.del_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=TTarget("weibo_id"), platform_name="weibo", ) @@ -157,5 +155,6 @@ async def test_del_subsribe(init_scheduler): async with AsyncSession(get_engine()) as sess: assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 1 assert (await sess.scalar(select(func.count()).select_from(Target))) == 1 - target: Target = await sess.scalar(select(Target)) + target = await sess.scalar(select(Target)) + assert target assert target.target_name == "weibo_name_new" diff --git a/tests/config/test_data_migration.py b/tests/config/test_data_migration.py index c0d0df6..06154cf 100644 --- a/tests/config/test_data_migration.py +++ b/tests/config/test_data_migration.py @@ -1,5 +1,6 @@ async def test_migration(use_legacy_config): from nonebot_plugin_datastore.db import init_db + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config.config_legacy import Config from nonebot_bison.config.db_config import config @@ -34,7 +35,7 @@ async def test_migration(use_legacy_config): ) # await data_migrate() await init_db() - user123_config = await config.list_subscribe(123, "group") + user123_config = await config.list_subscribe(TargetQQGroup(group_id=123)) assert len(user123_config) == 2 for c in user123_config: if c.target.target == "weibo_id": @@ -47,7 +48,7 @@ async def test_migration(use_legacy_config): assert c.target.target_name == "weibo_name2" assert c.target.platform_name == "weibo" assert c.tags == ["tag"] - user234_config = await config.list_subscribe(234, "group") + user234_config = await config.list_subscribe(TargetQQGroup(group_id=234)) assert len(user234_config) == 1 assert user234_config[0].categories == [1] assert user234_config[0].target.target == "weibo_id" @@ -57,6 +58,7 @@ async def test_migration(use_legacy_config): async def test_migrate_dup(use_legacy_config): from nonebot_plugin_datastore.db import init_db + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config.config_legacy import Config from nonebot_bison.config.db_config import config @@ -82,5 +84,5 @@ async def test_migrate_dup(use_legacy_config): ) # await data_migrate() await init_db() - user123_config = await config.list_subscribe(123, "group") + user123_config = await config.list_subscribe(TargetQQGroup(group_id=123)) assert len(user123_config) == 1 diff --git a/tests/config/test_scheduler_conf.py b/tests/config/test_scheduler_conf.py index c3d7f14..0f2fcb5 100644 --- a/tests/config/test_scheduler_conf.py +++ b/tests/config/test_scheduler_conf.py @@ -6,14 +6,14 @@ from pytest_mock import MockerFixture async def test_create_config(init_scheduler): from nonebot_plugin_datastore.db import get_engine + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config from nonebot_bison.config.db_model import Subscribe, Target, User from nonebot_bison.types import Target as T_Target await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=T_Target("weibo_id"), target_name="weibo_name", platform_name="weibo", @@ -21,8 +21,7 @@ async def test_create_config(init_scheduler): tags=[], ) await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=T_Target("weibo_id1"), target_name="weibo_name1", platform_name="weibo", @@ -58,6 +57,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture): from datetime import time from nonebot_plugin_datastore.db import get_engine + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config import db_config from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config @@ -65,8 +65,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture): from nonebot_bison.types import Target as T_Target await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=T_Target("weibo_id"), target_name="weibo_name", platform_name="weibo", @@ -74,8 +73,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture): tags=[], ) await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=T_Target("weibo_id1"), target_name="weibo_name1", platform_name="weibo", @@ -83,8 +81,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture): tags=[], ) await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=T_Target("weibo_id1"), target_name="weibo_name2", platform_name="bilibili", @@ -124,6 +121,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture): async def test_get_platform_target(app: App, init_scheduler): from nonebot_plugin_datastore.db import get_engine + from nonebot_plugin_saa import TargetQQGroup from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.sql.expression import select @@ -133,8 +131,7 @@ async def test_get_platform_target(app: App, init_scheduler): from nonebot_bison.types import Target as T_Target await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=T_Target("weibo_id"), target_name="weibo_name", platform_name="weibo", @@ -142,8 +139,7 @@ async def test_get_platform_target(app: App, init_scheduler): tags=[], ) await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=T_Target("weibo_id1"), target_name="weibo_name1", platform_name="weibo", @@ -151,8 +147,7 @@ async def test_get_platform_target(app: App, init_scheduler): tags=[], ) await config.add_subscribe( - user=245, - user_type="group", + TargetQQGroup(group_id=245), target=T_Target("weibo_id1"), target_name="weibo_name1", platform_name="weibo", @@ -161,10 +156,14 @@ async def test_get_platform_target(app: App, init_scheduler): ) res = await config.get_platform_target("weibo") assert len(res) == 2 - await config.del_subscribe(123, "group", T_Target("weibo_id1"), "weibo") + await config.del_subscribe( + TargetQQGroup(group_id=123), T_Target("weibo_id1"), "weibo" + ) res = await config.get_platform_target("weibo") assert len(res) == 2 - await config.del_subscribe(123, "group", T_Target("weibo_id"), "weibo") + await config.del_subscribe( + TargetQQGroup(group_id=123), T_Target("weibo_id"), "weibo" + ) res = await config.get_platform_target("weibo") assert len(res) == 1 @@ -175,6 +174,7 @@ async def test_get_platform_target(app: App, init_scheduler): async def test_get_platform_target_subscribers(app: App, init_scheduler): from nonebot_plugin_datastore.db import get_engine + from nonebot_plugin_saa import TargetQQGroup from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.sql.expression import select @@ -182,12 +182,10 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler): from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config from nonebot_bison.config.db_model import Subscribe, Target, User from nonebot_bison.types import Target as T_Target - from nonebot_bison.types import User as T_User from nonebot_bison.types import UserSubInfo await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=T_Target("weibo_id"), target_name="weibo_name", platform_name="weibo", @@ -195,8 +193,7 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler): tags=["tag1"], ) await config.add_subscribe( - user=123, - user_type="group", + TargetQQGroup(group_id=123), target=T_Target("weibo_id1"), target_name="weibo_name1", platform_name="weibo", @@ -204,8 +201,7 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler): tags=["tag2"], ) await config.add_subscribe( - user=245, - user_type="group", + TargetQQGroup(group_id=245), target=T_Target("weibo_id1"), target_name="weibo_name1", platform_name="weibo", @@ -215,9 +211,9 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler): res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id")) assert len(res) == 1 - assert res[0] == UserSubInfo(T_User(123, "group"), [1], ["tag1"]) + assert res[0] == UserSubInfo(TargetQQGroup(group_id=123), [1], ["tag1"]) res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id1")) assert len(res) == 2 - assert UserSubInfo(T_User(123, "group"), [2], ["tag2"]) in res - assert UserSubInfo(T_User(245, "group"), [3], ["tag3"]) in res + assert UserSubInfo(TargetQQGroup(group_id=123), [2], ["tag2"]) in res + assert UserSubInfo(TargetQQGroup(group_id=245), [3], ["tag3"]) in res diff --git a/tests/conftest.py b/tests/conftest.py index 0d0c7d7..4732b48 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -67,9 +67,11 @@ async def app(tmp_path: Path, request: pytest.FixtureRequest, mocker: MockerFixt @pytest.fixture def dummy_user_subinfo(app: App): - from nonebot_bison.types import User, UserSubInfo + from nonebot_plugin_saa import TargetQQGroup - user = User(123, "group") + from nonebot_bison.types import UserSubInfo + + user = TargetQQGroup(group_id=123) return UserSubInfo(user=user, categories=[], tags=[]) diff --git a/tests/platforms/test_bilibili_live.py b/tests/platforms/test_bilibili_live.py index b9a5e25..ed202d0 100644 --- a/tests/platforms/test_bilibili_live.py +++ b/tests/platforms/test_bilibili_live.py @@ -16,9 +16,11 @@ def bili_live(app: App): @pytest.fixture def dummy_only_open_user_subinfo(app: App): - from nonebot_bison.types import User, UserSubInfo + from nonebot_plugin_saa import TargetQQGroup - user = User(123, "group") + from nonebot_bison.types import UserSubInfo + + user = TargetQQGroup(group_id=123) return UserSubInfo(user=user, categories=[1], tags=[]) @@ -68,9 +70,11 @@ async def test_fetch_bililive_only_live_open(bili_live, dummy_only_open_user_sub @pytest.fixture def dummy_only_title_user_subinfo(app: App): - from nonebot_bison.types import User, UserSubInfo + from nonebot_plugin_saa import TargetQQGroup - user = User(123, "group") + from nonebot_bison.types import UserSubInfo + + user = TargetQQGroup(group_id=123) return UserSubInfo(user=user, categories=[2], tags=[]) @@ -128,9 +132,11 @@ async def test_fetch_bililive_only_title_change( @pytest.fixture def dummy_only_close_user_subinfo(app: App): - from nonebot_bison.types import User, UserSubInfo + from nonebot_plugin_saa import TargetQQGroup - user = User(123, "group") + from nonebot_bison.types import UserSubInfo + + user = TargetQQGroup(group_id=123) return UserSubInfo(user=user, categories=[3], tags=[]) @@ -187,9 +193,11 @@ async def test_fetch_bililive_only_close(bili_live, dummy_only_close_user_subinf @pytest.fixture def dummy_bililive_user_subinfo(app: App): - from nonebot_bison.types import User, UserSubInfo + from nonebot_plugin_saa import TargetQQGroup - user = User(123, "group") + from nonebot_bison.types import UserSubInfo + + user = TargetQQGroup(group_id=123) return UserSubInfo(user=user, categories=[1, 2, 3], tags=[]) diff --git a/tests/platforms/test_platform.py b/tests/platforms/test_platform.py index 19c72e9..ce49393 100644 --- a/tests/platforms/test_platform.py +++ b/tests/platforms/test_platform.py @@ -24,9 +24,9 @@ raw_post_list_2 = raw_post_list_1 + [ @pytest.fixture def dummy_user(app: App): - from nonebot_bison.types import User + from nonebot_plugin_saa import TargetQQGroup - user = User(123, "group") + user = TargetQQGroup(group_id=123) return user diff --git a/tests/scheduler/test_scheduler.py b/tests/scheduler/test_scheduler.py index 0777d9b..62b127c 100644 --- a/tests/scheduler/test_scheduler.py +++ b/tests/scheduler/test_scheduler.py @@ -12,6 +12,8 @@ if typing.TYPE_CHECKING: async def get_schedule_times( scheduler_config: Type["SchedulerConfig"], time: int ) -> dict[str, int]: + from nonebot_plugin_saa import TargetQQGroup + from nonebot_bison.scheduler import scheduler_dict scheduler = scheduler_dict[scheduler_config] @@ -25,6 +27,8 @@ async def get_schedule_times( async def test_scheduler_without_time(init_scheduler): + from nonebot_plugin_saa import TargetQQGroup + from nonebot_bison.config import config from nonebot_bison.config.db_config import WeightConfig from nonebot_bison.platform.bilibili import BilibiliSchedConf @@ -32,13 +36,13 @@ async def test_scheduler_without_time(init_scheduler): from nonebot_bison.types import Target as T_Target await config.add_subscribe( - 123, "group", T_Target("t1"), "target1", "bilibili", [], [] + TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], [] ) await config.add_subscribe( - 123, "group", T_Target("t2"), "target1", "bilibili", [], [] + TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili", [], [] ) await config.add_subscribe( - 123, "group", T_Target("t2"), "target1", "bilibili-live", [], [] + TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili-live", [], [] ) await config.update_time_weight_config( @@ -62,6 +66,8 @@ async def test_scheduler_without_time(init_scheduler): async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixture): + from nonebot_plugin_saa import TargetQQGroup + from nonebot_bison.config import config, db_config from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig from nonebot_bison.platform.bilibili import BilibiliSchedConf @@ -69,13 +75,13 @@ async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixtu from nonebot_bison.types import Target as T_Target await config.add_subscribe( - 123, "group", T_Target("t1"), "target1", "bilibili", [], [] + TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], [] ) await config.add_subscribe( - 123, "group", T_Target("t2"), "target1", "bilibili", [], [] + TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili", [], [] ) await config.add_subscribe( - 123, "group", T_Target("t2"), "target1", "bilibili-live", [], [] + TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili-live", [], [] ) await config.update_time_weight_config( @@ -113,38 +119,42 @@ async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixtu async def test_scheduler_add_new(init_scheduler): + from nonebot_plugin_saa import TargetQQGroup + from nonebot_bison.config import config from nonebot_bison.platform.bilibili import BilibiliSchedConf from nonebot_bison.scheduler.manager import init_scheduler from nonebot_bison.types import Target as T_Target await config.add_subscribe( - 123, "group", T_Target("t1"), "target1", "bilibili", [], [] + TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], [] ) await init_scheduler() await config.add_subscribe( - 2345, "group", T_Target("t1"), "target1", "bilibili", [], [] + TargetQQGroup(group_id=2345), T_Target("t1"), "target1", "bilibili", [], [] ) await config.add_subscribe( - 123, "group", T_Target("t2"), "target2", "bilibili", [], [] + TargetQQGroup(group_id=123), T_Target("t2"), "target2", "bilibili", [], [] ) stat_res = await get_schedule_times(BilibiliSchedConf, 1) assert stat_res["bilibili-t2"] == 1 async def test_schedule_delete(init_scheduler): + from nonebot_plugin_saa import TargetQQGroup + from nonebot_bison.config import config from nonebot_bison.platform.bilibili import BilibiliSchedConf from nonebot_bison.scheduler.manager import init_scheduler from nonebot_bison.types import Target as T_Target await config.add_subscribe( - 123, "group", T_Target("t1"), "target1", "bilibili", [], [] + TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], [] ) await config.add_subscribe( - 123, "group", T_Target("t2"), "target1", "bilibili", [], [] + TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili", [], [] ) await init_scheduler() @@ -153,6 +163,6 @@ async def test_schedule_delete(init_scheduler): assert stat_res["bilibili-t2"] == 1 assert stat_res["bilibili-t1"] == 1 - await config.del_subscribe(123, "group", T_Target("t1"), "bilibili") + await config.del_subscribe(TargetQQGroup(group_id=123), T_Target("t1"), "bilibili") stat_res = await get_schedule_times(BilibiliSchedConf, 2) assert stat_res["bilibili-t2"] == 2 diff --git a/tests/test_config_manager_abort.py b/tests/test_config_manager_abort.py index 188b1d9..da56639 100644 --- a/tests/test_config_manager_abort.py +++ b/tests/test_config_manager_abort.py @@ -279,6 +279,7 @@ async def test_abort_add_on_tag(app: App, init_scheduler): async def test_abort_del_sub(app: App, init_scheduler): from nonebot.adapters.onebot.v11.bot import Bot from nonebot.adapters.onebot.v11.message import Message + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config import config from nonebot_bison.config_manager import del_sub_matcher @@ -286,8 +287,7 @@ async def test_abort_del_sub(app: App, init_scheduler): from nonebot_bison.types import Target as T_Target await config.add_subscribe( - 10000, - "group", + TargetQQGroup(group_id=10000), T_Target("6279793937"), "明日方舟Arknights", "weibo", @@ -316,5 +316,5 @@ async def test_abort_del_sub(app: App, init_scheduler): ctx.receive_event(bot, event_abort) ctx.should_call_send(event_abort, "删除中止", True) ctx.should_finished() - subs = await config.list_subscribe(10000, "group") + subs = await config.list_subscribe(TargetQQGroup(group_id=10000)) assert subs diff --git a/tests/test_config_manager_add.py b/tests/test_config_manager_add.py index 2707d0f..026866e 100644 --- a/tests/test_config_manager_add.py +++ b/tests/test_config_manager_add.py @@ -64,6 +64,7 @@ async def test_configurable_at_me_false(app: App): async def test_add_with_target(app: App, init_scheduler): from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.message import Message + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config import config from nonebot_bison.config_manager import add_sub_matcher, common_platform @@ -171,7 +172,7 @@ async def test_add_with_target(app: App, init_scheduler): event_6_ok, BotReply.add_reply_subscribe_success("明日方舟Arknights"), True ) ctx.should_finished() - subs = await config.list_subscribe(10000, "group") + subs = await config.list_subscribe(TargetQQGroup(group_id=10000)) assert len(subs) == 1 sub = subs[0] assert sub.target.target == "6279793937" @@ -188,6 +189,7 @@ async def test_add_with_target(app: App, init_scheduler): async def test_add_with_target_no_cat(app: App, init_scheduler): from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.message import Message + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config import config from nonebot_bison.config_manager import add_sub_matcher, common_platform @@ -233,7 +235,7 @@ async def test_add_with_target_no_cat(app: App, init_scheduler): event_4_ok, BotReply.add_reply_subscribe_success("塞壬唱片-MSR"), True ) ctx.should_finished() - subs = await config.list_subscribe(10000, "group") + subs = await config.list_subscribe(TargetQQGroup(group_id=10000)) assert len(subs) == 1 sub = subs[0] assert sub.target.target == "32540734" @@ -248,6 +250,7 @@ async def test_add_with_target_no_cat(app: App, init_scheduler): async def test_add_no_target(app: App, init_scheduler): from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.message import Message + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config import config from nonebot_bison.config_manager import add_sub_matcher, common_platform @@ -284,7 +287,7 @@ async def test_add_no_target(app: App, init_scheduler): event_4, BotReply.add_reply_subscribe_success("明日方舟游戏信息"), True ) ctx.should_finished() - subs = await config.list_subscribe(10000, "group") + subs = await config.list_subscribe(TargetQQGroup(group_id=10000)) assert len(subs) == 1 sub = subs[0] assert sub.target.target == "default" @@ -334,6 +337,7 @@ async def test_platform_name_err(app: App): async def test_add_with_get_id(app: App): from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.message import Message, MessageSegment + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config import config from nonebot_bison.config_manager import add_sub_matcher, common_platform @@ -407,7 +411,7 @@ async def test_add_with_get_id(app: App): True, ) ctx.should_finished() - subs = await config.list_subscribe(10000, "group") + subs = await config.list_subscribe(TargetQQGroup(group_id=10000)) assert len(subs) == 0 @@ -416,6 +420,7 @@ async def test_add_with_get_id(app: App): async def test_add_with_bilibili_target_parser(app: App, init_scheduler): from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.message import Message + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config import config from nonebot_bison.config_manager import add_sub_matcher, common_platform @@ -524,7 +529,7 @@ async def test_add_with_bilibili_target_parser(app: App, init_scheduler): event_6, BotReply.add_reply_subscribe_success("明日方舟"), True ) ctx.should_finished() - subs = await config.list_subscribe(10000, "group") + subs = await config.list_subscribe(TargetQQGroup(group_id=10000)) assert len(subs) == 1 sub = subs[0] assert sub.target.target == "161775300" diff --git a/tests/test_config_manager_query_del.py b/tests/test_config_manager_query_del.py index c829ba4..9abb696 100644 --- a/tests/test_config_manager_query_del.py +++ b/tests/test_config_manager_query_del.py @@ -10,6 +10,7 @@ from .utils import fake_admin_user, fake_group_message_event @pytest.mark.asyncio async def test_query_sub(app: App, init_scheduler): from nonebot.adapters.onebot.v11.message import Message + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config import config from nonebot_bison.config_manager import query_sub_matcher @@ -17,8 +18,7 @@ async def test_query_sub(app: App, init_scheduler): from nonebot_bison.types import Target await config.add_subscribe( - 10000, - "group", + TargetQQGroup(group_id=10000), Target("6279793937"), "明日方舟Arknights", "weibo", @@ -40,6 +40,7 @@ async def test_query_sub(app: App, init_scheduler): async def test_del_sub(app: App, init_scheduler): from nonebot.adapters.onebot.v11.bot import Bot from nonebot.adapters.onebot.v11.message import Message + from nonebot_plugin_saa import TargetQQGroup from nonebot_bison.config import config from nonebot_bison.config_manager import del_sub_matcher @@ -47,8 +48,7 @@ async def test_del_sub(app: App, init_scheduler): from nonebot_bison.types import Target await config.add_subscribe( - 10000, - "group", + TargetQQGroup(group_id=10000), Target("6279793937"), "明日方舟Arknights", "weibo", @@ -83,7 +83,7 @@ async def test_del_sub(app: App, init_scheduler): ctx.receive_event(bot, event_1_ok) ctx.should_call_send(event_1_ok, "删除成功", True) ctx.should_finished() - subs = await config.list_subscribe(10000, "group") + subs = await config.list_subscribe(TargetQQGroup(group_id=10000)) assert len(subs) == 0 diff --git a/tests/test_get_bot.py b/tests/test_get_bot.py index 53534be..3c02a0a 100644 --- a/tests/test_get_bot.py +++ b/tests/test_get_bot.py @@ -30,8 +30,8 @@ async def test_refresh_bots(app: App) -> None: from nonebot import get_driver from nonebot.adapters.onebot.v11 import Bot as BotV11 from nonebot.adapters.onebot.v12 import Bot as BotV12 + from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate - from nonebot_bison.types import User from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots async with app.test_api() as ctx: @@ -44,13 +44,13 @@ async def test_refresh_bots(app: App) -> None: ctx.should_call_api("get_group_list", {}, [{"group_id": 1}]) ctx.should_call_api("get_friend_list", {}, [{"user_id": 2}]) - assert get_bot(User(1, "group")) is None - assert get_bot(User(2, "private")) is None + assert get_bot(TargetQQGroup(group_id=1)) is None + assert get_bot(TargetQQPrivate(user_id=2)) is None await refresh_bots() - assert get_bot(User(1, "group")) == botv11 - assert get_bot(User(2, "private")) == botv11 + assert get_bot(TargetQQGroup(group_id=1)) == botv11 + assert get_bot(TargetQQPrivate(user_id=2)) == botv11 # 测试获取群列表 ctx.should_call_api("get_group_list", {}, [{"group_id": 3}]) @@ -66,8 +66,8 @@ async def test_get_bot_two_bots(app: App) -> None: from nonebot import get_driver from nonebot.adapters.onebot.v11 import Bot as BotV11 from nonebot.adapters.onebot.v12 import Bot as BotV12 + from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate - from nonebot_bison.types import User from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots async with app.test_api() as ctx: @@ -85,14 +85,14 @@ async def test_get_bot_two_bots(app: App) -> None: await refresh_bots() - assert get_bot(User(0, "group")) is None - assert get_bot(User(1, "group")) == bot1 - assert get_bot(User(2, "group")) in (bot1, bot2) - assert get_bot(User(3, "group")) == bot2 - assert get_bot(User(0, "private")) is None - assert get_bot(User(1, "private")) == bot1 - assert get_bot(User(2, "private")) in (bot1, bot2) - assert get_bot(User(3, "private")) == bot2 + assert get_bot(TargetQQGroup(group_id=0)) is None + assert get_bot(TargetQQGroup(group_id=1)) == bot1 + assert get_bot(TargetQQGroup(group_id=2)) in (bot1, bot2) + assert get_bot(TargetQQGroup(group_id=3)) == bot2 + assert get_bot(TargetQQPrivate(user_id=0)) is None + assert get_bot(TargetQQPrivate(user_id=1)) == bot1 + assert get_bot(TargetQQPrivate(user_id=2)) in (bot1, bot2) + assert get_bot(TargetQQPrivate(user_id=3)) == bot2 ctx.should_call_api("get_group_list", {}, [{"group_id": 1}, {"group_id": 2}]) ctx.should_call_api("get_group_list", {}, [{"group_id": 2}, {"group_id": 3}])