from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Optional, TYPE_CHECKING, Union

import torch

from ._dim_entry import _match_levels, DimEntry
from ._tensor_info import TensorInfo


if TYPE_CHECKING:
    from . import Dim


def _safe_index(lst: list, item: Any) -> Optional[int]:
    """
    Helper function to find index of item in list.

    For DimEntry objects, uses __eq__ comparison which properly handles
    both positional and Dim entries.

    Returns the index if found, None if not found.
    """
    for i, list_item in enumerate(lst):
        # Use == for DimEntry objects as they have proper __eq__ implementation
        if isinstance(item, DimEntry) and isinstance(list_item, DimEntry):
            if list_item == item:
                return i
        elif list_item is item:
            return i
    return None


@dataclass
class IndexingInfo:
    can_call_original: bool = False
    advanced_indexing: bool = False
    self_tensor: Optional[torch.Tensor] = None
    flat_inputs: list[Any] = field(default_factory=list)
    result_levels: list[DimEntry] = field(default_factory=list)
    has_device: bool = False


def has_dims(obj: Any) -> bool:
    """
    Check if an object has first-class dimensions.

    This function checks if the object is either a Dim or a functorch Tensor
    that has first-class dimensions, using the proper check_exact methods.
    """
    from . import Dim, Tensor

    return Dim.check_exact(obj) or Tensor.check_exact(obj)


def _bind_dims_to_size(sz: int, sd: int, dims: list, nsz: list, nsd: list) -> None:
    """
    Bind dimensions to size and calculate proper strides for dim packs.
    """
    from . import DimensionBindError

    rhs_prod = 1
    for i, dim in enumerate(dims):
        if not dim.is_bound:
            # Check for multiple unbound dimensions
            for j in range(i + 1, len(dims)):
                if not dims[j].is_bound:
                    raise DimensionBindError(
                        f"cannot infer the sizes of two dimensions at once {dim!r} and {dims[j]!r}"
                    )
                rhs_prod *= dims[j].size

            # Calculate the size for this unbound dimension
            if sz % rhs_prod != 0:
                tup = tuple(dim.size if dim.is_bound else "?" for dim in dims)
                raise DimensionBindError(
                    f"inferred dimension does not evenly fit into larger dimension: {sz} vs {tup}"
                )

            inferred_size = sz // rhs_prod
            dim.size = inferred_size
            rhs_prod = sz
            break
        else:
            rhs_prod *= dim.size

    # Final validation that dimensions match
    if rhs_prod != sz:
        tup = tuple(dims)
        raise DimensionBindError(
            f"Dimension sizes to do not match ({sz} != {rhs_prod}) when matching dimension pack {tup}"
        )

    # Calculate new sizes and strides for each dimension in the pack
    # First calculate all strides by iterating in reverse
    new_strides = [0] * len(dims)
    current_stride = sd
    for i in reversed(range(len(dims))):
        new_strides[i] = current_stride
        current_stride *= dims[i].size

    # Then append sizes and strides in forward order
    for i in range(len(dims)):
        nsz.append(dims[i].size)
        nsd.append(new_strides[i])


def slice_to_tuple(flat_inputs: list) -> tuple:
    return tuple(flat_inputs)


def extractIndices(index: Any, indices: list) -> bool:
    if isinstance(index, tuple):  # mpy::tuple_view::check
        indices.extend(index)
        return True
    elif isinstance(index, torch.Tensor):  # THPVariable_Check
        indices.append(index)
        return False
    elif not hasattr(index, "__iter__") or isinstance(
        index, (str, bytes)
    ):  # !mpy::is_sequence
        indices.append(index)
        return False

    # Handle sequence case (list)
    if isinstance(index, list):
        if len(index) >= 32:
            indices.extend(index)
            return True

        # Check each item in the sequence
        for item in index:
            if (
                isinstance(item, (torch.Tensor, slice))
                or hasattr(item, "__iter__")
                or item is ...
                or item is None
                or has_dims(item)
            ):
                indices.extend(index)
                return True

        # If we got here, treat as single index
        indices.append(index)
        return False

    # Default case
    indices.append(index)
    return False


def getitem(cls: Any, func: Any, types: Any, args: Any, kwargs: Any) -> Any:
    self = args[0]
    index = args[1]

    iinfo = getsetitem(self, index, has_dims(self))
    if iinfo.can_call_original:
        # Call original tensor __getitem__ directly, bypassing __torch_function__
        return torch.Tensor.__getitem__(self, index)

    return invoke_getitem(iinfo)


def setitem(self: Any, index: Any, rhs: Any) -> None:
    """Set values in tensor using first-class dimensions."""
    from . import DimensionBindError, TensorInfo

    iinfo = getsetitem(self, index, has_dims(self) or has_dims(rhs))

    if iinfo.can_call_original:
        # Call original tensor __setitem__ directly, bypassing __torch_function__
        torch._C.TensorBase.__setitem__(self, index, rhs)
        return

    # Handle RHS tensor with dimensions
    rhs_info = TensorInfo.create(rhs, False, False)

    if rhs_info:
        # Check that rhs dimensions are compatible with result dimensions
        for l in rhs_info.levels:
            if not l.is_positional():
                # Find this dimension in result levels
                found = False
                for result_level in iinfo.result_levels:
                    if (
                        not result_level.is_positional()
                        and result_level.dim() is l.dim()
                    ):
                        found = True
                        break

                if not found:
                    # Create tuple representation of result levels for error message
                    result_dims: list[Union[int, Dim]] = []
                    for rl in iinfo.result_levels:
                        if rl.is_positional():
                            result_dims.append(rl.position())
                        else:
                            result_dims.append(rl.dim())

                    raise DimensionBindError(
                        f"rhs of setitem contains dimension {l.dim()!r} which is not in the dimension on the left "
                        f"({tuple(result_dims)!r})"
                    )

        # Match RHS tensor to result levels
        assert rhs_info.tensor is not None, "Cannot match levels on None tensor"
        matched_rhs = _match_levels(
            rhs_info.tensor, rhs_info.levels, iinfo.result_levels
        )
    else:
        matched_rhs = rhs

    # For advanced indexing with dimensions, we need special handling
    if iinfo.advanced_indexing:
        # Use advanced indexing - the flat_inputs already contain matched tensors
        tup = slice_to_tuple(iinfo.flat_inputs)
        if iinfo.self_tensor is None:
            raise RuntimeError("Cannot setitem on None tensor")
        torch._C.TensorBase.__setitem__(iinfo.self_tensor, tup, matched_rhs)
    else:
        # Simple copy operation
        if iinfo.self_tensor is None:
            raise RuntimeError("Cannot copy to None tensor")
        iinfo.self_tensor.copy_(matched_rhs)


def invoke_getitem(iinfo: IndexingInfo) -> Any:
    if iinfo.advanced_indexing:
        self_tensor = iinfo.self_tensor
        tup = slice_to_tuple(iinfo.flat_inputs)
        if self_tensor is None:
            raise RuntimeError("Cannot getitem on None tensor")
        rtensor = self_tensor[tup]
    else:
        rtensor = iinfo.self_tensor  # type: ignore[assignment]
        if rtensor is None:
            raise RuntimeError("Cannot getitem on None tensor")
        # rtensor is now guaranteed to be not None

    # Create a Tensor with the proper dimensions using the class method
    from . import Tensor

    return Tensor.from_positional(rtensor, iinfo.result_levels, iinfo.has_device)


def getsetitem(self: Any, index: Any, tensors_have_dims: bool) -> IndexingInfo:
    from . import DimList  # Import DimList for type checking

    can_call_original_getitem = not tensors_have_dims

    input_list = []
    if has_dims(index):
        input_list.append(index)
    else:
        is_sequence = extractIndices(index, input_list)
        # nothing about first class dims here, fallback to getitem
        if can_call_original_getitem and not is_sequence:
            return IndexingInfo(can_call_original=True)

    # Calculate how many dimensions have been indexed in order to compute the
    # size of ... or expand a potentially unbound dimension list.
    dims_indexed = 0
    expanding_object = -1
    unbound_dim_list = None
    dimlists = []  # Track DimList positions for later processing

    def check_expanding(i: int) -> None:
        nonlocal expanding_object
        if expanding_object != -1:
            from . import DimensionBindError

            raise DimensionBindError(
                f"at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets "
                f"{expanding_object} and {i}"
            )
        expanding_object = i

    def is_dimpack(s: Any) -> bool:
        from . import Dim

        return (
            isinstance(s, (tuple, list))
            and len(s) > 0
            and all(Dim.check_exact(item) for item in s)
        )

    has_dimpacks_or_none = False
    for i, s in enumerate(input_list):
        if has_dims(s):
            can_call_original_getitem = False
            dims_indexed += 1
        elif s is ...:
            check_expanding(i)
        elif isinstance(s, DimList):
            can_call_original_getitem = False
            if not s.is_bound:
                check_expanding(i)
                unbound_dim_list = s
            else:
                dims_indexed += len(s._dims)
            dimlists.append(i)
        elif s is None:
            has_dimpacks_or_none = True
        elif is_dimpack(s):
            can_call_original_getitem = False
            has_dimpacks_or_none = True
            dims_indexed += 1
        else:
            dims_indexed += 1

    # Early return if we can use original getitem
    if can_call_original_getitem:
        return IndexingInfo(can_call_original=True)

    self_info = TensorInfo.create(self, False, True)
    total_dims = len(self_info.levels)  # Total dimensions (positional + named)
    if dims_indexed > total_dims:
        raise ValueError(
            f"at least {dims_indexed} indices were supplied but the tensor only has {total_dims} dimensions"
        )

    # Expand any unbound dimension list, or expand ... into individual : slices.
    expanding_dims = total_dims - dims_indexed
    if expanding_object != -1:
        if unbound_dim_list is not None:
            # Bind unbound dimension list to the expanding dimensions
            unbound_dim_list.bind_len(expanding_dims)
        else:
            # Expand ... into slice(None) objects
            no_slices = [slice(None)] * expanding_dims
            input_list = (
                input_list[:expanding_object]
                + no_slices
                + input_list[expanding_object + 1 :]
            )

    # Flatten out any dimensions stored in dimlist elements directly into the inputs
    # Process in reverse order to maintain indices
    for i in range(len(dimlists) - 1, -1, -1):
        idx = dimlists[i]

        # We added more elements to input because of ...
        # so we need to also adjust the index to get back to where the
        # dimlist existed
        if (
            unbound_dim_list is None
            and expanding_object != -1
            and idx > expanding_object
        ):
            idx += expanding_dims

        dl = input_list[idx]

        # PRIVATE here naughty
        input_list = input_list[:idx] + dl._dims + input_list[idx + 1 :]

    return getsetitem_flat(self_info, input_list, [], [], has_dimpacks_or_none)


def getsetitem_flat(
    self_info: TensorInfo,
    input_list: list,
    keys: list[DimEntry],
    values: list,
    has_dimpacks_or_none: bool,
) -> IndexingInfo:
    from . import Dim

    # Track dimension usage
    seen_dims: list[Any] = []
    seen_dims_nuses: list[int] = []

    def add_dim(dim: Any) -> None:
        # Use safe indexing to avoid triggering __torch_function__ on Dim objects
        idx = _safe_index(seen_dims, dim)
        if idx is not None:
            seen_dims_nuses[idx] += 1
        else:
            seen_dims.append(dim)
            seen_dims_nuses.append(1)

    flat_inputs = []
    tensor_inputs: list[Any] = []
    device_holding_tensor = None

    def append_flat_handle(handle: Any) -> None:
        flat_inputs.append(handle)
        tensor_inputs.append(None)

    def append_tensor_input(ti: TensorInfo) -> None:
        flat_inputs.append(None)
        tensor_inputs.append(ti)
        nonlocal device_holding_tensor
        if ti.has_device and device_holding_tensor is None:
            device_holding_tensor = ti.tensor

    nsz = []
    nsd = []
    if self_info.tensor is None:
        raise RuntimeError("Cannot get size/stride on None tensor")
    sz = self_info.tensor.size()
    sd = self_info.tensor.stride()

    def append_size(i: int) -> None:
        if has_dimpacks_or_none:
            nsz.append(sz[i])
            nsd.append(sd[i])

    input_it = input_list[:]

    def parse_nones() -> None:
        nonlocal input_it
        while input_it and input_it[0] is None:
            append_flat_handle(slice(None))
            nsz.append(1)
            nsd.append(0)
            input_it = input_it[1:]

    def append_item(i: int, arg: Any) -> None:
        if Dim.check_exact(arg):
            d = arg
            if d._size == -1:
                d.size = sz[i]
            add_dim(d)
            append_size(i)
            append_flat_handle(arg)
            return

        info = TensorInfo.create(arg, False, False)
        if info:
            append_size(i)
            append_tensor_input(info)
            for level in info.levels:
                if not level.is_positional():
                    add_dim(level.dim())
            return

        if has_dimpacks_or_none:
            if isinstance(arg, (tuple, list)) and all(Dim.check_exact(d) for d in arg):
                # dim pack
                dim_pack = list(arg)
                for d in dim_pack:
                    add_dim(d)
                    append_flat_handle(d)
                _bind_dims_to_size(sz[i], sd[i], dim_pack, nsz, nsd)
                return

        append_size(i)
        append_flat_handle(arg)

    # Match indexing expressions with tensor dimensions
    for i, level in enumerate(self_info.levels):
        # Use safe indexing to avoid triggering __torch_function__ on DimEntry comparisons
        idx = _safe_index(keys, level)
        if idx is not None:
            append_item(i, values[idx])
        else:
            if level.is_positional():
                parse_nones()
                if not input_it:
                    append_flat_handle(slice(None))
                    append_size(i)
                else:
                    arg = input_it[0]
                    input_it = input_it[1:]
                    append_item(i, arg)
            else:
                add_dim(level.dim())
                append_flat_handle(level.dim())
                append_size(i)

    parse_nones()

    # Restride tensor if needed
    if has_dimpacks_or_none and nsz:
        if self_info.tensor is None:
            raise RuntimeError("Cannot restride None tensor")
        self_tensor = self_info.tensor.as_strided(
            nsz, nsd, self_info.tensor.storage_offset()
        )
    else:
        self_tensor = self_info.tensor

    # Determine result shape and indexing requirements
    result_levels: list[Any] = []
    index_levels = []
    tensor_insert_point = -1
    requires_getindex = False

    def mark_tensor_index() -> None:
        nonlocal tensor_insert_point
        if tensor_insert_point == -1:
            tensor_insert_point = len(result_levels)
        elif tensor_insert_point != len(result_levels):
            tensor_insert_point = 0

    for i, inp in enumerate(flat_inputs):
        if tensor_inputs[i] is not None:
            requires_getindex = True
            mark_tensor_index()
            for level in tensor_inputs[i].levels:
                if level not in index_levels:
                    index_levels.append(level)
        elif Dim.check_exact(inp):
            d = inp
            # Use safe indexing to avoid triggering __torch_function__
            dim_idx = _safe_index(seen_dims, d)
            assert dim_idx is not None, f"Dim {d} not found in seen_dims"
            if seen_dims_nuses[dim_idx] == 1:
                flat_inputs[i] = slice(None)
                result_levels.append(DimEntry(d))
            else:
                requires_getindex = True
                flat_inputs[i] = None
                tensor_inputs[i] = TensorInfo(
                    d._get_range(), [DimEntry(d)], False, None
                )
                if DimEntry(d) not in index_levels:
                    index_levels.append(DimEntry(d))
                mark_tensor_index()
        else:
            if inp != slice(None):
                requires_getindex = True
            if not isinstance(inp, int):
                result_levels.append(DimEntry(-1))

    # Insert indexing dimensions at first tensor use point
    if tensor_insert_point != -1:
        for level in reversed(index_levels):
            result_levels.insert(tensor_insert_point, level)

    # Match tensors to indexing shape
    if requires_getindex:
        for i in range(len(flat_inputs)):
            if tensor_inputs[i] is not None:
                t = tensor_inputs[i].tensor
                assert t is not None, "TensorInfo should have valid tensor data"
                if (
                    not tensor_inputs[i].has_device
                    and device_holding_tensor is not None
                ):
                    t = t.to(device_holding_tensor.device)
                flat_inputs[i] = _match_levels(t, tensor_inputs[i].levels, index_levels)

    # Number positional dimensions correctly
    seen_positionals = 0
    for i in reversed(range(len(result_levels))):
        if result_levels[i].is_positional():
            seen_positionals += 1
            result_levels[i] = DimEntry(-seen_positionals)

    return IndexingInfo(
        can_call_original=False,
        advanced_indexing=requires_getindex,
        self_tensor=self_tensor,
        flat_inputs=flat_inputs,
        result_levels=result_levels,
        has_device=self_info.has_device,
    )
