add scheduler config

This commit is contained in:
felinae98 2022-06-01 23:03:46 +08:00
parent cf2ea83528
commit 68baa3b915
No known key found for this signature in database
GPG Key ID: 00C8B010587FF610
5 changed files with 240 additions and 29 deletions

View File

@ -1,15 +1,36 @@
from dataclasses import dataclass
from datetime import datetime, time
from typing import Optional from typing import Optional
from nonebot_bison.types import Category, Tag, Target
from nonebot_plugin_datastore.db import get_engine from nonebot_plugin_datastore.db import get_engine
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from sqlalchemy.sql.expression import delete, select from sqlalchemy.sql.expression import delete, select
from sqlalchemy.sql.functions import func from sqlalchemy.sql.functions import func
from .db_model import Subscribe as MSubscribe from ..types import Category, Tag
from .db_model import Target as MTarget from ..types import Target as T_Target
from .db_model import User from .db_model import ScheduleTimeWeight, Subscribe, Target, User
def _get_time():
dt = datetime.now()
cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second)
return cur_time
@dataclass
class TimeWeightConfig:
start_time: time
end_time: time
weight: int
@dataclass
class WeightConfig:
default: int
time_config: list[TimeWeightConfig]
class DBConfig: class DBConfig:
@ -17,7 +38,7 @@ class DBConfig:
self, self,
user: int, user: int,
user_type: str, user_type: str,
target: Target, target: T_Target,
target_name: str, target_name: str,
platform_name: str, platform_name: str,
cats: list[Category], cats: list[Category],
@ -32,18 +53,18 @@ class DBConfig:
db_user = User(uid=user, type=user_type) db_user = User(uid=user, type=user_type)
session.add(db_user) session.add(db_user)
db_target_stmt = ( db_target_stmt = (
select(MTarget) select(Target)
.where(MTarget.platform_name == platform_name) .where(Target.platform_name == platform_name)
.where(MTarget.target == target) .where(Target.target == target)
) )
db_target: Optional[MTarget] = await session.scalar(db_target_stmt) db_target: Optional[Target] = await session.scalar(db_target_stmt)
if not db_target: if not db_target:
db_target = MTarget( db_target = Target(
target=target, platform_name=platform_name, target_name=target_name target=target, platform_name=platform_name, target_name=target_name
) )
else: else:
db_target.target_name = target_name # type: ignore db_target.target_name = target_name # type: ignore
subscribe = MSubscribe( subscribe = Subscribe(
categories=cats, categories=cats,
tags=tags, tags=tags,
user=db_user, user=db_user,
@ -52,15 +73,15 @@ class DBConfig:
session.add(subscribe) session.add(subscribe)
await session.commit() await session.commit()
async def list_subscribe(self, user: int, user_type: str) -> list[MSubscribe]: async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]:
async with AsyncSession(get_engine()) as session: async with AsyncSession(get_engine()) as session:
query_stmt = ( query_stmt = (
select(MSubscribe) select(Subscribe)
.where(User.type == user_type, User.uid == user) .where(User.type == user_type, User.uid == user)
.join(User) .join(User)
.options(selectinload(MSubscribe.target)) # type:ignore .options(selectinload(Subscribe.target)) # type:ignore
) )
subs: list[MSubscribe] = (await session.scalars(query_stmt)).all() subs: list[Subscribe] = (await session.scalars(query_stmt)).all()
return subs return subs
async def del_subscribe( async def del_subscribe(
@ -71,19 +92,19 @@ class DBConfig:
select(User).where(User.uid == user, User.type == user_type) select(User).where(User.uid == user, User.type == user_type)
) )
target_obj = await session.scalar( target_obj = await session.scalar(
select(MTarget).where( select(Target).where(
MTarget.platform_name == platform_name, MTarget.target == target Target.platform_name == platform_name, MTarget.target == target
) )
) )
await session.execute( await session.execute(
delete(MSubscribe).where( delete(Subscribe).where(
MSubscribe.user == user_obj, MSubscribe.target == target_obj Subscribe.user == user_obj, MSubscribe.target == target_obj
) )
) )
target_count = await session.scalar( target_count = await session.scalar(
select(func.count()) select(func.count())
.select_from(MSubscribe) .select_from(Subscribe)
.where(MSubscribe.target == target_obj) .where(Subscribe.target == target_obj)
) )
if target_count == 0: if target_count == 0:
# delete empty target # delete empty target
@ -101,22 +122,93 @@ class DBConfig:
tags: list, tags: list,
): ):
async with AsyncSession(get_engine()) as sess: async with AsyncSession(get_engine()) as sess:
subscribe_obj: MSubscribe = await sess.scalar( subscribe_obj: Subscribe = await sess.scalar(
select(MSubscribe) select(Subscribe)
.where( .where(
User.uid == user, User.uid == user,
User.type == user_type, User.type == user_type,
MTarget.target == target, Target.target == target,
MTarget.platform_name == platform_name, Target.platform_name == platform_name,
) )
.join(User) .join(User)
.join(MTarget) .join(Target)
.options(selectinload(MSubscribe.target)) # type:ignore .options(selectinload(Subscribe.target)) # type:ignore
) )
subscribe_obj.tags = tags # type:ignore subscribe_obj.tags = tags # type:ignore
subscribe_obj.categories = cats # type:ignore subscribe_obj.categories = cats # type:ignore
subscribe_obj.target.target_name = target_name subscribe_obj.target.target_name = target_name
await sess.commit() await sess.commit()
async def get_time_weight_config(
self, target: T_Target, platform_name: str
) -> WeightConfig:
async with AsyncSession(get_engine()) as sess:
time_weight_conf: list[ScheduleTimeWeight] = await sess.scalars(
select(ScheduleTimeWeight)
.where(Target.platform_name == platform_name, Target.target == target)
.join(Target)
)
targetObj: Target = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
)
return WeightConfig(
default=targetObj.default_schedule_weight,
time_config=[
TimeWeightConfig(
start_time=time_conf.start_time,
end_time=time_conf.end_time,
weight=time_conf.weight,
)
for time_conf in time_weight_conf
],
)
async def update_time_weight_config(
self, target: T_Target, platform_name: str, conf: WeightConfig
):
async with AsyncSession(get_engine()) as sess:
targetObj: Target = await sess.scalar(
select(Target).where(
Target.platform_name == platform_name, Target.target == target
)
)
target_id = targetObj.id
targetObj.default_schedule_weight = conf.default
delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
for time_conf in conf.time_config:
new_conf = ScheduleTimeWeight(
start_time=time_conf.start_time,
end_time=time_conf.end_time,
weight=time_conf.weight,
target=targetObj,
)
sess.add(new_conf)
await sess.commit()
async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]:
res = {}
cur_time = _get_time()
async with AsyncSession(get_engine()) as sess:
targets: list[Target] = await sess.scalars(
select(Target)
.where(Target.platform_name.in_(platform_list))
.options(selectinload(Target.time_weight))
)
for target in targets:
key = f"{target.platform_name}-{target.target}"
weight = target.default_schedule_weight
for time_conf in target.time_weight:
if (
time_conf.start_time <= cur_time
and time_conf.end_time > cur_time
):
weight = time_conf.weight
break
res[key] = weight
return res
config = DBConfig() config = DBConfig()

View File

@ -1 +0,0 @@
from .scheduler_config import SchedulerConfig

View File

@ -10,8 +10,9 @@ from nonebot.plugin import require
from ..plugin_config import plugin_config from ..plugin_config import plugin_config
from .http import http_client from .http import http_client
from .scheduler_config import SchedulerConfig
__all__ = ["http_client", "Singleton", "parse_text", "html_to_text"] __all__ = ["http_client", "Singleton", "parse_text", "html_to_text", "SchedulerConfig"]
class Singleton(type): class Singleton(type):

View File

@ -0,0 +1,119 @@
from datetime import time
from nonebug import App
async def test_create_config(app: App, db_migration):
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target
from nonebot_plugin_datastore.db import get_engine
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
cats=[],
tags=[],
)
await config.update_time_weight_config(
target=T_Target("weibo_id"),
platform_name="weibo",
conf=WeightConfig(
default=10,
time_config=[
TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20)
],
),
)
test_config = await config.get_time_weight_config(
target=T_Target("weibo_id"), platform_name="weibo"
)
assert test_config.default == 10
assert test_config.time_config == [
TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20)
]
test_config1 = await config.get_time_weight_config(
target=T_Target("weibo_id1"), platform_name="weibo"
)
assert test_config1.default == 10
assert test_config1.time_config == []
async def test_get_current_weight(app: App, db_migration):
from datetime import time
from nonebot_bison.config import db_config
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
from nonebot_bison.config.db_model import Subscribe, Target, User
from nonebot_bison.types import Target as T_Target
from nonebot_plugin_datastore.db import get_engine
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id"),
target_name="weibo_name",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name1",
platform_name="weibo",
cats=[],
tags=[],
)
await config.add_subscribe(
user=123,
user_type="group",
target=T_Target("weibo_id1"),
target_name="weibo_name2",
platform_name="weibo2",
cats=[],
tags=[],
)
await config.update_time_weight_config(
target=T_Target("weibo_id"),
platform_name="weibo",
conf=WeightConfig(
default=10,
time_config=[
TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20),
TimeWeightConfig(start_time=time(4, 0), end_time=time(5, 0), weight=30),
],
),
)
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(1, 30))
weight = await config.get_current_weight_val(["weibo", "weibo2"])
assert len(weight) == 3
assert weight["weibo-weibo_id"] == 20
assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(4, 0))
weight = await config.get_current_weight_val(["weibo", "weibo2"])
assert len(weight) == 3
assert weight["weibo-weibo_id"] == 30
assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(5, 0))
weight = await config.get_current_weight_val(["weibo", "weibo2"])
assert len(weight) == 3
assert weight["weibo-weibo_id"] == 10
assert weight["weibo-weibo_id1"] == 10
assert weight["weibo2-weibo_id1"] == 10