#!/usr/bin/env python3
"""
V6 FAST - REGIME TRADING MODELS
================================
Optimized for SPEED:
1. Windowed episodes (not full dataset)
2. Parallel environments on GPU
3. Vectorized operations
4. Mini-batch training

Target: ~1 second per episode instead of minutes.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Tuple, Dict, List
import warnings
warnings.filterwarnings('ignore')

# ============================================================
# CONFIG
# ============================================================
class Config:
    # Data
    DATA_PATH = '/workspace/prices.csv'

    # Training
    EPISODES = 500
    BATCH_SIZE = 2048
    LR = 3e-4
    GAMMA = 0.99
    GAE_LAMBDA = 0.95
    CLIP_EPS = 0.2
    ENTROPY_COEF = 0.02

    # Environment
    NUM_ENVS = 64  # Parallel environments
    WINDOW_SIZE = 1000  # Steps per episode

    # Model
    HIDDEN_DIM = 256
    NUM_HEADS = 4
    NUM_LAYERS = 2
    DROPOUT = 0.1

    # Trading
    FEE = 0.0  # Zero fee during training
    IDLE_PENALTY = 0.0005
    TRADE_BONUS = 0.0002

    # Features
    LOOKBACK = 20

    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============================================================
# FEATURES (Vectorized GPU)
# ============================================================
@torch.jit.script
def compute_returns(prices: torch.Tensor, period: int) -> torch.Tensor:
    """Compute returns with given period."""
    n = prices.shape[0]
    ret = torch.zeros_like(prices)
    if period < n:
        ret[period:] = (prices[period:] - prices[:-period]) / (prices[:-period] + 1e-8)
    return ret

def compute_features(prices: np.ndarray) -> torch.Tensor:
    """Compute all features on GPU."""
    prices_t = torch.tensor(prices, dtype=torch.float32, device=Config.DEVICE)
    n = len(prices)

    features = []

    # Returns at multiple scales
    for period in [1, 5, 10, 20, 60]:
        features.append(compute_returns(prices_t, period))

    # Simple volatility (rolling std of 1-period returns)
    ret1 = compute_returns(prices_t, 1)
    vol = torch.zeros(n, device=Config.DEVICE)
    for w in [10, 30]:
        for i in range(w, n):
            vol[i] = ret1[i-w:i].std()
        features.append(vol.clone())

    # MA ratios
    for period in [10, 30]:
        ma = torch.zeros(n, device=Config.DEVICE)
        cumsum = torch.cumsum(prices_t, dim=0)
        ma[period:] = (cumsum[period:] - torch.cat([torch.zeros(1, device=Config.DEVICE), cumsum[:-period-1]])) / period
        ma[:period] = prices_t[:period].mean()
        ma_ratio = (prices_t - ma) / (ma + 1e-8)
        features.append(ma_ratio)

    # Stack and normalize
    feat = torch.stack(features, dim=1)

    # Z-score normalization
    mean = feat.mean(dim=0, keepdim=True)
    std = feat.std(dim=0, keepdim=True) + 1e-8
    feat = (feat - mean) / std
    feat = torch.clamp(feat, -3, 3)

    return feat

# ============================================================
# REGIME DETECTION (Vectorized)
# ============================================================
def detect_regimes(prices: np.ndarray, features: torch.Tensor) -> torch.Tensor:
    """Detect regimes for all timesteps."""
    n = len(prices)
    feat = features.cpu().numpy()

    # regime indices: 0=bullish, 1=bearish, 2=ranging, 3=volatile, 4=scalper
    regimes = np.full(n, 4, dtype=np.int64)  # Default: scalper

    ret20 = feat[:, 3]  # 20-period return
    ret60 = feat[:, 4]  # 60-period return
    vol30 = feat[:, 6]  # 30-period volatility

    # Volatile
    volatile_mask = vol30 > 1.5
    regimes[volatile_mask] = 3

    # Bullish
    bullish_mask = (ret20 > 0.3) & (ret60 > 0.2) & ~volatile_mask
    regimes[bullish_mask] = 0

    # Bearish
    bearish_mask = (ret20 < -0.3) & (ret60 < -0.2) & ~volatile_mask
    regimes[bearish_mask] = 1

    # Ranging
    ranging_mask = (np.abs(ret60) < 0.15) & (vol30 < 0.5) & ~volatile_mask
    regimes[ranging_mask] = 2

    return torch.tensor(regimes, dtype=torch.long, device=Config.DEVICE)

# ============================================================
# TRANSFORMER MODEL
# ============================================================
class TradingTransformer(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()

        self.embed = nn.Linear(input_dim + 1, Config.HIDDEN_DIM)  # +1 for position

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=Config.HIDDEN_DIM,
            nhead=Config.NUM_HEADS,
            dim_feedforward=Config.HIDDEN_DIM * 4,
            dropout=Config.DROPOUT,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=Config.NUM_LAYERS)

        self.actor_mean = nn.Linear(Config.HIDDEN_DIM, 1)
        self.actor_log_std = nn.Parameter(torch.zeros(1) - 0.5)
        self.critic = nn.Linear(Config.HIDDEN_DIM, 1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # x: (batch, seq, features)
        h = self.embed(x)
        h = self.transformer(h)
        h = h[:, -1]  # Take last timestep

        mean = torch.tanh(self.actor_mean(h))
        std = torch.exp(self.actor_log_std).expand(mean.shape[0], 1)
        value = self.critic(h)

        return mean, std, value

# ============================================================
# VECTORIZED ENVIRONMENT
# ============================================================
class VectorizedEnv:
    """Parallel trading environments on GPU."""

    def __init__(self, prices: torch.Tensor, features: torch.Tensor,
                 regimes: torch.Tensor, target_regime: int, num_envs: int):
        self.prices = prices
        self.features = features
        self.regimes = regimes
        self.target_regime = target_regime
        self.num_envs = num_envs
        self.n = len(prices)

        # State
        self.positions = torch.zeros(num_envs, device=Config.DEVICE)
        self.start_idx = torch.zeros(num_envs, dtype=torch.long, device=Config.DEVICE)
        self.current_idx = torch.zeros(num_envs, dtype=torch.long, device=Config.DEVICE)
        self.pnl = torch.zeros(num_envs, device=Config.DEVICE)
        self.trade_count = torch.zeros(num_envs, dtype=torch.long, device=Config.DEVICE)

        self.reset()

    def reset(self) -> torch.Tensor:
        """Reset all environments to random starting points."""
        # Random start positions (leave room for lookback and window)
        max_start = self.n - Config.WINDOW_SIZE - Config.LOOKBACK
        self.start_idx = torch.randint(Config.LOOKBACK, max_start, (self.num_envs,), device=Config.DEVICE)
        self.current_idx = self.start_idx.clone()
        self.positions = torch.zeros(self.num_envs, device=Config.DEVICE)
        self.pnl = torch.zeros(self.num_envs, device=Config.DEVICE)
        self.trade_count = torch.zeros(self.num_envs, dtype=torch.long, device=Config.DEVICE)

        return self._get_obs()

    def _get_obs(self) -> torch.Tensor:
        """Get observations for all environments."""
        # Get feature sequences for all envs
        obs_list = []
        for i in range(self.num_envs):
            idx = self.current_idx[i].item()
            start = max(0, idx - Config.LOOKBACK)
            feat_seq = self.features[start:idx+1]

            # Pad if needed
            if feat_seq.shape[0] < Config.LOOKBACK:
                pad = torch.zeros(Config.LOOKBACK - feat_seq.shape[0], feat_seq.shape[1], device=Config.DEVICE)
                feat_seq = torch.cat([pad, feat_seq], dim=0)

            # Add position info
            pos = self.positions[i].unsqueeze(0).expand(feat_seq.shape[0], 1)
            feat_seq = torch.cat([feat_seq, pos], dim=1)

            obs_list.append(feat_seq)

        return torch.stack(obs_list)  # (num_envs, lookback, features+1)

    def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Execute actions in all environments."""
        actions = actions.squeeze(-1)  # (num_envs,)

        old_positions = self.positions.clone()
        new_positions = actions.clamp(-1, 1)

        # Price changes
        current_prices = self.prices[self.current_idx]
        next_idx = (self.current_idx + 1).clamp(max=self.n-1)
        next_prices = self.prices[next_idx]

        price_returns = (next_prices - current_prices) / (current_prices + 1e-8)

        # PnL
        avg_positions = (old_positions + new_positions) / 2
        pnl = avg_positions * price_returns

        # Fee
        position_change = (new_positions - old_positions).abs()
        fee_cost = position_change * Config.FEE

        # Trade count
        self.trade_count += (position_change > 0.1).long()

        # Rewards
        rewards = pnl - fee_cost

        # Regime weighting
        current_regimes = self.regimes[self.current_idx]
        regime_match = (current_regimes == self.target_regime).float()
        rewards = rewards * (1.0 + regime_match)  # Bonus for matching regime

        # Idle penalty
        idle_mask = new_positions.abs() < 0.1
        rewards -= idle_mask.float() * Config.IDLE_PENALTY

        # Trade bonus
        rewards += position_change * Config.TRADE_BONUS

        # Direction bonus
        if self.target_regime == 0:  # Bullish
            rewards += (new_positions > 0).float() * 0.0002 * regime_match
        elif self.target_regime == 1:  # Bearish
            rewards -= (new_positions < 0).float() * 0.0002 * regime_match

        # Update state
        self.positions = new_positions
        self.pnl += pnl - fee_cost
        self.current_idx = next_idx

        # Check done
        steps_taken = self.current_idx - self.start_idx
        dones = (steps_taken >= Config.WINDOW_SIZE) | (self.current_idx >= self.n - 1)

        return self._get_obs(), rewards, dones

# ============================================================
# PPO TRAINER
# ============================================================
class PPOTrainer:
    def __init__(self, model: TradingTransformer):
        self.model = model
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LR)

    def train_batch(self, states: torch.Tensor, actions: torch.Tensor,
                   old_log_probs: torch.Tensor, advantages: torch.Tensor,
                   returns: torch.Tensor) -> Dict[str, float]:
        """PPO update on a batch."""

        mean, std, values = self.model(states)
        dist = torch.distributions.Normal(mean, std)
        log_probs = dist.log_prob(actions).squeeze(-1)
        entropy = dist.entropy().mean()

        # PPO loss
        ratio = torch.exp(log_probs - old_log_probs)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - Config.CLIP_EPS, 1 + Config.CLIP_EPS) * advantages

        actor_loss = -torch.min(surr1, surr2).mean()
        critic_loss = F.mse_loss(values.squeeze(-1), returns)

        loss = actor_loss + 0.5 * critic_loss - Config.ENTROPY_COEF * entropy

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
        self.optimizer.step()

        return {
            'actor_loss': actor_loss.item(),
            'critic_loss': critic_loss.item(),
            'entropy': entropy.item()
        }

# ============================================================
# TRAINING LOOP
# ============================================================
def train_regime_model(regime: int, regime_name: str, prices: torch.Tensor,
                       features: torch.Tensor, regimes: torch.Tensor) -> TradingTransformer:
    """Train model for a specific regime."""

    print(f"\n{'='*60}")
    print(f"🎯 Training {regime_name.upper()} model")
    print(f"{'='*60}")

    regime_count = (regimes == regime).sum().item()
    print(f"   Regime samples: {regime_count}")

    # Model
    input_dim = features.shape[1]
    model = TradingTransformer(input_dim).to(Config.DEVICE)
    trainer = PPOTrainer(model)

    # Environment
    env = VectorizedEnv(prices, features, regimes, regime, Config.NUM_ENVS)

    best_return = -float('inf')
    best_state = None

    for ep in range(Config.EPISODES):
        # Collect trajectories
        states_list = []
        actions_list = []
        rewards_list = []
        values_list = []
        log_probs_list = []

        obs = env.reset()
        done = torch.zeros(Config.NUM_ENVS, dtype=torch.bool, device=Config.DEVICE)

        while not done.all():
            with torch.no_grad():
                mean, std, value = model(obs)
                dist = torch.distributions.Normal(mean, std)
                action = dist.sample()
                log_prob = dist.log_prob(action).squeeze(-1)

            states_list.append(obs[~done])
            actions_list.append(action[~done])
            log_probs_list.append(log_prob[~done])
            values_list.append(value.squeeze(-1)[~done])

            obs, rewards, new_done = env.step(action)
            rewards_list.append(rewards[~done])

            done = done | new_done

        if len(states_list) == 0:
            continue

        # Stack trajectories
        states = torch.cat(states_list)
        actions = torch.cat(actions_list)
        old_log_probs = torch.cat(log_probs_list)
        values = torch.cat(values_list)
        rewards = torch.cat(rewards_list)

        # Compute advantages
        returns = rewards  # Simple returns for now
        advantages = returns - values.detach()
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # PPO update
        n_samples = len(states)
        indices = torch.randperm(n_samples, device=Config.DEVICE)

        for start in range(0, n_samples, Config.BATCH_SIZE):
            end = min(start + Config.BATCH_SIZE, n_samples)
            batch_idx = indices[start:end]

            trainer.train_batch(
                states[batch_idx],
                actions[batch_idx],
                old_log_probs[batch_idx],
                advantages[batch_idx],
                returns[batch_idx]
            )

        # Stats
        mean_return = env.pnl.mean().item() * 100
        mean_trades = env.trade_count.float().mean().item()

        if mean_return > best_return:
            best_return = mean_return
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

        if ep % 25 == 0:
            print(f"   Ep {ep:3d} | Ret: {mean_return:+7.2f}% | Trades: {mean_trades:5.1f}")

    # Load best
    if best_state:
        model.load_state_dict({k: v.to(Config.DEVICE) for k, v in best_state.items()})

    print(f"   ✅ {regime_name.upper()} complete - Best: {best_return:+.2f}%")

    return model

def main():
    print("🚀 Device:", Config.DEVICE)
    if Config.DEVICE.type == 'cuda':
        print(f"   GPU: {torch.cuda.get_device_name(0)}")

    print("""
    ╔══════════════════════════════════════════════════════════════╗
    ║                                                              ║
    ║   🚀 REGIME TRADING V6 - FAST TRAINING                      ║
    ║                                                              ║
    ║   Optimizations:                                             ║
    ║   - 64 parallel environments                                 ║
    ║   - Windowed episodes (1000 steps)                          ║
    ║   - Vectorized operations                                   ║
    ║   - ~1 second per episode                                   ║
    ║                                                              ║
    ╚══════════════════════════════════════════════════════════════╝
    """)

    # Load data
    df = pd.read_csv(Config.DATA_PATH)
    if 'price' in df.columns:
        prices = df['price'].values.astype(np.float32)
    elif 'close' in df.columns:
        prices = df['close'].values.astype(np.float32)
    else:
        prices = df.iloc[:, 1].values.astype(np.float32)

    print(f"📂 Loaded {len(prices):,} prices")

    prices_t = torch.tensor(prices, dtype=torch.float32, device=Config.DEVICE)

    # Features
    print("⚡ Computing features...")
    features = compute_features(prices)
    print(f"   Shape: {features.shape}")

    # Regimes
    print("🔍 Detecting regimes...")
    regimes = detect_regimes(prices, features)

    regime_names = ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']
    print("\n📊 Regime distribution:")
    for i, name in enumerate(regime_names):
        count = (regimes == i).sum().item()
        pct = count / len(regimes) * 100
        print(f"   {name:10s}: {count:6d} ({pct:5.1f}%)")

    # Train models
    models = {}
    for i, name in enumerate(regime_names):
        model = train_regime_model(i, name, prices_t, features, regimes)
        models[name] = model

        # Save
        save_path = f'/workspace/model_{name}_v6.pt'
        torch.save({
            'model_state': model.state_dict(),
            'regime': name,
            'regime_id': i,
            'input_dim': features.shape[1],
            'config': {
                'hidden_dim': Config.HIDDEN_DIM,
                'num_heads': Config.NUM_HEADS,
                'num_layers': Config.NUM_LAYERS,
            }
        }, save_path)
        print(f"   💾 Saved: {save_path}")

    print("\n" + "="*60)
    print("✅ ALL MODELS TRAINED!")
    print("="*60)

    # Final evaluation with fee
    print("\n📈 Final evaluation (with 0.4% fee):")
    original_fee = Config.FEE
    Config.FEE = 0.004
    Config.IDLE_PENALTY = 0
    Config.TRADE_BONUS = 0

    for i, name in enumerate(regime_names):
        model = models[name]
        env = VectorizedEnv(prices_t, features, regimes, i, 1)
        obs = env.reset()

        total_steps = 0
        while total_steps < len(prices) - Config.LOOKBACK - 1:
            with torch.no_grad():
                mean, _, _ = model(obs)
            obs, _, done = env.step(mean)
            total_steps += 1
            if done.all():
                break

        print(f"   {name:10s}: Return {env.pnl[0].item()*100:+7.2f}% | Trades: {env.trade_count[0].item()}")

    Config.FEE = original_fee
    print("\n🎉 Done! Models saved to /workspace/model_*_v6.pt")

if __name__ == '__main__':
    main()
