#!/usr/bin/env python3
"""
SOL ULTRA TRADER V3 - GPU MAXIMIZED
===================================
- 32 ambienti paralleli (vectorized)
- Modello grande (500k+ params)
- Batch 2048 per saturare la GPU
- Mixed precision (FP16) per 2x speed
- Gradient accumulation

Target: RTX 4090 al 90%+ utilization
"""

import os
import json
import numpy as np
import pandas as pd
from datetime import datetime
from typing import List, Tuple
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# ============================================================================
# CONFIG - OTTIMIZZATO PER GPU
# ============================================================================

CONFIG = {
    # Trading
    'FEE_RATE': 0.004,
    'INITIAL_CAPITAL': 10000,
    'MIN_HOLD_BARS': 3,

    # Model - GRANDE per saturare GPU
    'D_MODEL': 256,
    'N_HEADS': 8,
    'N_LAYERS': 6,
    'DROPOUT': 0.1,
    'SEQ_LEN': 60,
    'N_FEATURES': 24,

    # Parallelization
    'N_ENVS': 32,                # 32 ambienti paralleli
    'BATCH_SIZE': 2048,          # Batch grande
    'ACCUMULATION_STEPS': 4,     # Gradient accumulation

    # PPO
    'GAMMA': 0.99,
    'GAE_LAMBDA': 0.95,
    'PPO_CLIP': 0.2,
    'PPO_EPOCHS': 4,
    'VALUE_COEF': 0.5,
    'ENTROPY_COEF': 0.01,

    # Training
    'EPISODES': 200,
    'LR': 3e-4,
    'WEIGHT_DECAY': 0.01,
    'MAX_GRAD_NORM': 0.5,

    # Mixed precision
    'USE_AMP': True,
}

# ============================================================================
# FAST VECTORIZED FEATURES
# ============================================================================

def compute_features_vectorized(prices: np.ndarray) -> np.ndarray:
    """24 features vettorizzati - OTTIMIZZATO"""
    n = len(prices)
    features = np.zeros((n, CONFIG['N_FEATURES']), dtype=np.float32)

    df = pd.DataFrame({'p': prices})

    # Returns (6) - multiple timeframes
    for i, period in enumerate([1, 3, 5, 10, 30, 60]):
        features[:, i] = df['p'].pct_change(period).fillna(0).values * 100

    # Volatility (3)
    for i, period in enumerate([5, 15, 30]):
        features[:, 6+i] = df['p'].pct_change().rolling(period).std().fillna(0).values * 100

    # MA distances (4)
    for i, period in enumerate([5, 10, 20, 50]):
        ma = df['p'].rolling(period).mean()
        features[:, 9+i] = ((df['p'] - ma) / (ma + 1e-8) * 100).fillna(0).values

    # RSI (2)
    delta = df['p'].diff()
    gain = delta.where(delta > 0, 0)
    loss = (-delta.where(delta < 0, 0))
    for i, period in enumerate([7, 14]):
        rs = gain.rolling(period).mean() / (loss.rolling(period).mean() + 1e-8)
        features[:, 13+i] = (100 - 100 / (1 + rs)).fillna(50).values

    # Bollinger (2)
    ma20 = df['p'].rolling(20).mean()
    std20 = df['p'].rolling(20).std()
    features[:, 15] = ((df['p'] - ma20) / (std20 + 1e-8)).fillna(0).values
    features[:, 16] = (std20 / (ma20 + 1e-8) * 100).fillna(0).values

    # MACD (2)
    ema12 = df['p'].ewm(span=12).mean()
    ema26 = df['p'].ewm(span=26).mean()
    macd = ema12 - ema26
    signal = macd.ewm(span=9).mean()
    features[:, 17] = (macd / (df['p'] + 1e-8) * 100).fillna(0).values
    features[:, 18] = ((macd - signal) / (df['p'] + 1e-8) * 100).fillna(0).values

    # Price position (2)
    roll_max = df['p'].rolling(60).max()
    roll_min = df['p'].rolling(60).min()
    features[:, 19] = ((df['p'] - roll_min) / (roll_max - roll_min + 1e-8)).fillna(0.5).values
    features[:, 20] = ((roll_max - roll_min) / (df['p'] + 1e-8) * 100).fillna(0).values

    # Momentum (2)
    features[:, 21] = df['p'].diff(5).fillna(0).values / (df['p'].values + 1e-8) * 100
    mom5 = df['p'].diff(5)
    mom10 = df['p'].diff(10)
    features[:, 22] = ((mom5 - mom10.shift(5)) / (df['p'] + 1e-8) * 100).fillna(0).values

    # Trend (1)
    features[:, 23] = ((ema12 - ema26) / (ema26 + 1e-8) * 100).fillna(0).values

    # Clip and convert
    features = np.clip(features, -10, 10)
    return features


# ============================================================================
# TRANSFORMER MODEL - GRANDE PER GPU
# ============================================================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class TradingTransformer(nn.Module):
    """Transformer grande per saturare GPU"""
    def __init__(self, config):
        super().__init__()
        d = config['D_MODEL']

        # Input
        self.embed = nn.Sequential(
            nn.Linear(config['N_FEATURES'], d),
            nn.LayerNorm(d),
            nn.GELU(),
            nn.Dropout(config['DROPOUT'])
        )

        self.pos_enc = PositionalEncoding(d, config['SEQ_LEN'])

        # Transformer layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d,
            nhead=config['N_HEADS'],
            dim_feedforward=d * 4,
            dropout=config['DROPOUT'],
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config['N_LAYERS'])

        # Output heads
        self.actor = nn.Sequential(
            nn.Linear(d, d),
            nn.GELU(),
            nn.Dropout(config['DROPOUT']),
            nn.Linear(d, d // 2),
            nn.GELU(),
            nn.Linear(d // 2, 3)
        )

        self.critic = nn.Sequential(
            nn.Linear(d, d),
            nn.GELU(),
            nn.Dropout(config['DROPOUT']),
            nn.Linear(d, d // 2),
            nn.GELU(),
            nn.Linear(d // 2, 1)
        )

    def forward(self, x):
        # x: (batch, seq, features)
        x = self.embed(x)
        x = self.pos_enc(x)
        x = self.transformer(x)
        x = x[:, -1]  # Last timestep
        return self.actor(x), self.critic(x)


# ============================================================================
# VECTORIZED ENVIRONMENT - 32 AMBIENTI PARALLELI
# ============================================================================

class VectorizedTradingEnv:
    """32 ambienti trading paralleli su GPU"""

    def __init__(self, prices: np.ndarray, features: np.ndarray, config: dict):
        self.n_envs = config['N_ENVS']
        self.seq_len = config['SEQ_LEN']
        self.fee = config['FEE_RATE']
        self.min_hold = config['MIN_HOLD_BARS']

        # Converti a tensori GPU
        self.prices = torch.tensor(prices, dtype=torch.float32, device=device)
        self.features = torch.tensor(features, dtype=torch.float32, device=device)

        self.max_steps = len(prices) - self.seq_len - 1

        # State per ogni ambiente
        self.step_idx = torch.zeros(self.n_envs, dtype=torch.long, device=device)
        self.positions = torch.zeros(self.n_envs, dtype=torch.long, device=device)
        self.entry_prices = torch.zeros(self.n_envs, dtype=torch.float32, device=device)
        self.entry_steps = torch.zeros(self.n_envs, dtype=torch.long, device=device)
        self.total_pnl = torch.zeros(self.n_envs, dtype=torch.float32, device=device)
        self.n_trades = torch.zeros(self.n_envs, dtype=torch.long, device=device)
        self.n_wins = torch.zeros(self.n_envs, dtype=torch.long, device=device)

        self.reset()

    def reset(self) -> torch.Tensor:
        """Reset tutti gli ambienti"""
        # Random starting points per diversità
        max_start = self.max_steps - 1000  # Lascia spazio per episodio
        if max_start > self.seq_len:
            self.step_idx = torch.randint(self.seq_len, max_start, (self.n_envs,), device=device)
        else:
            self.step_idx = torch.full((self.n_envs,), self.seq_len, device=device)

        self.positions.zero_()
        self.entry_prices.zero_()
        self.entry_steps.zero_()
        self.total_pnl.zero_()
        self.n_trades.zero_()
        self.n_wins.zero_()

        return self._get_states()

    def _get_states(self) -> torch.Tensor:
        """Ottieni stati per tutti gli ambienti - BATCHED"""
        states = torch.zeros(self.n_envs, self.seq_len, self.features.shape[1], device=device)
        for i in range(self.n_envs):
            idx = self.step_idx[i].item()
            states[i] = self.features[idx - self.seq_len:idx]
        return states

    def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Step parallelo per tutti gli ambienti"""
        rewards = torch.zeros(self.n_envs, device=device)

        # Target positions: 0->SHORT(-1), 1->FLAT(0), 2->LONG(1)
        target_pos = actions - 1

        # Current prices
        current_prices = self.prices[self.step_idx]

        # Hold time check
        hold_times = self.step_idx - self.entry_steps
        can_trade = (hold_times >= self.min_hold) | (self.positions == 0)

        # Mask per chi può tradare e vuole cambiare posizione
        want_change = target_pos != self.positions
        will_trade = can_trade & want_change

        # Close existing positions
        has_position = self.positions != 0
        closing = will_trade & has_position

        # Calculate PnL for closing positions
        long_pnl = (current_prices / (self.entry_prices + 1e-8) - 1) - self.fee
        short_pnl = (self.entry_prices / (current_prices + 1e-8) - 1) - self.fee

        pnl = torch.where(self.positions == 1, long_pnl, short_pnl)
        pnl = torch.where(closing, pnl, torch.zeros_like(pnl))

        rewards += pnl * 100
        self.total_pnl += torch.where(closing, pnl, torch.zeros_like(pnl))
        self.n_trades += closing.long()
        self.n_wins += (closing & (pnl > 0)).long()

        # Open new positions
        opening = will_trade & (target_pos != 0)
        self.entry_prices = torch.where(opening, current_prices, self.entry_prices)
        self.entry_steps = torch.where(opening, self.step_idx, self.entry_steps)

        # Update positions
        self.positions = torch.where(will_trade, target_pos, self.positions)

        # Penalty for flat
        flat_penalty = (self.positions == 0).float() * 0.001
        rewards -= flat_penalty

        # Move forward
        self.step_idx += 1

        # Check done
        dones = self.step_idx >= len(self.prices) - 1

        # Force close at end
        force_close = dones & (self.positions != 0)
        end_prices = self.prices[torch.clamp(self.step_idx, max=len(self.prices)-1)]

        end_long_pnl = (end_prices / (self.entry_prices + 1e-8) - 1) - self.fee
        end_short_pnl = (self.entry_prices / (end_prices + 1e-8) - 1) - self.fee
        end_pnl = torch.where(self.positions == 1, end_long_pnl, end_short_pnl)
        end_pnl = torch.where(force_close, end_pnl, torch.zeros_like(end_pnl))

        rewards += end_pnl * 100
        self.total_pnl += end_pnl

        return self._get_states(), rewards, dones

    def get_metrics(self) -> dict:
        """Metriche aggregate"""
        return {
            'total_pnl': self.total_pnl.mean().item() * 100,
            'n_trades': self.n_trades.float().mean().item(),
            'win_rate': (self.n_wins.float() / (self.n_trades.float() + 1e-8)).mean().item() * 100
        }


# ============================================================================
# PPO AGENT WITH MIXED PRECISION
# ============================================================================

class PPOAgentGPU:
    def __init__(self, model: nn.Module, config: dict):
        self.model = model.to(device)
        self.config = config
        self.optimizer = AdamW(
            model.parameters(),
            lr=config['LR'],
            weight_decay=config['WEIGHT_DECAY']
        )
        self.scaler = GradScaler() if config['USE_AMP'] else None

        # Buffers su GPU
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []

    def select_actions(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Selezione azioni batched"""
        self.model.eval()
        with torch.no_grad():
            if self.config['USE_AMP']:
                with autocast():
                    logits, values = self.model(states)
            else:
                logits, values = self.model(states)

            probs = F.softmax(logits, dim=-1)
            dist = torch.distributions.Categorical(probs)
            actions = dist.sample()
            log_probs = dist.log_prob(actions)

        return actions, log_probs, values.squeeze(-1)

    def store(self, states, actions, rewards, values, log_probs, dones):
        self.states.append(states)
        self.actions.append(actions)
        self.rewards.append(rewards)
        self.values.append(values)
        self.log_probs.append(log_probs)
        self.dones.append(dones)

    def compute_gae(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute GAE su GPU"""
        rewards = torch.stack(self.rewards)  # (T, n_envs)
        values = torch.stack(self.values)    # (T, n_envs)
        dones = torch.stack(self.dones).float()

        T, n_envs = rewards.shape

        advantages = torch.zeros_like(rewards)
        gae = torch.zeros(n_envs, device=device)

        for t in reversed(range(T)):
            if t == T - 1:
                next_val = torch.zeros(n_envs, device=device)
            else:
                next_val = values[t + 1]

            delta = rewards[t] + self.config['GAMMA'] * next_val * (1 - dones[t]) - values[t]
            gae = delta + self.config['GAMMA'] * self.config['GAE_LAMBDA'] * (1 - dones[t]) * gae
            advantages[t] = gae

        returns = advantages + values
        return advantages.flatten(), returns.flatten()

    def update(self) -> float:
        """PPO update con mixed precision"""
        self.model.train()

        # Stack all data
        states = torch.cat(self.states, dim=0)       # (T*n_envs, seq, feat)
        actions = torch.cat(self.actions, dim=0)     # (T*n_envs,)
        old_log_probs = torch.cat(self.log_probs, dim=0)

        advantages, returns = self.compute_gae()
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        total_loss = 0
        n_samples = len(states)
        batch_size = self.config['BATCH_SIZE']

        for _ in range(self.config['PPO_EPOCHS']):
            indices = torch.randperm(n_samples, device=device)

            for start in range(0, n_samples, batch_size):
                idx = indices[start:start + batch_size]

                if self.config['USE_AMP']:
                    with autocast():
                        logits, values = self.model(states[idx])
                        probs = F.softmax(logits, dim=-1)
                        dist = torch.distributions.Categorical(probs)

                        new_log_probs = dist.log_prob(actions[idx])
                        entropy = dist.entropy().mean()

                        ratio = torch.exp(new_log_probs - old_log_probs[idx])
                        surr1 = ratio * advantages[idx]
                        surr2 = torch.clamp(ratio, 1 - self.config['PPO_CLIP'],
                                           1 + self.config['PPO_CLIP']) * advantages[idx]
                        actor_loss = -torch.min(surr1, surr2).mean()

                        critic_loss = F.mse_loss(values.squeeze(-1), returns[idx])

                        loss = actor_loss + self.config['VALUE_COEF'] * critic_loss - \
                               self.config['ENTROPY_COEF'] * entropy

                    self.optimizer.zero_grad()
                    self.scaler.scale(loss).backward()
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['MAX_GRAD_NORM'])
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    logits, values = self.model(states[idx])
                    probs = F.softmax(logits, dim=-1)
                    dist = torch.distributions.Categorical(probs)

                    new_log_probs = dist.log_prob(actions[idx])
                    entropy = dist.entropy().mean()

                    ratio = torch.exp(new_log_probs - old_log_probs[idx])
                    surr1 = ratio * advantages[idx]
                    surr2 = torch.clamp(ratio, 1 - self.config['PPO_CLIP'],
                                       1 + self.config['PPO_CLIP']) * advantages[idx]
                    actor_loss = -torch.min(surr1, surr2).mean()

                    critic_loss = F.mse_loss(values.squeeze(-1), returns[idx])

                    loss = actor_loss + self.config['VALUE_COEF'] * critic_loss - \
                           self.config['ENTROPY_COEF'] * entropy

                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['MAX_GRAD_NORM'])
                    self.optimizer.step()

                total_loss += loss.item()

        # Clear buffers
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []

        return total_loss


# ============================================================================
# TRAINING
# ============================================================================

def train():
    print("\n" + "="*60)
    print("🔥 SOL ULTRA TRADER V3 - GPU MAXIMIZED")
    print("="*60)

    # Load data
    paths = ['/workspace/prices.csv', './prices.csv']
    for p in paths:
        if os.path.exists(p):
            df = pd.read_csv(p)
            prices = df['price'].values.astype(np.float32)
            print(f"📂 Loaded {len(prices)} prices")
            break

    # Features
    print("⚡ Computing features...")
    features = compute_features_vectorized(prices)
    print(f"   Shape: {features.shape}")

    # Split
    train_size = int(len(prices) * 0.8)
    train_prices = prices[:train_size]
    train_features = features[:train_size]
    val_prices = prices[train_size:]
    val_features = features[train_size:]

    print(f"📊 Train: {len(train_prices)} | Val: {len(val_prices)}")

    # Model
    model = TradingTransformer(CONFIG)
    n_params = sum(p.numel() for p in model.parameters())
    print(f"🧠 Model: {n_params:,} parameters")
    print(f"   Layers: {CONFIG['N_LAYERS']} | Heads: {CONFIG['N_HEADS']} | d_model: {CONFIG['D_MODEL']}")

    agent = PPOAgentGPU(model, CONFIG)

    # Vectorized environment
    env = VectorizedTradingEnv(train_prices, train_features, CONFIG)
    print(f"🎮 Environments: {CONFIG['N_ENVS']} parallel")
    print(f"   Batch size: {CONFIG['BATCH_SIZE']}")
    print(f"   Mixed precision: {CONFIG['USE_AMP']}")

    # Training
    print(f"\n🎯 Training {CONFIG['EPISODES']} episodes...")
    print("-"*60)

    best_return = -999
    best_sharpe = -999
    episode_length = 500  # Steps per episode

    for ep in range(CONFIG['EPISODES']):
        states = env.reset()
        ep_rewards = []

        for step in range(episode_length):
            actions, log_probs, values = agent.select_actions(states)
            next_states, rewards, dones = env.step(actions)

            agent.store(states, actions, rewards, values, log_probs, dones)
            ep_rewards.append(rewards.mean().item())

            states = next_states

            # Reset done environments
            if dones.any():
                # Partial reset for done envs
                done_idx = dones.nonzero().squeeze(-1)
                for idx in done_idx:
                    i = idx.item()
                    max_start = env.max_steps - 1000
                    if max_start > env.seq_len:
                        env.step_idx[i] = torch.randint(env.seq_len, max_start, (1,), device=device).item()
                    env.positions[i] = 0
                    env.entry_prices[i] = 0
                    env.entry_steps[i] = 0

        # Update
        loss = agent.update()

        # Metrics
        metrics = env.get_metrics()
        ret = metrics['total_pnl']
        trades = metrics['n_trades']
        win_rate = metrics['win_rate']

        # Approximate Sharpe
        sharpe = np.mean(ep_rewards) / (np.std(ep_rewards) + 1e-8) * np.sqrt(252)

        if ep % 5 == 0:
            gpu_mem = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
            print(f"Ep {ep:3d} | Ret: {ret:+7.2f}% | Trades: {trades:5.0f} | "
                  f"Win: {win_rate:5.1f}% | Sharpe: {sharpe:+5.2f} | GPU: {gpu_mem:.1f}GB")

        # Save best
        if ret > best_return:
            best_return = ret
            save_model(model, 'best_return_v3', {'return': ret, 'trades': trades, 'win_rate': win_rate, 'sharpe': sharpe})

        if sharpe > best_sharpe and trades >= 20:
            best_sharpe = sharpe
            save_model(model, 'best_sharpe_v3', {'return': ret, 'trades': trades, 'win_rate': win_rate, 'sharpe': sharpe})

    # Final
    save_model(model, 'final_v3', {'return': ret, 'trades': trades, 'win_rate': win_rate, 'sharpe': sharpe})

    # Validation
    print("\n" + "="*60)
    print("📈 VALIDATION")
    print("="*60)

    validate(model, val_prices, val_features, CONFIG)

    print("\n✅ TRAINING COMPLETE!")
    print(f"   Best Return: {best_return:+.2f}%")
    print(f"   Best Sharpe: {best_sharpe:+.2f}")


def validate(model, prices, features, config):
    """Single environment validation"""
    model.eval()

    prices_t = torch.tensor(prices, dtype=torch.float32, device=device)
    features_t = torch.tensor(features, dtype=torch.float32, device=device)

    seq_len = config['SEQ_LEN']
    position = 0
    entry_price = 0
    entry_step = 0
    total_pnl = 0
    trades = []

    with torch.no_grad():
        for i in range(seq_len, len(prices) - 1):
            state = features_t[i - seq_len:i].unsqueeze(0)

            if config['USE_AMP']:
                with autocast():
                    logits, _ = model(state)
            else:
                logits, _ = model(state)

            action = logits.argmax(dim=-1).item()
            target_pos = action - 1

            price = prices[i]
            hold_time = i - entry_step

            if (hold_time >= config['MIN_HOLD_BARS'] or position == 0) and target_pos != position:
                # Close
                if position != 0:
                    if position == 1:
                        pnl = (price / entry_price - 1) - config['FEE_RATE']
                    else:
                        pnl = (entry_price / price - 1) - config['FEE_RATE']
                    total_pnl += pnl
                    trades.append(pnl)

                # Open
                if target_pos != 0:
                    entry_price = price
                    entry_step = i

                position = target_pos

    val_ret = total_pnl * 100
    n_trades = len(trades)
    win_rate = sum(1 for t in trades if t > 0) / max(n_trades, 1) * 100
    buy_hold = (prices[-1] / prices[seq_len] - 1) * 100

    print(f"   Model Return:  {val_ret:+.2f}%")
    print(f"   Buy & Hold:    {buy_hold:+.2f}%")
    print(f"   Trades:        {n_trades}")
    print(f"   Win Rate:      {win_rate:.1f}%")
    print(f"   Alpha:         {val_ret - buy_hold:+.2f}%")


def save_model(model, name, metrics):
    state = model.state_dict()
    weights = {k: v.cpu().numpy().tolist() for k, v in state.items()}

    output = {
        'type': 'TransformerTrader_V3',
        'asset': 'SOL/EUR',
        'weights': weights,
        'metrics': metrics,
        'config': {k: v for k, v in CONFIG.items() if not callable(v)},
        'trainedAt': datetime.now().isoformat(),
        'trainedOn': 'RunPod RTX 4090 - GPU Maximized'
    }

    path = f'/workspace/model_sol_{name}.json'
    with open(path, 'w') as f:
        json.dump(output, f)

    torch.save(model.state_dict(), f'/workspace/model_sol_{name}.pt')
    print(f"   💾 Saved: {name}")


if __name__ == '__main__':
    print("""
    ╔══════════════════════════════════════════════════════════════╗
    ║                                                              ║
    ║   🔥 SOL ULTRA TRADER V3 - GPU MAXIMIZED                    ║
    ║                                                              ║
    ║   - 32 ambienti paralleli                                   ║
    ║   - Transformer 500k+ params                                ║
    ║   - Batch 2048                                              ║
    ║   - Mixed precision FP16                                    ║
    ║   - Target: RTX 4090 @ 90%+                                 ║
    ║                                                              ║
    ╚══════════════════════════════════════════════════════════════╝
    """)
    train()
