diff --git a/nonebot_bison/platform/bilibili/fsm.py b/nonebot_bison/platform/bilibili/fsm.py index 9f9466d..6252f53 100644 --- a/nonebot_bison/platform/bilibili/fsm.py +++ b/nonebot_bison/platform/bilibili/fsm.py @@ -2,10 +2,24 @@ import sys import asyncio import inspect from enum import Enum +from functools import wraps from dataclasses import dataclass from collections.abc import Set as AbstractSet from collections.abc import Callable, Sequence, Awaitable, AsyncGenerator -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Protocol, TypeAlias, TypedDict, NamedTuple, runtime_checkable +from typing import ( + TYPE_CHECKING, + Any, + Generic, + TypeVar, + Protocol, + ParamSpec, + TypeAlias, + TypedDict, + NamedTuple, + Concatenate, + overload, + runtime_checkable, +) from nonebot import logger @@ -17,6 +31,7 @@ TAddon = TypeVar("TAddon", contravariant=True) TState = TypeVar("TState", contravariant=True) TEvent = TypeVar("TEvent", contravariant=True) TFSM = TypeVar("TFSM", bound="FSM", contravariant=True) +P = ParamSpec("P") class StateError(Exception): ... @@ -163,6 +178,52 @@ class FSM(Generic[TState, TEvent, TAddon]): self.started = False del self.machine + self.current_state = self.graph["initial"] self.machine = self._core() logger.trace("FSM closed") + + +@overload +def reset_on_exception( + func: Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]], +) -> Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]: + """自动在发生异常后重置 FSM""" + + +@overload +def reset_on_exception( + auto_start: bool = False, +) -> Callable[ + [Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]], Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]] +]: + """自动在异常后重置 FSM,当 auto_start 为 True 时,自动启动 FSM""" + + +# 参考自 dataclasses.dataclass 的实现 +def reset_on_exception(func=None, /, *, auto_start=False): # pyright: ignore[reportInconsistentOverload] + def warp(func: Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]]): + return __reset_clear_up(func, auto_start) + + # 判断调用的是 @reset_on_exception 还是 @reset_on_exception(...) + if func is None: + # 调用的是带括号的 + return warp + + # 调用的是不带括号的 + return warp(func) + + +def __reset_clear_up(func: Callable[Concatenate[TFSM, P], Awaitable[ActionReturn]], auto_start: bool): + @wraps(func) + async def wrapper(fsm_self: TFSM, *args: P.args, **kwargs: P.kwargs) -> ActionReturn: + try: + return await func(fsm_self, *args, **kwargs) + except Exception as e: + logger.error(f"Exception in {func.__name__}: {e}") + await fsm_self.reset() + if auto_start and not fsm_self.started: + await fsm_self.start() + raise e + + return wrapper diff --git a/nonebot_bison/platform/bilibili/retry.py b/nonebot_bison/platform/bilibili/retry.py index 94d0cb0..20cc9ed 100644 --- a/nonebot_bison/platform/bilibili/retry.py +++ b/nonebot_bison/platform/bilibili/retry.py @@ -14,7 +14,7 @@ from httpx import URL as HttpxURL from nonebot_bison.types import Target from .models import DynRawPost -from .fsm import FSM, Condition, StateGraph, Transition, ActionReturn +from .fsm import FSM, Condition, StateGraph, Transition, ActionReturn, reset_on_exception if TYPE_CHECKING: from .platforms import Bilibili @@ -218,6 +218,11 @@ class RetryFSM(FSM[RetryState, RetryEvent, RetryAddon[TBilibili]]): self.addon.reset_all() await super().reset() + @override + @reset_on_exception + async def emit(self, event: RetryEvent): + await super().emit(event) + # FIXME: 拿出来是方便测试了,但全局单例会导致所有被装饰的函数共享状态,有待改进 _retry_fsm = RetryFSM(RETRY_GRAPH, RetryAddon["Bilibili"]()) diff --git a/tests/platforms/test_bilibili.py b/tests/platforms/test_bilibili.py index ac54005..e31dfd8 100644 --- a/tests/platforms/test_bilibili.py +++ b/tests/platforms/test_bilibili.py @@ -58,6 +58,92 @@ def without_dynamic(app: App): ) +@pytest.mark.asyncio +async def test_reset_on_exception(app: App): + from strenum import StrEnum + + from nonebot_bison.platform.bilibili.fsm import FSM, StateGraph, Transition, ActionReturn, reset_on_exception + + class State(StrEnum): + A = "A" + B = "B" + C = "C" + + class Event(StrEnum): + A = "A" + B = "B" + C = "C" + + class Addon: + pass + + async def raction(from_: State, event: Event, to: State, addon: Addon) -> ActionReturn: + logger.info(f"action: {from_} -> {to}") + raise RuntimeError("test") + + async def action(from_: State, event: Event, to: State, addon: Addon) -> ActionReturn: + logger.info(f"action: {from_} -> {to}") + + graph: StateGraph[State, Event, Addon] = { + "transitions": { + State.A: { + Event.A: Transition(raction, State.B), + Event.B: Transition(action, State.C), + }, + State.B: { + Event.B: Transition(action, State.C), + }, + State.C: { + Event.C: Transition(action, State.A), + }, + }, + "initial": State.A, + } + + addon = Addon() + + class AFSM(FSM[State, Event, Addon]): + @reset_on_exception(auto_start=True) + async def emit(self, event: Event): + return await super().emit(event) + + fsm = AFSM(graph, addon) + + await fsm.start() + with pytest.raises(RuntimeError): + await fsm.emit(Event.A) + + assert fsm.started is True + await fsm.emit(Event.B) + await fsm.emit(Event.C) + + class BFSM(FSM[State, Event, Addon]): + @reset_on_exception + async def emit(self, event: Event): + return await super().emit(event) + + fsm = BFSM(graph, addon) + await fsm.start() + with pytest.raises(RuntimeError): + await fsm.emit(Event.A) + + assert fsm.started is False + with pytest.raises(TypeError, match="can't send non-None value to a just-started async generator"): + await fsm.emit(Event.B) + + class CFSM(FSM[State, Event, Addon]): ... + + fsm = CFSM(graph, addon) + await fsm.start() + with pytest.raises(RuntimeError): + await fsm.emit(Event.A) + + assert fsm.started is True + + with pytest.raises(StopAsyncIteration): + await fsm.emit(Event.B) + + @pytest.mark.asyncio async def test_retry_for_352(app: App, mocker: MockerFixture): from nonebot_bison.post import Post