This commit is contained in:
felinae98
2022-03-29 22:43:39 +08:00
parent df23648b0f
commit cf35432757
15 changed files with 439 additions and 132 deletions
+82 -3
View File
@@ -1,17 +1,96 @@
from nonebug.app import App
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.functions import func
from sqlmodel.sql.expression import select
async def test_add_subscrib(app: App):
async def test_add_subscribe(app: App, db_migration):
from nonebot_bison.config.db_config import config
from nonebot_bison.types import Target
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as TTarget
from nonebot_plugin_datastore.db import get_engine
await config.add_subscribe(
user=123,
user_type="group",
target=Target("weibo_id"),
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
confs = await config.list_subscribe(123, "group")
assert len(confs) == 1
conf: Subscribe = confs[0]
async with AsyncSession(get_engine()) as sess:
related_user_obj = await sess.scalar(
select(User).where(User.id == conf.user_id)
)
related_target_obj = await sess.scalar(
select(Target).where(Target.id == conf.target_id)
)
assert related_user_obj.uid == 123
assert related_target_obj.target_name == "weibo_name"
assert related_target_obj.target == "weibo_id"
assert conf.target.target == "weibo_id"
assert conf.categories == []
async def test_del_subsribe(db_migration):
from nonebot_bison.config.db_config import config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as TTarget
from nonebot_plugin_datastore.db import get_engine
await config.add_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.del_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
platform_name="weibo",
)
async with AsyncSession(get_engine()) as sess:
assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 0
assert (await sess.scalar(select(func.count()).select_from(Target))) == 0
await config.add_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=124,
user_type="group",
target=TTarget("weibo_id"),
target_name="weibo_name_new",
platform_name="weibo",
cats=[],
tags=[],
)
await config.del_subscribe(
user=123,
user_type="group",
target=TTarget("weibo_id"),
platform_name="weibo",
)
async with AsyncSession(get_engine()) as sess:
assert (await sess.scalar(select(func.count()).select_from(Subscribe))) == 1
assert (await sess.scalar(select(func.count()).select_from(Target))) == 1
target: Target = await sess.scalar(select(Target))
assert target.target_name == "weibo_name_new"
+17 -16
View File
@@ -5,6 +5,8 @@ from pathlib import Path
import nonebot
import pytest
from nonebug.app import App
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.sql.expression import delete
@pytest.fixture
@@ -12,7 +14,10 @@ async def app(nonebug_init: None, tmp_path: Path, monkeypatch: pytest.MonkeyPatc
import nonebot
config = nonebot.get_driver().config
config.bison_config_path = str(tmp_path)
config.bison_config_path = str(tmp_path / "legacy_config")
config.datastore_config_dir = str(tmp_path / "config")
config.datastore_cache_dir = str(tmp_path / "cache")
config.datastore_data_dir = str(tmp_path / "data")
config.command_start = {""}
config.superusers = {"10001"}
config.log_level = "TRACE"
@@ -29,19 +34,15 @@ def dummy_user_subinfo(app: App):
@pytest.fixture
def task_watchdog(request):
def cancel_test_on_exception(task: asyncio.Task):
def maybe_cancel_clbk(t: asyncio.Task):
exception = t.exception()
if exception is None:
return
async def db_migration(app: App):
from nonebot_bison.config.db import upgrade_db
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_plugin_datastore.db import get_engine
for task in asyncio.all_tasks():
coro = task.get_coro()
if coro.__qualname__ == request.function.__qualname__:
task.cancel()
return
task.add_done_callback(maybe_cancel_clbk)
return cancel_test_on_exception
await upgrade_db()
async with AsyncSession(get_engine()) as sess:
await sess.execute(delete(User))
await sess.execute(delete(Subscribe))
await sess.execute(delete(Target))
await sess.commit()
await sess.close()