From cf35432757885948c10d1f2248c2c90d380dd820 Mon Sep 17 00:00:00 2001 From: felinae98 <731499577@qq.com> Date: Tue, 29 Mar 2022 22:43:39 +0800 Subject: [PATCH] update --- .pre-commit-config.yaml | 4 +- src/plugins/nonebot_bison/config/__init__.py | 3 +- src/plugins/nonebot_bison/config/db.py | 99 ++++++------- src/plugins/nonebot_bison/config/db_config.py | 130 +++++++++++++----- src/plugins/nonebot_bison/config/db_model.py | 31 ++++- .../nonebot_bison/config/migrate/env.py | 7 +- .../versions/4a46ba54a3f3_alter_type.py | 53 +++++++ .../a333d6224193_add_last_scheduled_time.py | 33 +++++ .../versions/c97c445e2bdb_add_constraint.py | 47 +++++++ src/plugins/nonebot_bison/config_manager.py | 40 +++--- .../nonebot_bison/platform/platform.py | 2 +- src/plugins/nonebot_bison/post.py | 1 - src/plugins/nonebot_bison/types.py | 3 +- tests/config/test_config_operation.py | 85 +++++++++++- tests/conftest.py | 33 ++--- 15 files changed, 439 insertions(+), 132 deletions(-) create mode 100644 src/plugins/nonebot_bison/config/migrate/versions/4a46ba54a3f3_alter_type.py create mode 100644 src/plugins/nonebot_bison/config/migrate/versions/a333d6224193_add_last_scheduled_time.py create mode 100644 src/plugins/nonebot_bison/config/migrate/versions/c97c445e2bdb_add_constraint.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6796996..e44209a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,12 +11,12 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 22.1.0 + rev: 22.3.0 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-prettier - rev: v2.5.1 + rev: v2.6.1 hooks: - id: prettier types_or: [markdown, ts, tsx] diff --git a/src/plugins/nonebot_bison/config/__init__.py b/src/plugins/nonebot_bison/config/__init__.py index 8d7c675..0544e51 100644 --- a/src/plugins/nonebot_bison/config/__init__.py +++ b/src/plugins/nonebot_bison/config/__init__.py @@ -1,2 +1,3 @@ -from .config_legacy import NoSuchSubscribeException, NoSuchUserException, config +from .config_legacy import NoSuchSubscribeException, NoSuchUserException from .db import DATA +from .db_config import config diff --git a/src/plugins/nonebot_bison/config/db.py b/src/plugins/nonebot_bison/config/db.py index 2a30494..856b14d 100644 --- a/src/plugins/nonebot_bison/config/db.py +++ b/src/plugins/nonebot_bison/config/db.py @@ -1,4 +1,3 @@ -import json from pathlib import Path import nonebot @@ -7,7 +6,9 @@ from alembic.runtime.environment import EnvironmentContext from alembic.script.base import ScriptDirectory from nonebot.log import logger from nonebot_plugin_datastore import PluginData, create_session, db +from nonebot_plugin_datastore.db import get_engine from sqlalchemy.engine.base import Connection +from sqlalchemy.ext.asyncio.session import AsyncSession from .config_legacy import ConfigContent, config from .db_model import Base, Subscribe, Target, User @@ -21,53 +22,52 @@ async def data_migrate(): all_subs: list[ConfigContent] = list( map(lambda item: ConfigContent(**item), config.get_all_subscribe().all()) ) - print(all_subs) - sess = create_session() - user_to_create = [] - subscribe_to_create = [] - platform_target_map: dict[str, tuple[Target, str, int]] = {} - for user in all_subs: - db_user = User(uid=user["user"], type=user["user_type"]) - user_to_create.append(db_user) - for sub in user["subs"]: - target = sub["target"] - platform_name = sub["target_type"] - target_name = sub["target_name"] - key = f"{target}-{platform_name}" - if key in platform_target_map.keys(): - target_obj, ext_user_type, ext_user = platform_target_map[key] - if target_obj.target_name != target_name: - # GG - logger.error( - f"你的旧版本数据库中存在数据不一致问题,请完成迁移后执行重新添加{platform_name}平台的{target}" - f"它的名字可能为{target_obj.target_name}或{target_name}" - ) + async with AsyncSession(get_engine()) as sess: + user_to_create = [] + subscribe_to_create = [] + platform_target_map: dict[str, tuple[Target, str, int]] = {} + for user in all_subs: + db_user = User(uid=user["user"], type=user["user_type"]) + user_to_create.append(db_user) + for sub in user["subs"]: + target = sub["target"] + platform_name = sub["target_type"] + target_name = sub["target_name"] + key = f"{target}-{platform_name}" + if key in platform_target_map.keys(): + target_obj, ext_user_type, ext_user = platform_target_map[key] + if target_obj.target_name != target_name: + # GG + logger.error( + f"你的旧版本数据库中存在数据不一致问题,请完成迁移后执行重新添加{platform_name}平台的{target}" + f"它的名字可能为{target_obj.target_name}或{target_name}" + ) - else: - target_obj = Target( - platform_name=platform_name, - target_name=target_name, - target=target, + else: + target_obj = Target( + platform_name=platform_name, + target_name=target_name, + target=target, + ) + platform_target_map[key] = ( + target_obj, + user["user_type"], + user["user"], + ) + subscribe_obj = Subscribe( + user=db_user, + target=target_obj, + categories=sub["cats"], + tags=sub["tags"], ) - platform_target_map[key] = ( - target_obj, - user["user_type"], - user["user"], - ) - subscribe_obj = Subscribe( - user=db_user, - target=target_obj, - categories=json.dumps(sub["cats"]), - tags=json.dumps(sub["tags"]), - ) - subscribe_to_create.append(subscribe_obj) - sess.add_all( - user_to_create - + list(map(lambda x: x[0], platform_target_map.values())) - + subscribe_to_create - ) - await sess.commit() - logger.info("migrate success") + subscribe_to_create.append(subscribe_obj) + sess.add_all( + user_to_create + + list(map(lambda x: x[0], platform_target_map.values())) + + subscribe_to_create + ) + await sess.commit() + logger.info("migrate success") @nonebot.get_driver().on_startup @@ -85,7 +85,12 @@ async def upgrade_db(): return script._upgrade_revs("head", revision) def do_run_migration(connection: Connection): - env.configure(connection, target_metadata=Base.metadata, fn=migrate_fun) + env.configure( + connection, + target_metadata=Base.metadata, + fn=migrate_fun, + render_as_batch=True, + ) with env.begin_transaction(): env.run_migrations() logger.info("Finish auto migrate") diff --git a/src/plugins/nonebot_bison/config/db_config.py b/src/plugins/nonebot_bison/config/db_config.py index 4843777..f7b15e0 100644 --- a/src/plugins/nonebot_bison/config/db_config.py +++ b/src/plugins/nonebot_bison/config/db_config.py @@ -1,9 +1,11 @@ -import json from typing import Optional from nonebot_bison.types import Category, Tag, Target -from nonebot_plugin_datastore.db import create_session -from sqlalchemy.sql.expression import select +from nonebot_plugin_datastore.db import get_engine +from sqlalchemy.ext.asyncio.session import AsyncSession +from sqlalchemy.orm import selectinload +from sqlalchemy.sql.expression import delete, select +from sqlalchemy.sql.functions import func from .db_model import Subscribe as MSubscribe from .db_model import Target as MTarget @@ -11,9 +13,6 @@ from .db_model import User class DBConfig: - def __init__(self): - self.session = create_session() - async def add_subscribe( self, user: int, @@ -24,35 +23,98 @@ class DBConfig: cats: list[Category], tags: list[Tag], ): - db_user_stmt = ( - select(User).where(User.uid == user).where(User.type == user_type) - ) - db_user: Optional[User] = (await self.session.scalars(db_user_stmt)).first() - if not db_user: - db_user = User(uid=user, type=user_type) - self.session.add(db_user) - db_target_stmt = ( - select(MTarget) - .where(MTarget.platform_name == platform_name) - .where(MTarget.target == target) - ) - db_target: Optional[MTarget] = ( - await self.session.scalars(db_target_stmt) - ).first() - if not db_target: - db_target = MTarget( - target=target, platform_name=platform_name, target_name=target_name + async with AsyncSession(get_engine()) as session: + db_user_stmt = ( + select(User).where(User.uid == user).where(User.type == user_type) ) - else: - db_target.target_name = target_name # type: ignore - subscribe = MSubscribe( - categories=json.dumps(cats), - tags=json.dumps(tags), - user=db_user, - target=db_target, - ) - self.session.add(subscribe) - await self.session.commit() + db_user: Optional[User] = await session.scalar(db_user_stmt) + if not db_user: + db_user = User(uid=user, type=user_type) + session.add(db_user) + db_target_stmt = ( + select(MTarget) + .where(MTarget.platform_name == platform_name) + .where(MTarget.target == target) + ) + db_target: Optional[MTarget] = await session.scalar(db_target_stmt) + if not db_target: + db_target = MTarget( + target=target, platform_name=platform_name, target_name=target_name + ) + else: + db_target.target_name = target_name # type: ignore + subscribe = MSubscribe( + categories=cats, + tags=tags, + user=db_user, + target=db_target, + ) + session.add(subscribe) + await session.commit() + + async def list_subscribe(self, user: int, user_type: str) -> list[MSubscribe]: + async with AsyncSession(get_engine()) as session: + query_stmt = ( + select(MSubscribe) + .where(User.type == user_type and User.uid == user) + .join(User) + .options(selectinload(MSubscribe.target)) + ) # type:ignore + subs: list[MSubscribe] = (await session.scalars(query_stmt)).all() + return subs + + async def del_subscribe( + self, user: int, user_type: str, target: str, platform_name: str + ): + async with AsyncSession(get_engine()) as session: + user_obj = await session.scalar( + select(User).where(User.uid == user and User.type == user_type) + ) + target_obj = await session.scalar( + select(MTarget).where( + MTarget.platform_name == platform_name and MTarget.target == target + ) + ) + await session.execute( + delete(MSubscribe).where( + MSubscribe.user == user_obj and MSubscribe.target == target_obj + ) + ) + target_count = await session.scalar( + select(func.count()) + .select_from(MSubscribe) + .where(MSubscribe.target == target_obj) + ) + if target_count == 0: + # delete empty target + await session.delete(target_obj) + await session.commit() + + async def update_subscribe( + self, + user: int, + user_type: str, + target: str, + target_name: str, + platform_name: str, + cats: list, + tags: list, + ): + async with AsyncSession(get_engine()) as sess: + subscribe_obj: MSubscribe = await sess.scalar( + select(MSubscribe) + .where( + User.uid == user + and User.type == user_type + and MTarget.target == target + and MTarget.platform_name == platform_name + ) + .join(User) + .join(MTarget) + ) + subscribe_obj.tags = tags # type:ignore + subscribe_obj.categories = cats # type:ignore + await sess.commit() config = DBConfig() diff --git a/src/plugins/nonebot_bison/config/db_model.py b/src/plugins/nonebot_bison/config/db_model.py index c905bc3..75faa12 100644 --- a/src/plugins/nonebot_bison/config/db_model.py +++ b/src/plugins/nonebot_bison/config/db_model.py @@ -1,35 +1,52 @@ -from sqlalchemy.orm import declarative_base, relationship -from sqlalchemy.sql.schema import Column, ForeignKey -from sqlalchemy.sql.sqltypes import Integer, String +from datetime import datetime + +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship +from sqlalchemy.sql.schema import Column, ForeignKey, UniqueConstraint +from sqlalchemy.sql.sqltypes import JSON, DateTime, Integer, String Base = declarative_base() class User(Base): __tablename__ = "user" + __table_args__ = (UniqueConstraint("type", "uid", name="unique-user-constraint"),) id = Column(Integer, primary_key=True, autoincrement=True) type = Column(String(20), nullable=False) uid = Column(Integer, nullable=False) + subscribes = relationship("Subscribe", back_populates="user") + class Target(Base): __tablename__ = "target" + __table_args__ = ( + UniqueConstraint("target", "platform_name", name="unique-target-constraint"), + ) id = Column(Integer, primary_key=True, autoincrement=True) platform_name = Column(String(20), nullable=False) target = Column(String(1024), nullable=False) target_name = Column(String(1024), nullable=False) + last_schedule_time = Column( + DateTime(timezone=True), default=datetime(year=2000, month=1, day=1) + ) + + subscribes = relationship("Subscribe", back_populates="target") class Subscribe(Base): __tablename__ = "subscribe" + __table_args__ = ( + UniqueConstraint("target_id", "user_id", name="unique-subscribe-constraint"), + ) id = Column(Integer, primary_key=True, autoincrement=True) target_id = Column(Integer, ForeignKey(Target.id)) user_id = Column(Integer, ForeignKey(User.id)) - categories = Column(String(1024)) - tags = Column(String(1024)) + categories = Column(JSON) + tags = Column(JSON) - target = relationship("Target") - user = relationship("User") + target = relationship("Target", back_populates="subscribes") + user = relationship("User", back_populates="subscribes") diff --git a/src/plugins/nonebot_bison/config/migrate/env.py b/src/plugins/nonebot_bison/config/migrate/env.py index 6ab47e5..79790cc 100644 --- a/src/plugins/nonebot_bison/config/migrate/env.py +++ b/src/plugins/nonebot_bison/config/migrate/env.py @@ -66,7 +66,12 @@ def do_run_migration(connection: Connection): if __as_plugin: context.configure(connection=connection) else: - context.configure(connection=connection, target_metadata=target_metadata) + context.configure( + connection=connection, + target_metadata=target_metadata, + render_as_batch=True, + compare_type=True, + ) with context.begin_transaction(): context.run_migrations() diff --git a/src/plugins/nonebot_bison/config/migrate/versions/4a46ba54a3f3_alter_type.py b/src/plugins/nonebot_bison/config/migrate/versions/4a46ba54a3f3_alter_type.py new file mode 100644 index 0000000..0c3a602 --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrate/versions/4a46ba54a3f3_alter_type.py @@ -0,0 +1,53 @@ +"""alter type + +Revision ID: 4a46ba54a3f3 +Revises: c97c445e2bdb +Create Date: 2022-03-27 21:50:10.911649 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "4a46ba54a3f3" +down_revision = "c97c445e2bdb" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("subscribe", schema=None) as batch_op: + batch_op.alter_column( + "categories", + existing_type=sa.VARCHAR(length=1024), + type_=sa.JSON(), + existing_nullable=True, + ) + batch_op.alter_column( + "tags", + existing_type=sa.VARCHAR(length=1024), + type_=sa.JSON(), + existing_nullable=True, + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("subscribe", schema=None) as batch_op: + batch_op.alter_column( + "tags", + existing_type=sa.JSON(), + type_=sa.VARCHAR(length=1024), + existing_nullable=True, + ) + batch_op.alter_column( + "categories", + existing_type=sa.JSON(), + type_=sa.VARCHAR(length=1024), + existing_nullable=True, + ) + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrate/versions/a333d6224193_add_last_scheduled_time.py b/src/plugins/nonebot_bison/config/migrate/versions/a333d6224193_add_last_scheduled_time.py new file mode 100644 index 0000000..43848e5 --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrate/versions/a333d6224193_add_last_scheduled_time.py @@ -0,0 +1,33 @@ +"""add last scheduled time + +Revision ID: a333d6224193 +Revises: 4a46ba54a3f3 +Create Date: 2022-03-29 21:01:38.213153 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a333d6224193" +down_revision = "4a46ba54a3f3" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("target", schema=None) as batch_op: + batch_op.add_column( + sa.Column("last_schedule_time", sa.DateTime(timezone=True), nullable=True) + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("target", schema=None) as batch_op: + batch_op.drop_column("last_schedule_time") + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrate/versions/c97c445e2bdb_add_constraint.py b/src/plugins/nonebot_bison/config/migrate/versions/c97c445e2bdb_add_constraint.py new file mode 100644 index 0000000..9119d3b --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrate/versions/c97c445e2bdb_add_constraint.py @@ -0,0 +1,47 @@ +"""add constraint + +Revision ID: c97c445e2bdb +Revises: 0571870f5222 +Create Date: 2022-03-26 19:46:50.910721 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c97c445e2bdb" +down_revision = "0571870f5222" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("subscribe", schema=None) as batch_op: + batch_op.create_unique_constraint( + "unique-subscribe-constraint", ["target_id", "user_id"] + ) + + with op.batch_alter_table("target", schema=None) as batch_op: + batch_op.create_unique_constraint( + "unique-target-constraint", ["target", "platform_name"] + ) + + with op.batch_alter_table("user", schema=None) as batch_op: + batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"]) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("user", schema=None) as batch_op: + batch_op.drop_constraint("unique-user-constraint", type_="unique") + + with op.batch_alter_table("target", schema=None) as batch_op: + batch_op.drop_constraint("unique-target-constraint", type_="unique") + + with op.batch_alter_table("subscribe", schema=None) as batch_op: + batch_op.drop_constraint("unique-subscribe-constraint", type_="unique") + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config_manager.py b/src/plugins/nonebot_bison/config_manager.py index dc12594..be4c58b 100644 --- a/src/plugins/nonebot_bison/config_manager.py +++ b/src/plugins/nonebot_bison/config_manager.py @@ -1,6 +1,6 @@ import asyncio from datetime import datetime -from typing import Optional, Type +from typing import Optional, Type, cast from nonebot import on_command from nonebot.adapters.onebot.v11 import Bot, Event, MessageEvent @@ -188,16 +188,16 @@ def do_add_sub(add_sub: Type[Matcher]): @add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)]) async def add_sub_process(event: Event, state: T_State): - user = state.get("target_user_info") + user = cast(User, state.get("target_user_info")) assert isinstance(user, User) - config.add_subscribe( + await config.add_subscribe( # state.get("_user_id") or event.group_id, # user_type="group", user=user.user, user_type=user.user_type, target=state["id"], target_name=state["name"], - target_type=state["platform"], + platform_name=state["platform"], cats=state.get("cats", []), tags=state.get("tags", []), ) @@ -211,7 +211,7 @@ def do_query_sub(query_sub: Type[Matcher]): async def _(state: T_State): user_info = state["target_user_info"] assert isinstance(user_info, User) - sub_list = config.list_subscribe( + sub_list = await config.list_subscribe( # state.get("_user_id") or event.group_id, "group" user_info.user, user_info.user_type, @@ -219,17 +219,20 @@ def do_query_sub(query_sub: Type[Matcher]): res = "订阅的帐号为:\n" for sub in sub_list: res += "{} {} {}".format( - sub["target_type"], sub["target_name"], sub["target"] + # sub["target_type"], sub["target_name"], sub["target"] + sub.target.platform_name, + sub.target.target_name, + sub.target.target, ) - platform = platform_manager[sub["target_type"]] + platform = platform_manager[sub.target.platform_name] if platform.categories: res += " [{}]".format( ", ".join( - map(lambda x: platform.categories[Category(x)], sub["cats"]) + map(lambda x: platform.categories[Category(x)], sub.categories) ) ) if platform.enable_tag: - res += " {}".format(", ".join(sub["tags"])) + res += " {}".format(", ".join(sub.tags)) res += "\n" await query_sub.finish(Message(await parse_text(res))) @@ -241,7 +244,7 @@ def do_del_sub(del_sub: Type[Matcher]): async def send_list(bot: Bot, event: Event, state: T_State): user_info = state["target_user_info"] assert isinstance(user_info, User) - sub_list = config.list_subscribe( + sub_list = await config.list_subscribe( # state.get("_user_id") or event.group_id, "group" user_info.user, user_info.user_type, @@ -250,21 +253,24 @@ def do_del_sub(del_sub: Type[Matcher]): state["sub_table"] = {} for index, sub in enumerate(sub_list, 1): state["sub_table"][index] = { - "target_type": sub["target_type"], - "target": sub["target"], + "platform_name": sub.target.platform_name, + "target": sub.target.target, } res += "{} {} {} {}\n".format( - index, sub["target_type"], sub["target_name"], sub["target"] + index, + sub.target.platform_name, + sub.target.target_name, + sub.target.target, ) - platform = platform_manager[sub["target_type"]] + platform = platform_manager[sub.target.platform_name] if platform.categories: res += " [{}]".format( ", ".join( - map(lambda x: platform.categories[Category(x)], sub["cats"]) + map(lambda x: platform.categories[Category(x)], sub.categories) ) ) if platform.enable_tag: - res += " {}".format(", ".join(sub["tags"])) + res += " {}".format(", ".join(sub.tags)) res += "\n" res += "请输入要删除的订阅的序号" await bot.send(event=event, message=Message(await parse_text(res))) @@ -275,7 +281,7 @@ def do_del_sub(del_sub: Type[Matcher]): index = int(str(event.get_message()).strip()) user_info = state["target_user_info"] assert isinstance(user_info, User) - config.del_subscribe( + await config.del_subscribe( # state.get("_user_id") or event.group_id, # "group", user_info.user, diff --git a/src/plugins/nonebot_bison/platform/platform.py b/src/plugins/nonebot_bison/platform/platform.py index bc30824..e1c3471 100644 --- a/src/plugins/nonebot_bison/platform/platform.py +++ b/src/plugins/nonebot_bison/platform/platform.py @@ -1,11 +1,11 @@ import time from abc import ABC, abstractmethod from collections import defaultdict +from dataclasses import dataclass from typing import Any, Collection, Literal, Optional import httpx from nonebot.log import logger -from pydantic.dataclasses import dataclass from ..plugin_config import plugin_config from ..post import Post diff --git a/src/plugins/nonebot_bison/post.py b/src/plugins/nonebot_bison/post.py index 0ed0cc5..ceca521 100644 --- a/src/plugins/nonebot_bison/post.py +++ b/src/plugins/nonebot_bison/post.py @@ -1,4 +1,3 @@ -# from pydantic.dataclasses import dataclass from dataclasses import dataclass, field from functools import reduce from io import BytesIO diff --git a/src/plugins/nonebot_bison/types.py b/src/plugins/nonebot_bison/types.py index c3fc5be..954e90e 100644 --- a/src/plugins/nonebot_bison/types.py +++ b/src/plugins/nonebot_bison/types.py @@ -1,7 +1,6 @@ +from dataclasses import dataclass from typing import Any, Callable, Literal, NamedTuple, NewType -from pydantic.dataclasses import dataclass - RawPost = NewType("RawPost", Any) Target = NewType("Target", str) Category = int diff --git a/tests/config/test_config_operation.py b/tests/config/test_config_operation.py index 3c3bba5..f019654 100644 --- a/tests/config/test_config_operation.py +++ b/tests/config/test_config_operation.py @@ -1,17 +1,96 @@ from nonebug.app import App +from sqlalchemy.ext.asyncio.session import AsyncSession +from sqlalchemy.sql.functions import func +from sqlmodel.sql.expression import select -async def test_add_subscrib(app: App): +async def test_add_subscribe(app: App, db_migration): from nonebot_bison.config.db_config import config - from nonebot_bison.types import Target + from nonebot_bison.config.db_model import Subscribe, Target, User + from nonebot_bison.types import Target as TTarget + from nonebot_plugin_datastore.db import get_engine await config.add_subscribe( user=123, user_type="group", - target=Target("weibo_id"), + target=TTarget("weibo_id"), target_name="weibo_name", platform_name="weibo", cats=[], tags=[], ) + confs = await config.list_subscribe(123, "group") + assert len(confs) == 1 + conf: Subscribe = confs[0] + async with AsyncSession(get_engine()) as sess: + related_user_obj = await sess.scalar( + select(User).where(User.id == conf.user_id) + ) + related_target_obj = await sess.scalar( + select(Target).where(Target.id == conf.target_id) + ) + assert related_user_obj.uid == 123 + assert related_target_obj.target_name == "weibo_name" + assert related_target_obj.target == "weibo_id" + assert conf.target.target == "weibo_id" + assert conf.categories == [] + + +async def test_del_subsribe(db_migration): + from nonebot_bison.config.db_config import config + from nonebot_bison.config.db_model import Subscribe, Target, User + from nonebot_bison.types import Target as TTarget + from nonebot_plugin_datastore.db import get_engine + + await config.add_subscribe( + user=123, + user_type="group", + target=TTarget("weibo_id"), + target_name="weibo_name", + platform_name="weibo", + cats=[], + tags=[], + ) + await config.del_subscribe( + user=123, + user_type="group", + target=TTarget("weibo_id"), + platform_name="weibo", + ) + async with AsyncSession(get_engine()) as sess: + assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 0 + assert (await sess.scalar(select(func.count()).select_from(Target))) == 0 + + await config.add_subscribe( + user=123, + user_type="group", + target=TTarget("weibo_id"), + target_name="weibo_name", + platform_name="weibo", + cats=[], + tags=[], + ) + + await config.add_subscribe( + user=124, + user_type="group", + target=TTarget("weibo_id"), + target_name="weibo_name_new", + platform_name="weibo", + cats=[], + tags=[], + ) + + await config.del_subscribe( + user=123, + user_type="group", + target=TTarget("weibo_id"), + platform_name="weibo", + ) + + async with AsyncSession(get_engine()) as sess: + assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 1 + assert (await sess.scalar(select(func.count()).select_from(Target))) == 1 + target: Target = await sess.scalar(select(Target)) + assert target.target_name == "weibo_name_new" diff --git a/tests/conftest.py b/tests/conftest.py index 18a0691..7e90b45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,8 @@ from pathlib import Path import nonebot import pytest from nonebug.app import App +from sqlalchemy.ext.asyncio.session import AsyncSession +from sqlalchemy.sql.expression import delete @pytest.fixture @@ -12,7 +14,10 @@ async def app(nonebug_init: None, tmp_path: Path, monkeypatch: pytest.MonkeyPatc import nonebot config = nonebot.get_driver().config - config.bison_config_path = str(tmp_path) + config.bison_config_path = str(tmp_path / "legacy_config") + config.datastore_config_dir = str(tmp_path / "config") + config.datastore_cache_dir = str(tmp_path / "cache") + config.datastore_data_dir = str(tmp_path / "data") config.command_start = {""} config.superusers = {"10001"} config.log_level = "TRACE" @@ -29,19 +34,15 @@ def dummy_user_subinfo(app: App): @pytest.fixture -def task_watchdog(request): - def cancel_test_on_exception(task: asyncio.Task): - def maybe_cancel_clbk(t: asyncio.Task): - exception = t.exception() - if exception is None: - return +async def db_migration(app: App): + from nonebot_bison.config.db import upgrade_db + from nonebot_bison.config.db_model import Subscribe, Target, User + from nonebot_plugin_datastore.db import get_engine - for task in asyncio.all_tasks(): - coro = task.get_coro() - if coro.__qualname__ == request.function.__qualname__: - task.cancel() - return - - task.add_done_callback(maybe_cancel_clbk) - - return cancel_test_on_exception + await upgrade_db() + async with AsyncSession(get_engine()) as sess: + await sess.execute(delete(User)) + await sess.execute(delete(Subscribe)) + await sess.execute(delete(Target)) + await sess.commit() + await sess.close()