#!/usr/bin/env python3
"""
V5 - REGIME TRADING MODELS
===========================
Fixes for the "not trading" problem:
1. ZERO fee during training (learn patterns, not to avoid fees)
2. FORCE trading with idle penalty
3. Continuous position sizing (-1 to +1)
4. All data, regime-weighted rewards
5. Aggressive exploration with entropy bonus

This version will ACTUALLY TRADE.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Tuple, Dict
import warnings
warnings.filterwarnings('ignore')

# ============================================================
# CONFIG
# ============================================================
class Config:
    # Data
    DATA_PATH = '/workspace/prices.csv'

    # Training
    EPISODES = 200
    BATCH_SIZE = 512
    LR = 1e-4
    GAMMA = 0.99
    GAE_LAMBDA = 0.95
    CLIP_EPS = 0.2
    ENTROPY_COEF = 0.05  # High entropy to force exploration

    # Model
    HIDDEN_DIM = 256
    NUM_LAYERS = 3
    DROPOUT = 0.1

    # Trading
    FEE = 0.0  # ZERO fee during training!
    IDLE_PENALTY = 0.001  # Penalty for not having position
    TRADE_BONUS = 0.0005  # Bonus for changing position

    # Features
    LOOKBACK = 32

    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============================================================
# FEATURES (Vectorized, GPU)
# ============================================================
def compute_features_gpu(prices: np.ndarray) -> torch.Tensor:
    """Compute features on GPU for speed."""
    prices_t = torch.tensor(prices, dtype=torch.float32, device=Config.DEVICE)
    n = len(prices)

    features = []

    # Returns at multiple scales
    for period in [1, 5, 15, 30, 60]:
        if period < n:
            ret = torch.zeros(n, device=Config.DEVICE)
            ret[period:] = (prices_t[period:] - prices_t[:-period]) / (prices_t[:-period] + 1e-8)
            features.append(ret)

    # Volatility (rolling std of returns)
    ret1 = torch.zeros(n, device=Config.DEVICE)
    ret1[1:] = (prices_t[1:] - prices_t[:-1]) / (prices_t[:-1] + 1e-8)

    for window in [10, 30, 60]:
        vol = torch.zeros(n, device=Config.DEVICE)
        for i in range(window, n):
            vol[i] = ret1[i-window:i].std()
        features.append(vol)

    # Moving averages
    for period in [10, 30, 60]:
        ma = torch.zeros(n, device=Config.DEVICE)
        for i in range(period, n):
            ma[i] = prices_t[i-period:i].mean()
        ma_ratio = (prices_t - ma) / (ma + 1e-8)
        features.append(ma_ratio)

    # RSI-like momentum
    for period in [14, 28]:
        gains = torch.zeros(n, device=Config.DEVICE)
        losses = torch.zeros(n, device=Config.DEVICE)
        for i in range(period, n):
            changes = prices_t[i-period+1:i+1] - prices_t[i-period:i]
            gains[i] = changes[changes > 0].sum()
            losses[i] = (-changes[changes < 0]).sum()
        rsi = gains / (gains + losses + 1e-8)
        features.append(rsi)

    # Stack and normalize
    feat_tensor = torch.stack(features, dim=1)

    # Z-score normalization
    mean = feat_tensor.mean(dim=0, keepdim=True)
    std = feat_tensor.std(dim=0, keepdim=True) + 1e-8
    feat_tensor = (feat_tensor - mean) / std

    # Clip outliers
    feat_tensor = torch.clamp(feat_tensor, -5, 5)

    return feat_tensor

# ============================================================
# REGIME DETECTION
# ============================================================
def detect_regimes(prices: np.ndarray, features: torch.Tensor) -> np.ndarray:
    """
    Detect market regime for each timestep.
    Returns array of regime names.
    """
    n = len(prices)
    regimes = np.array(['scalper'] * n)  # Default

    # Features: [ret1, ret5, ret15, ret30, ret60, vol10, vol30, vol60, ma10, ma30, ma60, rsi14, rsi28]
    feat_np = features.cpu().numpy()

    for i in range(100, n):
        ret30 = feat_np[i, 3]  # 30-period return
        ret60 = feat_np[i, 4]  # 60-period return
        vol30 = feat_np[i, 6]  # 30-period volatility
        rsi = feat_np[i, 11]   # RSI 14

        # Volatile: high volatility
        if vol30 > 1.5:
            regimes[i] = 'volatile'
        # Bullish: strong uptrend
        elif ret30 > 0.5 and ret60 > 0.3 and rsi > 0.6:
            regimes[i] = 'bullish'
        # Bearish: strong downtrend
        elif ret30 < -0.5 and ret60 < -0.3 and rsi < 0.4:
            regimes[i] = 'bearish'
        # Ranging: sideways
        elif abs(ret60) < 0.2 and vol30 < 0.5:
            regimes[i] = 'ranging'
        # else: scalper (default)

    return regimes

# ============================================================
# ACTOR-CRITIC MODEL (Continuous actions)
# ============================================================
class ActorCritic(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()

        # Shared backbone
        self.backbone = nn.Sequential(
            nn.Linear(input_dim, Config.HIDDEN_DIM),
            nn.LayerNorm(Config.HIDDEN_DIM),
            nn.GELU(),
            nn.Dropout(Config.DROPOUT),
            nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM),
            nn.LayerNorm(Config.HIDDEN_DIM),
            nn.GELU(),
            nn.Dropout(Config.DROPOUT),
            nn.Linear(Config.HIDDEN_DIM, Config.HIDDEN_DIM),
            nn.LayerNorm(Config.HIDDEN_DIM),
            nn.GELU(),
        )

        # Actor: outputs mean and log_std for continuous action
        self.actor_mean = nn.Linear(Config.HIDDEN_DIM, 1)
        self.actor_log_std = nn.Parameter(torch.zeros(1))  # Learnable std

        # Critic
        self.critic = nn.Linear(Config.HIDDEN_DIM, 1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        h = self.backbone(x)

        mean = torch.tanh(self.actor_mean(h))  # Position in [-1, 1]
        std = torch.exp(self.actor_log_std).expand_as(mean)
        value = self.critic(h)

        return mean, std, value

    def get_action(self, x: torch.Tensor, deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        mean, std, value = self.forward(x)

        if deterministic:
            action = mean
        else:
            dist = torch.distributions.Normal(mean, std)
            action = dist.sample()
            action = torch.clamp(action, -1, 1)

        return action, value

    def evaluate(self, x: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mean, std, value = self.forward(x)

        dist = torch.distributions.Normal(mean, std)
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()

        return log_prob, value, entropy

# ============================================================
# TRADING ENVIRONMENT (Continuous positions)
# ============================================================
class TradingEnv:
    def __init__(self, prices: np.ndarray, features: torch.Tensor,
                 regimes: np.ndarray, target_regime: str):
        self.prices = prices
        self.features = features
        self.regimes = regimes
        self.target_regime = target_regime
        self.n = len(prices)

        # Create regime mask (1.0 = target regime, 0.5 = other)
        self.regime_weights = np.where(regimes == target_regime, 1.0, 0.3)

        self.reset()

    def reset(self):
        self.idx = Config.LOOKBACK
        self.position = 0.0  # Continuous position [-1, 1]
        self.pnl = 0.0
        self.trade_count = 0
        self.wins = 0
        return self._get_obs()

    def _get_obs(self) -> torch.Tensor:
        # Current features
        feat = self.features[self.idx]

        # Add position info
        pos_tensor = torch.tensor([self.position], device=Config.DEVICE, dtype=torch.float32)

        return torch.cat([feat, pos_tensor])

    def step(self, action: float) -> Tuple[torch.Tensor, float, bool]:
        """
        action: continuous position target in [-1, 1]
        -1 = full short, 0 = flat, 1 = full long
        """
        old_position = self.position
        new_position = float(action)

        # Price change
        price_now = self.prices[self.idx]
        price_next = self.prices[min(self.idx + 1, self.n - 1)]
        price_return = (price_next - price_now) / price_now

        # PnL from position
        # Average position during step * return
        avg_position = (old_position + new_position) / 2
        pnl = avg_position * price_return

        # Fee on position change (ZERO during training!)
        position_change = abs(new_position - old_position)
        fee_cost = position_change * Config.FEE

        # Track trades
        if position_change > 0.1:
            self.trade_count += 1
            if pnl > 0:
                self.wins += 1

        # Base reward is PnL
        reward = pnl - fee_cost

        # Regime-weighted reward
        regime_weight = self.regime_weights[self.idx]
        reward *= regime_weight

        # IDLE PENALTY: penalize being flat
        if abs(new_position) < 0.1:
            reward -= Config.IDLE_PENALTY

        # TRADE BONUS: reward for changing position
        reward += position_change * Config.TRADE_BONUS

        # Direction bonus for target regime
        if self.regimes[self.idx] == self.target_regime:
            if self.target_regime == 'bullish' and new_position > 0:
                reward += 0.001 * new_position
            elif self.target_regime == 'bearish' and new_position < 0:
                reward -= 0.001 * new_position  # reward for being short
            elif self.target_regime == 'scalper':
                # Reward quick position changes
                reward += position_change * 0.0005

        # Update state
        self.position = new_position
        self.pnl += pnl - fee_cost
        self.idx += 1

        done = self.idx >= self.n - 1

        return self._get_obs(), reward, done

# ============================================================
# PPO TRAINER
# ============================================================
class PPOTrainer:
    def __init__(self, model: ActorCritic, regime: str):
        self.model = model
        self.regime = regime
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LR)

    def train_episode(self, env: TradingEnv) -> Dict:
        """Train for one episode using PPO."""
        obs = env.reset()

        states = []
        actions = []
        rewards = []
        values = []
        log_probs = []

        # Collect trajectory
        while True:
            states.append(obs)

            with torch.no_grad():
                action, value = self.model.get_action(obs.unsqueeze(0))

            action = action.squeeze()
            value = value.squeeze()

            # Get log prob
            mean, std, _ = self.model.forward(obs.unsqueeze(0))
            dist = torch.distributions.Normal(mean.squeeze(), std.squeeze())
            log_prob = dist.log_prob(action)

            obs, reward, done = env.step(action.item())

            actions.append(action)
            rewards.append(reward)
            values.append(value)
            log_probs.append(log_prob)

            if done:
                break

        # Convert to tensors
        states = torch.stack(states)
        actions = torch.stack(actions).unsqueeze(-1)
        values = torch.stack(values)
        log_probs = torch.stack(log_probs)
        rewards = torch.tensor(rewards, device=Config.DEVICE, dtype=torch.float32)

        # Compute advantages (GAE)
        advantages = torch.zeros_like(rewards)
        lastgaelam = 0
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]

            delta = rewards[t] + Config.GAMMA * next_value - values[t]
            advantages[t] = lastgaelam = delta + Config.GAMMA * Config.GAE_LAMBDA * lastgaelam

        returns = advantages + values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # PPO update
        for _ in range(4):  # Multiple epochs
            # Shuffle and batch
            indices = torch.randperm(len(states))
            for start in range(0, len(states), Config.BATCH_SIZE):
                end = min(start + Config.BATCH_SIZE, len(states))
                batch_idx = indices[start:end]

                b_states = states[batch_idx]
                b_actions = actions[batch_idx]
                b_log_probs = log_probs[batch_idx]
                b_advantages = advantages[batch_idx]
                b_returns = returns[batch_idx]

                # Evaluate current policy
                new_log_probs, new_values, entropy = self.model.evaluate(b_states, b_actions)

                # PPO loss
                ratio = torch.exp(new_log_probs.squeeze() - b_log_probs)
                surr1 = ratio * b_advantages
                surr2 = torch.clamp(ratio, 1 - Config.CLIP_EPS, 1 + Config.CLIP_EPS) * b_advantages

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_loss = F.mse_loss(new_values.squeeze(), b_returns)
                entropy_loss = -entropy.mean()

                loss = actor_loss + 0.5 * critic_loss + Config.ENTROPY_COEF * entropy_loss

                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
                self.optimizer.step()

        # Return stats
        return {
            'return': env.pnl * 100,
            'trades': env.trade_count,
            'win_rate': env.wins / max(env.trade_count, 1) * 100
        }

# ============================================================
# MAIN TRAINING
# ============================================================
def train_regime_model(regime: str, prices: np.ndarray, features: torch.Tensor,
                       regimes: np.ndarray) -> ActorCritic:
    """Train a model for a specific regime."""

    print(f"\n{'='*60}")
    print(f"🎯 Training {regime.upper()} model")
    print(f"{'='*60}")

    # Count regime samples
    regime_count = (regimes == regime).sum()
    print(f"   Regime samples: {regime_count}")

    # Create model
    input_dim = features.shape[1] + 1  # +1 for position
    model = ActorCritic(input_dim).to(Config.DEVICE)

    # Create environment
    env = TradingEnv(prices, features, regimes, regime)
    trainer = PPOTrainer(model, regime)

    best_return = -float('inf')
    best_model_state = None

    for ep in range(Config.EPISODES):
        stats = trainer.train_episode(env)

        if stats['return'] > best_return:
            best_return = stats['return']
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

        if ep % 20 == 0:
            print(f"   Ep {ep:3d} | Ret: {stats['return']:+7.2f}% | Trades: {stats['trades']:5d} | Win: {stats['win_rate']:5.1f}%")

    # Load best model
    if best_model_state is not None:
        model.load_state_dict({k: v.to(Config.DEVICE) for k, v in best_model_state.items()})

    print(f"   ✅ {regime.upper()} complete - Best return: {best_return:+.2f}%")

    return model

def main():
    print("🚀 Device:", Config.DEVICE)

    print("""
    ╔══════════════════════════════════════════════════════════════╗
    ║                                                              ║
    ║   🎯 REGIME TRADING V5 - FORCE TRADING                      ║
    ║                                                              ║
    ║   Key changes:                                               ║
    ║   - ZERO fee during training                                 ║
    ║   - Idle penalty forces positions                           ║
    ║   - Continuous position sizing [-1, 1]                      ║
    ║   - High entropy for exploration                            ║
    ║   - Trade bonus rewards activity                            ║
    ║                                                              ║
    ╚══════════════════════════════════════════════════════════════╝
    """)

    # Load data
    df = pd.read_csv(Config.DATA_PATH)
    if 'price' in df.columns:
        prices = df['price'].values.astype(np.float32)
    elif 'close' in df.columns:
        prices = df['close'].values.astype(np.float32)
    else:
        prices = df.iloc[:, 1].values.astype(np.float32)

    print(f"📂 Loaded {len(prices)} prices")

    # Compute features
    print("⚡ Computing features on GPU...")
    features = compute_features_gpu(prices)
    print(f"   Features shape: {features.shape}")

    # Detect regimes
    print("🔍 Detecting regimes...")
    regimes = detect_regimes(prices, features)

    # Show distribution
    print("\n📊 Regime distribution:")
    for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
        count = (regimes == regime).sum()
        pct = count / len(regimes) * 100
        print(f"   {regime:10s}: {count:6d} ({pct:5.1f}%)")

    # Train models
    models = {}
    for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
        model = train_regime_model(regime, prices, features, regimes)
        models[regime] = model

        # Save model
        save_path = f'/workspace/model_{regime}_v5.pt'
        torch.save({
            'model_state': model.state_dict(),
            'regime': regime,
            'input_dim': features.shape[1] + 1,
            'config': {
                'hidden_dim': Config.HIDDEN_DIM,
                'num_layers': Config.NUM_LAYERS,
            }
        }, save_path)
        print(f"   💾 Saved: {save_path}")

    print("\n" + "="*60)
    print("✅ ALL MODELS TRAINED!")
    print("="*60)

    # Final evaluation with test fee
    print("\n📈 Final evaluation (with 0.4% fee):")
    for regime, model in models.items():
        env = TradingEnv(prices, features, regimes, regime)
        # Temporarily set production fee
        Config.FEE = 0.004
        Config.IDLE_PENALTY = 0
        Config.TRADE_BONUS = 0

        obs = env.reset()
        total_trades = 0
        while True:
            with torch.no_grad():
                action, _ = model.get_action(obs.unsqueeze(0), deterministic=True)
            obs, _, done = env.step(action.item())
            if done:
                break
            total_trades = env.trade_count

        print(f"   {regime:10s}: Return {env.pnl*100:+7.2f}% | Trades: {total_trades}")

        # Reset fee
        Config.FEE = 0.0
        Config.IDLE_PENALTY = 0.001
        Config.TRADE_BONUS = 0.0005

    print("\n🎉 Done! Models saved to /workspace/model_*_v5.pt")

if __name__ == '__main__':
    main()
