mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-05 19:36:43 +08:00
88 lines
2.8 KiB
Python
88 lines
2.8 KiB
Python
from collections import defaultdict
|
||
from typing import Any, Callable, cast
|
||
|
||
from nonebot.log import logger
|
||
from nonebot_plugin_datastore.db import create_session
|
||
from sqlalchemy import select
|
||
from sqlalchemy.orm.strategy_options import selectinload
|
||
from sqlalchemy.sql.selectable import Select
|
||
|
||
from ..db_model import Subscribe, User
|
||
from .nbesf_model import (
|
||
NBESFParseErr,
|
||
NBESFVerMatchErr,
|
||
SubGroup,
|
||
SubPack,
|
||
SubPayload,
|
||
UserHead,
|
||
)
|
||
from .utils import subs_receipt_gen_ver_1
|
||
|
||
|
||
async def subscribes_export(selector: Callable[[Select], Select]) -> SubGroup:
|
||
"""
|
||
将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据
|
||
|
||
selector:
|
||
对 sqlalchemy Select 对象的操作函数,用于限定查询范围 e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
|
||
"""
|
||
async with create_session() as sess:
|
||
sub_stmt = select(Subscribe).join(User)
|
||
sub_stmt = selector(sub_stmt).options(selectinload(Subscribe.target))
|
||
sub_stmt = cast(Select[tuple[Subscribe]], sub_stmt)
|
||
sub_data = await sess.scalars(sub_stmt)
|
||
|
||
user_stmt = select(User).join(Subscribe)
|
||
user_stmt = selector(user_stmt).distinct()
|
||
user_stmt = cast(Select[tuple[User]], user_stmt)
|
||
user_data = await sess.scalars(user_stmt)
|
||
|
||
groups: list[SubPack] = []
|
||
user_id_sub_dict: dict[int, list[SubPayload]] = defaultdict(list)
|
||
|
||
for sub in sub_data:
|
||
sub_paylaod = SubPayload.from_orm(sub)
|
||
user_id_sub_dict[sub.user_id].append(sub_paylaod)
|
||
|
||
for user in user_data:
|
||
user_head = UserHead.from_orm(user)
|
||
sub_pack = SubPack(user=user_head, subs=user_id_sub_dict[user.id])
|
||
groups.append(sub_pack)
|
||
|
||
sub_group = SubGroup(groups=groups)
|
||
|
||
return sub_group
|
||
|
||
|
||
async def subscribes_import(
|
||
nbesf_data: SubGroup,
|
||
):
|
||
"""
|
||
从 Nonebot Bison Exchangable Subscribes File 标准格式的数据中导入订阅
|
||
|
||
nbesf_data:
|
||
符合nbesf_model标准的 SubGroup 类型数据
|
||
"""
|
||
|
||
logger.info("开始添加订阅流程")
|
||
match nbesf_data.version:
|
||
case 1:
|
||
await subs_receipt_gen_ver_1(nbesf_data)
|
||
case _:
|
||
raise NBESFVerMatchErr(f"不支持的NBESF版本:{nbesf_data.version}")
|
||
logger.info("订阅流程结束,请检查所有订阅记录是否全部添加成功")
|
||
|
||
|
||
def nbesf_parser(raw_data: Any) -> SubGroup:
|
||
try:
|
||
if isinstance(raw_data, str):
|
||
nbesf_data = SubGroup.parse_raw(raw_data)
|
||
else:
|
||
nbesf_data = SubGroup.parse_obj(raw_data)
|
||
|
||
except Exception as e:
|
||
logger.error("数据解析失败,该数据格式可能不满足NBESF格式标准!")
|
||
raise NBESFParseErr("数据解析失败") from e
|
||
else:
|
||
return nbesf_data
|