diff --git a/src/plugins/nonebot_bison/config/db_config.py b/src/plugins/nonebot_bison/config/db_config.py index f7b15e0..8615bcf 100644 --- a/src/plugins/nonebot_bison/config/db_config.py +++ b/src/plugins/nonebot_bison/config/db_config.py @@ -56,10 +56,10 @@ class DBConfig: async with AsyncSession(get_engine()) as session: query_stmt = ( select(MSubscribe) - .where(User.type == user_type and User.uid == user) + .where(User.type == user_type, User.uid == user) .join(User) - .options(selectinload(MSubscribe.target)) - ) # type:ignore + .options(selectinload(MSubscribe.target)) # type:ignore + ) subs: list[MSubscribe] = (await session.scalars(query_stmt)).all() return subs @@ -68,16 +68,16 @@ class DBConfig: ): async with AsyncSession(get_engine()) as session: user_obj = await session.scalar( - select(User).where(User.uid == user and User.type == user_type) + select(User).where(User.uid == user, User.type == user_type) ) target_obj = await session.scalar( select(MTarget).where( - MTarget.platform_name == platform_name and MTarget.target == target + MTarget.platform_name == platform_name, MTarget.target == target ) ) await session.execute( delete(MSubscribe).where( - MSubscribe.user == user_obj and MSubscribe.target == target_obj + MSubscribe.user == user_obj, MSubscribe.target == target_obj ) ) target_count = await session.scalar( @@ -104,16 +104,18 @@ class DBConfig: subscribe_obj: MSubscribe = await sess.scalar( select(MSubscribe) .where( - User.uid == user - and User.type == user_type - and MTarget.target == target - and MTarget.platform_name == platform_name + User.uid == user, + User.type == user_type, + MTarget.target == target, + MTarget.platform_name == platform_name, ) .join(User) .join(MTarget) + .options(selectinload(MSubscribe.target)) # type:ignore ) subscribe_obj.tags = tags # type:ignore subscribe_obj.categories = cats # type:ignore + subscribe_obj.target.target_name = target_name await sess.commit() diff --git a/tests/config/test_config_operation.py b/tests/config/test_config_operation.py index f019654..333fc10 100644 --- a/tests/config/test_config_operation.py +++ b/tests/config/test_config_operation.py @@ -20,6 +20,15 @@ async def test_add_subscribe(app: App, db_migration): cats=[], tags=[], ) + await config.add_subscribe( + user=234, + user_type="group", + target=TTarget("weibo_id"), + target_name="weibo_name", + platform_name="weibo", + cats=[], + tags=[], + ) confs = await config.list_subscribe(123, "group") assert len(confs) == 1 conf: Subscribe = confs[0] @@ -36,6 +45,32 @@ async def test_add_subscribe(app: App, db_migration): assert conf.target.target == "weibo_id" assert conf.categories == [] + await config.update_subscribe( + user=123, + user_type="group", + target=TTarget("weibo_id"), + platform_name="weibo", + target_name="weibo_name2", + cats=[1], + tags=["tag"], + ) + confs = await config.list_subscribe(123, "group") + assert len(confs) == 1 + conf: Subscribe = confs[0] + async with AsyncSession(get_engine()) as sess: + related_user_obj = await sess.scalar( + select(User).where(User.id == conf.user_id) + ) + related_target_obj = await sess.scalar( + select(Target).where(Target.id == conf.target_id) + ) + assert related_user_obj.uid == 123 + assert related_target_obj.target_name == "weibo_name2" + assert related_target_obj.target == "weibo_id" + assert conf.target.target == "weibo_id" + assert conf.categories == [1] + assert conf.tags == ["tag"] + async def test_del_subsribe(db_migration): from nonebot_bison.config.db_config import config diff --git a/tests/config/test_data_migration.py b/tests/config/test_data_migration.py index e53fd4a..051cb9a 100644 --- a/tests/config/test_data_migration.py +++ b/tests/config/test_data_migration.py @@ -4,3 +4,52 @@ import pytest async def test_migration(use_legacy_config, db_migration): from nonebot_bison.config.config_legacy import config as config_legacy from nonebot_bison.config.db import data_migrate + from nonebot_bison.config.db_config import config + + config_legacy.add_subscribe( + user=123, + user_type="group", + target="weibo_id", + target_name="weibo_name", + target_type="weibo", + cats=[2, 3], + tags=[], + ) + config_legacy.add_subscribe( + user=123, + user_type="group", + target="weibo_id2", + target_name="weibo_name2", + target_type="weibo", + cats=[1, 2], + tags=["tag"], + ) + config_legacy.add_subscribe( + user=234, + user_type="group", + target="weibo_id", + target_name="weibo_name", + target_type="weibo", + cats=[1], + tags=[], + ) + await data_migrate() + user123_config = await config.list_subscribe(123, "group") + assert len(user123_config) == 2 + for c in user123_config: + if c.target.target == "weibo_id": + assert c.categories == [2, 3] + assert c.target.target_name == "weibo_name" + assert c.target.platform_name == "weibo" + assert c.tags == [] + elif c.target.target == "weibo_id2": + assert c.categories == [1, 2] + assert c.target.target_name == "weibo_name2" + assert c.target.platform_name == "weibo" + assert c.tags == ["tag"] + user234_config = await config.list_subscribe(234, "group") + assert len(user234_config) == 1 + assert user234_config[0].categories == [1] + assert user234_config[0].target.target == "weibo_id" + assert user234_config[0].target.target_name == "weibo_name" + assert user234_config[0].tags == []