2023-10-15 23:18:33 +08:00

263 lines
11 KiB
Python

import asyncio
from collections import defaultdict
from datetime import time, datetime
from collections.abc import Callable, Sequence, Awaitable
from sqlalchemy.orm import selectinload
from sqlalchemy.exc import IntegrityError
from sqlalchemy import func, delete, select
from nonebot_plugin_saa import PlatformTarget
from nonebot_plugin_datastore import create_session
from ..types import Tag
from ..types import Target as T_Target
from .utils import NoSuchTargetException
from .db_model import User, Target, Subscribe, ScheduleTimeWeight
from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp
def _get_time():
dt = datetime.now()
cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second)
return cur_time
class SubscribeDupException(Exception): ...
class DBConfig:
def __init__(self):
self.add_target_hook: list[Callable[[str, T_Target], Awaitable]] = []
self.delete_target_hook: list[Callable[[str, T_Target], Awaitable]] = []
def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.add_target_hook.append(fun)
def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.delete_target_hook.append(fun)
async def add_subscribe(
self,
user: PlatformTarget,
target: T_Target,
target_name: str,
platform_name: str,
cats: list[Category],
tags: list[Tag],
):
async with create_session() as session:
db_user_stmt = select(User).where(User.user_target == user.dict())
db_user: User | None = await session.scalar(db_user_stmt)
if not db_user:
db_user = User(user_target=user.dict())
session.add(db_user)
db_target_stmt = select(Target).where(Target.platform_name == platform_name).where(Target.target == target)
db_target: Target | None = await session.scalar(db_target_stmt)
if not db_target:
db_target = Target(target=target, platform_name=platform_name, target_name=target_name)
await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook])
else:
db_target.target_name = target_name
subscribe = Subscribe(
categories=cats,
tags=tags,
user=db_user,
target=db_target,
)
session.add(subscribe)
try:
await session.commit()
except IntegrityError as e:
if len(e.args) > 0 and "UNIQUE constraint failed" in e.args[0]:
raise SubscribeDupException()
raise e
async def list_subscribe(self, user: PlatformTarget) -> Sequence[Subscribe]:
async with create_session() as session:
query_stmt = (
select(Subscribe)
.where(User.user_target == user.dict())
.join(User)
.options(selectinload(Subscribe.target))
)
subs = (await session.scalars(query_stmt)).all()
return subs
async def list_subs_with_all_info(self) -> Sequence[Subscribe]:
"""获取数据库中带有user、target信息的subscribe数据"""
async with create_session() as session:
query_stmt = (
select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user))
)
subs = (await session.scalars(query_stmt)).all()
return subs
async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str):
async with create_session() as session:
user_obj = await session.scalar(select(User).where(User.user_target == user.dict()))
target_obj = await session.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj))
target_count = await session.scalar(
select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj)
)
if target_count == 0:
# delete empty target
await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook])
await session.commit()
async def update_subscribe(
self,
user: PlatformTarget,
target: str,
target_name: str,
platform_name: str,
cats: list,
tags: list,
):
async with create_session() as sess:
subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe)
.where(
User.user_target == user.dict(),
Target.target == target,
Target.platform_name == platform_name,
)
.join(User)
.join(Target)
.options(selectinload(Subscribe.target)) # type:ignore
)
subscribe_obj.tags = tags # type:ignore
subscribe_obj.categories = cats # type:ignore
subscribe_obj.target.target_name = target_name
await sess.commit()
async def get_platform_target(self, platform_name: str) -> Sequence[Target]:
async with create_session() as sess:
subq = select(Subscribe.target_id).distinct().subquery()
query = select(Target).join(subq).where(Target.platform_name == platform_name)
return (await sess.scalars(query)).all()
async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig:
async with create_session() as sess:
time_weight_conf = (
await sess.scalars(
select(ScheduleTimeWeight)
.where(Target.platform_name == platform_name, Target.target == target)
.join(Target)
)
).all()
targetObj = await sess.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
assert targetObj
return WeightConfig(
default=targetObj.default_schedule_weight,
time_config=[
TimeWeightConfig(
start_time=time_conf.start_time,
end_time=time_conf.end_time,
weight=time_conf.weight,
)
for time_conf in time_weight_conf
],
)
async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig):
async with create_session() as sess:
targetObj = await sess.scalar(
select(Target).where(Target.platform_name == platform_name, Target.target == target)
)
if not targetObj:
raise NoSuchTargetException()
target_id = targetObj.id
targetObj.default_schedule_weight = conf.default
delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
await sess.execute(delete_statement)
for time_conf in conf.time_config:
new_conf = ScheduleTimeWeight(
start_time=time_conf.start_time,
end_time=time_conf.end_time,
weight=time_conf.weight,
target=targetObj,
)
sess.add(new_conf)
await sess.commit()
async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]:
res = {}
cur_time = _get_time()
async with create_session() as sess:
targets = (
await sess.scalars(
select(Target)
.where(Target.platform_name.in_(platform_list))
.options(selectinload(Target.time_weight))
)
).all()
for target in targets:
key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight
for time_conf in target.time_weight:
if time_conf.start_time <= cur_time and time_conf.end_time > cur_time:
weight = time_conf.weight
break
res[key] = weight
return res
async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]:
async with create_session() as sess:
query = (
select(Subscribe)
.join(Target)
.where(Target.platform_name == platform_name, Target.target == target)
.options(selectinload(Subscribe.user))
)
subsribes = (await sess.scalars(query)).all()
return [
UserSubInfo(
PlatformTarget.deserialize(subscribe.user.user_target),
subscribe.categories,
subscribe.tags,
)
for subscribe in subsribes
]
async def get_all_weight_config(
self,
) -> dict[str, dict[str, PlatformWeightConfigResp]]:
res: dict[str, dict[str, PlatformWeightConfigResp]] = defaultdict(dict)
async with create_session() as sess:
query = select(Target)
targets = (await sess.scalars(query)).all()
query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target))
time_weights = (await sess.scalars(query)).all()
for target in targets:
platform_name = target.platform_name
if platform_name not in res.keys():
res[platform_name][target.target] = PlatformWeightConfigResp(
target=T_Target(target.target),
target_name=target.target_name,
platform_name=platform_name,
weight=WeightConfig(default=target.default_schedule_weight, time_config=[]),
)
for time_weight_config in time_weights:
platform_name = time_weight_config.target.platform_name
target = time_weight_config.target.target
res[platform_name][target].weight.time_config.append(
TimeWeightConfig(
start_time=time_weight_config.start_time,
end_time=time_weight_config.end_time,
weight=time_weight_config.weight,
)
)
return res
config = DBConfig()