🐛 register db hook after init scheduler

This commit is contained in:
felinae98 2023-03-12 17:43:32 +08:00
parent e43edaa717
commit 010e6335f2
2 changed files with 17 additions and 14 deletions

View File

@ -1,3 +1,4 @@
import asyncio
from collections import defaultdict from collections import defaultdict
from datetime import datetime, time from datetime import datetime, time
from typing import Awaitable, Callable, Optional, Sequence from typing import Awaitable, Callable, Optional, Sequence
@ -28,14 +29,14 @@ class SubscribeDupException(Exception):
class DBConfig: class DBConfig:
def __init__(self): def __init__(self):
self.add_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None self.add_target_hook: list[Callable[[str, T_Target], Awaitable]] = []
self.delete_target_hook: Optional[Callable[[str, T_Target], Awaitable]] = None self.delete_target_hook: list[Callable[[str, T_Target], Awaitable]] = []
def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]): def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.add_target_hook = fun self.add_target_hook.append(fun)
def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]): def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]):
self.delete_target_hook = fun self.delete_target_hook.append(fun)
async def add_subscribe( async def add_subscribe(
self, self,
@ -65,8 +66,9 @@ class DBConfig:
db_target = Target( db_target = Target(
target=target, platform_name=platform_name, target_name=target_name target=target, platform_name=platform_name, target_name=target_name
) )
if self.add_target_hook: await asyncio.gather(
await self.add_target_hook(platform_name, target) *[hook(platform_name, target) for hook in self.add_target_hook]
)
else: else:
db_target.target_name = target_name db_target.target_name = target_name
subscribe = Subscribe( subscribe = Subscribe(
@ -118,9 +120,12 @@ class DBConfig:
) )
if target_count == 0: if target_count == 0:
# delete empty target # delete empty target
# await session.delete(target_obj) await asyncio.gather(
if self.delete_target_hook: *[
await self.delete_target_hook(platform_name, T_Target(target)) hook(platform_name, T_Target(target))
for hook in self.delete_target_hook
]
)
await session.commit() await session.commit()
async def update_subscribe( async def update_subscribe(

View File

@ -21,7 +21,7 @@ async def init_scheduler():
platform_name = platform.platform_name platform_name = platform.platform_name
targets = await config.get_platform_target(platform_name) targets = await config.get_platform_target(platform_name)
if scheduler_config not in _schedule_class_dict: if scheduler_config not in _schedule_class_dict:
_schedule_class_dict[scheduler_config] = targets _schedule_class_dict[scheduler_config] = list(targets)
else: else:
_schedule_class_dict[scheduler_config].extend(targets) _schedule_class_dict[scheduler_config].extend(targets)
if scheduler_config not in _schedule_class_platform_dict: if scheduler_config not in _schedule_class_platform_dict:
@ -36,6 +36,8 @@ async def init_scheduler():
scheduler_dict[scheduler_config] = Scheduler( scheduler_dict[scheduler_config] = Scheduler(
scheduler_config, schedulable_args, platform_name_list scheduler_config, schedulable_args, platform_name_list
) )
config.register_add_target_hook(handle_insert_new_target)
config.register_delete_target_hook(handle_delete_target)
async def handle_insert_new_target(platform_name: str, target: T_Target): async def handle_insert_new_target(platform_name: str, target: T_Target):
@ -48,7 +50,3 @@ async def handle_delete_target(platform_name: str, target: T_Target):
platform = platform_manager[platform_name] platform = platform_manager[platform_name]
scheduler_obj = scheduler_dict[platform.scheduler] scheduler_obj = scheduler_dict[platform.scheduler]
scheduler_obj.delete_schedulable(platform_name, target) scheduler_obj.delete_schedulable(platform_name, target)
config.register_add_target_hook(handle_insert_new_target)
config.register_delete_target_hook(handle_delete_target)