mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-05 19:36:43 +08:00
add scheduler config
This commit is contained in:
parent
cf2ea83528
commit
68baa3b915
@ -1,15 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, time
|
||||
from typing import Optional
|
||||
|
||||
from nonebot_bison.types import Category, Tag, Target
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.sql.expression import delete, select
|
||||
from sqlalchemy.sql.functions import func
|
||||
|
||||
from .db_model import Subscribe as MSubscribe
|
||||
from .db_model import Target as MTarget
|
||||
from .db_model import User
|
||||
from ..types import Category, Tag
|
||||
from ..types import Target as T_Target
|
||||
from .db_model import ScheduleTimeWeight, Subscribe, Target, User
|
||||
|
||||
|
||||
def _get_time():
|
||||
dt = datetime.now()
|
||||
cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second)
|
||||
return cur_time
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeWeightConfig:
|
||||
start_time: time
|
||||
end_time: time
|
||||
weight: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightConfig:
|
||||
|
||||
default: int
|
||||
time_config: list[TimeWeightConfig]
|
||||
|
||||
|
||||
class DBConfig:
|
||||
@ -17,7 +38,7 @@ class DBConfig:
|
||||
self,
|
||||
user: int,
|
||||
user_type: str,
|
||||
target: Target,
|
||||
target: T_Target,
|
||||
target_name: str,
|
||||
platform_name: str,
|
||||
cats: list[Category],
|
||||
@ -32,18 +53,18 @@ class DBConfig:
|
||||
db_user = User(uid=user, type=user_type)
|
||||
session.add(db_user)
|
||||
db_target_stmt = (
|
||||
select(MTarget)
|
||||
.where(MTarget.platform_name == platform_name)
|
||||
.where(MTarget.target == target)
|
||||
select(Target)
|
||||
.where(Target.platform_name == platform_name)
|
||||
.where(Target.target == target)
|
||||
)
|
||||
db_target: Optional[MTarget] = await session.scalar(db_target_stmt)
|
||||
db_target: Optional[Target] = await session.scalar(db_target_stmt)
|
||||
if not db_target:
|
||||
db_target = MTarget(
|
||||
db_target = Target(
|
||||
target=target, platform_name=platform_name, target_name=target_name
|
||||
)
|
||||
else:
|
||||
db_target.target_name = target_name # type: ignore
|
||||
subscribe = MSubscribe(
|
||||
subscribe = Subscribe(
|
||||
categories=cats,
|
||||
tags=tags,
|
||||
user=db_user,
|
||||
@ -52,15 +73,15 @@ class DBConfig:
|
||||
session.add(subscribe)
|
||||
await session.commit()
|
||||
|
||||
async def list_subscribe(self, user: int, user_type: str) -> list[MSubscribe]:
|
||||
async def list_subscribe(self, user: int, user_type: str) -> list[Subscribe]:
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
query_stmt = (
|
||||
select(MSubscribe)
|
||||
select(Subscribe)
|
||||
.where(User.type == user_type, User.uid == user)
|
||||
.join(User)
|
||||
.options(selectinload(MSubscribe.target)) # type:ignore
|
||||
.options(selectinload(Subscribe.target)) # type:ignore
|
||||
)
|
||||
subs: list[MSubscribe] = (await session.scalars(query_stmt)).all()
|
||||
subs: list[Subscribe] = (await session.scalars(query_stmt)).all()
|
||||
return subs
|
||||
|
||||
async def del_subscribe(
|
||||
@ -71,19 +92,19 @@ class DBConfig:
|
||||
select(User).where(User.uid == user, User.type == user_type)
|
||||
)
|
||||
target_obj = await session.scalar(
|
||||
select(MTarget).where(
|
||||
MTarget.platform_name == platform_name, MTarget.target == target
|
||||
select(Target).where(
|
||||
Target.platform_name == platform_name, MTarget.target == target
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
delete(MSubscribe).where(
|
||||
MSubscribe.user == user_obj, MSubscribe.target == target_obj
|
||||
delete(Subscribe).where(
|
||||
Subscribe.user == user_obj, MSubscribe.target == target_obj
|
||||
)
|
||||
)
|
||||
target_count = await session.scalar(
|
||||
select(func.count())
|
||||
.select_from(MSubscribe)
|
||||
.where(MSubscribe.target == target_obj)
|
||||
.select_from(Subscribe)
|
||||
.where(Subscribe.target == target_obj)
|
||||
)
|
||||
if target_count == 0:
|
||||
# delete empty target
|
||||
@ -101,22 +122,93 @@ class DBConfig:
|
||||
tags: list,
|
||||
):
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
subscribe_obj: MSubscribe = await sess.scalar(
|
||||
select(MSubscribe)
|
||||
subscribe_obj: Subscribe = await sess.scalar(
|
||||
select(Subscribe)
|
||||
.where(
|
||||
User.uid == user,
|
||||
User.type == user_type,
|
||||
MTarget.target == target,
|
||||
MTarget.platform_name == platform_name,
|
||||
Target.target == target,
|
||||
Target.platform_name == platform_name,
|
||||
)
|
||||
.join(User)
|
||||
.join(MTarget)
|
||||
.options(selectinload(MSubscribe.target)) # type:ignore
|
||||
.join(Target)
|
||||
.options(selectinload(Subscribe.target)) # type:ignore
|
||||
)
|
||||
subscribe_obj.tags = tags # type:ignore
|
||||
subscribe_obj.categories = cats # type:ignore
|
||||
subscribe_obj.target.target_name = target_name
|
||||
await sess.commit()
|
||||
|
||||
async def get_time_weight_config(
|
||||
self, target: T_Target, platform_name: str
|
||||
) -> WeightConfig:
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
time_weight_conf: list[ScheduleTimeWeight] = await sess.scalars(
|
||||
select(ScheduleTimeWeight)
|
||||
.where(Target.platform_name == platform_name, Target.target == target)
|
||||
.join(Target)
|
||||
)
|
||||
targetObj: Target = await sess.scalar(
|
||||
select(Target).where(
|
||||
Target.platform_name == platform_name, Target.target == target
|
||||
)
|
||||
)
|
||||
return WeightConfig(
|
||||
default=targetObj.default_schedule_weight,
|
||||
time_config=[
|
||||
TimeWeightConfig(
|
||||
start_time=time_conf.start_time,
|
||||
end_time=time_conf.end_time,
|
||||
weight=time_conf.weight,
|
||||
)
|
||||
for time_conf in time_weight_conf
|
||||
],
|
||||
)
|
||||
|
||||
async def update_time_weight_config(
|
||||
self, target: T_Target, platform_name: str, conf: WeightConfig
|
||||
):
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
targetObj: Target = await sess.scalar(
|
||||
select(Target).where(
|
||||
Target.platform_name == platform_name, Target.target == target
|
||||
)
|
||||
)
|
||||
target_id = targetObj.id
|
||||
targetObj.default_schedule_weight = conf.default
|
||||
delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id)
|
||||
for time_conf in conf.time_config:
|
||||
new_conf = ScheduleTimeWeight(
|
||||
start_time=time_conf.start_time,
|
||||
end_time=time_conf.end_time,
|
||||
weight=time_conf.weight,
|
||||
target=targetObj,
|
||||
)
|
||||
sess.add(new_conf)
|
||||
|
||||
await sess.commit()
|
||||
|
||||
async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]:
|
||||
res = {}
|
||||
cur_time = _get_time()
|
||||
async with AsyncSession(get_engine()) as sess:
|
||||
targets: list[Target] = await sess.scalars(
|
||||
select(Target)
|
||||
.where(Target.platform_name.in_(platform_list))
|
||||
.options(selectinload(Target.time_weight))
|
||||
)
|
||||
for target in targets:
|
||||
key = f"{target.platform_name}-{target.target}"
|
||||
weight = target.default_schedule_weight
|
||||
for time_conf in target.time_weight:
|
||||
if (
|
||||
time_conf.start_time <= cur_time
|
||||
and time_conf.end_time > cur_time
|
||||
):
|
||||
weight = time_conf.weight
|
||||
break
|
||||
res[key] = weight
|
||||
return res
|
||||
|
||||
|
||||
config = DBConfig()
|
||||
|
@ -1 +0,0 @@
|
||||
from .scheduler_config import SchedulerConfig
|
@ -10,8 +10,9 @@ from nonebot.plugin import require
|
||||
|
||||
from ..plugin_config import plugin_config
|
||||
from .http import http_client
|
||||
from .scheduler_config import SchedulerConfig
|
||||
|
||||
__all__ = ["http_client", "Singleton", "parse_text", "html_to_text"]
|
||||
__all__ = ["http_client", "Singleton", "parse_text", "html_to_text", "SchedulerConfig"]
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
|
119
tests/config/test_scheduler_conf.py
Normal file
119
tests/config/test_scheduler_conf.py
Normal file
@ -0,0 +1,119 @@
|
||||
from datetime import time
|
||||
|
||||
from nonebug import App
|
||||
|
||||
|
||||
async def test_create_config(app: App, db_migration):
|
||||
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
|
||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||
from nonebot_bison.types import Target as T_Target
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
|
||||
await config.add_subscribe(
|
||||
user=123,
|
||||
user_type="group",
|
||||
target=T_Target("weibo_id"),
|
||||
target_name="weibo_name",
|
||||
platform_name="weibo",
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
await config.add_subscribe(
|
||||
user=123,
|
||||
user_type="group",
|
||||
target=T_Target("weibo_id1"),
|
||||
target_name="weibo_name1",
|
||||
platform_name="weibo",
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
await config.update_time_weight_config(
|
||||
target=T_Target("weibo_id"),
|
||||
platform_name="weibo",
|
||||
conf=WeightConfig(
|
||||
default=10,
|
||||
time_config=[
|
||||
TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
test_config = await config.get_time_weight_config(
|
||||
target=T_Target("weibo_id"), platform_name="weibo"
|
||||
)
|
||||
assert test_config.default == 10
|
||||
assert test_config.time_config == [
|
||||
TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20)
|
||||
]
|
||||
test_config1 = await config.get_time_weight_config(
|
||||
target=T_Target("weibo_id1"), platform_name="weibo"
|
||||
)
|
||||
assert test_config1.default == 10
|
||||
assert test_config1.time_config == []
|
||||
|
||||
|
||||
async def test_get_current_weight(app: App, db_migration):
|
||||
from datetime import time
|
||||
|
||||
from nonebot_bison.config import db_config
|
||||
from nonebot_bison.config.db_config import TimeWeightConfig, WeightConfig, config
|
||||
from nonebot_bison.config.db_model import Subscribe, Target, User
|
||||
from nonebot_bison.types import Target as T_Target
|
||||
from nonebot_plugin_datastore.db import get_engine
|
||||
|
||||
await config.add_subscribe(
|
||||
user=123,
|
||||
user_type="group",
|
||||
target=T_Target("weibo_id"),
|
||||
target_name="weibo_name",
|
||||
platform_name="weibo",
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
await config.add_subscribe(
|
||||
user=123,
|
||||
user_type="group",
|
||||
target=T_Target("weibo_id1"),
|
||||
target_name="weibo_name1",
|
||||
platform_name="weibo",
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
await config.add_subscribe(
|
||||
user=123,
|
||||
user_type="group",
|
||||
target=T_Target("weibo_id1"),
|
||||
target_name="weibo_name2",
|
||||
platform_name="weibo2",
|
||||
cats=[],
|
||||
tags=[],
|
||||
)
|
||||
await config.update_time_weight_config(
|
||||
target=T_Target("weibo_id"),
|
||||
platform_name="weibo",
|
||||
conf=WeightConfig(
|
||||
default=10,
|
||||
time_config=[
|
||||
TimeWeightConfig(start_time=time(1, 0), end_time=time(2, 0), weight=20),
|
||||
TimeWeightConfig(start_time=time(4, 0), end_time=time(5, 0), weight=30),
|
||||
],
|
||||
),
|
||||
)
|
||||
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(1, 30))
|
||||
weight = await config.get_current_weight_val(["weibo", "weibo2"])
|
||||
assert len(weight) == 3
|
||||
assert weight["weibo-weibo_id"] == 20
|
||||
assert weight["weibo-weibo_id1"] == 10
|
||||
assert weight["weibo2-weibo_id1"] == 10
|
||||
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(4, 0))
|
||||
weight = await config.get_current_weight_val(["weibo", "weibo2"])
|
||||
assert len(weight) == 3
|
||||
assert weight["weibo-weibo_id"] == 30
|
||||
assert weight["weibo-weibo_id1"] == 10
|
||||
assert weight["weibo2-weibo_id1"] == 10
|
||||
app.monkeypatch.setattr(db_config, "_get_time", lambda: time(5, 0))
|
||||
weight = await config.get_current_weight_val(["weibo", "weibo2"])
|
||||
assert len(weight) == 3
|
||||
assert weight["weibo-weibo_id"] == 10
|
||||
assert weight["weibo-weibo_id1"] == 10
|
||||
assert weight["weibo2-weibo_id1"] == 10
|
Loading…
x
Reference in New Issue
Block a user