#!/usr/bin/env python3
"""
ML Inference Server V7
======================
FastAPI server that loads the trained PyTorch model and provides
inference endpoints for the orchestrator.

Endpoints:
- POST /predict - Get trading signal from model
- GET /health - Health check
- POST /load-model - Load a new model
- GET /model-info - Get current model info
"""

import os
import sys
import json
import torch
import torch.nn as nn
import numpy as np
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
from collections import deque
import logging

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn

# ============================================================
# CONFIGURATION
# ============================================================

CONFIG = {
    "host": "0.0.0.0",
    "port": 3058,
    "model_dir": "/var/www/html/bestrading.cuttalo.com/models/btc_v7",
    "default_model": "model_best.pt",
    "lookback": 60,
    "num_features": 12,
    "hidden_dim": 256,
    "num_heads": 4,
    "num_layers": 3,
}

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ml-inference")


# ============================================================
# MODEL ARCHITECTURE (must match training)
# ============================================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, :x.size(1)]


class TradingTransformer(nn.Module):
    """Transformer-based trading model."""

    def __init__(self, config: dict):
        super().__init__()
        self.config = config

        input_dim = config.get("num_features", 12) + 1  # features + position
        hidden_dim = config.get("hidden_dim", 256)
        lookback = config.get("lookback", 60)

        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.pos_encoder = PositionalEncoding(hidden_dim, lookback)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=config.get("num_heads", 4),
            dim_feedforward=hidden_dim * 4,
            dropout=config.get("dropout", 0.1),
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            config.get("num_layers", 3)
        )

        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 2)
        )

        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, x: torch.Tensor):
        x = self.input_proj(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = x[:, -1, :]

        policy_out = self.policy_head(x)
        action_mean = torch.tanh(policy_out[:, 0:1])
        action_log_std = torch.clamp(policy_out[:, 1:2], -2, 0.5)

        value = self.value_head(x)

        return action_mean, action_log_std, value

    def get_action(self, x: torch.Tensor, deterministic: bool = True):
        """Get action for inference."""
        action_mean, action_log_std, value = self.forward(x)

        if deterministic:
            return action_mean.squeeze(), value.squeeze()

        action_std = torch.exp(action_log_std)
        dist = torch.distributions.Normal(action_mean, action_std)
        action = dist.sample()
        action = torch.clamp(action, -1, 1)

        return action.squeeze(), value.squeeze()


# ============================================================
# INFERENCE ENGINE
# ============================================================

class InferenceEngine:
    """Manages model loading and inference."""

    def __init__(self, config: dict):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model: Optional[TradingTransformer] = None
        self.model_info: Dict[str, Any] = {}

        # Price history buffer per pair
        self.price_buffers: Dict[str, deque] = {}
        self.lookback = config.get("lookback", 60)

        logger.info(f"InferenceEngine initialized on {self.device}")

    def load_model(self, model_path: str) -> bool:
        """Load a PyTorch model."""
        try:
            if not os.path.exists(model_path):
                logger.error(f"Model not found: {model_path}")
                return False

            checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)

            # Get config from checkpoint or use defaults
            model_config = checkpoint.get("config", self.config)

            # Create model
            self.model = TradingTransformer(model_config)
            self.model.load_state_dict(checkpoint["model_state_dict"])
            self.model.to(self.device)
            self.model.eval()

            # Store model info (convert numpy types to Python types)
            best_reward = checkpoint.get("best_reward", "unknown")
            if hasattr(best_reward, 'item'):
                best_reward = best_reward.item()
            episode = checkpoint.get("episode", "unknown")
            if hasattr(episode, 'item'):
                episode = episode.item()

            self.model_info = {
                "path": model_path,
                "loaded_at": datetime.now().isoformat(),
                "episode": episode,
                "best_reward": float(best_reward) if isinstance(best_reward, (int, float)) else best_reward,
                "config": {k: str(v) for k, v in model_config.items() if not callable(v)},
            }

            logger.info(f"Model loaded: {model_path}")
            return True

        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            return False

    def compute_features(self, prices: List[float], current_position: float = 0) -> np.ndarray:
        """Compute features from price history."""
        n = len(prices)
        num_features = self.config.get("num_features", 12)
        features = np.zeros((n, num_features + 1))  # +1 for position

        prices_arr = np.array(prices)

        # Returns at different timeframes
        for i, period in enumerate([1, 5, 15, 30, 60]):
            if i < num_features and n > period:
                ret = np.zeros(n)
                ret[period:] = (prices_arr[period:] - prices_arr[:-period]) / prices_arr[:-period]
                features[:, i] = ret

        # Volatility at different timeframes
        for i, period in enumerate([5, 15, 30]):
            idx = 5 + i
            if idx < num_features and n > period:
                vol = np.zeros(n)
                for j in range(period, n):
                    vol[j] = np.std(np.diff(prices_arr[j-period:j]) / prices_arr[j-period:j-1])
                features[:, idx] = vol

        # Price momentum (SMA crossover)
        if 8 < num_features and n >= 30:
            import pandas as pd
            sma_short = pd.Series(prices_arr).rolling(10).mean().values
            sma_long = pd.Series(prices_arr).rolling(30).mean().values
            features[:, 8] = np.nan_to_num((sma_short - sma_long) / sma_long)

        # RSI-like indicator
        if 9 < num_features and n > 14:
            delta = np.diff(prices_arr, prepend=prices_arr[0])
            gain = np.where(delta > 0, delta, 0)
            loss = np.where(delta < 0, -delta, 0)
            import pandas as pd
            avg_gain = pd.Series(gain).rolling(14).mean().values
            avg_loss = pd.Series(loss).rolling(14).mean().values
            rs = np.nan_to_num(avg_gain / (avg_loss + 1e-10))
            features[:, 9] = (100 - 100 / (1 + rs)) / 100 - 0.5

        # Volume proxy
        if 10 < num_features:
            features[:, 10] = np.abs(np.diff(prices_arr, prepend=prices_arr[0])) / prices_arr

        # Trend strength
        if 11 < num_features and n > 20:
            ret_20 = np.zeros(n)
            ret_20[20:] = (prices_arr[20:] - prices_arr[:-20]) / prices_arr[:-20]
            features[:, 11] = ret_20

        # Normalize features
        for i in range(num_features):
            col = features[:, i]
            std = np.std(col[~np.isnan(col)])
            if std > 0:
                features[:, i] = np.clip(col / (std * 3), -1, 1)

        # Add position
        features[:, -1] = current_position

        return np.nan_to_num(features).astype(np.float32)

    def predict(self, pair: str, price: float, current_position: float = 0) -> Dict[str, Any]:
        """Get prediction for a trading pair."""
        if self.model is None:
            return {"error": "No model loaded"}

        # Initialize buffer if needed
        if pair not in self.price_buffers:
            self.price_buffers[pair] = deque(maxlen=self.lookback)

        # Add price to buffer
        self.price_buffers[pair].append(price)

        # Check if we have enough data
        if len(self.price_buffers[pair]) < self.lookback:
            return {
                "signal": "WAIT",
                "action": 0.0,
                "confidence": 0.0,
                "value": 0.0,
                "pair": pair,
                "price": price,
                "position": current_position,
                "history_length": len(self.price_buffers[pair]),
                "message": f"Accumulating data: {len(self.price_buffers[pair])}/{self.lookback}"
            }

        # Compute features
        prices = list(self.price_buffers[pair])
        features = self.compute_features(prices, current_position)

        # Take last lookback points
        features = features[-self.lookback:]

        # Convert to tensor
        x = torch.FloatTensor(features).unsqueeze(0).to(self.device)

        # Get prediction
        with torch.no_grad():
            action, value = self.model.get_action(x, deterministic=True)

        action_val = action.item()
        value_val = value.item()

        # Map action to signal
        if action_val > 0.3:
            signal = "LONG"
            confidence = min(1.0, (action_val - 0.3) / 0.7)
        elif action_val < -0.3:
            signal = "SHORT"
            confidence = min(1.0, (-action_val - 0.3) / 0.7)
        else:
            signal = "HOLD"
            confidence = 1.0 - abs(action_val) / 0.3

        return {
            "signal": signal,
            "action": action_val,
            "confidence": confidence,
            "value": value_val,
            "pair": pair,
            "price": price,
            "position": current_position,
            "history_length": len(self.price_buffers[pair]),
        }

    def clear_history(self, pair: Optional[str] = None):
        """Clear price history buffer."""
        if pair:
            if pair in self.price_buffers:
                self.price_buffers[pair].clear()
        else:
            self.price_buffers.clear()


# ============================================================
# FASTAPI APP
# ============================================================

app = FastAPI(
    title="ML Inference Server V7",
    description="Trading model inference API",
    version="7.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global inference engine
engine: Optional[InferenceEngine] = None


# Request/Response models
class PredictRequest(BaseModel):
    pair: str
    price: float
    position: float = 0.0


class PredictResponse(BaseModel):
    signal: str
    action: float
    confidence: float
    value: float
    pair: str
    price: float
    position: float
    history_length: int
    message: Optional[str] = None
    error: Optional[str] = None


class LoadModelRequest(BaseModel):
    model_path: str


class HealthResponse(BaseModel):
    status: str
    model_loaded: bool
    device: str
    uptime: str


# Startup time
startup_time = datetime.now()


@app.on_event("startup")
async def startup_event():
    """Initialize engine and load default model."""
    global engine
    engine = InferenceEngine(CONFIG)

    # Try to load default model
    default_model = os.path.join(CONFIG["model_dir"], CONFIG["default_model"])
    if os.path.exists(default_model):
        engine.load_model(default_model)
        logger.info(f"Default model loaded: {default_model}")
    else:
        logger.warning(f"Default model not found: {default_model}")


@app.get("/health", response_model=HealthResponse)
async def health():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "model_loaded": engine.model is not None if engine else False,
        "device": str(engine.device) if engine else "N/A",
        "uptime": str(datetime.now() - startup_time),
    }


@app.get("/model-info")
async def model_info():
    """Get current model information."""
    if not engine or not engine.model:
        return {"error": "No model loaded"}
    return engine.model_info


@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
    """Get trading signal prediction."""
    if not engine:
        raise HTTPException(status_code=500, detail="Engine not initialized")

    result = engine.predict(request.pair, request.price, request.position)

    if "error" in result:
        raise HTTPException(status_code=400, detail=result["error"])

    return result


@app.post("/load-model")
async def load_model(request: LoadModelRequest):
    """Load a new model."""
    if not engine:
        raise HTTPException(status_code=500, detail="Engine not initialized")

    success = engine.load_model(request.model_path)

    if not success:
        raise HTTPException(status_code=400, detail=f"Failed to load model: {request.model_path}")

    return {"success": True, "model_info": engine.model_info}


@app.post("/clear-history")
async def clear_history(pair: Optional[str] = None):
    """Clear price history buffer."""
    if not engine:
        raise HTTPException(status_code=500, detail="Engine not initialized")

    engine.clear_history(pair)
    return {"success": True, "cleared": pair or "all"}


@app.get("/models")
async def list_models():
    """List available models."""
    model_dir = CONFIG["model_dir"]
    if not os.path.exists(model_dir):
        return {"models": [], "directory": model_dir}

    models = [f for f in os.listdir(model_dir) if f.endswith(".pt")]
    return {
        "models": models,
        "directory": model_dir,
        "current": engine.model_info.get("path") if engine and engine.model_info else None,
    }


# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":
    print("="*60)
    print("ML INFERENCE SERVER V7")
    print("="*60)
    print(f"Host: {CONFIG['host']}")
    print(f"Port: {CONFIG['port']}")
    print(f"Model dir: {CONFIG['model_dir']}")
    print("="*60)

    uvicorn.run(
        "inference_server:app",
        host=CONFIG["host"],
        port=CONFIG["port"],
        reload=False,
        log_level="info"
    )
