"""
BestTrading FAST GPU Training - Uses Pre-computed Features
No feature computation during training = FAST!
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import json
import os
from collections import deque
import random

print("=" * 70)
print("   BestTrading FAST GPU Training - Pre-computed Features")
print("=" * 70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Load PRE-COMPUTED features (no computation needed!)
print("\nLoading pre-computed data...")
features = np.load('/kaggle/input/bestrading-prices/features.npy')
prices = np.load('/kaggle/input/bestrading-prices/prices_aligned.npy')

print(f"Features shape: {features.shape}")
print(f"Prices: {len(prices)} points, range {prices.min():.2f} - {prices.max():.2f}")


class TradingNN(nn.Module):
    """Small fast network: 24 -> 64 -> 32 -> 3"""
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(24, 64),
            nn.LeakyReLU(0.01),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.01),
            nn.Linear(32, 3)
        )
        # Small init
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)


class FastTradingEnv:
    """
    Fast environment using pre-computed features
    Only updates position-related features (6 out of 24)
    """
    def __init__(self, features, prices, fee_rate=0.0026):
        self.base_features = features  # Pre-computed [N, 24]
        self.prices = prices
        self.fee_rate = fee_rate
        self.n_steps = len(prices)
        self.reset()

    def reset(self):
        self.idx = 0
        self.position = 0  # -1=short, 0=flat, 1=long
        self.entry_price = 0
        self.capital = 10000
        self.peak_capital = 10000
        self.trades = 0
        self.wins = 0
        return self._get_state()

    def _get_state(self):
        if self.idx >= self.n_steps:
            return None

        # Start with pre-computed features (first 18 are market features)
        state = self.base_features[self.idx].copy()

        # Update position features (last 6)
        state[18] = self.position  # position

        # Unrealized PnL
        if self.position != 0 and self.entry_price > 0:
            price = self.prices[self.idx]
            if self.position > 0:
                upnl = (price - self.entry_price) / self.entry_price
            else:
                upnl = (self.entry_price - price) / self.entry_price
            state[19] = max(-0.1, min(0.1, upnl)) * 10
        else:
            state[19] = 0

        # Time in position (simplified)
        state[20] = 0

        # Drawdown
        dd = (self.peak_capital - self.capital) / self.peak_capital
        state[21] = min(1, dd * 10)

        # Win rate (simplified)
        state[22] = 0.5 if self.trades == 0 else self.wins / self.trades

        # Capital ratio
        state[23] = (self.capital / 10000) - 1

        return state

    def step(self, action):
        """action: 0=FLAT, 1=LONG, 2=SHORT"""
        if self.idx >= self.n_steps:
            return None, 0, True

        price = self.prices[self.idx]
        reward = 0

        # Position changes
        if action == 0:  # FLAT
            if self.position != 0:
                # Close position
                if self.position > 0:
                    pnl = (price - self.entry_price) / self.entry_price - self.fee_rate
                else:
                    pnl = (self.entry_price - price) / self.entry_price - self.fee_rate

                self.capital *= (1 + pnl)
                self.trades += 1
                if pnl > 0:
                    self.wins += 1
                reward = pnl * 100  # Scale reward
                self.position = 0
                self.entry_price = 0

        elif action == 1:  # LONG
            if self.position == -1:  # Close short first
                pnl = (self.entry_price - price) / self.entry_price - self.fee_rate
                self.capital *= (1 + pnl)
                self.trades += 1
                if pnl > 0:
                    self.wins += 1
                reward += pnl * 100
            if self.position != 1:
                self.position = 1
                self.entry_price = price * (1 + self.fee_rate)

        elif action == 2:  # SHORT
            if self.position == 1:  # Close long first
                pnl = (price - self.entry_price) / self.entry_price - self.fee_rate
                self.capital *= (1 + pnl)
                self.trades += 1
                if pnl > 0:
                    self.wins += 1
                reward += pnl * 100
            if self.position != -1:
                self.position = -1
                self.entry_price = price * (1 - self.fee_rate)

        # Update peak
        self.peak_capital = max(self.peak_capital, self.capital)

        # Add holding reward/penalty
        if self.position != 0:
            if self.idx + 1 < self.n_steps:
                next_price = self.prices[self.idx + 1]
                if self.position > 0:
                    reward += (next_price - price) / price * 50
                else:
                    reward += (price - next_price) / price * 50

        self.idx += 1
        next_state = self._get_state()
        done = next_state is None

        return next_state, reward, done


def train_fast(model_type, episodes=100):
    """Fast training with pre-computed features"""
    print(f"\n{'='*60}")
    print(f"   Training {model_type.upper()} (FAST)")
    print(f"{'='*60}")

    model = TradingNN().to(device)
    target = TradingNN().to(device)
    target.load_state_dict(model.state_dict())

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    buffer = deque(maxlen=20000)

    gamma = 0.95
    epsilon = 1.0
    batch_size = 256

    best_return = float('-inf')
    best_state = None

    for ep in range(episodes):
        env = FastTradingEnv(features, prices)
        state = env.reset()
        ep_reward = 0

        while state is not None:
            # Epsilon-greedy
            if random.random() < epsilon:
                action = random.randint(0, 2)
            else:
                with torch.no_grad():
                    q = model(torch.FloatTensor(state).to(device))
                    action = q.argmax().item()

            next_state, reward, done = env.step(action)

            if next_state is not None:
                buffer.append((state, action, reward, next_state, done))

            ep_reward += reward
            state = next_state

        # Batch training (multiple updates per episode)
        if len(buffer) >= batch_size:
            for _ in range(10):  # 10 updates per episode
                batch = random.sample(buffer, batch_size)

                states = torch.FloatTensor([b[0] for b in batch]).to(device)
                actions = torch.LongTensor([b[1] for b in batch]).to(device)
                rewards = torch.FloatTensor([b[2] for b in batch]).to(device)
                next_states = torch.FloatTensor([b[3] for b in batch]).to(device)
                dones = torch.FloatTensor([b[4] for b in batch]).to(device)

                # DQN update
                current_q = model(states).gather(1, actions.unsqueeze(1))
                with torch.no_grad():
                    next_q = target(next_states).max(1)[0]
                    target_q = rewards + gamma * next_q * (1 - dones)

                loss = F.smooth_l1_loss(current_q.squeeze(), target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # Decay epsilon
        epsilon = max(0.05, epsilon * 0.98)

        # Update target
        if ep % 5 == 0:
            target.load_state_dict(model.state_dict())

        # Evaluate
        if (ep + 1) % 20 == 0:
            total_return = (env.capital - 10000) / 100
            trades = env.trades
            win_rate = env.wins / trades * 100 if trades > 0 else 0

            marker = '*' if total_return > best_return else ' '
            print(f"{marker} Ep {ep+1:3d}: Return={total_return:+6.2f}%, Trades={trades:3d}, Win={win_rate:4.1f}%")

            if total_return > best_return and trades >= 5:
                best_return = total_return
                best_state = {k: v.clone() for k, v in model.state_dict().items()}

    # Load best
    if best_state:
        model.load_state_dict(best_state)

    # Final eval
    env = FastTradingEnv(features, prices)
    state = env.reset()
    actions_count = {0: 0, 1: 0, 2: 0}

    while state is not None:
        with torch.no_grad():
            q = model(torch.FloatTensor(state).to(device))
            action = q.argmax().item()
        actions_count[action] += 1
        state, _, _ = env.step(action)

    final_return = (env.capital - 10000) / 100
    win_rate = env.wins / env.trades * 100 if env.trades > 0 else 0

    print(f"\n   FINAL {model_type.upper()}: {final_return:+.2f}%, {env.trades} trades, {win_rate:.1f}% win")
    print(f"   Actions: FLAT={actions_count[0]}, LONG={actions_count[1]}, SHORT={actions_count[2]}")

    return model, {
        'return': final_return,
        'trades': env.trades,
        'win_rate': win_rate,
        'actions': actions_count
    }


# Main training
print("\n" + "=" * 70)
print("   Training All Models")
print("=" * 70)

results = {}
all_models = {}

for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
    model, metrics = train_fast(regime, episodes=100)
    results[regime] = metrics
    all_models[regime] = model

# Save all models
print("\n" + "=" * 70)
print("   Saving Models")
print("=" * 70)

for regime, model in all_models.items():
    # Get weights
    state_dict = model.state_dict()

    weights = {
        'input_size': 24,
        'hidden_sizes': [64, 32],
        'output_size': 3,
        'weights': {}
    }

    for name, tensor in state_dict.items():
        weights['weights'][name] = tensor.cpu().numpy().tolist()

    filename = f'/kaggle/working/model_{regime}_v2.json'
    with open(filename, 'w') as f:
        json.dump(weights, f)
    print(f"✓ Saved {filename}")

# Summary
print("\n" + "=" * 70)
print("   SUMMARY")
print("=" * 70)

buy_hold = (prices[-1] - prices[0]) / prices[0] * 100
print(f"Buy & Hold: {buy_hold:+.2f}%\n")

for regime, metrics in results.items():
    beat = "✓" if metrics['return'] > buy_hold else "✗"
    print(f"{regime.upper():10s}: {metrics['return']:+6.2f}% | {metrics['trades']:3d} trades | {metrics['win_rate']:5.1f}% win | {beat}")

print("\n" + "=" * 70)
print("   DONE!")
print("=" * 70)
