import math
import os

import torch
from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps
from torch.utils._ordered_set import OrderedSet

from .flop_counter import flop_registry


aten = torch.ops.aten

_FLOAT_TYPES = OrderedSet(
    [
        torch.float16,
        torch.bfloat16,
        torch.float32,
        torch.float64,
    ]
)

# This value is hard-coded here:
# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117
_PYTORCH_MIN_ALLOCATE = (
    2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1
)

# No fall-back kernel needed/exists for view ops
_VIEW_OPS = OrderedSet(
    [
        aten.lift_fresh,
        aten.t,
        aten.transpose,
        aten.view,
        aten.detach,
        aten._unsafe_view,
        aten.split,
        aten.adjoint,
        aten.as_strided,
        aten.diagonal,
        aten.expand,
        aten.expand_as,
        aten.movedim,
        aten.permute,
        aten.select,
        aten.squeeze,
        aten.mT,
        aten.mH,
        aten.real,
        aten.imag,
        aten.view_as,
        aten.unflatten,
        aten.unfold,
        aten.unbind,
        aten.unsqueeze,
        aten.vsplit,
        aten.hsplit,
        aten.split_with_sizes,
        aten.swapaxes,
        aten.swapdims,
        aten.chunk,
    ]
)
# We can ignore benchmarking tensor create ops
_CREATE_OPS = OrderedSet(
    [
        aten.randint,
        aten.randn,
        aten.rand,
        aten.randn_like,
        aten.rand_like,
        aten.randint_like,
        aten.arange,
        aten.ones_like,
        aten.zeros_like,
    ]
)

_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS


def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float:  # type: ignore[no-untyped-def]
    """
    Estimates the compute time of an aten operator.

    Args:
        func_packet: The operator overload packet.
        args: The arguments to the operator.
        kwargs: The keyword arguments to the operator.
        out: The output of the operator.
        out_dtypes: The output data types.

    Returns:
        float: The estimated compute time in nanoseconds.
    """
    if func_packet in flop_registry:
        assert len(out_dtypes) == 1, (
            f"Only support single out dtype got {out_dtypes} for {func_packet}"
        )
        dtype = out_dtypes.pop()
        # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s
        peak_gpu_flops = get_device_tflops(dtype) * 1e15
        # We can expect to achieve 75% of theoretical peak flops
        factor = 0.75
        peak_empirical_flops = factor * peak_gpu_flops
        flop_count_func = flop_registry[func_packet]
        # We divide by a factor of 2 to get the MACs (multiply and accumulate)
        flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2
        # We multiply by 1e9 to get the time in nano seconds
        compute_time = (flop_count / peak_empirical_flops) * 1e9
        return compute_time
    return 0.0


def get_num_bytes(t: torch.Tensor) -> int:
    """
    Calculates the memory consumption of a tensor.

    Args:
        t (torch.Tensor): The input tensor.

    Returns:
        int: The memory consumption of the tensor in bytes.
    """
    num_bytes = t.untyped_storage().nbytes()
    mem_consumed = math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE
    return mem_consumed


def get_transfer_time(flat_args_kwargs, flat_outs) -> float:  # type: ignore[no-untyped-def]
    """
    Estimates the memory transfer time of input and output tensors.

    Args:
        flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments.
        flat_outs (List[torch.Tensor]): The flat list of outputs.

    Returns:
        float: The estimated memory transfer time in nanoseconds.
    """
    gpu_memory_bandwidth = get_gpu_dram_gbps()
    read_bytes = sum(
        get_num_bytes(t) for t in flat_args_kwargs if isinstance(t, torch.Tensor)
    )
    write_bytes = sum(
        get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor)
    )
    counted_bytes = read_bytes + write_bytes
    # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds
    transfer_time = counted_bytes / gpu_memory_bandwidth
    return transfer_time
