add dedup for data migration

This commit is contained in:
felinae98 2022-10-09 21:25:11 +08:00
parent 8db0ed3fe1
commit f6e392e8db
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
2 changed files with 39 additions and 4 deletions

View File

@ -28,11 +28,20 @@ async def data_migrate():
for user in all_subs:
db_user = User(uid=user["user"], type=user["user_type"])
user_to_create.append(db_user)
user_sub_set = set()
for sub in user["subs"]:
target = sub["target"]
platform_name = sub["target_type"]
target_name = sub["target_name"]
key = f"{target}-{platform_name}"
if key in user_sub_set:
# a user subscribe a target twice
logger.error(
f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次,"
"随机采用了一个订阅"
)
continue
user_sub_set.add(key)
if key in platform_target_map.keys():
target_obj, ext_user_type, ext_user = platform_target_map[key]
if target_obj.target_name != target_name:

View File

@ -1,9 +1,6 @@
import pytest
async def test_migration(use_legacy_config):
from nonebot_bison.config.config_legacy import config as config_legacy
from nonebot_bison.config.db import data_migrate, upgrade_db
from nonebot_bison.config.db import upgrade_db
from nonebot_bison.config.db_config import config
config_legacy.add_subscribe(
@ -54,3 +51,32 @@ async def test_migration(use_legacy_config):
assert user234_config[0].target.target == "weibo_id"
assert user234_config[0].target.target_name == "weibo_name"
assert user234_config[0].tags == []
async def test_migrate_dup(use_legacy_config):
from nonebot_bison.config.config_legacy import config as config_legacy
from nonebot_bison.config.db import upgrade_db
from nonebot_bison.config.db_config import config
config_legacy.add_subscribe(
user=123,
user_type="group",
target="weibo_id",
target_name="weibo_name",
target_type="weibo",
cats=[2, 3],
tags=[],
)
config_legacy.add_subscribe(
user=123,
user_type="group",
target="weibo_id",
target_name="weibo_name",
target_type="weibo",
cats=[2, 3],
tags=[],
)
# await data_migrate()
await upgrade_db()
user123_config = await config.list_subscribe(123, "group")
assert len(user123_config) == 1