mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-02 09:26:12 +08:00
115 lines
4.2 KiB
Python
115 lines
4.2 KiB
Python
from typing import cast
|
||
from collections import defaultdict
|
||
from collections.abc import Callable
|
||
|
||
from sqlalchemy import select
|
||
from nonebot.log import logger
|
||
from sqlalchemy.sql.selectable import Select
|
||
from nonebot_plugin_saa import PlatformTarget
|
||
from nonebot.compat import type_validate_python
|
||
from nonebot_plugin_datastore.db import create_session
|
||
from sqlalchemy.orm.strategy_options import selectinload
|
||
|
||
from .. import config
|
||
from .utils import NBESFVerMatchErr, row2dict
|
||
from .nbesf_model import NBESFBase, v1, v2, v3
|
||
from ..db_model import User, Cookie, Target, Subscribe, CookieTarget
|
||
|
||
|
||
async def subscribes_export(selector: Callable[[Select], Select]) -> v3.SubGroup:
|
||
"""
|
||
将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据
|
||
|
||
selector:
|
||
对 sqlalchemy Select 对象的操作函数,用于限定查询范围
|
||
e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group")
|
||
"""
|
||
async with create_session() as sess:
|
||
sub_stmt = select(Subscribe).join(User)
|
||
sub_stmt = selector(sub_stmt).options(selectinload(Subscribe.target))
|
||
sub_stmt = cast(Select[tuple[Subscribe]], sub_stmt)
|
||
sub_data = await sess.scalars(sub_stmt)
|
||
|
||
user_stmt = select(User).join(Subscribe)
|
||
user_stmt = selector(user_stmt).distinct()
|
||
user_stmt = cast(Select[tuple[User]], user_stmt)
|
||
user_data = await sess.scalars(user_stmt)
|
||
|
||
groups: list[v3.SubPack] = []
|
||
user_id_sub_dict: dict[int, list[v3.SubPayload]] = defaultdict(list)
|
||
|
||
for sub in sub_data:
|
||
sub_paylaod = type_validate_python(v3.SubPayload, sub)
|
||
user_id_sub_dict[sub.user_id].append(sub_paylaod)
|
||
|
||
for user in user_data:
|
||
assert isinstance(user, User)
|
||
sub_pack = v3.SubPack(
|
||
user_target=PlatformTarget.deserialize(user.user_target),
|
||
subs=user_id_sub_dict[user.id],
|
||
)
|
||
groups.append(sub_pack)
|
||
|
||
async with create_session() as sess:
|
||
cookie_target_stmt = (
|
||
select(CookieTarget)
|
||
.join(Cookie)
|
||
.join(Target)
|
||
.options(selectinload(CookieTarget.target))
|
||
.options(selectinload(CookieTarget.cookie))
|
||
)
|
||
cookie_target_data = await sess.scalars(cookie_target_stmt)
|
||
|
||
cookie_target_dict: dict[Cookie, list[v3.Target]] = defaultdict(list)
|
||
for cookie_target in cookie_target_data:
|
||
target_payload = type_validate_python(v3.Target, cookie_target.target)
|
||
cookie_target_dict[cookie_target.cookie].append(target_payload)
|
||
|
||
def cookie_transform(cookie: Cookie, targets: list[v3.Target]) -> v3.Cookie:
|
||
cookie_dict = row2dict(cookie)
|
||
cookie_dict["tags"] = cookie.tags
|
||
cookie_dict["targets"] = targets
|
||
return v3.Cookie(**cookie_dict)
|
||
|
||
cookies: list[v3.Cookie] = []
|
||
cookie_set = set()
|
||
for cookie, targets in cookie_target_dict.items():
|
||
assert isinstance(cookie, Cookie)
|
||
cookies.append(cookie_transform(cookie, targets))
|
||
cookie_set.add(cookie.id)
|
||
|
||
# 添加未关联的cookie
|
||
all_cookies = await config.get_cookie(is_anonymous=False)
|
||
cookies.extend([cookie_transform(c, []) for c in all_cookies if c.id not in cookie_set])
|
||
|
||
sub_group = v3.SubGroup(groups=groups, cookies=cookies)
|
||
|
||
return sub_group
|
||
|
||
|
||
async def subscribes_import(
|
||
nbesf_data: NBESFBase,
|
||
):
|
||
"""
|
||
从 Nonebot Bison Exchangable Subscribes File 标准格式的数据中导入订阅
|
||
|
||
nbesf_data:
|
||
符合nbesf_model标准的 SubGroup 类型数据
|
||
"""
|
||
|
||
logger.info("开始添加订阅流程")
|
||
match nbesf_data.version:
|
||
case 1:
|
||
assert isinstance(nbesf_data, v1.SubGroup)
|
||
await v1.subs_receipt_gen(nbesf_data)
|
||
case 2:
|
||
assert isinstance(nbesf_data, v2.SubGroup)
|
||
await v2.subs_receipt_gen(nbesf_data)
|
||
case 3:
|
||
assert isinstance(nbesf_data, v3.SubGroup)
|
||
await v3.subs_receipt_gen(nbesf_data)
|
||
await v3.magic_cookie_gen(nbesf_data)
|
||
case _:
|
||
raise NBESFVerMatchErr(f"不支持的NBESF版本:{nbesf_data.version}")
|
||
logger.info("订阅流程结束,请检查所有订阅记录是否全部添加成功")
|