import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from collections import deque
import random
import json
import os
from datetime import datetime

print("=" * 60)
print("   BestTrading GPU Training on Kaggle")
print("=" * 60)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("WARNING: No GPU available, using CPU")

# Load data
df = pd.read_csv('/kaggle/input/bestrading-sol-eur-prices/prices.csv')
prices = df['price'].values.tolist()
print(f"\nLoaded {len(prices)} price points")
print(f"Range: {min(prices):.2f} - {max(prices):.2f}")

class TradingNN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size=3):
        super().__init__()
        layers = []
        prev_size = input_size
        for h in hidden_sizes:
            layers.append(nn.Linear(prev_size, h))
            layers.append(nn.LeakyReLU(0.01))
            prev_size = h
        layers.append(nn.Linear(prev_size, output_size))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

    def get_action(self, state, epsilon=0):
        if random.random() < epsilon:
            return random.randint(0, 2)
        with torch.no_grad():
            state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
            q_values = self(state_t)
            return q_values.argmax(1).item()

class RangingTradingEnv:
    def __init__(self, prices, initial_capital=10000):
        self.prices = prices
        self.initial_capital = initial_capital
        self.fee_rate = 0.001
        self.reset()

    def reset(self):
        self.idx = 30
        self.capital = self.initial_capital
        self.position = 0
        self.entry_price = 0
        self.trades = 0
        self.wins = 0
        return self.get_state()

    def get_state(self):
        if self.idx >= len(self.prices):
            return None
        price = self.prices[self.idx]
        features = []
        for lb in [1, 2, 3, 5, 10]:
            if self.idx >= lb:
                features.append((price - self.prices[self.idx - lb]) / self.prices[self.idx - lb] * 100)
            else:
                features.append(0)
        if self.idx >= 20:
            window = self.prices[self.idx - 20:self.idx]
            mean = np.mean(window)
            std = np.std(window)
            z_score = (price - mean) / std if std > 0 else 0
            features.append(max(-3, min(3, z_score)))
            features.append(std / mean * 100)
        else:
            features.extend([0, 0])
        if self.idx >= 15:
            gains, losses = 0, 0
            for i in range(1, 15):
                change = self.prices[self.idx - i + 1] - self.prices[self.idx - i]
                if change > 0: gains += change
                else: losses -= change
            rs = gains / losses if losses > 0 else 100
            rsi = 100 - (100 / (1 + rs))
            features.append((rsi - 50) / 50)
        else:
            features.append(0)
        features.append(self.position)
        pnl = 0
        if self.position != 0 and self.entry_price > 0:
            pnl = self.position * (price - self.entry_price) / self.entry_price
        features.append(pnl * 100)
        return features

    def step(self, action):
        price = self.prices[self.idx]
        reward = 0
        if self.position != 0:
            pnl = self.position * (price - self.entry_price) / self.entry_price
            if action == 0 or (action == 1 and self.position == -1) or (action == 2 and self.position == 1):
                net_pnl = pnl - self.fee_rate
                self.capital *= (1 + net_pnl)
                self.trades += 1
                if net_pnl > 0: self.wins += 1
                reward = net_pnl * 100
                self.position = 0
                self.entry_price = 0
        if self.position == 0:
            if action == 1:
                self.position = 1
                self.entry_price = price
                self.capital *= (1 - self.fee_rate)
            elif action == 2:
                self.position = -1
                self.entry_price = price
                self.capital *= (1 - self.fee_rate)
        self.idx += 1
        next_state = self.get_state()
        done = next_state is None
        if done and self.position != 0:
            final_price = self.prices[-1]
            pnl = self.position * (final_price - self.entry_price) / self.entry_price
            net_pnl = pnl - self.fee_rate
            self.capital *= (1 + net_pnl)
            self.trades += 1
            if net_pnl > 0: self.wins += 1
            reward += net_pnl * 100
        return next_state, reward, done

class VolatileTradingEnv:
    def __init__(self, prices, initial_capital=10000):
        self.prices = prices
        self.initial_capital = initial_capital
        self.fee_rate = 0.001
        self.reset()

    def reset(self):
        self.idx = 30
        self.capital = self.initial_capital
        self.position = 0
        self.entry_price = 0
        self.trades = 0
        self.wins = 0
        return self.get_state()

    def get_state(self):
        if self.idx >= len(self.prices): return None
        price = self.prices[self.idx]
        features = []
        for lb in [1, 3, 5, 10, 20]:
            if self.idx >= lb:
                features.append((price - self.prices[self.idx - lb]) / self.prices[self.idx - lb] * 100)
            else:
                features.append(0)
        if self.idx >= 20:
            rets = [(self.prices[self.idx - i + 1] - self.prices[self.idx - i]) / self.prices[self.idx - i] for i in range(1, 21)]
            volatility = np.sqrt(np.mean(np.array(rets) ** 2)) * 100
            features.append(volatility)
            recent_vol = np.sqrt(np.mean(np.array(rets[:5]) ** 2))
            older_vol = np.sqrt(np.mean(np.array(rets[10:15]) ** 2))
            features.append((recent_vol - older_vol) / older_vol if older_vol > 0 else 0)
        else:
            features.extend([0, 0])
        if self.idx >= 10:
            ma5 = np.mean(self.prices[self.idx - 5:self.idx])
            ma10 = np.mean(self.prices[self.idx - 10:self.idx])
            features.append((ma5 - ma10) / ma10 * 100)
        else:
            features.append(0)
        features.append(self.position)
        pnl = 0
        if self.position != 0 and self.entry_price > 0:
            pnl = self.position * (price - self.entry_price) / self.entry_price
        features.append(pnl * 100)
        return features

    def step(self, action):
        price = self.prices[self.idx]
        reward = 0
        if self.position != 0:
            pnl = self.position * (price - self.entry_price) / self.entry_price
            if action == 0 or (action == 1 and self.position == -1) or (action == 2 and self.position == 1):
                net_pnl = pnl - self.fee_rate
                self.capital *= (1 + net_pnl)
                self.trades += 1
                if net_pnl > 0: self.wins += 1
                reward = net_pnl * 100
                self.position = 0
                self.entry_price = 0
        if self.position == 0:
            if action == 1:
                self.position = 1
                self.entry_price = price
                self.capital *= (1 - self.fee_rate)
            elif action == 2:
                self.position = -1
                self.entry_price = price
                self.capital *= (1 - self.fee_rate)
        self.idx += 1
        next_state = self.get_state()
        done = next_state is None
        if done and self.position != 0:
            final_price = self.prices[-1]
            pnl = self.position * (final_price - self.entry_price) / self.entry_price
            net_pnl = pnl - self.fee_rate
            self.capital *= (1 + net_pnl)
            self.trades += 1
            if net_pnl > 0: self.wins += 1
            reward += net_pnl * 100
        return next_state, reward, done

class ScalperTradingEnv:
    def __init__(self, prices, initial_capital=10000):
        self.prices = prices
        self.initial_capital = initial_capital
        self.fee_rate = 0.0008
        self.max_holding = 30
        self.reset()

    def reset(self):
        self.idx = 20
        self.capital = self.initial_capital
        self.position = 0
        self.entry_price = 0
        self.entry_idx = 0
        self.trades = 0
        self.wins = 0
        return self.get_state()

    def get_state(self):
        if self.idx >= len(self.prices): return None
        price = self.prices[self.idx]
        features = []
        for lb in [1, 2, 3, 5]:
            if self.idx >= lb:
                features.append((price - self.prices[self.idx - lb]) / self.prices[self.idx - lb] * 100)
            else:
                features.append(0)
        if self.idx >= 10:
            rets = [(self.prices[self.idx - i + 1] - self.prices[self.idx - i]) / self.prices[self.idx - i] for i in range(1, 11)]
            features.append(np.sqrt(np.mean(np.array(rets) ** 2)) * 100)
        else:
            features.append(0)
        if self.idx >= 5:
            ma5 = np.mean(self.prices[self.idx - 5:self.idx])
            features.append((price - ma5) / ma5 * 100)
        else:
            features.append(0)
        features.append(self.position)
        pnl = 0
        if self.position != 0 and self.entry_price > 0:
            pnl = self.position * (price - self.entry_price) / self.entry_price
        features.append(pnl * 100)
        features.append(min(1, (self.idx - self.entry_idx) / self.max_holding) if self.position != 0 else 0)
        return features

    def step(self, action):
        price = self.prices[self.idx]
        reward = 0
        must_exit = self.position != 0 and (self.idx - self.entry_idx) >= self.max_holding
        if self.position != 0:
            pnl = self.position * (price - self.entry_price) / self.entry_price
            if must_exit or action == 0 or (action == 1 and self.position == -1) or (action == 2 and self.position == 1):
                net_pnl = pnl - self.fee_rate
                self.capital *= (1 + net_pnl)
                self.trades += 1
                if net_pnl > 0: self.wins += 1
                reward = net_pnl * 150
                self.position = 0
                self.entry_price = 0
        if self.position == 0 and not must_exit:
            if action == 1:
                self.position = 1
                self.entry_price = price
                self.entry_idx = self.idx
                self.capital *= (1 - self.fee_rate)
            elif action == 2:
                self.position = -1
                self.entry_price = price
                self.entry_idx = self.idx
                self.capital *= (1 - self.fee_rate)
        self.idx += 1
        next_state = self.get_state()
        done = next_state is None
        if done and self.position != 0:
            final_price = self.prices[-1]
            pnl = self.position * (final_price - self.entry_price) / self.entry_price
            net_pnl = pnl - self.fee_rate
            self.capital *= (1 + net_pnl)
            self.trades += 1
            if net_pnl > 0: self.wins += 1
            reward += net_pnl * 150
        return next_state, reward, done

class ReplayBuffer:
    def __init__(self, max_size=30000):
        self.buffer = deque(maxlen=max_size)
    def add(self, exp):
        self.buffer.append(exp)
    def sample(self, n):
        return random.sample(self.buffer, min(n, len(self.buffer)))
    def __len__(self):
        return len(self.buffer)

def train_model(prices, model_type, EnvClass, input_size, hidden_sizes, episodes=300):
    print(f"\n{'='*60}")
    print(f"   Training {model_type.upper()} model")
    print(f"{'='*60}")

    if len(prices) < 500:
        print("Not enough data")
        return None, None

    buy_hold = (prices[-1] - prices[0]) / prices[0] * 100
    print(f"Buy & Hold: {buy_hold:+.2f}%")

    model = TradingNN(input_size, hidden_sizes).to(device)
    target_model = TradingNN(input_size, hidden_sizes).to(device)
    target_model.load_state_dict(model.state_dict())

    optimizer = optim.Adam(model.parameters(), lr=0.0005)
    buffer = ReplayBuffer(30000)

    gamma = 0.95
    epsilon = 1.0
    epsilon_decay = 0.996
    epsilon_min = 0.03
    batch_size = 64

    best_return = float('-inf')
    best_model_state = None

    for ep in range(episodes):
        env = EnvClass(prices)
        state = env.reset()

        while state is not None:
            action = model.get_action(state, epsilon)
            next_state, reward, done = env.step(action)
            if next_state is not None:
                buffer.add((state, action, reward, next_state, done))
            if len(buffer) >= batch_size:
                batch = buffer.sample(batch_size)
                states = torch.FloatTensor([e[0] for e in batch]).to(device)
                actions = torch.LongTensor([e[1] for e in batch]).to(device)
                rewards = torch.FloatTensor([e[2] for e in batch]).to(device)
                next_states = torch.FloatTensor([e[3] for e in batch]).to(device)
                dones = torch.FloatTensor([e[4] for e in batch]).to(device)
                current_q = model(states).gather(1, actions.unsqueeze(1))
                next_q = target_model(next_states).max(1)[0].detach()
                target_q = rewards + gamma * next_q * (1 - dones)
                loss = F.smooth_l1_loss(current_q.squeeze(), target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            state = next_state

        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        if ep % 10 == 0:
            target_model.load_state_dict(model.state_dict())

        if (ep + 1) % 50 == 0:
            eval_env = EnvClass(prices)
            eval_state = eval_env.reset()
            while eval_state is not None:
                action = model.get_action(eval_state, 0)
                eval_state, _, _ = eval_env.step(action)
            ret = (eval_env.capital - 10000) / 10000 * 100
            win_rate = eval_env.wins / eval_env.trades * 100 if eval_env.trades > 0 else 0
            marker = '*' if ret > best_return else ' '
            print(f"{marker} Ep {ep+1:3d}: Return={ret:+7.2f}%, Trades={eval_env.trades:3d}, Win={win_rate:4.1f}%")
            if ret > best_return:
                best_return = ret
                best_model_state = {k: v.clone() for k, v in model.state_dict().items()}

    if best_model_state:
        model.load_state_dict(best_model_state)

    final_env = EnvClass(prices)
    state = final_env.reset()
    while state is not None:
        action = model.get_action(state, 0)
        state, _, _ = final_env.step(action)

    final_return = (final_env.capital - 10000) / 10000 * 100
    win_rate = final_env.wins / final_env.trades * 100 if final_env.trades > 0 else 0

    print(f"\n{'-'*60}")
    print(f"   FINAL - {model_type.upper()}")
    print(f"{'-'*60}")
    print(f"   Return: {final_return:+.2f}%")
    print(f"   B&H: {buy_hold:+.2f}%")
    print(f"   {'BEATS B&H!' if final_return > buy_hold else 'B&H wins'}")
    print(f"   Trades: {final_env.trades}, Win Rate: {win_rate:.1f}%")

    return model, {
        'model_type': model_type,
        'return': final_return,
        'buy_hold': buy_hold,
        'trades': final_env.trades,
        'win_rate': win_rate
    }

def export_model_js(model, model_type, metrics):
    layers = []
    state_dict = model.state_dict()
    keys = list(state_dict.keys())
    i = 0
    while i < len(keys):
        if 'weight' in keys[i]:
            weights = state_dict[keys[i]].cpu().numpy().tolist()
            biases = state_dict[keys[i+1]].cpu().numpy().tolist() if i+1 < len(keys) and 'bias' in keys[i+1] else []
            layers.append({'weights': weights, 'biases': biases})
            i += 2
        else:
            i += 1

    model_data = {
        'type': 'PureNN',
        'network': {'layers': layers},
        'metrics': {
            'totalReturn': metrics['return'],
            'buyHoldReturn': metrics['buy_hold'],
            'winRate': metrics['win_rate'],
            'totalTrades': metrics['trades'],
            'modelType': model_type,
            'trainedAt': datetime.now().isoformat(),
            'trainedOn': 'Kaggle GPU'
        }
    }

    filename = f"/kaggle/working/model_{model_type}.json"
    with open(filename, 'w') as f:
        json.dump(model_data, f)
    print(f"Saved {filename}")
    return filename

# Train all models
results = []

# RANGING
trained_model, metrics = train_model(prices, 'ranging', RangingTradingEnv, 10, [48, 24], 300)
if trained_model:
    export_model_js(trained_model, 'ranging', metrics)
    results.append(metrics)

# VOLATILE
trained_model, metrics = train_model(prices, 'volatile', VolatileTradingEnv, 10, [48, 24], 300)
if trained_model:
    export_model_js(trained_model, 'volatile', metrics)
    results.append(metrics)

# SCALPER
trained_model, metrics = train_model(prices, 'scalper', ScalperTradingEnv, 9, [32, 16], 400)
if trained_model:
    export_model_js(trained_model, 'scalper', metrics)
    results.append(metrics)

# Summary
print(f"\n{'='*60}")
print("   TRAINING SUMMARY")
print(f"{'='*60}")
for r in results:
    print(f"   {r['model_type'].upper():10}: {r['return']:+.2f}% ({r['trades']} trades, {r['win_rate']:.0f}% win)")

# Save combined results
with open('/kaggle/working/all_results.json', 'w') as f:
    json.dump(results, f)

print("\nDone!")
