"""
BestTrading GPU Training V2 - FIXED Feature Extraction
Matches EXACTLY the 24 features used in production bot
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from collections import deque
import random
import json
import os
from datetime import datetime

print("=" * 70)
print("   BestTrading GPU Training V2 - FIXED 24 Features")
print("=" * 70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Load data
df = pd.read_csv('/kaggle/input/bestrading-prices/prices.csv')
prices = df['price'].values.tolist()
print(f"\nLoaded {len(prices)} price points")
print(f"Price range: {min(prices):.2f} - {max(prices):.2f}")


class TradingNN(nn.Module):
    """Neural network for trading - 24 inputs, 3 outputs (FLAT, LONG, SHORT)"""
    def __init__(self, input_size=24, hidden_sizes=[64, 32], output_size=3):
        super().__init__()
        layers = []
        prev_size = input_size
        for h in hidden_sizes:
            layers.append(nn.Linear(prev_size, h))
            layers.append(nn.LeakyReLU(0.01))
            layers.append(nn.Dropout(0.1))  # Regularization
            prev_size = h
        layers.append(nn.Linear(prev_size, output_size))
        self.net = nn.Sequential(*layers)

        # Initialize weights with smaller values
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)

    def get_action(self, state, epsilon=0):
        if random.random() < epsilon:
            return random.randint(0, 2)
        with torch.no_grad():
            state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
            q_values = self(state_t)
            return q_values.argmax(1).item()


class TradingEnvV2:
    """
    Trading environment with EXACT 24 features matching production bot
    """
    def __init__(self, prices, initial_capital=10000, regime='ranging'):
        self.prices = prices
        self.initial_capital = initial_capital
        self.fee_rate = 0.0026  # 0.26% like production
        self.slippage = 0.0005  # 0.05%
        self.regime = regime
        self.max_drawdown_penalty = 0.2  # 20% drawdown triggers penalty
        self.reset()

    def reset(self):
        self.idx = 60  # Need 60 candles for lookback
        self.capital = self.initial_capital
        self.position = 0  # -1=short, 0=flat, 1=long
        self.entry_price = 0
        self.peak_capital = self.initial_capital
        self.trades = 0
        self.wins = 0
        self.total_pnl = 0
        self.consecutive_losses = 0
        return self.get_state()

    def get_state(self):
        """Extract EXACTLY 24 features matching production bot"""
        if self.idx >= len(self.prices):
            return None

        price = self.prices[self.idx]
        features = []

        # Get price history
        hist = self.prices[max(0, self.idx-60):self.idx+1]

        # === 1. Returns at multiple timeframes (4 features) ===
        for period in [1, 5, 10, 20]:
            if len(hist) > period:
                ret = (price - hist[-1-period]) / hist[-1-period]
                features.append(max(-0.1, min(0.1, ret)) * 10)
            else:
                features.append(0)

        # === 2. Volatility features (3 features) ===
        if len(hist) >= 20:
            returns = [(hist[i] - hist[i-1]) / max(hist[i-1], 0.0001) for i in range(1, min(20, len(hist)))]
            rv = np.sqrt(np.mean(np.array(returns) ** 2)) * np.sqrt(252 * 24 * 60) * 100
            features.append(min(2, rv / 50))  # rv1m
            features.append(min(2, rv / 50))  # rv5m (approx)
            features.append(min(2, abs(returns[0]) * 100 / 30 if returns else 0))  # vov
        else:
            features.extend([0, 0, 0])

        # === 3. Flow features (3 features) - simulated from price action ===
        if len(hist) >= 5:
            flow = (price - hist[-5]) / hist[-5] * 50
            flow = max(-1, min(1, flow))
        else:
            flow = 0
        features.append(flow)  # flowImbalance
        features.append(0)  # burstScore
        features.append(flow * 0.8)  # obi

        # === 4. Spread and liquidity (2 features) ===
        features.append(0.2)  # spreadBps normalized
        features.append(0.5)  # intensity

        # === 5. Entropy and data quality (2 features) ===
        features.append(0.5)  # entropyH
        features.append(1)  # dataQuality

        # === 6. Momentum - MA crossover (1 feature) ===
        if len(hist) >= 10:
            ma5 = np.mean(hist[-5:])
            ma10 = np.mean(hist[-10:])
            features.append(max(-1, min(1, (ma5 - ma10) / ma10 * 100)))
        else:
            features.append(0)

        # === 7. RSI (1 feature) ===
        if len(hist) >= 15:
            gains, losses = [], []
            for i in range(1, min(14, len(hist))):
                change = hist[-i] - hist[-i-1]
                if change > 0:
                    gains.append(change)
                elif change < 0:
                    losses.append(abs(change))
            # Robust calculation - avoid ANY possibility of division by zero
            avg_gain = sum(gains) / 14.0 if gains else 0.0
            avg_loss = sum(losses) / 14.0 if losses else 0.0
            # Calculate RSI with bulletproof protection
            if avg_loss == 0:
                rsi = 100.0 if avg_gain > 0 else 50.0
            elif avg_gain == 0:
                rsi = 0.0
            else:
                rs = avg_gain / avg_loss
                rsi = 100.0 - (100.0 / (1.0 + rs))
            features.append((rsi - 50.0) / 50.0)
        else:
            features.append(0)

        # === 8. Flow momentum (1 feature) ===
        if len(hist) >= 6:
            flow_now = (price - hist[-2]) / max(hist[-2], 0.0001) * 50
            flow_5ago = (hist[-5] - hist[-6]) / max(hist[-6], 0.0001) * 50 if len(hist) >= 6 else 0
            features.append(max(-1, min(1, flow_now - flow_5ago)))
        else:
            features.append(0)

        # === 9. Volatility regime (1 feature) ===
        features.append(1)

        # === POSITION INFO (6 features) ===

        # Position normalized
        features.append(self.position)  # Already -1, 0, or 1

        # Unrealized PnL
        if self.position != 0 and self.entry_price > 0:
            if self.position > 0:
                upnl = (price - self.entry_price) / self.entry_price
            else:
                upnl = (self.entry_price - price) / self.entry_price
            features.append(max(-0.1, min(0.1, upnl)) * 10)
        else:
            features.append(0)

        # Time in position (normalized)
        features.append(0.5 if self.position != 0 else 0)

        # Drawdown
        current_capital = self.capital
        if self.position != 0 and self.entry_price > 0:
            if self.position > 0:
                current_capital += self.capital * 0.15 * 5 * (price - self.entry_price) / self.entry_price
            else:
                current_capital += self.capital * 0.15 * 5 * (self.entry_price - price) / self.entry_price
        dd = (self.peak_capital - current_capital) / self.peak_capital if self.peak_capital > 0 else 0
        features.append(min(1, max(0, dd) * 10))

        # Capital utilization
        features.append(abs(self.position))

        # Regime indicator
        regime_val = 0
        if len(hist) >= 20:
            ret20 = (price - hist[-20]) / hist[-20]
            if ret20 > 0.01:
                regime_val = 1  # bullish
            elif ret20 < -0.01:
                regime_val = -1  # bearish
        features.append(regime_val)

        return features

    def step(self, action):
        """
        Execute action: 0=FLAT, 1=LONG, 2=SHORT
        Reward based on PnL with drawdown penalty
        """
        price = self.prices[self.idx]
        reward = 0

        # Current position value
        position_value = self.capital * 0.15 * 5  # 15% equity * 5x leverage

        # === CLOSE existing position if needed ===
        if self.position != 0:
            # Calculate PnL
            if self.position > 0:  # Long position
                exit_price = price * (1 - self.slippage)
                pnl_pct = (exit_price - self.entry_price) / self.entry_price
            else:  # Short position
                exit_price = price * (1 + self.slippage)
                pnl_pct = (self.entry_price - exit_price) / self.entry_price

            # Check if we should close
            should_close = (
                action == 0 or  # FLAT requested
                (action == 1 and self.position == -1) or  # LONG but we're short
                (action == 2 and self.position == 1)  # SHORT but we're long
            )

            if should_close:
                # Close position
                gross_pnl = position_value * pnl_pct
                fee = position_value * self.fee_rate
                net_pnl = gross_pnl - fee

                self.capital += net_pnl
                self.total_pnl += net_pnl
                self.trades += 1

                if net_pnl > 0:
                    self.wins += 1
                    self.consecutive_losses = 0
                else:
                    self.consecutive_losses += 1

                # Reward = net PnL percentage * 100
                reward = pnl_pct * 100 - self.fee_rate * 100

                # Bonus for profitable trades
                if net_pnl > 0:
                    reward += 0.5

                # Penalty for consecutive losses
                if self.consecutive_losses >= 3:
                    reward -= 0.3

                self.position = 0
                self.entry_price = 0

        # === OPEN new position if requested and currently flat ===
        if self.position == 0:
            if action == 1:  # LONG
                self.position = 1
                self.entry_price = price * (1 + self.slippage)
                self.capital -= position_value * self.fee_rate
            elif action == 2:  # SHORT
                self.position = -1
                self.entry_price = price * (1 - self.slippage)
                self.capital -= position_value * self.fee_rate

        # Update peak capital for drawdown tracking
        current_equity = self.capital
        if self.position != 0 and self.entry_price > 0:
            if self.position > 0:
                current_equity += position_value * (price - self.entry_price) / self.entry_price
            else:
                current_equity += position_value * (self.entry_price - price) / self.entry_price

        if current_equity > self.peak_capital:
            self.peak_capital = current_equity

        # Drawdown penalty
        dd = (self.peak_capital - current_equity) / self.peak_capital if self.peak_capital > 0 else 0
        if dd > self.max_drawdown_penalty:
            reward -= (dd - self.max_drawdown_penalty) * 10

        # Move to next timestep
        self.idx += 1
        next_state = self.get_state()
        done = next_state is None

        # Close position at end
        if done and self.position != 0:
            final_price = self.prices[-1]
            if self.position > 0:
                pnl_pct = (final_price * (1 - self.slippage) - self.entry_price) / self.entry_price
            else:
                pnl_pct = (self.entry_price - final_price * (1 + self.slippage)) / self.entry_price

            gross_pnl = position_value * pnl_pct
            fee = position_value * self.fee_rate
            net_pnl = gross_pnl - fee
            self.capital += net_pnl
            self.trades += 1
            if net_pnl > 0:
                self.wins += 1
            reward += pnl_pct * 100 - self.fee_rate * 100

        return next_state, reward, done


class PrioritizedReplayBuffer:
    """Replay buffer with priority sampling"""
    def __init__(self, max_size=50000):
        self.buffer = deque(maxlen=max_size)
        self.priorities = deque(maxlen=max_size)

    def add(self, exp, priority=1.0):
        self.buffer.append(exp)
        self.priorities.append(priority)

    def sample(self, n):
        if len(self.buffer) < n:
            return list(self.buffer)

        # Sample with priority
        probs = np.array(self.priorities)
        probs = probs / probs.sum()
        indices = np.random.choice(len(self.buffer), size=n, p=probs, replace=False)
        return [self.buffer[i] for i in indices]

    def __len__(self):
        return len(self.buffer)


def train_model_v2(prices, model_type, episodes=150, lr=0.0005):
    """Train model with improved DQN"""
    print(f"\n{'='*70}")
    print(f"   Training {model_type.upper()} model (V2 - 24 features)")
    print(f"{'='*70}")

    if len(prices) < 1000:
        print("Not enough data (need 1000+)")
        return None, None

    buy_hold = (prices[-1] - prices[0]) / prices[0] * 100
    print(f"Buy & Hold: {buy_hold:+.2f}%")

    # Create models
    model = TradingNN(24, [64, 32], 3).to(device)
    target_model = TradingNN(24, [64, 32], 3).to(device)
    target_model.load_state_dict(model.state_dict())

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.9)
    buffer = PrioritizedReplayBuffer(50000)

    # Hyperparameters
    gamma = 0.97
    epsilon = 1.0
    epsilon_decay = 0.995
    epsilon_min = 0.05
    batch_size = 128

    best_sharpe = float('-inf')
    best_model_state = None

    # Training loop
    train_every = 100  # Train every N steps (not every step!)
    step_count = 0

    for ep in range(episodes):
        env = TradingEnvV2(prices, regime=model_type)
        state = env.reset()
        episode_reward = 0

        while state is not None:
            action = model.get_action(state, epsilon)
            next_state, reward, done = env.step(action)

            if next_state is not None:
                # Priority based on reward magnitude
                priority = abs(reward) + 0.1
                buffer.add((state, action, reward, next_state, done), priority)

            episode_reward += reward
            step_count += 1

            # Training step - only every N steps!
            if step_count % train_every == 0 and len(buffer) >= batch_size:
                batch = buffer.sample(batch_size)
                states = torch.FloatTensor([e[0] for e in batch]).to(device)
                actions = torch.LongTensor([e[1] for e in batch]).to(device)
                rewards = torch.FloatTensor([e[2] for e in batch]).to(device)
                next_states = torch.FloatTensor([e[3] for e in batch]).to(device)
                dones = torch.FloatTensor([e[4] for e in batch]).to(device)

                # Double DQN
                current_q = model(states).gather(1, actions.unsqueeze(1))
                with torch.no_grad():
                    next_actions = model(next_states).argmax(1)
                    next_q = target_model(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
                    target_q = rewards + gamma * next_q * (1 - dones)

                loss = F.smooth_l1_loss(current_q.squeeze(), target_q)

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

            state = next_state

        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        scheduler.step()

        # Update target network
        if ep % 10 == 0:
            target_model.load_state_dict(model.state_dict())

        # Evaluation - every 30 episodes
        if (ep + 1) % 30 == 0:
            eval_env = TradingEnvV2(prices, regime=model_type)
            eval_state = eval_env.reset()
            returns = []

            while eval_state is not None:
                action = model.get_action(eval_state, 0)
                eval_state, r, _ = eval_env.step(action)
                if r != 0:
                    returns.append(r)

            total_return = (eval_env.capital - 10000) / 10000 * 100
            win_rate = eval_env.wins / eval_env.trades * 100 if eval_env.trades > 0 else 0

            # Calculate Sharpe-like metric
            if len(returns) > 1:
                sharpe = np.mean(returns) / (np.std(returns) + 0.001) * np.sqrt(len(returns))
            else:
                sharpe = total_return

            marker = '*' if sharpe > best_sharpe else ' '
            print(f"{marker} Ep {ep+1:3d}: Return={total_return:+7.2f}%, Trades={eval_env.trades:3d}, Win={win_rate:4.1f}%, Sharpe={sharpe:.2f}")

            if sharpe > best_sharpe and eval_env.trades >= 5:
                best_sharpe = sharpe
                best_model_state = {k: v.clone() for k, v in model.state_dict().items()}

    # Load best model
    if best_model_state:
        model.load_state_dict(best_model_state)

    # Final evaluation
    final_env = TradingEnvV2(prices, regime=model_type)
    state = final_env.reset()
    actions_taken = {0: 0, 1: 0, 2: 0}

    while state is not None:
        action = model.get_action(state, 0)
        actions_taken[action] += 1
        state, _, _ = final_env.step(action)

    final_return = (final_env.capital - 10000) / 10000 * 100
    win_rate = final_env.wins / final_env.trades * 100 if final_env.trades > 0 else 0

    print(f"\n{'-'*70}")
    print(f"   FINAL - {model_type.upper()}")
    print(f"{'-'*70}")
    print(f"   Return:     {final_return:+.2f}%")
    print(f"   Buy & Hold: {buy_hold:+.2f}%")
    print(f"   {'✓ BEATS B&H' if final_return > buy_hold else '✗ B&H wins'}")
    print(f"   Trades: {final_env.trades}, Win Rate: {win_rate:.1f}%")
    print(f"   Actions: FLAT={actions_taken[0]}, LONG={actions_taken[1]}, SHORT={actions_taken[2]}")

    return model, {
        'model_type': model_type,
        'return': final_return,
        'buy_hold': buy_hold,
        'trades': final_env.trades,
        'win_rate': win_rate,
        'actions': actions_taken
    }


def export_model(model, model_type, metrics):
    """Export model to JSON format matching production bot"""
    layers = []
    state_dict = model.state_dict()
    keys = list(state_dict.keys())

    i = 0
    while i < len(keys):
        key = keys[i]
        if 'weight' in key:
            weights = state_dict[key].cpu().numpy().tolist()
            bias_key = keys[i+1] if i+1 < len(keys) else None
            biases = state_dict[bias_key].cpu().numpy().tolist() if bias_key and 'bias' in bias_key else []
            layers.append({'weights': weights, 'biases': biases})
            i += 2 if bias_key and 'bias' in bias_key else 1
        else:
            i += 1

    model_data = {
        'type': 'PureNN',
        'network': {'layers': layers},
        'metrics': {
            'totalReturn': metrics['return'],
            'buyHoldReturn': metrics['buy_hold'],
            'winRate': metrics['win_rate'],
            'totalTrades': metrics['trades'],
            'modelType': model_type,
            'trainedAt': datetime.now().isoformat(),
            'trainedOn': 'Kaggle GPU V2'
        }
    }

    filename = f"/kaggle/working/model_{model_type}_v2.json"
    with open(filename, 'w') as f:
        json.dump(model_data, f)
    print(f"✓ Saved {filename}")
    return filename


# === MAIN TRAINING ===
print("\n" + "=" * 70)
print("   Starting Training V2 - All 5 Regime Models")
print("=" * 70)

results = []
model_files = []

# Train each regime model
for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
    trained_model, metrics = train_model_v2(prices[-2000:], regime, episodes=50, lr=0.001)
    if trained_model and metrics:
        filename = export_model(trained_model, regime, metrics)
        model_files.append(filename)
        results.append(metrics)

# Create combined output
print(f"\n{'='*70}")
print("   TRAINING SUMMARY")
print(f"{'='*70}")

total_return = 0
for r in results:
    emoji = '✓' if r['return'] > r['buy_hold'] else '✗'
    print(f"   {emoji} {r['model_type'].upper():10}: {r['return']:+7.2f}% ({r['trades']} trades, {r['win_rate']:.0f}% win)")
    total_return += r['return']

avg_return = total_return / len(results) if results else 0
print(f"\n   Average Return: {avg_return:+.2f}%")

# Create combined model file for import
combined = {
    'pair': 'PAIR_PLACEHOLDER',
    'type': 'regime_models',
    'trainedAt': datetime.now().isoformat(),
    'trainedOn': 'Kaggle GPU V2',
    'regimeModels': {}
}

for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
    try:
        with open(f"/kaggle/working/model_{regime}_v2.json", 'r') as f:
            combined['regimeModels'][regime] = json.load(f)
    except:
        pass

with open('/kaggle/working/combined_model_v2.json', 'w') as f:
    json.dump(combined, f)

print("\n✓ Created combined_model_v2.json")
print("\nDone!")
