#!/usr/bin/env python3
"""
BTC Trading Model V8 FAST - Pre-computed Features
==================================================

OTTIMIZZAZIONE CRITICA:
Features pre-calcolate UNA SOLA VOLTA per dataset invece che per ogni ambiente.
Risparmio: ~2 ore di computazione features.

Author: Claude Code
Date: January 2026
"""

import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from datetime import datetime
from typing import Optional, List, Tuple
import warnings
warnings.filterwarnings('ignore')

# Mixed precision
try:
    from torch.cuda.amp import GradScaler, autocast
    USE_AMP = True
except ImportError:
    USE_AMP = False

# ============================================================
# CONFIG
# ============================================================

CONFIG = {
    "episodes": 1000,
    "steps_per_episode": 3000,
    "batch_size": 2048,        # Reduced for VRAM
    "num_envs": 32,            # Reduced for VRAM
    "learning_rate": 5e-5,
    "warmup_episodes": 50,
    "gamma": 0.995,
    "gae_lambda": 0.98,
    "clip_epsilon": 0.1,
    "value_coef": 0.5,
    "max_grad_norm": 0.3,
    "entropy_coef": 0.05,
    "entropy_decay": 0.99995,
    "min_entropy_coef": 0.015,
    "trading_fee": 0.004,
    "spread": 0.001,
    "slippage": 0.0005,
    "pnl_scale": 100.0,
    "hold_bonus": 0.0003,
    "win_bonus": 0.03,
    "drawdown_penalty": 0.5,
    "overtrading_penalty": 0.01,
    "sharpe_bonus": 0.1,
    "hidden_dim": 256,         # Reduced for VRAM (was 512)
    "num_heads": 4,            # Reduced for VRAM (was 8)
    "num_layers": 3,           # Reduced for VRAM (was 5)
    "dropout": 0.18,
    "lookback": 60,            # Reduced for VRAM (was 120)
    "num_features": 24,
    "weight_decay": 0.02,
    "gradient_accumulation": 4, # Increased to compensate smaller batch
    "output_dir": "/workspace/models_v8_fast",
    "save_every": 50,
    "log_every": 10,
}

print("=" * 70)
print("BTC TRADING MODEL V8 FAST - PRE-COMPUTED FEATURES")
print(f"Started: {datetime.now()}")
print("=" * 70)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"Mixed Precision: {USE_AMP}")

# ============================================================
# STRING THEORY MATH (FAST)
# ============================================================

class StringTheoryFast:
    """Optimized String Theory calculations."""

    @staticmethod
    def hurst_fast(prices: np.ndarray) -> float:
        n = len(prices)
        if n < 20:
            return 0.5
        returns = np.diff(np.log(prices + 1e-10))
        if len(returns) < 10:
            return 0.5
        mean_adj = returns - np.mean(returns)
        cumsum = np.cumsum(mean_adj)
        R = np.max(cumsum) - np.min(cumsum)
        S = np.std(returns)
        if S < 1e-10 or R < 1e-10:
            return 0.5
        return np.clip(np.log(R / S) / np.log(n), 0.1, 0.9)

    @staticmethod
    def catastrophe_fast(prices: np.ndarray) -> float:
        if len(prices) < 20:
            return 1.0
        returns = np.diff(prices[-20:]) / prices[-20:-1]
        vol = np.std(returns)
        bias = np.mean(returns) / (vol + 1e-10)
        return np.tanh(vol * 50 + abs(bias) * 2.5)

    @staticmethod
    def entropy_fast(returns: np.ndarray) -> float:
        if len(returns) < 10:
            return 0.5
        hist, _ = np.histogram(returns, bins=15, density=True)
        hist = hist[hist > 0]
        if len(hist) == 0:
            return 0.5
        hist = hist / hist.sum()
        entropy = -np.sum(hist * np.log(hist + 1e-10))
        return np.clip(entropy / np.log(15), 0, 1)

    @staticmethod
    def kelly_fast(returns: np.ndarray) -> float:
        if len(returns) < 10:
            return 0.0
        mean = np.mean(returns)
        var = np.var(returns)
        if var < 1e-10:
            return 0.0
        return np.clip(mean / var, -0.15, 0.15)

    @staticmethod
    def wasserstein_fast(prices: np.ndarray) -> float:
        if len(prices) < 40:
            return 0.5
        r1 = np.diff(prices[-20:]) / prices[-20:-1]
        r2 = np.diff(prices[-40:-20]) / prices[-40:-21]
        q1 = np.percentile(r1, [25, 50, 75])
        q2 = np.percentile(r2, [25, 50, 75])
        return np.tanh(np.mean(np.abs(q1 - q2)) * 100)

    @staticmethod
    def instability_fast(prices: np.ndarray) -> float:
        if len(prices) < 20:
            return 0.5
        returns = np.diff(prices[-20:]) / prices[-20:-1]
        vol = np.tanh(np.std(returns) * 50)
        mom = np.tanh(abs(returns[-1] - returns[0]) * 100) if len(returns) > 1 else 0
        return np.clip(0.6 * vol + 0.4 * mom, 0, 1)


# ============================================================
# FEATURE ENGINEERING (VECTORIZED)
# ============================================================

def compute_all_features(prices: np.ndarray) -> np.ndarray:
    """
    Compute 24 features for entire price array at once.
    VECTORIZED for speed.
    """
    n = len(prices)
    features = np.zeros((n, 24), dtype=np.float32)

    price_series = pd.Series(prices)
    pct_change = price_series.pct_change().values

    # Returns (0-4)
    for i, period in enumerate([1, 5, 15, 30, 60]):
        if n > period:
            ret = np.zeros(n)
            ret[period:] = (prices[period:] - prices[:-period]) / prices[:-period]
            features[:, i] = np.clip(ret * 10, -3, 3)

    # Volatility (5-8)
    for i, period in enumerate([5, 15, 30, 60]):
        if n > period:
            vol = price_series.pct_change().rolling(period).std().values * np.sqrt(525600)
            features[:, 5 + i] = np.nan_to_num(np.clip(vol / 100, 0, 3))

    # Trend (9-11)
    sma_10 = price_series.rolling(10).mean().values
    sma_30 = price_series.rolling(30).mean().values
    sma_60 = price_series.rolling(60).mean().values

    features[:, 9] = np.nan_to_num(np.clip((sma_10 - sma_30) / (sma_30 + 1e-10) * 10, -3, 3))
    features[:, 10] = np.nan_to_num(np.clip((sma_30 - sma_60) / (sma_60 + 1e-10) * 10, -3, 3))
    features[:, 11] = np.nan_to_num(np.clip((prices - sma_30) / (sma_30 + 1e-10) * 10, -3, 3))

    # Momentum (12-14)
    delta = np.diff(prices, prepend=prices[0])
    gain = np.where(delta > 0, delta, 0)
    loss = np.where(delta < 0, -delta, 0)
    avg_gain = pd.Series(gain).rolling(14).mean().values
    avg_loss = pd.Series(loss).rolling(14).mean().values
    rs = np.nan_to_num(avg_gain / (avg_loss + 1e-10))
    rsi = 100 - 100 / (1 + rs)
    features[:, 12] = np.clip((rsi - 50) / 50, -1, 1)

    roc = np.zeros(n)
    roc[10:] = (prices[10:] - prices[:-10]) / prices[:-10]
    features[:, 13] = np.clip(roc * 20, -3, 3)

    mom = np.zeros(n)
    mom[5:] = pct_change[5:] - pct_change[:-5]
    features[:, 14] = np.nan_to_num(np.clip(mom * 100, -3, 3))

    # Regime (15-17)
    vol_short = pd.Series(pct_change).rolling(10).std().values
    vol_long = pd.Series(pct_change).rolling(30).std().values
    features[:, 15] = np.nan_to_num(np.clip((vol_short - vol_long) / (vol_long + 1e-10), -3, 3))

    for j in range(20, n):
        features[j, 16] = np.clip((np.max(prices[j-20:j]) - np.min(prices[j-20:j])) / prices[j] * 20, 0, 3)

    for j in range(30, n):
        mean = np.mean(prices[j-30:j])
        std = np.std(prices[j-30:j])
        if std > 0:
            features[j, 17] = np.clip((prices[j] - mean) / std / 2, -2, 2)

    # String Theory (18-23) - computed in blocks for speed
    print("    Computing String Theory features...")
    window = 60

    for i in range(window, n):
        if i % 50000 == 0:
            print(f"      Progress: {i}/{n} ({100*i/n:.1f}%)")

        pw = prices[i-window:i]
        returns = np.diff(pw) / pw[:-1]

        features[i, 18] = (StringTheoryFast.hurst_fast(pw) - 0.5) * 2
        features[i, 19] = 1 - StringTheoryFast.catastrophe_fast(pw)
        features[i, 20] = StringTheoryFast.entropy_fast(returns) * 2 - 1
        features[i, 21] = StringTheoryFast.kelly_fast(returns) * 6
        features[i, 22] = StringTheoryFast.wasserstein_fast(pw) * 2 - 1
        features[i, 23] = StringTheoryFast.instability_fast(pw) * 2 - 1

    print("    Features completed!")
    return np.nan_to_num(features, nan=0.0, posinf=3.0, neginf=-3.0)


# ============================================================
# DATA AUGMENTATION
# ============================================================

def create_augmented_datasets(prices: np.ndarray) -> List[np.ndarray]:
    """Create augmented price datasets."""
    datasets = [prices]
    datasets.append(prices[::-1].copy())  # Time reversal

    # Scaled versions
    for scale in [0.7, 0.85, 1.15, 1.3]:
        scaled = prices * scale
        noise = np.random.normal(0, scaled.std() * 0.0005, len(scaled))
        datasets.append(scaled + noise)

    return datasets


# ============================================================
# MODEL
# ============================================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[0, :, 1::2] = torch.cos(position * div_term)
        else:
            pe[0, :, 1::2] = torch.cos(position * div_term[:-1])
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class TradingTransformerV8(nn.Module):
    def __init__(self, config):
        super().__init__()
        input_dim = config["num_features"] + 1
        hidden_dim = config["hidden_dim"]
        dropout = config["dropout"]

        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout / 2),
        )
        self.pos_encoder = PositionalEncoding(hidden_dim, config["lookback"])

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=config["num_heads"],
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, config["num_layers"])

        self.attention_pool = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Softmax(dim=1))

        head_hidden = hidden_dim // 2
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden), nn.LayerNorm(head_hidden),
            nn.GELU(), nn.Dropout(dropout), nn.Linear(head_hidden, 2)
        )
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, head_hidden), nn.LayerNorm(head_hidden),
            nn.GELU(), nn.Dropout(dropout), nn.Linear(head_hidden, 1)
        )

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        attn = self.attention_pool(x)
        x = torch.sum(x * attn, dim=1)

        policy = self.policy_head(x)
        action_mean = torch.tanh(policy[:, 0:1])
        action_log_std = torch.clamp(policy[:, 1:2], -3.0, 0.5)
        value = self.value_head(x)

        return action_mean, action_log_std, value


# ============================================================
# ENVIRONMENT (uses pre-computed features)
# ============================================================

class TradingEnvFast:
    """Trading environment with pre-computed features."""

    def __init__(self, features: np.ndarray, prices: np.ndarray, config: dict):
        self.features = features  # Pre-computed!
        self.prices = prices
        self.config = config
        self.lookback = config["lookback"]
        self.fee = config["trading_fee"]
        self.spread = config["spread"]
        self.slippage = config["slippage"]
        self.reset()

    def reset(self, start_idx: Optional[int] = None) -> np.ndarray:
        max_start = len(self.prices) - self.config["steps_per_episode"] - self.lookback - 10
        if start_idx is None:
            self.start_idx = np.random.randint(self.lookback, max(self.lookback + 1, max_start))
        else:
            self.start_idx = start_idx

        self.idx = self.start_idx
        self.position = 0.0
        self.entry_price = 0.0
        self.equity = 100000.0
        self.initial_equity = self.equity
        self.peak_equity = self.equity
        self.total_fees = 0.0
        self.trades = 0
        self.wins = 0
        self.last_trade_idx = 0
        self.equity_history = [self.equity]

        return self._get_obs()

    def _get_obs(self) -> np.ndarray:
        start = self.idx - self.lookback
        features = self.features[start:self.idx].copy()
        position_feature = np.full((self.lookback, 1), self.position, dtype=np.float32)
        return np.concatenate([features, position_feature], axis=1)

    def step(self, action: float) -> Tuple[np.ndarray, float, bool, dict]:
        price = self.prices[self.idx]
        prev_equity = self.equity
        reward = 0.0

        target = 1.0 if action > 0.3 else (-1.0 if action < -0.3 else 0.0)

        if target != self.position:
            if self.position != 0:
                if self.position > 0:
                    exit_price = price * (1 - self.spread / 2 - self.slippage)
                    pnl = (exit_price - self.entry_price) / self.entry_price * self.equity
                else:
                    exit_price = price * (1 + self.spread / 2 + self.slippage)
                    pnl = (self.entry_price - exit_price) / self.entry_price * self.equity

                fee = self.fee * self.equity
                self.total_fees += fee
                self.equity += pnl - fee
                self.trades += 1
                if pnl - fee > 0:
                    self.wins += 1
                    reward += self.config["win_bonus"]

            if target != 0:
                if target > 0:
                    self.entry_price = price * (1 + self.spread / 2 + self.slippage)
                else:
                    self.entry_price = price * (1 - self.spread / 2 - self.slippage)
                fee = self.fee * self.equity
                self.total_fees += fee
                self.equity -= fee
                self.last_trade_idx = self.idx

            self.position = target
        else:
            reward += self.config["hold_bonus"]

        # Current equity
        if self.position != 0:
            if self.position > 0:
                unrealized = (price - self.entry_price) / self.entry_price * self.equity
            else:
                unrealized = (self.entry_price - price) / self.entry_price * self.equity
            current_equity = self.equity + unrealized
        else:
            current_equity = self.equity

        self.equity_history.append(current_equity)
        self.peak_equity = max(self.peak_equity, current_equity)
        drawdown = (self.peak_equity - current_equity) / self.peak_equity

        # Reward
        equity_change = (current_equity - prev_equity) / self.initial_equity
        reward += equity_change * self.config["pnl_scale"]

        if drawdown > 0.03:
            reward -= (drawdown ** 2) * self.config["drawdown_penalty"]

        if self.trades > 0 and self.idx - self.last_trade_idx < 15:
            reward -= self.config["overtrading_penalty"]

        if len(self.equity_history) > 50:
            recent = np.diff(self.equity_history[-50:]) / np.array(self.equity_history[-50:-1])
            if np.std(recent) > 0:
                sharpe = np.mean(recent) / np.std(recent)
                if sharpe > 0:
                    reward += sharpe * self.config["sharpe_bonus"]

        self.idx += 1
        done = self.idx >= self.start_idx + self.config["steps_per_episode"] - 1
        done = done or self.idx >= len(self.prices) - 1
        done = done or current_equity < self.initial_equity * 0.5

        return self._get_obs(), reward, done, {
            "equity": current_equity, "trades": self.trades,
            "wins": self.wins, "drawdown": drawdown
        }


class VectorizedEnvFast:
    """Vectorized environment with shared pre-computed features."""

    def __init__(self, features_list: List[np.ndarray], prices_list: List[np.ndarray],
                 config: dict, num_envs: int):
        self.num_envs = num_envs
        self.envs = []

        print(f"\nCreating {num_envs} environments (features pre-computed)...")
        for i in range(num_envs):
            idx = i % len(features_list)
            self.envs.append(TradingEnvFast(features_list[idx], prices_list[idx], config))
        print(f"All {num_envs} environments ready!")

    def reset(self) -> np.ndarray:
        return np.stack([env.reset() for env in self.envs])

    def step(self, actions: np.ndarray):
        results = [env.step(a) for env, a in zip(self.envs, actions)]
        obs = np.stack([r[0] for r in results])
        rewards = np.array([r[1] for r in results])
        dones = np.array([r[2] for r in results])
        infos = [r[3] for r in results]

        for i, done in enumerate(dones):
            if done:
                obs[i] = self.envs[i].reset()

        return obs, rewards, dones, infos


# ============================================================
# PPO TRAINER
# ============================================================

class PPOTrainer:
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.device = next(model.parameters()).device

        self.optimizer = torch.optim.AdamW(
            model.parameters(), lr=config["learning_rate"],
            weight_decay=config["weight_decay"], eps=1e-5
        )
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=100, T_mult=2
        )
        self.entropy_coef = config["entropy_coef"]
        self.scaler = GradScaler() if USE_AMP else None
        self.grad_accum = config["gradient_accumulation"]

    def compute_gae(self, rewards, values, dones, next_value):
        gamma, gae_lambda = self.config["gamma"], self.config["gae_lambda"]
        advantages = np.zeros_like(rewards)
        last_gae = 0
        for t in reversed(range(len(rewards))):
            next_val = next_value if t == len(rewards) - 1 else values[t + 1]
            delta = rewards[t] + gamma * next_val * (1 - dones[t]) - values[t]
            last_gae = delta + gamma * gae_lambda * (1 - dones[t]) * last_gae
            advantages[t] = last_gae
        return advantages, advantages + values

    def update(self, rollout, episode):
        # Warmup
        lr_mult = min(1.0, (episode + 1) / self.config["warmup_episodes"])
        for pg in self.optimizer.param_groups:
            pg['lr'] = self.config["learning_rate"] * lr_mult

        obs = torch.FloatTensor(rollout["obs"]).to(self.device)
        actions = torch.FloatTensor(rollout["actions"]).to(self.device)
        old_log_probs = torch.FloatTensor(rollout["log_probs"]).to(self.device)
        advantages = torch.FloatTensor(rollout["advantages"]).to(self.device)
        returns = torch.FloatTensor(rollout["returns"]).to(self.device)

        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        batch_size = self.config["batch_size"]
        indices = np.random.permutation(len(obs))

        total_loss = {"policy": 0, "value": 0, "entropy": 0}
        num_batches = 0

        self.optimizer.zero_grad()

        for batch_idx, start in enumerate(range(0, len(obs), batch_size)):
            batch_indices = indices[start:start + batch_size]
            b_obs = obs[batch_indices]
            b_actions = actions[batch_indices]
            b_old_lp = old_log_probs[batch_indices]
            b_adv = advantages[batch_indices]
            b_ret = returns[batch_indices]

            if USE_AMP:
                with autocast():
                    action_mean, action_log_std, values = self.model(b_obs)
                    action_std = torch.exp(action_log_std)
                    dist = torch.distributions.Normal(action_mean, action_std)
                    new_lp = dist.log_prob(b_actions.unsqueeze(-1)).squeeze(-1)
                    entropy = dist.entropy().mean()

                    ratio = torch.exp(new_lp - b_old_lp)
                    surr1 = ratio * b_adv
                    surr2 = torch.clamp(ratio, 1 - self.config["clip_epsilon"],
                                       1 + self.config["clip_epsilon"]) * b_adv
                    policy_loss = -torch.min(surr1, surr2).mean()
                    value_loss = F.mse_loss(values.squeeze(), b_ret)

                    loss = (policy_loss + self.config["value_coef"] * value_loss
                           - self.entropy_coef * entropy) / self.grad_accum

                self.scaler.scale(loss).backward()
                if (batch_idx + 1) % self.grad_accum == 0:
                    self.scaler.unscale_(self.optimizer)
                    nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()
            else:
                action_mean, action_log_std, values = self.model(b_obs)
                action_std = torch.exp(action_log_std)
                dist = torch.distributions.Normal(action_mean, action_std)
                new_lp = dist.log_prob(b_actions.unsqueeze(-1)).squeeze(-1)
                entropy = dist.entropy().mean()

                ratio = torch.exp(new_lp - b_old_lp)
                surr1 = ratio * b_adv
                surr2 = torch.clamp(ratio, 1 - self.config["clip_epsilon"],
                                   1 + self.config["clip_epsilon"]) * b_adv
                policy_loss = -torch.min(surr1, surr2).mean()
                value_loss = F.mse_loss(values.squeeze(), b_ret)

                loss = (policy_loss + self.config["value_coef"] * value_loss
                       - self.entropy_coef * entropy) / self.grad_accum
                loss.backward()

                if (batch_idx + 1) % self.grad_accum == 0:
                    nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
                    self.optimizer.step()
                    self.optimizer.zero_grad()

            total_loss["policy"] += policy_loss.item()
            total_loss["value"] += value_loss.item()
            total_loss["entropy"] += entropy.item()
            num_batches += 1

        self.scheduler.step()
        self.entropy_coef = max(self.config["min_entropy_coef"],
                                self.entropy_coef * self.config["entropy_decay"])

        return {k: v / max(num_batches, 1) for k, v in total_loss.items()} | {
            "entropy_coef": self.entropy_coef,
            "lr": self.optimizer.param_groups[0]["lr"]
        }


# ============================================================
# MAIN
# ============================================================

def main():
    print("\n" + "=" * 70)
    print("LOADING DATA")
    print("=" * 70)

    data_path = "/workspace/prices_btc_2025.csv"
    if not os.path.exists(data_path):
        print(f"ERROR: {data_path} not found")
        sys.exit(1)

    df = pd.read_csv(data_path)
    prices = df['price'].values if 'price' in df.columns else df.iloc[:, 1].values
    prices = prices.astype(np.float64)
    print(f"Loaded {len(prices)} prices ({prices.min():.2f} - {prices.max():.2f})")

    # Augment
    print("\n" + "=" * 70)
    print("DATA AUGMENTATION")
    print("=" * 70)
    prices_list = create_augmented_datasets(prices)
    print(f"Created {len(prices_list)} datasets (including time-reversed)")

    # PRE-COMPUTE FEATURES (the key optimization!)
    print("\n" + "=" * 70)
    print("PRE-COMPUTING FEATURES (this takes ~10-15 min)")
    print("=" * 70)

    features_list = []
    for i, p in enumerate(prices_list):
        print(f"\n  Dataset {i+1}/{len(prices_list)} ({len(p)} points)...")
        f = compute_all_features(p)
        features_list.append(f)
        print(f"  Dataset {i+1} features: {f.shape}")

    print("\nAll features pre-computed!")

    # Create output dir
    os.makedirs(CONFIG["output_dir"], exist_ok=True)

    # Model
    print("\n" + "=" * 70)
    print("INITIALIZING MODEL")
    print("=" * 70)
    model = TradingTransformerV8(CONFIG).to(device)
    trainer = PPOTrainer(model, CONFIG)
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Environment
    print("\n" + "=" * 70)
    print("CREATING ENVIRONMENTS")
    print("=" * 70)
    vec_env = VectorizedEnvFast(features_list, prices_list, CONFIG, CONFIG["num_envs"])

    # Training
    best_reward = float('-inf')

    print("\n" + "=" * 70)
    print("TRAINING STARTED")
    print("=" * 70)

    for episode in range(CONFIG["episodes"]):
        obs = vec_env.reset()
        rollout = {"obs": [], "actions": [], "rewards": [], "dones": [], "values": [], "log_probs": []}

        episode_data = {"rewards": [], "trades": [], "wins": []}

        for _ in range(CONFIG["steps_per_episode"]):
            obs_t = torch.FloatTensor(obs).to(device)

            with torch.no_grad():
                if USE_AMP:
                    with autocast():
                        action_mean, action_log_std, values = model(obs_t)
                else:
                    action_mean, action_log_std, values = model(obs_t)

                action_std = torch.exp(action_log_std)
                dist = torch.distributions.Normal(action_mean, action_std)
                actions = torch.clamp(dist.sample(), -1, 1)
                log_probs = dist.log_prob(actions)

            actions_np = actions.squeeze(-1).cpu().numpy()
            next_obs, rewards, dones, infos = vec_env.step(actions_np)

            rollout["obs"].append(obs)
            rollout["actions"].append(actions_np)
            rollout["rewards"].append(rewards)
            rollout["dones"].append(dones)
            rollout["values"].append(values.squeeze(-1).cpu().numpy())
            rollout["log_probs"].append(log_probs.squeeze(-1).cpu().numpy())

            obs = next_obs
            for info in infos:
                episode_data["rewards"].append(info["equity"] / 100000 - 1)
                episode_data["trades"].append(info["trades"])
                episode_data["wins"].append(info["wins"])

        # GAE
        with torch.no_grad():
            if USE_AMP:
                with autocast():
                    _, _, next_values = model(torch.FloatTensor(obs).to(device))
            else:
                _, _, next_values = model(torch.FloatTensor(obs).to(device))
            next_values = next_values.squeeze(-1).cpu().numpy()

        for k in rollout:
            rollout[k] = np.stack(rollout[k])

        T, N = rollout["rewards"].shape
        adv_all, ret_all = np.zeros((T, N)), np.zeros((T, N))
        for env_idx in range(N):
            adv_all[:, env_idx], ret_all[:, env_idx] = trainer.compute_gae(
                rollout["rewards"][:, env_idx],
                rollout["values"][:, env_idx],
                rollout["dones"][:, env_idx],
                next_values[env_idx]
            )

        flat_rollout = {
            "obs": rollout["obs"].reshape(-1, *rollout["obs"].shape[2:]),
            "actions": rollout["actions"].flatten(),
            "log_probs": rollout["log_probs"].flatten(),
            "advantages": adv_all.flatten(),
            "returns": ret_all.flatten(),
        }

        stats = trainer.update(flat_rollout, episode)

        # Metrics
        mean_reward = np.mean(episode_data["rewards"])
        mean_trades = np.mean(episode_data["trades"])
        win_rate = np.mean(episode_data["wins"]) / max(mean_trades, 1)

        if episode % CONFIG["log_every"] == 0:
            pnl = mean_reward * 100000
            print(f"Ep {episode:4d} | R: {mean_reward:7.4f} | PnL: {pnl:8.0f} | "
                  f"Tr: {mean_trades:4.1f} | WR: {win_rate:.2f} | "
                  f"Ent: {stats['entropy']:.3f} | LR: {stats['lr']:.2e}")

        # Save best
        if mean_reward > best_reward:
            best_reward = mean_reward
            torch.save({
                "model_state_dict": model.state_dict(),
                "config": CONFIG,
                "episode": episode,
                "best_reward": float(best_reward),
            }, os.path.join(CONFIG["output_dir"], "model_best.pt"))
            print(f"  -> Best model: {best_reward:.4f}")

        # Checkpoint
        if episode > 0 and episode % CONFIG["save_every"] == 0:
            torch.save({
                "model_state_dict": model.state_dict(),
                "config": CONFIG, "episode": episode,
            }, os.path.join(CONFIG["output_dir"], f"model_ep{episode}.pt"))

    # Final
    torch.save({
        "model_state_dict": model.state_dict(),
        "config": CONFIG,
        "episode": CONFIG["episodes"],
        "best_reward": float(best_reward),
    }, os.path.join(CONFIG["output_dir"], "model_final.pt"))

    print("\n" + "=" * 70)
    print(f"TRAINING COMPLETED - Best: {best_reward:.4f}")
    print("=" * 70)


if __name__ == "__main__":
    main()
