#!/usr/bin/env python3
"""
BTC Trading System V9 - Dual Model Architecture
================================================

Based on REAL DATA ANALYSIS:
- 4h timeframe (41.6% profitable after fees vs 4.1% on 15min)
- Autocorrelation 0.98 (trends persist)
- Volatility clustering 0.999 (predictable)

Architecture:
1. VOLATILITY GATE: Predict if there's enough movement to trade
2. TREND DETECTOR: XGBoost classifier for trend direction
3. EXIT OPTIMIZER: RL (PPO) to learn optimal exit timing

Author: Claude Code
Date: January 2026
"""

import os
import sys
import numpy as np
import pandas as pd
from datetime import datetime
from typing import Tuple, List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# ML Libraries
from sklearn.model_selection import train_test_split, TimeSeriesSplit
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, f1_score
import lightgbm as lgb
import joblib

print("=" * 70)
print("BTC TRADING SYSTEM V9 - DUAL MODEL ARCHITECTURE")
print(f"Started: {datetime.now()}")
print("=" * 70)

# ============================================================
# CONFIGURATION
# ============================================================

CONFIG = {
    # Data
    "data_path": "/var/www/html/bestrading.cuttalo.com/scripts/prices_BTC_EUR_2025_full.csv",
    "output_dir": "/var/www/html/bestrading.cuttalo.com/models/btc_v9",

    # Timeframe
    "resample_minutes": 240,  # 4 hours
    "lookback_periods": 30,   # 30 x 4h = 5 days of history

    # Trading
    "fee_rate": 0.005,        # 0.5% total (entry + exit)
    "min_profit_threshold": 0.01,  # 1% minimum expected move
    "confidence_threshold": 0.70,   # 70% confidence to trade

    # Volatility Gate
    "vol_gate_threshold": 0.015,  # 1.5% expected volatility to trade

    # Exit Optimizer (RL)
    "exit_episodes": 500,
    "exit_lr": 1e-4,
    "exit_gamma": 0.99,
    "exit_hidden_dim": 128,

    # Training
    "train_split": 0.8,
    "random_state": 42,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# ============================================================
# STEP 1: DATA PREPARATION
# ============================================================

def load_and_resample_data(config: dict) -> pd.DataFrame:
    """Load minute data and resample to 4h OHLCV."""
    print("\n" + "=" * 60)
    print("STEP 1: DATA PREPARATION")
    print("=" * 60)

    # Load raw data
    df = pd.read_csv(config["data_path"])
    print(f"Loaded {len(df):,} minute candles")

    # Convert timestamp (seconds, not ms)
    df['timestamp'] = pd.to_datetime(df['timestamp'], unit='s')
    df.set_index('timestamp', inplace=True)

    # Resample to 4h
    resample_rule = f"{config['resample_minutes']}min"
    df_4h = df.resample(resample_rule).agg({
        'open': 'first',
        'high': 'max',
        'low': 'min',
        'close': 'last',
        'volume': 'sum'
    }).dropna()

    print(f"Resampled to {len(df_4h):,} 4-hour candles")
    print(f"Date range: {df_4h.index[0]} to {df_4h.index[-1]}")

    return df_4h


def compute_features(df: pd.DataFrame, config: dict) -> pd.DataFrame:
    """Compute features for all models."""
    print("\nComputing features...")

    df = df.copy()

    # === RETURNS ===
    for periods in [1, 2, 3, 6, 12, 24]:  # 4h, 8h, 12h, 24h, 48h, 96h
        df[f'ret_{periods}'] = df['close'].pct_change(periods)

    # === VOLATILITY ===
    for periods in [3, 6, 12, 24]:
        df[f'vol_{periods}'] = df['ret_1'].rolling(periods).std() * np.sqrt(6 * 365)  # Annualized

    # === ATR (Average True Range) ===
    df['tr'] = np.maximum(
        df['high'] - df['low'],
        np.maximum(
            abs(df['high'] - df['close'].shift(1)),
            abs(df['low'] - df['close'].shift(1))
        )
    )
    for periods in [3, 6, 12]:
        df[f'atr_{periods}'] = df['tr'].rolling(periods).mean() / df['close']

    # === TREND INDICATORS ===
    for periods in [6, 12, 24, 48]:
        df[f'sma_{periods}'] = df['close'].rolling(periods).mean()

    # SMA crossovers
    df['sma_cross_6_12'] = (df['sma_6'] - df['sma_12']) / df['close']
    df['sma_cross_12_24'] = (df['sma_12'] - df['sma_24']) / df['close']
    df['sma_cross_24_48'] = (df['sma_24'] - df['sma_48']) / df['close']

    # Price vs SMA
    df['price_vs_sma12'] = (df['close'] - df['sma_12']) / df['sma_12']
    df['price_vs_sma24'] = (df['close'] - df['sma_24']) / df['sma_24']

    # === MOMENTUM ===
    # RSI
    delta = df['close'].diff()
    gain = delta.where(delta > 0, 0).rolling(14).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
    rs = gain / (loss + 1e-10)
    df['rsi'] = 100 - (100 / (1 + rs))
    df['rsi_norm'] = (df['rsi'] - 50) / 50  # Normalize to [-1, 1]

    # ROC
    for periods in [3, 6, 12]:
        df[f'roc_{periods}'] = df['close'].pct_change(periods)

    # === VOLUME ===
    df['volume_sma'] = df['volume'].rolling(12).mean()
    df['volume_ratio'] = df['volume'] / (df['volume_sma'] + 1e-10)

    # === VOLATILITY REGIME ===
    df['vol_regime'] = df['vol_6'] / (df['vol_24'] + 1e-10)  # Short vs long vol

    # === RANGE ===
    df['range_pct'] = (df['high'] - df['low']) / df['close']
    df['range_sma'] = df['range_pct'].rolling(6).mean()

    # === HOUR OF DAY (cyclical encoding) ===
    df['hour'] = df.index.hour
    df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
    df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)

    # === FUTURE TARGET (for training) ===
    # Next 4h return (1 period forward)
    df['future_ret_1'] = df['ret_1'].shift(-1)
    # Next 8h return (2 periods forward)
    df['future_ret_2'] = df['close'].pct_change(2).shift(-2)
    # Max favorable excursion in next 6 periods (24h)
    df['mfe_6'] = df['high'].rolling(6).max().shift(-6) / df['close'] - 1
    # Max adverse excursion in next 6 periods
    df['mae_6'] = df['low'].rolling(6).min().shift(-6) / df['close'] - 1

    # === LABELS ===
    # Volatility Gate label: Will there be >1.5% move in next 24h?
    df['vol_gate_label'] = ((df['mfe_6'].abs() > config['vol_gate_threshold']) |
                            (df['mae_6'].abs() > config['vol_gate_threshold'])).astype(int)

    # Trend label: Direction of next significant move
    # 1 = Long profitable, 0 = Short profitable, -1 = Neither
    profit_threshold = config['fee_rate'] + 0.005  # Fee + 0.5% min profit
    df['trend_label'] = 0
    df.loc[df['mfe_6'] > profit_threshold, 'trend_label'] = 1
    df.loc[df['mae_6'] < -profit_threshold, 'trend_label'] = -1

    # Drop rows with NaN
    df = df.dropna()

    print(f"Features computed: {len(df):,} samples")
    print(f"Volatility Gate positive: {df['vol_gate_label'].mean()*100:.1f}%")
    print(f"Trend labels: Long={((df['trend_label']==1).sum()/len(df)*100):.1f}%, "
          f"Short={((df['trend_label']==-1).sum()/len(df)*100):.1f}%, "
          f"Flat={((df['trend_label']==0).sum()/len(df)*100):.1f}%")

    return df


# ============================================================
# STEP 2: VOLATILITY GATE MODEL
# ============================================================

def train_volatility_gate(df: pd.DataFrame, config: dict) -> Tuple[lgb.Booster, StandardScaler, List[str]]:
    """Train LightGBM model to predict if there will be enough volatility."""
    print("\n" + "=" * 60)
    print("STEP 2: VOLATILITY GATE MODEL")
    print("=" * 60)

    # Features for volatility prediction
    vol_features = [
        'vol_3', 'vol_6', 'vol_12', 'vol_24',
        'atr_3', 'atr_6', 'atr_12',
        'vol_regime', 'range_pct', 'range_sma',
        'volume_ratio',
        'hour_sin', 'hour_cos',
        'ret_1', 'ret_3', 'ret_6',
    ]

    X = df[vol_features].values
    y = df['vol_gate_label'].values

    # Scale features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # Time series split
    split_idx = int(len(X) * config['train_split'])
    X_train, X_val = X_scaled[:split_idx], X_scaled[split_idx:]
    y_train, y_val = y[:split_idx], y[split_idx:]

    print(f"Train: {len(X_train)}, Val: {len(X_val)}")
    print(f"Train positive rate: {y_train.mean()*100:.1f}%")

    # LightGBM
    train_data = lgb.Dataset(X_train, label=y_train)
    val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)

    params = {
        'objective': 'binary',
        'metric': 'auc',
        'boosting_type': 'gbdt',
        'num_leaves': 31,
        'learning_rate': 0.05,
        'feature_fraction': 0.8,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'verbose': -1,
        'is_unbalance': True,
    }

    model = lgb.train(
        params,
        train_data,
        num_boost_round=500,
        valid_sets=[val_data],
        callbacks=[lgb.early_stopping(50), lgb.log_evaluation(100)]
    )

    # Evaluate
    y_pred_proba = model.predict(X_val)
    y_pred = (y_pred_proba > 0.5).astype(int)

    acc = accuracy_score(y_val, y_pred)
    f1 = f1_score(y_val, y_pred)

    print(f"\nVolatility Gate Results:")
    print(f"  Accuracy: {acc*100:.1f}%")
    print(f"  F1 Score: {f1:.3f}")
    print(f"  Precision when predicting vol: {(y_val[y_pred==1].mean()*100):.1f}%")

    # Feature importance
    importance = pd.DataFrame({
        'feature': vol_features,
        'importance': model.feature_importance()
    }).sort_values('importance', ascending=False)
    print(f"\nTop features:")
    print(importance.head(5).to_string(index=False))

    return model, scaler, vol_features


# ============================================================
# STEP 3: TREND DETECTOR MODEL
# ============================================================

def train_trend_detector(df: pd.DataFrame, config: dict) -> Tuple[lgb.Booster, StandardScaler, List[str]]:
    """Train LightGBM classifier for trend direction."""
    print("\n" + "=" * 60)
    print("STEP 3: TREND DETECTOR MODEL")
    print("=" * 60)

    # Only train on samples where volatility gate would be open
    df_volatile = df[df['vol_gate_label'] == 1].copy()
    print(f"Training on {len(df_volatile)} volatile samples")

    # Features for trend prediction
    trend_features = [
        # Returns
        'ret_1', 'ret_2', 'ret_3', 'ret_6', 'ret_12', 'ret_24',
        # Trend
        'sma_cross_6_12', 'sma_cross_12_24', 'sma_cross_24_48',
        'price_vs_sma12', 'price_vs_sma24',
        # Momentum
        'rsi_norm', 'roc_3', 'roc_6', 'roc_12',
        # Volatility
        'vol_6', 'vol_12', 'vol_regime',
        'atr_6', 'atr_12',
        # Volume
        'volume_ratio',
        # Time
        'hour_sin', 'hour_cos',
    ]

    # Convert to binary: Long (1) vs Not Long (0)
    # We'll train two models: Long detector and Short detector
    X = df_volatile[trend_features].values

    # Scaler
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # Split
    split_idx = int(len(X) * config['train_split'])

    # === LONG MODEL ===
    print("\n--- Training LONG detector ---")
    y_long = (df_volatile['trend_label'] == 1).astype(int).values

    X_train, X_val = X_scaled[:split_idx], X_scaled[split_idx:]
    y_train, y_val = y_long[:split_idx], y_long[split_idx:]

    train_data = lgb.Dataset(X_train, label=y_train)
    val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)

    params = {
        'objective': 'binary',
        'metric': 'auc',
        'boosting_type': 'gbdt',
        'num_leaves': 31,
        'learning_rate': 0.03,
        'feature_fraction': 0.8,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'verbose': -1,
        'scale_pos_weight': len(y_train) / (y_train.sum() + 1),
    }

    long_model = lgb.train(
        params,
        train_data,
        num_boost_round=500,
        valid_sets=[val_data],
        callbacks=[lgb.early_stopping(50), lgb.log_evaluation(100)]
    )

    # === SHORT MODEL ===
    print("\n--- Training SHORT detector ---")
    y_short = (df_volatile['trend_label'] == -1).astype(int).values
    y_train_s, y_val_s = y_short[:split_idx], y_short[split_idx:]

    train_data_s = lgb.Dataset(X_train, label=y_train_s)
    val_data_s = lgb.Dataset(X_val, label=y_val_s, reference=train_data_s)

    params['scale_pos_weight'] = len(y_train_s) / (y_train_s.sum() + 1)

    short_model = lgb.train(
        params,
        train_data_s,
        num_boost_round=500,
        valid_sets=[val_data_s],
        callbacks=[lgb.early_stopping(50), lgb.log_evaluation(100)]
    )

    # === EVALUATION ===
    print("\n--- Combined Evaluation ---")
    long_proba = long_model.predict(X_val)
    short_proba = short_model.predict(X_val)

    # Decision logic
    y_actual = df_volatile['trend_label'].values[split_idx:]
    y_pred = np.zeros_like(y_actual)

    confidence_threshold = config['confidence_threshold']

    for i in range(len(X_val)):
        if long_proba[i] > confidence_threshold and long_proba[i] > short_proba[i]:
            y_pred[i] = 1
        elif short_proba[i] > confidence_threshold and short_proba[i] > long_proba[i]:
            y_pred[i] = -1
        else:
            y_pred[i] = 0  # No trade

    # Metrics
    trades = y_pred != 0
    if trades.sum() > 0:
        trade_accuracy = (y_pred[trades] == y_actual[trades]).mean()
        print(f"Trades taken: {trades.sum()} / {len(y_pred)} ({trades.mean()*100:.1f}%)")
        print(f"Trade accuracy: {trade_accuracy*100:.1f}%")

        # Profit simulation
        fee = config['fee_rate']
        profits = []
        for i in range(len(y_pred)):
            if y_pred[i] == 1:  # Long
                profit = df_volatile['future_ret_2'].iloc[split_idx + i] - fee
                profits.append(profit)
            elif y_pred[i] == -1:  # Short
                profit = -df_volatile['future_ret_2'].iloc[split_idx + i] - fee
                profits.append(profit)

        profits = np.array(profits)
        print(f"Avg profit per trade: {profits.mean()*100:.2f}%")
        print(f"Win rate: {(profits > 0).mean()*100:.1f}%")
        print(f"Total return: {profits.sum()*100:.1f}%")

    return long_model, short_model, scaler, trend_features


# ============================================================
# STEP 4: EXIT OPTIMIZER (RL)
# ============================================================

class ExitOptimizer(nn.Module):
    """Neural network for exit decision."""

    def __init__(self, input_dim: int, hidden_dim: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
        )
        self.policy = nn.Linear(hidden_dim, 2)  # Hold, Exit
        self.value = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        features = self.net(x)
        action_logits = self.policy(features)
        value = self.value(features)
        return action_logits, value

    def get_action(self, x, deterministic=False):
        action_logits, value = self.forward(x)
        probs = F.softmax(action_logits, dim=-1)

        if deterministic:
            action = probs.argmax(dim=-1)
        else:
            dist = torch.distributions.Categorical(probs)
            action = dist.sample()

        return action, probs, value


class ExitEnvironment:
    """Environment for training exit optimizer."""

    def __init__(self, df: pd.DataFrame, config: dict):
        self.df = df
        self.config = config
        self.fee = config['fee_rate']
        self.reset()

    def reset(self, idx: Optional[int] = None):
        # Start a new trade
        if idx is None:
            # Random start where trend_label is non-zero
            valid_starts = self.df[self.df['trend_label'] != 0].index
            self.start_idx = np.random.choice(range(len(valid_starts) - 50))
            self.entry_row = valid_starts[self.start_idx]
        else:
            self.entry_row = self.df.index[idx]

        self.entry_price = self.df.loc[self.entry_row, 'close']
        self.direction = self.df.loc[self.entry_row, 'trend_label']  # 1 or -1
        self.current_idx = self.df.index.get_loc(self.entry_row)
        self.steps = 0
        self.max_steps = 12  # Max 48 hours (12 x 4h)
        self.max_profit_seen = 0

        return self._get_state()

    def _get_state(self):
        """Get current state for exit decision."""
        row = self.df.iloc[self.current_idx]
        current_price = row['close']

        # P&L
        if self.direction == 1:  # Long
            pnl = (current_price - self.entry_price) / self.entry_price
        else:  # Short
            pnl = (self.entry_price - current_price) / self.entry_price

        self.max_profit_seen = max(self.max_profit_seen, pnl)

        state = np.array([
            pnl,  # Current P&L
            self.max_profit_seen,  # Max profit seen
            pnl - self.max_profit_seen,  # Drawdown from max
            self.steps / self.max_steps,  # Time in trade
            row['vol_6'],  # Current volatility
            row['ret_1'],  # Recent return
            row['rsi_norm'],  # RSI
            self.direction,  # Trade direction
        ], dtype=np.float32)

        return state

    def step(self, action):
        """Take action: 0=Hold, 1=Exit"""
        self.steps += 1
        self.current_idx += 1

        # Check if we've run out of data
        if self.current_idx >= len(self.df) - 1:
            return self._get_state(), 0, True, {'reason': 'end_of_data'}

        current_price = self.df.iloc[self.current_idx]['close']

        if self.direction == 1:
            pnl = (current_price - self.entry_price) / self.entry_price
        else:
            pnl = (self.entry_price - current_price) / self.entry_price

        done = False
        reward = 0
        info = {}

        if action == 1:  # Exit
            # Reward = profit after fees, with bonus for capturing max profit
            net_pnl = pnl - self.fee

            # Bonus for capturing most of the max profit
            if self.max_profit_seen > 0:
                capture_ratio = pnl / self.max_profit_seen
                capture_bonus = capture_ratio * 0.1
            else:
                capture_bonus = 0

            reward = net_pnl * 100 + capture_bonus  # Scale for RL
            done = True
            info = {'reason': 'exit', 'pnl': net_pnl, 'capture_ratio': pnl / (self.max_profit_seen + 1e-10)}

        elif self.steps >= self.max_steps:  # Force exit
            net_pnl = pnl - self.fee
            reward = net_pnl * 100 - 0.5  # Penalty for timeout
            done = True
            info = {'reason': 'timeout', 'pnl': net_pnl}

        else:  # Hold
            # Small reward for holding in profit, penalty for holding in loss
            if pnl > 0:
                reward = 0.01
            else:
                reward = -0.02

        return self._get_state(), reward, done, info


def train_exit_optimizer(df: pd.DataFrame, config: dict) -> ExitOptimizer:
    """Train the exit optimizer using PPO."""
    print("\n" + "=" * 60)
    print("STEP 4: EXIT OPTIMIZER (RL)")
    print("=" * 60)

    # Only train on volatile samples with clear direction
    df_train = df[(df['vol_gate_label'] == 1) & (df['trend_label'] != 0)].copy()
    print(f"Training samples: {len(df_train)}")

    # Initialize
    env = ExitEnvironment(df_train, config)
    model = ExitOptimizer(input_dim=8, hidden_dim=config['exit_hidden_dim']).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['exit_lr'])

    # Training loop
    episodes = config['exit_episodes']
    batch_size = 32
    gamma = config['exit_gamma']

    all_rewards = []
    all_pnls = []
    best_avg_pnl = -float('inf')

    for ep in range(episodes):
        states, actions, rewards, values, dones = [], [], [], [], []

        for _ in range(batch_size):
            state = env.reset()
            episode_states, episode_actions, episode_rewards, episode_values = [], [], [], []

            done = False
            while not done:
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
                with torch.no_grad():
                    action, probs, value = model.get_action(state_tensor)

                action_item = action.item()
                next_state, reward, done, info = env.step(action_item)

                episode_states.append(state)
                episode_actions.append(action_item)
                episode_rewards.append(reward)
                episode_values.append(value.item())

                state = next_state

            # Store episode data
            states.extend(episode_states)
            actions.extend(episode_actions)
            rewards.extend(episode_rewards)
            values.extend(episode_values)

            if 'pnl' in info:
                all_pnls.append(info['pnl'])

        # Compute returns
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)

        # Convert to tensors
        states_t = torch.FloatTensor(np.array(states)).to(device)
        actions_t = torch.LongTensor(actions).to(device)
        returns_t = torch.FloatTensor(returns).to(device)
        values_t = torch.FloatTensor(values).to(device)

        # Normalize returns
        returns_t = (returns_t - returns_t.mean()) / (returns_t.std() + 1e-8)

        # PPO update
        action_logits, new_values = model(states_t)
        probs = F.softmax(action_logits, dim=-1)
        dist = torch.distributions.Categorical(probs)

        log_probs = dist.log_prob(actions_t)
        advantages = returns_t - values_t

        # Policy loss
        policy_loss = -(log_probs * advantages.detach()).mean()

        # Value loss
        value_loss = F.mse_loss(new_values.squeeze(), returns_t)

        # Entropy bonus
        entropy = dist.entropy().mean()

        # Total loss
        loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        # Logging
        avg_reward = np.mean(rewards)
        all_rewards.append(avg_reward)

        if (ep + 1) % 50 == 0:
            recent_pnls = all_pnls[-100:] if len(all_pnls) >= 100 else all_pnls
            avg_pnl = np.mean(recent_pnls) * 100
            win_rate = np.mean([p > 0 for p in recent_pnls]) * 100

            print(f"Ep {ep+1:4d} | Avg PnL: {avg_pnl:+.2f}% | Win Rate: {win_rate:.1f}% | Avg Reward: {avg_reward:.3f}")

            if avg_pnl > best_avg_pnl:
                best_avg_pnl = avg_pnl
                # Save best model
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'avg_pnl': avg_pnl,
                    'win_rate': win_rate,
                    'episode': ep + 1,
                }, os.path.join(config['output_dir'], 'exit_optimizer_best.pt'))

    print(f"\nBest Avg PnL: {best_avg_pnl:.2f}%")

    return model


# ============================================================
# STEP 5: BACKTEST
# ============================================================

def backtest_system(df: pd.DataFrame, vol_gate_model, long_model, short_model,
                   exit_model, vol_scaler, trend_scaler,
                   vol_features, trend_features, config: dict):
    """Full system backtest."""
    print("\n" + "=" * 60)
    print("STEP 5: FULL SYSTEM BACKTEST")
    print("=" * 60)

    # Use last 20% as test set
    split_idx = int(len(df) * config['train_split'])
    df_test = df.iloc[split_idx:].copy()

    print(f"Test period: {df_test.index[0]} to {df_test.index[-1]}")
    print(f"Test samples: {len(df_test)}")

    # Simulation
    capital = 10000
    position = 0  # 0 = flat, 1 = long, -1 = short
    entry_price = 0
    entry_idx = 0
    trades = []
    equity_curve = [capital]

    confidence_threshold = config['confidence_threshold']
    fee = config['fee_rate']

    i = 0
    while i < len(df_test) - 1:
        row = df_test.iloc[i]

        if position == 0:  # No position - look for entry
            # Volatility Gate
            vol_X = vol_scaler.transform(row[vol_features].values.reshape(1, -1))
            vol_proba = vol_gate_model.predict(vol_X)[0]

            if vol_proba > 0.5:  # Gate open
                # Trend detection
                trend_X = trend_scaler.transform(row[trend_features].values.reshape(1, -1))
                long_proba = long_model.predict(trend_X)[0]
                short_proba = short_model.predict(trend_X)[0]

                if long_proba > confidence_threshold and long_proba > short_proba:
                    position = 1
                    entry_price = row['close']
                    entry_idx = i
                elif short_proba > confidence_threshold and short_proba > long_proba:
                    position = -1
                    entry_price = row['close']
                    entry_idx = i

        else:  # In position - check exit
            current_price = row['close']

            # Calculate P&L
            if position == 1:
                pnl = (current_price - entry_price) / entry_price
            else:
                pnl = (entry_price - current_price) / entry_price

            # Exit state
            max_pnl = pnl  # Simplified - in real would track max
            exit_state = np.array([
                pnl,
                max_pnl,
                0,  # Drawdown
                (i - entry_idx) / 12,  # Time
                row['vol_6'],
                row['ret_1'],
                row['rsi_norm'],
                position,
            ], dtype=np.float32)

            # Exit decision
            state_tensor = torch.FloatTensor(exit_state).unsqueeze(0).to(device)
            with torch.no_grad():
                action, _, _ = exit_model.get_action(state_tensor, deterministic=True)

            # Force exit after 12 periods (48h)
            if action.item() == 1 or (i - entry_idx) >= 12:
                # Close position
                net_pnl = pnl - fee
                capital *= (1 + net_pnl)

                trades.append({
                    'entry_time': df_test.index[entry_idx],
                    'exit_time': df_test.index[i],
                    'direction': 'LONG' if position == 1 else 'SHORT',
                    'entry_price': entry_price,
                    'exit_price': current_price,
                    'pnl': net_pnl,
                    'duration': i - entry_idx,
                })

                position = 0

        equity_curve.append(capital)
        i += 1

    # Results
    trades_df = pd.DataFrame(trades)

    if len(trades_df) > 0:
        print(f"\n=== BACKTEST RESULTS ===")
        print(f"Total trades: {len(trades_df)}")
        print(f"Win rate: {(trades_df['pnl'] > 0).mean() * 100:.1f}%")
        print(f"Avg trade PnL: {trades_df['pnl'].mean() * 100:.2f}%")
        print(f"Best trade: {trades_df['pnl'].max() * 100:.2f}%")
        print(f"Worst trade: {trades_df['pnl'].min() * 100:.2f}%")
        print(f"Avg duration: {trades_df['duration'].mean():.1f} periods ({trades_df['duration'].mean()*4:.0f}h)")
        print(f"\nFinal capital: ${capital:,.2f} ({(capital/10000-1)*100:+.1f}%)")

        # Sharpe
        if len(trades_df) > 1:
            sharpe = trades_df['pnl'].mean() / (trades_df['pnl'].std() + 1e-10) * np.sqrt(252 / trades_df['duration'].mean())
            print(f"Sharpe ratio: {sharpe:.2f}")

        # Save trades
        trades_df.to_csv(os.path.join(config['output_dir'], 'backtest_trades.csv'), index=False)
    else:
        print("No trades executed!")

    return trades_df, equity_curve


# ============================================================
# MAIN
# ============================================================

def main():
    """Main training pipeline."""

    # Step 1: Load and prepare data
    df = load_and_resample_data(CONFIG)
    df = compute_features(df, CONFIG)

    # Save processed data
    df.to_csv(os.path.join(CONFIG['output_dir'], 'data_4h_features.csv'))

    # Step 2: Train Volatility Gate
    vol_gate_model, vol_scaler, vol_features = train_volatility_gate(df, CONFIG)

    # Save volatility gate
    vol_gate_model.save_model(os.path.join(CONFIG['output_dir'], 'vol_gate_model.txt'))
    joblib.dump(vol_scaler, os.path.join(CONFIG['output_dir'], 'vol_scaler.pkl'))
    joblib.dump(vol_features, os.path.join(CONFIG['output_dir'], 'vol_features.pkl'))

    # Step 3: Train Trend Detector
    long_model, short_model, trend_scaler, trend_features = train_trend_detector(df, CONFIG)

    # Save trend models
    long_model.save_model(os.path.join(CONFIG['output_dir'], 'long_model.txt'))
    short_model.save_model(os.path.join(CONFIG['output_dir'], 'short_model.txt'))
    joblib.dump(trend_scaler, os.path.join(CONFIG['output_dir'], 'trend_scaler.pkl'))
    joblib.dump(trend_features, os.path.join(CONFIG['output_dir'], 'trend_features.pkl'))

    # Step 4: Train Exit Optimizer
    exit_model = train_exit_optimizer(df, CONFIG)

    # Save final exit model
    torch.save({
        'model_state_dict': exit_model.state_dict(),
        'config': CONFIG,
    }, os.path.join(CONFIG['output_dir'], 'exit_optimizer_final.pt'))

    # Step 5: Backtest
    trades_df, equity_curve = backtest_system(
        df, vol_gate_model, long_model, short_model, exit_model,
        vol_scaler, trend_scaler, vol_features, trend_features, CONFIG
    )

    # Save config
    import json
    with open(os.path.join(CONFIG['output_dir'], 'config.json'), 'w') as f:
        json.dump(CONFIG, f, indent=2)

    print("\n" + "=" * 60)
    print("TRAINING COMPLETE")
    print(f"Models saved to: {CONFIG['output_dir']}")
    print("=" * 60)


if __name__ == "__main__":
    main()
