388 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from collections import defaultdict
from datetime import time, datetime
from collections.abc import Callable, Sequence, Awaitable
from nonebot.compat import model_dump
from sqlalchemy.orm import selectinload
from sqlalchemy.exc import IntegrityError
from sqlalchemy import func, delete, select
from nonebot_plugin_saa import PlatformTarget
from nonebot_plugin_datastore import create_session
from ..types import Tag
from ..types import Target as T_Target
from .utils import NoSuchTargetException, DuplicateCookieTargetException
from .db_model import User, Cookie, Target, Subscribe, CookieTarget, ScheduleTimeWeight
from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp
def _get_time():
dt = datetime.now()
cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second)
return cur_time
class SubscribeDupException(Exception): ...
class DBConfig:
def __init__(self):
self.add_target_hook: list[Callable[[str, T_Target], Awaitable]] = []
self.delete_target_hook: list[Callable[[str, T_Target], Awaitable]] = []
def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.add_target_hook.append(fun)
def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.delete_target_hook.append(fun)
async def add_subscribe(
self,
user: PlatformTarget,
target: T_Target,
target_name: str,
platform_name: str,
cats: list[Category],
tags: list[Tag],
):
async with create_session() as session:
db_user_stmt = select(User).where(User.user_target == model_dump(user))
db_user: User | None = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(user_target=model_dump(user))
session.add(db_user)
db_target_stmt = select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
db_target: Target | None = await session.scalar(db_target_stmt)
if not db_target:
db_target = Target(target=target, platform_name=platform_name, target_name=target_name)
await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook])
else:
db_target.target_name = target_name
subscribe = Subscribe(
categories=cats,
tags=tags,
user=db_user,
target=db_target,
)
session.add(subscribe)
try:
await session.commit()
except IntegrityError as e:
if len(e.args) > 0 and "UNIQUE constraint failed" in e.args[0]:
raise SubscribeDupException()
raise e
async def list_subscribe(self, user: PlatformTarget) -> Sequence[Subscribe]:
async with create_session() as session:
query_stmt = (
select(Subscribe)
.where(User.user_target == model_dump(user))
.join(User)
.options(selectinload(Subscribe.target))
)
subs = (await session.scalars(query_stmt)).all()
return subs
async def list_subs_with_all_info(self) -> Sequence[Subscribe]:
"""获取数据库中带有user、target信息的subscribe数据"""
async with create_session() as session:
query_stmt = (
select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user))
)
subs = (await session.scalars(query_stmt)).all()
return subs
async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
async with create_session() as session:
user_obj = await session.scalar(select(User).where(User.user_target == model_dump(user)))
target_obj = await session.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj))
target_count = await session.scalar(
select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj)
)
if target_count == 0:
# delete empty target
await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook])
await session.commit()
async def update_subscribe(
self,
user: PlatformTarget,
target: str,
target_name: str,
platform_name: str,
cats: list,
tags: list,
):
async with create_session() as sess:
subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe)
.where(
User.user_target == model_dump(user),
Target.target == target,
Target.platform_name == platform_name,
)
.join(User)
.join(Target)
.options(selectinload(Subscribe.target)) # type:ignore
)
subscribe_obj.tags = tags # type:ignore
subscribe_obj.categories = cats # type:ignore
subscribe_obj.target.target_name = target_name
await sess.commit()
async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
async with create_session() as sess:
subq = select(Subscribe.target_id).distinct().subquery()
query = select(Target).join(subq).where(Target.platform_name == platform_name)
return (await sess.scalars(query)).all()
async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig:
async with create_session() as sess:
time_weight_conf = (
await sess.scalars(
select(ScheduleTimeWeight)
.where(Target.platform_name == platform_name, Target.target == target)
.join(Target)
)
).all()
targetObj = await sess.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
assert targetObj
return WeightConfig(
default=targetObj.default_schedule_weight,
time_config=[
TimeWeightConfig(
start_time=time_conf.start_time,
end_time=time_conf.end_time,
weight=time_conf.weight,
)
for time_conf in time_weight_conf
],
)
async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig):
async with create_session() as sess:
targetObj = await sess.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
if not targetObj:
raise NoSuchTargetException()
target_id = targetObj.id
targetObj.default_schedule_weight = conf.default
delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
await sess.execute(delete_statement)
for time_conf in conf.time_config:
new_conf = ScheduleTimeWeight(
start_time=time_conf.start_time,
end_time=time_conf.end_time,
weight=time_conf.weight,
target=targetObj,
)
sess.add(new_conf)
await sess.commit()
async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]:
res = {}
cur_time = _get_time()
async with create_session() as sess:
targets = (
await sess.scalars(
select(Target)
.where(Target.platform_name.in_(platform_list))
.options(selectinload(Target.time_weight))
)
).all()
for target in targets:
key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight
for time_conf in target.time_weight:
if time_conf.start_time <= cur_time and time_conf.end_time > cur_time:
weight = time_conf.weight
break
res[key] = weight
return res
async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]:
async with create_session() as sess:
query = (
select(Subscribe)
.join(Target)
.where(Target.platform_name == platform_name, Target.target == target)
.options(selectinload(Subscribe.user))
)
subsribes = (await sess.scalars(query)).all()
return [
UserSubInfo(
PlatformTarget.deserialize(subscribe.user.user_target),
subscribe.categories,
subscribe.tags,
)
for subscribe in subsribes
]
async def get_all_weight_config(
self,
) -> dict[str, dict[str, PlatformWeightConfigResp]]:
res: dict[str, dict[str, PlatformWeightConfigResp]] = defaultdict(dict)
async with create_session() as sess:
query = select(Target)
targets = (await sess.scalars(query)).all()
query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target))
time_weights = (await sess.scalars(query)).all()
for target in targets:
platform_name = target.platform_name
if platform_name not in res.keys():
res[platform_name][target.target] = PlatformWeightConfigResp(
target=T_Target(target.target),
target_name=target.target_name,
platform_name=platform_name,
weight=WeightConfig(default=target.default_schedule_weight, time_config=[]),
)
for time_weight_config in time_weights:
platform_name = time_weight_config.target.platform_name
target = time_weight_config.target.target
res[platform_name][target].weight.time_config.append(
TimeWeightConfig(
start_time=time_weight_config.start_time,
end_time=time_weight_config.end_time,
weight=time_weight_config.weight,
)
)
return res
async def get_cookie(
self,
site_name: str | None = None,
target: T_Target | None = None,
is_universal: bool | None = None,
is_anonymous: bool | None = None,
) -> Sequence[Cookie]:
"""获取满足传入条件的所有 cookie"""
async with create_session() as sess:
query = select(Cookie).distinct()
if is_universal is not None:
query = query.where(Cookie.is_universal == is_universal)
if is_anonymous is not None:
query = query.where(Cookie.is_anonymous == is_anonymous)
if site_name:
query = query.where(Cookie.site_name == site_name)
query = query.outerjoin(CookieTarget).options(selectinload(Cookie.targets))
res = (await sess.scalars(query)).all()
if target:
# 如果指定了 target过滤掉不满足要求的cookie
query = select(CookieTarget.cookie_id).join(Target).where(Target.target == target)
ids = set((await sess.scalars(query)).all())
# 如果指定了 target 且未指定 is_universal则添加返回 universal cookie
res = [cookie for cookie in res if cookie.id in ids or cookie.is_universal]
return res
async def get_cookie_by_id(self, cookie_id: int) -> Cookie:
async with create_session() as sess:
cookie = await sess.scalar(select(Cookie).where(Cookie.id == cookie_id))
if not cookie:
raise NoSuchTargetException(f"cookie {cookie_id} not found")
return cookie
async def add_cookie(self, cookie: Cookie) -> int:
async with create_session() as sess:
sess.add(cookie)
await sess.commit()
await sess.refresh(cookie)
return cookie.id
async def update_cookie(self, cookie: Cookie):
async with create_session() as sess:
cookie_in_db: Cookie | None = await sess.scalar(select(Cookie).where(Cookie.id == cookie.id))
if not cookie_in_db:
raise ValueError(f"cookie {cookie.id} not found")
cookie_in_db.content = cookie.content
cookie_in_db.cookie_name = cookie.cookie_name
cookie_in_db.last_usage = cookie.last_usage
cookie_in_db.status = cookie.status
cookie_in_db.tags = cookie.tags
await sess.commit()
async def delete_cookie_by_id(self, cookie_id: int):
async with create_session() as sess:
cookie = await sess.scalar(
select(Cookie)
.where(Cookie.id == cookie_id)
.outerjoin(CookieTarget)
.options(selectinload(Cookie.targets))
)
if not cookie:
raise NoSuchTargetException(f"cookie {cookie_id} not found")
if len(cookie.targets) > 0:
raise Exception(f"cookie {cookie.id} in use")
await sess.execute(delete(Cookie).where(Cookie.id == cookie_id))
await sess.commit()
async def add_cookie_target(self, target: T_Target, platform_name: str, cookie_id: int):
"""通过 cookie_id 可以唯一确定一个 Cookie通过 target 和 platform_name 可以唯一确定一个 Target"""
async with create_session() as sess:
target_obj = await sess.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
# check if relation exists
cookie_target = await sess.scalar(
select(CookieTarget).where(CookieTarget.target == target_obj, CookieTarget.cookie_id == cookie_id)
)
if cookie_target:
raise DuplicateCookieTargetException()
cookie_obj = await sess.scalar(select(Cookie).where(Cookie.id == cookie_id))
cookie_target = CookieTarget(target=target_obj, cookie=cookie_obj)
sess.add(cookie_target)
await sess.commit()
async def delete_cookie_target(self, target: T_Target, platform_name: str, cookie_id: int):
async with create_session() as sess:
target_obj = await sess.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
cookie_obj = await sess.scalar(select(Cookie).where(Cookie.id == cookie_id))
await sess.execute(
delete(CookieTarget).where(CookieTarget.target == target_obj, CookieTarget.cookie == cookie_obj)
)
await sess.commit()
async def delete_cookie_target_by_id(self, cookie_target_id: int):
async with create_session() as sess:
await sess.execute(delete(CookieTarget).where(CookieTarget.id == cookie_target_id))
await sess.commit()
async def get_cookie_target(self) -> list[CookieTarget]:
async with create_session() as sess:
query = (
select(CookieTarget)
.outerjoin(Target)
.options(selectinload(CookieTarget.target))
.outerjoin(Cookie)
.options(selectinload(CookieTarget.cookie))
)
res = list((await sess.scalars(query)).all())
res.sort(key=lambda x: (x.target.platform_name, x.cookie_id, x.target_id))
return res
async def clear_db(self):
"""清空数据库,用于单元测试清理环境"""
async with create_session() as sess:
await sess.execute(delete(User))
await sess.execute(delete(Target))
await sess.execute(delete(ScheduleTimeWeight))
await sess.execute(delete(Subscribe))
await sess.execute(delete(Cookie))
await sess.execute(delete(CookieTarget))
await sess.commit()
config = DBConfig()