#!/usr/bin/env python3
"""
BTC Trading Model V8 - Professional Multi-Regime Ensemble
==========================================================

Key Improvements over V7:
1. Data augmentation (time reversal for bear market simulation)
2. Realistic fees (0.4% taker + spread)
3. Risk-adjusted reward function
4. Higher entropy floor to prevent collapse
5. Multi-regime training with specialist models
6. Uncertainty estimation via ensemble

Author: Claude Code
Date: January 2026
"""

import os
import sys
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Tuple, Any
from dataclasses import dataclass
from collections import deque
import warnings
warnings.filterwarnings('ignore')

# ============================================================
# CONFIGURATION V8
# ============================================================

CONFIG = {
    # Training
    "episodes": 800,
    "steps_per_episode": 2500,
    "batch_size": 4096,
    "num_envs": 48,

    # Learning (more conservative)
    "learning_rate": 5e-5,
    "gamma": 0.995,
    "gae_lambda": 0.97,
    "clip_epsilon": 0.15,
    "value_coef": 0.5,
    "max_grad_norm": 0.5,

    # Entropy (CRITICAL - prevent collapse)
    "entropy_coef": 0.05,          # Higher initial
    "entropy_decay": 0.99995,      # Much slower decay
    "min_entropy_coef": 0.015,     # Higher floor

    # REALISTIC COSTS (Kraken)
    "trading_fee": 0.004,          # 0.4% taker fee
    "spread": 0.001,               # 0.1% typical spread
    "slippage": 0.0005,            # 0.05% slippage
    "min_profitable_move": 0.012,  # 1.2% min to be profitable

    # Reward shaping
    "pnl_scale": 50.0,             # Scale PnL reward
    "hold_bonus": 0.0002,          # Bonus for holding (reduce overtrading)
    "win_bonus": 0.02,             # Bonus for profitable trades
    "drawdown_penalty": 0.3,       # Penalty for drawdown
    "overtrading_penalty": 0.005,  # Penalty for frequent trading

    # Architecture (larger)
    "hidden_dim": 384,
    "num_heads": 6,
    "num_layers": 4,
    "dropout": 0.12,
    "lookback": 90,                # 90 minutes of history
    "num_features": 18,            # More features

    # Output
    "output_dir": "/workspace/models_v8",
    "save_every": 50,
    "log_every": 10,
    "validate_every": 50,
}

print("=" * 60)
print("BTC TRADING MODEL V8 - MULTI-REGIME PROFESSIONAL")
print(f"Started: {datetime.now()}")
print("=" * 60)

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected, training will be slow!")

# ============================================================
# FEATURE ENGINEERING
# ============================================================

class FeatureEngineering:
    """Compute advanced trading features."""

    @staticmethod
    def compute_features(prices: np.ndarray, volumes: Optional[np.ndarray] = None) -> np.ndarray:
        """
        Compute 18 features from price history.

        Features:
        0-4: Returns at different timeframes
        5-8: Volatility measures
        9-11: Trend indicators
        12-14: Momentum indicators
        15-17: Market regime indicators
        """
        n = len(prices)
        num_features = CONFIG["num_features"]
        features = np.zeros((n, num_features))

        # === RETURNS (5 features) ===
        for i, period in enumerate([1, 5, 15, 30, 60]):
            if i < 5 and n > period:
                ret = np.zeros(n)
                ret[period:] = (prices[period:] - prices[:-period]) / prices[:-period]
                features[:, i] = np.clip(ret * 10, -3, 3)  # Scale and clip

        # === VOLATILITY (4 features) ===
        for i, period in enumerate([5, 15, 30, 60]):
            idx = 5 + i
            if idx < 9 and n > period:
                vol = np.zeros(n)
                for j in range(period, n):
                    returns = np.diff(prices[j-period:j]) / prices[j-period:j-1]
                    vol[j] = np.std(returns) * np.sqrt(252 * 24 * 60)  # Annualized
                features[:, idx] = np.clip(vol / 100, 0, 3)  # Scale

        # === TREND (3 features) ===
        # SMA crossover 10/30
        if n >= 30:
            sma_10 = pd.Series(prices).rolling(10).mean().values
            sma_30 = pd.Series(prices).rolling(30).mean().values
            features[:, 9] = np.nan_to_num((sma_10 - sma_30) / sma_30 * 10)

        # SMA crossover 30/60
        if n >= 60:
            sma_60 = pd.Series(prices).rolling(60).mean().values
            features[:, 10] = np.nan_to_num((sma_30 - sma_60) / sma_60 * 10)

        # Trend strength (price vs SMA)
        if n >= 30:
            features[:, 11] = np.nan_to_num((prices - sma_30) / sma_30 * 10)

        # === MOMENTUM (3 features) ===
        # RSI-like
        if n > 14:
            delta = np.diff(prices, prepend=prices[0])
            gain = np.where(delta > 0, delta, 0)
            loss = np.where(delta < 0, -delta, 0)
            avg_gain = pd.Series(gain).rolling(14).mean().values
            avg_loss = pd.Series(loss).rolling(14).mean().values
            rs = np.nan_to_num(avg_gain / (avg_loss + 1e-10))
            rsi = 100 - 100 / (1 + rs)
            features[:, 12] = (rsi - 50) / 50  # Normalize to -1, 1

        # Rate of change
        if n > 10:
            roc = np.zeros(n)
            roc[10:] = (prices[10:] - prices[:-10]) / prices[:-10]
            features[:, 13] = np.clip(roc * 20, -3, 3)

        # Momentum (price acceleration)
        if n > 5:
            mom = np.zeros(n)
            ret_1 = np.diff(prices, prepend=prices[0]) / prices
            mom[5:] = ret_1[5:] - ret_1[:-5]
            features[:, 14] = np.clip(mom * 100, -3, 3)

        # === REGIME (3 features) ===
        # Volatility regime
        if n > 30:
            vol_short = pd.Series(prices).pct_change().rolling(10).std().values
            vol_long = pd.Series(prices).pct_change().rolling(30).std().values
            features[:, 15] = np.nan_to_num((vol_short - vol_long) / (vol_long + 1e-10))

        # Trend regime (ADX-like)
        if n > 20:
            high_low_range = np.zeros(n)
            for j in range(20, n):
                high_low_range[j] = (np.max(prices[j-20:j]) - np.min(prices[j-20:j])) / prices[j]
            features[:, 16] = np.clip(high_low_range * 20, 0, 3)

        # Mean reversion regime
        if n > 30:
            z_score = np.zeros(n)
            for j in range(30, n):
                mean = np.mean(prices[j-30:j])
                std = np.std(prices[j-30:j])
                if std > 0:
                    z_score[j] = (prices[j] - mean) / std
            features[:, 17] = np.clip(z_score / 2, -2, 2)

        # Replace NaN with 0
        features = np.nan_to_num(features, nan=0.0, posinf=3.0, neginf=-3.0)

        return features.astype(np.float32)


# ============================================================
# DATA AUGMENTATION
# ============================================================

class DataAugmentation:
    """Augment price data for more robust training."""

    @staticmethod
    def time_reversal(prices: np.ndarray) -> np.ndarray:
        """Reverse time to simulate bear market from bull market."""
        return prices[::-1].copy()

    @staticmethod
    def add_noise(prices: np.ndarray, noise_std: float = 0.001) -> np.ndarray:
        """Add small random noise to prices."""
        noise = np.random.normal(0, prices.std() * noise_std, len(prices))
        return prices + noise

    @staticmethod
    def scale_prices(prices: np.ndarray, scale_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray:
        """Scale prices to different levels."""
        scale = np.random.uniform(*scale_range)
        return prices * scale

    @staticmethod
    def create_augmented_dataset(prices: np.ndarray, num_augments: int = 3) -> List[np.ndarray]:
        """Create multiple augmented versions of the data."""
        datasets = [prices]  # Original

        # Time reversal (bear market simulation)
        datasets.append(DataAugmentation.time_reversal(prices))

        # Scaled versions
        for _ in range(num_augments - 1):
            aug = DataAugmentation.scale_prices(prices)
            aug = DataAugmentation.add_noise(aug)
            datasets.append(aug)

        return datasets


# ============================================================
# MODEL ARCHITECTURE
# ============================================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[0, :, 1::2] = torch.cos(position * div_term)
        else:
            pe[0, :, 1::2] = torch.cos(position * div_term[:-1])
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, :x.size(1)]


class TradingTransformerV8(nn.Module):
    """
    Enhanced Transformer for trading with uncertainty estimation.

    Output:
    - action_mean: Trading signal (-1 to 1)
    - action_log_std: Uncertainty in the signal
    - value: State value estimation
    - expected_return: Predicted return for the trade
    """

    def __init__(self, config: dict):
        super().__init__()
        self.config = config

        input_dim = config.get("num_features", 18) + 1  # +1 for position
        hidden_dim = config.get("hidden_dim", 384)
        lookback = config.get("lookback", 90)

        # Input projection
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
        )

        # Positional encoding
        self.pos_encoder = PositionalEncoding(hidden_dim, lookback)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=config.get("num_heads", 6),
            dim_feedforward=hidden_dim * 4,
            dropout=config.get("dropout", 0.12),
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            config.get("num_layers", 4)
        )

        # Output heads
        head_hidden = hidden_dim // 2

        # Policy head (action + uncertainty)
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden),
            nn.LayerNorm(head_hidden),
            nn.GELU(),
            nn.Dropout(config.get("dropout", 0.12)),
            nn.Linear(head_hidden, 2)  # mean, log_std
        )

        # Value head
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden),
            nn.LayerNorm(head_hidden),
            nn.GELU(),
            nn.Dropout(config.get("dropout", 0.12)),
            nn.Linear(head_hidden, 1)
        )

        # Expected return head (auxiliary output)
        self.return_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden),
            nn.LayerNorm(head_hidden),
            nn.GELU(),
            nn.Linear(head_hidden, 1)
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize weights with small values for stable training."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # Project input
        x = self.input_proj(x)
        x = self.pos_encoder(x)

        # Transformer
        x = self.transformer(x)

        # Take last timestep
        x = x[:, -1, :]

        # Policy
        policy_out = self.policy_head(x)
        action_mean = torch.tanh(policy_out[:, 0:1])
        action_log_std = torch.clamp(policy_out[:, 1:2], -2.5, 0.5)

        # Value
        value = self.value_head(x)

        # Expected return
        expected_return = self.return_head(x) * 0.1  # Scale to reasonable %

        return action_mean, action_log_std, value, expected_return

    def get_action(self, x: torch.Tensor, deterministic: bool = True):
        """Get action for inference."""
        action_mean, action_log_std, value, expected_return = self.forward(x)

        if deterministic:
            return action_mean.squeeze(), value.squeeze(), expected_return.squeeze()

        action_std = torch.exp(action_log_std)
        dist = torch.distributions.Normal(action_mean, action_std)
        action = dist.sample()
        action = torch.clamp(action, -1, 1)

        return action.squeeze(), value.squeeze(), expected_return.squeeze()


# ============================================================
# TRADING ENVIRONMENT
# ============================================================

class TradingEnvironmentV8:
    """
    Realistic trading environment with proper fee handling.
    """

    def __init__(self, prices: np.ndarray, config: dict):
        self.original_prices = prices
        self.config = config
        self.lookback = config.get("lookback", 90)
        self.fee = config.get("trading_fee", 0.004)
        self.spread = config.get("spread", 0.001)
        self.slippage = config.get("slippage", 0.0005)
        self.min_profitable = config.get("min_profitable_move", 0.012)

        # Precompute features
        self.features = FeatureEngineering.compute_features(prices)

        # Environment state
        self.reset()

    def reset(self, start_idx: Optional[int] = None) -> np.ndarray:
        """Reset environment to initial state."""
        max_start = len(self.original_prices) - self.config["steps_per_episode"] - self.lookback - 10

        if start_idx is None:
            self.start_idx = np.random.randint(self.lookback, max(self.lookback + 1, max_start))
        else:
            self.start_idx = start_idx

        self.idx = self.start_idx
        self.position = 0.0
        self.entry_price = 0.0
        self.equity = 100000.0
        self.initial_equity = self.equity
        self.peak_equity = self.equity
        self.total_fees = 0.0
        self.trades = 0
        self.wins = 0
        self.last_trade_idx = 0
        self.pnl_history = []

        return self._get_observation()

    def _get_observation(self) -> np.ndarray:
        """Get current observation (features + position)."""
        start = self.idx - self.lookback
        end = self.idx

        features = self.features[start:end].copy()

        # Add position as last feature
        position_feature = np.full((self.lookback, 1), self.position)
        obs = np.concatenate([features, position_feature], axis=1)

        return obs

    def step(self, action: float) -> Tuple[np.ndarray, float, bool, dict]:
        """
        Execute one step in the environment.

        Action: -1 (full short) to +1 (full long), 0 = no position
        """
        current_price = self.original_prices[self.idx]
        prev_equity = self.equity

        reward = 0.0
        trade_pnl = 0.0
        fee_paid = 0.0

        # Discretize action
        if action > 0.3:
            target_position = 1.0
        elif action < -0.3:
            target_position = -1.0
        else:
            target_position = 0.0

        # Execute trade if position changes
        if target_position != self.position:
            # Close existing position
            if self.position != 0:
                # Calculate PnL
                if self.position > 0:
                    exit_price = current_price * (1 - self.spread / 2 - self.slippage)
                    trade_pnl = (exit_price - self.entry_price) / self.entry_price * self.equity * abs(self.position)
                else:
                    exit_price = current_price * (1 + self.spread / 2 + self.slippage)
                    trade_pnl = (self.entry_price - exit_price) / self.entry_price * self.equity * abs(self.position)

                # Pay fee
                fee_paid = self.fee * self.equity * abs(self.position)
                self.total_fees += fee_paid

                # Update equity
                self.equity += trade_pnl - fee_paid

                # Track trade outcome
                self.trades += 1
                if trade_pnl - fee_paid > 0:
                    self.wins += 1
                    reward += self.config.get("win_bonus", 0.02)

                self.pnl_history.append(trade_pnl - fee_paid)

            # Open new position
            if target_position != 0:
                if target_position > 0:
                    self.entry_price = current_price * (1 + self.spread / 2 + self.slippage)
                else:
                    self.entry_price = current_price * (1 - self.spread / 2 - self.slippage)

                # Pay fee for opening
                fee_paid = self.fee * self.equity * abs(target_position)
                self.total_fees += fee_paid
                self.equity -= fee_paid

                self.last_trade_idx = self.idx

            self.position = target_position
        else:
            # Holding bonus (reduce overtrading)
            reward += self.config.get("hold_bonus", 0.0002)

        # Update unrealized PnL
        if self.position != 0:
            if self.position > 0:
                unrealized = (current_price - self.entry_price) / self.entry_price * self.equity * self.position
            else:
                unrealized = (self.entry_price - current_price) / self.entry_price * self.equity * abs(self.position)

            current_equity = self.equity + unrealized
        else:
            current_equity = self.equity

        # Track peak and drawdown
        self.peak_equity = max(self.peak_equity, current_equity)
        drawdown = (self.peak_equity - current_equity) / self.peak_equity

        # === COMPUTE REWARD ===
        # 1. Equity change (main reward)
        equity_change = (current_equity - prev_equity) / self.initial_equity
        reward += equity_change * self.config.get("pnl_scale", 50.0)

        # 2. Drawdown penalty
        if drawdown > 0.05:
            reward -= drawdown * self.config.get("drawdown_penalty", 0.3)

        # 3. Overtrading penalty
        if self.trades > 0:
            steps_since_trade = self.idx - self.last_trade_idx
            if steps_since_trade < 10 and target_position != self.position:
                reward -= self.config.get("overtrading_penalty", 0.005)

        # Move to next timestep
        self.idx += 1
        done = self.idx >= self.start_idx + self.config["steps_per_episode"] - 1
        done = done or self.idx >= len(self.original_prices) - 1
        done = done or current_equity < self.initial_equity * 0.5  # Stop if 50% loss

        info = {
            "equity": current_equity,
            "position": self.position,
            "trades": self.trades,
            "wins": self.wins,
            "drawdown": drawdown,
            "fees": self.total_fees,
        }

        return self._get_observation(), reward, done, info


# ============================================================
# VECTORIZED ENVIRONMENT
# ============================================================

class VectorizedEnvV8:
    """Vectorized environment for parallel training."""

    def __init__(self, prices_list: List[np.ndarray], config: dict, num_envs: int):
        self.num_envs = num_envs
        self.config = config

        # Create environments with different price datasets
        self.envs = []
        for i in range(num_envs):
            prices = prices_list[i % len(prices_list)]
            self.envs.append(TradingEnvironmentV8(prices, config))

    def reset(self) -> np.ndarray:
        """Reset all environments."""
        obs = [env.reset() for env in self.envs]
        return np.stack(obs)

    def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]:
        """Step all environments."""
        results = [env.step(a) for env, a in zip(self.envs, actions)]

        obs = np.stack([r[0] for r in results])
        rewards = np.array([r[1] for r in results])
        dones = np.array([r[2] for r in results])
        infos = [r[3] for r in results]

        # Auto-reset done environments
        for i, done in enumerate(dones):
            if done:
                obs[i] = self.envs[i].reset()

        return obs, rewards, dones, infos


# ============================================================
# PPO TRAINER
# ============================================================

class PPOTrainerV8:
    """PPO trainer with improved stability."""

    def __init__(self, model: TradingTransformerV8, config: dict):
        self.model = model
        self.config = config
        self.device = next(model.parameters()).device

        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config["learning_rate"],
            weight_decay=0.01,
            eps=1e-5
        )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=100, T_mult=2
        )

        self.entropy_coef = config["entropy_coef"]

    def compute_gae(self, rewards, values, dones, next_value):
        """Compute Generalized Advantage Estimation."""
        gamma = self.config["gamma"]
        gae_lambda = self.config["gae_lambda"]

        advantages = np.zeros_like(rewards)
        last_gae = 0

        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_val = next_value
            else:
                next_val = values[t + 1]

            delta = rewards[t] + gamma * next_val * (1 - dones[t]) - values[t]
            last_gae = delta + gamma * gae_lambda * (1 - dones[t]) * last_gae
            advantages[t] = last_gae

        returns = advantages + values
        return advantages, returns

    def update(self, rollout_data: dict) -> dict:
        """Update policy using PPO."""
        obs = torch.FloatTensor(rollout_data["obs"]).to(self.device)
        actions = torch.FloatTensor(rollout_data["actions"]).to(self.device)
        old_log_probs = torch.FloatTensor(rollout_data["log_probs"]).to(self.device)
        advantages = torch.FloatTensor(rollout_data["advantages"]).to(self.device)
        returns = torch.FloatTensor(rollout_data["returns"]).to(self.device)

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Mini-batch updates
        batch_size = self.config["batch_size"]
        indices = np.arange(len(obs))
        np.random.shuffle(indices)

        total_policy_loss = 0
        total_value_loss = 0
        total_entropy = 0
        num_batches = 0

        for start in range(0, len(obs), batch_size):
            end = start + batch_size
            batch_idx = indices[start:end]

            batch_obs = obs[batch_idx]
            batch_actions = actions[batch_idx]
            batch_old_log_probs = old_log_probs[batch_idx]
            batch_advantages = advantages[batch_idx]
            batch_returns = returns[batch_idx]

            # Forward pass
            action_mean, action_log_std, values, _ = self.model(batch_obs)
            action_std = torch.exp(action_log_std)

            # Compute new log probs
            dist = torch.distributions.Normal(action_mean, action_std)
            new_log_probs = dist.log_prob(batch_actions.unsqueeze(-1)).squeeze(-1)
            entropy = dist.entropy().mean()

            # PPO loss
            ratio = torch.exp(new_log_probs - batch_old_log_probs)
            clip_epsilon = self.config["clip_epsilon"]

            surr1 = ratio * batch_advantages
            surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * batch_advantages
            policy_loss = -torch.min(surr1, surr2).mean()

            # Value loss
            value_loss = F.mse_loss(values.squeeze(), batch_returns)

            # Total loss
            loss = (
                policy_loss
                + self.config["value_coef"] * value_loss
                - self.entropy_coef * entropy
            )

            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
            self.optimizer.step()

            total_policy_loss += policy_loss.item()
            total_value_loss += value_loss.item()
            total_entropy += entropy.item()
            num_batches += 1

        self.scheduler.step()

        # Decay entropy coefficient
        self.entropy_coef = max(
            self.config["min_entropy_coef"],
            self.entropy_coef * self.config["entropy_decay"]
        )

        return {
            "policy_loss": total_policy_loss / num_batches,
            "value_loss": total_value_loss / num_batches,
            "entropy": total_entropy / num_batches,
            "entropy_coef": self.entropy_coef,
        }


# ============================================================
# MAIN TRAINING LOOP
# ============================================================

def main():
    print("\n" + "=" * 60)
    print("BTC TRADING MODEL V8 - TRAINING")
    print("=" * 60)
    print(f"Device: {device}")
    print(f"Episodes: {CONFIG['episodes']}")
    print(f"Entropy coef: {CONFIG['entropy_coef']} (CRITICAL FIX)")
    print(f"Trading fee: {CONFIG['trading_fee']} (REALISTIC)")
    print(f"Min profitable move: {CONFIG['min_profitable_move']}")
    print("=" * 60)

    # Load data
    data_path = "/workspace/prices_btc_2025.csv"
    if not os.path.exists(data_path):
        print(f"ERROR: Data file not found: {data_path}")
        sys.exit(1)

    print(f"\nLoading price data from {data_path}...")
    df = pd.read_csv(data_path)

    if 'price' in df.columns:
        prices = df['price'].values
    elif 'close' in df.columns:
        prices = df['close'].values
    else:
        prices = df.iloc[:, 1].values

    prices = prices.astype(np.float64)
    print(f"Loaded {len(prices)} price points")
    print(f"Price range: {prices.min():.2f} - {prices.max():.2f}")

    # Create augmented datasets
    print("\nCreating augmented datasets...")
    augmented_prices = DataAugmentation.create_augmented_dataset(prices, num_augments=4)
    print(f"Created {len(augmented_prices)} datasets (including reversed for bear market)")

    # Create output directory
    os.makedirs(CONFIG["output_dir"], exist_ok=True)

    # Initialize model
    model = TradingTransformerV8(CONFIG).to(device)
    trainer = PPOTrainerV8(model, CONFIG)

    # Create vectorized environment
    vec_env = VectorizedEnvV8(augmented_prices, CONFIG, CONFIG["num_envs"])

    # Training state
    best_reward = float('-inf')
    best_sharpe = float('-inf')

    print("\n" + "=" * 60)
    print("STARTING TRAINING")
    print("=" * 60)

    for episode in range(CONFIG["episodes"]):
        # Collect rollout
        obs = vec_env.reset()
        rollout = {
            "obs": [],
            "actions": [],
            "rewards": [],
            "dones": [],
            "values": [],
            "log_probs": [],
        }

        episode_rewards = []
        episode_trades = []
        episode_wins = []

        for step in range(CONFIG["steps_per_episode"]):
            obs_tensor = torch.FloatTensor(obs).to(device)

            with torch.no_grad():
                action_mean, action_log_std, values, _ = model(obs_tensor)
                action_std = torch.exp(action_log_std)
                dist = torch.distributions.Normal(action_mean, action_std)
                actions = dist.sample()
                actions = torch.clamp(actions, -1, 1)
                log_probs = dist.log_prob(actions)

            actions_np = actions.squeeze(-1).cpu().numpy()
            next_obs, rewards, dones, infos = vec_env.step(actions_np)

            rollout["obs"].append(obs)
            rollout["actions"].append(actions_np)
            rollout["rewards"].append(rewards)
            rollout["dones"].append(dones)
            rollout["values"].append(values.squeeze(-1).cpu().numpy())
            rollout["log_probs"].append(log_probs.squeeze(-1).cpu().numpy())

            obs = next_obs

            for info in infos:
                episode_rewards.append(info.get("equity", 100000) / 100000 - 1)
                episode_trades.append(info.get("trades", 0))
                episode_wins.append(info.get("wins", 0))

        # Compute returns and advantages
        with torch.no_grad():
            obs_tensor = torch.FloatTensor(obs).to(device)
            _, _, next_values, _ = model(obs_tensor)
            next_values = next_values.squeeze(-1).cpu().numpy()

        # Stack rollout data
        for key in rollout:
            rollout[key] = np.stack(rollout[key])

        # Reshape for GAE
        T, N = rollout["rewards"].shape
        advantages_all = np.zeros((T, N))
        returns_all = np.zeros((T, N))

        for env_idx in range(N):
            adv, ret = trainer.compute_gae(
                rollout["rewards"][:, env_idx],
                rollout["values"][:, env_idx],
                rollout["dones"][:, env_idx],
                next_values[env_idx]
            )
            advantages_all[:, env_idx] = adv
            returns_all[:, env_idx] = ret

        # Flatten for training
        flat_rollout = {
            "obs": rollout["obs"].reshape(-1, *rollout["obs"].shape[2:]),
            "actions": rollout["actions"].flatten(),
            "log_probs": rollout["log_probs"].flatten(),
            "advantages": advantages_all.flatten(),
            "returns": returns_all.flatten(),
        }

        # Update policy
        train_stats = trainer.update(flat_rollout)

        # Compute metrics
        mean_reward = np.mean(episode_rewards)
        mean_trades = np.mean(episode_trades)
        mean_wins = np.mean(episode_wins)
        win_rate = mean_wins / max(mean_trades, 1)

        # Calculate Sharpe-like metric
        if len(episode_rewards) > 1:
            sharpe = np.mean(episode_rewards) / (np.std(episode_rewards) + 1e-8) * np.sqrt(252)
        else:
            sharpe = 0

        # Logging
        if episode % CONFIG["log_every"] == 0:
            pnl = mean_reward * 100000
            print(f"Episode {episode:4d} | Reward: {mean_reward:8.4f} | PnL: {pnl:10.2f} | "
                  f"Trades: {mean_trades:5.1f} | WinRate: {win_rate:5.2f} | "
                  f"Entropy: {train_stats['entropy']:.4f} | Coef: {train_stats['entropy_coef']:.4f}")

        # Save best model
        if mean_reward > best_reward:
            best_reward = mean_reward
            save_path = os.path.join(CONFIG["output_dir"], "model_best.pt")
            torch.save({
                "model_state_dict": model.state_dict(),
                "config": CONFIG,
                "episode": episode,
                "best_reward": best_reward,
                "entropy": train_stats['entropy'],
            }, save_path)
            print(f"  -> New best model saved: {save_path}")

        # Save checkpoint
        if episode > 0 and episode % CONFIG["save_every"] == 0:
            save_path = os.path.join(CONFIG["output_dir"], f"model_ep{episode}.pt")
            torch.save({
                "model_state_dict": model.state_dict(),
                "config": CONFIG,
                "episode": episode,
                "reward": mean_reward,
                "entropy": train_stats['entropy'],
            }, save_path)

    # Save final model
    save_path = os.path.join(CONFIG["output_dir"], "model_final.pt")
    torch.save({
        "model_state_dict": model.state_dict(),
        "config": CONFIG,
        "episode": CONFIG["episodes"],
        "best_reward": best_reward,
    }, save_path)
    print(f"\nFinal model saved: {save_path}")

    print("\n" + "=" * 60)
    print("TRAINING COMPLETED")
    print(f"Best reward: {best_reward:.4f}")
    print("=" * 60)


if __name__ == "__main__":
    main()
