from __future__ import annotations

from typing import Any, TYPE_CHECKING

import torch

from ._dim_entry import DimEntry


if TYPE_CHECKING:
    from . import Dim, Tensor


class EnableAllLayers:
    """
    RAII-style context manager for enabling functorch vmap layers.
    It manages the creation and cleanup of functorch dynamic layers.

    This is probably one of the more algorithmically important parts of first
    class dims. Intuitively, FCD can be thought of as another way of using
    vmap, where you don't actually have to vmap at the top level, instead the
    vmaps are implicitly determined by inspecting the bound dimensions on the
    FCD tensors involved in a compute (this is similar to our concept of
    non-lexical modes that we spent a long time talking about years ago). But
    under the hood you still need to actually enable the vmap mode. So once
    FCD has determined all of the dims we are batching over, it needs to
    enable all those layers so functorch can actually apply the batching
    rules. Therefore enable all layers!
    """

    levels_start: int
    levels_to_dim: list[Dim]

    def __init__(self, levels: list[DimEntry]):
        """
        Initialize and push dynamic layers for all first-class dimensions.

        Args:
            levels: List of dimension entries to create layers for
        """

        from . import Dim

        self.levels_start = 0
        self.levels_to_dim = []

        for l in levels:
            if not l.is_positional():
                d = l.dim()
                assert isinstance(d, Dim)
                self.levels_to_dim.append(d)

        # Sort by level for stable ordering
        self.levels_to_dim.sort(key=lambda d: d._level)

    def __enter__(self) -> EnableAllLayers:  # noqa: PYI034
        # Create functorch dynamic layers
        for i, dim in enumerate(self.levels_to_dim):
            batch_size = dim.size
            level = torch._C._functorch._vmap_increment_nesting(batch_size, "different")
            if i == 0:
                self.levels_start = level
        return self

    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
        """Clean up dynamic layers in reverse order."""
        to_remove = self.levels_start + len(self.levels_to_dim) - 1
        for i in range(len(self.levels_to_dim)):
            popped = torch._C._functorch._vmap_decrement_nesting()
            assert popped == to_remove - i, (
                f"Expected layer {to_remove - i}, got {popped}"
            )

    def from_batched(self, batchedtensor: torch.Tensor, has_device: bool) -> Tensor:
        """
        Create a Tensor from a batched tensor by unwrapping functorch layers.

        Args:
            batchedtensor: Batched tensor from functorch operation
            has_device: Whether tensor has device info

        Returns:
            Tensor with appropriate levels
        """
        # Create positional levels for base dimensions
        levels: list[DimEntry] = []
        for i in range(-batchedtensor.dim(), 0):
            levels.append(DimEntry(i))

        tensor = batchedtensor

        while torch._C._functorch.is_batchedtensor(tensor):
            level = torch._C._functorch.maybe_get_level(tensor)
            assert level is not None
            assert level >= self.levels_start and level < self.levels_start + len(
                self.levels_to_dim
            )
            dim = DimEntry(self.levels_to_dim[level - self.levels_start])
            bdim = torch._C._functorch.maybe_get_bdim(tensor)
            assert bdim is not None
            levels.insert(bdim, dim)
            tensor = torch._C._functorch.get_unwrapped(tensor)

        from . import Tensor

        result = Tensor()
        result._tensor = tensor
        result._batchtensor = batchedtensor
        result._has_device = has_device
        result._levels = levels
        return result

    def inplace_update_layers(
        self, batchtensor: torch.Tensor, levels: list[DimEntry]
    ) -> None:
        """
        Update the levels of a batched tensor in place.

        This requires the _maybe_unsafe_set_level binding that we'll add to functorch.

        Args:
            batchtensor: Batched tensor to update
            levels: New levels to set
        """
        # Check if tensor is batched
        if not torch._C._functorch.is_batchedtensor(batchtensor):
            return

        impl = batchtensor

        for i in reversed(range(len(self.levels_to_dim))):
            if impl is None:
                break

            if any(l == DimEntry(self.levels_to_dim[i]) for l in levels):
                # This is very interesting!  The level on batch tensor is
                # meaningless!  We set it RIGHT before we go into vmap
                torch._C._functorch._maybe_unsafe_set_level(impl, self.levels_start + i)
                impl = torch._C._functorch.get_unwrapped(impl)
