#!/usr/bin/env python3
"""
SOL MAMBA TRADING - ULTRA ADVANCED TRAINING
============================================
Architettura: SAMBA (Graph-Mamba) + PPO Reinforcement Learning
- Mamba State Space Models (battono i Transformer)
- Graph Neural Networks per correlazioni multi-timeframe
- PPO per decisioni trading ottimali
- Fee-aware reward shaping

Target: SOL/EUR su dati 1 Nov 2025 - 17 Jan 2026
"""

import os
import json
import numpy as np
import pandas as pd
from datetime import datetime
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# ============================================================================
# CONFIGURAZIONE
# ============================================================================

CONFIG = {
    # Trading
    'FEE_RATE': 0.004,           # 0.4% Kraken taker fee round-trip
    'INITIAL_CAPITAL': 10000,
    'MAX_POSITION': 1.0,         # 100% max position
    'MIN_HOLD_BARS': 3,          # Minimo 3 minuti hold

    # Model - Mamba
    'D_MODEL': 128,              # Dimensione hidden state
    'D_STATE': 64,               # SSM state dimension
    'D_CONV': 4,                 # Convolution kernel size
    'EXPAND': 2,                 # Expansion factor
    'N_LAYERS': 4,               # Numero layer Mamba

    # Model - Graph
    'N_NODES': 5,                # 5 timeframes (1m, 5m, 15m, 1h, 4h)
    'GNN_HIDDEN': 64,
    'GNN_LAYERS': 2,

    # Model - PPO
    'PPO_EPOCHS': 10,
    'PPO_CLIP': 0.2,
    'GAMMA': 0.99,
    'GAE_LAMBDA': 0.95,
    'VALUE_COEF': 0.5,
    'ENTROPY_COEF': 0.01,

    # Training
    'BATCH_SIZE': 256,
    'LEARNING_RATE': 3e-4,
    'WEIGHT_DECAY': 0.01,
    'EPISODES': 500,
    'WARMUP_EPISODES': 50,
    'SEQUENCE_LENGTH': 120,      # 2 ore di storia

    # Features
    'N_FEATURES': 32,            # Features per timeframe
}

# ============================================================================
# MAMBA BLOCK - State Space Model
# ============================================================================

class MambaBlock(nn.Module):
    """
    Mamba: Linear-Time Sequence Modeling with Selective State Spaces
    https://arxiv.org/abs/2312.00752

    Batte i Transformer con complessità O(n) invece di O(n²)
    """
    def __init__(self, d_model: int, d_state: int = 64, d_conv: int = 4, expand: int = 2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = d_model * expand

        # Input projection
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)

        # Convolution
        self.conv1d = nn.Conv1d(
            self.d_inner, self.d_inner,
            kernel_size=d_conv,
            padding=d_conv - 1,
            groups=self.d_inner
        )

        # SSM parameters
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)  # dt, B, C

        # A is log-spaced initialization
        A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(self.d_inner))

        # Output projection
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

        # Layer norm
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (batch, seq_len, d_model)
        """
        batch, seq_len, _ = x.shape
        residual = x
        x = self.norm(x)

        # Input projection: split into x and z (gate)
        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)

        # Convolution
        x = x.transpose(1, 2)  # (batch, d_inner, seq_len)
        x = self.conv1d(x)[:, :, :seq_len]
        x = x.transpose(1, 2)  # (batch, seq_len, d_inner)
        x = F.silu(x)

        # SSM
        x = self.ssm(x)

        # Gate and output
        x = x * F.silu(z)
        x = self.out_proj(x)

        return x + residual

    def ssm(self, x: torch.Tensor) -> torch.Tensor:
        """Selective State Space Model - Vectorized"""
        batch, seq_len, d_inner = x.shape

        # Project to get dt, B, C
        x_dbl = self.x_proj(x)
        dt = x_dbl[:, :, :1]  # (batch, seq, 1)
        B = x_dbl[:, :, 1:self.d_state+1]  # (batch, seq, d_state)
        C = x_dbl[:, :, self.d_state+1:]  # (batch, seq, d_state)

        # Softplus for dt (timescale)
        dt = F.softplus(dt).squeeze(-1)  # (batch, seq)

        # Discretize A: (d_inner, d_state)
        A = -torch.exp(self.A_log)

        # Selective scan
        y = torch.zeros_like(x)
        h = torch.zeros(batch, d_inner, self.d_state, device=x.device)

        for t in range(seq_len):
            dt_t = dt[:, t].view(batch, 1, 1)  # (batch, 1, 1)

            # dA: (batch, d_inner, d_state)
            dA = torch.exp(dt_t * A.unsqueeze(0))

            # dB: (batch, d_inner, d_state)
            B_t = B[:, t].unsqueeze(1)  # (batch, 1, d_state)
            x_t = x[:, t].unsqueeze(-1)  # (batch, d_inner, 1)
            dB = dt_t * B_t * x_t  # (batch, d_inner, d_state)

            # State update
            h = dA * h + dB

            # Output: y = C * h + D * x
            C_t = C[:, t].unsqueeze(1)  # (batch, 1, d_state)
            y[:, t] = (h * C_t).sum(-1) + self.D * x[:, t]

        return y


# ============================================================================
# GRAPH ATTENTION NETWORK - Multi-Timeframe Correlations
# ============================================================================

class GraphAttention(nn.Module):
    """
    Graph Attention Network per catturare correlazioni tra timeframe
    """
    def __init__(self, in_features: int, out_features: int, n_heads: int = 4, dropout: float = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = out_features // n_heads

        self.W = nn.Linear(in_features, out_features, bias=False)
        self.a = nn.Parameter(torch.randn(n_heads, 2 * self.head_dim))
        self.dropout = nn.Dropout(dropout)
        self.leaky_relu = nn.LeakyReLU(0.2)

    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        """
        x: (batch, n_nodes, in_features)
        adj: (n_nodes, n_nodes) adjacency matrix
        """
        batch, n_nodes, _ = x.shape

        # Linear transformation
        h = self.W(x)  # (batch, n_nodes, out_features)
        h = h.view(batch, n_nodes, self.n_heads, self.head_dim)

        # Attention scores
        h_i = h.unsqueeze(2).expand(-1, -1, n_nodes, -1, -1)
        h_j = h.unsqueeze(1).expand(-1, n_nodes, -1, -1, -1)

        concat = torch.cat([h_i, h_j], dim=-1)  # (batch, n, n, heads, 2*head_dim)
        e = (concat * self.a).sum(-1)  # (batch, n, n, heads)
        e = self.leaky_relu(e)

        # Mask with adjacency
        mask = adj.unsqueeze(0).unsqueeze(-1) == 0
        e = e.masked_fill(mask, float('-inf'))

        # Softmax attention
        alpha = F.softmax(e, dim=2)
        alpha = self.dropout(alpha)

        # Aggregate
        out = torch.einsum('bnjh,bnjhd->bnhd', alpha, h_j)
        out = out.reshape(batch, n_nodes, -1)

        return out


class GNN(nn.Module):
    """Multi-layer Graph Neural Network"""
    def __init__(self, in_features: int, hidden: int, out_features: int, n_layers: int = 2):
        super().__init__()
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        # First layer
        self.layers.append(GraphAttention(in_features, hidden))
        self.norms.append(nn.LayerNorm(hidden))

        # Hidden layers
        for _ in range(n_layers - 2):
            self.layers.append(GraphAttention(hidden, hidden))
            self.norms.append(nn.LayerNorm(hidden))

        # Output layer
        if n_layers > 1:
            self.layers.append(GraphAttention(hidden, out_features))
            self.norms.append(nn.LayerNorm(out_features))

    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        for layer, norm in zip(self.layers, self.norms):
            x = F.elu(norm(layer(x, adj)))
        return x


# ============================================================================
# SAMBA: MAMBA + GRAPH NEURAL NETWORK
# ============================================================================

class SAMBA(nn.Module):
    """
    SAMBA: Graph-Mamba Architecture
    Combina Mamba (SSM) con GNN per trading

    Input: Multi-timeframe features
    Output: Trading signals con confidence
    """
    def __init__(self, config: dict):
        super().__init__()
        self.config = config

        # Feature embedding per timeframe
        self.feature_embed = nn.Sequential(
            nn.Linear(config['N_FEATURES'], config['D_MODEL']),
            nn.LayerNorm(config['D_MODEL']),
            nn.GELU()
        )

        # Mamba layers per sequenza temporale
        self.mamba_layers = nn.ModuleList([
            MambaBlock(
                d_model=config['D_MODEL'],
                d_state=config['D_STATE'],
                d_conv=config['D_CONV'],
                expand=config['EXPAND']
            )
            for _ in range(config['N_LAYERS'])
        ])

        # GNN per correlazioni multi-timeframe
        self.gnn = GNN(
            in_features=config['D_MODEL'],
            hidden=config['GNN_HIDDEN'],
            out_features=config['GNN_HIDDEN'],
            n_layers=config['GNN_LAYERS']
        )

        # Adjacency matrix (fully connected tra timeframe)
        adj = torch.ones(config['N_NODES'], config['N_NODES'])
        self.register_buffer('adj', adj)

        # Fusion layer
        self.fusion = nn.Sequential(
            nn.Linear(config['D_MODEL'] + config['GNN_HIDDEN'] * config['N_NODES'], 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.GELU()
        )

        # Output heads
        self.action_head = nn.Linear(128, 3)  # FLAT, LONG, SHORT
        self.value_head = nn.Linear(128, 1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        x: (batch, n_timeframes, seq_len, n_features)
        Returns: action_logits (batch, 3), values (batch, 1)
        """
        batch, n_tf, seq_len, n_feat = x.shape

        # Process each timeframe with Mamba
        tf_outputs = []
        for tf in range(n_tf):
            h = self.feature_embed(x[:, tf])  # (batch, seq, d_model)
            for mamba in self.mamba_layers:
                h = mamba(h)
            tf_outputs.append(h[:, -1])  # Take last timestep

        # Stack for GNN
        tf_stack = torch.stack(tf_outputs, dim=1)  # (batch, n_tf, d_model)

        # GNN for cross-timeframe attention
        gnn_out = self.gnn(tf_stack, self.adj)  # (batch, n_tf, gnn_hidden)
        gnn_flat = gnn_out.flatten(1)  # (batch, n_tf * gnn_hidden)

        # Combine Mamba output (last timeframe) with GNN
        combined = torch.cat([tf_outputs[-1], gnn_flat], dim=1)

        # Fusion
        features = self.fusion(combined)

        # Output
        action_logits = self.action_head(features)
        values = self.value_head(features)

        return action_logits, values


# ============================================================================
# PPO AGENT
# ============================================================================

class PPOMemory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []

    def store(self, state, action, reward, value, log_prob, done):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.values.append(value)
        self.log_probs.append(log_prob)
        self.dones.append(done)

    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []

    def get_batches(self, batch_size: int):
        n = len(self.states)
        indices = np.arange(n)
        np.random.shuffle(indices)

        for start in range(0, n, batch_size):
            end = start + batch_size
            batch_idx = indices[start:end]
            yield (
                torch.stack([self.states[i] for i in batch_idx]),
                torch.tensor([self.actions[i] for i in batch_idx]),
                torch.tensor([self.rewards[i] for i in batch_idx], dtype=torch.float32),
                torch.tensor([self.values[i] for i in batch_idx], dtype=torch.float32),
                torch.tensor([self.log_probs[i] for i in batch_idx], dtype=torch.float32),
                torch.tensor([self.dones[i] for i in batch_idx], dtype=torch.float32)
            )


class PPOAgent:
    def __init__(self, model: SAMBA, config: dict):
        self.model = model.to(device)
        self.config = config
        self.memory = PPOMemory()

        self.optimizer = AdamW(
            model.parameters(),
            lr=config['LEARNING_RATE'],
            weight_decay=config['WEIGHT_DECAY']
        )
        self.scheduler = CosineAnnealingWarmRestarts(
            self.optimizer, T_0=50, T_mult=2
        )

    def select_action(self, state: torch.Tensor) -> Tuple[int, float, float]:
        """Select action using policy"""
        self.model.eval()
        with torch.no_grad():
            state = state.unsqueeze(0).to(device)
            action_logits, value = self.model(state)

            probs = F.softmax(action_logits, dim=-1)
            dist = torch.distributions.Categorical(probs)
            action = dist.sample()
            log_prob = dist.log_prob(action)

        return action.item(), log_prob.item(), value.squeeze().item()

    def compute_gae(self, rewards, values, dones, next_value):
        """Generalized Advantage Estimation"""
        advantages = []
        gae = 0

        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_val = next_value
            else:
                next_val = values[t + 1]

            delta = rewards[t] + self.config['GAMMA'] * next_val * (1 - dones[t]) - values[t]
            gae = delta + self.config['GAMMA'] * self.config['GAE_LAMBDA'] * (1 - dones[t]) * gae
            advantages.insert(0, gae)

        advantages = torch.tensor(advantages, dtype=torch.float32)
        returns = advantages + torch.tensor(values, dtype=torch.float32)

        return advantages, returns

    def update(self):
        """PPO update"""
        self.model.train()

        # Compute GAE
        with torch.no_grad():
            last_state = self.memory.states[-1].unsqueeze(0).to(device)
            _, next_value = self.model(last_state)
            next_value = next_value.squeeze().item()

        advantages, returns = self.compute_gae(
            self.memory.rewards,
            self.memory.values,
            self.memory.dones,
            next_value
        )

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # PPO epochs
        total_loss = 0
        for _ in range(self.config['PPO_EPOCHS']):
            for batch in self.memory.get_batches(self.config['BATCH_SIZE']):
                states, actions, _, old_values, old_log_probs, _ = batch
                batch_adv = advantages[:len(states)]
                batch_ret = returns[:len(states)]

                states = states.to(device)
                actions = actions.to(device)
                batch_adv = batch_adv.to(device)
                batch_ret = batch_ret.to(device)
                old_log_probs = old_log_probs.to(device)

                # Forward pass
                action_logits, values = self.model(states)
                probs = F.softmax(action_logits, dim=-1)
                dist = torch.distributions.Categorical(probs)

                new_log_probs = dist.log_prob(actions)
                entropy = dist.entropy().mean()

                # PPO clipped objective
                ratio = torch.exp(new_log_probs - old_log_probs)
                surr1 = ratio * batch_adv
                surr2 = torch.clamp(ratio, 1 - self.config['PPO_CLIP'], 1 + self.config['PPO_CLIP']) * batch_adv
                actor_loss = -torch.min(surr1, surr2).mean()

                # Value loss
                value_loss = F.mse_loss(values.squeeze(), batch_ret)

                # Total loss
                loss = actor_loss + self.config['VALUE_COEF'] * value_loss - self.config['ENTROPY_COEF'] * entropy

                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
                self.optimizer.step()

                total_loss += loss.item()

        self.scheduler.step()
        self.memory.clear()

        return total_loss


# ============================================================================
# TRADING ENVIRONMENT
# ============================================================================

class TradingEnv:
    """Fee-aware trading environment"""

    def __init__(self, prices: np.ndarray, config: dict):
        self.prices = prices
        self.config = config
        self.seq_len = config['SEQUENCE_LENGTH']
        self.fee_rate = config['FEE_RATE']
        self.min_hold = config['MIN_HOLD_BARS']

        # Pre-compute multi-timeframe features
        self.features = self._compute_all_features()

        self.reset()

    def _compute_all_features(self) -> np.ndarray:
        """Compute features for all timeframes"""
        n = len(self.prices)
        n_tf = CONFIG['N_NODES']
        n_feat = CONFIG['N_FEATURES']

        features = np.zeros((n_tf, n, n_feat))

        # Timeframe multipliers: 1m, 5m, 15m, 1h, 4h
        tf_mult = [1, 5, 15, 60, 240]

        for tf_idx, mult in enumerate(tf_mult):
            for i in range(n):
                features[tf_idx, i] = self._compute_features(i, mult)

        return features

    def _compute_features(self, idx: int, tf_mult: int) -> np.ndarray:
        """Compute 32 features for a single timeframe"""
        feat = np.zeros(CONFIG['N_FEATURES'])

        # Safety bounds
        lookback = min(idx, 100 * tf_mult)
        if lookback < 5:
            return feat

        prices = self.prices[max(0, idx - lookback):idx + 1]

        # 1-8: Returns at different horizons
        for i, horizon in enumerate([1, 2, 5, 10, 20, 30, 60, 120]):
            h = min(horizon * tf_mult, len(prices) - 1)
            if h > 0 and len(prices) > h:
                feat[i] = (prices[-1] / prices[-1-h] - 1) * 100

        # 9-12: Volatility measures
        if len(prices) > 5:
            returns = np.diff(prices) / prices[:-1]
            feat[8] = np.std(returns[-5:]) * 100 if len(returns) >= 5 else 0
            feat[9] = np.std(returns[-20:]) * 100 if len(returns) >= 20 else 0
            feat[10] = np.std(returns[-60:]) * 100 if len(returns) >= 60 else 0
            # Volatility of volatility
            if len(returns) >= 20:
                roll_vol = pd.Series(returns).rolling(5).std()
                feat[11] = roll_vol.std() * 100 if not pd.isna(roll_vol.std()) else 0

        # 13-16: Moving averages
        for i, period in enumerate([5, 10, 20, 50]):
            if len(prices) > period:
                ma = np.mean(prices[-period:])
                feat[12 + i] = (prices[-1] / ma - 1) * 100

        # 17-20: RSI variants
        if len(prices) > 14:
            delta = np.diff(prices)
            gains = np.where(delta > 0, delta, 0)
            losses = np.where(delta < 0, -delta, 0)
            for i, period in enumerate([7, 14, 21, 28]):
                if len(gains) >= period:
                    avg_gain = np.mean(gains[-period:])
                    avg_loss = np.mean(losses[-period:])
                    if avg_loss > 0:
                        rs = avg_gain / avg_loss
                        feat[16 + i] = 100 - (100 / (1 + rs))
                    else:
                        feat[16 + i] = 100

        # 21-24: MACD components
        if len(prices) >= 26:
            ema12 = pd.Series(prices).ewm(span=12).mean().iloc[-1]
            ema26 = pd.Series(prices).ewm(span=26).mean().iloc[-1]
            macd = ema12 - ema26
            signal = pd.Series(prices).ewm(span=26).mean().ewm(span=9).mean().iloc[-1]
            feat[20] = macd
            feat[21] = macd - signal
            feat[22] = (ema12 / ema26 - 1) * 100
            feat[23] = macd / prices[-1] * 100

        # 25-28: Bollinger Bands
        if len(prices) >= 20:
            ma20 = np.mean(prices[-20:])
            std20 = np.std(prices[-20:])
            upper = ma20 + 2 * std20
            lower = ma20 - 2 * std20
            feat[24] = (prices[-1] - lower) / (upper - lower) if upper != lower else 0.5
            feat[25] = (upper - lower) / ma20 * 100
            feat[26] = (prices[-1] - ma20) / std20 if std20 > 0 else 0
            feat[27] = std20 / ma20 * 100

        # 29-32: Price patterns
        if len(prices) >= 5:
            feat[28] = (prices[-1] - np.min(prices[-5:])) / (np.max(prices[-5:]) - np.min(prices[-5:]) + 1e-8)
            feat[29] = (np.max(prices[-5:]) - np.min(prices[-5:])) / prices[-1] * 100
            # Momentum
            feat[30] = np.mean(np.diff(prices[-5:])) / prices[-1] * 100
            # Acceleration
            if len(prices) >= 10:
                mom_now = np.mean(np.diff(prices[-5:]))
                mom_prev = np.mean(np.diff(prices[-10:-5]))
                feat[31] = (mom_now - mom_prev) / prices[-1] * 100

        # Normalize
        feat = np.clip(feat, -10, 10)

        return feat

    def reset(self) -> torch.Tensor:
        self.current_step = self.seq_len
        self.position = 0  # -1 short, 0 flat, 1 long
        self.entry_price = 0
        self.entry_step = 0
        self.capital = self.config['INITIAL_CAPITAL']
        self.total_pnl = 0
        self.trades = []

        return self._get_state()

    def _get_state(self) -> torch.Tensor:
        """Get current state: (n_timeframes, seq_len, n_features)"""
        idx = self.current_step
        state = self.features[:, idx - self.seq_len:idx, :]
        return torch.tensor(state, dtype=torch.float32)

    def step(self, action: int) -> Tuple[torch.Tensor, float, bool]:
        """
        action: 0=FLAT, 1=LONG, 2=SHORT
        Returns: next_state, reward, done
        """
        current_price = self.prices[self.current_step]
        reward = 0

        # Map action
        target_position = action - 1  # -1, 0, 1

        # Check minimum hold time
        hold_time = self.current_step - self.entry_step
        can_exit = hold_time >= self.min_hold or self.position == 0

        if can_exit and target_position != self.position:
            # Close existing position
            if self.position != 0:
                if self.position == 1:  # Long
                    pnl = (current_price / self.entry_price - 1) - self.fee_rate
                else:  # Short
                    pnl = (self.entry_price / current_price - 1) - self.fee_rate

                reward = pnl * 100  # Scale reward
                self.total_pnl += pnl
                self.trades.append({
                    'entry': self.entry_price,
                    'exit': current_price,
                    'side': 'LONG' if self.position == 1 else 'SHORT',
                    'pnl': pnl,
                    'hold_time': hold_time
                })

            # Open new position
            if target_position != 0:
                self.entry_price = current_price
                self.entry_step = self.current_step

            self.position = target_position

        # Small penalty for holding (opportunity cost)
        if self.position == 0:
            reward -= 0.001

        # Move forward
        self.current_step += 1
        done = self.current_step >= len(self.prices) - 1

        if done and self.position != 0:
            # Force close at end
            current_price = self.prices[self.current_step]
            if self.position == 1:
                pnl = (current_price / self.entry_price - 1) - self.fee_rate
            else:
                pnl = (self.entry_price / current_price - 1) - self.fee_rate
            reward = pnl * 100
            self.total_pnl += pnl

        return self._get_state(), reward, done


# ============================================================================
# DATA LOADING
# ============================================================================

def load_sol_data() -> np.ndarray:
    """Load SOL/EUR price data"""
    # Try multiple paths
    paths = [
        '/workspace/prices.csv',
        '/kaggle/input/bestrading-sol-eur-prices/prices.csv',
        './prices.csv',
        '/var/www/html/bestrading.cuttalo.com/scripts/kaggle-dataset-sol-eur/prices.csv'
    ]

    for path in paths:
        if os.path.exists(path):
            print(f"📂 Loading data from: {path}")
            df = pd.read_csv(path)
            prices = df['price'].values
            print(f"   Loaded {len(prices)} price points")
            print(f"   Date range: {df['timestamp'].iloc[0]} to {df['timestamp'].iloc[-1]}")
            print(f"   Price range: {prices.min():.2f} - {prices.max():.2f}")
            return prices

    raise FileNotFoundError("Could not find prices.csv in any expected location")


# ============================================================================
# TRAINING LOOP
# ============================================================================

def train():
    """Main training function"""
    print("\n" + "="*60)
    print("🚀 SOL MAMBA TRADING - TRAINING START")
    print("="*60)

    # Load data
    prices = load_sol_data()

    # Split train/val
    train_size = int(len(prices) * 0.8)
    train_prices = prices[:train_size]
    val_prices = prices[train_size:]

    print(f"\n📊 Data split:")
    print(f"   Train: {len(train_prices)} bars")
    print(f"   Val: {len(val_prices)} bars")

    # Create model and agent
    model = SAMBA(CONFIG)
    agent = PPOAgent(model, CONFIG)

    # Count parameters
    n_params = sum(p.numel() for p in model.parameters())
    print(f"\n🧠 Model: SAMBA (Mamba + GNN)")
    print(f"   Parameters: {n_params:,}")
    print(f"   Mamba layers: {CONFIG['N_LAYERS']}")
    print(f"   GNN layers: {CONFIG['GNN_LAYERS']}")

    # Training environment
    env = TradingEnv(train_prices, CONFIG)

    # Metrics
    best_return = -float('inf')
    best_sharpe = -float('inf')
    history = []

    print(f"\n🎯 Training for {CONFIG['EPISODES']} episodes...")
    print("-"*60)

    for episode in range(CONFIG['EPISODES']):
        state = env.reset()
        episode_reward = 0
        steps = 0

        while True:
            action, log_prob, value = agent.select_action(state)
            next_state, reward, done = env.step(action)

            agent.memory.store(state, action, reward, value, log_prob, done)

            episode_reward += reward
            state = next_state
            steps += 1

            if done:
                break

        # Update policy
        loss = agent.update()

        # Calculate metrics
        total_return = env.total_pnl * 100
        n_trades = len(env.trades)
        win_rate = sum(1 for t in env.trades if t['pnl'] > 0) / max(n_trades, 1) * 100

        # Sharpe ratio (simplified)
        if env.trades:
            returns = [t['pnl'] for t in env.trades]
            sharpe = np.mean(returns) / (np.std(returns) + 1e-8) * np.sqrt(252)
        else:
            sharpe = 0

        history.append({
            'episode': episode,
            'return': total_return,
            'trades': n_trades,
            'win_rate': win_rate,
            'sharpe': sharpe,
            'loss': loss
        })

        # Logging
        if episode % 10 == 0 or episode == CONFIG['EPISODES'] - 1:
            print(f"Ep {episode:4d} | Return: {total_return:+7.2f}% | "
                  f"Trades: {n_trades:4d} | Win: {win_rate:5.1f}% | "
                  f"Sharpe: {sharpe:+5.2f} | Loss: {loss:.4f}")

        # Save best model
        if total_return > best_return:
            best_return = total_return
            save_model(model, 'best_return', history[-1])

        if sharpe > best_sharpe and n_trades >= 10:
            best_sharpe = sharpe
            save_model(model, 'best_sharpe', history[-1])

    # Final save
    save_model(model, 'final', history[-1])

    # Validation
    print("\n" + "="*60)
    print("📈 VALIDATION")
    print("="*60)

    validate(model, val_prices, CONFIG)

    # Save history
    with open('/workspace/training_history.json', 'w') as f:
        json.dump(history, f, indent=2)

    print("\n✅ Training complete!")
    print(f"   Best return: {best_return:+.2f}%")
    print(f"   Best Sharpe: {best_sharpe:+.2f}")


def validate(model: SAMBA, prices: np.ndarray, config: dict):
    """Validate model on held-out data"""
    env = TradingEnv(prices, config)
    model.eval()

    state = env.reset()

    with torch.no_grad():
        while True:
            state_tensor = state.unsqueeze(0).to(device)
            action_logits, _ = model(state_tensor)
            action = action_logits.argmax(dim=-1).item()

            state, _, done = env.step(action)
            if done:
                break

    # Results
    total_return = env.total_pnl * 100
    n_trades = len(env.trades)
    win_rate = sum(1 for t in env.trades if t['pnl'] > 0) / max(n_trades, 1) * 100

    buy_hold = (prices[-1] / prices[config['SEQUENCE_LENGTH']] - 1) * 100

    print(f"   Model return: {total_return:+.2f}%")
    print(f"   Buy & Hold:   {buy_hold:+.2f}%")
    print(f"   Trades:       {n_trades}")
    print(f"   Win rate:     {win_rate:.1f}%")

    if env.trades:
        avg_hold = np.mean([t['hold_time'] for t in env.trades])
        avg_pnl = np.mean([t['pnl'] for t in env.trades]) * 100
        print(f"   Avg hold:     {avg_hold:.1f} bars")
        print(f"   Avg PnL:      {avg_pnl:+.3f}%")


def save_model(model: SAMBA, name: str, metrics: dict):
    """Save model in JSON format compatible with bestrading"""
    state_dict = model.state_dict()

    # Convert to serializable format
    weights = {}
    for k, v in state_dict.items():
        weights[k] = v.cpu().numpy().tolist()

    output = {
        'type': 'SAMBA',
        'architecture': {
            'd_model': CONFIG['D_MODEL'],
            'd_state': CONFIG['D_STATE'],
            'n_layers': CONFIG['N_LAYERS'],
            'n_nodes': CONFIG['N_NODES'],
            'gnn_hidden': CONFIG['GNN_HIDDEN']
        },
        'weights': weights,
        'metrics': {
            'totalReturn': metrics['return'],
            'trades': metrics['trades'],
            'winRate': metrics['win_rate'],
            'sharpe': metrics['sharpe'],
            'modelType': name,
            'trainedAt': datetime.now().isoformat(),
            'trainedOn': 'RunPod RTX 4090',
            'asset': 'SOL/EUR'
        },
        'config': CONFIG
    }

    path = f'/workspace/model_sol_{name}.json'
    with open(path, 'w') as f:
        json.dump(output, f, indent=2)

    # Also save PyTorch checkpoint
    torch.save({
        'model_state_dict': state_dict,
        'config': CONFIG,
        'metrics': metrics
    }, f'/workspace/model_sol_{name}.pt')

    print(f"   💾 Saved: {path}")


# ============================================================================
# MAIN
# ============================================================================

if __name__ == '__main__':
    print("""
    ╔═══════════════════════════════════════════════════════════════╗
    ║                                                               ║
    ║   🧠 SOL MAMBA TRADING - ULTRA ADVANCED                      ║
    ║                                                               ║
    ║   Architecture: SAMBA (Graph-Mamba + PPO)                    ║
    ║   - Mamba State Space Models (beats Transformers)            ║
    ║   - Graph Neural Networks (multi-timeframe correlation)      ║
    ║   - PPO Reinforcement Learning (optimal decisions)           ║
    ║   - Fee-aware reward shaping (0.4% round-trip)              ║
    ║                                                               ║
    ╚═══════════════════════════════════════════════════════════════╝
    """)

    train()
