#!/usr/bin/env python3
"""
V9 Realistic Backtest
=====================

Simula ESATTAMENTE il comportamento del paper trader V9:
- Candele 4h
- Volatility Gate + Trend Detection + Exit Optimizer
- Fee realistiche (0.5% per trade = 1% round-trip)
- Slippage simulato
- Periodi casuali di 10 giorni

Autore: Claude Code
"""

import pandas as pd
import numpy as np
import lightgbm as lgb
import torch
import torch.nn as nn
import joblib
from datetime import datetime
import random
import warnings
warnings.filterwarnings('ignore')

# Paths
MODEL_DIR = '/var/www/html/bestrading.cuttalo.com/models/btc_v9'
FEATURES_PATH = f'{MODEL_DIR}/data_4h_features.csv'

# Trading parameters (ESATTAMENTE come paper trader)
INITIAL_CAPITAL = 10000
POSITION_SIZE = 0.20  # 20% del capitale per trade
FEE_RATE = 0.005  # 0.5% per side (Kraken taker)
SLIPPAGE_BPS = 2  # 2 basis points di slippage
CONFIDENCE_THRESHOLD = 0.70  # 70% minimo per entrare

# Feature lists (from training)
VOL_FEATURES = [
    'vol_3', 'vol_6', 'vol_12', 'vol_24',
    'atr_3', 'atr_6', 'atr_12',
    'vol_regime', 'range_pct', 'range_sma',
    'volume_ratio',
    'hour_sin', 'hour_cos',
    'ret_1', 'ret_3', 'ret_6',
]

TREND_FEATURES = [
    'ret_1', 'ret_2', 'ret_3', 'ret_6', 'ret_12', 'ret_24',
    'sma_cross_6_12', 'sma_cross_12_24', 'sma_cross_24_48',
    'price_vs_sma12', 'price_vs_sma24',
    'rsi_norm', 'roc_3', 'roc_6', 'roc_12',
    'vol_6', 'vol_12', 'vol_regime',
    'atr_6', 'atr_12',
    'volume_ratio',
    'hour_sin', 'hour_cos',
]


# Exit Optimizer Network (EXACT same as training)
class ExitOptimizerNet(nn.Module):
    def __init__(self, input_dim=8, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
        )
        self.policy = nn.Linear(hidden_dim, 2)  # Hold, Exit
        self.value = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        features = self.net(x)
        action_logits = self.policy(features)
        value = self.value(features)
        return action_logits, value

    def get_action(self, x, deterministic=True):
        action_logits, value = self.forward(x)
        probs = torch.softmax(action_logits, dim=-1)
        if deterministic:
            action = probs.argmax(dim=-1)
        else:
            action = torch.multinomial(probs, 1).squeeze(-1)
        return action, probs, value


def load_models():
    """Load all V9 models and scalers."""
    print("Loading V9 models...")

    # Volatility Gate
    vol_gate = lgb.Booster(model_file=f'{MODEL_DIR}/vol_gate_model.txt')
    vol_scaler = joblib.load(f'{MODEL_DIR}/vol_scaler.pkl')

    # Trend detectors
    long_model = lgb.Booster(model_file=f'{MODEL_DIR}/long_model.txt')
    short_model = lgb.Booster(model_file=f'{MODEL_DIR}/short_model.txt')
    trend_scaler = joblib.load(f'{MODEL_DIR}/trend_scaler.pkl')

    # Exit optimizer
    exit_model = ExitOptimizerNet()
    checkpoint = torch.load(f'{MODEL_DIR}/exit_optimizer_best.pt', map_location='cpu', weights_only=False)
    if 'model_state_dict' in checkpoint:
        exit_model.load_state_dict(checkpoint['model_state_dict'])
    else:
        exit_model.load_state_dict(checkpoint)
    exit_model.eval()

    print("  All models and scalers loaded successfully")
    return vol_gate, vol_scaler, long_model, short_model, trend_scaler, exit_model


def load_data():
    """Load pre-computed features from training."""
    print("Loading pre-computed features...")
    df = pd.read_csv(FEATURES_PATH)
    df['timestamp'] = pd.to_datetime(df['timestamp'])
    df.set_index('timestamp', inplace=True)

    print(f"  Total candles: {len(df)}")
    print(f"  Date range: {df.index[0]} to {df.index[-1]}")

    return df


def get_signal(vol_gate, vol_scaler, long_model, short_model, trend_scaler, row):
    """Get trading signal from models."""

    # 1. Volatility Gate
    vol_X = row[VOL_FEATURES].values.reshape(1, -1)
    vol_X_scaled = vol_scaler.transform(vol_X)
    vol_prob = vol_gate.predict(vol_X_scaled)[0]

    if vol_prob < 0.5:
        return 'HOLD', 0, f'Vol gate closed ({vol_prob:.1%})', vol_prob

    # 2. Trend Detection
    trend_X = row[TREND_FEATURES].values.reshape(1, -1)
    trend_X_scaled = trend_scaler.transform(trend_X)

    long_prob = long_model.predict(trend_X_scaled)[0]
    short_prob = short_model.predict(trend_X_scaled)[0]

    # Determine signal
    if long_prob >= CONFIDENCE_THRESHOLD and long_prob > short_prob:
        return 'LONG', long_prob, f'Long signal ({long_prob:.1%})', vol_prob
    elif short_prob >= CONFIDENCE_THRESHOLD and short_prob > long_prob:
        return 'SHORT', short_prob, f'Short signal ({short_prob:.1%})', vol_prob
    else:
        return 'HOLD', max(long_prob, short_prob), f'No confident signal (L:{long_prob:.1%} S:{short_prob:.1%})', vol_prob


def should_exit(exit_model, position, entry_price, current_price, bars_held, max_profit_seen, row):
    """Check if we should exit using Exit Optimizer."""

    if position == 0:
        return False, 0, max_profit_seen

    # Current PnL
    if position == 1:  # Long
        pnl_pct = (current_price / entry_price - 1)
    else:  # Short
        pnl_pct = (entry_price / current_price - 1)

    # Track max profit
    if pnl_pct > max_profit_seen:
        max_profit_seen = pnl_pct

    # Build state (EXACTLY like inference server)
    state = torch.tensor([[
        pnl_pct,                          # Current PnL
        max_profit_seen,                  # Max profit seen
        pnl_pct - max_profit_seen,        # Drawdown from max
        bars_held / 12,                   # Normalized time (12 = 48h)
        row.get('vol_6', 0.5),            # Volatility
        row.get('ret_1', 0),              # Recent return
        row.get('rsi_norm', 0),           # RSI normalized
        position,                         # Direction
    ]], dtype=torch.float32)

    with torch.no_grad():
        action, probs, value = exit_model.get_action(state, deterministic=True)

    # Force exit after 12 candles (48h)
    should_exit_now = action.item() == 1 or bars_held >= 12

    return should_exit_now, probs[0, 1].item(), max_profit_seen


def run_backtest(df, vol_gate, vol_scaler, long_model, short_model, trend_scaler, exit_model, start_idx, num_candles=60):
    """Run backtest on a specific period."""

    # Extract period
    end_idx = start_idx + num_candles
    if end_idx > len(df):
        return None

    period_df = df.iloc[start_idx:end_idx].copy()

    # Trading state
    capital = INITIAL_CAPITAL
    position = 0  # 0=flat, 1=long, -1=short
    entry_price = 0
    btc_amount = 0
    bars_held = 0
    max_profit_seen = 0
    trades = []

    for i in range(len(period_df)):
        row = period_df.iloc[i]
        current_price = row['close']

        # Check exit first
        if position != 0:
            bars_held += 1

            # Exit check
            should_exit_now, exit_prob, max_profit_seen = should_exit(
                exit_model, position, entry_price, current_price, bars_held,
                max_profit_seen, row
            )
            if should_exit_now:
                # Close position
                if position == 1:  # Close long
                    exit_price = current_price * (1 - SLIPPAGE_BPS/10000)
                    gross_value = btc_amount * exit_price
                    fee = gross_value * FEE_RATE
                    net_value = gross_value - fee
                    pnl = net_value - (entry_price * btc_amount)
                else:  # Close short
                    exit_price = current_price * (1 + SLIPPAGE_BPS/10000)
                    price_diff = entry_price - exit_price
                    gross_pnl = btc_amount * price_diff
                    fee = abs(btc_amount * exit_price) * FEE_RATE
                    pnl = gross_pnl - fee
                    net_value = capital + pnl  # For shorts

                capital = capital + pnl if position == -1 else net_value

                trades.append({
                    'entry_time': period_df.index[i - bars_held],
                    'exit_time': period_df.index[i],
                    'direction': 'LONG' if position == 1 else 'SHORT',
                    'entry_price': entry_price,
                    'exit_price': exit_price,
                    'pnl': pnl,
                    'pnl_pct': pnl / INITIAL_CAPITAL * 100,
                    'bars_held': bars_held,
                })

                position = 0
                btc_amount = 0
                entry_price = 0
                bars_held = 0
                max_profit_seen = 0
                continue

        # Entry logic (only if flat)
        if position == 0:
            signal, confidence, reason, vol_prob = get_signal(
                vol_gate, vol_scaler, long_model, short_model, trend_scaler, row
            )

            if signal == 'LONG':
                # Open long
                trade_capital = capital * POSITION_SIZE
                entry_price = current_price * (1 + SLIPPAGE_BPS/10000)
                fee = trade_capital * FEE_RATE
                btc_amount = (trade_capital - fee) / entry_price
                position = 1
                bars_held = 0
                max_profit_seen = 0

            elif signal == 'SHORT':
                # Open short
                trade_capital = capital * POSITION_SIZE
                entry_price = current_price * (1 - SLIPPAGE_BPS/10000)
                fee = trade_capital * FEE_RATE
                btc_amount = (trade_capital - fee) / entry_price
                position = -1
                bars_held = 0
                max_profit_seen = 0

    # Close any open position at end
    if position != 0:
        current_price = period_df.iloc[-1]['close']
        if position == 1:
            exit_price = current_price * (1 - SLIPPAGE_BPS/10000)
            gross_value = btc_amount * exit_price
            fee = gross_value * FEE_RATE
            pnl = gross_value - fee - (entry_price * btc_amount)
        else:
            exit_price = current_price * (1 + SLIPPAGE_BPS/10000)
            price_diff = entry_price - exit_price
            gross_pnl = btc_amount * price_diff
            fee = abs(btc_amount * exit_price) * FEE_RATE
            pnl = gross_pnl - fee

        capital += pnl
        trades.append({
            'entry_time': period_df.index[-bars_held-1] if bars_held < len(period_df) else period_df.index[0],
            'exit_time': period_df.index[-1],
            'direction': 'LONG' if position == 1 else 'SHORT',
            'entry_price': entry_price,
            'exit_price': exit_price,
            'pnl': pnl,
            'pnl_pct': pnl / INITIAL_CAPITAL * 100,
            'bars_held': bars_held,
            'forced_close': True,
        })

    # Calculate metrics
    if len(trades) == 0:
        return {
            'period_start': period_df.index[0],
            'period_end': period_df.index[-1],
            'num_trades': 0,
            'win_rate': 0,
            'total_pnl': 0,
            'total_pnl_pct': 0,
            'final_capital': capital,
            'trades': [],
        }

    wins = sum(1 for t in trades if t['pnl'] > 0)
    total_pnl = sum(t['pnl'] for t in trades)

    return {
        'period_start': period_df.index[0],
        'period_end': period_df.index[-1],
        'num_trades': len(trades),
        'wins': wins,
        'losses': len(trades) - wins,
        'win_rate': wins / len(trades) * 100,
        'total_pnl': total_pnl,
        'total_pnl_pct': total_pnl / INITIAL_CAPITAL * 100,
        'avg_pnl': total_pnl / len(trades),
        'final_capital': capital,
        'trades': trades,
    }


def main():
    print("=" * 70)
    print("V9 REALISTIC BACKTEST")
    print("=" * 70)
    print(f"Initial Capital: €{INITIAL_CAPITAL:,.2f}")
    print(f"Position Size: {POSITION_SIZE*100:.0f}%")
    print(f"Fee Rate: {FEE_RATE*100:.2f}% per side")
    print(f"Slippage: {SLIPPAGE_BPS} bps")
    print(f"Confidence Threshold: {CONFIDENCE_THRESHOLD*100:.0f}%")
    print("=" * 70)

    # Load models
    vol_gate, vol_scaler, long_model, short_model, trend_scaler, exit_model = load_models()

    # Load data
    df = load_data()

    # Run multiple random period tests
    num_tests = 20
    period_days = 10
    candles_per_period = period_days * 6  # 6 candles per day (4h)

    # Calculate valid start range
    max_start = len(df) - candles_per_period

    print(f"\nRunning {num_tests} random {period_days}-day backtests...")
    print("-" * 70)

    all_results = []

    for test_num in range(num_tests):
        # Random start
        start_idx = random.randint(0, max_start)

        result = run_backtest(df, vol_gate, vol_scaler, long_model, short_model, trend_scaler, exit_model, start_idx, candles_per_period)

        if result:
            all_results.append(result)

            # Print result
            pnl_color = '\033[92m' if result['total_pnl'] >= 0 else '\033[91m'
            reset = '\033[0m'

            print(f"Test {test_num+1:2d}: {result['period_start'].strftime('%Y-%m-%d')} to {result['period_end'].strftime('%Y-%m-%d')} | "
                  f"Trades: {result['num_trades']:2d} | "
                  f"Win Rate: {result['win_rate']:5.1f}% | "
                  f"PnL: {pnl_color}€{result['total_pnl']:+8.2f}{reset} ({result['total_pnl_pct']:+5.2f}%)")

    # Summary
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)

    if all_results:
        total_trades = sum(r['num_trades'] for r in all_results)
        total_wins = sum(r.get('wins', 0) for r in all_results)
        total_pnl = sum(r['total_pnl'] for r in all_results)
        avg_pnl_per_test = total_pnl / len(all_results)
        profitable_periods = sum(1 for r in all_results if r['total_pnl'] > 0)

        print(f"Tests run:          {len(all_results)}")
        print(f"Total trades:       {total_trades}")
        if total_trades > 0:
            print(f"Total wins:         {total_wins} ({total_wins/total_trades*100:.1f}% win rate)")
            print(f"Avg trades/period:  {total_trades/len(all_results):.1f}")
        print(f"Profitable periods: {profitable_periods}/{len(all_results)} ({profitable_periods/len(all_results)*100:.1f}%)")
        print(f"Total PnL:          €{total_pnl:+.2f}")
        print(f"Avg PnL per test:   €{avg_pnl_per_test:+.2f}")
        if total_trades > 0:
            print(f"Avg PnL per trade:  €{total_pnl/total_trades:+.2f}")

        # Best and worst
        best = max(all_results, key=lambda x: x['total_pnl'])
        worst = min(all_results, key=lambda x: x['total_pnl'])

        print(f"\nBest period:  {best['period_start'].strftime('%Y-%m-%d')} to {best['period_end'].strftime('%Y-%m-%d')} | €{best['total_pnl']:+.2f}")
        print(f"Worst period: {worst['period_start'].strftime('%Y-%m-%d')} to {worst['period_end'].strftime('%Y-%m-%d')} | €{worst['total_pnl']:+.2f}")

    print("=" * 70)

    return all_results


if __name__ == '__main__':
    results = main()
