Azide 38f0edcd25
🐛 Bilibili调度新增回避策略 (#573)
* 🐛 将Bilibili的调度速度降低到60s

*  增加回避策略

*  降低轮询间隔,增加回避次数,抛出阶段随机刷新

* ♻️ 更清晰的调度逻辑实现

* 🐛 兼容3.10的NamedTuple多继承

* ♻️ 合并重复逻辑

* ♻️ ctx放入fsm

* 🐛 测试并调整逻辑

* 🐛 补全类型标注

* ♻️ 添加Condition和State.on_exit/on_enter,以实现自动状态切换

*  调整测试

* 🐛 私有化命名方法

* 🔊 调整补充日志

* 🐛 添加测试后清理

* ✏️ fix typing typo
2024-08-04 18:40:01 +08:00

169 lines
5.5 KiB
Python

import sys
import asyncio
import inspect
from enum import Enum
from dataclasses import dataclass
from collections.abc import Set as AbstractSet
from collections.abc import Callable, Sequence, Awaitable, AsyncGenerator
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Protocol, TypeAlias, TypedDict, NamedTuple, runtime_checkable
from nonebot import logger
class StrEnum(str, Enum): ...
TAddon = TypeVar("TAddon", contravariant=True)
TState = TypeVar("TState", contravariant=True)
TEvent = TypeVar("TEvent", contravariant=True)
TFSM = TypeVar("TFSM", bound="FSM", contravariant=True)
class StateError(Exception): ...
ActionReturn: TypeAlias = Any
@runtime_checkable
class SupportStateOnExit(Generic[TAddon], Protocol):
async def on_exit(self, addon: TAddon) -> None: ...
@runtime_checkable
class SupportStateOnEnter(Generic[TAddon], Protocol):
async def on_enter(self, addon: TAddon) -> None: ...
class Action(Generic[TState, TEvent, TAddon], Protocol):
async def __call__(self, from_: TState, event: TEvent, to: TState, addon: TAddon) -> ActionReturn: ...
ConditionFunc = Callable[[TAddon], Awaitable[bool]]
@dataclass(frozen=True)
class Condition(Generic[TAddon]):
call: ConditionFunc[TAddon]
not_: bool = False
def __repr__(self):
if inspect.isfunction(self.call) or inspect.isclass(self.call):
call_str = self.call.__name__
else:
call_str = repr(self.call)
return f"Condition(call={call_str})"
async def __call__(self, addon: TAddon) -> bool:
return (await self.call(addon)) ^ self.not_
# FIXME: Python 3.11+ 才支持 NamedTuple和TypedDict使用多继承添加泛型
# 所以什么时候 drop 3.10(?
if sys.version_info >= (3, 11) or TYPE_CHECKING:
class Transition(Generic[TState, TEvent, TAddon], NamedTuple):
action: Action[TState, TEvent, TAddon]
to: TState
conditions: AbstractSet[Condition[TAddon]] | None = None
class StateGraph(Generic[TState, TEvent, TAddon], TypedDict):
transitions: dict[
TState,
dict[
TEvent,
Transition[TState, TEvent, TAddon] | Sequence[Transition[TState, TEvent, TAddon]],
],
]
initial: TState
else:
class Transition(NamedTuple):
action: Action
to: Any
conditions: AbstractSet[Condition] | None = None
class StateGraph(TypedDict):
transitions: dict[Any, dict[Any, Transition]]
initial: Any
class FSM(Generic[TState, TEvent, TAddon]):
def __init__(self, graph: StateGraph[TState, TEvent, TAddon], addon: TAddon):
self.started = False
self.graph = graph
self.current_state = graph["initial"]
self.machine = self._core()
self.addon = addon
async def _core(self) -> AsyncGenerator[ActionReturn, TEvent]:
self.current_state = self.graph["initial"]
res = None
while True:
event = yield res
if not self.started:
raise StateError("FSM not started, please call start() first")
selected_transition = await self.cherry_pick(event)
logger.trace(f"exit state: {self.current_state}")
if isinstance(self.current_state, SupportStateOnExit):
logger.trace(f"do {self.current_state}.on_exit")
await self.current_state.on_exit(self.addon)
logger.trace(f"do action: {selected_transition.action}")
res = await selected_transition.action(self.current_state, event, selected_transition.to, self.addon)
logger.trace(f"enter state: {selected_transition.to}")
self.current_state = selected_transition.to
if isinstance(self.current_state, SupportStateOnEnter):
logger.trace(f"do {self.current_state}.on_enter")
await self.current_state.on_enter(self.addon)
async def start(self):
await anext(self.machine)
self.started = True
logger.trace(f"FSM started, initial state: {self.current_state}")
async def cherry_pick(self, event: TEvent) -> Transition[TState, TEvent, TAddon]:
transitions = self.graph["transitions"][self.current_state].get(event)
if transitions is None:
raise StateError(f"Invalid event {event} in state {self.current_state}")
if isinstance(transitions, Transition):
return transitions
elif isinstance(transitions, Sequence):
no_conds: list[Transition[TState, TEvent, TAddon]] = []
for transition in transitions:
if not transition.conditions:
no_conds.append(transition)
continue
values = await asyncio.gather(*(condition(self.addon) for condition in transition.conditions))
if all(values):
logger.trace(f"conditions {transition.conditions} passed")
return transition
else:
if no_conds:
return no_conds.pop()
else:
raise StateError(f"Invalid event {event} in state {self.current_state}")
else:
raise TypeError("Invalid transition type: {transitions}, expected Transition or Sequence[Transition]")
async def emit(self, event: TEvent):
return await self.machine.asend(event)
async def reset(self):
await self.machine.aclose()
self.started = False
del self.machine
self.machine = self._core()
logger.trace("FSM closed")