add scheduler config

This commit is contained in:
felinae98 2022-06-01 23:03:46 +08:00
parent cf2ea83528
commit 68baa3b915
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
5 changed files with 240 additions and 29 deletions

View File

@ -1,15 +1,36 @@
from dataclasses import dataclass
from datetime import datetime, time
from typing import Optional
from nonebot_bison.types import Category, Tag, Target
from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import delete, select
from sqlalchemy.sql.functions import func
from .db_model import Subscribe as MSubscribe
from .db_model import Target as MTarget
from .db_model import User
from ..types import Category, Tag
from ..types import Target as T_Target
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
def _get_time():
dt = datetime.now()
cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second)
return cur_time
@dataclass
class TimeWeightConfig:
start_time: time
end_time: time
weight: int
@dataclass
class WeightConfig:
default: int
time_config: list[TimeWeightConfig]
class DBConfig:
@ -17,7 +38,7 @@ class DBConfig:
self,
user: int,
user_type: str,
target: Target,
target: T_Target,
target_name: str,
platform_name: str,
cats: list[Category],
@ -32,18 +53,18 @@ class DBConfig:
db_user = User(uid=user, type=user_type)
session.add(db_user)
db_target_stmt = (
select(MTarget)
.where(MTarget.platform_name == platform_name)
.where(MTarget.target == target)
select(Target)
.where(Target.platform_name == platform_name)
.where(Target.target == target)
)
db_target: Optional[MTarget] = await session.scalar(db_target_stmt)
db_target: Optional[Target] = await session.scalar(db_target_stmt)
if not db_target:
db_target = MTarget(
db_target = Target(
target=target, platform_name=platform_name, target_name=target_name
)
else:
db_target.target_name = target_name # type: ignore
subscribe = MSubscribe(
subscribe = Subscribe(
categories=cats,
tags=tags,
user=db_user,
@ -52,15 +73,15 @@ class DBConfig:
session.add(subscribe)
await session.commit()
async def list_subscribe(self, user: int, user_type: str) -> list[MSubscribe]:
async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]:
async with AsyncSession(get_engine()) as session:
query_stmt = (
select(MSubscribe)
select(Subscribe)
.where(User.type == user_type, User.uid == user)
.join(User)
.options(selectinload(MSubscribe.target)) # type:ignore
.options(selectinload(Subscribe.target)) # type:ignore
)
subs: list[MSubscribe] = (await session.scalars(query_stmt)).all()
subs: list[Subscribe] = (await session.scalars(query_stmt)).all()
return subs
async def del_subscribe(
@ -71,19 +92,19 @@ class DBConfig:
select(User).where(User.uid == user, User.type == user_type)
)
target_obj = await session.scalar(
select(MTarget).where(
MTarget.platform_name == platform_name, MTarget.target == target
select(Target).where(
Target.platform_name == platform_name, MTarget.target == target
)
)
await session.execute(
delete(MSubscribe).where(
MSubscribe.user == user_obj, MSubscribe.target == target_obj
delete(Subscribe).where(
Subscribe.user == user_obj, MSubscribe.target == target_obj
)
)
target_count = await session.scalar(
select(func.count())
.select_from(MSubscribe)
.where(MSubscribe.target == target_obj)
.select_from(Subscribe)
.where(Subscribe.target == target_obj)
)
if target_count == 0:
# delete empty target
@ -101,22 +122,93 @@ class DBConfig:
tags: list,
):
async with AsyncSession(get_engine()) as sess:
subscribe_obj: MSubscribe = await sess.scalar(
select(MSubscribe)
subscribe_obj: Subscribe = await sess.scalar(
select(Subscribe)
.where(
User.uid == user,
User.type == user_type,
MTarget.target == target,
MTarget.platform_name == platform_name,
Target.target == target,
Target.platform_name == platform_name,
)
.join(User)
.join(MTarget)
.options(selectinload(MSubscribe.target)) # type:ignore
.join(Target)
.options(selectinload(Subscribe.target)) # type:ignore
)
subscribe_obj.tags = tags # type:ignore
subscribe_obj.categories = cats # type:ignore
subscribe_obj.target.target_name = target_name
await sess.commit()
async def get_time_weight_config(
self, target: T_Target, platform_name: str
) -> WeightConfig:
async with AsyncSession(get_engine()) as sess:
time_weight_conf: list[ScheduleTimeWeight] = await sess.scalars(
select(ScheduleTimeWeight)
.where(Target.platform_name == platform_name, Target.target == target)
.join(Target)
)
targetObj: Target = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
)
return WeightConfig(
default=targetObj.default_schedule_weight,
time_config=[
TimeWeightConfig(
start_time=time_conf.start_time,
end_time=time_conf.end_time,
weight=time_conf.weight,
)
for time_conf in time_weight_conf
],
)
async def update_time_weight_config(
self, target: T_Target, platform_name: str, conf: WeightConfig
):
async with AsyncSession(get_engine()) as sess:
targetObj: Target = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
)
target_id = targetObj.id
targetObj.default_schedule_weight = conf.default
delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
for time_conf in conf.time_config:
new_conf = ScheduleTimeWeight(
start_time=time_conf.start_time,
end_time=time_conf.end_time,
weight=time_conf.weight,
target=targetObj,
)
sess.add(new_conf)
await sess.commit()
async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]:
res = {}
cur_time = _get_time()
async with AsyncSession(get_engine()) as sess:
targets: list[Target] = await sess.scalars(
select(Target)
.where(Target.platform_name.in_(platform_list))
.options(selectinload(Target.time_weight))
)
for target in targets:
key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight
for time_conf in target.time_weight:
if (
time_conf.start_time <= cur_time
and time_conf.end_time > cur_time
):
weight = time_conf.weight
break
res[key] = weight
return res
config = DBConfig()

View File

@ -1 +0,0 @@
from .scheduler_config import SchedulerConfig

View File

@ -10,8 +10,9 @@ from nonebot.plugin import require
from ..plugin_config import plugin_config
from .http import http_client
from .scheduler_config import SchedulerConfig
__all__ = ["http_client", "Singleton", "parse_text", "html_to_text"]
__all__ = ["http_client", "Singleton", "parse_text", "html_to_text", "SchedulerConfig"]
class Singleton(type):

View File

@ -0,0 +1,119 @@
from datetime import time
from nonebug import App
async def test_create_config(app: App, db_migration):
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target
from nonebot_plugin_datastore.db import get_engine
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
cats=[],
tags=[],
)
await config.update_time_weight_config(
target=T_Target("weibo_id"),
platform_name="weibo",
conf=WeightConfig(
default=10,
time_config=[
TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20)
],
),
)
test_config = await config.get_time_weight_config(
target=T_Target("weibo_id"), platform_name="weibo"
)
assert test_config.default == 10
assert test_config.time_config == [
TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20)
]
test_config1 = await config.get_time_weight_config(
target=T_Target("weibo_id1"), platform_name="weibo"
)
assert test_config1.default == 10
assert test_config1.time_config == []
async def test_get_current_weight(app: App, db_migration):
from datetime import time
from nonebot_bison.config import db_config
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target
from nonebot_plugin_datastore.db import get_engine
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name2",
platform_name="weibo2",
cats=[],
tags=[],
)
await config.update_time_weight_config(
target=T_Target("weibo_id"),
platform_name="weibo",
conf=WeightConfig(
default=10,
time_config=[
TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20),
TimeWeightConfig(start_time=time(4, 0), end_time=time(5, 0), weight=30),
],
),
)
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(1, 30))
weight = await config.get_current_weight_val(["weibo", "weibo2"])
assert len(weight) == 3
assert weight["weibo-weibo_id"] == 20
assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(4, 0))
weight = await config.get_current_weight_val(["weibo", "weibo2"])
assert len(weight) == 3
assert weight["weibo-weibo_id"] == 30
assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(5, 0))
weight = await config.get_current_weight_val(["weibo", "weibo2"])
assert len(weight) == 3
assert weight["weibo-weibo_id"] == 10
assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10