#!/usr/bin/env python3
"""
BTC Trading Model V8 ENTERPRISE - String Theory Integration
============================================================

ARCHITETTURA A 24 FEATURES che integra:
1. Features tecniche tradizionali (18)
2. String Theory Math (6 nuove):
   - Hurst Exponent (Differential Geometry)
   - Catastrophe Distance (Catastrophe Theory - René Thom)
   - Shannon Entropy (Information Theory)
   - Kelly Fraction (Ergodic Economics - Ole Peters)
   - Wasserstein Tension (Optimal Transport - Cédric Villani)
   - Surface Instability (Enhanced Surface)

KEY IMPROVEMENTS:
- Data augmentation con time reversal per bear market
- Fee REALI Kraken (0.4% taker + spread)
- Reward function risk-adjusted con Sharpe
- Entropy floor alto (0.015) per evitare collapse
- Multi-regime training

Author: Claude Code Enterprise
Date: January 2026
"""

import os
import sys
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Tuple, Any
from dataclasses import dataclass
from collections import deque
import warnings
warnings.filterwarnings('ignore')

# ============================================================
# CONFIGURATION V8 ENTERPRISE - 24 FEATURES
# ============================================================

CONFIG = {
    # Training
    "episodes": 1000,
    "steps_per_episode": 3000,
    "batch_size": 4096,
    "num_envs": 64,

    # Learning (conservative for stability)
    "learning_rate": 3e-5,
    "gamma": 0.995,
    "gae_lambda": 0.98,
    "clip_epsilon": 0.1,
    "value_coef": 0.5,
    "max_grad_norm": 0.5,

    # Entropy (CRITICAL - prevent policy collapse)
    "entropy_coef": 0.05,          # High initial
    "entropy_decay": 0.99995,      # Very slow decay
    "min_entropy_coef": 0.015,     # High floor

    # REALISTIC COSTS (Kraken verified)
    "trading_fee": 0.004,          # 0.4% taker fee
    "spread": 0.001,               # 0.1% typical spread
    "slippage": 0.0005,            # 0.05% slippage
    "min_profitable_move": 0.012,  # 1.2% min to be profitable after costs

    # Reward shaping (risk-adjusted)
    "pnl_scale": 100.0,            # Scale PnL reward
    "hold_bonus": 0.0003,          # Bonus for holding (reduce overtrading)
    "win_bonus": 0.03,             # Bonus for profitable trades
    "drawdown_penalty": 0.5,       # Strong penalty for drawdown
    "overtrading_penalty": 0.01,   # Penalty for frequent trading
    "sharpe_bonus": 0.1,           # Bonus for consistent returns

    # Architecture (24 features)
    "hidden_dim": 512,             # Larger for more features
    "num_heads": 8,
    "num_layers": 5,
    "dropout": 0.15,
    "lookback": 120,               # 2 hours of history
    "num_features": 24,            # STRING THEORY INTEGRATED

    # Output
    "output_dir": "/workspace/models_v8_enterprise",
    "save_every": 50,
    "log_every": 10,
    "validate_every": 100,
}

print("=" * 70)
print("BTC TRADING MODEL V8 ENTERPRISE - STRING THEORY INTEGRATION")
print(f"Started: {datetime.now()}")
print("=" * 70)

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected, training will be slow!")

# ============================================================
# STRING THEORY MATH MODULES (Ported from tradingstringhe)
# ============================================================

class StringTheoryMath:
    """
    Mathematical modules from String Theory Trading.
    Based on work by:
    - Cédric Villani (Optimal Transport, Fields Medal 2010)
    - René Thom (Catastrophe Theory)
    - Ole Peters (Ergodic Economics)
    - Shannon (Information Theory)
    - Riemann (Differential Geometry)
    """

    # ========== DIFFERENTIAL GEOMETRY ==========
    @staticmethod
    def hurst_exponent(prices: np.ndarray, max_lag: int = 20) -> float:
        """
        Calculate Hurst exponent using R/S analysis.

        H > 0.5: Trending (persistent)
        H = 0.5: Random walk
        H < 0.5: Mean-reverting (anti-persistent)

        Reference: Mandelbrot, B. (1971). "Analysis of long-run dependence in economics"
        """
        if len(prices) < max_lag * 2:
            return 0.5  # Default to random walk

        lags = range(2, max_lag)
        tau = []
        rs_values = []

        for lag in lags:
            try:
                # Get subseries
                subseries = []
                for i in range(0, len(prices) - lag, lag):
                    subseries.append(prices[i:i+lag])

                if len(subseries) == 0:
                    continue

                # Calculate R/S for each subseries
                rs_list = []
                for series in subseries:
                    returns = np.diff(series) / series[:-1]
                    if len(returns) < 2:
                        continue

                    mean_adj = returns - np.mean(returns)
                    cumsum = np.cumsum(mean_adj)
                    R = np.max(cumsum) - np.min(cumsum)
                    S = np.std(returns)

                    if S > 0:
                        rs_list.append(R / S)

                if rs_list:
                    tau.append(lag)
                    rs_values.append(np.mean(rs_list))
            except:
                continue

        if len(tau) < 3:
            return 0.5

        # Linear regression on log-log scale
        try:
            log_tau = np.log(tau)
            log_rs = np.log(rs_values)

            # Simple linear regression
            slope = np.cov(log_tau, log_rs)[0, 1] / np.var(log_tau)
            return np.clip(slope, 0, 1)
        except:
            return 0.5

    @staticmethod
    def price_curvature(prices: np.ndarray, window: int = 20) -> np.ndarray:
        """
        Calculate local curvature of price path.

        High curvature = potential reversal point
        Low curvature = trending
        """
        n = len(prices)
        curvature = np.zeros(n)

        if n < window + 2:
            return curvature

        for i in range(window, n - 1):
            segment = prices[i-window:i+1]
            # Second derivative approximation
            first_diff = np.diff(segment)
            second_diff = np.diff(first_diff)

            # Curvature formula: |f''| / (1 + f'^2)^(3/2)
            if len(second_diff) > 0:
                avg_first = np.mean(np.abs(first_diff))
                avg_second = np.mean(np.abs(second_diff))
                denom = (1 + avg_first**2) ** 1.5
                curvature[i] = avg_second / (denom + 1e-10)

        return curvature

    # ========== CATASTROPHE THEORY (René Thom) ==========
    @staticmethod
    def detect_cusp_catastrophe(prices: np.ndarray, window: int = 30) -> Tuple[float, float]:
        """
        Detect cusp catastrophe using volatility and bias as control parameters.

        Cusp catastrophe potential: V(x) = x^4/4 + a*x^2/2 + b*x

        Returns:
        - distance_to_bifurcation: 0 = at catastrophe point, 1 = far
        - jump_probability: Probability of sudden jump
        """
        if len(prices) < window:
            return 1.0, 0.0

        recent = prices[-window:]
        returns = np.diff(recent) / recent[:-1]

        # Control parameter a: Volatility/tension
        volatility = np.std(returns)

        # Control parameter b: Directional bias
        mean_return = np.mean(returns)
        bias = mean_return / (volatility + 1e-10)

        # Cusp bifurcation set: 4a^3 + 27b^2 = 0
        # Distance to bifurcation
        a = volatility * 100  # Scale
        b = bias * 10

        # Normalized distance (0 = at bifurcation, 1 = far)
        bifurcation_measure = 4 * (a ** 3) + 27 * (b ** 2)
        distance = np.tanh(abs(bifurcation_measure) * 0.1)

        # Jump probability increases near bifurcation
        jump_prob = 1 - distance
        jump_prob *= min(1, volatility * 20)  # Scale by volatility

        return distance, np.clip(jump_prob, 0, 1)

    # ========== INFORMATION THEORY (Shannon) ==========
    @staticmethod
    def shannon_entropy(returns: np.ndarray, bins: int = 20) -> float:
        """
        Calculate Shannon entropy of return distribution.

        H(X) = -Σ p(x) * log(p(x))

        High entropy = chaotic/unpredictable market
        Low entropy = ordered/predictable market
        """
        if len(returns) < bins:
            return 0.5

        # Discretize returns into bins
        hist, _ = np.histogram(returns, bins=bins, density=True)
        hist = hist[hist > 0]  # Remove zero bins

        if len(hist) == 0:
            return 0.5

        # Normalize to probabilities
        hist = hist / hist.sum()

        # Shannon entropy
        entropy = -np.sum(hist * np.log(hist + 1e-10))

        # Normalize by max entropy (uniform distribution)
        max_entropy = np.log(bins)
        normalized = entropy / (max_entropy + 1e-10)

        return np.clip(normalized, 0, 1)

    @staticmethod
    def kl_divergence(returns: np.ndarray, bins: int = 20) -> float:
        """
        KL divergence from normal distribution.

        D_KL(P||Q) = Σ P(x) * log(P(x)/Q(x))

        High KL = returns are very non-normal (fat tails, skew)
        Low KL = returns are near-normal
        """
        if len(returns) < bins:
            return 0.0

        # Actual distribution
        hist, edges = np.histogram(returns, bins=bins, density=True)

        # Normal distribution with same mean/std
        mean, std = np.mean(returns), np.std(returns)
        if std < 1e-10:
            return 0.0

        # Evaluate normal PDF at bin centers
        centers = (edges[:-1] + edges[1:]) / 2
        normal_pdf = np.exp(-0.5 * ((centers - mean) / std) ** 2) / (std * np.sqrt(2 * np.pi))

        # Normalize
        hist = hist / (hist.sum() + 1e-10)
        normal_pdf = normal_pdf / (normal_pdf.sum() + 1e-10)

        # KL divergence (with smoothing to avoid log(0))
        epsilon = 1e-10
        kl = np.sum(hist * np.log((hist + epsilon) / (normal_pdf + epsilon)))

        return np.clip(kl, 0, 5)  # Cap at 5 for normalization

    # ========== ERGODIC ECONOMICS (Ole Peters) ==========
    @staticmethod
    def kelly_fraction(returns: np.ndarray) -> float:
        """
        Calculate Kelly fraction for optimal position sizing.

        f* = μ / σ² (continuous version)

        Reference: Peters, O. (2019). "The ergodicity problem in economics"
        """
        if len(returns) < 10:
            return 0.0

        mean_return = np.mean(returns)
        variance = np.var(returns)

        if variance < 1e-10:
            return 0.0

        # Kelly fraction
        kelly = mean_return / variance

        # Cap at 25% (professional standard)
        return np.clip(kelly, -0.25, 0.25)

    @staticmethod
    def ergodic_penalty(returns: np.ndarray) -> float:
        """
        Calculate the ergodic penalty: σ²/2

        This is what you "lose" to volatility in multiplicative processes.
        """
        volatility = np.std(returns)
        return (volatility ** 2) / 2

    @staticmethod
    def time_average_growth_rate(returns: np.ndarray) -> float:
        """
        Calculate time-average growth rate (what actually matters).

        g = E[log(1 + r)] ≈ μ - σ²/2
        """
        if len(returns) < 2:
            return 0.0

        # Exact calculation
        growth_rates = []
        for r in returns:
            if 1 + r > 0:
                growth_rates.append(np.log(1 + r))
            else:
                growth_rates.append(-10)  # Ruin penalty

        return np.mean(growth_rates)

    # ========== OPTIMAL TRANSPORT (Cédric Villani) ==========
    @staticmethod
    def wasserstein_tension(prices: np.ndarray, window: int = 30) -> float:
        """
        Calculate Wasserstein-inspired tension from price distribution changes.

        High tension = distribution is far from equilibrium
        Low tension = market is in equilibrium
        """
        if len(prices) < window * 2:
            return 0.5

        # Compare recent vs older distribution
        recent = prices[-window:]
        older = prices[-window*2:-window]

        # Calculate returns distributions
        recent_returns = np.diff(recent) / recent[:-1]
        older_returns = np.diff(older) / older[:-1]

        # Simple Wasserstein-1 approximation using sorted values
        sorted_recent = np.sort(recent_returns)
        sorted_older = np.sort(older_returns)

        # Interpolate to same length if needed
        min_len = min(len(sorted_recent), len(sorted_older))
        sorted_recent = sorted_recent[:min_len]
        sorted_older = sorted_older[:min_len]

        # Wasserstein-1 distance
        w1_distance = np.mean(np.abs(sorted_recent - sorted_older))

        # Normalize to 0-1
        tension = np.tanh(w1_distance * 100)

        return tension

    # ========== ENHANCED SURFACE ==========
    @staticmethod
    def surface_instability(prices: np.ndarray, window: int = 30) -> float:
        """
        Calculate market surface instability score.

        Combines multiple tension sources:
        - Volatility tension
        - Momentum tension
        - Mean reversion tension
        """
        if len(prices) < window:
            return 0.5

        recent = prices[-window:]
        returns = np.diff(recent) / recent[:-1]

        # Volatility component
        vol = np.std(returns)
        vol_tension = np.tanh(vol * 50)

        # Momentum component (acceleration)
        if len(returns) > 5:
            momentum = np.mean(returns[-5:]) - np.mean(returns[:-5])
            mom_tension = np.tanh(abs(momentum) * 200)
        else:
            mom_tension = 0

        # Mean reversion component (distance from SMA)
        sma = np.mean(recent)
        deviation = abs(recent[-1] - sma) / sma
        mr_tension = np.tanh(deviation * 20)

        # Combine (weighted average)
        instability = 0.4 * vol_tension + 0.35 * mom_tension + 0.25 * mr_tension

        return np.clip(instability, 0, 1)


# ============================================================
# FEATURE ENGINEERING - 24 FEATURES
# ============================================================

class FeatureEngineeringEnterprise:
    """
    Compute 24 advanced trading features including String Theory modules.

    FEATURE BREAKDOWN:

    TRADITIONAL TECHNICAL (18):
    - Returns (5): 1m, 5m, 15m, 30m, 60m
    - Volatility (4): 5m, 15m, 30m, 60m
    - Trend (3): SMA cross 10/30, 30/60, trend strength
    - Momentum (3): RSI, ROC, acceleration
    - Regime (3): vol regime, trend regime, mean reversion z-score

    STRING THEORY (6):
    - Hurst exponent (Geometry)
    - Catastrophe distance (Catastrophe Theory)
    - Shannon entropy (Information Theory)
    - Kelly fraction (Ergodic Economics)
    - Wasserstein tension (Optimal Transport)
    - Surface instability (Enhanced Surface)
    """

    @staticmethod
    def compute_features(prices: np.ndarray, volumes: Optional[np.ndarray] = None) -> np.ndarray:
        """
        Compute all 24 features from price history.
        """
        n = len(prices)
        num_features = CONFIG["num_features"]  # 24
        features = np.zeros((n, num_features))

        # === RETURNS (5 features: 0-4) ===
        for i, period in enumerate([1, 5, 15, 30, 60]):
            if i < 5 and n > period:
                ret = np.zeros(n)
                ret[period:] = (prices[period:] - prices[:-period]) / prices[:-period]
                features[:, i] = np.clip(ret * 10, -3, 3)

        # === VOLATILITY (4 features: 5-8) ===
        for i, period in enumerate([5, 15, 30, 60]):
            idx = 5 + i
            if idx < 9 and n > period:
                vol = np.zeros(n)
                for j in range(period, n):
                    returns = np.diff(prices[j-period:j]) / prices[j-period:j-1]
                    if len(returns) > 0:
                        vol[j] = np.std(returns) * np.sqrt(252 * 24 * 60)
                features[:, idx] = np.clip(vol / 100, 0, 3)

        # === TREND (3 features: 9-11) ===
        if n >= 30:
            sma_10 = pd.Series(prices).rolling(10).mean().values
            sma_30 = pd.Series(prices).rolling(30).mean().values
            features[:, 9] = np.nan_to_num((sma_10 - sma_30) / sma_30 * 10)

        if n >= 60:
            sma_60 = pd.Series(prices).rolling(60).mean().values
            features[:, 10] = np.nan_to_num((sma_30 - sma_60) / sma_60 * 10)

        if n >= 30:
            features[:, 11] = np.nan_to_num((prices - sma_30) / sma_30 * 10)

        # === MOMENTUM (3 features: 12-14) ===
        # RSI-like
        if n > 14:
            delta = np.diff(prices, prepend=prices[0])
            gain = np.where(delta > 0, delta, 0)
            loss = np.where(delta < 0, -delta, 0)
            avg_gain = pd.Series(gain).rolling(14).mean().values
            avg_loss = pd.Series(loss).rolling(14).mean().values
            rs = np.nan_to_num(avg_gain / (avg_loss + 1e-10))
            rsi = 100 - 100 / (1 + rs)
            features[:, 12] = (rsi - 50) / 50

        # Rate of change
        if n > 10:
            roc = np.zeros(n)
            roc[10:] = (prices[10:] - prices[:-10]) / prices[:-10]
            features[:, 13] = np.clip(roc * 20, -3, 3)

        # Momentum acceleration
        if n > 5:
            mom = np.zeros(n)
            ret_1 = np.diff(prices, prepend=prices[0]) / prices
            mom[5:] = ret_1[5:] - ret_1[:-5]
            features[:, 14] = np.clip(mom * 100, -3, 3)

        # === REGIME (3 features: 15-17) ===
        # Volatility regime
        if n > 30:
            vol_short = pd.Series(prices).pct_change().rolling(10).std().values
            vol_long = pd.Series(prices).pct_change().rolling(30).std().values
            features[:, 15] = np.nan_to_num((vol_short - vol_long) / (vol_long + 1e-10))

        # Trend regime (high-low range)
        if n > 20:
            high_low_range = np.zeros(n)
            for j in range(20, n):
                high_low_range[j] = (np.max(prices[j-20:j]) - np.min(prices[j-20:j])) / prices[j]
            features[:, 16] = np.clip(high_low_range * 20, 0, 3)

        # Mean reversion z-score
        if n > 30:
            z_score = np.zeros(n)
            for j in range(30, n):
                mean = np.mean(prices[j-30:j])
                std = np.std(prices[j-30:j])
                if std > 0:
                    z_score[j] = (prices[j] - mean) / std
            features[:, 17] = np.clip(z_score / 2, -2, 2)

        # ============================================
        # STRING THEORY FEATURES (6 features: 18-23)
        # ============================================

        window = 30

        for i in range(window, n):
            price_window = prices[max(0, i-60):i]
            returns = np.diff(price_window) / price_window[:-1] if len(price_window) > 1 else np.array([0])

            # Feature 18: Hurst Exponent (Geometry)
            hurst = StringTheoryMath.hurst_exponent(price_window)
            features[i, 18] = (hurst - 0.5) * 2  # Normalize: -1 (mean-rev) to +1 (trending)

            # Feature 19: Catastrophe Distance (Catastrophe Theory)
            cat_dist, cat_prob = StringTheoryMath.detect_cusp_catastrophe(price_window)
            features[i, 19] = 1 - cat_dist  # 0 = far, 1 = at catastrophe

            # Feature 20: Shannon Entropy (Information Theory)
            entropy = StringTheoryMath.shannon_entropy(returns)
            features[i, 20] = entropy * 2 - 1  # Normalize to -1 (low) to +1 (high)

            # Feature 21: Kelly Fraction (Ergodic Economics)
            kelly = StringTheoryMath.kelly_fraction(returns)
            features[i, 21] = kelly * 4  # Scale: -1 to +1 range

            # Feature 22: Wasserstein Tension (Optimal Transport)
            w_tension = StringTheoryMath.wasserstein_tension(price_window)
            features[i, 22] = w_tension * 2 - 1  # -1 (low) to +1 (high)

            # Feature 23: Surface Instability
            instability = StringTheoryMath.surface_instability(price_window)
            features[i, 23] = instability * 2 - 1  # -1 (stable) to +1 (unstable)

        # Replace NaN with 0
        features = np.nan_to_num(features, nan=0.0, posinf=3.0, neginf=-3.0)

        return features.astype(np.float32)


# ============================================================
# DATA AUGMENTATION (Bear Market Simulation)
# ============================================================

class DataAugmentation:
    """Augment price data for multi-regime training."""

    @staticmethod
    def time_reversal(prices: np.ndarray) -> np.ndarray:
        """Reverse time to simulate bear market from bull market."""
        return prices[::-1].copy()

    @staticmethod
    def add_noise(prices: np.ndarray, noise_std: float = 0.001) -> np.ndarray:
        """Add small random noise to prices."""
        noise = np.random.normal(0, prices.std() * noise_std, len(prices))
        return prices + noise

    @staticmethod
    def scale_prices(prices: np.ndarray, scale_range: Tuple[float, float] = (0.8, 1.2)) -> np.ndarray:
        """Scale prices to different levels."""
        scale = np.random.uniform(*scale_range)
        return prices * scale

    @staticmethod
    def create_augmented_dataset(prices: np.ndarray, num_augments: int = 5) -> List[np.ndarray]:
        """
        Create multiple augmented versions for robust training.

        Returns:
        - Original data
        - Time-reversed (bear market)
        - Multiple scaled/noisy versions
        """
        datasets = [prices]  # Original

        # Time reversal (CRITICAL for bear market)
        datasets.append(DataAugmentation.time_reversal(prices))

        # Time-reversed with noise
        reversed_noisy = DataAugmentation.add_noise(DataAugmentation.time_reversal(prices))
        datasets.append(reversed_noisy)

        # Scaled versions
        for _ in range(num_augments - 2):
            aug = DataAugmentation.scale_prices(prices)
            aug = DataAugmentation.add_noise(aug)
            datasets.append(aug)

        return datasets


# ============================================================
# MODEL ARCHITECTURE - TRANSFORMER V8 ENTERPRISE
# ============================================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[0, :, 1::2] = torch.cos(position * div_term)
        else:
            pe[0, :, 1::2] = torch.cos(position * div_term[:-1])
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, :x.size(1)]


class TradingTransformerV8Enterprise(nn.Module):
    """
    Enterprise Transformer with 24 features + uncertainty estimation.

    Outputs:
    - action_mean: Trading signal (-1 to 1)
    - action_log_std: Uncertainty
    - value: State value
    - expected_return: Predicted % return
    - regime: Market regime classification
    """

    def __init__(self, config: dict):
        super().__init__()
        self.config = config

        input_dim = config.get("num_features", 24) + 1  # +1 for position
        hidden_dim = config.get("hidden_dim", 512)
        lookback = config.get("lookback", 120)

        # Input projection with layer norm
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(config.get("dropout", 0.15) / 2),
        )

        # Positional encoding
        self.pos_encoder = PositionalEncoding(hidden_dim, lookback)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=config.get("num_heads", 8),
            dim_feedforward=hidden_dim * 4,
            dropout=config.get("dropout", 0.15),
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            config.get("num_layers", 5)
        )

        # Global attention pooling
        self.attention_pool = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Softmax(dim=1)
        )

        # Output heads
        head_hidden = hidden_dim // 2

        # Policy head (action + uncertainty)
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden),
            nn.LayerNorm(head_hidden),
            nn.GELU(),
            nn.Dropout(config.get("dropout", 0.15)),
            nn.Linear(head_hidden, 2)  # mean, log_std
        )

        # Value head
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden),
            nn.LayerNorm(head_hidden),
            nn.GELU(),
            nn.Dropout(config.get("dropout", 0.15)),
            nn.Linear(head_hidden, 1)
        )

        # Expected return head
        self.return_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden),
            nn.LayerNorm(head_hidden),
            nn.GELU(),
            nn.Linear(head_hidden, 1)
        )

        # Regime classification head (5 regimes)
        self.regime_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden),
            nn.LayerNorm(head_hidden),
            nn.GELU(),
            nn.Linear(head_hidden, 5)  # bullish, bearish, ranging, volatile, uncertain
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize weights for stable training."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        # Project input
        x = self.input_proj(x)
        x = self.pos_encoder(x)

        # Transformer
        x = self.transformer(x)

        # Attention pooling over sequence
        attn_weights = self.attention_pool(x)
        x = torch.sum(x * attn_weights, dim=1)

        # Policy
        policy_out = self.policy_head(x)
        action_mean = torch.tanh(policy_out[:, 0:1])
        action_log_std = torch.clamp(policy_out[:, 1:2], -3.0, 0.5)

        # Value
        value = self.value_head(x)

        # Expected return
        expected_return = self.return_head(x) * 0.1

        # Regime
        regime_logits = self.regime_head(x)

        return action_mean, action_log_std, value, expected_return, regime_logits

    def get_action(self, x: torch.Tensor, deterministic: bool = True):
        """Get action for inference."""
        action_mean, action_log_std, value, expected_return, regime_logits = self.forward(x)

        if deterministic:
            return action_mean.squeeze(), value.squeeze(), expected_return.squeeze()

        action_std = torch.exp(action_log_std)
        dist = torch.distributions.Normal(action_mean, action_std)
        action = dist.sample()
        action = torch.clamp(action, -1, 1)

        return action.squeeze(), value.squeeze(), expected_return.squeeze()


# ============================================================
# TRADING ENVIRONMENT V8 ENTERPRISE
# ============================================================

class TradingEnvironmentV8Enterprise:
    """
    Realistic trading environment with String Theory features.
    """

    def __init__(self, prices: np.ndarray, config: dict):
        self.original_prices = prices
        self.config = config
        self.lookback = config.get("lookback", 120)
        self.fee = config.get("trading_fee", 0.004)
        self.spread = config.get("spread", 0.001)
        self.slippage = config.get("slippage", 0.0005)
        self.min_profitable = config.get("min_profitable_move", 0.012)

        # Precompute 24 features
        print(f"  Computing 24 features for {len(prices)} prices...")
        self.features = FeatureEngineeringEnterprise.compute_features(prices)
        print(f"  Features shape: {self.features.shape}")

        self.reset()

    def reset(self, start_idx: Optional[int] = None) -> np.ndarray:
        """Reset environment."""
        max_start = len(self.original_prices) - self.config["steps_per_episode"] - self.lookback - 10

        if start_idx is None:
            self.start_idx = np.random.randint(self.lookback, max(self.lookback + 1, max_start))
        else:
            self.start_idx = start_idx

        self.idx = self.start_idx
        self.position = 0.0
        self.entry_price = 0.0
        self.equity = 100000.0
        self.initial_equity = self.equity
        self.peak_equity = self.equity
        self.total_fees = 0.0
        self.trades = 0
        self.wins = 0
        self.last_trade_idx = 0
        self.pnl_history = []
        self.equity_history = [self.equity]

        return self._get_observation()

    def _get_observation(self) -> np.ndarray:
        """Get observation with 24 features + position."""
        start = self.idx - self.lookback
        end = self.idx

        features = self.features[start:end].copy()

        # Add position as feature
        position_feature = np.full((self.lookback, 1), self.position)
        obs = np.concatenate([features, position_feature], axis=1)

        return obs

    def step(self, action: float) -> Tuple[np.ndarray, float, bool, dict]:
        """Execute trading step with risk-adjusted reward."""
        current_price = self.original_prices[self.idx]
        prev_equity = self.equity

        reward = 0.0
        trade_pnl = 0.0

        # Discretize action
        if action > 0.3:
            target_position = 1.0
        elif action < -0.3:
            target_position = -1.0
        else:
            target_position = 0.0

        # Execute trade
        if target_position != self.position:
            # Close existing position
            if self.position != 0:
                if self.position > 0:
                    exit_price = current_price * (1 - self.spread / 2 - self.slippage)
                    trade_pnl = (exit_price - self.entry_price) / self.entry_price * self.equity * abs(self.position)
                else:
                    exit_price = current_price * (1 + self.spread / 2 + self.slippage)
                    trade_pnl = (self.entry_price - exit_price) / self.entry_price * self.equity * abs(self.position)

                # Pay fee
                fee_paid = self.fee * self.equity * abs(self.position)
                self.total_fees += fee_paid
                self.equity += trade_pnl - fee_paid

                # Track
                self.trades += 1
                if trade_pnl - fee_paid > 0:
                    self.wins += 1
                    reward += self.config.get("win_bonus", 0.03)

                self.pnl_history.append(trade_pnl - fee_paid)

            # Open new position
            if target_position != 0:
                if target_position > 0:
                    self.entry_price = current_price * (1 + self.spread / 2 + self.slippage)
                else:
                    self.entry_price = current_price * (1 - self.spread / 2 - self.slippage)

                fee_paid = self.fee * self.equity * abs(target_position)
                self.total_fees += fee_paid
                self.equity -= fee_paid

                self.last_trade_idx = self.idx

            self.position = target_position
        else:
            # Hold bonus
            reward += self.config.get("hold_bonus", 0.0003)

        # Calculate current equity (including unrealized)
        if self.position != 0:
            if self.position > 0:
                unrealized = (current_price - self.entry_price) / self.entry_price * self.equity * self.position
            else:
                unrealized = (self.entry_price - current_price) / self.entry_price * self.equity * abs(self.position)
            current_equity = self.equity + unrealized
        else:
            current_equity = self.equity

        self.equity_history.append(current_equity)

        # Track peak and drawdown
        self.peak_equity = max(self.peak_equity, current_equity)
        drawdown = (self.peak_equity - current_equity) / self.peak_equity

        # === RISK-ADJUSTED REWARD ===

        # 1. Equity change (main reward)
        equity_change = (current_equity - prev_equity) / self.initial_equity
        reward += equity_change * self.config.get("pnl_scale", 100.0)

        # 2. Drawdown penalty (quadratic)
        if drawdown > 0.03:
            reward -= (drawdown ** 2) * self.config.get("drawdown_penalty", 0.5)

        # 3. Overtrading penalty
        if self.trades > 0:
            steps_since_trade = self.idx - self.last_trade_idx
            if steps_since_trade < 15 and target_position != self.position:
                reward -= self.config.get("overtrading_penalty", 0.01)

        # 4. Sharpe bonus (consistency)
        if len(self.equity_history) > 50:
            recent_returns = np.diff(self.equity_history[-50:]) / np.array(self.equity_history[-50:-1])
            if len(recent_returns) > 0 and np.std(recent_returns) > 0:
                sharpe = np.mean(recent_returns) / np.std(recent_returns)
                if sharpe > 0:
                    reward += sharpe * self.config.get("sharpe_bonus", 0.1)

        # Move forward
        self.idx += 1
        done = self.idx >= self.start_idx + self.config["steps_per_episode"] - 1
        done = done or self.idx >= len(self.original_prices) - 1
        done = done or current_equity < self.initial_equity * 0.5

        info = {
            "equity": current_equity,
            "position": self.position,
            "trades": self.trades,
            "wins": self.wins,
            "drawdown": drawdown,
            "fees": self.total_fees,
        }

        return self._get_observation(), reward, done, info


# ============================================================
# VECTORIZED ENVIRONMENT
# ============================================================

class VectorizedEnvV8Enterprise:
    """Vectorized environment for parallel training."""

    def __init__(self, prices_list: List[np.ndarray], config: dict, num_envs: int):
        self.num_envs = num_envs
        self.config = config

        print(f"\nCreating {num_envs} parallel environments...")
        self.envs = []
        for i in range(num_envs):
            prices = prices_list[i % len(prices_list)]
            self.envs.append(TradingEnvironmentV8Enterprise(prices, config))
        print(f"All environments created successfully!")

    def reset(self) -> np.ndarray:
        obs = [env.reset() for env in self.envs]
        return np.stack(obs)

    def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]:
        results = [env.step(a) for env, a in zip(self.envs, actions)]

        obs = np.stack([r[0] for r in results])
        rewards = np.array([r[1] for r in results])
        dones = np.array([r[2] for r in results])
        infos = [r[3] for r in results]

        for i, done in enumerate(dones):
            if done:
                obs[i] = self.envs[i].reset()

        return obs, rewards, dones, infos


# ============================================================
# PPO TRAINER
# ============================================================

class PPOTrainerV8Enterprise:
    """PPO trainer with entropy control and learning rate scheduling."""

    def __init__(self, model: TradingTransformerV8Enterprise, config: dict):
        self.model = model
        self.config = config
        self.device = next(model.parameters()).device

        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config["learning_rate"],
            weight_decay=0.01,
            eps=1e-5
        )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=100, T_mult=2
        )

        self.entropy_coef = config["entropy_coef"]

    def compute_gae(self, rewards, values, dones, next_value):
        gamma = self.config["gamma"]
        gae_lambda = self.config["gae_lambda"]

        advantages = np.zeros_like(rewards)
        last_gae = 0

        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_val = next_value
            else:
                next_val = values[t + 1]

            delta = rewards[t] + gamma * next_val * (1 - dones[t]) - values[t]
            last_gae = delta + gamma * gae_lambda * (1 - dones[t]) * last_gae
            advantages[t] = last_gae

        returns = advantages + values
        return advantages, returns

    def update(self, rollout_data: dict) -> dict:
        obs = torch.FloatTensor(rollout_data["obs"]).to(self.device)
        actions = torch.FloatTensor(rollout_data["actions"]).to(self.device)
        old_log_probs = torch.FloatTensor(rollout_data["log_probs"]).to(self.device)
        advantages = torch.FloatTensor(rollout_data["advantages"]).to(self.device)
        returns = torch.FloatTensor(rollout_data["returns"]).to(self.device)

        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        batch_size = self.config["batch_size"]
        indices = np.arange(len(obs))
        np.random.shuffle(indices)

        total_policy_loss = 0
        total_value_loss = 0
        total_entropy = 0
        num_batches = 0

        for start in range(0, len(obs), batch_size):
            end = start + batch_size
            batch_idx = indices[start:end]

            batch_obs = obs[batch_idx]
            batch_actions = actions[batch_idx]
            batch_old_log_probs = old_log_probs[batch_idx]
            batch_advantages = advantages[batch_idx]
            batch_returns = returns[batch_idx]

            action_mean, action_log_std, values, _, _ = self.model(batch_obs)
            action_std = torch.exp(action_log_std)

            dist = torch.distributions.Normal(action_mean, action_std)
            new_log_probs = dist.log_prob(batch_actions.unsqueeze(-1)).squeeze(-1)
            entropy = dist.entropy().mean()

            ratio = torch.exp(new_log_probs - batch_old_log_probs)
            clip_epsilon = self.config["clip_epsilon"]

            surr1 = ratio * batch_advantages
            surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * batch_advantages
            policy_loss = -torch.min(surr1, surr2).mean()

            value_loss = F.mse_loss(values.squeeze(), batch_returns)

            loss = (
                policy_loss
                + self.config["value_coef"] * value_loss
                - self.entropy_coef * entropy
            )

            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
            self.optimizer.step()

            total_policy_loss += policy_loss.item()
            total_value_loss += value_loss.item()
            total_entropy += entropy.item()
            num_batches += 1

        self.scheduler.step()

        self.entropy_coef = max(
            self.config["min_entropy_coef"],
            self.entropy_coef * self.config["entropy_decay"]
        )

        return {
            "policy_loss": total_policy_loss / num_batches,
            "value_loss": total_value_loss / num_batches,
            "entropy": total_entropy / num_batches,
            "entropy_coef": self.entropy_coef,
            "lr": self.optimizer.param_groups[0]["lr"],
        }


# ============================================================
# MAIN TRAINING LOOP
# ============================================================

def main():
    print("\n" + "=" * 70)
    print("BTC TRADING MODEL V8 ENTERPRISE - STRING THEORY INTEGRATION")
    print("=" * 70)
    print(f"Device: {device}")
    print(f"Episodes: {CONFIG['episodes']}")
    print(f"Features: {CONFIG['num_features']} (18 Technical + 6 String Theory)")
    print(f"Entropy coef: {CONFIG['entropy_coef']} → {CONFIG['min_entropy_coef']} (floor)")
    print(f"Trading fee: {CONFIG['trading_fee']} (REALISTIC)")
    print(f"Min profitable move: {CONFIG['min_profitable_move']}")
    print("=" * 70)

    # Load data
    data_path = "/workspace/prices_btc_2025.csv"
    if not os.path.exists(data_path):
        print(f"ERROR: Data file not found: {data_path}")
        sys.exit(1)

    print(f"\nLoading price data from {data_path}...")
    df = pd.read_csv(data_path)

    if 'price' in df.columns:
        prices = df['price'].values
    elif 'close' in df.columns:
        prices = df['close'].values
    else:
        prices = df.iloc[:, 1].values

    prices = prices.astype(np.float64)
    print(f"Loaded {len(prices)} price points")
    print(f"Price range: {prices.min():.2f} - {prices.max():.2f}")

    # Create augmented datasets
    print("\n" + "=" * 70)
    print("CREATING AUGMENTED DATASETS")
    print("=" * 70)
    augmented_prices = DataAugmentation.create_augmented_dataset(prices, num_augments=6)
    print(f"Created {len(augmented_prices)} datasets:")
    print("  - Original (bullish)")
    print("  - Time-reversed (bear market simulation)")
    print("  - Time-reversed with noise")
    print("  - Multiple scaled/noisy versions")

    # Create output directory
    os.makedirs(CONFIG["output_dir"], exist_ok=True)

    # Initialize model
    print("\n" + "=" * 70)
    print("INITIALIZING MODEL")
    print("=" * 70)
    model = TradingTransformerV8Enterprise(CONFIG).to(device)
    trainer = PPOTrainerV8Enterprise(model, CONFIG)

    param_count = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {param_count:,}")

    # Create vectorized environment
    print("\n" + "=" * 70)
    print("CREATING ENVIRONMENTS")
    print("=" * 70)
    vec_env = VectorizedEnvV8Enterprise(augmented_prices, CONFIG, CONFIG["num_envs"])

    # Training state
    best_reward = float('-inf')
    best_sharpe = float('-inf')

    print("\n" + "=" * 70)
    print("STARTING TRAINING")
    print("=" * 70)

    for episode in range(CONFIG["episodes"]):
        # Collect rollout
        obs = vec_env.reset()
        rollout = {
            "obs": [],
            "actions": [],
            "rewards": [],
            "dones": [],
            "values": [],
            "log_probs": [],
        }

        episode_rewards = []
        episode_trades = []
        episode_wins = []

        for step in range(CONFIG["steps_per_episode"]):
            obs_tensor = torch.FloatTensor(obs).to(device)

            with torch.no_grad():
                action_mean, action_log_std, values, _, _ = model(obs_tensor)
                action_std = torch.exp(action_log_std)
                dist = torch.distributions.Normal(action_mean, action_std)
                actions = dist.sample()
                actions = torch.clamp(actions, -1, 1)
                log_probs = dist.log_prob(actions)

            actions_np = actions.squeeze(-1).cpu().numpy()
            next_obs, rewards, dones, infos = vec_env.step(actions_np)

            rollout["obs"].append(obs)
            rollout["actions"].append(actions_np)
            rollout["rewards"].append(rewards)
            rollout["dones"].append(dones)
            rollout["values"].append(values.squeeze(-1).cpu().numpy())
            rollout["log_probs"].append(log_probs.squeeze(-1).cpu().numpy())

            obs = next_obs

            for info in infos:
                episode_rewards.append(info.get("equity", 100000) / 100000 - 1)
                episode_trades.append(info.get("trades", 0))
                episode_wins.append(info.get("wins", 0))

        # Compute returns and advantages
        with torch.no_grad():
            obs_tensor = torch.FloatTensor(obs).to(device)
            _, _, next_values, _, _ = model(obs_tensor)
            next_values = next_values.squeeze(-1).cpu().numpy()

        for key in rollout:
            rollout[key] = np.stack(rollout[key])

        T, N = rollout["rewards"].shape
        advantages_all = np.zeros((T, N))
        returns_all = np.zeros((T, N))

        for env_idx in range(N):
            adv, ret = trainer.compute_gae(
                rollout["rewards"][:, env_idx],
                rollout["values"][:, env_idx],
                rollout["dones"][:, env_idx],
                next_values[env_idx]
            )
            advantages_all[:, env_idx] = adv
            returns_all[:, env_idx] = ret

        flat_rollout = {
            "obs": rollout["obs"].reshape(-1, *rollout["obs"].shape[2:]),
            "actions": rollout["actions"].flatten(),
            "log_probs": rollout["log_probs"].flatten(),
            "advantages": advantages_all.flatten(),
            "returns": returns_all.flatten(),
        }

        train_stats = trainer.update(flat_rollout)

        # Metrics
        mean_reward = np.mean(episode_rewards)
        mean_trades = np.mean(episode_trades)
        mean_wins = np.mean(episode_wins)
        win_rate = mean_wins / max(mean_trades, 1)

        if len(episode_rewards) > 1:
            sharpe = np.mean(episode_rewards) / (np.std(episode_rewards) + 1e-8) * np.sqrt(252)
        else:
            sharpe = 0

        # Logging
        if episode % CONFIG["log_every"] == 0:
            pnl = mean_reward * 100000
            print(f"Ep {episode:4d} | Reward: {mean_reward:8.4f} | PnL: {pnl:10.0f} | "
                  f"Trades: {mean_trades:5.1f} | WR: {win_rate:5.2f} | "
                  f"Entropy: {train_stats['entropy']:.4f} | Coef: {train_stats['entropy_coef']:.4f}")

        # Save best model
        if mean_reward > best_reward:
            best_reward = mean_reward
            save_path = os.path.join(CONFIG["output_dir"], "model_best.pt")
            torch.save({
                "model_state_dict": model.state_dict(),
                "config": CONFIG,
                "episode": episode,
                "best_reward": float(best_reward),
                "entropy": float(train_stats['entropy']),
                "win_rate": float(win_rate),
            }, save_path)
            print(f"  -> New best model saved: reward={best_reward:.4f}")

        # Checkpoint
        if episode > 0 and episode % CONFIG["save_every"] == 0:
            save_path = os.path.join(CONFIG["output_dir"], f"model_ep{episode}.pt")
            torch.save({
                "model_state_dict": model.state_dict(),
                "config": CONFIG,
                "episode": episode,
                "reward": float(mean_reward),
                "entropy": float(train_stats['entropy']),
            }, save_path)

    # Final model
    save_path = os.path.join(CONFIG["output_dir"], "model_final.pt")
    torch.save({
        "model_state_dict": model.state_dict(),
        "config": CONFIG,
        "episode": CONFIG["episodes"],
        "best_reward": float(best_reward),
    }, save_path)
    print(f"\nFinal model saved: {save_path}")

    print("\n" + "=" * 70)
    print("TRAINING COMPLETED")
    print(f"Best reward: {best_reward:.4f}")
    print("=" * 70)


if __name__ == "__main__":
    main()
