#!/usr/bin/env python3
"""
REGIME-BASED TRADING MODELS V4
==============================
Addestra 5 modelli specializzati per i diversi regimi di mercato:
- bullish: trend up con conferma flow
- bearish: trend down con conferma flow
- ranging: mercato laterale/calm
- volatile: alta volatilità
- scalper: micro movimenti

Ogni modello ha reward shaping specifico per il suo regime.
"""

import os
import json
import numpy as np
import pandas as pd
from datetime import datetime
from typing import Dict, Tuple
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Device: {device}")

# ============================================================================
# CONFIG
# ============================================================================

CONFIG = {
    # Trading - FEE RIDOTTA per training (incentiva trading)
    'FEE_RATE': 0.002,           # 0.2% per training (0.4% in prod)
    'INITIAL_CAPITAL': 10000,
    'MIN_HOLD_BARS': 2,

    # Model
    'D_MODEL': 128,
    'N_LAYERS': 4,
    'SEQ_LEN': 60,
    'N_FEATURES': 20,

    # Training
    'EPISODES': 150,
    'BATCH_SIZE': 128,
    'LR': 1e-3,
    'GAMMA': 0.99,
}

# ============================================================================
# REGIME DETECTION
# ============================================================================

def detect_regime(prices: np.ndarray, idx: int, window: int = 60) -> str:
    """Rileva il regime di mercato corrente"""
    if idx < window:
        return 'ranging'

    p = prices[idx - window:idx]
    returns = np.diff(p) / p[:-1]

    # Metriche
    trend = (p[-1] / p[0] - 1) * 100
    volatility = np.std(returns) * 100
    momentum = np.mean(returns[-10:]) * 100

    # Classificazione
    if volatility > 0.8:  # Alta volatilità
        return 'volatile'
    elif abs(trend) < 0.5 and volatility < 0.3:  # Stretto range
        return 'scalper'
    elif trend > 1.0 and momentum > 0:  # Trend up
        return 'bullish'
    elif trend < -1.0 and momentum < 0:  # Trend down
        return 'bearish'
    else:
        return 'ranging'


def filter_data_by_regime(prices: np.ndarray, features: np.ndarray, regime: str) -> Tuple[np.ndarray, np.ndarray, list]:
    """Filtra i dati per un regime specifico"""
    indices = []
    window = 60

    for i in range(window, len(prices)):
        detected = detect_regime(prices, i, window)
        if detected == regime:
            indices.append(i)

    if len(indices) < 1000:
        # Se pochi dati per regime, usa tutti
        indices = list(range(window, len(prices)))

    return prices, features, indices


# ============================================================================
# FEATURES
# ============================================================================

def compute_features(prices: np.ndarray) -> np.ndarray:
    """20 features ottimizzate"""
    n = len(prices)
    features = np.zeros((n, CONFIG['N_FEATURES']), dtype=np.float32)
    df = pd.DataFrame({'p': prices})

    # Returns (5)
    for i, period in enumerate([1, 5, 15, 30, 60]):
        features[:, i] = df['p'].pct_change(period).fillna(0).values * 100

    # Volatility (3)
    for i, period in enumerate([10, 30, 60]):
        features[:, 5+i] = df['p'].pct_change().rolling(period).std().fillna(0).values * 100

    # MA distances (3)
    for i, period in enumerate([10, 30, 60]):
        ma = df['p'].rolling(period).mean()
        features[:, 8+i] = ((df['p'] - ma) / (ma + 1e-8) * 100).fillna(0).values

    # RSI (1)
    delta = df['p'].diff()
    gain = delta.where(delta > 0, 0).rolling(14).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
    rs = gain / (loss + 1e-8)
    features[:, 11] = (100 - 100 / (1 + rs)).fillna(50).values

    # Bollinger (2)
    ma20 = df['p'].rolling(20).mean()
    std20 = df['p'].rolling(20).std()
    features[:, 12] = ((df['p'] - ma20) / (std20 + 1e-8)).fillna(0).values
    features[:, 13] = (std20 / (ma20 + 1e-8) * 100).fillna(0).values

    # Momentum (2)
    features[:, 14] = df['p'].diff(5).fillna(0).values / (df['p'].values + 1e-8) * 100
    features[:, 15] = df['p'].diff(20).fillna(0).values / (df['p'].values + 1e-8) * 100

    # Price position (2)
    roll_max = df['p'].rolling(60).max()
    roll_min = df['p'].rolling(60).min()
    features[:, 16] = ((df['p'] - roll_min) / (roll_max - roll_min + 1e-8)).fillna(0.5).values
    features[:, 17] = ((roll_max - roll_min) / (df['p'] + 1e-8) * 100).fillna(0).values

    # Trend (2)
    ema12 = df['p'].ewm(span=12).mean()
    ema26 = df['p'].ewm(span=26).mean()
    features[:, 18] = ((ema12 - ema26) / (ema26 + 1e-8) * 100).fillna(0).values
    features[:, 19] = ((df['p'] - ema12) / (ema12 + 1e-8) * 100).fillna(0).values

    return np.clip(features, -10, 10)


# ============================================================================
# MODEL
# ============================================================================

class TradingModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        d = config['D_MODEL']

        self.embed = nn.Sequential(
            nn.Linear(config['N_FEATURES'], d),
            nn.LayerNorm(d),
            nn.GELU(),
            nn.Dropout(0.1)
        )

        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d, d * 2),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(d * 2, d),
                nn.LayerNorm(d)
            ) for _ in range(config['N_LAYERS'])
        ])

        self.actor = nn.Sequential(
            nn.Linear(d * config['SEQ_LEN'], d),
            nn.GELU(),
            nn.Linear(d, 3)  # FLAT, LONG, SHORT
        )

        self.critic = nn.Sequential(
            nn.Linear(d * config['SEQ_LEN'], d),
            nn.GELU(),
            nn.Linear(d, 1)
        )

    def forward(self, x):
        # x: (batch, seq, features)
        x = self.embed(x)
        for layer in self.layers:
            x = x + layer(x)
        x = x.flatten(1)
        return self.actor(x), self.critic(x)


# ============================================================================
# REGIME-SPECIFIC REWARD SHAPING
# ============================================================================

def get_reward_params(regime: str) -> dict:
    """Parametri reward specifici per regime"""
    params = {
        'bullish': {
            'long_bonus': 1.5,      # Bonus per long
            'short_penalty': 0.5,   # Penalità per short
            'flat_penalty': 0.01,   # Penalità per flat (incentiva trading)
            'win_bonus': 2.0,       # Bonus trade vincente
        },
        'bearish': {
            'long_bonus': 0.5,
            'short_penalty': 1.5,   # Bonus per short (inverted)
            'flat_penalty': 0.01,
            'win_bonus': 2.0,
        },
        'ranging': {
            'long_bonus': 1.0,
            'short_penalty': 1.0,
            'flat_penalty': 0.005,  # Meno penalità per flat
            'win_bonus': 1.5,
        },
        'volatile': {
            'long_bonus': 1.2,
            'short_penalty': 1.2,
            'flat_penalty': 0.002,  # Minima penalità (può aspettare)
            'win_bonus': 3.0,       # Alto bonus per catturare big moves
        },
        'scalper': {
            'long_bonus': 1.0,
            'short_penalty': 1.0,
            'flat_penalty': 0.02,   # Alta penalità (deve tradare spesso)
            'win_bonus': 1.2,       # Piccoli profitti frequenti
        }
    }
    return params.get(regime, params['ranging'])


# ============================================================================
# TRAINING ENVIRONMENT
# ============================================================================

class RegimeTradingEnv:
    def __init__(self, prices, features, indices, regime, config):
        self.prices = prices
        self.features = features
        self.indices = indices
        self.regime = regime
        self.config = config
        self.seq_len = config['SEQ_LEN']
        self.fee = config['FEE_RATE']
        self.reward_params = get_reward_params(regime)
        self.reset()

    def reset(self):
        # Random start dentro gli indici del regime
        valid_start = [i for i in self.indices if i >= self.seq_len and i < len(self.prices) - 500]
        if not valid_start:
            valid_start = list(range(self.seq_len, len(self.prices) - 500))
        self.start_idx = np.random.choice(valid_start)
        self.step_idx = self.start_idx
        self.end_idx = min(self.start_idx + 500, len(self.prices) - 1)

        self.position = 0
        self.entry_price = 0
        self.entry_step = 0
        self.total_pnl = 0
        self.trades = []
        return self._get_state()

    def _get_state(self):
        state = self.features[self.step_idx - self.seq_len:self.step_idx]
        return torch.tensor(state, dtype=torch.float32)

    def step(self, action):
        price = self.prices[self.step_idx]
        reward = 0
        rp = self.reward_params

        target_pos = action - 1  # -1, 0, 1
        hold_time = self.step_idx - self.entry_step

        if (hold_time >= self.config['MIN_HOLD_BARS'] or self.position == 0) and target_pos != self.position:
            # Close
            if self.position != 0:
                if self.position == 1:
                    pnl = (price / self.entry_price - 1) - self.fee
                else:
                    pnl = (self.entry_price / price - 1) - self.fee

                # Reward shaping per regime
                base_reward = pnl * 100
                if pnl > 0:
                    reward = base_reward * rp['win_bonus']
                else:
                    reward = base_reward

                self.total_pnl += pnl
                self.trades.append({'pnl': pnl, 'side': self.position})

            # Open new
            if target_pos != 0:
                self.entry_price = price
                self.entry_step = self.step_idx

                # Bonus/penalità per direzione in base al regime
                if target_pos == 1:  # Long
                    reward += 0.01 * rp['long_bonus']
                else:  # Short
                    reward += 0.01 * rp['short_penalty']

            self.position = target_pos

        # Flat penalty
        if self.position == 0:
            reward -= rp['flat_penalty']

        self.step_idx += 1
        done = self.step_idx >= self.end_idx

        # Force close
        if done and self.position != 0:
            price = self.prices[self.step_idx]
            if self.position == 1:
                pnl = (price / self.entry_price - 1) - self.fee
            else:
                pnl = (self.entry_price / price - 1) - self.fee
            reward = pnl * 100
            self.total_pnl += pnl

        return self._get_state(), reward, done


# ============================================================================
# TRAINING
# ============================================================================

def train_regime_model(prices, features, regime, config):
    """Addestra un modello per un regime specifico"""
    print(f"\n{'='*60}")
    print(f"🎯 Training {regime.upper()} model")
    print(f"{'='*60}")

    # Filtra dati per regime
    _, _, indices = filter_data_by_regime(prices, features, regime)
    print(f"   Data points for {regime}: {len(indices)}")

    # Model
    model = TradingModel(config).to(device)
    optimizer = AdamW(model.parameters(), lr=config['LR'])

    # Storage
    best_return = -999
    best_sharpe = -999

    for ep in range(config['EPISODES']):
        env = RegimeTradingEnv(prices, features, indices, regime, config)
        state = env.reset()

        states, actions, rewards, values, log_probs = [], [], [], [], []

        while True:
            model.eval()
            with torch.no_grad():
                state_t = state.unsqueeze(0).to(device)
                logits, value = model(state_t)
                probs = F.softmax(logits, dim=-1)
                dist = torch.distributions.Categorical(probs)
                action = dist.sample()
                log_prob = dist.log_prob(action)

            next_state, reward, done = env.step(action.item())

            states.append(state)
            actions.append(action.item())
            rewards.append(reward)
            values.append(value.item())
            log_probs.append(log_prob.item())

            state = next_state
            if done:
                break

        # PPO Update
        model.train()
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + config['GAMMA'] * G
            returns.insert(0, G)

        states_t = torch.stack(states).to(device)
        actions_t = torch.tensor(actions).to(device)
        returns_t = torch.tensor(returns, dtype=torch.float32).to(device)
        old_log_probs_t = torch.tensor(log_probs).to(device)

        advantages = returns_t - torch.tensor(values).to(device)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        for _ in range(4):
            logits, values_pred = model(states_t)
            probs = F.softmax(logits, dim=-1)
            dist = torch.distributions.Categorical(probs)
            new_log_probs = dist.log_prob(actions_t)
            entropy = dist.entropy().mean()

            ratio = torch.exp(new_log_probs - old_log_probs_t)
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 0.8, 1.2) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()

            critic_loss = F.mse_loss(values_pred.squeeze(), returns_t)
            loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

        # Metrics
        ret = env.total_pnl * 100
        n_trades = len(env.trades)
        win_rate = sum(1 for t in env.trades if t['pnl'] > 0) / max(n_trades, 1) * 100

        if env.trades:
            pnls = [t['pnl'] for t in env.trades]
            sharpe = np.mean(pnls) / (np.std(pnls) + 1e-8) * np.sqrt(252)
        else:
            sharpe = 0

        if ep % 20 == 0:
            print(f"   Ep {ep:3d} | Ret: {ret:+7.2f}% | Trades: {n_trades:4d} | Win: {win_rate:5.1f}%")

        if ret > best_return and n_trades >= 5:
            best_return = ret
            save_model(model, f'{regime}', {'return': ret, 'trades': n_trades, 'win_rate': win_rate, 'sharpe': sharpe, 'regime': regime})

    print(f"   ✅ {regime.upper()} complete - Best return: {best_return:+.2f}%")
    return model


def save_model(model, name, metrics):
    state = model.state_dict()
    weights = {k: v.cpu().numpy().tolist() for k, v in state.items()}

    output = {
        'type': 'RegimeTrader',
        'regime': metrics.get('regime', name),
        'asset': 'BTC/EUR',
        'weights': weights,
        'metrics': metrics,
        'config': CONFIG,
        'trainedAt': datetime.now().isoformat(),
        'trainedOn': 'RunPod RTX 4090'
    }

    path = f'/workspace/model_{name}.json'
    with open(path, 'w') as f:
        json.dump(output, f)
    torch.save(model.state_dict(), f'/workspace/model_{name}.pt')


def main():
    print("""
    ╔══════════════════════════════════════════════════════════════╗
    ║                                                              ║
    ║   🎯 REGIME-BASED TRADING MODELS V4                         ║
    ║                                                              ║
    ║   Training 5 specialized models:                            ║
    ║   - bullish  (trend up + flow)                              ║
    ║   - bearish  (trend down + flow)                            ║
    ║   - ranging  (lateral/calm)                                 ║
    ║   - volatile (high volatility)                              ║
    ║   - scalper  (micro movements)                              ║
    ║                                                              ║
    ╚══════════════════════════════════════════════════════════════╝
    """)

    # Load data
    paths = ['/workspace/prices.csv', './prices.csv']
    for p in paths:
        if os.path.exists(p):
            df = pd.read_csv(p)
            prices = df['price'].values.astype(np.float32)
            print(f"📂 Loaded {len(prices)} prices")
            break

    # Features
    print("⚡ Computing features...")
    features = compute_features(prices)

    # Count regime distribution
    print("\n📊 Regime distribution:")
    regime_counts = {'bullish': 0, 'bearish': 0, 'ranging': 0, 'volatile': 0, 'scalper': 0}
    for i in range(60, len(prices)):
        r = detect_regime(prices, i)
        regime_counts[r] += 1

    for r, c in regime_counts.items():
        pct = c / sum(regime_counts.values()) * 100
        print(f"   {r:10s}: {c:6d} ({pct:5.1f}%)")

    # Train each regime model
    regimes = ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']
    models = {}

    for regime in regimes:
        models[regime] = train_regime_model(prices, features, regime, CONFIG)

    print("\n" + "="*60)
    print("✅ ALL MODELS TRAINED!")
    print("="*60)
    print("\nSaved models:")
    for regime in regimes:
        print(f"   - model_{regime}.json")
        print(f"   - model_{regime}.pt")

    print("\n📝 Next steps:")
    print("   1. Download models from pod")
    print("   2. Load into orchestrator database")
    print("   3. Test in paper mode")


if __name__ == '__main__':
    main()
