import sys
import asyncio
import inspect
from enum import Enum
from dataclasses import dataclass
from collections.abc import Set as AbstractSet
from collections.abc import Callable, Sequence, Awaitable, AsyncGenerator
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Protocol, TypeAlias, TypedDict, NamedTuple, runtime_checkable

from nonebot import logger


class StrEnum(str, Enum): ...


TAddon = TypeVar("TAddon", contravariant=True)
TState = TypeVar("TState", contravariant=True)
TEvent = TypeVar("TEvent", contravariant=True)
TFSM = TypeVar("TFSM", bound="FSM", contravariant=True)


class StateError(Exception): ...


ActionReturn: TypeAlias = Any


@runtime_checkable
class SupportStateOnExit(Generic[TAddon], Protocol):
    async def on_exit(self, addon: TAddon) -> None: ...


@runtime_checkable
class SupportStateOnEnter(Generic[TAddon], Protocol):
    async def on_enter(self, addon: TAddon) -> None: ...


class Action(Generic[TState, TEvent, TAddon], Protocol):
    async def __call__(self, from_: TState, event: TEvent, to: TState, addon: TAddon) -> ActionReturn: ...


ConditionFunc = Callable[[TAddon], Awaitable[bool]]


@dataclass(frozen=True)
class Condition(Generic[TAddon]):
    call: ConditionFunc[TAddon]
    not_: bool = False

    def __repr__(self):
        if inspect.isfunction(self.call) or inspect.isclass(self.call):
            call_str = self.call.__name__
        else:
            call_str = repr(self.call)
        return f"Condition(call={call_str})"

    async def __call__(self, addon: TAddon) -> bool:
        return (await self.call(addon)) ^ self.not_


# FIXME: Python 3.11+ 才支持 NamedTuple和TypedDict使用多继承添加泛型
# 所以什么时候 drop 3.10(?
if sys.version_info >= (3, 11) or TYPE_CHECKING:

    class Transition(Generic[TState, TEvent, TAddon], NamedTuple):
        action: Action[TState, TEvent, TAddon]
        to: TState
        conditions: AbstractSet[Condition[TAddon]] | None = None

    class StateGraph(Generic[TState, TEvent, TAddon], TypedDict):
        transitions: dict[
            TState,
            dict[
                TEvent,
                Transition[TState, TEvent, TAddon] | Sequence[Transition[TState, TEvent, TAddon]],
            ],
        ]
        initial: TState

else:

    class Transition(NamedTuple):
        action: Action
        to: Any
        conditions: AbstractSet[Condition] | None = None

    class StateGraph(TypedDict):
        transitions: dict[Any, dict[Any, Transition]]
        initial: Any


class FSM(Generic[TState, TEvent, TAddon]):
    def __init__(self, graph: StateGraph[TState, TEvent, TAddon], addon: TAddon):
        self.started = False
        self.graph = graph
        self.current_state = graph["initial"]
        self.machine = self._core()
        self.addon = addon

    async def _core(self) -> AsyncGenerator[ActionReturn, TEvent]:
        self.current_state = self.graph["initial"]
        res = None
        while True:
            event = yield res

            if not self.started:
                raise StateError("FSM not started, please call start() first")

            selected_transition = await self.cherry_pick(event)

            logger.trace(f"exit state: {self.current_state}")
            if isinstance(self.current_state, SupportStateOnExit):
                logger.trace(f"do {self.current_state}.on_exit")
                await self.current_state.on_exit(self.addon)

            logger.trace(f"do action: {selected_transition.action}")
            res = await selected_transition.action(self.current_state, event, selected_transition.to, self.addon)

            logger.trace(f"enter state: {selected_transition.to}")
            self.current_state = selected_transition.to

            if isinstance(self.current_state, SupportStateOnEnter):
                logger.trace(f"do {self.current_state}.on_enter")
                await self.current_state.on_enter(self.addon)

    async def start(self):
        await anext(self.machine)
        self.started = True
        logger.trace(f"FSM started, initial state: {self.current_state}")

    async def cherry_pick(self, event: TEvent) -> Transition[TState, TEvent, TAddon]:
        transitions = self.graph["transitions"][self.current_state].get(event)
        if transitions is None:
            raise StateError(f"Invalid event {event} in state {self.current_state}")

        if isinstance(transitions, Transition):
            return transitions
        elif isinstance(transitions, Sequence):
            no_conds: list[Transition[TState, TEvent, TAddon]] = []
            for transition in transitions:
                if not transition.conditions:
                    no_conds.append(transition)
                    continue

                values = await asyncio.gather(*(condition(self.addon) for condition in transition.conditions))

                if all(values):
                    logger.trace(f"conditions {transition.conditions} passed")
                    return transition
            else:
                if no_conds:
                    return no_conds.pop()
                else:
                    raise StateError(f"Invalid event {event} in state {self.current_state}")
        else:
            raise TypeError("Invalid transition type: {transitions}, expected Transition or Sequence[Transition]")

    async def emit(self, event: TEvent):
        return await self.machine.asend(event)

    async def reset(self):
        await self.machine.aclose()
        self.started = False

        del self.machine
        self.machine = self._core()

        logger.trace("FSM closed")