🚧 remove User type

This commit is contained in:
felinae98 2023-03-20 14:39:04 +08:00
parent d535f5212d
commit 4118329bb0
24 changed files with 298 additions and 221 deletions

View File

@ -4,6 +4,7 @@ from fastapi.exceptions import HTTPException
from fastapi.param_functions import Depends from fastapi.param_functions import Depends
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.security.oauth2 import OAuth2PasswordBearer from fastapi.security.oauth2 import OAuth2PasswordBearer
from nonebot_plugin_saa import TargetQQGroup
from ..apis import check_sub_target from ..apis import check_sub_target
from ..config import ( from ..config import (
@ -15,7 +16,7 @@ from ..config import (
from ..config.db_config import SubscribeDupException from ..config.db_config import SubscribeDupException
from ..platform import platform_manager from ..platform import platform_manager
from ..types import Target as T_Target from ..types import Target as T_Target
from ..types import User, WeightConfig from ..types import WeightConfig
from ..utils.get_bot import get_bot, get_groups from ..utils.get_bot import get_bot, get_groups
from .jwt import load_jwt, pack_jwt from .jwt import load_jwt, pack_jwt
from .token_manager import token_manager from .token_manager import token_manager
@ -75,7 +76,7 @@ async def get_admin_groups(qq: int):
res = [] res = []
for group in await get_groups(): for group in await get_groups():
group_id = group["group_id"] group_id = group["group_id"]
bot = get_bot(User(group_id, "group")) bot = get_bot(TargetQQGroup(group_id=group_id))
if not bot: if not bot:
continue continue
users = await bot.get_group_member_list(group_id=group_id) users = await bot.get_group_member_list(group_id=group_id)
@ -131,7 +132,7 @@ async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp:
res: SubscribeResp = {} res: SubscribeResp = {}
for group in groups: for group in groups:
group_id = group["id"] group_id = group["id"]
raw_subs = await config.list_subscribe(group_id, "group") raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id))
subs = list( subs = list(
map( map(
lambda sub: SubscribeConfig( lambda sub: SubscribeConfig(
@ -157,8 +158,7 @@ async def get_target_name(platformName: str, target: str):
async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp: async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
try: try:
await config.add_subscribe( await config.add_subscribe(
int(groupNumber), TargetQQGroup(group_id=groupNumber),
"group",
T_Target(req.target), T_Target(req.target),
req.targetName, req.targetName,
req.platformName, req.platformName,
@ -173,7 +173,9 @@ async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
@router.delete("/subs", dependencies=[Depends(check_group_permission)]) @router.delete("/subs", dependencies=[Depends(check_group_permission)])
async def del_group_sub(groupNumber: int, platformName: str, target: str): async def del_group_sub(groupNumber: int, platformName: str, target: str):
try: try:
await config.del_subscribe(int(groupNumber), "group", target, platformName) await config.del_subscribe(
TargetQQGroup(group_id=groupNumber), target, platformName
)
except (NoSuchUserException, NoSuchSubscribeException): except (NoSuchUserException, NoSuchSubscribeException):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe") raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe")
return StatusResp(ok=True, msg="") return StatusResp(ok=True, msg="")

View File

@ -4,15 +4,14 @@ from datetime import datetime, time
from typing import Awaitable, Callable, Optional, Sequence from typing import Awaitable, Callable, Optional, Sequence
from nonebot_plugin_datastore import create_session from nonebot_plugin_datastore import create_session
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy import delete, func, select from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from ..types import Category, PlatformWeightConfigResp, Tag from ..types import Category, PlatformWeightConfigResp, Tag
from ..types import Target as T_Target from ..types import Target as T_Target
from ..types import TimeWeightConfig from ..types import TimeWeightConfig, UserSubInfo, WeightConfig
from ..types import User as T_User
from ..types import UserSubInfo, WeightConfig
from .db_model import ScheduleTimeWeight, Subscribe, Target, User from .db_model import ScheduleTimeWeight, Subscribe, Target, User
from .utils import NoSuchTargetException from .utils import NoSuchTargetException
@ -40,8 +39,7 @@ class DBConfig:
async def add_subscribe( async def add_subscribe(
self, self,
user: int, user: PlatformTarget,
user_type: str,
target: T_Target, target: T_Target,
target_name: str, target_name: str,
platform_name: str, platform_name: str,
@ -49,12 +47,10 @@ class DBConfig:
tags: list[Tag], tags: list[Tag],
): ):
async with create_session() as session: async with create_session() as session:
db_user_stmt = ( db_user_stmt = select(User).where(User.user_target == user.dict())
select(User).where(User.uid == user).where(User.type == user_type)
)
db_user: Optional[User] = await session.scalar(db_user_stmt) db_user: Optional[User] = await session.scalar(db_user_stmt)
if not db_user: if not db_user:
db_user = User(uid=user, type=user_type) db_user = User(user_target=user.dict())
session.add(db_user) session.add(db_user)
db_target_stmt = ( db_target_stmt = (
select(Target) select(Target)
@ -85,11 +81,11 @@ class DBConfig:
raise SubscribeDupException() raise SubscribeDupException()
raise e raise e
async def list_subscribe(self, user: int, user_type: str) -> Sequence[Subscribe]: async def list_subscribe(self, user: PlatformTarget) -> Sequence[Subscribe]:
async with create_session() as session: async with create_session() as session:
query_stmt = ( query_stmt = (
select(Subscribe) select(Subscribe)
.where(User.type == user_type, User.uid == user) .where(User.user_target == user.dict())
.join(User) .join(User)
.options(selectinload(Subscribe.target)) .options(selectinload(Subscribe.target))
) )
@ -109,11 +105,11 @@ class DBConfig:
return subs return subs
async def del_subscribe( async def del_subscribe(
self, user: int, user_type: str, target: str, platform_name: str self, user: PlatformTarget, target: str, platform_name: str
): ):
async with create_session() as session: async with create_session() as session:
user_obj = await session.scalar( user_obj = await session.scalar(
select(User).where(User.uid == user, User.type == user_type) select(User).where(User.user_target == user.dict())
) )
target_obj = await session.scalar( target_obj = await session.scalar(
select(Target).where( select(Target).where(
@ -142,8 +138,7 @@ class DBConfig:
async def update_subscribe( async def update_subscribe(
self, self,
user: int, user: PlatformTarget,
user_type: str,
target: str, target: str,
target_name: str, target_name: str,
platform_name: str, platform_name: str,
@ -154,8 +149,7 @@ class DBConfig:
subscribe_obj: Subscribe = await sess.scalar( subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe) select(Subscribe)
.where( .where(
User.uid == user, User.user_target == user.dict(),
User.type == user_type,
Target.target == target, Target.target == target,
Target.platform_name == platform_name, Target.platform_name == platform_name,
) )
@ -272,7 +266,7 @@ class DBConfig:
return list( return list(
map( map(
lambda subscribe: UserSubInfo( lambda subscribe: UserSubInfo(
T_User(subscribe.user.uid, subscribe.user.type), PlatformTarget.deserialize(subscribe.user.user_target),
subscribe.categories, subscribe.categories,
subscribe.tags, subscribe.tags,
), ),

View File

@ -1,5 +1,6 @@
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from .config_legacy import Config, ConfigContent, drop from .config_legacy import Config, ConfigContent, drop
@ -21,7 +22,11 @@ async def data_migrate():
subscribe_to_create = [] subscribe_to_create = []
platform_target_map: dict[str, tuple[Target, str, int]] = {} platform_target_map: dict[str, tuple[Target, str, int]] = {}
for user in all_subs: for user in all_subs:
db_user = User(uid=user["user"], type=user["user_type"]) if user["user_type"] == "group":
user_target = TargetQQGroup(group_id=user["user"])
else:
user_target = TargetQQPrivate(user_id=user["user"])
db_user = User(user_target=user_target.dict())
user_to_create.append(db_user) user_to_create.append(db_user)
user_sub_set = set() user_sub_set = set()
for sub in user["subs"]: for sub in user["subs"]:

View File

@ -14,15 +14,15 @@ get_plugin_data().set_migration_dir(Path(__file__).parent / "migrations")
class User(Model): class User(Model):
__table_args__ = (UniqueConstraint("type", "uid", name="unique-user-constraint"),)
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
type: Mapped[str] = mapped_column(String(20)) user_target: Mapped[dict] = mapped_column(JSON)
uid: Mapped[int]
user_target: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
subscribes: Mapped[list["Subscribe"]] = relationship(back_populates="user") subscribes: Mapped[list["Subscribe"]] = relationship(back_populates="user")
@property
def saa_target(self) -> PlatformTarget:
return PlatformTarget.deserialize(self.user_target)
class Target(Model): class Target(Model):
__table_args__ = ( __table_args__ = (

View File

@ -18,6 +18,7 @@ depends_on = None
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.drop_constraint("unique-user-constraint", type_="unique")
batch_op.add_column(sa.Column("user_target", sa.JSON(), nullable=True)) batch_op.add_column(sa.Column("user_target", sa.JSON(), nullable=True))
# ### end Alembic commands ### # ### end Alembic commands ###
@ -27,5 +28,6 @@ def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.drop_column("user_target") batch_op.drop_column("user_target")
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -0,0 +1,34 @@
"""make user_target not nullable
Revision ID: 67c38b3f39c2
Revises: a5466912fad0
Create Date: 2023-03-20 11:08:42.883556
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import sqlite
# revision identifiers, used by Alembic.
revision = "67c38b3f39c2"
down_revision = "a5466912fad0"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.alter_column(
"user_target", existing_type=sqlite.JSON(), nullable=False
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.alter_column("user_target", existing_type=sqlite.JSON(), nullable=True)
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""remove uid and type
Revision ID: 8d3863e9d74b
Revises: 67c38b3f39c2
Create Date: 2023-03-20 15:38:20.220599
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "8d3863e9d74b"
down_revision = "67c38b3f39c2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.drop_column("uid")
batch_op.drop_column("type")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.add_column(sa.Column("type", sa.VARCHAR(length=20), nullable=False))
batch_op.add_column(sa.Column("uid", sa.INTEGER(), nullable=False))
# ### end Alembic commands ###

View File

@ -15,13 +15,14 @@ from nonebot.params import Depends, EventPlainText, EventToMe
from nonebot.permission import SUPERUSER from nonebot.permission import SUPERUSER
from nonebot.rule import to_me from nonebot.rule import to_me
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, extract_target
from .apis import check_sub_target from .apis import check_sub_target
from .config import config from .config import config
from .config.db_config import SubscribeDupException from .config.db_config import SubscribeDupException
from .platform import Platform, platform_manager from .platform import Platform, platform_manager
from .plugin_config import plugin_config from .plugin_config import plugin_config
from .types import Category, Target, User from .types import Category, Target
from .utils import parse_text from .utils import parse_text
@ -61,12 +62,8 @@ def ensure_user_info(matcher: Type[Matcher]):
async def set_target_user_info(event: MessageEvent, state: T_State): async def set_target_user_info(event: MessageEvent, state: T_State):
if isinstance(event, GroupMessageEvent): user = extract_target(event)
user = User(event.group_id, "group") state["target_user_info"] = user
state["target_user_info"] = user
elif isinstance(event, PrivateMessageEvent):
user = User(event.user_id, "private")
state["target_user_info"] = user
def do_add_sub(add_sub: Type[Matcher]): def do_add_sub(add_sub: Type[Matcher]):
@ -201,14 +198,11 @@ def do_add_sub(add_sub: Type[Matcher]):
@add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)]) @add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)])
async def add_sub_process(event: Event, state: T_State): async def add_sub_process(event: Event, state: T_State):
user = cast(User, state.get("target_user_info")) user = cast(PlatformTarget, state.get("target_user_info"))
assert isinstance(user, User) assert isinstance(user, PlatformTarget)
try: try:
await config.add_subscribe( await config.add_subscribe(
# state.get("_user_id") or event.group_id, user=user,
# user_type="group",
user=user.user,
user_type=user.user_type,
target=state["id"], target=state["id"],
target_name=state["name"], target_name=state["name"],
platform_name=state["platform"], platform_name=state["platform"],
@ -228,12 +222,8 @@ def do_query_sub(query_sub: Type[Matcher]):
@query_sub.handle() @query_sub.handle()
async def _(state: T_State): async def _(state: T_State):
user_info = state["target_user_info"] user_info = state["target_user_info"]
assert isinstance(user_info, User) assert isinstance(user_info, PlatformTarget)
sub_list = await config.list_subscribe( sub_list = await config.list_subscribe(user_info)
# state.get("_user_id") or event.group_id, "group"
user_info.user,
user_info.user_type,
)
res = "订阅的帐号为:\n" res = "订阅的帐号为:\n"
for sub in sub_list: for sub in sub_list:
res += "{} {} {}".format( res += "{} {} {}".format(
@ -261,13 +251,9 @@ def do_del_sub(del_sub: Type[Matcher]):
@del_sub.handle() @del_sub.handle()
async def send_list(bot: Bot, event: Event, state: T_State): async def send_list(bot: Bot, event: Event, state: T_State):
user_info = state["target_user_info"] user_info = state["target_user_info"]
assert isinstance(user_info, User) assert isinstance(user_info, PlatformTarget)
try: try:
sub_list = await config.list_subscribe( sub_list = await config.list_subscribe(user_info)
# state.get("_user_id") or event.group_id, "group"
user_info.user,
user_info.user_type,
)
assert sub_list assert sub_list
except AssertionError: except AssertionError:
await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅") await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
@ -309,14 +295,8 @@ def do_del_sub(del_sub: Type[Matcher]):
try: try:
index = int(user_msg) index = int(user_msg)
user_info = state["target_user_info"] user_info = state["target_user_info"]
assert isinstance(user_info, User) assert isinstance(user_info, PlatformTarget)
await config.del_subscribe( await config.del_subscribe(user_info, **state["sub_table"][index])
# state.get("_user_id") or event.group_id,
# "group",
user_info.user,
user_info.user_type,
**state["sub_table"][index],
)
except Exception as e: except Exception as e:
await del_sub.reject("删除错误") await del_sub.reject("删除错误")
else: else:
@ -398,7 +378,7 @@ async def do_choose_group_number(state: T_State):
group_number_idx: dict[int, int] = state["group_number_idx"] group_number_idx: dict[int, int] = state["group_number_idx"]
idx: int = state["group_idx"] idx: int = state["group_idx"]
group_id = group_number_idx[idx] group_id = group_number_idx[idx]
state["target_user_info"] = User(user=group_id, user_type="group") state["target_user_info"] = TargetQQGroup(group_id=group_id)
async def _check_command(event_msg: str = EventPlainText()): async def _check_command(event_msg: str = EventPlainText()):

View File

@ -10,10 +10,11 @@ from typing import Any, Collection, Optional, Type
import httpx import httpx
from httpx import AsyncClient from httpx import AsyncClient
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_saa import PlatformTarget
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
from ..post import Post from ..post import Post
from ..types import Category, RawPost, Tag, Target, User, UserSubInfo from ..types import Category, RawPost, Tag, Target, UserSubInfo
from ..utils import ProcessContext, SchedulerConfig from ..utils import ProcessContext, SchedulerConfig
@ -84,12 +85,12 @@ class Platform(metaclass=PlatformABCMeta, base=True):
@abstractmethod @abstractmethod
async def fetch_new_post( async def fetch_new_post(
self, target: Target, users: list[UserSubInfo] self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]: ) -> list[tuple[PlatformTarget, list[Post]]]:
... ...
async def do_fetch_new_post( async def do_fetch_new_post(
self, target: Target, users: list[UserSubInfo] self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]: ) -> list[tuple[PlatformTarget, list[Post]]]:
try: try:
return await self.fetch_new_post(target, users) return await self.fetch_new_post(target, users)
except httpx.RequestError as err: except httpx.RequestError as err:
@ -197,8 +198,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
async def dispatch_user_post( async def dispatch_user_post(
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo] self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]: ) -> list[tuple[PlatformTarget, list[Post]]]:
res: list[tuple[User, list[Post]]] = [] res: list[tuple[PlatformTarget, list[Post]]] = []
for user, cats, required_tags in users: for user, cats, required_tags in users:
user_raw_post = await self.filter_user_custom( user_raw_post = await self.filter_user_custom(
new_posts, cats, required_tags new_posts, cats, required_tags
@ -314,7 +315,7 @@ class NewMessage(MessageProcess, abstract=True):
async def fetch_new_post( async def fetch_new_post(
self, target: Target, users: list[UserSubInfo] self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]: ) -> list[tuple[PlatformTarget, list[Post]]]:
post_list = await self.get_sub_list(target) post_list = await self.get_sub_list(target)
new_posts = await self.filter_common_with_diff(target, post_list) new_posts = await self.filter_common_with_diff(target, post_list)
if not new_posts: if not new_posts:
@ -353,7 +354,7 @@ class StatusChange(Platform, abstract=True):
async def fetch_new_post( async def fetch_new_post(
self, target: Target, users: list[UserSubInfo] self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]: ) -> list[tuple[PlatformTarget, list[Post]]]:
try: try:
new_status = await self.get_status(target) new_status = await self.get_status(target)
except self.FetchError as err: except self.FetchError as err:
@ -381,7 +382,7 @@ class SimplePost(MessageProcess, abstract=True):
async def fetch_new_post( async def fetch_new_post(
self, target: Target, users: list[UserSubInfo] self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]: ) -> list[tuple[PlatformTarget, list[Post]]]:
new_posts = await self.get_sub_list(target) new_posts = await self.get_sub_list(target)
if not new_posts: if not new_posts:
return [] return []

View File

@ -3,6 +3,7 @@ from datetime import time
from typing import Any, Literal, NamedTuple, NewType from typing import Any, Literal, NamedTuple, NewType
from httpx import URL from httpx import URL
from nonebot_plugin_saa import PlatformTarget as SendTarget
from pydantic import BaseModel from pydantic import BaseModel
RawPost = Any RawPost = Any
@ -25,7 +26,7 @@ class PlatformTarget:
class UserSubInfo(NamedTuple): class UserSubInfo(NamedTuple):
user: User user: SendTarget
categories: list[Category] categories: list[Category]
tags: list[Tag] tags: list[Tag]

View File

@ -1,53 +1,57 @@
""" 提供获取 Bot 的方法 """ """ 提供获取 Bot 的方法 """
import random import random
from collections import defaultdict
from typing import Any, Optional from typing import Any, Optional
import nonebot import nonebot
from nonebot import get_driver, on_notice from nonebot import get_driver, on_notice
from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import Bot as Ob11Bot
from nonebot.adapters.onebot.v11 import ( from nonebot.adapters.onebot.v11 import (
Bot,
FriendAddNoticeEvent, FriendAddNoticeEvent,
GroupDecreaseNoticeEvent, GroupDecreaseNoticeEvent,
GroupIncreaseNoticeEvent, GroupIncreaseNoticeEvent,
) )
from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, TargetQQPrivate
from ..types import User
GROUP: dict[int, list[Bot]] = {} GROUP: dict[int, list[Bot]] = {}
USER: dict[int, list[Bot]] = {} USER: dict[int, list[Bot]] = {}
BOT_CACHE: dict[PlatformTarget, list[Bot]] = defaultdict(list)
def get_bots() -> list[Bot]: def get_bots() -> list[Bot]:
"""获取所有 OneBot 11 Bot""" """获取所有 OneBot 11 Bot"""
# TODO: support ob12
bots = [] bots = []
for bot in nonebot.get_bots().values(): for bot in nonebot.get_bots().values():
if isinstance(bot, Bot): if isinstance(bot, Ob11Bot):
bots.append(bot) bots.append(bot)
return bots return bots
async def _refresh_ob11(bot: Ob11Bot):
# 获取群列表
groups = await bot.get_group_list()
for group in groups:
group_id = group["group_id"]
target = TargetQQGroup(group_id=group_id)
BOT_CACHE[target].append(bot)
# 获取好友列表
users = await bot.get_friend_list()
for user in users:
user_id = user["user_id"]
target = TargetQQPrivate(user_id=user_id)
BOT_CACHE[target].append(bot)
async def refresh_bots(): async def refresh_bots():
"""刷新缓存的 Bot 数据""" """刷新缓存的 Bot 数据"""
GROUP.clear() BOT_CACHE.clear()
USER.clear()
for bot in get_bots(): for bot in get_bots():
# 获取群列表 match bot:
groups = await bot.get_group_list() case Ob11Bot():
for group in groups: await _refresh_ob11(bot)
group_id = group["group_id"]
if group_id not in GROUP:
GROUP[group_id] = [bot]
else:
GROUP[group_id].append(bot)
# 获取好友列表
users = await bot.get_friend_list()
for user in users:
user_id = user["user_id"]
if user_id not in USER:
USER[user_id] = [bot]
else:
USER[user_id].append(bot)
driver = get_driver() driver = get_driver()
@ -75,15 +79,9 @@ async def _(bot: Bot, event: GroupDecreaseNoticeEvent | GroupIncreaseNoticeEvent
await refresh_bots() await refresh_bots()
def get_bot(user: User) -> Optional[Bot]: def get_bot(user: PlatformTarget) -> Optional[Bot]:
"""获取 Bot""" """获取 Bot"""
bots = [] bots = BOT_CACHE.get(user)
if user.user_type == "group":
bots = GROUP.get(user.user, [])
if user.user_type == "private":
bots = USER.get(user.user, [])
if not bots: if not bots:
return return
@ -92,6 +90,7 @@ def get_bot(user: User) -> Optional[Bot]:
async def get_groups() -> list[dict[str, Any]]: async def get_groups() -> list[dict[str, Any]]:
"""获取所有群号""" """获取所有群号"""
# TODO
all_groups: dict[int, dict[str, Any]] = {} all_groups: dict[int, dict[str, Any]] = {}
for bot in get_bots(): for bot in get_bots():
groups = await bot.get_group_list() groups = await bot.get_group_list()

22
poetry.lock generated
View File

@ -1470,16 +1470,20 @@ name = "nonebot-plugin-send-anything-anywhere"
version = "0.2.4" version = "0.2.4"
description = "An adaptor for nonebot2 adaptors" description = "An adaptor for nonebot2 adaptors"
optional = false optional = false
python-versions = ">=3.8,<4.0" python-versions = "^3.8"
files = [ files = []
{file = "nonebot_plugin_send_anything_anywhere-0.2.4-py3-none-any.whl", hash = "sha256:97c1c1565479c1750c21ce471545ea293a1f26d606cbe5ae071dab0047200408"}, develop = false
{file = "nonebot_plugin_send_anything_anywhere-0.2.4.tar.gz", hash = "sha256:71217c6bd7f84d6f3d266914562c60dadf9b28e66801c3996d6d7c36bafa7fca"},
]
[package.dependencies] [package.dependencies]
nonebot2 = ">=2.0.0rc1,<3.0.0" nonebot2 = "^2.0.0rc1"
pydantic = ">=1.10.5,<2.0.0" pydantic = "^1.10.5"
strenum = ">=0.4.8,<0.5.0" strenum = "^0.4.8"
[package.source]
type = "git"
url = "https://github.com/felinae98/nonebot-plugin-send-anything-anywhere.git"
reference = "main"
resolved_reference = "7f8a57afc72b5b6a7f909935f1a87411bf597173"
[[package]] [[package]]
name = "nonebot2" name = "nonebot2"
@ -2869,4 +2873,4 @@ yaml = []
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<4.0.0" python-versions = ">=3.10,<4.0.0"
content-hash = "a8af95b0b5285f14d48ba11d7237cf636ca2102e7374d07d6b808eb5fdba8a76" content-hash = "efba4feca911691e91af2b93cb810268f6e35a6e985811587e6b00999c2bd263"

View File

@ -34,7 +34,7 @@ nonebot-adapter-onebot = "^2.0.0-beta.1"
nonebot-plugin-htmlrender = ">=0.2.0" nonebot-plugin-htmlrender = ">=0.2.0"
nonebot-plugin-datastore = "^0.6.2" nonebot-plugin-datastore = "^0.6.2"
nonebot-plugin-apscheduler = "^0.2.0" nonebot-plugin-apscheduler = "^0.2.0"
nonebot-plugin-send-anything-anywhere = "^0.2.1" nonebot-plugin-send-anything-anywhere = {git = "https://github.com/felinae98/nonebot-plugin-send-anything-anywhere.git", rev = "main"}
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
ipdb = "^0.13.4" ipdb = "^0.13.4"

View File

@ -4,6 +4,7 @@ from nonebug.app import App
async def test_add_subscribe(app: App, init_scheduler): async def test_add_subscribe(app: App, init_scheduler):
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select from sqlalchemy.sql.expression import select
@ -12,8 +13,7 @@ async def test_add_subscribe(app: App, init_scheduler):
from nonebot_bison.types import Target as TTarget from nonebot_bison.types import Target as TTarget
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=TTarget("weibo_id"), target=TTarget("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
@ -21,15 +21,14 @@ async def test_add_subscribe(app: App, init_scheduler):
tags=[], tags=[],
) )
await config.add_subscribe( await config.add_subscribe(
user=234, TargetQQGroup(group_id=234),
user_type="group",
target=TTarget("weibo_id"), target=TTarget("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
cats=[], cats=[],
tags=[], tags=[],
) )
confs = await config.list_subscribe(123, "group") confs = await config.list_subscribe(TargetQQGroup(group_id=123))
assert len(confs) == 1 assert len(confs) == 1
conf: Subscribe = confs[0] conf: Subscribe = confs[0]
async with AsyncSession(get_engine()) as sess: async with AsyncSession(get_engine()) as sess:
@ -39,22 +38,23 @@ async def test_add_subscribe(app: App, init_scheduler):
related_target_obj = await sess.scalar( related_target_obj = await sess.scalar(
select(Target).where(Target.id == conf.target_id) select(Target).where(Target.id == conf.target_id)
) )
assert related_user_obj.uid == 123 assert related_user_obj
assert related_target_obj
assert related_user_obj.user_target["group_id"] == 123
assert related_target_obj.target_name == "weibo_name" assert related_target_obj.target_name == "weibo_name"
assert related_target_obj.target == "weibo_id" assert related_target_obj.target == "weibo_id"
assert conf.target.target == "weibo_id" assert conf.target.target == "weibo_id"
assert conf.categories == [] assert conf.categories == []
await config.update_subscribe( await config.update_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=TTarget("weibo_id"), target=TTarget("weibo_id"),
platform_name="weibo", platform_name="weibo",
target_name="weibo_name2", target_name="weibo_name2",
cats=[1], cats=[1],
tags=["tag"], tags=["tag"],
) )
confs = await config.list_subscribe(123, "group") confs = await config.list_subscribe(TargetQQGroup(group_id=123))
assert len(confs) == 1 assert len(confs) == 1
conf: Subscribe = confs[0] conf: Subscribe = confs[0]
async with AsyncSession(get_engine()) as sess: async with AsyncSession(get_engine()) as sess:
@ -64,7 +64,9 @@ async def test_add_subscribe(app: App, init_scheduler):
related_target_obj = await sess.scalar( related_target_obj = await sess.scalar(
select(Target).where(Target.id == conf.target_id) select(Target).where(Target.id == conf.target_id)
) )
assert related_user_obj.uid == 123 assert related_user_obj
assert related_target_obj
assert related_user_obj.user_target["group_id"] == 123
assert related_target_obj.target_name == "weibo_name2" assert related_target_obj.target_name == "weibo_name2"
assert related_target_obj.target == "weibo_id" assert related_target_obj.target == "weibo_id"
assert conf.target.target == "weibo_id" assert conf.target.target == "weibo_id"
@ -73,12 +75,13 @@ async def test_add_subscribe(app: App, init_scheduler):
async def test_add_dup_sub(init_scheduler): async def test_add_dup_sub(init_scheduler):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config.db_config import SubscribeDupException, config from nonebot_bison.config.db_config import SubscribeDupException, config
from nonebot_bison.types import Target as TTarget from nonebot_bison.types import Target as TTarget
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=TTarget("weibo_id"), target=TTarget("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
@ -88,8 +91,7 @@ async def test_add_dup_sub(init_scheduler):
with pytest.raises(SubscribeDupException): with pytest.raises(SubscribeDupException):
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=TTarget("weibo_id"), target=TTarget("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
@ -100,6 +102,7 @@ async def test_add_dup_sub(init_scheduler):
async def test_del_subsribe(init_scheduler): async def test_del_subsribe(init_scheduler):
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select from sqlalchemy.sql.expression import select
from sqlalchemy.sql.functions import func from sqlalchemy.sql.functions import func
@ -109,8 +112,7 @@ async def test_del_subsribe(init_scheduler):
from nonebot_bison.types import Target as TTarget from nonebot_bison.types import Target as TTarget
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=TTarget("weibo_id"), target=TTarget("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
@ -118,8 +120,7 @@ async def test_del_subsribe(init_scheduler):
tags=[], tags=[],
) )
await config.del_subscribe( await config.del_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=TTarget("weibo_id"), target=TTarget("weibo_id"),
platform_name="weibo", platform_name="weibo",
) )
@ -128,8 +129,7 @@ async def test_del_subsribe(init_scheduler):
assert (await sess.scalar(select(func.count()).select_from(Target))) == 1 assert (await sess.scalar(select(func.count()).select_from(Target))) == 1
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=TTarget("weibo_id"), target=TTarget("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
@ -138,8 +138,7 @@ async def test_del_subsribe(init_scheduler):
) )
await config.add_subscribe( await config.add_subscribe(
user=124, TargetQQGroup(group_id=124),
user_type="group",
target=TTarget("weibo_id"), target=TTarget("weibo_id"),
target_name="weibo_name_new", target_name="weibo_name_new",
platform_name="weibo", platform_name="weibo",
@ -148,8 +147,7 @@ async def test_del_subsribe(init_scheduler):
) )
await config.del_subscribe( await config.del_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=TTarget("weibo_id"), target=TTarget("weibo_id"),
platform_name="weibo", platform_name="weibo",
) )
@ -157,5 +155,6 @@ async def test_del_subsribe(init_scheduler):
async with AsyncSession(get_engine()) as sess: async with AsyncSession(get_engine()) as sess:
assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 1 assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 1
assert (await sess.scalar(select(func.count()).select_from(Target))) == 1 assert (await sess.scalar(select(func.count()).select_from(Target))) == 1
target: Target = await sess.scalar(select(Target)) target = await sess.scalar(select(Target))
assert target
assert target.target_name == "weibo_name_new" assert target.target_name == "weibo_name_new"

View File

@ -1,5 +1,6 @@
async def test_migration(use_legacy_config): async def test_migration(use_legacy_config):
from nonebot_plugin_datastore.db import init_db from nonebot_plugin_datastore.db import init_db
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config.config_legacy import Config from nonebot_bison.config.config_legacy import Config
from nonebot_bison.config.db_config import config from nonebot_bison.config.db_config import config
@ -34,7 +35,7 @@ async def test_migration(use_legacy_config):
) )
# await data_migrate() # await data_migrate()
await init_db() await init_db()
user123_config = await config.list_subscribe(123, "group") user123_config = await config.list_subscribe(TargetQQGroup(group_id=123))
assert len(user123_config) == 2 assert len(user123_config) == 2
for c in user123_config: for c in user123_config:
if c.target.target == "weibo_id": if c.target.target == "weibo_id":
@ -47,7 +48,7 @@ async def test_migration(use_legacy_config):
assert c.target.target_name == "weibo_name2" assert c.target.target_name == "weibo_name2"
assert c.target.platform_name == "weibo" assert c.target.platform_name == "weibo"
assert c.tags == ["tag"] assert c.tags == ["tag"]
user234_config = await config.list_subscribe(234, "group") user234_config = await config.list_subscribe(TargetQQGroup(group_id=234))
assert len(user234_config) == 1 assert len(user234_config) == 1
assert user234_config[0].categories == [1] assert user234_config[0].categories == [1]
assert user234_config[0].target.target == "weibo_id" assert user234_config[0].target.target == "weibo_id"
@ -57,6 +58,7 @@ async def test_migration(use_legacy_config):
async def test_migrate_dup(use_legacy_config): async def test_migrate_dup(use_legacy_config):
from nonebot_plugin_datastore.db import init_db from nonebot_plugin_datastore.db import init_db
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config.config_legacy import Config from nonebot_bison.config.config_legacy import Config
from nonebot_bison.config.db_config import config from nonebot_bison.config.db_config import config
@ -82,5 +84,5 @@ async def test_migrate_dup(use_legacy_config):
) )
# await data_migrate() # await data_migrate()
await init_db() await init_db()
user123_config = await config.list_subscribe(123, "group") user123_config = await config.list_subscribe(TargetQQGroup(group_id=123))
assert len(user123_config) == 1 assert len(user123_config) == 1

View File

@ -6,14 +6,14 @@ from pytest_mock import MockerFixture
async def test_create_config(init_scheduler): async def test_create_config(init_scheduler):
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=T_Target("weibo_id"), target=T_Target("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
@ -21,8 +21,7 @@ async def test_create_config(init_scheduler):
tags=[], tags=[],
) )
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=T_Target("weibo_id1"), target=T_Target("weibo_id1"),
target_name="weibo_name1", target_name="weibo_name1",
platform_name="weibo", platform_name="weibo",
@ -58,6 +57,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture):
from datetime import time from datetime import time
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import db_config from nonebot_bison.config import db_config
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
@ -65,8 +65,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture):
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=T_Target("weibo_id"), target=T_Target("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
@ -74,8 +73,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture):
tags=[], tags=[],
) )
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=T_Target("weibo_id1"), target=T_Target("weibo_id1"),
target_name="weibo_name1", target_name="weibo_name1",
platform_name="weibo", platform_name="weibo",
@ -83,8 +81,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture):
tags=[], tags=[],
) )
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=T_Target("weibo_id1"), target=T_Target("weibo_id1"),
target_name="weibo_name2", target_name="weibo_name2",
platform_name="bilibili", platform_name="bilibili",
@ -124,6 +121,7 @@ async def test_get_current_weight(init_scheduler, mocker: MockerFixture):
async def test_get_platform_target(app: App, init_scheduler): async def test_get_platform_target(app: App, init_scheduler):
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select from sqlalchemy.sql.expression import select
@ -133,8 +131,7 @@ async def test_get_platform_target(app: App, init_scheduler):
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=T_Target("weibo_id"), target=T_Target("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
@ -142,8 +139,7 @@ async def test_get_platform_target(app: App, init_scheduler):
tags=[], tags=[],
) )
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=T_Target("weibo_id1"), target=T_Target("weibo_id1"),
target_name="weibo_name1", target_name="weibo_name1",
platform_name="weibo", platform_name="weibo",
@ -151,8 +147,7 @@ async def test_get_platform_target(app: App, init_scheduler):
tags=[], tags=[],
) )
await config.add_subscribe( await config.add_subscribe(
user=245, TargetQQGroup(group_id=245),
user_type="group",
target=T_Target("weibo_id1"), target=T_Target("weibo_id1"),
target_name="weibo_name1", target_name="weibo_name1",
platform_name="weibo", platform_name="weibo",
@ -161,10 +156,14 @@ async def test_get_platform_target(app: App, init_scheduler):
) )
res = await config.get_platform_target("weibo") res = await config.get_platform_target("weibo")
assert len(res) == 2 assert len(res) == 2
await config.del_subscribe(123, "group", T_Target("weibo_id1"), "weibo") await config.del_subscribe(
TargetQQGroup(group_id=123), T_Target("weibo_id1"), "weibo"
)
res = await config.get_platform_target("weibo") res = await config.get_platform_target("weibo")
assert len(res) == 2 assert len(res) == 2
await config.del_subscribe(123, "group", T_Target("weibo_id"), "weibo") await config.del_subscribe(
TargetQQGroup(group_id=123), T_Target("weibo_id"), "weibo"
)
res = await config.get_platform_target("weibo") res = await config.get_platform_target("weibo")
assert len(res) == 1 assert len(res) == 1
@ -175,6 +174,7 @@ async def test_get_platform_target(app: App, init_scheduler):
async def test_get_platform_target_subscribers(app: App, init_scheduler): async def test_get_platform_target_subscribers(app: App, init_scheduler):
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import select from sqlalchemy.sql.expression import select
@ -182,12 +182,10 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler):
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
from nonebot_bison.types import User as T_User
from nonebot_bison.types import UserSubInfo from nonebot_bison.types import UserSubInfo
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=T_Target("weibo_id"), target=T_Target("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
@ -195,8 +193,7 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler):
tags=["tag1"], tags=["tag1"],
) )
await config.add_subscribe( await config.add_subscribe(
user=123, TargetQQGroup(group_id=123),
user_type="group",
target=T_Target("weibo_id1"), target=T_Target("weibo_id1"),
target_name="weibo_name1", target_name="weibo_name1",
platform_name="weibo", platform_name="weibo",
@ -204,8 +201,7 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler):
tags=["tag2"], tags=["tag2"],
) )
await config.add_subscribe( await config.add_subscribe(
user=245, TargetQQGroup(group_id=245),
user_type="group",
target=T_Target("weibo_id1"), target=T_Target("weibo_id1"),
target_name="weibo_name1", target_name="weibo_name1",
platform_name="weibo", platform_name="weibo",
@ -215,9 +211,9 @@ async def test_get_platform_target_subscribers(app: App, init_scheduler):
res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id")) res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id"))
assert len(res) == 1 assert len(res) == 1
assert res[0] == UserSubInfo(T_User(123, "group"), [1], ["tag1"]) assert res[0] == UserSubInfo(TargetQQGroup(group_id=123), [1], ["tag1"])
res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id1")) res = await config.get_platform_target_subscribers("weibo", T_Target("weibo_id1"))
assert len(res) == 2 assert len(res) == 2
assert UserSubInfo(T_User(123, "group"), [2], ["tag2"]) in res assert UserSubInfo(TargetQQGroup(group_id=123), [2], ["tag2"]) in res
assert UserSubInfo(T_User(245, "group"), [3], ["tag3"]) in res assert UserSubInfo(TargetQQGroup(group_id=245), [3], ["tag3"]) in res

View File

@ -67,9 +67,11 @@ async def app(tmp_path: Path, request: pytest.FixtureRequest, mocker: MockerFixt
@pytest.fixture @pytest.fixture
def dummy_user_subinfo(app: App): def dummy_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group") from nonebot_bison.types import UserSubInfo
user = TargetQQGroup(group_id=123)
return UserSubInfo(user=user, categories=[], tags=[]) return UserSubInfo(user=user, categories=[], tags=[])

View File

@ -16,9 +16,11 @@ def bili_live(app: App):
@pytest.fixture @pytest.fixture
def dummy_only_open_user_subinfo(app: App): def dummy_only_open_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group") from nonebot_bison.types import UserSubInfo
user = TargetQQGroup(group_id=123)
return UserSubInfo(user=user, categories=[1], tags=[]) return UserSubInfo(user=user, categories=[1], tags=[])
@ -68,9 +70,11 @@ async def test_fetch_bililive_only_live_open(bili_live, dummy_only_open_user_sub
@pytest.fixture @pytest.fixture
def dummy_only_title_user_subinfo(app: App): def dummy_only_title_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group") from nonebot_bison.types import UserSubInfo
user = TargetQQGroup(group_id=123)
return UserSubInfo(user=user, categories=[2], tags=[]) return UserSubInfo(user=user, categories=[2], tags=[])
@ -128,9 +132,11 @@ async def test_fetch_bililive_only_title_change(
@pytest.fixture @pytest.fixture
def dummy_only_close_user_subinfo(app: App): def dummy_only_close_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group") from nonebot_bison.types import UserSubInfo
user = TargetQQGroup(group_id=123)
return UserSubInfo(user=user, categories=[3], tags=[]) return UserSubInfo(user=user, categories=[3], tags=[])
@ -187,9 +193,11 @@ async def test_fetch_bililive_only_close(bili_live, dummy_only_close_user_subinf
@pytest.fixture @pytest.fixture
def dummy_bililive_user_subinfo(app: App): def dummy_bililive_user_subinfo(app: App):
from nonebot_bison.types import User, UserSubInfo from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group") from nonebot_bison.types import UserSubInfo
user = TargetQQGroup(group_id=123)
return UserSubInfo(user=user, categories=[1, 2, 3], tags=[]) return UserSubInfo(user=user, categories=[1, 2, 3], tags=[])

View File

@ -24,9 +24,9 @@ raw_post_list_2 = raw_post_list_1 + [
@pytest.fixture @pytest.fixture
def dummy_user(app: App): def dummy_user(app: App):
from nonebot_bison.types import User from nonebot_plugin_saa import TargetQQGroup
user = User(123, "group") user = TargetQQGroup(group_id=123)
return user return user

View File

@ -12,6 +12,8 @@ if typing.TYPE_CHECKING:
async def get_schedule_times( async def get_schedule_times(
scheduler_config: Type["SchedulerConfig"], time: int scheduler_config: Type["SchedulerConfig"], time: int
) -> dict[str, int]: ) -> dict[str, int]:
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.scheduler import scheduler_dict from nonebot_bison.scheduler import scheduler_dict
scheduler = scheduler_dict[scheduler_config] scheduler = scheduler_dict[scheduler_config]
@ -25,6 +27,8 @@ async def get_schedule_times(
async def test_scheduler_without_time(init_scheduler): async def test_scheduler_without_time(init_scheduler):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.config.db_config import WeightConfig from nonebot_bison.config.db_config import WeightConfig
from nonebot_bison.platform.bilibili import BilibiliSchedConf from nonebot_bison.platform.bilibili import BilibiliSchedConf
@ -32,13 +36,13 @@ async def test_scheduler_without_time(init_scheduler):
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
await config.add_subscribe( await config.add_subscribe(
123, "group", T_Target("t1"), "target1", "bilibili", [], [] TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], []
) )
await config.add_subscribe( await config.add_subscribe(
123, "group", T_Target("t2"), "target1", "bilibili", [], [] TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili", [], []
) )
await config.add_subscribe( await config.add_subscribe(
123, "group", T_Target("t2"), "target1", "bilibili-live", [], [] TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili-live", [], []
) )
await config.update_time_weight_config( await config.update_time_weight_config(
@ -62,6 +66,8 @@ async def test_scheduler_without_time(init_scheduler):
async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixture): async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixture):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config, db_config from nonebot_bison.config import config, db_config
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig
from nonebot_bison.platform.bilibili import BilibiliSchedConf from nonebot_bison.platform.bilibili import BilibiliSchedConf
@ -69,13 +75,13 @@ async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixtu
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
await config.add_subscribe( await config.add_subscribe(
123, "group", T_Target("t1"), "target1", "bilibili", [], [] TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], []
) )
await config.add_subscribe( await config.add_subscribe(
123, "group", T_Target("t2"), "target1", "bilibili", [], [] TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili", [], []
) )
await config.add_subscribe( await config.add_subscribe(
123, "group", T_Target("t2"), "target1", "bilibili-live", [], [] TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili-live", [], []
) )
await config.update_time_weight_config( await config.update_time_weight_config(
@ -113,38 +119,42 @@ async def test_scheduler_with_time(app: App, init_scheduler, mocker: MockerFixtu
async def test_scheduler_add_new(init_scheduler): async def test_scheduler_add_new(init_scheduler):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.platform.bilibili import BilibiliSchedConf from nonebot_bison.platform.bilibili import BilibiliSchedConf
from nonebot_bison.scheduler.manager import init_scheduler from nonebot_bison.scheduler.manager import init_scheduler
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
await config.add_subscribe( await config.add_subscribe(
123, "group", T_Target("t1"), "target1", "bilibili", [], [] TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], []
) )
await init_scheduler() await init_scheduler()
await config.add_subscribe( await config.add_subscribe(
2345, "group", T_Target("t1"), "target1", "bilibili", [], [] TargetQQGroup(group_id=2345), T_Target("t1"), "target1", "bilibili", [], []
) )
await config.add_subscribe( await config.add_subscribe(
123, "group", T_Target("t2"), "target2", "bilibili", [], [] TargetQQGroup(group_id=123), T_Target("t2"), "target2", "bilibili", [], []
) )
stat_res = await get_schedule_times(BilibiliSchedConf, 1) stat_res = await get_schedule_times(BilibiliSchedConf, 1)
assert stat_res["bilibili-t2"] == 1 assert stat_res["bilibili-t2"] == 1
async def test_schedule_delete(init_scheduler): async def test_schedule_delete(init_scheduler):
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.platform.bilibili import BilibiliSchedConf from nonebot_bison.platform.bilibili import BilibiliSchedConf
from nonebot_bison.scheduler.manager import init_scheduler from nonebot_bison.scheduler.manager import init_scheduler
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
await config.add_subscribe( await config.add_subscribe(
123, "group", T_Target("t1"), "target1", "bilibili", [], [] TargetQQGroup(group_id=123), T_Target("t1"), "target1", "bilibili", [], []
) )
await config.add_subscribe( await config.add_subscribe(
123, "group", T_Target("t2"), "target1", "bilibili", [], [] TargetQQGroup(group_id=123), T_Target("t2"), "target1", "bilibili", [], []
) )
await init_scheduler() await init_scheduler()
@ -153,6 +163,6 @@ async def test_schedule_delete(init_scheduler):
assert stat_res["bilibili-t2"] == 1 assert stat_res["bilibili-t2"] == 1
assert stat_res["bilibili-t1"] == 1 assert stat_res["bilibili-t1"] == 1
await config.del_subscribe(123, "group", T_Target("t1"), "bilibili") await config.del_subscribe(TargetQQGroup(group_id=123), T_Target("t1"), "bilibili")
stat_res = await get_schedule_times(BilibiliSchedConf, 2) stat_res = await get_schedule_times(BilibiliSchedConf, 2)
assert stat_res["bilibili-t2"] == 2 assert stat_res["bilibili-t2"] == 2

View File

@ -279,6 +279,7 @@ async def test_abort_add_on_tag(app: App, init_scheduler):
async def test_abort_del_sub(app: App, init_scheduler): async def test_abort_del_sub(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.bot import Bot from nonebot.adapters.onebot.v11.bot import Bot
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.config_manager import del_sub_matcher from nonebot_bison.config_manager import del_sub_matcher
@ -286,8 +287,7 @@ async def test_abort_del_sub(app: App, init_scheduler):
from nonebot_bison.types import Target as T_Target from nonebot_bison.types import Target as T_Target
await config.add_subscribe( await config.add_subscribe(
10000, TargetQQGroup(group_id=10000),
"group",
T_Target("6279793937"), T_Target("6279793937"),
"明日方舟Arknights", "明日方舟Arknights",
"weibo", "weibo",
@ -316,5 +316,5 @@ async def test_abort_del_sub(app: App, init_scheduler):
ctx.receive_event(bot, event_abort) ctx.receive_event(bot, event_abort)
ctx.should_call_send(event_abort, "删除中止", True) ctx.should_call_send(event_abort, "删除中止", True)
ctx.should_finished() ctx.should_finished()
subs = await config.list_subscribe(10000, "group") subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert subs assert subs

View File

@ -64,6 +64,7 @@ async def test_configurable_at_me_false(app: App):
async def test_add_with_target(app: App, init_scheduler): async def test_add_with_target(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -171,7 +172,7 @@ async def test_add_with_target(app: App, init_scheduler):
event_6_ok, BotReply.add_reply_subscribe_success("明日方舟Arknights"), True event_6_ok, BotReply.add_reply_subscribe_success("明日方舟Arknights"), True
) )
ctx.should_finished() ctx.should_finished()
subs = await config.list_subscribe(10000, "group") subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 1 assert len(subs) == 1
sub = subs[0] sub = subs[0]
assert sub.target.target == "6279793937" assert sub.target.target == "6279793937"
@ -188,6 +189,7 @@ async def test_add_with_target(app: App, init_scheduler):
async def test_add_with_target_no_cat(app: App, init_scheduler): async def test_add_with_target_no_cat(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -233,7 +235,7 @@ async def test_add_with_target_no_cat(app: App, init_scheduler):
event_4_ok, BotReply.add_reply_subscribe_success("塞壬唱片-MSR"), True event_4_ok, BotReply.add_reply_subscribe_success("塞壬唱片-MSR"), True
) )
ctx.should_finished() ctx.should_finished()
subs = await config.list_subscribe(10000, "group") subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 1 assert len(subs) == 1
sub = subs[0] sub = subs[0]
assert sub.target.target == "32540734" assert sub.target.target == "32540734"
@ -248,6 +250,7 @@ async def test_add_with_target_no_cat(app: App, init_scheduler):
async def test_add_no_target(app: App, init_scheduler): async def test_add_no_target(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -284,7 +287,7 @@ async def test_add_no_target(app: App, init_scheduler):
event_4, BotReply.add_reply_subscribe_success("明日方舟游戏信息"), True event_4, BotReply.add_reply_subscribe_success("明日方舟游戏信息"), True
) )
ctx.should_finished() ctx.should_finished()
subs = await config.list_subscribe(10000, "group") subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 1 assert len(subs) == 1
sub = subs[0] sub = subs[0]
assert sub.target.target == "default" assert sub.target.target == "default"
@ -334,6 +337,7 @@ async def test_platform_name_err(app: App):
async def test_add_with_get_id(app: App): async def test_add_with_get_id(app: App):
from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message, MessageSegment from nonebot.adapters.onebot.v11.message import Message, MessageSegment
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -407,7 +411,7 @@ async def test_add_with_get_id(app: App):
True, True,
) )
ctx.should_finished() ctx.should_finished()
subs = await config.list_subscribe(10000, "group") subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 0 assert len(subs) == 0
@ -416,6 +420,7 @@ async def test_add_with_get_id(app: App):
async def test_add_with_bilibili_target_parser(app: App, init_scheduler): async def test_add_with_bilibili_target_parser(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.event import Sender from nonebot.adapters.onebot.v11.event import Sender
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.config_manager import add_sub_matcher, common_platform from nonebot_bison.config_manager import add_sub_matcher, common_platform
@ -524,7 +529,7 @@ async def test_add_with_bilibili_target_parser(app: App, init_scheduler):
event_6, BotReply.add_reply_subscribe_success("明日方舟"), True event_6, BotReply.add_reply_subscribe_success("明日方舟"), True
) )
ctx.should_finished() ctx.should_finished()
subs = await config.list_subscribe(10000, "group") subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 1 assert len(subs) == 1
sub = subs[0] sub = subs[0]
assert sub.target.target == "161775300" assert sub.target.target == "161775300"

View File

@ -10,6 +10,7 @@ from .utils import fake_admin_user, fake_group_message_event
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_sub(app: App, init_scheduler): async def test_query_sub(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.config_manager import query_sub_matcher from nonebot_bison.config_manager import query_sub_matcher
@ -17,8 +18,7 @@ async def test_query_sub(app: App, init_scheduler):
from nonebot_bison.types import Target from nonebot_bison.types import Target
await config.add_subscribe( await config.add_subscribe(
10000, TargetQQGroup(group_id=10000),
"group",
Target("6279793937"), Target("6279793937"),
"明日方舟Arknights", "明日方舟Arknights",
"weibo", "weibo",
@ -40,6 +40,7 @@ async def test_query_sub(app: App, init_scheduler):
async def test_del_sub(app: App, init_scheduler): async def test_del_sub(app: App, init_scheduler):
from nonebot.adapters.onebot.v11.bot import Bot from nonebot.adapters.onebot.v11.bot import Bot
from nonebot.adapters.onebot.v11.message import Message from nonebot.adapters.onebot.v11.message import Message
from nonebot_plugin_saa import TargetQQGroup
from nonebot_bison.config import config from nonebot_bison.config import config
from nonebot_bison.config_manager import del_sub_matcher from nonebot_bison.config_manager import del_sub_matcher
@ -47,8 +48,7 @@ async def test_del_sub(app: App, init_scheduler):
from nonebot_bison.types import Target from nonebot_bison.types import Target
await config.add_subscribe( await config.add_subscribe(
10000, TargetQQGroup(group_id=10000),
"group",
Target("6279793937"), Target("6279793937"),
"明日方舟Arknights", "明日方舟Arknights",
"weibo", "weibo",
@ -83,7 +83,7 @@ async def test_del_sub(app: App, init_scheduler):
ctx.receive_event(bot, event_1_ok) ctx.receive_event(bot, event_1_ok)
ctx.should_call_send(event_1_ok, "删除成功", True) ctx.should_call_send(event_1_ok, "删除成功", True)
ctx.should_finished() ctx.should_finished()
subs = await config.list_subscribe(10000, "group") subs = await config.list_subscribe(TargetQQGroup(group_id=10000))
assert len(subs) == 0 assert len(subs) == 0

View File

@ -30,8 +30,8 @@ async def test_refresh_bots(app: App) -> None:
from nonebot import get_driver from nonebot import get_driver
from nonebot.adapters.onebot.v11 import Bot as BotV11 from nonebot.adapters.onebot.v11 import Bot as BotV11
from nonebot.adapters.onebot.v12 import Bot as BotV12 from nonebot.adapters.onebot.v12 import Bot as BotV12
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from nonebot_bison.types import User
from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots
async with app.test_api() as ctx: async with app.test_api() as ctx:
@ -44,13 +44,13 @@ async def test_refresh_bots(app: App) -> None:
ctx.should_call_api("get_group_list", {}, [{"group_id": 1}]) ctx.should_call_api("get_group_list", {}, [{"group_id": 1}])
ctx.should_call_api("get_friend_list", {}, [{"user_id": 2}]) ctx.should_call_api("get_friend_list", {}, [{"user_id": 2}])
assert get_bot(User(1, "group")) is None assert get_bot(TargetQQGroup(group_id=1)) is None
assert get_bot(User(2, "private")) is None assert get_bot(TargetQQPrivate(user_id=2)) is None
await refresh_bots() await refresh_bots()
assert get_bot(User(1, "group")) == botv11 assert get_bot(TargetQQGroup(group_id=1)) == botv11
assert get_bot(User(2, "private")) == botv11 assert get_bot(TargetQQPrivate(user_id=2)) == botv11
# 测试获取群列表 # 测试获取群列表
ctx.should_call_api("get_group_list", {}, [{"group_id": 3}]) ctx.should_call_api("get_group_list", {}, [{"group_id": 3}])
@ -66,8 +66,8 @@ async def test_get_bot_two_bots(app: App) -> None:
from nonebot import get_driver from nonebot import get_driver
from nonebot.adapters.onebot.v11 import Bot as BotV11 from nonebot.adapters.onebot.v11 import Bot as BotV11
from nonebot.adapters.onebot.v12 import Bot as BotV12 from nonebot.adapters.onebot.v12 import Bot as BotV12
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from nonebot_bison.types import User
from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots from nonebot_bison.utils.get_bot import get_bot, get_groups, refresh_bots
async with app.test_api() as ctx: async with app.test_api() as ctx:
@ -85,14 +85,14 @@ async def test_get_bot_two_bots(app: App) -> None:
await refresh_bots() await refresh_bots()
assert get_bot(User(0, "group")) is None assert get_bot(TargetQQGroup(group_id=0)) is None
assert get_bot(User(1, "group")) == bot1 assert get_bot(TargetQQGroup(group_id=1)) == bot1
assert get_bot(User(2, "group")) in (bot1, bot2) assert get_bot(TargetQQGroup(group_id=2)) in (bot1, bot2)
assert get_bot(User(3, "group")) == bot2 assert get_bot(TargetQQGroup(group_id=3)) == bot2
assert get_bot(User(0, "private")) is None assert get_bot(TargetQQPrivate(user_id=0)) is None
assert get_bot(User(1, "private")) == bot1 assert get_bot(TargetQQPrivate(user_id=1)) == bot1
assert get_bot(User(2, "private")) in (bot1, bot2) assert get_bot(TargetQQPrivate(user_id=2)) in (bot1, bot2)
assert get_bot(User(3, "private")) == bot2 assert get_bot(TargetQQPrivate(user_id=3)) == bot2
ctx.should_call_api("get_group_list", {}, [{"group_id": 1}, {"group_id": 2}]) ctx.should_call_api("get_group_list", {}, [{"group_id": 1}, {"group_id": 2}])
ctx.should_call_api("get_group_list", {}, [{"group_id": 2}, {"group_id": 3}]) ctx.should_call_api("get_group_list", {}, [{"group_id": 2}, {"group_id": 3}])