This commit is contained in:
felinae98 2022-03-29 22:43:39 +08:00
parent df23648b0f
commit cf35432757
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
15 changed files with 439 additions and 132 deletions

View File

@ -11,12 +11,12 @@ repos:
- id: isort - id: isort
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 22.1.0 rev: 22.3.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pre-commit/mirrors-prettier - repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.5.1 rev: v2.6.1
hooks: hooks:
- id: prettier - id: prettier
types_or: [markdown, ts, tsx] types_or: [markdown, ts, tsx]

View File

@ -1,2 +1,3 @@
from .config_legacy import NoSuchSubscribeException, NoSuchUserException, config from .config_legacy import NoSuchSubscribeException, NoSuchUserException
from .db import DATA from .db import DATA
from .db_config import config

View File

@ -1,4 +1,3 @@
import json
from pathlib import Path from pathlib import Path
import nonebot import nonebot
@ -7,7 +6,9 @@ from alembic.runtime.environment import EnvironmentContext
from alembic.script.base import ScriptDirectory from alembic.script.base import ScriptDirectory
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_datastore import PluginData, create_session, db from nonebot_plugin_datastore import PluginData, create_session, db
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.engine.base import Connection from sqlalchemy.engine.base import Connection
from sqlalchemy.ext.asyncio.session import AsyncSession
from .config_legacy import ConfigContent, config from .config_legacy import ConfigContent, config
from .db_model import Base, Subscribe, Target, User from .db_model import Base, Subscribe, Target, User
@ -21,8 +22,7 @@ async def data_migrate():
all_subs: list[ConfigContent] = list( all_subs: list[ConfigContent] = list(
map(lambda item: ConfigContent(**item), config.get_all_subscribe().all()) map(lambda item: ConfigContent(**item), config.get_all_subscribe().all())
) )
print(all_subs) async with AsyncSession(get_engine()) as sess:
sess = create_session()
user_to_create = [] user_to_create = []
subscribe_to_create = [] subscribe_to_create = []
platform_target_map: dict[str, tuple[Target, str, int]] = {} platform_target_map: dict[str, tuple[Target, str, int]] = {}
@ -57,8 +57,8 @@ async def data_migrate():
subscribe_obj = Subscribe( subscribe_obj = Subscribe(
user=db_user, user=db_user,
target=target_obj, target=target_obj,
categories=json.dumps(sub["cats"]), categories=sub["cats"],
tags=json.dumps(sub["tags"]), tags=sub["tags"],
) )
subscribe_to_create.append(subscribe_obj) subscribe_to_create.append(subscribe_obj)
sess.add_all( sess.add_all(
@ -85,7 +85,12 @@ async def upgrade_db():
return script._upgrade_revs("head", revision) return script._upgrade_revs("head", revision)
def do_run_migration(connection: Connection): def do_run_migration(connection: Connection):
env.configure(connection, target_metadata=Base.metadata, fn=migrate_fun) env.configure(
connection,
target_metadata=Base.metadata,
fn=migrate_fun,
render_as_batch=True,
)
with env.begin_transaction(): with env.begin_transaction():
env.run_migrations() env.run_migrations()
logger.info("Finish auto migrate") logger.info("Finish auto migrate")

View File

@ -1,9 +1,11 @@
import json
from typing import Optional from typing import Optional
from nonebot_bison.types import Category, Tag, Target from nonebot_bison.types import Category, Tag, Target
from nonebot_plugin_datastore.db import create_session from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.sql.expression import select from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import delete, select
from sqlalchemy.sql.functions import func
from .db_model import Subscribe as MSubscribe from .db_model import Subscribe as MSubscribe
from .db_model import Target as MTarget from .db_model import Target as MTarget
@ -11,9 +13,6 @@ from .db_model import User
class DBConfig: class DBConfig:
def __init__(self):
self.session = create_session()
async def add_subscribe( async def add_subscribe(
self, self,
user: int, user: int,
@ -24,21 +23,20 @@ class DBConfig:
cats: list[Category], cats: list[Category],
tags: list[Tag], tags: list[Tag],
): ):
async with AsyncSession(get_engine()) as session:
db_user_stmt = ( db_user_stmt = (
select(User).where(User.uid == user).where(User.type == user_type) select(User).where(User.uid == user).where(User.type == user_type)
) )
db_user: Optional[User] = (await self.session.scalars(db_user_stmt)).first() db_user: Optional[User] = await session.scalar(db_user_stmt)
if not db_user: if not db_user:
db_user = User(uid=user, type=user_type) db_user = User(uid=user, type=user_type)
self.session.add(db_user) session.add(db_user)
db_target_stmt = ( db_target_stmt = (
select(MTarget) select(MTarget)
.where(MTarget.platform_name == platform_name) .where(MTarget.platform_name == platform_name)
.where(MTarget.target == target) .where(MTarget.target == target)
) )
db_target: Optional[MTarget] = ( db_target: Optional[MTarget] = await session.scalar(db_target_stmt)
await self.session.scalars(db_target_stmt)
).first()
if not db_target: if not db_target:
db_target = MTarget( db_target = MTarget(
target=target, platform_name=platform_name, target_name=target_name target=target, platform_name=platform_name, target_name=target_name
@ -46,13 +44,77 @@ class DBConfig:
else: else:
db_target.target_name = target_name # type: ignore db_target.target_name = target_name # type: ignore
subscribe = MSubscribe( subscribe = MSubscribe(
categories=json.dumps(cats), categories=cats,
tags=json.dumps(tags), tags=tags,
user=db_user, user=db_user,
target=db_target, target=db_target,
) )
self.session.add(subscribe) session.add(subscribe)
await self.session.commit() await session.commit()
async def list_subscribe(self, user: int, user_type: str) -> list[MSubscribe]:
async with AsyncSession(get_engine()) as session:
query_stmt = (
select(MSubscribe)
.where(User.type == user_type and User.uid == user)
.join(User)
.options(selectinload(MSubscribe.target))
) # type:ignore
subs: list[MSubscribe] = (await session.scalars(query_stmt)).all()
return subs
async def del_subscribe(
self, user: int, user_type: str, target: str, platform_name: str
):
async with AsyncSession(get_engine()) as session:
user_obj = await session.scalar(
select(User).where(User.uid == user and User.type == user_type)
)
target_obj = await session.scalar(
select(MTarget).where(
MTarget.platform_name == platform_name and MTarget.target == target
)
)
await session.execute(
delete(MSubscribe).where(
MSubscribe.user == user_obj and MSubscribe.target == target_obj
)
)
target_count = await session.scalar(
select(func.count())
.select_from(MSubscribe)
.where(MSubscribe.target == target_obj)
)
if target_count == 0:
# delete empty target
await session.delete(target_obj)
await session.commit()
async def update_subscribe(
self,
user: int,
user_type: str,
target: str,
target_name: str,
platform_name: str,
cats: list,
tags: list,
):
async with AsyncSession(get_engine()) as sess:
subscribe_obj: MSubscribe = await sess.scalar(
select(MSubscribe)
.where(
User.uid == user
and User.type == user_type
and MTarget.target == target
and MTarget.platform_name == platform_name
)
.join(User)
.join(MTarget)
)
subscribe_obj.tags = tags # type:ignore
subscribe_obj.categories = cats # type:ignore
await sess.commit()
config = DBConfig() config = DBConfig()

View File

@ -1,35 +1,52 @@
from sqlalchemy.orm import declarative_base, relationship from datetime import datetime
from sqlalchemy.sql.schema import Column, ForeignKey
from sqlalchemy.sql.sqltypes import Integer, String from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.sql.schema import Column, ForeignKey, UniqueConstraint
from sqlalchemy.sql.sqltypes import JSON, DateTime, Integer, String
Base = declarative_base() Base = declarative_base()
class User(Base): class User(Base):
__tablename__ = "user" __tablename__ = "user"
__table_args__ = (UniqueConstraint("type", "uid", name="unique-user-constraint"),)
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
type = Column(String(20), nullable=False) type = Column(String(20), nullable=False)
uid = Column(Integer, nullable=False) uid = Column(Integer, nullable=False)
subscribes = relationship("Subscribe", back_populates="user")
class Target(Base): class Target(Base):
__tablename__ = "target" __tablename__ = "target"
__table_args__ = (
UniqueConstraint("target", "platform_name", name="unique-target-constraint"),
)
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
platform_name = Column(String(20), nullable=False) platform_name = Column(String(20), nullable=False)
target = Column(String(1024), nullable=False) target = Column(String(1024), nullable=False)
target_name = Column(String(1024), nullable=False) target_name = Column(String(1024), nullable=False)
last_schedule_time = Column(
DateTime(timezone=True), default=datetime(year=2000, month=1, day=1)
)
subscribes = relationship("Subscribe", back_populates="target")
class Subscribe(Base): class Subscribe(Base):
__tablename__ = "subscribe" __tablename__ = "subscribe"
__table_args__ = (
UniqueConstraint("target_id", "user_id", name="unique-subscribe-constraint"),
)
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
target_id = Column(Integer, ForeignKey(Target.id)) target_id = Column(Integer, ForeignKey(Target.id))
user_id = Column(Integer, ForeignKey(User.id)) user_id = Column(Integer, ForeignKey(User.id))
categories = Column(String(1024)) categories = Column(JSON)
tags = Column(String(1024)) tags = Column(JSON)
target = relationship("Target") target = relationship("Target", back_populates="subscribes")
user = relationship("User") user = relationship("User", back_populates="subscribes")

View File

@ -66,7 +66,12 @@ def do_run_migration(connection: Connection):
if __as_plugin: if __as_plugin:
context.configure(connection=connection) context.configure(connection=connection)
else: else:
context.configure(connection=connection, target_metadata=target_metadata) context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
compare_type=True,
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

View File

@ -0,0 +1,53 @@
"""alter type
Revision ID: 4a46ba54a3f3
Revises: c97c445e2bdb
Create Date: 2022-03-27 21:50:10.911649
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "4a46ba54a3f3"
down_revision = "c97c445e2bdb"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.alter_column(
"categories",
existing_type=sa.VARCHAR(length=1024),
type_=sa.JSON(),
existing_nullable=True,
)
batch_op.alter_column(
"tags",
existing_type=sa.VARCHAR(length=1024),
type_=sa.JSON(),
existing_nullable=True,
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.alter_column(
"tags",
existing_type=sa.JSON(),
type_=sa.VARCHAR(length=1024),
existing_nullable=True,
)
batch_op.alter_column(
"categories",
existing_type=sa.JSON(),
type_=sa.VARCHAR(length=1024),
existing_nullable=True,
)
# ### end Alembic commands ###

View File

@ -0,0 +1,33 @@
"""add last scheduled time
Revision ID: a333d6224193
Revises: 4a46ba54a3f3
Create Date: 2022-03-29 21:01:38.213153
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "a333d6224193"
down_revision = "4a46ba54a3f3"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.add_column(
sa.Column("last_schedule_time", sa.DateTime(timezone=True), nullable=True)
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.drop_column("last_schedule_time")
# ### end Alembic commands ###

View File

@ -0,0 +1,47 @@
"""add constraint
Revision ID: c97c445e2bdb
Revises: 0571870f5222
Create Date: 2022-03-26 19:46:50.910721
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "c97c445e2bdb"
down_revision = "0571870f5222"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique-subscribe-constraint", ["target_id", "user_id"]
)
with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.create_unique_constraint(
"unique-target-constraint", ["target", "platform_name"]
)
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.drop_constraint("unique-user-constraint", type_="unique")
with op.batch_alter_table("target", schema=None) as batch_op:
batch_op.drop_constraint("unique-target-constraint", type_="unique")
with op.batch_alter_table("subscribe", schema=None) as batch_op:
batch_op.drop_constraint("unique-subscribe-constraint", type_="unique")
# ### end Alembic commands ###

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from typing import Optional, Type from typing import Optional, Type, cast
from nonebot import on_command from nonebot import on_command
from nonebot.adapters.onebot.v11 import Bot, Event, MessageEvent from nonebot.adapters.onebot.v11 import Bot, Event, MessageEvent
@ -188,16 +188,16 @@ def do_add_sub(add_sub: Type[Matcher]):
@add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)]) @add_sub.got("tags", _gen_prompt_template("{_prompt}"), [Depends(parser_tags)])
async def add_sub_process(event: Event, state: T_State): async def add_sub_process(event: Event, state: T_State):
user = state.get("target_user_info") user = cast(User, state.get("target_user_info"))
assert isinstance(user, User) assert isinstance(user, User)
config.add_subscribe( await config.add_subscribe(
# state.get("_user_id") or event.group_id, # state.get("_user_id") or event.group_id,
# user_type="group", # user_type="group",
user=user.user, user=user.user,
user_type=user.user_type, user_type=user.user_type,
target=state["id"], target=state["id"],
target_name=state["name"], target_name=state["name"],
target_type=state["platform"], platform_name=state["platform"],
cats=state.get("cats", []), cats=state.get("cats", []),
tags=state.get("tags", []), tags=state.get("tags", []),
) )
@ -211,7 +211,7 @@ def do_query_sub(query_sub: Type[Matcher]):
async def _(state: T_State): async def _(state: T_State):
user_info = state["target_user_info"] user_info = state["target_user_info"]
assert isinstance(user_info, User) assert isinstance(user_info, User)
sub_list = config.list_subscribe( sub_list = await config.list_subscribe(
# state.get("_user_id") or event.group_id, "group" # state.get("_user_id") or event.group_id, "group"
user_info.user, user_info.user,
user_info.user_type, user_info.user_type,
@ -219,17 +219,20 @@ def do_query_sub(query_sub: Type[Matcher]):
res = "订阅的帐号为:\n" res = "订阅的帐号为:\n"
for sub in sub_list: for sub in sub_list:
res += "{} {} {}".format( res += "{} {} {}".format(
sub["target_type"], sub["target_name"], sub["target"] # sub["target_type"], sub["target_name"], sub["target"]
sub.target.platform_name,
sub.target.target_name,
sub.target.target,
) )
platform = platform_manager[sub["target_type"]] platform = platform_manager[sub.target.platform_name]
if platform.categories: if platform.categories:
res += " [{}]".format( res += " [{}]".format(
", ".join( ", ".join(
map(lambda x: platform.categories[Category(x)], sub["cats"]) map(lambda x: platform.categories[Category(x)], sub.categories)
) )
) )
if platform.enable_tag: if platform.enable_tag:
res += " {}".format(", ".join(sub["tags"])) res += " {}".format(", ".join(sub.tags))
res += "\n" res += "\n"
await query_sub.finish(Message(await parse_text(res))) await query_sub.finish(Message(await parse_text(res)))
@ -241,7 +244,7 @@ def do_del_sub(del_sub: Type[Matcher]):
async def send_list(bot: Bot, event: Event, state: T_State): async def send_list(bot: Bot, event: Event, state: T_State):
user_info = state["target_user_info"] user_info = state["target_user_info"]
assert isinstance(user_info, User) assert isinstance(user_info, User)
sub_list = config.list_subscribe( sub_list = await config.list_subscribe(
# state.get("_user_id") or event.group_id, "group" # state.get("_user_id") or event.group_id, "group"
user_info.user, user_info.user,
user_info.user_type, user_info.user_type,
@ -250,21 +253,24 @@ def do_del_sub(del_sub: Type[Matcher]):
state["sub_table"] = {} state["sub_table"] = {}
for index, sub in enumerate(sub_list, 1): for index, sub in enumerate(sub_list, 1):
state["sub_table"][index] = { state["sub_table"][index] = {
"target_type": sub["target_type"], "platform_name": sub.target.platform_name,
"target": sub["target"], "target": sub.target.target,
} }
res += "{} {} {} {}\n".format( res += "{} {} {} {}\n".format(
index, sub["target_type"], sub["target_name"], sub["target"] index,
sub.target.platform_name,
sub.target.target_name,
sub.target.target,
) )
platform = platform_manager[sub["target_type"]] platform = platform_manager[sub.target.platform_name]
if platform.categories: if platform.categories:
res += " [{}]".format( res += " [{}]".format(
", ".join( ", ".join(
map(lambda x: platform.categories[Category(x)], sub["cats"]) map(lambda x: platform.categories[Category(x)], sub.categories)
) )
) )
if platform.enable_tag: if platform.enable_tag:
res += " {}".format(", ".join(sub["tags"])) res += " {}".format(", ".join(sub.tags))
res += "\n" res += "\n"
res += "请输入要删除的订阅的序号" res += "请输入要删除的订阅的序号"
await bot.send(event=event, message=Message(await parse_text(res))) await bot.send(event=event, message=Message(await parse_text(res)))
@ -275,7 +281,7 @@ def do_del_sub(del_sub: Type[Matcher]):
index = int(str(event.get_message()).strip()) index = int(str(event.get_message()).strip())
user_info = state["target_user_info"] user_info = state["target_user_info"]
assert isinstance(user_info, User) assert isinstance(user_info, User)
config.del_subscribe( await config.del_subscribe(
# state.get("_user_id") or event.group_id, # state.get("_user_id") or event.group_id,
# "group", # "group",
user_info.user, user_info.user,

View File

@ -1,11 +1,11 @@
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Collection, Literal, Optional from typing import Any, Collection, Literal, Optional
import httpx import httpx
from nonebot.log import logger from nonebot.log import logger
from pydantic.dataclasses import dataclass
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
from ..post import Post from ..post import Post

View File

@ -1,4 +1,3 @@
# from pydantic.dataclasses import dataclass
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import reduce from functools import reduce
from io import BytesIO from io import BytesIO

View File

@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Any, Callable, Literal, NamedTuple, NewType from typing import Any, Callable, Literal, NamedTuple, NewType
from pydantic.dataclasses import dataclass
RawPost = NewType("RawPost", Any) RawPost = NewType("RawPost", Any)
Target = NewType("Target", str) Target = NewType("Target", str)
Category = int Category = int

View File

@ -1,17 +1,96 @@
from nonebug.app import App from nonebug.app import App
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.functions import func
from sqlmodel.sql.expression import select
async def test_add_subscrib(app: App): async def test_add_subscribe(app: App, db_migration):
from nonebot_bison.config.db_config import config from nonebot_bison.config.db_config import config
from nonebot_bison.types import Target from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as TTarget
from nonebot_plugin_datastore.db import get_engine
await config.add_subscribe( await config.add_subscribe(
user=123, user=123,
user_type="group", user_type="group",
target=Target("weibo_id"), target=TTarget("weibo_id"),
target_name="weibo_name", target_name="weibo_name",
platform_name="weibo", platform_name="weibo",
cats=[], cats=[],
tags=[], tags=[],
) )
confs = await config.list_subscribe(123, "group")
assert len(confs) == 1
conf: Subscribe = confs[0]
async with AsyncSession(get_engine()) as sess:
related_user_obj = await sess.scalar(
select(User).where(User.id == conf.user_id)
)
related_target_obj = await sess.scalar(
select(Target).where(Target.id == conf.target_id)
)
assert related_user_obj.uid == 123
assert related_target_obj.target_name == "weibo_name"
assert related_target_obj.target == "weibo_id"
assert conf.target.target == "weibo_id"
assert conf.categories == []
async def test_del_subsribe(db_migration):
from nonebot_bison.config.db_config import config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as TTarget
from nonebot_plugin_datastore.db import get_engine
await config.add_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.del_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
platform_name="weibo",
)
async with AsyncSession(get_engine()) as sess:
assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 0
assert (await sess.scalar(select(func.count()).select_from(Target))) == 0
await config.add_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=124,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name_new",
platform_name="weibo",
cats=[],
tags=[],
)
await config.del_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
platform_name="weibo",
)
async with AsyncSession(get_engine()) as sess:
assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 1
assert (await sess.scalar(select(func.count()).select_from(Target))) == 1
target: Target = await sess.scalar(select(Target))
assert target.target_name == "weibo_name_new"

View File

@ -5,6 +5,8 @@ from pathlib import Path
import nonebot import nonebot
import pytest import pytest
from nonebug.app import App from nonebug.app import App
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import delete
@pytest.fixture @pytest.fixture
@ -12,7 +14,10 @@ async def app(nonebug_init: None, tmp_path: Path, monkeypatch: pytest.MonkeyPatc
import nonebot import nonebot
config = nonebot.get_driver().config config = nonebot.get_driver().config
config.bison_config_path = str(tmp_path) config.bison_config_path = str(tmp_path / "legacy_config")
config.datastore_config_dir = str(tmp_path / "config")
config.datastore_cache_dir = str(tmp_path / "cache")
config.datastore_data_dir = str(tmp_path / "data")
config.command_start = {""} config.command_start = {""}
config.superusers = {"10001"} config.superusers = {"10001"}
config.log_level = "TRACE" config.log_level = "TRACE"
@ -29,19 +34,15 @@ def dummy_user_subinfo(app: App):
@pytest.fixture @pytest.fixture
def task_watchdog(request): async def db_migration(app: App):
def cancel_test_on_exception(task: asyncio.Task): from nonebot_bison.config.db import upgrade_db
def maybe_cancel_clbk(t: asyncio.Task): from nonebot_bison.config.db_model import Subscribe, Target, User
exception = t.exception() from nonebot_plugin_datastore.db import get_engine
if exception is None:
return
for task in asyncio.all_tasks(): await upgrade_db()
coro = task.get_coro() async with AsyncSession(get_engine()) as sess:
if coro.__qualname__ == request.function.__qualname__: await sess.execute(delete(User))
task.cancel() await sess.execute(delete(Subscribe))
return await sess.execute(delete(Target))
await sess.commit()
task.add_done_callback(maybe_cancel_clbk) await sess.close()
return cancel_test_on_exception