from __future__ import annotations

from typing import Any, TYPE_CHECKING, Union


if TYPE_CHECKING:
    from collections.abc import Sequence

import torch  # noqa: TC002

from ._dim_entry import _match_levels, DimEntry, ndim_of_levels


def _wrap_dim(arg: Any, orig_ndim: int, allow_none: bool = True) -> DimEntry:
    """
    Convert various dimension representations to DimEntry.

    Args:
        arg: The argument to convert (Dim, int, or other)
        orig_ndim: Original number of dimensions
        allow_none: Whether to allow None values

    Returns:
        DimEntry representation of the dimension
    """
    from . import Dim

    if arg is None and allow_none:
        return DimEntry()  # None entry
    elif isinstance(arg, Dim):
        return DimEntry(arg)
    elif isinstance(arg, int):
        if arg < 0:
            pos = arg
        else:
            pos = arg - orig_ndim
        return DimEntry(pos)
    else:
        return DimEntry()


def order(
    tensor_or_dim: Union[torch.Tensor, Any], *dims: Union[Any, Sequence[Any]]
) -> torch.Tensor:
    """
    Reorder the dimensions of a tensor or create a tensor from a dimension.

    It allows reordering tensor dimensions using first-class dimensions and
    positional indices.

    Args:
        tensor_or_dim: Input tensor with first-class dimensions, or a Dim object
        *dims: Dimensions or sequences of dimensions specifying the new order

    Returns:
        Tensor with reordered dimensions

    Examples:
        >>> import torch
        >>> from functorch.dim import dims
        >>> batch, channel, height, width = dims(4)
        >>> x = torch.randn(2, 3, 4, 5)[batch, channel, height, width]
        >>> # Reorder to [height, width, batch, channel]
        >>> y = order(x, height, width, batch, channel)
    """
    from . import Dim, DimList, Tensor

    # Handle first argument - tensor or dimension
    if isinstance(tensor_or_dim, Tensor):
        # First-class tensor
        orig_levels = tensor_or_dim._levels[:]
        data = tensor_or_dim._tensor
        has_device = tensor_or_dim._has_device
    elif isinstance(tensor_or_dim, Dim):
        # Single dimension - create range tensor
        orig_levels = [DimEntry(tensor_or_dim)]
        data = tensor_or_dim._get_range()
        has_device = False
    else:
        raise ValueError("First argument must be a Tensor or Dim object")

    flat_positional_dims = []
    to_flatten = []  # List of (start_index, length) pairs for flattening
    levels = orig_levels[:]

    orig_ndim = ndim_of_levels(levels)

    def append_dim(d: DimEntry) -> None:
        """Add a dimension to the reordering, removing it from available levels."""
        try:
            idx = levels.index(d)
        except ValueError:
            idx = None
        if idx is None:
            if d.is_positional():
                raise ValueError(
                    f"tensor has {orig_ndim} positional dimensions, but {d.position() + orig_ndim} specified, "
                    f"or it was specified twice"
                )
            else:
                raise ValueError(
                    f"tensor does not contain dim {d.dim()} or it was specified twice"
                )

        levels[idx] = DimEntry()
        flat_positional_dims.append(d)

    n_new_positional = 0

    # Process each dimension argument
    for arg in dims:
        entry = _wrap_dim(arg, orig_ndim, False)
        if not entry.is_none():
            append_dim(entry)
            n_new_positional += 1
        elif isinstance(arg, DimList):
            # Handle DimList
            for dim in arg._dims:
                append_dim(DimEntry(dim))
                n_new_positional += 1
        else:
            # Handle sequences of dimensions for flattening
            n_new_positional += 1
            if not hasattr(arg, "__iter__"):
                raise ValueError("expected a Dim, List[Dim], or Sequence[Dim]")

            # Convert to list to get length
            seq = list(arg)
            to_flatten.append((len(flat_positional_dims), len(seq)))

            for item in seq:
                entry = _wrap_dim(item, orig_ndim, False)
                if entry.is_none():
                    raise ValueError("expected a Dim or int")
                append_dim(entry)

    # Build new level ordering
    insert_point = -1
    new_levels: list[DimEntry] = []

    # Add remaining (non-reordered) levels, finding insertion point for new dimensions
    for level in levels:
        if level.is_none():
            continue
        if level.is_positional():
            if insert_point == -1:
                insert_point = len(new_levels)
                new_levels.extend(flat_positional_dims)
        new_levels.append(level)

    # If no positional dimensions found, append new dims at the end
    if insert_point == -1:
        insert_point = len(new_levels)
        new_levels.extend(flat_positional_dims)

    # Match tensor to new level structure
    assert data is not None, "Cannot reorder None tensor"
    ndata = _match_levels(data, orig_levels, new_levels)

    # Handle dimension flattening if requested
    if to_flatten:
        # Now build the reshape target
        view_shape = []
        sizes = ndata.size()

        # Add dimensions before the reordered ones
        for i in range(insert_point):
            view_shape.append(sizes[i])

        # Process flattening groups
        i = 0
        for start_idx, length in to_flatten:
            # Add individual dims before this flattening group
            while i < start_idx:
                view_shape.append(sizes[insert_point + i])
                i += 1

            # Flatten the group
            new_size = 1
            for j in range(length):
                new_size *= sizes[insert_point + i + j]
            view_shape.append(new_size)
            i += length

        # Add remaining individual dims
        while i < len(flat_positional_dims):
            view_shape.append(sizes[insert_point + i])
            i += 1

        # Add dimensions after the reordered ones
        for i in range(insert_point + len(flat_positional_dims), len(levels)):
            view_shape.append(sizes[i])

        # Update levels by removing flattened dimensions
        n_to_remove = len(flat_positional_dims) - n_new_positional
        if n_to_remove > 0:
            # Remove flattened levels
            new_levels = (
                new_levels[:insert_point] + new_levels[insert_point + n_to_remove :]
            )

        ndata = ndata.reshape(view_shape)

    # Renumber positional dimensions (negative indexing from the right)
    seen = 0
    for i in range(len(new_levels) - 1, -1, -1):
        if new_levels[i].is_positional() or (
            i >= insert_point and i < insert_point + n_new_positional
        ):
            seen -= 1
            new_levels[i] = DimEntry(seen)

    result = Tensor.from_positional(ndata, new_levels, has_device)
    return result  # type: ignore[return-value]
