#!/usr/bin/env python3
"""
Single regime trainer - can be launched in parallel
Usage: python train_single_regime.py <regime_name>
"""

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from typing import Tuple, Dict
import warnings
warnings.filterwarnings('ignore')

# Get regime from command line
if len(sys.argv) < 2:
    print("Usage: python train_single_regime.py <regime_name>")
    print("Regimes: bullish, bearish, ranging, volatile, scalper")
    sys.exit(1)

TARGET_REGIME = sys.argv[1].lower()
REGIME_MAP = {'bullish': 0, 'bearish': 1, 'ranging': 2, 'volatile': 3, 'scalper': 4}

if TARGET_REGIME not in REGIME_MAP:
    print(f"Unknown regime: {TARGET_REGIME}")
    sys.exit(1)

TARGET_REGIME_ID = REGIME_MAP[TARGET_REGIME]

# Config
class Config:
    DATA_PATH = '/workspace/prices_btc_2025.csv'
    EPISODES = 500
    BATCH_SIZE = 2048
    LR = 3e-4
    GAMMA = 0.99
    GAE_LAMBDA = 0.95
    CLIP_EPS = 0.2
    ENTROPY_COEF = 0.02
    NUM_ENVS = 64
    WINDOW_SIZE = 1000
    HIDDEN_DIM = 256
    NUM_HEADS = 4
    NUM_LAYERS = 2
    DROPOUT = 0.1
    FEE = 0.0
    IDLE_PENALTY = 0.0005
    TRADE_BONUS = 0.0002
    LOOKBACK = 20
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@torch.jit.script
def compute_returns(prices: torch.Tensor, period: int) -> torch.Tensor:
    n = prices.shape[0]
    ret = torch.zeros_like(prices)
    if period < n:
        ret[period:] = (prices[period:] - prices[:-period]) / (prices[:-period] + 1e-8)
    return ret

def compute_features(prices: np.ndarray) -> torch.Tensor:
    prices_t = torch.tensor(prices, dtype=torch.float32, device=Config.DEVICE)
    n = len(prices)
    features = []
    for period in [1, 5, 10, 20, 60]:
        features.append(compute_returns(prices_t, period))
    ret1 = compute_returns(prices_t, 1)
    vol = torch.zeros(n, device=Config.DEVICE)
    for w in [10, 30]:
        for i in range(w, n):
            vol[i] = ret1[i-w:i].std()
        features.append(vol.clone())
    for period in [10, 30]:
        ma = torch.zeros(n, device=Config.DEVICE)
        cumsum = torch.cumsum(prices_t, dim=0)
        ma[period:] = (cumsum[period:] - torch.cat([torch.zeros(1, device=Config.DEVICE), cumsum[:-period-1]])) / period
        ma[:period] = prices_t[:period].mean()
        ma_ratio = (prices_t - ma) / (ma + 1e-8)
        features.append(ma_ratio)
    feat = torch.stack(features, dim=1)
    mean = feat.mean(dim=0, keepdim=True)
    std = feat.std(dim=0, keepdim=True) + 1e-8
    feat = (feat - mean) / std
    feat = torch.clamp(feat, -3, 3)
    return feat

def detect_regimes(prices: np.ndarray, features: torch.Tensor) -> torch.Tensor:
    n = len(prices)
    feat = features.cpu().numpy()
    regimes = np.full(n, 4, dtype=np.int64)
    ret20 = feat[:, 3]
    ret60 = feat[:, 4]
    vol30 = feat[:, 6]
    volatile_mask = vol30 > 1.5
    regimes[volatile_mask] = 3
    bullish_mask = (ret20 > 0.3) & (ret60 > 0.2) & ~volatile_mask
    regimes[bullish_mask] = 0
    bearish_mask = (ret20 < -0.3) & (ret60 < -0.2) & ~volatile_mask
    regimes[bearish_mask] = 1
    ranging_mask = (np.abs(ret60) < 0.15) & (vol30 < 0.5) & ~volatile_mask
    regimes[ranging_mask] = 2
    return torch.tensor(regimes, dtype=torch.long, device=Config.DEVICE)

class TradingTransformer(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.embed = nn.Linear(input_dim + 1, Config.HIDDEN_DIM)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=Config.HIDDEN_DIM, nhead=Config.NUM_HEADS,
            dim_feedforward=Config.HIDDEN_DIM * 4, dropout=Config.DROPOUT, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=Config.NUM_LAYERS)
        self.actor_mean = nn.Linear(Config.HIDDEN_DIM, 1)
        self.actor_log_std = nn.Parameter(torch.zeros(1) - 0.5)
        self.critic = nn.Linear(Config.HIDDEN_DIM, 1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        h = self.embed(x)
        h = self.transformer(h)
        h = h[:, -1]
        mean = torch.tanh(self.actor_mean(h))
        std = torch.exp(self.actor_log_std).expand(mean.shape[0], 1)
        value = self.critic(h)
        return mean, std, value

class VectorizedEnv:
    def __init__(self, prices: torch.Tensor, features: torch.Tensor,
                 regimes: torch.Tensor, target_regime: int, num_envs: int):
        self.prices = prices
        self.features = features
        self.regimes = regimes
        self.target_regime = target_regime
        self.num_envs = num_envs
        self.n = len(prices)
        self.positions = torch.zeros(num_envs, device=Config.DEVICE)
        self.start_idx = torch.zeros(num_envs, dtype=torch.long, device=Config.DEVICE)
        self.current_idx = torch.zeros(num_envs, dtype=torch.long, device=Config.DEVICE)
        self.pnl = torch.zeros(num_envs, device=Config.DEVICE)
        self.trade_count = torch.zeros(num_envs, dtype=torch.long, device=Config.DEVICE)
        self.reset()

    def reset(self) -> torch.Tensor:
        max_start = self.n - Config.WINDOW_SIZE - Config.LOOKBACK
        self.start_idx = torch.randint(Config.LOOKBACK, max_start, (self.num_envs,), device=Config.DEVICE)
        self.current_idx = self.start_idx.clone()
        self.positions = torch.zeros(self.num_envs, device=Config.DEVICE)
        self.pnl = torch.zeros(self.num_envs, device=Config.DEVICE)
        self.trade_count = torch.zeros(self.num_envs, dtype=torch.long, device=Config.DEVICE)
        return self._get_obs()

    def _get_obs(self) -> torch.Tensor:
        obs_list = []
        for i in range(self.num_envs):
            idx = self.current_idx[i].item()
            start = max(0, idx - Config.LOOKBACK)
            feat_seq = self.features[start:idx+1]
            if feat_seq.shape[0] < Config.LOOKBACK:
                pad = torch.zeros(Config.LOOKBACK - feat_seq.shape[0], feat_seq.shape[1], device=Config.DEVICE)
                feat_seq = torch.cat([pad, feat_seq], dim=0)
            pos = self.positions[i].unsqueeze(0).expand(feat_seq.shape[0], 1)
            feat_seq = torch.cat([feat_seq, pos], dim=1)
            obs_list.append(feat_seq)
        return torch.stack(obs_list)

    def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        actions = actions.squeeze(-1)
        old_positions = self.positions.clone()
        new_positions = actions.clamp(-1, 1)
        current_prices = self.prices[self.current_idx]
        next_idx = (self.current_idx + 1).clamp(max=self.n-1)
        next_prices = self.prices[next_idx]
        price_returns = (next_prices - current_prices) / (current_prices + 1e-8)
        avg_positions = (old_positions + new_positions) / 2
        pnl = avg_positions * price_returns
        position_change = (new_positions - old_positions).abs()
        fee_cost = position_change * Config.FEE
        self.trade_count += (position_change > 0.1).long()
        rewards = pnl - fee_cost
        current_regimes = self.regimes[self.current_idx]
        regime_match = (current_regimes == self.target_regime).float()
        rewards = rewards * (1.0 + regime_match)
        idle_mask = new_positions.abs() < 0.1
        rewards -= idle_mask.float() * Config.IDLE_PENALTY
        rewards += position_change * Config.TRADE_BONUS
        if self.target_regime == 0:
            rewards += (new_positions > 0).float() * 0.0002 * regime_match
        elif self.target_regime == 1:
            rewards -= (new_positions < 0).float() * 0.0002 * regime_match
        self.positions = new_positions
        self.pnl += pnl - fee_cost
        self.current_idx = next_idx
        steps_taken = self.current_idx - self.start_idx
        dones = (steps_taken >= Config.WINDOW_SIZE) | (self.current_idx >= self.n - 1)
        return self._get_obs(), rewards, dones

class PPOTrainer:
    def __init__(self, model: TradingTransformer):
        self.model = model
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LR)

    def train_batch(self, states, actions, old_log_probs, advantages, returns):
        mean, std, values = self.model(states)
        dist = torch.distributions.Normal(mean, std)
        log_probs = dist.log_prob(actions).squeeze(-1)
        entropy = dist.entropy().mean()
        ratio = torch.exp(log_probs - old_log_probs)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - Config.CLIP_EPS, 1 + Config.CLIP_EPS) * advantages
        actor_loss = -torch.min(surr1, surr2).mean()
        critic_loss = F.mse_loss(values.squeeze(-1), returns)
        loss = actor_loss + 0.5 * critic_loss - Config.ENTROPY_COEF * entropy
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
        self.optimizer.step()

def main():
    print(f"🚀 Training {TARGET_REGIME.upper()} model")
    print(f"   Device: {Config.DEVICE}")

    df = pd.read_csv(Config.DATA_PATH)
    if 'price' in df.columns:
        prices = df['price'].values.astype(np.float32)
    elif 'close' in df.columns:
        prices = df['close'].values.astype(np.float32)
    else:
        prices = df.iloc[:, 1].values.astype(np.float32)

    print(f"   Loaded {len(prices):,} prices")

    prices_t = torch.tensor(prices, dtype=torch.float32, device=Config.DEVICE)
    features = compute_features(prices)
    regimes = detect_regimes(prices, features)

    regime_count = (regimes == TARGET_REGIME_ID).sum().item()
    print(f"   Regime samples: {regime_count}")

    input_dim = features.shape[1]
    model = TradingTransformer(input_dim).to(Config.DEVICE)
    trainer = PPOTrainer(model)
    env = VectorizedEnv(prices_t, features, regimes, TARGET_REGIME_ID, Config.NUM_ENVS)

    best_return = -float('inf')
    best_state = None

    for ep in range(Config.EPISODES):
        states_list, actions_list, rewards_list, values_list, log_probs_list = [], [], [], [], []
        obs = env.reset()
        done = torch.zeros(Config.NUM_ENVS, dtype=torch.bool, device=Config.DEVICE)

        while not done.all():
            with torch.no_grad():
                mean, std, value = model(obs)
                dist = torch.distributions.Normal(mean, std)
                action = dist.sample()
                log_prob = dist.log_prob(action).squeeze(-1)

            states_list.append(obs[~done])
            actions_list.append(action[~done])
            log_probs_list.append(log_prob[~done])
            values_list.append(value.squeeze(-1)[~done])
            obs, rewards, new_done = env.step(action)
            rewards_list.append(rewards[~done])
            done = done | new_done

        if len(states_list) == 0:
            continue

        states = torch.cat(states_list)
        actions = torch.cat(actions_list)
        old_log_probs = torch.cat(log_probs_list)
        values = torch.cat(values_list)
        rewards = torch.cat(rewards_list)

        returns = rewards
        advantages = returns - values.detach()
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        n_samples = len(states)
        indices = torch.randperm(n_samples, device=Config.DEVICE)
        for start in range(0, n_samples, Config.BATCH_SIZE):
            end = min(start + Config.BATCH_SIZE, n_samples)
            batch_idx = indices[start:end]
            trainer.train_batch(states[batch_idx], actions[batch_idx], old_log_probs[batch_idx],
                              advantages[batch_idx], returns[batch_idx])

        mean_return = env.pnl.mean().item() * 100
        mean_trades = env.trade_count.float().mean().item()

        if mean_return > best_return:
            best_return = mean_return
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

        if ep % 25 == 0:
            print(f"   Ep {ep:3d} | Ret: {mean_return:+7.2f}% | Trades: {mean_trades:5.1f}")

    if best_state:
        model.load_state_dict({k: v.to(Config.DEVICE) for k, v in best_state.items()})

    save_path = f'/workspace/model_{TARGET_REGIME}_v6.pt'
    torch.save({
        'model_state': model.state_dict(),
        'regime': TARGET_REGIME,
        'regime_id': TARGET_REGIME_ID,
        'input_dim': features.shape[1],
        'config': {'hidden_dim': Config.HIDDEN_DIM, 'num_heads': Config.NUM_HEADS, 'num_layers': Config.NUM_LAYERS}
    }, save_path)

    print(f"   ✅ {TARGET_REGIME.upper()} complete - Best: {best_return:+.2f}%")
    print(f"   💾 Saved: {save_path}")

if __name__ == '__main__':
    main()
