#!/usr/bin/env python3
"""
BTC Trading Model V7 - PPO Training with Entropy Bonus
=======================================================
Fixes from V6:
1. Entropy bonus to prevent policy collapse
2. Small fees during training (realistic)
3. Better reward shaping
4. Separate entry/exit signals
5. Multi-timeframe input

Author: Claude Code
Date: January 2026
"""

import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from datetime import datetime
from pathlib import Path
from typing import Tuple, Optional
import json

# ============================================================
# CONFIGURATION
# ============================================================

CONFIG = {
    # Training
    "episodes": 800,
    "steps_per_episode": 2000,
    "batch_size": 2048,
    "num_envs": 32,
    "learning_rate": 1e-4,
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_epsilon": 0.2,
    "value_coef": 0.5,
    "max_grad_norm": 0.5,

    # CRITICAL FIX: Entropy bonus to prevent collapse
    "entropy_coef": 0.02,  # Was 0.0 in V6!
    "entropy_decay": 0.9995,  # Slowly decay entropy bonus
    "min_entropy_coef": 0.005,

    # CRITICAL FIX: Realistic fees during training
    "trading_fee": 0.001,  # 0.1% fee (was 0.0 in V6!)
    "slippage": 0.0005,  # 0.05% slippage

    # Reward shaping
    "pnl_scale": 100.0,  # Scale PnL reward
    "hold_bonus": 0.0001,  # Small bonus for holding profitable position
    "win_bonus": 0.01,  # Bonus for profitable trade
    "drawdown_penalty": 0.1,  # Penalty for drawdown

    # Model architecture
    "hidden_dim": 256,
    "num_heads": 4,
    "num_layers": 3,
    "dropout": 0.1,
    "lookback": 60,  # 60 minutes of history

    # Features
    "num_features": 12,  # price, returns, volatility, etc.

    # Output
    "output_dir": "/workspace/models",
    "save_every": 50,
    "log_every": 10,
}


# ============================================================
# TRADING ENVIRONMENT
# ============================================================

class TradingEnv:
    """
    Vectorized trading environment for PPO training.
    Supports multiple parallel environments.
    """

    def __init__(self, prices: np.ndarray, num_envs: int = 32, config: dict = None):
        self.prices = prices
        self.num_envs = num_envs
        self.config = config or CONFIG
        self.lookback = self.config["lookback"]

        # Precompute features
        self.features = self._compute_features(prices)
        self.max_steps = len(prices) - self.lookback - 1

        # State
        self.positions = np.zeros(num_envs)  # -1, 0, 1
        self.entry_prices = np.zeros(num_envs)
        self.steps = np.zeros(num_envs, dtype=np.int32)
        self.pnls = np.zeros(num_envs)
        self.peak_values = np.ones(num_envs)

    def _compute_features(self, prices: np.ndarray) -> np.ndarray:
        """Compute multi-timeframe features from price data."""
        n = len(prices)
        features = np.zeros((n, self.config["num_features"]))

        # Returns at different timeframes
        for i, period in enumerate([1, 5, 15, 30, 60]):
            if i < self.config["num_features"]:
                ret = np.zeros(n)
                ret[period:] = (prices[period:] - prices[:-period]) / prices[:-period]
                features[:, i] = ret

        # Volatility at different timeframes
        for i, period in enumerate([5, 15, 30]):
            idx = 5 + i
            if idx < self.config["num_features"]:
                vol = np.zeros(n)
                for j in range(period, n):
                    vol[j] = np.std(np.diff(prices[j-period:j]) / prices[j-period:j-1])
                features[:, idx] = vol

        # Price momentum
        if 8 < self.config["num_features"]:
            sma_short = pd.Series(prices).rolling(10).mean().values
            sma_long = pd.Series(prices).rolling(30).mean().values
            features[:, 8] = np.nan_to_num((sma_short - sma_long) / sma_long)

        # RSI-like indicator
        if 9 < self.config["num_features"]:
            delta = np.diff(prices, prepend=prices[0])
            gain = np.where(delta > 0, delta, 0)
            loss = np.where(delta < 0, -delta, 0)
            avg_gain = pd.Series(gain).rolling(14).mean().values
            avg_loss = pd.Series(loss).rolling(14).mean().values
            rs = np.nan_to_num(avg_gain / (avg_loss + 1e-10))
            features[:, 9] = (100 - 100 / (1 + rs)) / 100 - 0.5  # Normalize to [-0.5, 0.5]

        # Volume proxy (price change magnitude)
        if 10 < self.config["num_features"]:
            features[:, 10] = np.abs(np.diff(prices, prepend=prices[0])) / prices

        # Trend strength
        if 11 < self.config["num_features"]:
            ret_20 = np.zeros(n)
            ret_20[20:] = (prices[20:] - prices[:-20]) / prices[:-20]
            features[:, 11] = ret_20

        # Normalize features
        for i in range(features.shape[1]):
            col = features[:, i]
            std = np.std(col[~np.isnan(col)])
            if std > 0:
                features[:, i] = np.clip(col / (std * 3), -1, 1)

        return np.nan_to_num(features)

    def reset(self) -> np.ndarray:
        """Reset environments to random starting points."""
        self.positions = np.zeros(self.num_envs)
        self.entry_prices = np.zeros(self.num_envs)
        self.pnls = np.zeros(self.num_envs)
        self.peak_values = np.ones(self.num_envs)

        # Random start positions (with margin for lookback)
        max_start = self.max_steps - self.config["steps_per_episode"]
        self.steps = np.random.randint(self.lookback, max(self.lookback + 1, max_start), self.num_envs)

        return self._get_obs()

    def _get_obs(self) -> np.ndarray:
        """Get observation for all environments."""
        obs = np.zeros((self.num_envs, self.lookback, self.config["num_features"] + 1))

        for i in range(self.num_envs):
            start = self.steps[i] - self.lookback
            end = self.steps[i]
            obs[i, :, :-1] = self.features[start:end]
            obs[i, :, -1] = self.positions[i]  # Current position

        return obs.astype(np.float32)

    def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
        """
        Execute actions in all environments.
        Actions: continuous [-1, 1] mapped to position
        """
        rewards = np.zeros(self.num_envs)
        dones = np.zeros(self.num_envs, dtype=bool)

        current_prices = self.prices[self.steps]
        next_prices = self.prices[np.minimum(self.steps + 1, len(self.prices) - 1)]

        # Map actions to target positions (-1, 0, 1)
        target_positions = np.clip(np.round(actions * 1.5), -1, 1)

        for i in range(self.num_envs):
            old_pos = self.positions[i]
            new_pos = target_positions[i]
            price = current_prices[i]
            next_price = next_prices[i]

            # Calculate trading costs if position changes
            trade_cost = 0
            if old_pos != new_pos:
                trade_size = abs(new_pos - old_pos)
                trade_cost = trade_size * price * (self.config["trading_fee"] + self.config["slippage"])

                # Close old position
                if old_pos != 0:
                    pnl = old_pos * (price - self.entry_prices[i])
                    self.pnls[i] += pnl - trade_cost / 2

                    # Win bonus
                    if pnl > 0:
                        rewards[i] += self.config["win_bonus"]

                # Open new position
                if new_pos != 0:
                    self.entry_prices[i] = price

                self.positions[i] = new_pos

            # Calculate PnL from price movement
            if self.positions[i] != 0:
                price_pnl = self.positions[i] * (next_price - price) / price
                rewards[i] += price_pnl * self.config["pnl_scale"]

                # Hold bonus for profitable positions
                if self.positions[i] * (next_price - self.entry_prices[i]) > 0:
                    rewards[i] += self.config["hold_bonus"]

            # Drawdown penalty
            current_value = 1 + self.pnls[i] / 10000  # Normalize
            if current_value > self.peak_values[i]:
                self.peak_values[i] = current_value
            drawdown = (self.peak_values[i] - current_value) / self.peak_values[i]
            rewards[i] -= drawdown * self.config["drawdown_penalty"]

            # Trading cost penalty
            rewards[i] -= trade_cost / price * self.config["pnl_scale"]

        # Advance time
        self.steps += 1

        # Check if done
        dones = self.steps >= self.max_steps

        # Reset done environments
        done_indices = np.where(dones)[0]
        if len(done_indices) > 0:
            max_start = self.max_steps - self.config["steps_per_episode"]
            self.steps[done_indices] = np.random.randint(
                self.lookback, max(self.lookback + 1, max_start), len(done_indices)
            )
            self.positions[done_indices] = 0
            self.entry_prices[done_indices] = 0
            self.pnls[done_indices] = 0
            self.peak_values[done_indices] = 1

        info = {"pnls": self.pnls.copy()}
        return self._get_obs(), rewards.astype(np.float32), dones, info


# ============================================================
# TRANSFORMER MODEL
# ============================================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, :x.size(1)]


class TradingTransformer(nn.Module):
    """
    Transformer-based trading model with PPO heads.
    Outputs: action (position), value estimate
    """

    def __init__(self, config: dict = None):
        super().__init__()
        self.config = config or CONFIG

        input_dim = self.config["num_features"] + 1  # features + position
        hidden_dim = self.config["hidden_dim"]

        # Input projection
        self.input_proj = nn.Linear(input_dim, hidden_dim)

        # Positional encoding
        self.pos_encoder = PositionalEncoding(hidden_dim, self.config["lookback"])

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=self.config["num_heads"],
            dim_feedforward=hidden_dim * 4,
            dropout=self.config["dropout"],
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, self.config["num_layers"])

        # Output heads
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 2)  # mean and log_std
        )

        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass.
        x: (batch, seq_len, features)
        Returns: action_mean, action_log_std, value
        """
        # Input projection
        x = self.input_proj(x)
        x = self.pos_encoder(x)

        # Transformer
        x = self.transformer(x)

        # Use last timestep
        x = x[:, -1, :]

        # Policy head
        policy_out = self.policy_head(x)
        action_mean = torch.tanh(policy_out[:, 0:1])  # Bounded [-1, 1]
        action_log_std = torch.clamp(policy_out[:, 1:2], -2, 0.5)  # Bounded std

        # Value head
        value = self.value_head(x)

        return action_mean, action_log_std, value

    def get_action(self, x: torch.Tensor, deterministic: bool = False):
        """Sample action from policy."""
        action_mean, action_log_std, value = self.forward(x)

        if deterministic:
            return action_mean, value

        action_std = torch.exp(action_log_std)
        dist = torch.distributions.Normal(action_mean, action_std)
        action = dist.sample()
        action = torch.clamp(action, -1, 1)
        log_prob = dist.log_prob(action)

        return action, log_prob, value, action_mean, action_std

    def evaluate_action(self, x: torch.Tensor, action: torch.Tensor):
        """Evaluate log probability and entropy of action."""
        action_mean, action_log_std, value = self.forward(x)
        action_std = torch.exp(action_log_std)

        dist = torch.distributions.Normal(action_mean, action_std)
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()

        return log_prob, entropy, value


# ============================================================
# PPO TRAINER
# ============================================================

class PPOTrainer:
    """PPO trainer with entropy bonus."""

    def __init__(self, model: TradingTransformer, config: dict = None):
        self.model = model
        self.config = config or CONFIG
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        self.optimizer = torch.optim.Adam(
            model.parameters(),
            lr=self.config["learning_rate"]
        )

        self.entropy_coef = self.config["entropy_coef"]

    def compute_gae(self, rewards, values, dones):
        """Compute Generalized Advantage Estimation."""
        advantages = np.zeros_like(rewards)
        last_gae = 0

        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]

            delta = rewards[t] + self.config["gamma"] * next_value * (1 - dones[t]) - values[t]
            advantages[t] = last_gae = delta + self.config["gamma"] * self.config["gae_lambda"] * (1 - dones[t]) * last_gae

        returns = advantages + values
        return advantages, returns

    def update(self, rollout_buffer):
        """Update policy using PPO."""
        obs = torch.FloatTensor(rollout_buffer["obs"]).to(self.device)
        actions = torch.FloatTensor(rollout_buffer["actions"]).to(self.device)
        old_log_probs = torch.FloatTensor(rollout_buffer["log_probs"]).to(self.device)
        advantages = torch.FloatTensor(rollout_buffer["advantages"]).to(self.device)
        returns = torch.FloatTensor(rollout_buffer["returns"]).to(self.device)

        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Mini-batch updates
        batch_size = self.config["batch_size"]
        indices = np.arange(len(obs))

        total_loss = 0
        total_policy_loss = 0
        total_value_loss = 0
        total_entropy = 0
        num_updates = 0

        for _ in range(4):  # PPO epochs
            np.random.shuffle(indices)

            for start in range(0, len(obs), batch_size):
                end = start + batch_size
                batch_indices = indices[start:end]

                batch_obs = obs[batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                batch_advantages = advantages[batch_indices]
                batch_returns = returns[batch_indices]

                # Get new log probs and values
                log_probs, entropy, values = self.model.evaluate_action(batch_obs, batch_actions)

                # Policy loss with clipping
                ratio = torch.exp(log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantages.unsqueeze(-1)
                surr2 = torch.clamp(ratio, 1 - self.config["clip_epsilon"], 1 + self.config["clip_epsilon"]) * batch_advantages.unsqueeze(-1)
                policy_loss = -torch.min(surr1, surr2).mean()

                # Value loss
                value_loss = F.mse_loss(values.squeeze(-1), batch_returns)

                # Entropy bonus (CRITICAL FIX!)
                entropy_loss = -entropy.mean()

                # Total loss
                loss = policy_loss + self.config["value_coef"] * value_loss + self.entropy_coef * entropy_loss

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), self.config["max_grad_norm"])
                self.optimizer.step()

                total_loss += loss.item()
                total_policy_loss += policy_loss.item()
                total_value_loss += value_loss.item()
                total_entropy += entropy.mean().item()
                num_updates += 1

        # Decay entropy coefficient
        self.entropy_coef = max(
            self.config["min_entropy_coef"],
            self.entropy_coef * self.config["entropy_decay"]
        )

        return {
            "loss": total_loss / num_updates,
            "policy_loss": total_policy_loss / num_updates,
            "value_loss": total_value_loss / num_updates,
            "entropy": total_entropy / num_updates,
            "entropy_coef": self.entropy_coef
        }


# ============================================================
# TRAINING LOOP
# ============================================================

def train(prices: np.ndarray, config: dict = None):
    """Main training loop."""
    config = config or CONFIG
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("="*60)
    print("BTC TRADING MODEL V7 - TRAINING")
    print("="*60)
    print(f"Device: {device}")
    print(f"Episodes: {config['episodes']}")
    print(f"Entropy coef: {config['entropy_coef']} (CRITICAL FIX)")
    print(f"Trading fee: {config['trading_fee']} (CRITICAL FIX)")
    print("="*60)

    # Create environment and model
    env = TradingEnv(prices, config["num_envs"], config)
    model = TradingTransformer(config)
    trainer = PPOTrainer(model, config)

    # Training metrics
    best_reward = -float('inf')
    training_log = []

    for episode in range(config["episodes"]):
        obs = env.reset()
        episode_rewards = []

        # Collect rollout
        rollout = {
            "obs": [],
            "actions": [],
            "log_probs": [],
            "rewards": [],
            "values": [],
            "dones": []
        }

        for step in range(config["steps_per_episode"]):
            obs_tensor = torch.FloatTensor(obs).to(device)

            with torch.no_grad():
                action, log_prob, value, _, _ = model.get_action(obs_tensor)

            action_np = action.cpu().numpy().squeeze(-1)
            next_obs, rewards, dones, info = env.step(action_np)

            rollout["obs"].append(obs)
            rollout["actions"].append(action.cpu().numpy())
            rollout["log_probs"].append(log_prob.cpu().numpy())
            rollout["rewards"].append(rewards)
            rollout["values"].append(value.cpu().numpy().squeeze(-1))
            rollout["dones"].append(dones)

            episode_rewards.append(rewards.mean())
            obs = next_obs

        # Convert to arrays
        for key in rollout:
            rollout[key] = np.array(rollout[key])

        # Compute advantages
        advantages, returns = trainer.compute_gae(
            rollout["rewards"].mean(axis=1),
            rollout["values"].mean(axis=1),
            rollout["dones"].any(axis=1)
        )

        # Reshape for batch processing
        rollout["obs"] = rollout["obs"].reshape(-1, config["lookback"], config["num_features"] + 1)
        rollout["actions"] = rollout["actions"].reshape(-1, 1)
        rollout["log_probs"] = rollout["log_probs"].reshape(-1, 1)
        rollout["advantages"] = np.repeat(advantages, config["num_envs"])
        rollout["returns"] = np.repeat(returns, config["num_envs"])

        # Update policy
        update_info = trainer.update(rollout)

        # Logging
        mean_reward = np.mean(episode_rewards)
        mean_pnl = info["pnls"].mean()

        log_entry = {
            "episode": episode,
            "mean_reward": mean_reward,
            "mean_pnl": mean_pnl,
            **update_info
        }
        training_log.append(log_entry)

        if episode % config["log_every"] == 0:
            print(f"Episode {episode:4d} | Reward: {mean_reward:8.4f} | PnL: {mean_pnl:8.2f} | "
                  f"Entropy: {update_info['entropy']:.4f} | Coef: {update_info['entropy_coef']:.4f}")

        # Save checkpoint
        if episode % config["save_every"] == 0 or mean_reward > best_reward:
            if mean_reward > best_reward:
                best_reward = mean_reward
                checkpoint_path = f"{config['output_dir']}/model_best.pt"
            else:
                checkpoint_path = f"{config['output_dir']}/model_ep{episode}.pt"

            os.makedirs(config["output_dir"], exist_ok=True)
            torch.save({
                "model_state_dict": model.state_dict(),
                "config": config,
                "episode": episode,
                "best_reward": best_reward,
            }, checkpoint_path)

            if mean_reward >= best_reward:
                print(f"  -> New best model saved: {checkpoint_path}")

    # Save final model
    final_path = f"{config['output_dir']}/model_final.pt"
    torch.save({
        "model_state_dict": model.state_dict(),
        "config": config,
        "episode": config["episodes"],
        "best_reward": best_reward,
    }, final_path)
    print(f"\nFinal model saved: {final_path}")

    # Save training log
    log_path = f"{config['output_dir']}/training_log.json"
    with open(log_path, "w") as f:
        json.dump(training_log, f)
    print(f"Training log saved: {log_path}")

    return model, training_log


# ============================================================
# MAIN
# ============================================================

def main():
    print("\n" + "="*60)
    print("BTC TRADING MODEL V7 - TRAINING SCRIPT")
    print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("="*60)

    # Check GPU
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        print("WARNING: No GPU detected, training will be slow!")

    # Load price data
    data_path = "/workspace/prices_btc_2025.csv"
    if not os.path.exists(data_path):
        data_path = "/workspace/prices.csv"

    if not os.path.exists(data_path):
        print(f"ERROR: Price data not found at {data_path}")
        print("Please upload price data first!")
        sys.exit(1)

    print(f"\nLoading price data from {data_path}...")
    df = pd.read_csv(data_path)

    # Handle different column names
    if "price" in df.columns:
        prices = df["price"].values
    elif "close" in df.columns:
        prices = df["close"].values
    elif "mid" in df.columns:
        prices = df["mid"].values
    else:
        prices = df.iloc[:, 1].values  # Assume second column is price

    print(f"Loaded {len(prices)} price points")
    print(f"Price range: {prices.min():.2f} - {prices.max():.2f}")

    # Train
    model, log = train(prices, CONFIG)

    print("\n" + "="*60)
    print("TRAINING COMPLETE!")
    print("="*60)
    print(f"Models saved in: {CONFIG['output_dir']}")
    print("\nTo download:")
    print("  scp -P PORT root@IP:/workspace/models/* ./")


if __name__ == "__main__":
    main()
