🚧 remove User type

This commit is contained in:
felinae98
2023-03-20 14:39:04 +08:00
parent d535f5212d
commit 4118329bb0
24 changed files with 298 additions and 221 deletions
+8 -6
View File
@@ -4,6 +4,7 @@ from fastapi.exceptions import HTTPException
from fastapi.param_functions import Depends
from fastapi.routing import APIRouter
from fastapi.security.oauth2 import OAuth2PasswordBearer
from nonebot_plugin_saa import TargetQQGroup
from ..apis import check_sub_target
from ..config import (
@@ -15,7 +16,7 @@ from ..config import (
from ..config.db_config import SubscribeDupException
from ..platform import platform_manager
from ..types import Target as T_Target
from ..types import User, WeightConfig
from ..types import WeightConfig
from ..utils.get_bot import get_bot, get_groups
from .jwt import load_jwt, pack_jwt
from .token_manager import token_manager
@@ -75,7 +76,7 @@ async def get_admin_groups(qq: int):
res = []
for group in await get_groups():
group_id = group["group_id"]
bot = get_bot(User(group_id, "group"))
bot = get_bot(TargetQQGroup(group_id=group_id))
if not bot:
continue
users = await bot.get_group_member_list(group_id=group_id)
@@ -131,7 +132,7 @@ async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp:
res: SubscribeResp = {}
for group in groups:
group_id = group["id"]
raw_subs = await config.list_subscribe(group_id, "group")
raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id))
subs = list(
map(
lambda sub: SubscribeConfig(
@@ -157,8 +158,7 @@ async def get_target_name(platformName: str, target: str):
async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
try:
await config.add_subscribe(
int(groupNumber),
"group",
TargetQQGroup(group_id=groupNumber),
T_Target(req.target),
req.targetName,
req.platformName,
@@ -173,7 +173,9 @@ async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp:
@router.delete("/subs", dependencies=[Depends(check_group_permission)])
async def del_group_sub(groupNumber: int, platformName: str, target: str):
try:
await config.del_subscribe(int(groupNumber), "group", target, platformName)
await config.del_subscribe(
TargetQQGroup(group_id=groupNumber), target, platformName
)
except (NoSuchUserException, NoSuchSubscribeException):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe")
return StatusResp(ok=True, msg="")
+12 -18
View File
@@ -4,15 +4,14 @@ from datetime import datetime, time
from typing import Awaitable, Callable, Optional, Sequence
from nonebot_plugin_datastore import create_session
from nonebot_plugin_saa import PlatformTarget
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from ..types import Category, PlatformWeightConfigResp, Tag
from ..types import Target as T_Target
from ..types import TimeWeightConfig
from ..types import User as T_User
from ..types import UserSubInfo, WeightConfig
from ..types import TimeWeightConfig, UserSubInfo, WeightConfig
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
from .utils import NoSuchTargetException
@@ -40,8 +39,7 @@ class DBConfig:
async def add_subscribe(
self,
user: int,
user_type: str,
user: PlatformTarget,
target: T_Target,
target_name: str,
platform_name: str,
@@ -49,12 +47,10 @@ class DBConfig:
tags: list[Tag],
):
async with create_session() as session:
db_user_stmt = (
select(User).where(User.uid == user).where(User.type == user_type)
)
db_user_stmt = select(User).where(User.user_target == user.dict())
db_user: Optional[User] = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(uid=user, type=user_type)
db_user = User(user_target=user.dict())
session.add(db_user)
db_target_stmt = (
select(Target)
@@ -85,11 +81,11 @@ class DBConfig:
raise SubscribeDupException()
raise e
async def list_subscribe(self, user: int, user_type: str) -> Sequence[Subscribe]:
async def list_subscribe(self, user: PlatformTarget) -> Sequence[Subscribe]:
async with create_session() as session:
query_stmt = (
select(Subscribe)
.where(User.type == user_type, User.uid == user)
.where(User.user_target == user.dict())
.join(User)
.options(selectinload(Subscribe.target))
)
@@ -109,11 +105,11 @@ class DBConfig:
return subs
async def del_subscribe(
self, user: int, user_type: str, target: str, platform_name: str
self, user: PlatformTarget, target: str, platform_name: str
):
async with create_session() as session:
user_obj = await session.scalar(
select(User).where(User.uid == user, User.type == user_type)
select(User).where(User.user_target == user.dict())
)
target_obj = await session.scalar(
select(Target).where(
@@ -142,8 +138,7 @@ class DBConfig:
async def update_subscribe(
self,
user: int,
user_type: str,
user: PlatformTarget,
target: str,
target_name: str,
platform_name: str,
@@ -154,8 +149,7 @@ class DBConfig:
subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe)
.where(
User.uid == user,
User.type == user_type,
User.user_target == user.dict(),
Target.target == target,
Target.platform_name == platform_name,
)
@@ -272,7 +266,7 @@ class DBConfig:
return list(
map(
lambda subscribe: UserSubInfo(
T_User(subscribe.user.uid, subscribe.user.type),
PlatformTarget.deserialize(subscribe.user.user_target),
subscribe.categories,
subscribe.tags,
),
+6 -1
View File
@@ -1,5 +1,6 @@
from nonebot.log import logger
from nonebot_plugin_datastore.db import get_engine
from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate
from sqlalchemy.ext.asyncio.session import AsyncSession
from .config_legacy import Config, ConfigContent, drop
@@ -21,7 +22,11 @@ async def data_migrate():
subscribe_to_create = []
platform_target_map: dict[str, tuple[Target, str, int]] = {}
for user in all_subs:
db_user = User(uid=user["user"], type=user["user_type"])
if user["user_type"] == "group":
user_target = TargetQQGroup(group_id=user["user"])
else:
user_target = TargetQQPrivate(user_id=user["user"])
db_user = User(user_target=user_target.dict())
user_to_create.append(db_user)
user_sub_set = set()
for sub in user["subs"]:
+5 -5
View File
@@ -14,15 +14,15 @@ get_plugin_data().set_migration_dir(Path(__file__).parent / "migrations")
class User(Model):
__table_args__ = (UniqueConstraint("type", "uid", name="unique-user-constraint"),)
id: Mapped[int] = mapped_column(primary_key=True)
type: Mapped[str] = mapped_column(String(20))
uid: Mapped[int]
user_target: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
user_target: Mapped[dict] = mapped_column(JSON)
subscribes: Mapped[list["Subscribe"]] = relationship(back_populates="user")
@property
def saa_target(self) -> PlatformTarget:
return PlatformTarget.deserialize(self.user_target)
class Target(Model):
__table_args__ = (
@@ -18,6 +18,7 @@ depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.drop_constraint("unique-user-constraint", type_="unique")
batch_op.add_column(sa.Column("user_target", sa.JSON(), nullable=True))
# ### end Alembic commands ###
@@ -27,5 +28,6 @@ def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.drop_column("user_target")
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
# ### end Alembic commands ###
@@ -0,0 +1,34 @@
"""make user_target not nullable
Revision ID: 67c38b3f39c2
Revises: a5466912fad0
Create Date: 2023-03-20 11:08:42.883556
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import sqlite
# revision identifiers, used by Alembic.
revision = "67c38b3f39c2"
down_revision = "a5466912fad0"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.alter_column(
"user_target", existing_type=sqlite.JSON(), nullable=False
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.alter_column("user_target", existing_type=sqlite.JSON(), nullable=True)
# ### end Alembic commands ###
@@ -0,0 +1,33 @@
"""remove uid and type
Revision ID: 8d3863e9d74b
Revises: 67c38b3f39c2
Create Date: 2023-03-20 15:38:20.220599
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "8d3863e9d74b"
down_revision = "67c38b3f39c2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.drop_column("uid")
batch_op.drop_column("type")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op:
batch_op.add_column(sa.Column("type", sa.VARCHAR(length=20), nullable=False))
batch_op.add_column(sa.Column("uid", sa.INTEGER(), nullable=False))
# ### end Alembic commands ###
+14 -34
View File
@@ -15,13 +15,14 @@ from nonebot.params import Depends, EventPlainText, EventToMe
from nonebot.permission import SUPERUSER
from nonebot.rule import to_me
from nonebot.typing import T_State
from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, extract_target
from .apis import check_sub_target
from .config import config
from .config.db_config import SubscribeDupException
from .platform import Platform, platform_manager
from .plugin_config import plugin_config
from .types import Category, Target, User
from .types import Category, Target
from .utils import parse_text
@@ -61,12 +62,8 @@ def ensure_user_info(matcher: Type[Matcher]):
async def set_target_user_info(event: MessageEvent, state: T_State):
if isinstance(event, GroupMessageEvent):
user = User(event.group_id, "group")
state["target_user_info"] = user
elif isinstance(event, PrivateMessageEvent):
user = User(event.user_id, "private")
state["target_user_info"] = user
user = extract_target(event)
state["target_user_info"] = user
def do_add_sub(add_sub: Type[Matcher]):
@@ -201,14 +198,11 @@ def do_add_sub(add_sub: Type[Matcher]):
@add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)])
async def add_sub_process(event: Event, state: T_State):
user = cast(User, state.get("target_user_info"))
assert isinstance(user, User)
user = cast(PlatformTarget, state.get("target_user_info"))
assert isinstance(user, PlatformTarget)
try:
await config.add_subscribe(
# state.get("_user_id") or event.group_id,
# user_type="group",
user=user.user,
user_type=user.user_type,
user=user,
target=state["id"],
target_name=state["name"],
platform_name=state["platform"],
@@ -228,12 +222,8 @@ def do_query_sub(query_sub: Type[Matcher]):
@query_sub.handle()
async def _(state: T_State):
user_info = state["target_user_info"]
assert isinstance(user_info, User)
sub_list = await config.list_subscribe(
# state.get("_user_id") or event.group_id, "group"
user_info.user,
user_info.user_type,
)
assert isinstance(user_info, PlatformTarget)
sub_list = await config.list_subscribe(user_info)
res = "订阅的帐号为:\n"
for sub in sub_list:
res += "{} {} {}".format(
@@ -261,13 +251,9 @@ def do_del_sub(del_sub: Type[Matcher]):
@del_sub.handle()
async def send_list(bot: Bot, event: Event, state: T_State):
user_info = state["target_user_info"]
assert isinstance(user_info, User)
assert isinstance(user_info, PlatformTarget)
try:
sub_list = await config.list_subscribe(
# state.get("_user_id") or event.group_id, "group"
user_info.user,
user_info.user_type,
)
sub_list = await config.list_subscribe(user_info)
assert sub_list
except AssertionError:
await del_sub.finish("暂无已订阅账号\n请使用“添加订阅”命令添加订阅")
@@ -309,14 +295,8 @@ def do_del_sub(del_sub: Type[Matcher]):
try:
index = int(user_msg)
user_info = state["target_user_info"]
assert isinstance(user_info, User)
await config.del_subscribe(
# state.get("_user_id") or event.group_id,
# "group",
user_info.user,
user_info.user_type,
**state["sub_table"][index],
)
assert isinstance(user_info, PlatformTarget)
await config.del_subscribe(user_info, **state["sub_table"][index])
except Exception as e:
await del_sub.reject("删除错误")
else:
@@ -398,7 +378,7 @@ async def do_choose_group_number(state: T_State):
group_number_idx: dict[int, int] = state["group_number_idx"]
idx: int = state["group_idx"]
group_id = group_number_idx[idx]
state["target_user_info"] = User(user=group_id, user_type="group")
state["target_user_info"] = TargetQQGroup(group_id=group_id)
async def _check_command(event_msg: str = EventPlainText()):
+9 -8
View File
@@ -10,10 +10,11 @@ from typing import Any, Collection, Optional, Type
import httpx
from httpx import AsyncClient
from nonebot.log import logger
from nonebot_plugin_saa import PlatformTarget
from ..plugin_config import plugin_config
from ..post import Post
from ..types import Category, RawPost, Tag, Target, User, UserSubInfo
from ..types import Category, RawPost, Tag, Target, UserSubInfo
from ..utils import ProcessContext, SchedulerConfig
@@ -84,12 +85,12 @@ class Platform(metaclass=PlatformABCMeta, base=True):
@abstractmethod
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
) -> list[tuple[PlatformTarget, list[Post]]]:
...
async def do_fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
) -> list[tuple[PlatformTarget, list[Post]]]:
try:
return await self.fetch_new_post(target, users)
except httpx.RequestError as err:
@@ -197,8 +198,8 @@ class Platform(metaclass=PlatformABCMeta, base=True):
async def dispatch_user_post(
self, target: Target, new_posts: list[RawPost], users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
res: list[tuple[User, list[Post]]] = []
) -> list[tuple[PlatformTarget, list[Post]]]:
res: list[tuple[PlatformTarget, list[Post]]] = []
for user, cats, required_tags in users:
user_raw_post = await self.filter_user_custom(
new_posts, cats, required_tags
@@ -314,7 +315,7 @@ class NewMessage(MessageProcess, abstract=True):
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
) -> list[tuple[PlatformTarget, list[Post]]]:
post_list = await self.get_sub_list(target)
new_posts = await self.filter_common_with_diff(target, post_list)
if not new_posts:
@@ -353,7 +354,7 @@ class StatusChange(Platform, abstract=True):
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
) -> list[tuple[PlatformTarget, list[Post]]]:
try:
new_status = await self.get_status(target)
except self.FetchError as err:
@@ -381,7 +382,7 @@ class SimplePost(MessageProcess, abstract=True):
async def fetch_new_post(
self, target: Target, users: list[UserSubInfo]
) -> list[tuple[User, list[Post]]]:
) -> list[tuple[PlatformTarget, list[Post]]]:
new_posts = await self.get_sub_list(target)
if not new_posts:
return []
+2 -1
View File
@@ -3,6 +3,7 @@ from datetime import time
from typing import Any, Literal, NamedTuple, NewType
from httpx import URL
from nonebot_plugin_saa import PlatformTarget as SendTarget
from pydantic import BaseModel
RawPost = Any
@@ -25,7 +26,7 @@ class PlatformTarget:
class UserSubInfo(NamedTuple):
user: User
user: SendTarget
categories: list[Category]
tags: list[Tag]
+30 -31
View File
@@ -1,53 +1,57 @@
""" 提供获取 Bot 的方法 """
import random
from collections import defaultdict
from typing import Any, Optional
import nonebot
from nonebot import get_driver, on_notice
from nonebot.adapters import Bot
from nonebot.adapters.onebot.v11 import Bot as Ob11Bot
from nonebot.adapters.onebot.v11 import (
Bot,
FriendAddNoticeEvent,
GroupDecreaseNoticeEvent,
GroupIncreaseNoticeEvent,
)
from ..types import User
from nonebot_plugin_saa import PlatformTarget, TargetQQGroup, TargetQQPrivate
GROUP: dict[int, list[Bot]] = {}
USER: dict[int, list[Bot]] = {}
BOT_CACHE: dict[PlatformTarget, list[Bot]] = defaultdict(list)
def get_bots() -> list[Bot]:
"""获取所有 OneBot 11 Bot"""
# TODO: support ob12
bots = []
for bot in nonebot.get_bots().values():
if isinstance(bot, Bot):
if isinstance(bot, Ob11Bot):
bots.append(bot)
return bots
async def _refresh_ob11(bot: Ob11Bot):
# 获取群列表
groups = await bot.get_group_list()
for group in groups:
group_id = group["group_id"]
target = TargetQQGroup(group_id=group_id)
BOT_CACHE[target].append(bot)
# 获取好友列表
users = await bot.get_friend_list()
for user in users:
user_id = user["user_id"]
target = TargetQQPrivate(user_id=user_id)
BOT_CACHE[target].append(bot)
async def refresh_bots():
"""刷新缓存的 Bot 数据"""
GROUP.clear()
USER.clear()
BOT_CACHE.clear()
for bot in get_bots():
# 获取群列表
groups = await bot.get_group_list()
for group in groups:
group_id = group["group_id"]
if group_id not in GROUP:
GROUP[group_id] = [bot]
else:
GROUP[group_id].append(bot)
# 获取好友列表
users = await bot.get_friend_list()
for user in users:
user_id = user["user_id"]
if user_id not in USER:
USER[user_id] = [bot]
else:
USER[user_id].append(bot)
match bot:
case Ob11Bot():
await _refresh_ob11(bot)
driver = get_driver()
@@ -75,15 +79,9 @@ async def _(bot: Bot, event: GroupDecreaseNoticeEvent | GroupIncreaseNoticeEvent
await refresh_bots()
def get_bot(user: User) -> Optional[Bot]:
def get_bot(user: PlatformTarget) -> Optional[Bot]:
"""获取 Bot"""
bots = []
if user.user_type == "group":
bots = GROUP.get(user.user, [])
if user.user_type == "private":
bots = USER.get(user.user, [])
bots = BOT_CACHE.get(user)
if not bots:
return
@@ -92,6 +90,7 @@ def get_bot(user: User) -> Optional[Bot]:
async def get_groups() -> list[dict[str, Any]]:
"""获取所有群号"""
# TODO
all_groups: dict[int, dict[str, Any]] = {}
for bot in get_bots():
groups = await bot.get_group_list()