# pylint: disable=useless-parent-delegation
from __future__ import annotations

from typing import Optional, Union
from typing_extensions import Self

import torch


_POOL_HANDLE = tuple[int, int]


def graph_pool_handle() -> _POOL_HANDLE:
    """
    Return an opaque token representing the id of a graph memory pool.
    """
    return torch._C._mtia_graphPoolHandle()


class MTIAGraph(torch._C._MTIAGraph):
    """
    Wrapper around a MTIA graph.
    """

    def __new__(cls, keep_graph: bool = False) -> Self:
        return super().__new__(cls, keep_graph)

    def capture_begin(self, pool: _POOL_HANDLE) -> None:
        """
        Begin capturing a MTIA graph.
        """
        super().capture_begin(pool)

    def capture_end(self) -> None:
        """
        End the capture of a MTIA graph.
        """
        super().capture_end()

    def instantiate(self) -> None:
        """
        Instantiate the captured MTIA graph.
        """
        super().instantiate()

    def replay(self) -> None:
        """
        Replay the captured MTIA graph.
        """
        super().replay()

    def reset(self) -> None:
        """
        Destroy the captured graph and reset the states.
        """
        super().reset()

    def pool(self) -> _POOL_HANDLE:
        """
        Return an opaque token representing the id of this graph's memory pool
        """
        return super().pool()


class graph:
    default_capture_stream: Optional[torch.mtia.Stream] = None

    def __init__(
        self,
        mtia_graph: MTIAGraph,
        pool: Optional[_POOL_HANDLE] = None,
        stream: Optional[torch.mtia.Stream] = None,
    ):
        if self.__class__.default_capture_stream is None:
            self.__class__.default_capture_stream = torch.mtia.current_stream()

        self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = (
            () if pool is None else (pool,)
        )
        self.capture_stream = (
            stream if stream is not None else self.__class__.default_capture_stream
        )
        assert self.capture_stream is not None
        self.stream_ctx = torch.mtia.stream(self.capture_stream)
        self.mtia_graph = mtia_graph

    def __enter__(self) -> None:
        torch.mtia.synchronize()
        torch.mtia.empty_cache()

        self.stream_ctx.__enter__()

        pool_arg = self.pool[0] if self.pool else (0, 0)
        self.mtia_graph.capture_begin(pool_arg)

    def __exit__(self, *args: object) -> None:
        self.mtia_graph.capture_end()
        self.stream_ctx.__exit__(*args)


__all__ = [
    "MTIAGraph",
    "graph",
    "graph_pool_handle",
]
