#!/usr/bin/env python3
"""
Backtest V6 Regime Models
Uses the actual PyTorch models for accurate simulation.
"""

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import psycopg2
from pathlib import Path
from typing import Tuple, Dict
from datetime import datetime

# Config
MODELS_DIR = Path('/var/www/html/bestrading.cuttalo.com/models/btc_v6')
DATA_FILE = Path('/var/www/html/bestrading.cuttalo.com/scripts/prices_BTC_EUR_2025_full.csv')

DB_CONFIG = {
    'host': 'localhost',
    'port': 5432,
    'dbname': 'bestrading',
    'user': 'bestrading',
    'password': 'UQyvjfZIvUtpqlksPfKeq2MmXgGiG3y5'
}

# Trading parameters
INITIAL_CAPITAL = 10000  # EUR
FEE_RATE = 0.004  # 0.4% round-trip (0.2% per side)
SLIPPAGE = 0.0005  # 0.05%
POSITION_SIZE = 0.30  # 30% of equity per trade
MIN_POSITION_CHANGE = 0.5  # Only trade if signal changes by > 50%
MIN_SIGNAL_INTERVAL = 120  # 2 hours between trades
MIN_EXPECTED_PROFIT = 0.005  # 0.5% minimum expected move to justify fees

# Model architecture (must match training)
class TradingTransformer(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 256, num_heads: int = 4, num_layers: int = 2):
        super().__init__()
        self.embed = nn.Linear(input_dim + 1, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads,
            dim_feedforward=hidden_dim * 4, dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.actor_mean = nn.Linear(hidden_dim, 1)
        self.actor_log_std = nn.Parameter(torch.zeros(1) - 0.5)
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.embed(x)
        h = self.transformer(h)
        h = h[:, -1]  # Take last timestep
        mean = torch.tanh(self.actor_mean(h))
        return mean

def compute_features(prices: np.ndarray, idx: int, lookback: int = 20) -> np.ndarray:
    """Compute features for a single timestep."""
    if idx < 60:
        return None

    start = max(0, idx - lookback)
    price_window = prices[start:idx+1]

    features = []

    # Returns at multiple scales
    for period in [1, 5, 10, 20, 60]:
        if idx >= period:
            ret = (prices[idx] - prices[idx - period]) / (prices[idx - period] + 1e-8)
        else:
            ret = 0
        features.append(ret)

    # Volatility
    price_window = prices[max(0, idx-30):idx+1]
    if len(price_window) > 1:
        ret1 = np.diff(price_window) / (price_window[:-1] + 1e-8)
        features.append(np.std(ret1[-10:]) if len(ret1) >= 10 else np.std(ret1))
        features.append(np.std(ret1))
    else:
        features.extend([0, 0])

    # MA ratios
    for period in [10, 30]:
        if idx >= period:
            ma = np.mean(prices[idx-period:idx])
            features.append((prices[idx] - ma) / (ma + 1e-8))
        else:
            features.append(0)

    # Normalize
    features = np.array(features, dtype=np.float32)
    features = np.clip(features, -3, 3)

    return features

def detect_regime(features: np.ndarray) -> str:
    """Detect market regime from features."""
    if features is None:
        return 'scalper'

    ret20 = features[3]  # 20-period return
    ret60 = features[4]  # 60-period return
    vol30 = features[6]  # 30-period volatility

    # Adjusted thresholds for minute data (smaller moves)
    if vol30 > 0.02:  # High volatility (>2% std of returns)
        return 'volatile'
    elif ret20 > 0.005 and ret60 > 0.003:  # Bullish (>0.5% return)
        return 'bullish'
    elif ret20 < -0.005 and ret60 < -0.003:  # Bearish
        return 'bearish'
    elif abs(ret60) < 0.002 and vol30 < 0.005:  # Ranging (very tight)
        return 'ranging'
    else:
        return 'scalper'

def load_models() -> Dict[str, TradingTransformer]:
    """Load all regime models."""
    models = {}
    device = torch.device('cpu')

    for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
        path = MODELS_DIR / f'model_{regime}_v6.pt'
        if path.exists():
            checkpoint = torch.load(path, map_location=device)
            input_dim = checkpoint.get('input_dim', 9)

            model = TradingTransformer(input_dim).to(device)
            model.load_state_dict(checkpoint['model_state'])
            model.eval()

            models[regime] = model
            print(f"  Loaded {regime} model")

    return models

def run_backtest():
    """Run backtest simulation."""
    print("=" * 70)
    print("🎯 BTC/EUR V6 REGIME MODELS BACKTEST")
    print("=" * 70)

    # Load models
    print("\n📦 Loading models...")
    models = load_models()

    if len(models) < 5:
        print("❌ Not all models loaded!")
        return

    # Load data
    print("\n📈 Loading price data...")
    df = pd.read_csv(DATA_FILE)
    prices = df['close'].values.astype(np.float32)
    timestamps = df['timestamp'].values

    print(f"   Loaded {len(prices):,} candles")
    print(f"   Date range: {datetime.fromtimestamp(timestamps[0])} - {datetime.fromtimestamp(timestamps[-1])}")
    print(f"   Price range: €{prices.min():.0f} - €{prices.max():.0f}")

    # Trading simulation
    print("\n" + "=" * 70)
    print("📊 RUNNING SIMULATION...")
    print("=" * 70)

    capital = INITIAL_CAPITAL
    position = 0.0  # BTC held
    entry_price = 0.0
    trades = []
    equity_curve = [(timestamps[60], capital)]

    lookback = 20
    regime_counts = {'bullish': 0, 'bearish': 0, 'ranging': 0, 'volatile': 0, 'scalper': 0}
    last_trade_idx = 0
    current_signal = 0.0  # Track current signal

    for i in range(61, len(prices)):
        current_price = prices[i]

        # Compute features
        features = compute_features(prices, i, lookback)
        if features is None:
            continue

        # Detect regime
        regime = detect_regime(features)
        regime_counts[regime] += 1

        # Get model for current regime
        model = models.get(regime, models['scalper'])

        # Prepare input (features + position)
        feat_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
        pos_tensor = torch.tensor([position / (capital / current_price + 1e-8)], dtype=torch.float32).unsqueeze(0)

        # Build lookback sequence (simplified - just repeat current features)
        seq_feat = feat_tensor.unsqueeze(0).expand(1, lookback, -1)
        seq_pos = pos_tensor.unsqueeze(1).expand(1, lookback, 1)
        input_tensor = torch.cat([seq_feat, seq_pos], dim=-1)

        # Get model prediction
        with torch.no_grad():
            target_position = model(input_tensor).item()  # [-1, 1]

        # Convert to position size in BTC
        max_position_value = capital * POSITION_SIZE / current_price
        desired_btc = target_position * max_position_value

        # Check if we should trade
        signal_change = abs(target_position - current_signal)
        time_since_trade = i - last_trade_idx

        # Estimate expected profit based on recent volatility
        if i >= 60:
            recent_returns = np.diff(prices[i-60:i]) / prices[i-60:i-1]
            expected_move = np.std(recent_returns) * np.sqrt(MIN_SIGNAL_INTERVAL) * 2  # 2 sigma move
        else:
            expected_move = 0.01

        # Only trade if:
        # 1. Signal changed significantly (> MIN_POSITION_CHANGE)
        # 2. Enough time has passed (> MIN_SIGNAL_INTERVAL minutes)
        # 3. Expected profit exceeds costs
        should_trade = (signal_change > MIN_POSITION_CHANGE and
                       time_since_trade >= MIN_SIGNAL_INTERVAL and
                       expected_move > MIN_EXPECTED_PROFIT)

        # Execute trade if conditions met
        position_change = desired_btc - position
        if should_trade and abs(position_change) > 0.0001:  # Min trade size
            last_trade_idx = i
            current_signal = target_position
            # Calculate costs
            trade_value = abs(position_change) * current_price
            fee = trade_value * FEE_RATE
            slip = trade_value * SLIPPAGE

            # Execute
            if position_change > 0:  # Buying
                cost = position_change * current_price * (1 + SLIPPAGE) + fee
                if cost <= capital:
                    capital -= cost
                    position = desired_btc
                    entry_price = current_price * (1 + SLIPPAGE)
            else:  # Selling
                proceeds = abs(position_change) * current_price * (1 - SLIPPAGE) - fee
                capital += proceeds
                position = desired_btc

            trades.append({
                'time': timestamps[i],
                'price': current_price,
                'position': position,
                'capital': capital,
                'regime': regime,
                'target': target_position
            })

        # Track equity
        equity = capital + position * current_price
        equity_curve.append((timestamps[i], equity))

        # Progress
        if i % 50000 == 0:
            print(f"   Progress: {i:,}/{len(prices):,} | Equity: €{equity:,.0f}")

    # Close final position
    if position != 0:
        proceeds = position * prices[-1] * (1 - SLIPPAGE) - abs(position) * prices[-1] * FEE_RATE
        capital += proceeds
        position = 0

    final_equity = capital

    # Calculate metrics
    returns = [(equity_curve[i][1] - equity_curve[i-1][1]) / equity_curve[i-1][1]
               for i in range(1, len(equity_curve)) if equity_curve[i-1][1] > 0]

    total_return = (final_equity - INITIAL_CAPITAL) / INITIAL_CAPITAL * 100
    max_equity = max(e[1] for e in equity_curve)
    max_drawdown = min((e[1] - max_equity) / max_equity for e in equity_curve) * 100

    sharpe = np.mean(returns) / (np.std(returns) + 1e-8) * np.sqrt(365 * 24 * 60) if returns else 0

    # Print results
    print("\n" + "=" * 70)
    print("📊 RISULTATI BACKTEST")
    print("=" * 70)

    print(f"""
    Capitale iniziale:    €{INITIAL_CAPITAL:,.0f}
    Capitale finale:      €{final_equity:,.0f}
    ─────────────────────────────────────────
    Return totale:        {total_return:+.2f}%
    Max Drawdown:         {max_drawdown:.2f}%
    Sharpe Ratio (ann.):  {sharpe:.2f}
    ─────────────────────────────────────────
    Trades totali:        {len(trades):,}
    Fee pagate:           €{len(trades) * INITIAL_CAPITAL * POSITION_SIZE * FEE_RATE:.0f} (stima)
    ─────────────────────────────────────────
    """)

    print("    📈 Regime Distribution:")
    for regime, count in sorted(regime_counts.items(), key=lambda x: -x[1]):
        pct = count / sum(regime_counts.values()) * 100
        print(f"       {regime:12s}: {count:6,} ({pct:5.1f}%)")

    # Save equity curve
    output_file = MODELS_DIR / 'backtest_results.csv'
    with open(output_file, 'w') as f:
        f.write('timestamp,equity\n')
        for ts, eq in equity_curve[::100]:  # Sample every 100
            f.write(f"{ts},{eq:.2f}\n")
    print(f"\n    💾 Equity curve saved to: {output_file}")

    return {
        'total_return': total_return,
        'max_drawdown': max_drawdown,
        'sharpe': sharpe,
        'trades': len(trades)
    }

if __name__ == '__main__':
    run_backtest()
