# mypy: allow-untyped-defs
"""Registry for flash attention implementations.

This module contains the registration system for flash attention implementations.
It has no torch dependencies to avoid circular imports during initialization.
"""

from collections.abc import Callable
from typing import Literal, Protocol


class FlashAttentionHandle(Protocol):
    def remove(self) -> None: ...


_RegisterFn = Callable[..., FlashAttentionHandle | None]
_FlashAttentionImpl = Literal["FA4"]

_FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}

_FLASH_ATTENTION_ACTIVE: str | None = None
_FLASH_ATTENTION_HANDLES: dict[str, FlashAttentionHandle] = {}


def register_flash_attention_impl(
    impl: str | _FlashAttentionImpl,
    *,
    register_fn: _RegisterFn,
) -> None:
    """
    Register the callable that activates a flash attention impl.

    .. note::
        This function is intended for SDPA backend providers to register their
        implementations. End users should use :func:`activate_flash_attention_impl`
        to activate a registered implementation.

    Args:
        impl: Implementation identifier (e.g., ``"FA4"``).
        register_fn: Callable that performs the actual dispatcher registration.
            This function will be invoked by :func:`activate_flash_attention_impl`
            and should register custom kernels with the PyTorch dispatcher.
            It may optionally return a handle implementing
            :class:`FlashAttentionHandle` to keep any necessary state alive.

    Example:
        >>> def my_impl_register(module_path: str = "my_flash_impl"):
        ...     # Register custom kernels with torch dispatcher
        ...     pass  # doctest: +SKIP
        >>> register_flash_attention_impl(
        ...     "MyImpl", register_fn=my_impl_register
        ... )  # doctest: +SKIP
    """
    _FLASH_ATTENTION_IMPLS[impl] = register_fn


def activate_flash_attention_impl(
    impl: str | _FlashAttentionImpl,
) -> None:
    """
    Activate into the dispatcher a previously registered flash attention impl.

    .. note::
        Backend providers should NOT automatically activate their implementation
        on import. Users should explicitly opt-in by calling this function or via
        environment variables to ensure multiple provider libraries can coexist.

    Args:
        impl: Implementation identifier to activate. See
            :func:`~torch.nn.attention.list_flash_attention_impls` for available
            implementations.
            If the backend's :func:`register_flash_attention_impl` callable
            returns a :class:`FlashAttentionHandle`, the registry keeps that
            handle alive for the lifetime of the process (until explicit
            uninstall support exists).

    Example:
        >>> activate_flash_attention_impl("FA4")  # doctest: +SKIP
    """
    global _FLASH_ATTENTION_ACTIVE
    register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
    if register_fn is None:
        raise ValueError(
            f"Unknown flash attention impl '{impl}'. "
            f"Available implementations: {list_flash_attention_impls()}"
        )
    # TODO: The only way to actually register a new impl is to unregister the current impl
    # reinstall the default impl and then register the new impl
    if _FLASH_ATTENTION_ACTIVE == impl:
        return

    handle = register_fn()
    if handle is not None:
        _FLASH_ATTENTION_HANDLES[impl] = handle
    _FLASH_ATTENTION_ACTIVE = impl


def list_flash_attention_impls() -> list[str]:
    """Return the names of all available flash attention implementations."""
    return sorted(_FLASH_ATTENTION_IMPLS.keys())


def current_flash_attention_impl() -> str | None:
    """
    Return the currently activated flash attention impl name, if any.

    ``None`` indicates that no custom impl has been activated.
    """
    return _FLASH_ATTENTION_ACTIVE
