mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-04 02:26:11 +08:00
230 lines
7.2 KiB
Python
230 lines
7.2 KiB
Python
import sys
|
||
import asyncio
|
||
import inspect
|
||
from enum import Enum
|
||
from functools import wraps
|
||
from dataclasses import dataclass
|
||
from collections.abc import Set as AbstractSet
|
||
from collections.abc import Callable, Sequence, Awaitable, AsyncGenerator
|
||
from typing import (
|
||
TYPE_CHECKING,
|
||
Any,
|
||
Generic,
|
||
TypeVar,
|
||
Protocol,
|
||
ParamSpec,
|
||
TypeAlias,
|
||
TypedDict,
|
||
NamedTuple,
|
||
Concatenate,
|
||
overload,
|
||
runtime_checkable,
|
||
)
|
||
|
||
from nonebot import logger
|
||
|
||
|
||
class StrEnum(str, Enum): ...
|
||
|
||
|
||
TAddon = TypeVar("TAddon", contravariant=True)
|
||
TState = TypeVar("TState", contravariant=True)
|
||
TEvent = TypeVar("TEvent", contravariant=True)
|
||
TFSM = TypeVar("TFSM", bound="FSM", contravariant=True)
|
||
P = ParamSpec("P")
|
||
|
||
|
||
class StateError(Exception): ...
|
||
|
||
|
||
ActionReturn: TypeAlias = Any
|
||
|
||
|
||
@runtime_checkable
|
||
class SupportStateOnExit(Generic[TAddon], Protocol):
|
||
async def on_exit(self, addon: TAddon) -> None: ...
|
||
|
||
|
||
@runtime_checkable
|
||
class SupportStateOnEnter(Generic[TAddon], Protocol):
|
||
async def on_enter(self, addon: TAddon) -> None: ...
|
||
|
||
|
||
class Action(Generic[TState, TEvent, TAddon], Protocol):
|
||
async def __call__(self, from_: TState, event: TEvent, to: TState, addon: TAddon) -> ActionReturn: ...
|
||
|
||
|
||
ConditionFunc = Callable[[TAddon], Awaitable[bool]]
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class Condition(Generic[TAddon]):
|
||
call: ConditionFunc[TAddon]
|
||
not_: bool = False
|
||
|
||
def __repr__(self):
|
||
if inspect.isfunction(self.call) or inspect.isclass(self.call):
|
||
call_str = self.call.__name__
|
||
else:
|
||
call_str = repr(self.call)
|
||
return f"Condition(call={call_str})"
|
||
|
||
async def __call__(self, addon: TAddon) -> bool:
|
||
return (await self.call(addon)) ^ self.not_
|
||
|
||
|
||
# FIXME: Python 3.11+ 才支持 NamedTuple和TypedDict使用多继承添加泛型
|
||
# 所以什么时候 drop 3.10(?
|
||
if sys.version_info >= (3, 11) or TYPE_CHECKING:
|
||
|
||
class Transition(Generic[TState, TEvent, TAddon], NamedTuple):
|
||
action: Action[TState, TEvent, TAddon]
|
||
to: TState
|
||
conditions: AbstractSet[Condition[TAddon]] | None = None
|
||
|
||
class StateGraph(Generic[TState, TEvent, TAddon], TypedDict):
|
||
transitions: dict[
|
||
TState,
|
||
dict[
|
||
TEvent,
|
||
Transition[TState, TEvent, TAddon] | Sequence[Transition[TState, TEvent, TAddon]],
|
||
],
|
||
]
|
||
initial: TState
|
||
|
||
else:
|
||
|
||
class Transition(NamedTuple):
|
||
action: Action
|
||
to: Any
|
||
conditions: AbstractSet[Condition] | None = None
|
||
|
||
class StateGraph(TypedDict):
|
||
transitions: dict[Any, dict[Any, Transition]]
|
||
initial: Any
|
||
|
||
|
||
class FSM(Generic[TState, TEvent, TAddon]):
|
||
def __init__(self, graph: StateGraph[TState, TEvent, TAddon], addon: TAddon):
|
||
self.started = False
|
||
self.graph = graph
|
||
self.current_state = graph["initial"]
|
||
self.machine = self._core()
|
||
self.addon = addon
|
||
|
||
async def _core(self) -> AsyncGenerator[ActionReturn, TEvent]:
|
||
self.current_state = self.graph["initial"]
|
||
res = None
|
||
while True:
|
||
event = yield res
|
||
|
||
if not self.started:
|
||
raise StateError("FSM not started, please call start() first")
|
||
|
||
selected_transition = await self.cherry_pick(event)
|
||
|
||
logger.trace(f"exit state: {self.current_state}")
|
||
if isinstance(self.current_state, SupportStateOnExit):
|
||
logger.trace(f"do {self.current_state}.on_exit")
|
||
await self.current_state.on_exit(self.addon)
|
||
|
||
logger.trace(f"do action: {selected_transition.action}")
|
||
res = await selected_transition.action(self.current_state, event, selected_transition.to, self.addon)
|
||
|
||
logger.trace(f"enter state: {selected_transition.to}")
|
||
self.current_state = selected_transition.to
|
||
|
||
if isinstance(self.current_state, SupportStateOnEnter):
|
||
logger.trace(f"do {self.current_state}.on_enter")
|
||
await self.current_state.on_enter(self.addon)
|
||
|
||
async def start(self):
|
||
await anext(self.machine)
|
||
self.started = True
|
||
logger.trace(f"FSM started, initial state: {self.current_state}")
|
||
|
||
async def cherry_pick(self, event: TEvent) -> Transition[TState, TEvent, TAddon]:
|
||
transitions = self.graph["transitions"][self.current_state].get(event)
|
||
if transitions is None:
|
||
raise StateError(f"Invalid event {event} in state {self.current_state}")
|
||
|
||
if isinstance(transitions, Transition):
|
||
return transitions
|
||
elif isinstance(transitions, Sequence):
|
||
no_conds: list[Transition[TState, TEvent, TAddon]] = []
|
||
for transition in transitions:
|
||
if not transition.conditions:
|
||
no_conds.append(transition)
|
||
continue
|
||
|
||
values = await asyncio.gather(*(condition(self.addon) for condition in transition.conditions))
|
||
|
||
if all(values):
|
||
logger.trace(f"conditions {transition.conditions} passed")
|
||
return transition
|
||
else:
|
||
if no_conds:
|
||
return no_conds.pop()
|
||
else:
|
||
raise StateError(f"Invalid event {event} in state {self.current_state}")
|
||
else:
|
||
raise TypeError("Invalid transition type: {transitions}, expected Transition or Sequence[Transition]")
|
||
|
||
async def emit(self, event: TEvent):
|
||
return await self.machine.asend(event)
|
||
|
||
async def reset(self):
|
||
await self.machine.aclose()
|
||
self.started = False
|
||
|
||
del self.machine
|
||
self.current_state = self.graph["initial"]
|
||
self.machine = self._core()
|
||
|
||
logger.trace("FSM closed")
|
||
|
||
|
||
@overload
|
||
def reset_on_exception(
|
||
func: Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]],
|
||
) -> Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]:
|
||
"""自动在发生异常后重置 FSM"""
|
||
|
||
|
||
@overload
|
||
def reset_on_exception(
|
||
auto_start: bool = False,
|
||
) -> Callable[
|
||
[Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]], Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]
|
||
]:
|
||
"""自动在异常后重置 FSM,当 auto_start 为 True 时,自动启动 FSM"""
|
||
|
||
|
||
# 参考自 dataclasses.dataclass 的实现
|
||
def reset_on_exception(func=None, /, *, auto_start=False): # pyright: ignore[reportInconsistentOverload]
|
||
def warp(func: Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]):
|
||
return __reset_clear_up(func, auto_start)
|
||
|
||
# 判断调用的是 @reset_on_exception 还是 @reset_on_exception(...)
|
||
if func is None:
|
||
# 调用的是带括号的
|
||
return warp
|
||
|
||
# 调用的是不带括号的
|
||
return warp(func)
|
||
|
||
|
||
def __reset_clear_up(func: Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]], auto_start: bool):
|
||
@wraps(func)
|
||
async def wrapper(fsm_self: TFSM, *args: P.args, **kwargs: P.kwargs) -> ActionReturn:
|
||
try:
|
||
return await func(fsm_self, *args, **kwargs)
|
||
except Exception as e:
|
||
logger.error(f"Exception in {func.__name__}: {e}")
|
||
await fsm_self.reset()
|
||
if auto_start and not fsm_self.started:
|
||
await fsm_self.start()
|
||
raise e
|
||
|
||
return wrapper
|