#!/usr/bin/env python3
"""
SOL ULTRA TRADER V2 - IL MIGLIORE DEL MONDO
============================================
- Mamba SSM (batte i Transformer)
- PPO Reinforcement Learning
- Features vettorizzate (VELOCE)
- Fee-aware (0.4% Kraken)
"""

import os
import json
import numpy as np
import pandas as pd
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

# ============================================================================
# CONFIG
# ============================================================================

CONFIG = {
    'FEE_RATE': 0.004,           # 0.4% round-trip
    'INITIAL_CAPITAL': 10000,
    'MIN_HOLD_BARS': 3,

    # Model
    'D_MODEL': 64,
    'N_LAYERS': 3,
    'SEQ_LEN': 60,               # 1 ora di storia
    'N_FEATURES': 16,            # Features semplificate ma potenti

    # PPO
    'GAMMA': 0.99,
    'GAE_LAMBDA': 0.95,
    'PPO_CLIP': 0.2,
    'PPO_EPOCHS': 4,
    'BATCH_SIZE': 64,

    # Training
    'EPISODES': 300,
    'LR': 1e-3,
}

# ============================================================================
# FAST FEATURE ENGINEERING (VETTORIZZATO)
# ============================================================================

def compute_features_fast(prices: np.ndarray) -> np.ndarray:
    """Compute 16 features vettorizzati - VELOCE"""
    n = len(prices)
    features = np.zeros((n, CONFIG['N_FEATURES']))

    df = pd.DataFrame({'price': prices})

    # Returns (4)
    for i, period in enumerate([1, 5, 15, 60]):
        features[:, i] = df['price'].pct_change(period).fillna(0).values * 100

    # Volatility (2)
    features[:, 4] = df['price'].pct_change().rolling(10).std().fillna(0).values * 100
    features[:, 5] = df['price'].pct_change().rolling(30).std().fillna(0).values * 100

    # MA distances (3)
    for i, period in enumerate([10, 30, 60]):
        ma = df['price'].rolling(period).mean()
        features[:, 6+i] = ((df['price'] - ma) / ma * 100).fillna(0).values

    # RSI (1)
    delta = df['price'].diff()
    gain = delta.where(delta > 0, 0).rolling(14).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
    rs = gain / (loss + 1e-8)
    features[:, 9] = (100 - 100 / (1 + rs)).fillna(50).values

    # Bollinger position (1)
    ma20 = df['price'].rolling(20).mean()
    std20 = df['price'].rolling(20).std()
    features[:, 10] = ((df['price'] - ma20) / (std20 + 1e-8)).fillna(0).values

    # Momentum (2)
    features[:, 11] = df['price'].diff(5).fillna(0).values / df['price'].values * 100
    features[:, 12] = df['price'].diff(20).fillna(0).values / df['price'].values * 100

    # High/Low position (1)
    roll_max = df['price'].rolling(60).max()
    roll_min = df['price'].rolling(60).min()
    features[:, 13] = ((df['price'] - roll_min) / (roll_max - roll_min + 1e-8)).fillna(0.5).values

    # Volume proxy: price change magnitude (1)
    features[:, 14] = np.abs(df['price'].pct_change().fillna(0).values) * 100

    # Trend strength (1)
    ema_fast = df['price'].ewm(span=12).mean()
    ema_slow = df['price'].ewm(span=26).mean()
    features[:, 15] = ((ema_fast - ema_slow) / ema_slow * 100).fillna(0).values

    # Clip outliers
    features = np.clip(features, -10, 10)

    return features

# ============================================================================
# MAMBA BLOCK - Simplified but effective
# ============================================================================

class SimpleMamba(nn.Module):
    """Simplified Mamba-like block with selective gating"""
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model

        # Input projection
        self.in_proj = nn.Linear(d_model, d_model * 2)

        # Conv for local patterns
        self.conv = nn.Conv1d(d_model, d_model, kernel_size=4, padding=2, groups=d_model)

        # Selective gate
        self.gate = nn.Linear(d_model, d_model)

        # Output
        self.out_proj = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        # x: (batch, seq, d_model)
        residual = x
        x = self.norm(x)

        # Split into value and gate
        xz = self.in_proj(x)
        x_val, z = xz.chunk(2, dim=-1)

        # Conv
        x_val = x_val.transpose(1, 2)
        x_val = self.conv(x_val)[:, :, :x.shape[1]]
        x_val = x_val.transpose(1, 2)
        x_val = F.silu(x_val)

        # Selective gating
        gate = torch.sigmoid(self.gate(x_val))
        x_val = x_val * gate

        # Output
        out = self.out_proj(x_val * F.silu(z))
        return out + residual


# ============================================================================
# ACTOR-CRITIC MODEL
# ============================================================================

class TradingModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        d = config['D_MODEL']

        # Input embedding
        self.embed = nn.Sequential(
            nn.Linear(config['N_FEATURES'], d),
            nn.LayerNorm(d),
            nn.GELU()
        )

        # Mamba layers
        self.layers = nn.ModuleList([
            SimpleMamba(d) for _ in range(config['N_LAYERS'])
        ])

        # Heads
        self.actor = nn.Sequential(
            nn.Linear(d, d),
            nn.GELU(),
            nn.Linear(d, 3)  # FLAT, LONG, SHORT
        )

        self.critic = nn.Sequential(
            nn.Linear(d, d),
            nn.GELU(),
            nn.Linear(d, 1)
        )

    def forward(self, x):
        # x: (batch, seq, features)
        x = self.embed(x)

        for layer in self.layers:
            x = layer(x)

        # Use last timestep
        x = x[:, -1]

        return self.actor(x), self.critic(x)


# ============================================================================
# PPO AGENT
# ============================================================================

class PPOAgent:
    def __init__(self, model, config):
        self.model = model.to(device)
        self.config = config
        self.optimizer = AdamW(model.parameters(), lr=config['LR'])

        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []

    def select_action(self, state):
        self.model.eval()
        with torch.no_grad():
            state = state.unsqueeze(0).to(device)
            logits, value = self.model(state)
            probs = F.softmax(logits, dim=-1)
            dist = torch.distributions.Categorical(probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)
        return action.item(), log_prob.item(), value.item()

    def store(self, state, action, reward, value, log_prob, done):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.values.append(value)
        self.log_probs.append(log_prob)
        self.dones.append(done)

    def update(self):
        self.model.train()

        # Compute returns and advantages
        returns = []
        advantages = []
        gae = 0

        for t in reversed(range(len(self.rewards))):
            if t == len(self.rewards) - 1:
                next_val = 0
            else:
                next_val = self.values[t + 1]

            delta = self.rewards[t] + self.config['GAMMA'] * next_val * (1 - self.dones[t]) - self.values[t]
            gae = delta + self.config['GAMMA'] * self.config['GAE_LAMBDA'] * (1 - self.dones[t]) * gae
            advantages.insert(0, gae)
            returns.insert(0, gae + self.values[t])

        # Convert to tensors
        states = torch.stack(self.states).to(device)
        actions = torch.tensor(self.actions).to(device)
        old_log_probs = torch.tensor(self.log_probs).to(device)
        returns = torch.tensor(returns, dtype=torch.float32).to(device)
        advantages = torch.tensor(advantages, dtype=torch.float32).to(device)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # PPO update
        total_loss = 0
        for _ in range(self.config['PPO_EPOCHS']):
            # Mini-batches
            indices = np.random.permutation(len(states))
            for start in range(0, len(states), self.config['BATCH_SIZE']):
                idx = indices[start:start + self.config['BATCH_SIZE']]

                logits, values = self.model(states[idx])
                probs = F.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(probs)

                new_log_probs = dist.log_prob(actions[idx])
                entropy = dist.entropy().mean()

                # PPO loss
                ratio = torch.exp(new_log_probs - old_log_probs[idx])
                surr1 = ratio * advantages[idx]
                surr2 = torch.clamp(ratio, 1 - self.config['PPO_CLIP'], 1 + self.config['PPO_CLIP']) * advantages[idx]
                actor_loss = -torch.min(surr1, surr2).mean()

                critic_loss = F.mse_loss(values.squeeze(), returns[idx])

                loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy

                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
                self.optimizer.step()

                total_loss += loss.item()

        # Clear memory
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []

        return total_loss


# ============================================================================
# TRADING ENVIRONMENT
# ============================================================================

class TradingEnv:
    def __init__(self, prices, features, config):
        self.prices = prices
        self.features = features
        self.config = config
        self.seq_len = config['SEQ_LEN']
        self.fee = config['FEE_RATE']
        self.min_hold = config['MIN_HOLD_BARS']
        self.reset()

    def reset(self):
        self.step_idx = self.seq_len
        self.position = 0  # -1, 0, 1
        self.entry_price = 0
        self.entry_step = 0
        self.total_pnl = 0
        self.trades = []
        return self._get_state()

    def _get_state(self):
        state = self.features[self.step_idx - self.seq_len:self.step_idx]
        return torch.tensor(state, dtype=torch.float32)

    def step(self, action):
        price = self.prices[self.step_idx]
        reward = 0

        target_pos = action - 1  # 0->-1, 1->0, 2->1
        hold_time = self.step_idx - self.entry_step
        can_trade = hold_time >= self.min_hold or self.position == 0

        if can_trade and target_pos != self.position:
            # Close position
            if self.position != 0:
                if self.position == 1:
                    pnl = (price / self.entry_price - 1) - self.fee
                else:
                    pnl = (self.entry_price / price - 1) - self.fee

                reward = pnl * 100
                self.total_pnl += pnl
                self.trades.append({
                    'pnl': pnl,
                    'side': 'LONG' if self.position == 1 else 'SHORT',
                    'hold': hold_time
                })

            # Open new
            if target_pos != 0:
                self.entry_price = price
                self.entry_step = self.step_idx

            self.position = target_pos

        # Small penalty for doing nothing
        if self.position == 0:
            reward -= 0.001

        self.step_idx += 1
        done = self.step_idx >= len(self.prices) - 1

        # Force close at end
        if done and self.position != 0:
            price = self.prices[self.step_idx]
            if self.position == 1:
                pnl = (price / self.entry_price - 1) - self.fee
            else:
                pnl = (self.entry_price / price - 1) - self.fee
            reward = pnl * 100
            self.total_pnl += pnl

        return self._get_state(), reward, done


# ============================================================================
# TRAINING
# ============================================================================

def train():
    print("\n" + "="*60)
    print("🔥 SOL ULTRA TRADER V2 - TRAINING")
    print("="*60)

    # Load data
    paths = ['/workspace/prices.csv', './prices.csv']
    for p in paths:
        if os.path.exists(p):
            df = pd.read_csv(p)
            prices = df['price'].values
            print(f"📂 Loaded {len(prices)} prices from {p}")
            break

    # Features
    print("⚡ Computing features (vectorized)...")
    features = compute_features_fast(prices)
    print(f"   Features shape: {features.shape}")

    # Split
    train_size = int(len(prices) * 0.8)
    train_prices = prices[:train_size]
    train_features = features[:train_size]
    val_prices = prices[train_size:]
    val_features = features[train_size:]

    print(f"📊 Train: {len(train_prices)} | Val: {len(val_prices)}")

    # Model
    model = TradingModel(CONFIG)
    n_params = sum(p.numel() for p in model.parameters())
    print(f"🧠 Model params: {n_params:,}")

    agent = PPOAgent(model, CONFIG)
    env = TradingEnv(train_prices, train_features, CONFIG)

    # Training
    print(f"\n🎯 Training {CONFIG['EPISODES']} episodes...")
    print("-"*60)

    best_return = -999
    best_sharpe = -999

    for ep in range(CONFIG['EPISODES']):
        state = env.reset()
        ep_reward = 0

        while True:
            action, log_prob, value = agent.select_action(state)
            next_state, reward, done = env.step(action)
            agent.store(state, action, reward, value, log_prob, done)
            ep_reward += reward
            state = next_state
            if done:
                break

        loss = agent.update()

        # Metrics
        ret = env.total_pnl * 100
        n_trades = len(env.trades)
        win_rate = sum(1 for t in env.trades if t['pnl'] > 0) / max(n_trades, 1) * 100

        if env.trades:
            pnls = [t['pnl'] for t in env.trades]
            sharpe = np.mean(pnls) / (np.std(pnls) + 1e-8) * np.sqrt(252)
        else:
            sharpe = 0

        if ep % 10 == 0:
            print(f"Ep {ep:3d} | Ret: {ret:+7.2f}% | Trades: {n_trades:4d} | "
                  f"Win: {win_rate:5.1f}% | Sharpe: {sharpe:+5.2f}")

        # Save best
        if ret > best_return:
            best_return = ret
            save_model(model, 'best_return', {'return': ret, 'trades': n_trades, 'win_rate': win_rate, 'sharpe': sharpe})

        if sharpe > best_sharpe and n_trades >= 20:
            best_sharpe = sharpe
            save_model(model, 'best_sharpe', {'return': ret, 'trades': n_trades, 'win_rate': win_rate, 'sharpe': sharpe})

    # Final save
    save_model(model, 'final', {'return': ret, 'trades': n_trades, 'win_rate': win_rate, 'sharpe': sharpe})

    # Validation
    print("\n" + "="*60)
    print("📈 VALIDATION")
    print("="*60)

    val_env = TradingEnv(val_prices, val_features, CONFIG)
    model.eval()
    state = val_env.reset()

    with torch.no_grad():
        while True:
            state_t = state.unsqueeze(0).to(device)
            logits, _ = model(state_t)
            action = logits.argmax(dim=-1).item()
            state, _, done = val_env.step(action)
            if done:
                break

    val_ret = val_env.total_pnl * 100
    val_trades = len(val_env.trades)
    val_win = sum(1 for t in val_env.trades if t['pnl'] > 0) / max(val_trades, 1) * 100
    buy_hold = (val_prices[-1] / val_prices[CONFIG['SEQ_LEN']] - 1) * 100

    print(f"   Model Return:  {val_ret:+.2f}%")
    print(f"   Buy & Hold:    {buy_hold:+.2f}%")
    print(f"   Trades:        {val_trades}")
    print(f"   Win Rate:      {val_win:.1f}%")
    print(f"   Alpha:         {val_ret - buy_hold:+.2f}%")

    print("\n✅ TRAINING COMPLETE!")
    print(f"   Best Return: {best_return:+.2f}%")
    print(f"   Best Sharpe: {best_sharpe:+.2f}")


def save_model(model, name, metrics):
    state = model.state_dict()
    weights = {k: v.cpu().numpy().tolist() for k, v in state.items()}

    output = {
        'type': 'MambaTrader',
        'asset': 'SOL/EUR',
        'weights': weights,
        'metrics': metrics,
        'config': CONFIG,
        'trainedAt': datetime.now().isoformat(),
        'trainedOn': 'RunPod RTX 4090'
    }

    path = f'/workspace/model_sol_{name}.json'
    with open(path, 'w') as f:
        json.dump(output, f)

    torch.save(model.state_dict(), f'/workspace/model_sol_{name}.pt')
    print(f"   💾 Saved: {name}")


if __name__ == '__main__':
    print("""
    ╔══════════════════════════════════════════════════════════════╗
    ║                                                              ║
    ║   🔥 SOL ULTRA TRADER V2                                    ║
    ║   Il Migliore del Mondo                                     ║
    ║                                                              ║
    ║   - Mamba SSM (batte Transformer)                           ║
    ║   - PPO Reinforcement Learning                              ║
    ║   - Fee-aware (0.4% Kraken)                                 ║
    ║   - 300 episodi di training                                 ║
    ║                                                              ║
    ╚══════════════════════════════════════════════════════════════╝
    """)
    train()
