AzideCupric 4e304a43b1
通过 nb-cli 实现数据库一键导入导出 (#210)
* feat: 实现导出存储的订阅信息的功能

* test: 编写导出功能测试

* test: 使用tmp_path

* feat: 实现导入订阅文件功能

* refactor: 将订阅导入导出部分独立出来

* fix: 修复一些拼写错误
test: 完成import的第一个测试

* feat: 将订阅导入导出函数加入nb script

test: 添加cli测试

* test: 完善subs import测试

* 🐛 fix nb cli entrypoint name error

* fix: 修改错误的entry_point, 关闭yaml导出时对键名的排序

* fix: 使用更简短的命令名

* 🚚 将subs_io迁移到config下

* ♻️ 不再使用抛出异常的方式创建目录

* refactor: 将subscribe_export类转为函数

* refactor: 将subscribe_import类转为函数

* refactor: 根据重写的subs_io重新调整cli

* test: 调整重写subs_io后的test

* chore: 清理未使用的import内容

* feat(cli): 将--yaml更改为--format

* test: 调整测试

* fix(cli): 为import添加不支持格式的报错

*  improve export performace

* feat: subscribes_import函数不再需要传入参数函数,而是指定为add_subscribes

fix: nbesf_parser在传入str时将调用parse_raw, 否则调用parse_obj

* feat: subscribes_import现在会根据nbesf_data的版本选择合适的导入方式

* fix(test): 调整测试

* feat: nb bison export命令不再将文件导出到data目录,而是当前工作目录

* docs: 增添相关文档

* fix(test): 修复错误的变量名

---------

Co-authored-by: felinae98 <731499577@qq.com>
2023-03-19 16:29:05 +08:00

321 lines
12 KiB
Python

import asyncio
from collections import defaultdict
from datetime import datetime, time
from typing import Awaitable, Callable, Optional, Sequence
from nonebot_plugin_datastore import create_session
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import selectinload
from ..types import Category, PlatformWeightConfigResp, Tag
from ..types import Target as T_Target
from ..types import TimeWeightConfig
from ..types import User as T_User
from ..types import UserSubInfo, WeightConfig
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
from .utils import NoSuchTargetException
def _get_time():
dt = datetime.now()
cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second)
return cur_time
class SubscribeDupException(Exception):
...
class DBConfig:
def __init__(self):
self.add_target_hook: list[Callable[[str, T_Target], Awaitable]] = []
self.delete_target_hook: list[Callable[[str, T_Target], Awaitable]] = []
def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.add_target_hook.append(fun)
def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.delete_target_hook.append(fun)
async def add_subscribe(
self,
user: int,
user_type: str,
target: T_Target,
target_name: str,
platform_name: str,
cats: list[Category],
tags: list[Tag],
):
async with create_session() as session:
db_user_stmt = (
select(User).where(User.uid == user).where(User.type == user_type)
)
db_user: Optional[User] = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(uid=user, type=user_type)
session.add(db_user)
db_target_stmt = (
select(Target)
.where(Target.platform_name == platform_name)
.where(Target.target == target)
)
db_target: Optional[Target] = await session.scalar(db_target_stmt)
if not db_target:
db_target = Target(
target=target, platform_name=platform_name, target_name=target_name
)
await asyncio.gather(
*[hook(platform_name, target) for hook in self.add_target_hook]
)
else:
db_target.target_name = target_name
subscribe = Subscribe(
categories=cats,
tags=tags,
user=db_user,
target=db_target,
)
session.add(subscribe)
try:
await session.commit()
except IntegrityError as e:
if len(e.args) > 0 and "UNIQUE constraint failed" in e.args[0]:
raise SubscribeDupException()
raise e
async def list_subscribe(self, user: int, user_type: str) -> Sequence[Subscribe]:
async with create_session() as session:
query_stmt = (
select(Subscribe)
.where(User.type == user_type, User.uid == user)
.join(User)
.options(selectinload(Subscribe.target))
)
subs = (await session.scalars(query_stmt)).all()
return subs
async def list_subs_with_all_info(self) -> Sequence[Subscribe]:
"""获取数据库中带有user、target信息的subscribe数据"""
async with create_session() as session:
query_stmt = (
select(Subscribe)
.join(User)
.options(selectinload(Subscribe.target), selectinload(Subscribe.user))
)
subs = (await session.scalars(query_stmt)).all()
return subs
async def del_subscribe(
self, user: int, user_type: str, target: str, platform_name: str
):
async with create_session() as session:
user_obj = await session.scalar(
select(User).where(User.uid == user, User.type == user_type)
)
target_obj = await session.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
)
await session.execute(
delete(Subscribe).where(
Subscribe.user == user_obj, Subscribe.target == target_obj
)
)
target_count = await session.scalar(
select(func.count())
.select_from(Subscribe)
.where(Subscribe.target == target_obj)
)
if target_count == 0:
# delete empty target
await asyncio.gather(
*[
hook(platform_name, T_Target(target))
for hook in self.delete_target_hook
]
)
await session.commit()
async def update_subscribe(
self,
user: int,
user_type: str,
target: str,
target_name: str,
platform_name: str,
cats: list,
tags: list,
):
async with create_session() as sess:
subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe)
.where(
User.uid == user,
User.type == user_type,
Target.target == target,
Target.platform_name == platform_name,
)
.join(User)
.join(Target)
.options(selectinload(Subscribe.target)) # type:ignore
)
subscribe_obj.tags = tags # type:ignore
subscribe_obj.categories = cats # type:ignore
subscribe_obj.target.target_name = target_name
await sess.commit()
async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
async with create_session() as sess:
subq = select(Subscribe.target_id).distinct().subquery()
query = (
select(Target).join(subq).where(Target.platform_name == platform_name)
)
return (await sess.scalars(query)).all()
async def get_time_weight_config(
self, target: T_Target, platform_name: str
) -> WeightConfig:
async with create_session() as sess:
time_weight_conf = (
await sess.scalars(
select(ScheduleTimeWeight)
.where(
Target.platform_name == platform_name, Target.target == target
)
.join(Target)
)
).all()
targetObj = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
)
return WeightConfig(
default=targetObj.default_schedule_weight,
time_config=[
TimeWeightConfig(
start_time=time_conf.start_time,
end_time=time_conf.end_time,
weight=time_conf.weight,
)
for time_conf in time_weight_conf
],
)
async def update_time_weight_config(
self, target: T_Target, platform_name: str, conf: WeightConfig
):
async with create_session() as sess:
targetObj = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
)
if not targetObj:
raise NoSuchTargetException()
target_id = targetObj.id
targetObj.default_schedule_weight = conf.default
delete_statement = delete(ScheduleTimeWeight).where(
ScheduleTimeWeight.target_id == target_id
)
await sess.execute(delete_statement)
for time_conf in conf.time_config:
new_conf = ScheduleTimeWeight(
start_time=time_conf.start_time,
end_time=time_conf.end_time,
weight=time_conf.weight,
target=targetObj,
)
sess.add(new_conf)
await sess.commit()
async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]:
res = {}
cur_time = _get_time()
async with create_session() as sess:
targets = (
await sess.scalars(
select(Target)
.where(Target.platform_name.in_(platform_list))
.options(selectinload(Target.time_weight))
)
).all()
for target in targets:
key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight
for time_conf in target.time_weight:
if (
time_conf.start_time <= cur_time
and time_conf.end_time > cur_time
):
weight = time_conf.weight
break
res[key] = weight
return res
async def get_platform_target_subscribers(
self, platform_name: str, target: T_Target
) -> list[UserSubInfo]:
async with create_session() as sess:
query = (
select(Subscribe)
.join(Target)
.where(Target.platform_name == platform_name, Target.target == target)
.options(selectinload(Subscribe.user))
)
subsribes = (await sess.scalars(query)).all()
return list(
map(
lambda subscribe: UserSubInfo(
T_User(subscribe.user.uid, subscribe.user.type),
subscribe.categories,
subscribe.tags,
),
subsribes,
)
)
async def get_all_weight_config(
self,
) -> dict[str, dict[str, PlatformWeightConfigResp]]:
res: dict[str, dict[str, PlatformWeightConfigResp]] = defaultdict(dict)
async with create_session() as sess:
query = select(Target)
targets = (await sess.scalars(query)).all()
query = select(ScheduleTimeWeight).options(
selectinload(ScheduleTimeWeight.target)
)
time_weights = (await sess.scalars(query)).all()
for target in targets:
platform_name = target.platform_name
if platform_name not in res.keys():
res[platform_name][target.target] = PlatformWeightConfigResp(
target=T_Target(target.target),
target_name=target.target_name,
platform_name=platform_name,
weight=WeightConfig(
default=target.default_schedule_weight, time_config=[]
),
)
for time_weight_config in time_weights:
platform_name = time_weight_config.target.platform_name
target = time_weight_config.target.target
res[platform_name][target].weight.time_config.append(
TimeWeightConfig(
start_time=time_weight_config.start_time,
end_time=time_weight_config.end_time,
weight=time_weight_config.weight,
)
)
return res
config = DBConfig()