🚧 remove User type

This commit is contained in:
felinae98 2023-03-20 14:39:04 +08:00
parent d535f5212d
commit 4118329bb0
24 changed files with 298 additions and 221 deletions

View File

@ -4,6 +4,7 @@ from fastapi.exceptions import HTTPException
from fastapi.param_functions import Depends
from fastapi.routing import APIRouter
from fastapi.security.oauth2 import OAuth2PasswordBearer
from nonebot_plugin_saa import TargetQQGroup
from ..apis import check_sub_target
from ..config import (
@ -15,7 +16,7 @@ from ..config import (
from ..config.db_config import SubscribeDupException
from ..platform import platform_manager
from ..types import Target as T_Target
from ..types import User, WeightConfig
from ..types import WeightConfig
from ..utils.get_bot import get_bot, get_groups
from .jwt import load_jwt, pack_jwt
from .token_manager import token_manager
@ -75,7 +76,7 @@ async def get_admin_groups(qq: int):
res = []
for group in await get_groups():
group_id = group["group_id"]
bot = get_bot(User(group_id, "group"))
bot = get_bot(TargetQQGroup(group_id=group_id))
if not bot:
continue
users = await bot.get_group_member_list(group_id=group_id)
@ -131,7 +132,7 @@ async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp:
res: SubscribeResp = {}
for group in groups:
group_id = group["id"]
raw_subs = await config.list_subscribe(group_id, "group")
raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id))
subs = list(
map(
lambda sub: SubscribeConfig(
@ -157,8 +158,7 @@ async def get_target_name(platformName: str, target: str):
async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
try:
await config.add_subscribe(
int(groupNumber),
"group",
TargetQQGroup(group_id=groupNumber),
T_Target(req.target),
req.targetName,
req.platformName,
@ -173,7 +173,9 @@ async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
@router.delete("/subs", dependencies=[Depends(check_group_permission)])
async def del_group_sub(groupNumber: int, platformName: str, target: str):
try:
await config.del_subscribe(int(groupNumber), "group", target, platformName)
await config.del_subscribe(
TargetQQGroup(group_id=groupNumber), target, platformName
)
except (NoSuchUserException, NoSuchSubscribeException):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe")
return StatusResp(ok=True, msg="")

View File

@ -4,15 +4,14 @@ from datetime import datetime, time
from typing import Awaitable, Callable, Optional, Sequence
from nonebot_plugin_datastore import create_session
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from ..types import Category, PlatformWeightConfigResp, Tag
from ..types import Target as T_Target
from ..types import TimeWeightConfig
from ..types import User as T_User
from ..types import UserSubInfo, WeightConfig
from ..types import TimeWeightConfig, UserSubInfo, WeightConfig
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
from .utils import NoSuchTargetException
@ -40,8 +39,7 @@ class DBConfig:
async def add_subscribe(
self,
user: int,
user_type: str,
user: PlatformTarget,
target: T_Target,
target_name: str,
platform_name: str,
@ -49,12 +47,10 @@ class DBConfig:
tags: list[Tag],
):
async with create_session() as session:
db_user_stmt = (
select(User).where(User.uid == user).where(User.type == user_type)
)
db_user_stmt = select(User).where(User.user_target == user.dict())
db_user: Optional[User] = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(uid=user, type=user_type)
db_user = User(user_target=user.dict())
session.add(db_user)
db_target_stmt = (
select(Target)
@ -85,11 +81,11 @@ class DBConfig:
raise SubscribeDupException()
raise e
async def list_subscribe(self, user: int, user_type: str) -> Sequence[Subscribe]:
async def list_subscribe(self, user: PlatformTarget) -> Sequence[Subscribe]:
async with create_session() as session:
query_stmt = (
select(Subscribe)
.where(User.type == user_type, User.uid == user)
.where(User.user_target == user.dict())
.join(User)
.options(selectinload(Subscribe.target))
)
@ -109,11 +105,11 @@ class DBConfig:
return subs
async def del_subscribe(
self, user: int, user_type: str, target: str, platform_name: str
self, user: PlatformTarget, target: str, platform_name: str
):
async with create_session() as session:
user_obj = await session.scalar(
select(User).where(User.uid == user, User.type == user_type)
select(User).where(User.user_target == user.dict())
)
target_obj = await session.scalar(
select(Target).where(
@ -142,8 +138,7 @@ class DBConfig:
async def update_subscribe(
self,
user: int,
user_type: str,
user: PlatformTarget,
target: str,
target_name: str,
platform_name: str,
@ -154,8 +149,7 @@ class DBConfig:
subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe)
.where(
User.uid == user,
User.type == user_type,
User.user_target == user.dict(),
Target.target == target,
Target.platform_name == platform_name,
)
@ -272,7 +266,7 @@ class DBConfig:
return list(
map(
lambda subscribe: UserSubInfo(
T_User(subscribe.user.uid, subscribe.user.type),
PlatformTarget.deserialize(subscribe.user.user_target),
subscribe.categories,
subscribe.tags,
),

View File

@ -1,5 +1,6 @@
from nonebot.log import logger
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from sqlalchemy.ext.asyncio.session import AsyncSession
from .config_legacy import Config, ConfigContent, drop
@ -21,7 +22,11 @@ async def data_migrate():
subscribe_to_create = []
platform_target_map: dict[str, tuple[Target, str, int]] = {}
for user in all_subs:
db_user = User(uid=user["user"], type=user["user_type"])
if user["user_type"] == "group":
user_target = TargetQQGroup(group_id=user["user"])
else:
user_target = TargetQQPrivate(user_id=user["user"])
db_user = User(user_target=user_target.dict())
user_to_create.append(db_user)
user_sub_set = set()
for sub in user["subs"]:

View File

@ -14,15 +14,15 @@ get_plugin_data().set_migration_dir(Path(__file__).parent / "migrations")
class User(Model):
__table_args__ = (UniqueConstraint("type", "uid", name="unique-user-constraint"),)
id: Mapped[int] = mapped_column(primary_key=True)
type: Mapped[str] = mapped_column(String(20))
uid: Mapped[int]
user_target: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
user_target: Mapped[dict] = mapped_column(JSON)
subscribes: Mapped[list["Subscribe"]] = relationship(back_populates="user")
@property
def saa_target(self) -> PlatformTarget:
return PlatformTarget.deserialize(self.user_target)
class Target(Model):
__table_args__ = (

View File

@ -18,6 +18,7 @@ depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.drop_constraint("unique-user-constraint", type_="unique")
batch_op.add_column(sa.Column("user_target", sa.JSON(), nullable=True))
# ### end Alembic commands ###
@ -27,5 +28,6 @@ def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.drop_column("user_target")
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
# ### end Alembic commands ###

View File

@ -0,0 +1,34 @@
"""make user_target not nullable
Revision ID: 67c38b3f39c2
Revises: a5466912fad0
Create Date: 2023-03-20 11:08:42.883556
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import sqlite
# revision identifiers, used by Alembic.
revision = "67c38b3f39c2"
down_revision = "a5466912fad0"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.alter_column(
"user_target", existing_type=sqlite.JSON(), nullable=False
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.alter_column("user_target", existing_type=sqlite.JSON(), nullable=True)
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""remove uid and type
Revision ID: 8d3863e9d74b
Revises: 67c38b3f39c2
Create Date: 2023-03-20 15:38:20.220599
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "8d3863e9d74b"
down_revision = "67c38b3f39c2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.drop_column("uid")
batch_op.drop_column("type")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.add_column(sa.Column("type", sa.VARCHAR(length=20), nullable=False))
batch_op.add_column(sa.Column("uid", sa.INTEGER(), nullable=False))
# ### end Alembic commands ###

View File

@ -15,13 +15,14 @@ from nonebot.params import Depends, EventPlainText, EventToMe
from nonebot.permission import SUPERUSER
from nonebot.rule import to_me
from nonebot.typing import T_State
from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, extract_target
from .apis import check_sub_target
from .config import config
from .config.db_config import SubscribeDupException
from .platform import Platform, platform_manager
from .plugin_config import plugin_config
from .types import Category, Target, User
from .types import Category, Target
from .utils import parse_text
@ -61,12 +62,8 @@ def ensure_user_info(matcher: Type[Matcher]):
async def set_target_user_info(event: MessageEvent, state: T_State):
if isinstance(event, GroupMessageEvent):
user = User(event.group_id, "group")
state["target_user_info"] = user
elif isinstance(event, PrivateMessageEvent):
user = User(event.user_id, "private")
state["target_user_info"] = user
user = extract_target(event)
state["target_user_info"] = user
def do_add_sub(add_sub: Type[Matcher]):
@ -201,14 +198,11 @@ def do_add_sub(add_sub: Type[Matcher]):
@add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)])
async def add_sub_process(event: Event, state: T_State):
user = cast(User, state.get("target_user_info"))
assert isinstance(user, User)
user = cast(PlatformTarget, state.get("target_user_info"))
assert isinstance(user, PlatformTarget)
try:
await config.add_subscribe(
# state.get("_user_id") or event.group_id,
# user_type="group",
user=user.user,
user_type=user.user_type,
user=user,
target=state["id"],
target_name=state["name"],
platform_name=state["platform"],
@ -228,12 +222,8 @@ def do_query_sub(query_sub: Type[Matcher]):
@query_sub.handle()
async def _(state: T_State):
user_info = state["target_user_info"]
assert isinstance(user_info, User)
sub_list = await config.list_subscribe(
# state.get("_user_id") or event.group_id, "group"
user_info.user,
user_info.user_type,
)
assert isinstance(user_info, PlatformTarget)
sub_list = await config.list_subscribe(user_info)
res = "订阅的帐号为:\n"
for sub in sub_list:
res += "{} {} {}".format(
@ -261,13 +251,9 @@ def do_del_sub(del_sub: Type[Matcher]):
@del_sub.handle()
async def send_list(bot: Bot, event: Event, state: T_State):
user_info = state["target_user_info"]
assert isinstance(user_info, User)
assert isinstance(user_info, PlatformTarget)
try:
sub_list = await config.list_subscribe(
# state.get("_user_id") or event.group_id, "group"
user_info.user,
user_info.user_type,
)
sub_list = await config.list_subscribe(user_info)
assert sub_list
except AssertionError:
await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
@ -309,14 +295,8 @@ def do_del_sub(del_sub: Type[Matcher]):
try:
index = int(user_msg)
user_info = state["target_user_info"]
assert isinstance(user_info, User)
await config.del_subscribe(
# state.get("_user_id") or event.group_id,
# "group",
user_info.user,
user_info.user_type,
**state["sub_table"][index],
)
assert isinstance(user_info, PlatformTarget)
await config.del_subscribe(user_info, **state["sub_table"][index])
except Exception as e:
await del_sub.reject("删除错误")
else:
@ -398,7 +378,7 @@ async def do_choose_group_number(state: T_State):
group_number_idx: dict[int, int] = state["group_number_idx"]
idx: int = state["group_idx"]
group_id = group_number_idx[idx]
state["target_user_info"] = User(user=group_id, user_type="group")
state["target_user_info"] = TargetQQGroup(group_id=group_id)
async def _check_command(event_msg: str = EventPlainText()):

View File

@ -10,10 +10,11 @@ from typing import Any, Collection, Optional, Type
import httpx
from httpx import AsyncClient
from nonebot.log import logger
from nonebot_plugin_saa import PlatformTarget
from ..plugin_config import plugin_config
from ..post import Post
from ..types import Category, RawPost, Tag, Target, User, UserSubInfo
from ..types import Category, RawPost, Tag, Target, UserSubInfo
from ..utils import ProcessContext, SchedulerConfig
@ -84,12 +85,12 @@ class Platform(metaclass=PlatformABCMeta, base=True):
@abstractmethod
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
) -> list[tuple[PlatformTarget, list[Post]]]:
...
async def do_fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
) -> list[tuple[PlatformTarget, list[Post]]]:
try:
return await self.fetch_new_post(target, users)
except httpx.RequestError as err:
@ -197,8 +198,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
async def dispatch_user_post(
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
res: list[tuple[User, list[Post]]] = []
) -> list[tuple[PlatformTarget, list[Post]]]:
res: list[tuple[PlatformTarget, list[Post]]] = []
for user, cats, required_tags in users:
user_raw_post = await self.filter_user_custom(
new_posts, cats, required_tags
@ -314,7 +315,7 @@ class NewMessage(MessageProcess, abstract=True):
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
) -> list[tuple[PlatformTarget, list[Post]]]:
post_list = await self.get_sub_list(target)
new_posts = await self.filter_common_with_diff(target, post_list)
if not new_posts:
@ -353,7 +354,7 @@ class StatusChange(Platform, abstract=True):
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
) -> list[tuple[PlatformTarget, list[Post]]]:
try:
new_status = await self.get_status(target)
except self.FetchError as err:
@ -381,7 +382,7 @@ class SimplePost(MessageProcess, abstract=True):
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
) -> list[tuple[PlatformTarget, list[Post]]]:
new_posts = await self.get_sub_list(target)
if not new_posts:
return []

View File

@ -3,6 +3,7 @@ from datetime import time
from typing import Any, Literal, NamedTuple, NewType
from httpx import URL
from nonebot_plugin_saa import PlatformTarget as SendTarget
from pydantic import BaseModel
RawPost = Any
@ -25,7 +26,7 @@ class PlatformTarget:
class UserSubInfo(NamedTuple):
user: User
user: SendTarget
categories: list[Category]
tags: list[Tag]

View File

@ -1,53 +1,57 @@
""" 提供获取 Bot 的方法 """
import random
from collections import defaultdict
from typing import Any, Optional
import nonebot
from nonebot import get_driver, on_notice
from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import Bot as Ob11Bot
from nonebot.adapters.onebot.v11 import (
Bot,
FriendAddNoticeEvent,
GroupDecreaseNoticeEvent,
GroupIncreaseNoticeEvent,
)
from ..types import User
from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, TargetQQPrivate
GROUP: dict[int, list[Bot]] = {}
USER: dict[int, list[Bot]] = {}
BOT_CACHE: dict[PlatformTarget, list[Bot]] = defaultdict(list)
def get_bots() -> list[Bot]:
"""获取所有 OneBot 11 Bot"""
# TODO: support ob12
bots = []
for bot in nonebot.get_bots().values():
if isinstance(bot, Bot):
if isinstance(bot, Ob11Bot):
bots.append(bot)
return bots
async def _refresh_ob11(bot: Ob11Bot):
# 获取群列表
groups = await bot.get_group_list()
for group in groups:
group_id = group["group_id"]
target = TargetQQGroup(group_id=group_id)
BOT_CACHE[target].append(bot)
# 获取好友列表
users = await bot.get_friend_list()
for user in users:
user_id = user["user_id"]
target = TargetQQPrivate(user_id=user_id)
BOT_CACHE[target].append(bot)
async def refresh_bots():
"""刷新缓存的 Bot 数据"""
GROUP.clear()
USER.clear()
BOT_CACHE.clear()
for bot in get_bots():
# 获取群列表
groups = await bot.get_group_list()
for group in groups:
group_id = group["group_id"]
if group_id not in GROUP:
GROUP[group_id] = [bot]
else:
GROUP[group_id].append(bot)
# 获取好友列表
users = await bot.get_friend_list()
for user in users:
user_id = user["user_id"]
if user_id not in USER:
USER[user_id] = [bot]
else:
USER[user_id].append(bot)
match bot:
case Ob11Bot():
await _refresh_ob11(bot)
driver = get_driver()
@ -75,15 +79,9 @@ async def _(bot: Bot, event: GroupDecreaseNoticeEvent | GroupIncreaseNoticeEvent
await refresh_bots()
def get_bot(user: User) -> Optional[Bot]:
def get_bot(user: PlatformTarget) -> Optional[Bot]:
"""获取 Bot"""
bots = []
if user.user_type == "group":
bots = GROUP.get(user.user, [])
if user.user_type == "private":
bots = USER.get(user.user, [])
bots = BOT_CACHE.get(user)
if not bots:
return
@ -92,6 +90,7 @@ def get_bot(user: User) -> Optional[Bot]:
async def get_groups() -> list[dict[str, Any]]:
"""获取所有群号"""
# TODO
all_groups: dict[int, dict[str, Any]] = {}
for bot in get_bots():
groups = await bot.get_group_list()

22
poetry.lock generated
View File

@ -1470,16 +1470,20 @@ name = "nonebot-plugin-send-anything-anywhere"
version = "0.2.4"
description = "An adaptor for nonebot2 adaptors"
optional = false
python-versions = ">=3.8,<4.0"
files = [
{file = "nonebot_plugin_send_anything_anywhere-0.2.4-py3-none-any.whl", hash = "sha256:97c1c1565479c1750c21ce471545ea293a1f26d606cbe5ae071dab0047200408"},
{file = "nonebot_plugin_send_anything_anywhere-0.2.4.tar.gz", hash = "sha256:71217c6bd7f84d6f3d266914562c60dadf9b28e66801c3996d6d7c36bafa7fca"},
]
python-versions = "^3.8"
files = []
develop = false
[package.dependencies]
nonebot2 = ">=2.0.0rc1,<3.0.0"
pydantic = ">=1.10.5,<2.0.0"
strenum = ">=0.4.8,<0.5.0"
nonebot2 = "^2.0.0rc1"
pydantic = "^1.10.5"
strenum = "^0.4.8"
[package.source]
type = "git"
url = "https://github.com/felinae98/nonebot-plugin-send-anything-anywhere.git"
reference = "main"
resolved_reference = "7f8a57afc72b5b6a7f909935f1a87411bf597173"
[[package]]
name = "nonebot2"
@ -2869,4 +2873,4 @@ yaml = []
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<4.0.0"
content-hash = "a8af95b0b5285f14d48ba11d7237cf636ca2102e7374d07d6b808eb5fdba8a76"
content-hash = "efba4feca911691e91af2b93cb810268f6e35a6e985811587e6b00999c2bd263"

View File

@ -34,7 +34,7 @@ nonebot-adapter-onebot = "^2.0.0-beta.1"
nonebot-plugin-htmlrender = ">=0.2.0"
nonebot-plugin-datastore = "^0.6.2"
nonebot-plugin-apscheduler = "^0.2.0"
nonebot-plugin-send-anything-anywhere = "^0.2.1"
nonebot-plugin-send-anything-anywhere = {git = "https://github.com/felinae98/nonebot-plugin-send-anything-anywhere.git", rev = "main"}
[tool.poetry.group.dev.dependencies]
ipdb = "^0.13.4"

View File

@ -4,6 +4,7 @@ from nonebug.app import App
async def test_add_subscribe(app: App, init_scheduler):
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select
@ -12,8 +13,7 @@ async def test_add_subscribe(app: App, init_scheduler):
from nonebot_bison.types import Target as TTarget
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
@ -21,15 +21,14 @@ async def test_add_subscribe(app: App, init_scheduler):
tags=[],
)
await config.add_subscribe(
user=234,
user_type="group",
TargetQQGroup(group_id=234),
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
confs = await config.list_subscribe(123, "group")
confs = await config.list_subscribe(TargetQQGroup(group_id=123))
assert len(confs) == 1
conf: Subscribe = confs[0]
async with AsyncSession(get_engine()) as sess:
@ -39,22 +38,23 @@ async def test_add_subscribe(app: App, init_scheduler):
related_target_obj = await sess.scalar(
select(Target).where(Target.id == conf.target_id)
)
assert related_user_obj.uid == 123
assert related_user_obj
assert related_target_obj
assert related_user_obj.user_target["group_id"] == 123
assert related_target_obj.target_name == "weibo_name"
assert related_target_obj.target == "weibo_id"
assert conf.target.target == "weibo_id"
assert conf.categories == []
await config.update_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=TTarget("weibo_id"),
platform_name="weibo",
target_name="weibo_name2",
cats=[1],
tags=["tag"],
)
confs = await config.list_subscribe(123, "group")
confs = await config.list_subscribe(TargetQQGroup(group_id=123))
assert len(confs) == 1
conf: Subscribe = confs[0]
async with AsyncSession(get_engine()) as sess:
@ -64,7 +64,9 @@ async def test_add_subscribe(app: App, init_scheduler):
related_target_obj = await sess.scalar(
select(Target).where(Target.id == conf.target_id)
)
assert related_user_obj.uid == 123
assert related_user_obj
assert related_target_obj
assert related_user_obj.user_target["group_id"] == 123
assert related_target_obj.target_name == "weibo_name2"
assert related_target_obj.target == "weibo_id"
assert conf.target.target == "weibo_id"
@ -73,12 +75,13 @@ async def test_add_subscribe(app: App, init_scheduler):
async def test_add_dup_sub(init_scheduler):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config.db_config import SubscribeDupException, config
from nonebot_bison.types import Target as TTarget
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
@ -88,8 +91,7 @@ async def test_add_dup_sub(init_scheduler):
with pytest.raises(SubscribeDupException):
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
@ -100,6 +102,7 @@ async def test_add_dup_sub(init_scheduler):
async def test_del_subsribe(init_scheduler):
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select
from sqlalchemy.sql.functions import func
@ -109,8 +112,7 @@ async def test_del_subsribe(init_scheduler):
from nonebot_bison.types import Target as TTarget
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
@ -118,8 +120,7 @@ async def test_del_subsribe(init_scheduler):
tags=[],
)
await config.del_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=TTarget("weibo_id"),
platform_name="weibo",
)
@ -128,8 +129,7 @@ async def test_del_subsribe(init_scheduler):
assert (await sess.scalar(select(func.count()).select_from(Target))) == 1
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
@ -138,8 +138,7 @@ async def test_del_subsribe(init_scheduler):
)
await config.add_subscribe(
user=124,
user_type="group",
TargetQQGroup(group_id=124),
target=TTarget("weibo_id"),
target_name="weibo_name_new",
platform_name="weibo",
@ -148,8 +147,7 @@ async def test_del_subsribe(init_scheduler):
)
await config.del_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=TTarget("weibo_id"),
platform_name="weibo",
)
@ -157,5 +155,6 @@ async def test_del_subsribe(init_scheduler):
async with AsyncSession(get_engine()) as sess:
assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 1
assert (await sess.scalar(select(func.count()).select_from(Target))) == 1
target: Target = await sess.scalar(select(Target))
target = await sess.scalar(select(Target))
assert target
assert target.target_name == "weibo_name_new"

View File

@ -1,5 +1,6 @@
async def test_migration(use_legacy_config):
from nonebot_plugin_datastore.db import init_db
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config.config_legacy import Config
from nonebot_bison.config.db_config import config
@ -34,7 +35,7 @@ async def test_migration(use_legacy_config):
)
# await data_migrate()
await init_db()
user123_config = await config.list_subscribe(123, "group")
user123_config = await config.list_subscribe(TargetQQGroup(group_id=123))
assert len(user123_config) == 2
for c in user123_config:
if c.target.target == "weibo_id":
@ -47,7 +48,7 @@ async def test_migration(use_legacy_config):
assert c.target.target_name == "weibo_name2"
assert c.target.platform_name == "weibo"
assert c.tags == ["tag"]
user234_config = await config.list_subscribe(234, "group")
user234_config = await config.list_subscribe(TargetQQGroup(group_id=234))
assert len(user234_config) == 1
assert user234_config[0].categories == [1]
assert user234_config[0].target.target == "weibo_id"
@ -57,6 +58,7 @@ async def test_migration(use_legacy_config):
async def test_migrate_dup(use_legacy_config):
from nonebot_plugin_datastore.db import init_db
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config.config_legacy import Config
from nonebot_bison.config.db_config import config
@ -82,5 +84,5 @@ async def test_migrate_dup(use_legacy_config):
)
# await data_migrate()
await init_db()
user123_config = await config.list_subscribe(123, "group")
user123_config = await config.list_subscribe(TargetQQGroup(group_id=123))
assert len(user123_config) == 1

View File

@ -6,14 +6,14 @@ from pytest_mock import MockerFixture
async def test_create_config(init_scheduler):
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
@ -21,8 +21,7 @@ async def test_create_config(init_scheduler):
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
@ -58,6 +57,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture):
from datetime import time
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import db_config
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
@ -65,8 +65,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture):
from nonebot_bison.types import Target as T_Target
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
@ -74,8 +73,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture):
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
@ -83,8 +81,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture):
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=T_Target("weibo_id1"),
target_name="weibo_name2",
platform_name="bilibili",
@ -124,6 +121,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture):
async def test_get_platform_target(app: App, init_scheduler):
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select
@ -133,8 +131,7 @@ async def test_get_platform_target(app: App, init_scheduler):
from nonebot_bison.types import Target as T_Target
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
@ -142,8 +139,7 @@ async def test_get_platform_target(app: App, init_scheduler):
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
@ -151,8 +147,7 @@ async def test_get_platform_target(app: App, init_scheduler):
tags=[],
)
await config.add_subscribe(
user=245,
user_type="group",
TargetQQGroup(group_id=245),
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
@ -161,10 +156,14 @@ async def test_get_platform_target(app: App, init_scheduler):
)
res = await config.get_platform_target("weibo")
assert len(res) == 2
await config.del_subscribe(123, "group", T_Target("weibo_id1"), "weibo")
await config.del_subscribe(
TargetQQGroup(group_id=123), T_Target("weibo_id1"), "weibo"
)
res = await config.get_platform_target("weibo")
assert len(res) == 2
await config.del_subscribe(123, "group", T_Target("weibo_id"), "weibo")
await config.del_subscribe(
TargetQQGroup(group_id=123), T_Target("weibo_id"), "weibo"
)
res = await config.get_platform_target("weibo")
assert len(res) == 1
@ -175,6 +174,7 @@ async def test_get_platform_target(app: App, init_scheduler):
async def test_get_platform_target_subscribers(app: App, init_scheduler):
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select
@ -182,12 +182,10 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler):
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target
from nonebot_bison.types import User as T_User
from nonebot_bison.types import UserSubInfo
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
@ -195,8 +193,7 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler):
tags=["tag1"],
)
await config.add_subscribe(
user=123,
user_type="group",
TargetQQGroup(group_id=123),
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
@ -204,8 +201,7 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler):
tags=["tag2"],
)
await config.add_subscribe(
user=245,
user_type="group",
TargetQQGroup(group_id=245),
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
@ -215,9 +211,9 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler):
res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id"))
assert len(res) == 1
assert res[0] == UserSubInfo(T_User(123, "group"), [1], ["tag1"])
assert res[0] == UserSubInfo(TargetQQGroup(group_id=123), [1], ["tag1"])
res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id1"))
assert len(res) == 2
assert UserSubInfo(T_User(123, "group"), [2], ["tag2"]) in res
assert UserSubInfo(T_User(245, "group"), [3], ["tag3"]) in res
assert UserSubInfo(TargetQQGroup(group_id=123), [2], ["tag2"]) in res
assert UserSubInfo(TargetQQGroup(group_id=245), [3], ["tag3"]) in res

View File

@ -67,9 +67,11 @@ async def app(tmp_path: Path, request: pytest.FixtureRequest, mocker: MockerFixt
@pytest.fixture
def dummy_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo
from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group")
from nonebot_bison.types import UserSubInfo
user = TargetQQGroup(group_id=123)
return UserSubInfo(user=user, categories=[], tags=[])

View File

@ -16,9 +16,11 @@ def bili_live(app: App):
@pytest.fixture
def dummy_only_open_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo
from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group")
from nonebot_bison.types import UserSubInfo
user = TargetQQGroup(group_id=123)
return UserSubInfo(user=user, categories=[1], tags=[])
@ -68,9 +70,11 @@ async def test_fetch_bililive_only_live_open(bili_live, dummy_only_open_user_sub
@pytest.fixture
def dummy_only_title_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo
from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group")
from nonebot_bison.types import UserSubInfo
user = TargetQQGroup(group_id=123)
return UserSubInfo(user=user, categories=[2], tags=[])
@ -128,9 +132,11 @@ async def test_fetch_bililive_only_title_change(
@pytest.fixture
def dummy_only_close_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo
from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group")
from nonebot_bison.types import UserSubInfo
user = TargetQQGroup(group_id=123)
return UserSubInfo(user=user, categories=[3], tags=[])
@ -187,9 +193,11 @@ async def test_fetch_bililive_only_close(bili_live, dummy_only_close_user_subinf
@pytest.fixture
def dummy_bililive_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo
from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group")
from nonebot_bison.types import UserSubInfo
user = TargetQQGroup(group_id=123)
return UserSubInfo(user=user, categories=[1, 2, 3], tags=[])

View File

@ -24,9 +24,9 @@ raw_post_list_2 = raw_post_list_1 + [
@pytest.fixture
def dummy_user(app: App):
from nonebot_bison.types import User
from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group")
user = TargetQQGroup(group_id=123)
return user

View File

@ -12,6 +12,8 @@ if typing.TYPE_CHECKING:
async def get_schedule_times(
scheduler_config: Type["SchedulerConfig"], time: int
) -> dict[str, int]:
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.scheduler import scheduler_dict
scheduler = scheduler_dict[scheduler_config]
@ -25,6 +27,8 @@ async def get_schedule_times(
async def test_scheduler_without_time(init_scheduler):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.config.db_config import WeightConfig
from nonebot_bison.platform.bilibili import BilibiliSchedConf
@ -32,13 +36,13 @@ async def test_scheduler_without_time(init_scheduler):
from nonebot_bison.types import Target as T_Target
await config.add_subscribe(
123, "group", T_Target("t1"), "target1", "bilibili", [], []
TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], []
)
await config.add_subscribe(
123, "group", T_Target("t2"), "target1", "bilibili", [], []
TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili", [], []
)
await config.add_subscribe(
123, "group", T_Target("t2"), "target1", "bilibili-live", [], []
TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili-live", [], []
)
await config.update_time_weight_config(
@ -62,6 +66,8 @@ async def test_scheduler_without_time(init_scheduler):
async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixture):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config, db_config
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig
from nonebot_bison.platform.bilibili import BilibiliSchedConf
@ -69,13 +75,13 @@ async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixtu
from nonebot_bison.types import Target as T_Target
await config.add_subscribe(
123, "group", T_Target("t1"), "target1", "bilibili", [], []
TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], []
)
await config.add_subscribe(
123, "group", T_Target("t2"), "target1", "bilibili", [], []
TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili", [], []
)
await config.add_subscribe(
123, "group", T_Target("t2"), "target1", "bilibili-live", [], []
TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili-live", [], []
)
await config.update_time_weight_config(
@ -113,38 +119,42 @@ async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixtu
async def test_scheduler_add_new(init_scheduler):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.platform.bilibili import BilibiliSchedConf
from nonebot_bison.scheduler.manager import init_scheduler
from nonebot_bison.types import Target as T_Target
await config.add_subscribe(
123, "group", T_Target("t1"), "target1", "bilibili", [], []
TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], []
)
await init_scheduler()
await config.add_subscribe(
2345, "group", T_Target("t1"), "target1", "bilibili", [], []
TargetQQGroup(group_id=2345), T_Target("t1"), "target1", "bilibili", [], []
)
await config.add_subscribe(
123, "group", T_Target("t2"), "target2", "bilibili", [], []
TargetQQGroup(group_id=123), T_Target("t2"), "target2", "bilibili", [], []
)
stat_res = await get_schedule_times(BilibiliSchedConf, 1)
assert stat_res["bilibili-t2"] == 1
async def test_schedule_delete(init_scheduler):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.platform.bilibili import BilibiliSchedConf
from nonebot_bison.scheduler.manager import init_scheduler
from nonebot_bison.types import Target as T_Target
await config.add_subscribe(
123, "group", T_Target("t1"), "target1", "bilibili", [], []
TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], []
)
await config.add_subscribe(
123, "group", T_Target("t2"), "target1", "bilibili", [], []
TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili", [], []
)
await init_scheduler()
@ -153,6 +163,6 @@ async def test_schedule_delete(init_scheduler):
assert stat_res["bilibili-t2"] == 1
assert stat_res["bilibili-t1"] == 1
await config.del_subscribe(123, "group", T_Target("t1"), "bilibili")
await config.del_subscribe(TargetQQGroup(group_id=123), T_Target("t1"), "bilibili")
stat_res = await get_schedule_times(BilibiliSchedConf, 2)
assert stat_res["bilibili-t2"] == 2

View File

@ -279,6 +279,7 @@ async def test_abort_add_on_tag(app: App, init_scheduler):
async def test_abort_del_sub(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.bot import Bot
from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.config_manager import del_sub_matcher
@ -286,8 +287,7 @@ async def test_abort_del_sub(app: App, init_scheduler):
from nonebot_bison.types import Target as T_Target
await config.add_subscribe(
10000,
"group",
TargetQQGroup(group_id=10000),
T_Target("6279793937"),
"明日方舟Arknights",
"weibo",
@ -316,5 +316,5 @@ async def test_abort_del_sub(app: App, init_scheduler):
ctx.receive_event(bot, event_abort)
ctx.should_call_send(event_abort, "删除中止", True)
ctx.should_finished()
subs = await config.list_subscribe(10000, "group")
subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert subs

View File

@ -64,6 +64,7 @@ async def test_configurable_at_me_false(app: App):
async def test_add_with_target(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -171,7 +172,7 @@ async def test_add_with_target(app: App, init_scheduler):
event_6_ok, BotReply.add_reply_subscribe_success("明日方舟Arknights"), True
)
ctx.should_finished()
subs = await config.list_subscribe(10000, "group")
subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 1
sub = subs[0]
assert sub.target.target == "6279793937"
@ -188,6 +189,7 @@ async def test_add_with_target(app: App, init_scheduler):
async def test_add_with_target_no_cat(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -233,7 +235,7 @@ async def test_add_with_target_no_cat(app: App, init_scheduler):
event_4_ok, BotReply.add_reply_subscribe_success("塞壬唱片-MSR"), True
)
ctx.should_finished()
subs = await config.list_subscribe(10000, "group")
subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 1
sub = subs[0]
assert sub.target.target == "32540734"
@ -248,6 +250,7 @@ async def test_add_with_target_no_cat(app: App, init_scheduler):
async def test_add_no_target(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -284,7 +287,7 @@ async def test_add_no_target(app: App, init_scheduler):
event_4, BotReply.add_reply_subscribe_success("明日方舟游戏信息"), True
)
ctx.should_finished()
subs = await config.list_subscribe(10000, "group")
subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 1
sub = subs[0]
assert sub.target.target == "default"
@ -334,6 +337,7 @@ async def test_platform_name_err(app: App):
async def test_add_with_get_id(app: App):
from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message, MessageSegment
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -407,7 +411,7 @@ async def test_add_with_get_id(app: App):
True,
)
ctx.should_finished()
subs = await config.list_subscribe(10000, "group")
subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 0
@ -416,6 +420,7 @@ async def test_add_with_get_id(app: App):
async def test_add_with_bilibili_target_parser(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -524,7 +529,7 @@ async def test_add_with_bilibili_target_parser(app: App, init_scheduler):
event_6, BotReply.add_reply_subscribe_success("明日方舟"), True
)
ctx.should_finished()
subs = await config.list_subscribe(10000, "group")
subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 1
sub = subs[0]
assert sub.target.target == "161775300"

View File

@ -10,6 +10,7 @@ from .utils import fake_admin_user, fake_group_message_event
@pytest.mark.asyncio
async def test_query_sub(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.config_manager import query_sub_matcher
@ -17,8 +18,7 @@ async def test_query_sub(app: App, init_scheduler):
from nonebot_bison.types import Target
await config.add_subscribe(
10000,
"group",
TargetQQGroup(group_id=10000),
Target("6279793937"),
"明日方舟Arknights",
"weibo",
@ -40,6 +40,7 @@ async def test_query_sub(app: App, init_scheduler):
async def test_del_sub(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.bot import Bot
from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config
from nonebot_bison.config_manager import del_sub_matcher
@ -47,8 +48,7 @@ async def test_del_sub(app: App, init_scheduler):
from nonebot_bison.types import Target
await config.add_subscribe(
10000,
"group",
TargetQQGroup(group_id=10000),
Target("6279793937"),
"明日方舟Arknights",
"weibo",
@ -83,7 +83,7 @@ async def test_del_sub(app: App, init_scheduler):
ctx.receive_event(bot, event_1_ok)
ctx.should_call_send(event_1_ok, "删除成功", True)
ctx.should_finished()
subs = await config.list_subscribe(10000, "group")
subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 0

View File

@ -30,8 +30,8 @@ async def test_refresh_bots(app: App) -> None:
from nonebot import get_driver
from nonebot.adapters.onebot.v11 import Bot as BotV11
from nonebot.adapters.onebot.v12 import Bot as BotV12
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from nonebot_bison.types import User
from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots
async with app.test_api() as ctx:
@ -44,13 +44,13 @@ async def test_refresh_bots(app: App) -> None:
ctx.should_call_api("get_group_list", {}, [{"group_id": 1}])
ctx.should_call_api("get_friend_list", {}, [{"user_id": 2}])
assert get_bot(User(1, "group")) is None
assert get_bot(User(2, "private")) is None
assert get_bot(TargetQQGroup(group_id=1)) is None
assert get_bot(TargetQQPrivate(user_id=2)) is None
await refresh_bots()
assert get_bot(User(1, "group")) == botv11
assert get_bot(User(2, "private")) == botv11
assert get_bot(TargetQQGroup(group_id=1)) == botv11
assert get_bot(TargetQQPrivate(user_id=2)) == botv11
# 测试获取群列表
ctx.should_call_api("get_group_list", {}, [{"group_id": 3}])
@ -66,8 +66,8 @@ async def test_get_bot_two_bots(app: App) -> None:
from nonebot import get_driver
from nonebot.adapters.onebot.v11 import Bot as BotV11
from nonebot.adapters.onebot.v12 import Bot as BotV12
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from nonebot_bison.types import User
from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots
async with app.test_api() as ctx:
@ -85,14 +85,14 @@ async def test_get_bot_two_bots(app: App) -> None:
await refresh_bots()
assert get_bot(User(0, "group")) is None
assert get_bot(User(1, "group")) == bot1
assert get_bot(User(2, "group")) in (bot1, bot2)
assert get_bot(User(3, "group")) == bot2
assert get_bot(User(0, "private")) is None
assert get_bot(User(1, "private")) == bot1
assert get_bot(User(2, "private")) in (bot1, bot2)
assert get_bot(User(3, "private")) == bot2
assert get_bot(TargetQQGroup(group_id=0)) is None
assert get_bot(TargetQQGroup(group_id=1)) == bot1
assert get_bot(TargetQQGroup(group_id=2)) in (bot1, bot2)
assert get_bot(TargetQQGroup(group_id=3)) == bot2
assert get_bot(TargetQQPrivate(user_id=0)) is None
assert get_bot(TargetQQPrivate(user_id=1)) == bot1
assert get_bot(TargetQQPrivate(user_id=2)) in (bot1, bot2)
assert get_bot(TargetQQPrivate(user_id=3)) == bot2
ctx.should_call_api("get_group_list", {}, [{"group_id": 1}, {"group_id": 2}])
ctx.should_call_api("get_group_list", {}, [{"group_id": 2}, {"group_id": 3}])