mirror of
https://github.com/suyiiyii/nonebot-bison.git
synced 2025-06-02 09:26:12 +08:00
✨ 新增可以在 fsm 抛出错误后重置 fsm 的装饰器工具
This commit is contained in:
parent
ab5236ee37
commit
088e7a439f
@ -2,10 +2,24 @@ import sys
|
||||
import asyncio
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Set as AbstractSet
|
||||
from collections.abc import Callable, Sequence, Awaitable, AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Protocol, TypeAlias, TypedDict, NamedTuple, runtime_checkable
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Protocol,
|
||||
ParamSpec,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
NamedTuple,
|
||||
Concatenate,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from nonebot import logger
|
||||
|
||||
@ -17,6 +31,7 @@ TAddon = TypeVar("TAddon", contravariant=True)
|
||||
TState = TypeVar("TState", contravariant=True)
|
||||
TEvent = TypeVar("TEvent", contravariant=True)
|
||||
TFSM = TypeVar("TFSM", bound="FSM", contravariant=True)
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class StateError(Exception): ...
|
||||
@ -163,6 +178,52 @@ class FSM(Generic[TState, TEvent, TAddon]):
|
||||
self.started = False
|
||||
|
||||
del self.machine
|
||||
self.current_state = self.graph["initial"]
|
||||
self.machine = self._core()
|
||||
|
||||
logger.trace("FSM closed")
|
||||
|
||||
|
||||
@overload
|
||||
def reset_on_exception(
|
||||
func: Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]],
|
||||
) -> Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]:
|
||||
"""自动在发生异常后重置 FSM"""
|
||||
|
||||
|
||||
@overload
|
||||
def reset_on_exception(
|
||||
auto_start: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]], Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]
|
||||
]:
|
||||
"""自动在异常后重置 FSM,当 auto_start 为 True 时,自动启动 FSM"""
|
||||
|
||||
|
||||
# 参考自 dataclasses.dataclass 的实现
|
||||
def reset_on_exception(func=None, /, *, auto_start=False): # pyright: ignore[reportInconsistentOverload]
|
||||
def warp(func: Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]):
|
||||
return __reset_clear_up(func, auto_start)
|
||||
|
||||
# 判断调用的是 @reset_on_exception 还是 @reset_on_exception(...)
|
||||
if func is None:
|
||||
# 调用的是带括号的
|
||||
return warp
|
||||
|
||||
# 调用的是不带括号的
|
||||
return warp(func)
|
||||
|
||||
|
||||
def __reset_clear_up(func: Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]], auto_start: bool):
|
||||
@wraps(func)
|
||||
async def wrapper(fsm_self: TFSM, *args: P.args, **kwargs: P.kwargs) -> ActionReturn:
|
||||
try:
|
||||
return await func(fsm_self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in {func.__name__}: {e}")
|
||||
await fsm_self.reset()
|
||||
if auto_start and not fsm_self.started:
|
||||
await fsm_self.start()
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
@ -14,7 +14,7 @@ from httpx import URL as HttpxURL
|
||||
from nonebot_bison.types import Target
|
||||
|
||||
from .models import DynRawPost
|
||||
from .fsm import FSM, Condition, StateGraph, Transition, ActionReturn
|
||||
from .fsm import FSM, Condition, StateGraph, Transition, ActionReturn, reset_on_exception
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .platforms import Bilibili
|
||||
@ -218,6 +218,11 @@ class RetryFSM(FSM[RetryState, RetryEvent, RetryAddon[TBilibili]]):
|
||||
self.addon.reset_all()
|
||||
await super().reset()
|
||||
|
||||
@override
|
||||
@reset_on_exception
|
||||
async def emit(self, event: RetryEvent):
|
||||
await super().emit(event)
|
||||
|
||||
|
||||
# FIXME: 拿出来是方便测试了,但全局单例会导致所有被装饰的函数共享状态,有待改进
|
||||
_retry_fsm = RetryFSM(RETRY_GRAPH, RetryAddon["Bilibili"]())
|
||||
|
@ -58,6 +58,92 @@ def without_dynamic(app: App):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_on_exception(app: App):
|
||||
from strenum import StrEnum
|
||||
|
||||
from nonebot_bison.platform.bilibili.fsm import FSM, StateGraph, Transition, ActionReturn, reset_on_exception
|
||||
|
||||
class State(StrEnum):
|
||||
A = "A"
|
||||
B = "B"
|
||||
C = "C"
|
||||
|
||||
class Event(StrEnum):
|
||||
A = "A"
|
||||
B = "B"
|
||||
C = "C"
|
||||
|
||||
class Addon:
|
||||
pass
|
||||
|
||||
async def raction(from_: State, event: Event, to: State, addon: Addon) -> ActionReturn:
|
||||
logger.info(f"action: {from_} -> {to}")
|
||||
raise RuntimeError("test")
|
||||
|
||||
async def action(from_: State, event: Event, to: State, addon: Addon) -> ActionReturn:
|
||||
logger.info(f"action: {from_} -> {to}")
|
||||
|
||||
graph: StateGraph[State, Event, Addon] = {
|
||||
"transitions": {
|
||||
State.A: {
|
||||
Event.A: Transition(raction, State.B),
|
||||
Event.B: Transition(action, State.C),
|
||||
},
|
||||
State.B: {
|
||||
Event.B: Transition(action, State.C),
|
||||
},
|
||||
State.C: {
|
||||
Event.C: Transition(action, State.A),
|
||||
},
|
||||
},
|
||||
"initial": State.A,
|
||||
}
|
||||
|
||||
addon = Addon()
|
||||
|
||||
class AFSM(FSM[State, Event, Addon]):
|
||||
@reset_on_exception(auto_start=True)
|
||||
async def emit(self, event: Event):
|
||||
return await super().emit(event)
|
||||
|
||||
fsm = AFSM(graph, addon)
|
||||
|
||||
await fsm.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
await fsm.emit(Event.A)
|
||||
|
||||
assert fsm.started is True
|
||||
await fsm.emit(Event.B)
|
||||
await fsm.emit(Event.C)
|
||||
|
||||
class BFSM(FSM[State, Event, Addon]):
|
||||
@reset_on_exception
|
||||
async def emit(self, event: Event):
|
||||
return await super().emit(event)
|
||||
|
||||
fsm = BFSM(graph, addon)
|
||||
await fsm.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
await fsm.emit(Event.A)
|
||||
|
||||
assert fsm.started is False
|
||||
with pytest.raises(TypeError, match="can't send non-None value to a just-started async generator"):
|
||||
await fsm.emit(Event.B)
|
||||
|
||||
class CFSM(FSM[State, Event, Addon]): ...
|
||||
|
||||
fsm = CFSM(graph, addon)
|
||||
await fsm.start()
|
||||
with pytest.raises(RuntimeError):
|
||||
await fsm.emit(Event.A)
|
||||
|
||||
assert fsm.started is True
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await fsm.emit(Event.B)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_for_352(app: App, mocker: MockerFixture):
|
||||
from nonebot_bison.post import Post
|
||||
|
Loading…
x
Reference in New Issue
Block a user