#!/usr/bin/env python3
"""
ML Inference Server V8 Lite
============================
FastAPI server for V8 Lite model with 24 features (18 technical + 6 String Theory).

Endpoints:
- POST /predict - Get trading signal from model
- GET /health - Health check
- POST /load-model - Load a new model
- GET /model-info - Get current model info
"""

import os
import sys
import json
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
from collections import deque
import logging

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn

# ============================================================
# CONFIGURATION
# ============================================================

CONFIG = {
    "host": "0.0.0.0",
    "port": 3058,
    "model_dir": "/var/www/html/bestrading.cuttalo.com/models/btc_v8_lite",
    "default_model": "model_best.pt",
    "lookback": 60,
    "num_features": 24,
    "hidden_dim": 256,
    "num_heads": 4,
    "num_layers": 3,
    "dropout": 0.18,
}

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ml-inference-v8")


# ============================================================
# STRING THEORY CALCULATIONS
# ============================================================

class StringTheoryFast:
    """Optimized String Theory calculations for live inference."""

    @staticmethod
    def hurst_fast(prices: np.ndarray) -> float:
        n = len(prices)
        if n < 20:
            return 0.5
        returns = np.diff(np.log(prices + 1e-10))
        if len(returns) < 10:
            return 0.5
        mean_adj = returns - np.mean(returns)
        cumsum = np.cumsum(mean_adj)
        R = np.max(cumsum) - np.min(cumsum)
        S = np.std(returns)
        if S < 1e-10 or R < 1e-10:
            return 0.5
        return np.clip(np.log(R / S) / np.log(n), 0.1, 0.9)

    @staticmethod
    def catastrophe_fast(prices: np.ndarray) -> float:
        if len(prices) < 20:
            return 1.0
        returns = np.diff(prices[-20:]) / prices[-20:-1]
        vol = np.std(returns)
        bias = np.mean(returns) / (vol + 1e-10)
        return np.tanh(vol * 50 + abs(bias) * 2.5)

    @staticmethod
    def entropy_fast(returns: np.ndarray) -> float:
        if len(returns) < 10:
            return 0.5
        hist, _ = np.histogram(returns, bins=15, density=True)
        hist = hist[hist > 0]
        if len(hist) == 0:
            return 0.5
        hist = hist / hist.sum()
        entropy = -np.sum(hist * np.log(hist + 1e-10))
        return np.clip(entropy / np.log(15), 0, 1)

    @staticmethod
    def kelly_fast(returns: np.ndarray) -> float:
        if len(returns) < 10:
            return 0.0
        mean = np.mean(returns)
        var = np.var(returns)
        if var < 1e-10:
            return 0.0
        return np.clip(mean / var, -0.15, 0.15)

    @staticmethod
    def wasserstein_fast(prices: np.ndarray) -> float:
        if len(prices) < 40:
            return 0.5
        r1 = np.diff(prices[-20:]) / prices[-20:-1]
        r2 = np.diff(prices[-40:-20]) / prices[-40:-21]
        q1 = np.percentile(r1, [25, 50, 75])
        q2 = np.percentile(r2, [25, 50, 75])
        return np.tanh(np.mean(np.abs(q1 - q2)) * 100)

    @staticmethod
    def instability_fast(prices: np.ndarray) -> float:
        if len(prices) < 20:
            return 0.5
        returns = np.diff(prices[-20:]) / prices[-20:-1]
        vol = np.tanh(np.std(returns) * 50)
        mom = np.tanh(abs(returns[-1] - returns[0]) * 100) if len(returns) > 1 else 0
        return np.clip(0.6 * vol + 0.4 * mom, 0, 1)


# ============================================================
# MODEL ARCHITECTURE V8
# ============================================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[0, :, 1::2] = torch.cos(position * div_term)
        else:
            pe[0, :, 1::2] = torch.cos(position * div_term[:-1])
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class TradingTransformerV8(nn.Module):
    """V8 Transformer with attention pooling."""

    def __init__(self, config: dict):
        super().__init__()
        input_dim = config.get("num_features", 24) + 1  # +1 for position
        hidden_dim = config.get("hidden_dim", 256)
        dropout = config.get("dropout", 0.18)

        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout / 2),
        )
        self.pos_encoder = PositionalEncoding(hidden_dim, config.get("lookback", 60))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=config.get("num_heads", 4),
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, config.get("num_layers", 3))

        self.attention_pool = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Softmax(dim=1))

        head_hidden = hidden_dim // 2
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden), nn.LayerNorm(head_hidden),
            nn.GELU(), nn.Dropout(dropout), nn.Linear(head_hidden, 2)
        )
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden), nn.LayerNorm(head_hidden),
            nn.GELU(), nn.Dropout(dropout), nn.Linear(head_hidden, 1)
        )

    def forward(self, x):
        x = self.input_proj(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        attn = self.attention_pool(x)
        x = torch.sum(x * attn, dim=1)

        policy = self.policy_head(x)
        action_mean = torch.tanh(policy[:, 0:1])
        action_log_std = torch.clamp(policy[:, 1:2], -3.0, 0.5)
        value = self.value_head(x)

        return action_mean, action_log_std, value

    def get_action(self, x: torch.Tensor, deterministic: bool = True):
        """Get action for inference."""
        action_mean, action_log_std, value = self.forward(x)

        if deterministic:
            return action_mean.squeeze(), value.squeeze()

        action_std = torch.exp(action_log_std)
        dist = torch.distributions.Normal(action_mean, action_std)
        action = dist.sample()
        action = torch.clamp(action, -1, 1)

        return action.squeeze(), value.squeeze()


# ============================================================
# INFERENCE ENGINE V8
# ============================================================

class InferenceEngineV8:
    """Manages model loading and inference for V8 with 24 features."""

    def __init__(self, config: dict):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model: Optional[TradingTransformerV8] = None
        self.model_info: Dict[str, Any] = {}

        # Price history buffer per pair
        self.price_buffers: Dict[str, deque] = {}
        self.lookback = config.get("lookback", 60)

        logger.info(f"InferenceEngineV8 initialized on {self.device}")

    def load_model(self, model_path: str) -> bool:
        """Load a V8 PyTorch model."""
        try:
            if not os.path.exists(model_path):
                logger.error(f"Model not found: {model_path}")
                return False

            checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)

            # Get config from checkpoint or use defaults
            model_config = checkpoint.get("config", self.config)

            # Create model
            self.model = TradingTransformerV8(model_config)
            self.model.load_state_dict(checkpoint["model_state_dict"])
            self.model.to(self.device)
            self.model.eval()

            # Store model info
            best_reward = checkpoint.get("best_reward", "unknown")
            if hasattr(best_reward, 'item'):
                best_reward = best_reward.item()
            episode = checkpoint.get("episode", "unknown")
            if hasattr(episode, 'item'):
                episode = episode.item()

            self.model_info = {
                "path": model_path,
                "loaded_at": datetime.now().isoformat(),
                "episode": episode,
                "best_reward": float(best_reward) if isinstance(best_reward, (int, float)) else best_reward,
                "version": "V8_LITE",
                "features": 24,
                "config": {k: str(v) for k, v in model_config.items() if not callable(v)},
            }

            logger.info(f"V8 Model loaded: {model_path}")
            return True

        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            import traceback
            traceback.print_exc()
            return False

    def compute_features(self, prices: List[float], current_position: float = 0) -> np.ndarray:
        """Compute 24 features from price history (matches training exactly)."""
        n = len(prices)
        features = np.zeros((n, 25), dtype=np.float32)  # 24 features + 1 position

        prices_arr = np.array(prices, dtype=np.float64)
        price_series = pd.Series(prices_arr)
        pct_change = price_series.pct_change().values

        # Returns (0-4): 1, 5, 15, 30, 60 min
        for i, period in enumerate([1, 5, 15, 30, 60]):
            if n > period:
                ret = np.zeros(n)
                ret[period:] = (prices_arr[period:] - prices_arr[:-period]) / prices_arr[:-period]
                features[:, i] = np.clip(ret * 10, -3, 3)

        # Volatility (5-8): 5, 15, 30, 60 min annualized
        for i, period in enumerate([5, 15, 30, 60]):
            if n > period:
                vol = price_series.pct_change().rolling(period).std().values * np.sqrt(525600)
                features[:, 5 + i] = np.nan_to_num(np.clip(vol / 100, 0, 3))

        # Trend (9-11): SMA crossovers
        sma_10 = price_series.rolling(10).mean().values
        sma_30 = price_series.rolling(30).mean().values
        sma_60 = price_series.rolling(60).mean().values

        features[:, 9] = np.nan_to_num(np.clip((sma_10 - sma_30) / (sma_30 + 1e-10) * 10, -3, 3))
        features[:, 10] = np.nan_to_num(np.clip((sma_30 - sma_60) / (sma_60 + 1e-10) * 10, -3, 3))
        features[:, 11] = np.nan_to_num(np.clip((prices_arr - sma_30) / (sma_30 + 1e-10) * 10, -3, 3))

        # Momentum (12-14): RSI, ROC, acceleration
        delta = np.diff(prices_arr, prepend=prices_arr[0])
        gain = np.where(delta > 0, delta, 0)
        loss = np.where(delta < 0, -delta, 0)
        avg_gain = pd.Series(gain).rolling(14).mean().values
        avg_loss = pd.Series(loss).rolling(14).mean().values
        rs = np.nan_to_num(avg_gain / (avg_loss + 1e-10))
        rsi = 100 - 100 / (1 + rs)
        features[:, 12] = np.clip((rsi - 50) / 50, -1, 1)

        roc = np.zeros(n)
        if n > 10:
            roc[10:] = (prices_arr[10:] - prices_arr[:-10]) / prices_arr[:-10]
        features[:, 13] = np.clip(roc * 20, -3, 3)

        mom = np.zeros(n)
        if n > 5:
            mom[5:] = pct_change[5:] - pct_change[:-5]
        features[:, 14] = np.nan_to_num(np.clip(mom * 100, -3, 3))

        # Regime (15-17): vol ratio, range, zscore
        vol_short = pd.Series(pct_change).rolling(10).std().values
        vol_long = pd.Series(pct_change).rolling(30).std().values
        features[:, 15] = np.nan_to_num(np.clip((vol_short - vol_long) / (vol_long + 1e-10), -3, 3))

        for j in range(20, n):
            features[j, 16] = np.clip((np.max(prices_arr[j-20:j]) - np.min(prices_arr[j-20:j])) / prices_arr[j] * 20, 0, 3)

        for j in range(30, n):
            mean = np.mean(prices_arr[j-30:j])
            std = np.std(prices_arr[j-30:j])
            if std > 0:
                features[j, 17] = np.clip((prices_arr[j] - mean) / std / 2, -2, 2)

        # String Theory (18-23): Compute for last point only (optimization)
        if n >= 60:
            pw = prices_arr[-60:]
            returns = np.diff(pw) / pw[:-1]

            features[-1, 18] = (StringTheoryFast.hurst_fast(pw) - 0.5) * 2
            features[-1, 19] = 1 - StringTheoryFast.catastrophe_fast(pw)
            features[-1, 20] = StringTheoryFast.entropy_fast(returns) * 2 - 1
            features[-1, 21] = StringTheoryFast.kelly_fast(returns) * 6
            features[-1, 22] = StringTheoryFast.wasserstein_fast(pw) * 2 - 1
            features[-1, 23] = StringTheoryFast.instability_fast(pw) * 2 - 1

            # Backfill String Theory for lookback window
            for i in range(max(0, n - self.lookback), n - 1):
                features[i, 18:24] = features[-1, 18:24]

        # Position (24)
        features[:, 24] = current_position

        return np.nan_to_num(features, nan=0.0, posinf=3.0, neginf=-3.0).astype(np.float32)

    def predict(self, pair: str, price: float, current_position: float = 0) -> Dict[str, Any]:
        """Get prediction for a trading pair."""
        if self.model is None:
            return {"error": "No model loaded"}

        # Initialize buffer if needed
        if pair not in self.price_buffers:
            self.price_buffers[pair] = deque(maxlen=max(120, self.lookback * 2))

        # Add price to buffer
        self.price_buffers[pair].append(price)

        # Check if we have enough data
        if len(self.price_buffers[pair]) < self.lookback:
            return {
                "signal": "WAIT",
                "action": 0.0,
                "confidence": 0.0,
                "value": 0.0,
                "pair": pair,
                "price": price,
                "position": current_position,
                "history_length": len(self.price_buffers[pair]),
                "message": f"Accumulating data: {len(self.price_buffers[pair])}/{self.lookback}"
            }

        # Compute features
        prices = list(self.price_buffers[pair])
        features = self.compute_features(prices, current_position)

        # Take last lookback points
        features = features[-self.lookback:]

        # Convert to tensor
        x = torch.FloatTensor(features).unsqueeze(0).to(self.device)

        # Get prediction
        with torch.no_grad():
            action, value = self.model.get_action(x, deterministic=True)

        action_val = action.item()
        value_val = value.item()

        # Map action to signal
        if action_val > 0.3:
            signal = "LONG"
            confidence = min(1.0, (action_val - 0.3) / 0.7)
        elif action_val < -0.3:
            signal = "SHORT"
            confidence = min(1.0, (-action_val - 0.3) / 0.7)
        else:
            signal = "HOLD"
            confidence = 1.0 - abs(action_val) / 0.3

        return {
            "signal": signal,
            "action": action_val,
            "confidence": confidence,
            "value": value_val,
            "pair": pair,
            "price": price,
            "position": current_position,
            "history_length": len(self.price_buffers[pair]),
            "version": "V8_LITE",
        }

    def clear_history(self, pair: Optional[str] = None):
        """Clear price history buffer."""
        if pair:
            if pair in self.price_buffers:
                self.price_buffers[pair].clear()
        else:
            self.price_buffers.clear()


# ============================================================
# FASTAPI APP
# ============================================================

app = FastAPI(
    title="ML Inference Server V8 Lite",
    description="Trading model inference API with 24 features (String Theory)",
    version="8.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global inference engine
engine: Optional[InferenceEngineV8] = None


# Request/Response models
class PredictRequest(BaseModel):
    pair: str
    price: float
    position: float = 0.0


class PredictResponse(BaseModel):
    signal: str
    action: float
    confidence: float
    value: float
    pair: str
    price: float
    position: float
    history_length: int
    message: Optional[str] = None
    error: Optional[str] = None
    version: Optional[str] = None


class LoadModelRequest(BaseModel):
    model_path: str


class HealthResponse(BaseModel):
    status: str
    model_loaded: bool
    device: str
    uptime: str
    version: str


# Startup time
startup_time = datetime.now()


@app.on_event("startup")
async def startup_event():
    """Initialize engine and load default model."""
    global engine
    engine = InferenceEngineV8(CONFIG)

    # Try to load default model
    default_model = os.path.join(CONFIG["model_dir"], CONFIG["default_model"])
    if os.path.exists(default_model):
        engine.load_model(default_model)
        logger.info(f"Default V8 model loaded: {default_model}")
    else:
        logger.warning(f"Default model not found: {default_model}")


@app.get("/health", response_model=HealthResponse)
async def health():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "model_loaded": engine.model is not None if engine else False,
        "device": str(engine.device) if engine else "N/A",
        "uptime": str(datetime.now() - startup_time),
        "version": "V8_LITE",
    }


@app.get("/model-info")
async def model_info():
    """Get current model information."""
    if not engine or not engine.model:
        return {"error": "No model loaded"}
    return engine.model_info


@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
    """Get trading signal prediction."""
    if not engine:
        raise HTTPException(status_code=500, detail="Engine not initialized")

    result = engine.predict(request.pair, request.price, request.position)

    if "error" in result:
        raise HTTPException(status_code=400, detail=result["error"])

    return result


@app.post("/load-model")
async def load_model(request: LoadModelRequest):
    """Load a new model."""
    if not engine:
        raise HTTPException(status_code=500, detail="Engine not initialized")

    success = engine.load_model(request.model_path)

    if not success:
        raise HTTPException(status_code=400, detail=f"Failed to load model: {request.model_path}")

    return {"success": True, "model_info": engine.model_info}


@app.post("/clear-history")
async def clear_history(pair: Optional[str] = None):
    """Clear price history buffer."""
    if not engine:
        raise HTTPException(status_code=500, detail="Engine not initialized")

    engine.clear_history(pair)
    return {"success": True, "cleared": pair or "all"}


@app.get("/models")
async def list_models():
    """List available models."""
    model_dir = CONFIG["model_dir"]
    if not os.path.exists(model_dir):
        return {"models": [], "directory": model_dir}

    models = [f for f in os.listdir(model_dir) if f.endswith(".pt")]
    return {
        "models": models,
        "directory": model_dir,
        "current": engine.model_info.get("path") if engine and engine.model_info else None,
    }


# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":
    print("=" * 60)
    print("ML INFERENCE SERVER V8 LITE")
    print("24 Features: 18 Technical + 6 String Theory")
    print("=" * 60)
    print(f"Host: {CONFIG['host']}")
    print(f"Port: {CONFIG['port']}")
    print(f"Model dir: {CONFIG['model_dir']}")
    print("=" * 60)

    uvicorn.run(
        "inference_server_v8:app",
        host=CONFIG["host"],
        port=CONFIG["port"],
        reload=False,
        log_level="info"
    )
