AzideCupric 4e304a43b1
通过 nb-cli 实现数据库一键导入导出 (#210)
* feat: 实现导出存储的订阅信息的功能

* test: 编写导出功能测试

* test: 使用tmp_path

* feat: 实现导入订阅文件功能

* refactor: 将订阅导入导出部分独立出来

* fix: 修复一些拼写错误
test: 完成import的第一个测试

* feat: 将订阅导入导出函数加入nb script

test: 添加cli测试

* test: 完善subs import测试

* 🐛 fix nb cli entrypoint name error

* fix: 修改错误的entry_point, 关闭yaml导出时对键名的排序

* fix: 使用更简短的命令名

* 🚚 将subs_io迁移到config下

* ♻️ 不再使用抛出异常的方式创建目录

* refactor: 将subscribe_export类转为函数

* refactor: 将subscribe_import类转为函数

* refactor: 根据重写的subs_io重新调整cli

* test: 调整重写subs_io后的test

* chore: 清理未使用的import内容

* feat(cli): 将--yaml更改为--format

* test: 调整测试

* fix(cli): 为import添加不支持格式的报错

*  improve export performace

* feat: subscribes_import函数不再需要传入参数函数,而是指定为add_subscribes

fix: nbesf_parser在传入str时将调用parse_raw, 否则调用parse_obj

* feat: subscribes_import现在会根据nbesf_data的版本选择合适的导入方式

* fix(test): 调整测试

* feat: nb bison export命令不再将文件导出到data目录,而是当前工作目录

* docs: 增添相关文档

* fix(test): 修复错误的变量名

---------

Co-authored-by: felinae98 <731499577@qq.com>
2023-03-19 16:29:05 +08:00

89 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from collections import defaultdict
from functools import partial
from typing import Any, Callable, TypeVar
from nonebot.log import logger
from nonebot_plugin_datastore.db import create_session
from sqlalchemy import select
from sqlalchemy.orm.strategy_options import selectinload
from sqlalchemy.sql.selectable import Select
from ..db_model import Subscribe, User
from .nbesf_model import (
NBESFParseErr,
NBESFVerMatchErr,
SubGroup,
SubPack,
SubPayload,
UserHead,
)
from .utils import subs_receipt_gen_ver_1
T = TypeVar("T", bound=Select)
async def subscribes_export(selector: Callable[[T], T]) -> SubGroup:
"""
将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据
selector:
对 sqlalchemy Select 对象的操作函数,用于限定查询范围 e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
"""
async with create_session() as sess:
sub_stmt = select(Subscribe).join(User)
sub_stmt = selector(sub_stmt).options(selectinload(Subscribe.target))
sub_data = await sess.scalars(sub_stmt)
user_stmt = select(User).join(Subscribe)
user_stmt = selector(user_stmt).distinct()
user_data = await sess.scalars(user_stmt)
groups: list[SubPack] = []
user_id_sub_dict: dict[int, list[SubPayload]] = defaultdict(list)
for sub in sub_data:
sub_paylaod = SubPayload.from_orm(sub)
user_id_sub_dict[sub.user_id].append(sub_paylaod)
for user in user_data:
user_head = UserHead.from_orm(user)
sub_pack = SubPack(user=user_head, subs=user_id_sub_dict[user.id])
groups.append(sub_pack)
sub_group = SubGroup(groups=groups)
return sub_group
async def subscribes_import(
nbesf_data: SubGroup,
):
"""
从 Nonebot Bison Exchangable Subscribes File 标准格式的数据中导入订阅
nbesf_data:
符合nbesf_model标准的 SubGroup 类型数据
"""
logger.info("开始添加订阅流程")
match nbesf_data.version:
case 1:
await subs_receipt_gen_ver_1(nbesf_data)
case _:
raise NBESFVerMatchErr(f"不支持的NBESF版本{nbesf_data.version}")
logger.info("订阅流程结束,请检查所有订阅记录是否全部添加成功")
def nbesf_parser(raw_data: Any) -> SubGroup:
try:
if isinstance(raw_data, str):
nbesf_data = SubGroup.parse_raw(raw_data)
else:
nbesf_data = SubGroup.parse_obj(raw_data)
except Exception as e:
logger.error("数据解析失败该数据格式可能不满足NBESF格式标准")
raise NBESFParseErr("数据解析失败") from e
else:
return nbesf_data