import sys
from functools import wraps

from sentry_sdk.integrations import DidNotEnable
from sentry_sdk.utils import reraise
from ..spans import (
    invoke_agent_span,
    end_invoke_agent_span,
    handoff_span,
)
from ..utils import _record_exception_on_span

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from typing import Any, Optional

    from sentry_sdk.tracing import Span

try:
    import agents
except ImportError:
    raise DidNotEnable("OpenAI Agents not installed")


def _patch_agent_run() -> None:
    """
    Patches AgentRunner methods to create agent invocation spans.
    This directly patches the execution flow to track when agents start and stop.
    """

    # Store original methods
    original_run_single_turn = agents.run.AgentRunner._run_single_turn
    original_execute_handoffs = agents._run_impl.RunImpl.execute_handoffs
    original_execute_final_output = agents._run_impl.RunImpl.execute_final_output

    def _start_invoke_agent_span(
        context_wrapper: "agents.RunContextWrapper",
        agent: "agents.Agent",
        kwargs: "dict[str, Any]",
    ) -> "Span":
        """Start an agent invocation span"""
        # Store the agent on the context wrapper so we can access it later
        context_wrapper._sentry_current_agent = agent
        span = invoke_agent_span(context_wrapper, agent, kwargs)
        context_wrapper._sentry_agent_span = span

        return span

    def _has_active_agent_span(context_wrapper: "agents.RunContextWrapper") -> bool:
        """Check if there's an active agent span for this context"""
        return getattr(context_wrapper, "_sentry_current_agent", None) is not None

    def _get_current_agent(
        context_wrapper: "agents.RunContextWrapper",
    ) -> "Optional[agents.Agent]":
        """Get the current agent from context wrapper"""
        return getattr(context_wrapper, "_sentry_current_agent", None)

    @wraps(
        original_run_single_turn.__func__
        if hasattr(original_run_single_turn, "__func__")
        else original_run_single_turn
    )
    async def patched_run_single_turn(
        cls: "agents.Runner", *args: "Any", **kwargs: "Any"
    ) -> "Any":
        """Patched _run_single_turn that creates agent invocation spans"""
        agent = kwargs.get("agent")
        context_wrapper = kwargs.get("context_wrapper")
        should_run_agent_start_hooks = kwargs.get("should_run_agent_start_hooks")

        span = getattr(context_wrapper, "_sentry_agent_span", None)
        # Start agent span when agent starts (but only once per agent)
        if should_run_agent_start_hooks and agent and context_wrapper:
            # End any existing span for a different agent
            if _has_active_agent_span(context_wrapper):
                current_agent = _get_current_agent(context_wrapper)
                if current_agent and current_agent != agent:
                    end_invoke_agent_span(context_wrapper, current_agent)

            span = _start_invoke_agent_span(context_wrapper, agent, kwargs)
            agent._sentry_agent_span = span

        # Call original method with all the correct parameters
        try:
            result = await original_run_single_turn(*args, **kwargs)
        except Exception as exc:
            if span is not None and span.timestamp is None:
                _record_exception_on_span(span, exc)
                end_invoke_agent_span(context_wrapper, agent)

            reraise(*sys.exc_info())

        return result

    @wraps(
        original_execute_handoffs.__func__
        if hasattr(original_execute_handoffs, "__func__")
        else original_execute_handoffs
    )
    async def patched_execute_handoffs(
        cls: "agents.Runner", *args: "Any", **kwargs: "Any"
    ) -> "Any":
        """Patched execute_handoffs that creates handoff spans and ends agent span for handoffs"""

        context_wrapper = kwargs.get("context_wrapper")
        run_handoffs = kwargs.get("run_handoffs")
        agent = kwargs.get("agent")

        # Create Sentry handoff span for the first handoff (agents library only processes the first one)
        if run_handoffs:
            first_handoff = run_handoffs[0]
            handoff_agent_name = first_handoff.handoff.agent_name
            handoff_span(context_wrapper, agent, handoff_agent_name)

        # Call original method with all parameters
        try:
            result = await original_execute_handoffs(*args, **kwargs)

        finally:
            # End span for current agent after handoff processing is complete
            if agent and context_wrapper and _has_active_agent_span(context_wrapper):
                end_invoke_agent_span(context_wrapper, agent)

        return result

    @wraps(
        original_execute_final_output.__func__
        if hasattr(original_execute_final_output, "__func__")
        else original_execute_final_output
    )
    async def patched_execute_final_output(
        cls: "agents.Runner", *args: "Any", **kwargs: "Any"
    ) -> "Any":
        """Patched execute_final_output that ends agent span for final outputs"""

        agent = kwargs.get("agent")
        context_wrapper = kwargs.get("context_wrapper")
        final_output = kwargs.get("final_output")

        # Call original method with all parameters
        try:
            result = await original_execute_final_output(*args, **kwargs)
        finally:
            # End span for current agent after final output processing is complete
            if agent and context_wrapper and _has_active_agent_span(context_wrapper):
                end_invoke_agent_span(context_wrapper, agent, final_output)

        return result

    # Apply patches
    agents.run.AgentRunner._run_single_turn = classmethod(patched_run_single_turn)
    agents._run_impl.RunImpl.execute_handoffs = classmethod(patched_execute_handoffs)
    agents._run_impl.RunImpl.execute_final_output = classmethod(
        patched_execute_final_output
    )
