#!/usr/bin/env python3
"""
Simple V6 Backtest - Long Only, Fixed Position Size
More realistic simulation without compounding bugs.
"""

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime

MODELS_DIR = Path('/var/www/html/bestrading.cuttalo.com/models/btc_v6')
DATA_FILE = Path('/var/www/html/bestrading.cuttalo.com/scripts/prices_BTC_EUR_2025_full.csv')

# Realistic trading parameters
INITIAL_CAPITAL = 10000  # EUR
FEE_RATE = 0.002  # 0.2% per trade (Kraken taker)
SLIPPAGE = 0.0003  # 0.03%
FIXED_TRADE_SIZE = 1500  # EUR per trade (fixed, no compounding)
MIN_REGIME_HOLD = 120  # Hold regime position for at least 2 hours

class TradingTransformer(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 256, num_heads: int = 4, num_layers: int = 2):
        super().__init__()
        self.embed = nn.Linear(input_dim + 1, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads,
            dim_feedforward=hidden_dim * 4, dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.actor_mean = nn.Linear(hidden_dim, 1)
        self.actor_log_std = nn.Parameter(torch.zeros(1) - 0.5)
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.embed(x)
        h = self.transformer(h)
        h = h[:, -1]
        mean = torch.tanh(self.actor_mean(h))
        return mean

def compute_features(prices: np.ndarray, idx: int) -> np.ndarray:
    if idx < 60:
        return None

    features = []
    for period in [1, 5, 10, 20, 60]:
        if idx >= period:
            ret = (prices[idx] - prices[idx - period]) / (prices[idx - period] + 1e-8)
        else:
            ret = 0
        features.append(ret)

    price_window = prices[max(0, idx-30):idx+1]
    if len(price_window) > 1:
        ret1 = np.diff(price_window) / (price_window[:-1] + 1e-8)
        features.append(np.std(ret1[-10:]) if len(ret1) >= 10 else np.std(ret1))
        features.append(np.std(ret1))
    else:
        features.extend([0, 0])

    for period in [10, 30]:
        if idx >= period:
            ma = np.mean(prices[idx-period:idx])
            features.append((prices[idx] - ma) / (ma + 1e-8))
        else:
            features.append(0)

    features = np.array(features, dtype=np.float32)
    features = np.clip(features, -3, 3)
    return features

def detect_regime(features: np.ndarray) -> str:
    if features is None:
        return 'scalper'

    ret20 = features[3]
    ret60 = features[4]
    vol30 = features[6]

    if vol30 > 0.012:
        return 'volatile'
    elif ret20 > 0.002 and ret60 > 0.001:
        return 'bullish'
    elif ret20 < -0.002 and ret60 < -0.001:
        return 'bearish'
    elif abs(ret60) < 0.001 and vol30 < 0.004:
        return 'ranging'
    else:
        return 'scalper'

def load_models():
    models = {}
    device = torch.device('cpu')
    for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
        path = MODELS_DIR / f'model_{regime}_v6.pt'
        if path.exists():
            checkpoint = torch.load(path, map_location=device, weights_only=False)
            input_dim = checkpoint.get('input_dim', 9)
            model = TradingTransformer(input_dim).to(device)
            model.load_state_dict(checkpoint['model_state'])
            model.eval()
            models[regime] = model
    return models

def run_backtest():
    print("=" * 70)
    print("V6 SIMPLE BACKTEST - LONG ONLY")
    print("=" * 70)

    models = load_models()
    print(f"  Loaded {len(models)} models")

    df = pd.read_csv(DATA_FILE)
    prices = df['close'].values.astype(np.float32)
    timestamps = df['timestamp'].values

    print(f"  Data: {len(prices):,} candles")
    print(f"  Range: €{prices.min():.0f} - €{prices.max():.0f}")

    # Simulation state
    capital = INITIAL_CAPITAL
    btc_held = 0.0
    entry_price = 0.0
    entry_idx = 0
    trades = []
    equity_curve = []

    current_regime = 'scalper'
    prev_regime = 'scalper'
    regime_change_idx = 61

    winning_trades = 0
    losing_trades = 0

    for i in range(61, len(prices)):
        price = prices[i]
        features = compute_features(prices, i)
        if features is None:
            continue

        regime = detect_regime(features)

        # Check for regime change
        if regime != current_regime:
            prev_regime = current_regime
            current_regime = regime
            regime_change_idx = i

        time_in_regime = i - regime_change_idx

        # Trading logic: Long only based on regime transitions
        action = None

        # Entry: Go long when entering bullish (confirmed by model)
        if btc_held == 0 and current_regime == 'bullish' and time_in_regime >= 30:
            model = models.get('bullish', models['scalper'])
            feat_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
            pos_tensor = torch.tensor([0.0], dtype=torch.float32).unsqueeze(0)
            seq_feat = feat_tensor.unsqueeze(0).expand(1, 20, -1)
            seq_pos = pos_tensor.unsqueeze(1).expand(1, 20, 1)
            input_tensor = torch.cat([seq_feat, seq_pos], dim=-1)

            with torch.no_grad():
                signal = model(input_tensor).item()

            if signal > 0.2:  # Model confirms bullish
                action = 'BUY'

        # Exit: Close when leaving bullish or entering bearish
        elif btc_held > 0:
            time_held = i - entry_idx

            # Minimum hold time
            if time_held < MIN_REGIME_HOLD:
                pass
            # Exit on bearish signal
            elif current_regime == 'bearish' and time_in_regime >= 30:
                action = 'SELL'
            # Exit on volatile signal (risk off)
            elif current_regime == 'volatile' and time_in_regime >= 30:
                action = 'SELL'
            # Take profit if up > 3%
            elif price > entry_price * 1.03:
                action = 'SELL'
            # Stop loss if down > 2%
            elif price < entry_price * 0.98:
                action = 'SELL'

        # Execute trade
        if action == 'BUY' and capital >= FIXED_TRADE_SIZE:
            btc_amount = (FIXED_TRADE_SIZE * (1 - FEE_RATE - SLIPPAGE)) / price
            capital -= FIXED_TRADE_SIZE
            btc_held = btc_amount
            entry_price = price
            entry_idx = i

            trades.append({
                'time': timestamps[i],
                'action': 'BUY',
                'price': price,
                'btc': btc_amount,
                'regime': current_regime
            })

        elif action == 'SELL' and btc_held > 0:
            proceeds = btc_held * price * (1 - FEE_RATE - SLIPPAGE)
            pnl = proceeds - FIXED_TRADE_SIZE
            capital += proceeds

            if pnl > 0:
                winning_trades += 1
            else:
                losing_trades += 1

            trades.append({
                'time': timestamps[i],
                'action': 'SELL',
                'price': price,
                'btc': btc_held,
                'pnl': pnl,
                'regime': current_regime
            })

            btc_held = 0
            entry_price = 0
            entry_idx = 0

        # Track equity
        equity = capital + btc_held * price
        equity_curve.append((timestamps[i], equity))

        if i % 100000 == 0:
            print(f"  Progress: {i:,}/{len(prices):,} | Equity: €{equity:,.0f} | Trades: {len(trades)}")

    # Close any open position
    if btc_held > 0:
        proceeds = btc_held * prices[-1] * (1 - FEE_RATE - SLIPPAGE)
        pnl = proceeds - FIXED_TRADE_SIZE
        capital += proceeds
        if pnl > 0:
            winning_trades += 1
        else:
            losing_trades += 1
        trades.append({
            'time': timestamps[-1],
            'action': 'SELL (CLOSE)',
            'price': prices[-1],
            'btc': btc_held,
            'pnl': pnl
        })
        btc_held = 0

    final_equity = capital
    total_return = (final_equity - INITIAL_CAPITAL) / INITIAL_CAPITAL * 100

    # Calculate metrics
    max_equity = max(e[1] for e in equity_curve) if equity_curve else INITIAL_CAPITAL
    max_drawdown = min((e[1] - max_equity) / max_equity for e in equity_curve) * 100 if equity_curve else 0

    total_trades = len([t for t in trades if t['action'] == 'SELL' or t['action'] == 'SELL (CLOSE)'])
    win_rate = winning_trades / total_trades * 100 if total_trades > 0 else 0

    # Print results
    print("\n" + "=" * 70)
    print("RESULTS")
    print("=" * 70)

    print(f"""
    Initial Capital:      €{INITIAL_CAPITAL:,.0f}
    Final Capital:        €{final_equity:,.0f}
    ─────────────────────────────────────────
    Total Return:         {total_return:+.2f}%
    Max Drawdown:         {max_drawdown:.2f}%
    ─────────────────────────────────────────
    Total Trades:         {total_trades}
    Winning Trades:       {winning_trades}
    Losing Trades:        {losing_trades}
    Win Rate:             {win_rate:.1f}%
    ─────────────────────────────────────────
    """)

    # Show trade history
    if trades:
        print("    Trade History (last 20):")
        for t in trades[-20:]:
            dt = datetime.fromtimestamp(t['time'])
            pnl_str = f"PnL: €{t.get('pnl', 0):+.2f}" if 'pnl' in t else ""
            print(f"       {dt.strftime('%Y-%m-%d %H:%M')} | {t['action']:12s} | €{t['price']:.0f} | {pnl_str}")

    # BTC price change for comparison
    btc_return = (prices[-1] - prices[0]) / prices[0] * 100
    print(f"\n    BTC Buy & Hold:       {btc_return:+.2f}%")
    print(f"    Strategy Return:      {total_return:+.2f}%")
    print(f"    Alpha (vs B&H):       {total_return - btc_return:+.2f}%")

    return {'total_return': total_return, 'max_drawdown': max_drawdown, 'win_rate': win_rate, 'trades': total_trades}

if __name__ == '__main__':
    run_backtest()
