#!/usr/bin/env python3
"""
Backtest V6 Regime Models - Regime Transition Strategy
Only trades when regime changes, not on every signal.
"""

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime

# Config
MODELS_DIR = Path('/var/www/html/bestrading.cuttalo.com/models/btc_v6')
DATA_FILE = Path('/var/www/html/bestrading.cuttalo.com/scripts/prices_BTC_EUR_2025_full.csv')

# Trading parameters
INITIAL_CAPITAL = 10000  # EUR
FEE_RATE = 0.004  # 0.4% round-trip
SLIPPAGE = 0.0005  # 0.05%
POSITION_SIZE = 0.50  # 50% of equity (larger positions, fewer trades)
MIN_REGIME_DURATION = 60  # Regime must persist for 60 minutes before acting

# Model architecture
class TradingTransformer(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 256, num_heads: int = 4, num_layers: int = 2):
        super().__init__()
        self.embed = nn.Linear(input_dim + 1, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads,
            dim_feedforward=hidden_dim * 4, dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.actor_mean = nn.Linear(hidden_dim, 1)
        self.actor_log_std = nn.Parameter(torch.zeros(1) - 0.5)
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.embed(x)
        h = self.transformer(h)
        h = h[:, -1]
        mean = torch.tanh(self.actor_mean(h))
        return mean

def compute_features(prices: np.ndarray, idx: int, lookback: int = 20) -> np.ndarray:
    if idx < 60:
        return None

    features = []

    # Returns at multiple scales
    for period in [1, 5, 10, 20, 60]:
        if idx >= period:
            ret = (prices[idx] - prices[idx - period]) / (prices[idx - period] + 1e-8)
        else:
            ret = 0
        features.append(ret)

    # Volatility
    price_window = prices[max(0, idx-30):idx+1]
    if len(price_window) > 1:
        ret1 = np.diff(price_window) / (price_window[:-1] + 1e-8)
        features.append(np.std(ret1[-10:]) if len(ret1) >= 10 else np.std(ret1))
        features.append(np.std(ret1))
    else:
        features.extend([0, 0])

    # MA ratios
    for period in [10, 30]:
        if idx >= period:
            ma = np.mean(prices[idx-period:idx])
            features.append((prices[idx] - ma) / (ma + 1e-8))
        else:
            features.append(0)

    features = np.array(features, dtype=np.float32)
    features = np.clip(features, -3, 3)

    return features

def detect_regime(features: np.ndarray) -> str:
    """Detect market regime - more aggressive thresholds."""
    if features is None:
        return 'scalper'

    ret20 = features[3]  # 20-period return
    ret60 = features[4]  # 60-period return
    vol30 = features[6]  # 30-period volatility

    # More aggressive thresholds for stronger signals
    if vol30 > 0.015:  # High volatility
        return 'volatile'
    elif ret20 > 0.003 and ret60 > 0.002:  # Bullish trend
        return 'bullish'
    elif ret20 < -0.003 and ret60 < -0.002:  # Bearish trend
        return 'bearish'
    elif abs(ret60) < 0.001 and vol30 < 0.003:  # Tight range
        return 'ranging'
    else:
        return 'scalper'

def load_models():
    models = {}
    device = torch.device('cpu')

    for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
        path = MODELS_DIR / f'model_{regime}_v6.pt'
        if path.exists():
            checkpoint = torch.load(path, map_location=device)
            input_dim = checkpoint.get('input_dim', 9)
            model = TradingTransformer(input_dim).to(device)
            model.load_state_dict(checkpoint['model_state'])
            model.eval()
            models[regime] = model
            print(f"  Loaded {regime} model")

    return models

def run_backtest():
    print("=" * 70)
    print("REGIME TRANSITION STRATEGY BACKTEST")
    print("=" * 70)

    models = load_models()

    # Load data
    df = pd.read_csv(DATA_FILE)
    prices = df['close'].values.astype(np.float32)
    timestamps = df['timestamp'].values

    print(f"\n   Data: {len(prices):,} candles")
    print(f"   Period: {datetime.fromtimestamp(timestamps[0])} - {datetime.fromtimestamp(timestamps[-1])}")

    # Trading simulation
    capital = INITIAL_CAPITAL
    position = 0.0  # BTC held
    trades = []
    equity_curve = [(timestamps[60], capital)]

    current_regime = 'scalper'
    regime_start_idx = 61
    last_trade_regime = 'scalper'

    regime_counts = {'bullish': 0, 'bearish': 0, 'ranging': 0, 'volatile': 0, 'scalper': 0}

    for i in range(61, len(prices)):
        current_price = prices[i]
        features = compute_features(prices, i, 20)
        if features is None:
            continue

        regime = detect_regime(features)
        regime_counts[regime] += 1

        # Check for regime change
        if regime != current_regime:
            # Regime changed - check if previous regime lasted long enough
            regime_duration = i - regime_start_idx

            if regime_duration >= MIN_REGIME_DURATION:
                # Act on the regime transition
                should_trade = False
                target_position = 0.0

                # Entering bullish from anything
                if regime == 'bullish' and last_trade_regime != 'bullish':
                    target_position = 1.0  # Go long
                    should_trade = True

                # Entering bearish from anything
                elif regime == 'bearish' and last_trade_regime != 'bearish':
                    target_position = -1.0  # Go short
                    should_trade = True

                # Exiting trending regimes
                elif last_trade_regime in ['bullish', 'bearish'] and regime in ['ranging', 'scalper', 'volatile']:
                    target_position = 0.0  # Close position
                    should_trade = True

                if should_trade:
                    # Get model prediction to confirm direction
                    model = models.get(regime, models['scalper'])
                    feat_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
                    pos_tensor = torch.tensor([position / (capital / current_price + 1e-8)], dtype=torch.float32).unsqueeze(0)
                    seq_feat = feat_tensor.unsqueeze(0).expand(1, 20, -1)
                    seq_pos = pos_tensor.unsqueeze(1).expand(1, 20, 1)
                    input_tensor = torch.cat([seq_feat, seq_pos], dim=-1)

                    with torch.no_grad():
                        model_signal = model(input_tensor).item()

                    # Only trade if model agrees with regime direction
                    if target_position > 0 and model_signal > 0:  # Long confirmed
                        max_btc = capital * POSITION_SIZE / current_price
                        desired_btc = max_btc

                        if desired_btc != position:
                            position_change = desired_btc - position
                            trade_value = abs(position_change) * current_price
                            fee = trade_value * FEE_RATE
                            slip = trade_value * SLIPPAGE

                            if position_change > 0:  # Buying
                                cost = position_change * current_price * (1 + SLIPPAGE) + fee
                                if cost <= capital:
                                    capital -= cost
                                    position = desired_btc
                                    last_trade_regime = regime
                            else:  # Selling
                                proceeds = abs(position_change) * current_price * (1 - SLIPPAGE) - fee
                                capital += proceeds
                                position = desired_btc
                                last_trade_regime = regime

                            trades.append({
                                'time': timestamps[i],
                                'price': current_price,
                                'position': position,
                                'regime': regime,
                                'type': 'LONG ENTRY'
                            })

                    elif target_position < 0 and model_signal < 0:  # Short confirmed
                        max_btc = capital * POSITION_SIZE / current_price
                        desired_btc = -max_btc

                        if desired_btc != position:
                            position_change = desired_btc - position
                            trade_value = abs(position_change) * current_price
                            fee = trade_value * FEE_RATE

                            # Close any existing long first
                            if position > 0:
                                proceeds = position * current_price * (1 - SLIPPAGE) - position * current_price * FEE_RATE
                                capital += proceeds
                                position = 0

                            # Go short (simulate with negative position tracking)
                            # For simplicity, we'll just track as negative and settle on close
                            position = desired_btc
                            last_trade_regime = regime

                            trades.append({
                                'time': timestamps[i],
                                'price': current_price,
                                'position': position,
                                'regime': regime,
                                'type': 'SHORT ENTRY'
                            })

                    elif target_position == 0 and position != 0:  # Exit
                        if position > 0:  # Close long
                            proceeds = position * current_price * (1 - SLIPPAGE) - abs(position) * current_price * FEE_RATE
                            capital += proceeds
                        else:  # Close short (simplified)
                            # Short profit = entry_price - exit_price
                            # We'll just reverse the position value change
                            proceeds = abs(position) * current_price * (1 - SLIPPAGE) - abs(position) * current_price * FEE_RATE
                            capital += proceeds

                        position = 0
                        last_trade_regime = regime

                        trades.append({
                            'time': timestamps[i],
                            'price': current_price,
                            'position': position,
                            'regime': regime,
                            'type': 'EXIT'
                        })

            current_regime = regime
            regime_start_idx = i

        # Track equity
        equity = capital + position * current_price
        equity_curve.append((timestamps[i], equity))

        if i % 50000 == 0:
            print(f"   Progress: {i:,}/{len(prices):,} | Equity: €{equity:,.0f} | Trades: {len(trades)}")

    # Close final position
    if position != 0:
        if position > 0:
            proceeds = position * prices[-1] * (1 - SLIPPAGE) - abs(position) * prices[-1] * FEE_RATE
        else:
            proceeds = abs(position) * prices[-1] * (1 - SLIPPAGE) - abs(position) * prices[-1] * FEE_RATE
        capital += proceeds
        position = 0

    final_equity = capital

    # Calculate metrics
    total_return = (final_equity - INITIAL_CAPITAL) / INITIAL_CAPITAL * 100
    max_equity = max(e[1] for e in equity_curve)
    max_drawdown = min((e[1] - max_equity) / max_equity for e in equity_curve) * 100

    returns = [(equity_curve[i][1] - equity_curve[i-1][1]) / equity_curve[i-1][1]
               for i in range(1, len(equity_curve)) if equity_curve[i-1][1] > 0]
    sharpe = np.mean(returns) / (np.std(returns) + 1e-8) * np.sqrt(365 * 24 * 60) if returns else 0

    # Print results
    print("\n" + "=" * 70)
    print("RESULTS")
    print("=" * 70)

    print(f"""
    Initial Capital:      €{INITIAL_CAPITAL:,.0f}
    Final Capital:        €{final_equity:,.0f}
    ─────────────────────────────────────────
    Total Return:         {total_return:+.2f}%
    Max Drawdown:         {max_drawdown:.2f}%
    Sharpe Ratio (ann.):  {sharpe:.2f}
    ─────────────────────────────────────────
    Total Trades:         {len(trades):,}
    Avg Trades/Month:     {len(trades) / 12:.1f}
    ─────────────────────────────────────────
    """)

    print("    Regime Distribution:")
    for regime, count in sorted(regime_counts.items(), key=lambda x: -x[1]):
        pct = count / sum(regime_counts.values()) * 100
        print(f"       {regime:12s}: {count:6,} ({pct:5.1f}%)")

    if trades:
        print("\n    Recent Trades:")
        for t in trades[-10:]:
            dt = datetime.fromtimestamp(t['time'])
            print(f"       {dt.strftime('%Y-%m-%d %H:%M')} | {t['type']:12s} | {t['regime']:10s} | €{t['price']:.0f}")

    # Save
    output_file = MODELS_DIR / 'backtest_regime_transition.csv'
    with open(output_file, 'w') as f:
        f.write('timestamp,equity\n')
        for ts, eq in equity_curve[::100]:
            f.write(f"{ts},{eq:.2f}\n")
    print(f"\n    Saved to: {output_file}")

    return {'total_return': total_return, 'max_drawdown': max_drawdown, 'sharpe': sharpe, 'trades': len(trades)}

if __name__ == '__main__':
    run_backtest()
