mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-04 02:26:11 +08:00
* 🐛 将Bilibili的调度速度降低到60s * ✨ 增加回避策略 * ✨ 降低轮询间隔,增加回避次数,抛出阶段随机刷新 * ♻️ 更清晰的调度逻辑实现 * 🐛 兼容3.10的NamedTuple多继承 * ♻️ 合并重复逻辑 * ♻️ ctx放入fsm * 🐛 测试并调整逻辑 * 🐛 补全类型标注 * ♻️ 添加Condition和State.on_exit/on_enter,以实现自动状态切换 * ✅ 调整测试 * 🐛 私有化命名方法 * 🔊 调整补充日志 * 🐛 添加测试后清理 * ✏️ fix typing typo
169 lines
5.5 KiB
Python
169 lines
5.5 KiB
Python
import sys
|
|
import asyncio
|
|
import inspect
|
|
from enum import Enum
|
|
from dataclasses import dataclass
|
|
from collections.abc import Set as AbstractSet
|
|
from collections.abc import Callable, Sequence, Awaitable, AsyncGenerator
|
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Protocol, TypeAlias, TypedDict, NamedTuple, runtime_checkable
|
|
|
|
from nonebot import logger
|
|
|
|
|
|
class StrEnum(str, Enum): ...
|
|
|
|
|
|
TAddon = TypeVar("TAddon", contravariant=True)
|
|
TState = TypeVar("TState", contravariant=True)
|
|
TEvent = TypeVar("TEvent", contravariant=True)
|
|
TFSM = TypeVar("TFSM", bound="FSM", contravariant=True)
|
|
|
|
|
|
class StateError(Exception): ...
|
|
|
|
|
|
ActionReturn: TypeAlias = Any
|
|
|
|
|
|
@runtime_checkable
|
|
class SupportStateOnExit(Generic[TAddon], Protocol):
|
|
async def on_exit(self, addon: TAddon) -> None: ...
|
|
|
|
|
|
@runtime_checkable
|
|
class SupportStateOnEnter(Generic[TAddon], Protocol):
|
|
async def on_enter(self, addon: TAddon) -> None: ...
|
|
|
|
|
|
class Action(Generic[TState, TEvent, TAddon], Protocol):
|
|
async def __call__(self, from_: TState, event: TEvent, to: TState, addon: TAddon) -> ActionReturn: ...
|
|
|
|
|
|
ConditionFunc = Callable[[TAddon], Awaitable[bool]]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Condition(Generic[TAddon]):
|
|
call: ConditionFunc[TAddon]
|
|
not_: bool = False
|
|
|
|
def __repr__(self):
|
|
if inspect.isfunction(self.call) or inspect.isclass(self.call):
|
|
call_str = self.call.__name__
|
|
else:
|
|
call_str = repr(self.call)
|
|
return f"Condition(call={call_str})"
|
|
|
|
async def __call__(self, addon: TAddon) -> bool:
|
|
return (await self.call(addon)) ^ self.not_
|
|
|
|
|
|
# FIXME: Python 3.11+ 才支持 NamedTuple和TypedDict使用多继承添加泛型
|
|
# 所以什么时候 drop 3.10(?
|
|
if sys.version_info >= (3, 11) or TYPE_CHECKING:
|
|
|
|
class Transition(Generic[TState, TEvent, TAddon], NamedTuple):
|
|
action: Action[TState, TEvent, TAddon]
|
|
to: TState
|
|
conditions: AbstractSet[Condition[TAddon]] | None = None
|
|
|
|
class StateGraph(Generic[TState, TEvent, TAddon], TypedDict):
|
|
transitions: dict[
|
|
TState,
|
|
dict[
|
|
TEvent,
|
|
Transition[TState, TEvent, TAddon] | Sequence[Transition[TState, TEvent, TAddon]],
|
|
],
|
|
]
|
|
initial: TState
|
|
|
|
else:
|
|
|
|
class Transition(NamedTuple):
|
|
action: Action
|
|
to: Any
|
|
conditions: AbstractSet[Condition] | None = None
|
|
|
|
class StateGraph(TypedDict):
|
|
transitions: dict[Any, dict[Any, Transition]]
|
|
initial: Any
|
|
|
|
|
|
class FSM(Generic[TState, TEvent, TAddon]):
|
|
def __init__(self, graph: StateGraph[TState, TEvent, TAddon], addon: TAddon):
|
|
self.started = False
|
|
self.graph = graph
|
|
self.current_state = graph["initial"]
|
|
self.machine = self._core()
|
|
self.addon = addon
|
|
|
|
async def _core(self) -> AsyncGenerator[ActionReturn, TEvent]:
|
|
self.current_state = self.graph["initial"]
|
|
res = None
|
|
while True:
|
|
event = yield res
|
|
|
|
if not self.started:
|
|
raise StateError("FSM not started, please call start() first")
|
|
|
|
selected_transition = await self.cherry_pick(event)
|
|
|
|
logger.trace(f"exit state: {self.current_state}")
|
|
if isinstance(self.current_state, SupportStateOnExit):
|
|
logger.trace(f"do {self.current_state}.on_exit")
|
|
await self.current_state.on_exit(self.addon)
|
|
|
|
logger.trace(f"do action: {selected_transition.action}")
|
|
res = await selected_transition.action(self.current_state, event, selected_transition.to, self.addon)
|
|
|
|
logger.trace(f"enter state: {selected_transition.to}")
|
|
self.current_state = selected_transition.to
|
|
|
|
if isinstance(self.current_state, SupportStateOnEnter):
|
|
logger.trace(f"do {self.current_state}.on_enter")
|
|
await self.current_state.on_enter(self.addon)
|
|
|
|
async def start(self):
|
|
await anext(self.machine)
|
|
self.started = True
|
|
logger.trace(f"FSM started, initial state: {self.current_state}")
|
|
|
|
async def cherry_pick(self, event: TEvent) -> Transition[TState, TEvent, TAddon]:
|
|
transitions = self.graph["transitions"][self.current_state].get(event)
|
|
if transitions is None:
|
|
raise StateError(f"Invalid event {event} in state {self.current_state}")
|
|
|
|
if isinstance(transitions, Transition):
|
|
return transitions
|
|
elif isinstance(transitions, Sequence):
|
|
no_conds: list[Transition[TState, TEvent, TAddon]] = []
|
|
for transition in transitions:
|
|
if not transition.conditions:
|
|
no_conds.append(transition)
|
|
continue
|
|
|
|
values = await asyncio.gather(*(condition(self.addon) for condition in transition.conditions))
|
|
|
|
if all(values):
|
|
logger.trace(f"conditions {transition.conditions} passed")
|
|
return transition
|
|
else:
|
|
if no_conds:
|
|
return no_conds.pop()
|
|
else:
|
|
raise StateError(f"Invalid event {event} in state {self.current_state}")
|
|
else:
|
|
raise TypeError("Invalid transition type: {transitions}, expected Transition or Sequence[Transition]")
|
|
|
|
async def emit(self, event: TEvent):
|
|
return await self.machine.asend(event)
|
|
|
|
async def reset(self):
|
|
await self.machine.aclose()
|
|
self.started = False
|
|
|
|
del self.machine
|
|
self.machine = self._core()
|
|
|
|
logger.trace("FSM closed")
|