"""
BestTrading V7 - Local CPU Training
Optimized for CPU - smaller batch size, fewer episodes
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import json
from collections import deque
import random
import time

print("=" * 70)
print("   BestTrading V7 - Local CPU Training")
print("=" * 70)

device = torch.device('cpu')
print(f"Device: {device}")

# ============================================================
# 1. LOAD PRE-COMPUTED DATA
# ============================================================
print("\nLoading pre-computed data...")
data_dir = '/var/www/html/bestrading.cuttalo.com/scripts/kaggle-kernel-v2'
features_array = np.load(f'{data_dir}/features.npy')
prices_aligned = np.load(f'{data_dir}/prices_aligned.npy')

print(f"Features: {features_array.shape}")
print(f"Prices: {len(prices_aligned):,} (range: {prices_aligned.min():.2f} - {prices_aligned.max():.2f})")

# ============================================================
# 2. NEURAL NETWORK
# ============================================================

class TradingNN(nn.Module):
    """24 -> 64 -> 32 -> 3"""
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(24, 64),
            nn.LeakyReLU(0.01),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.01),
            nn.Linear(32, 3)
        )
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)


class FastTradingEnv:
    """Fast environment with pre-computed features"""
    def __init__(self, features, prices, fee_rate=0.0026):
        self.base_features = features
        self.prices = prices
        self.fee_rate = fee_rate
        self.n_steps = len(prices)
        self.reset()

    def reset(self):
        self.idx = 0
        self.position = 0
        self.entry_price = 0
        self.capital = 10000
        self.peak_capital = 10000
        self.trades = 0
        self.wins = 0
        return self._get_state()

    def _get_state(self):
        if self.idx >= self.n_steps:
            return None

        state = self.base_features[self.idx].copy()
        state[18] = self.position

        if self.position != 0 and self.entry_price > 0:
            price = self.prices[self.idx]
            if self.position > 0:
                upnl = (price - self.entry_price) / self.entry_price
            else:
                upnl = (self.entry_price - price) / self.entry_price
            state[19] = max(-0.1, min(0.1, upnl)) * 10
        else:
            state[19] = 0

        state[20] = 0
        dd = (self.peak_capital - self.capital) / self.peak_capital
        state[21] = min(1, dd * 10)
        state[22] = 0.5 if self.trades == 0 else self.wins / self.trades
        state[23] = (self.capital / 10000) - 1

        return state

    def step(self, action):
        if self.idx >= self.n_steps:
            return None, 0, True

        price = self.prices[self.idx]
        reward = 0

        if action == 0:  # FLAT
            if self.position != 0:
                if self.position > 0:
                    pnl = (price - self.entry_price) / self.entry_price - self.fee_rate
                else:
                    pnl = (self.entry_price - price) / self.entry_price - self.fee_rate
                self.capital *= (1 + pnl)
                self.trades += 1
                if pnl > 0:
                    self.wins += 1
                reward = pnl * 100
                self.position = 0
                self.entry_price = 0

        elif action == 1:  # LONG
            if self.position == -1:
                pnl = (self.entry_price - price) / self.entry_price - self.fee_rate
                self.capital *= (1 + pnl)
                self.trades += 1
                if pnl > 0:
                    self.wins += 1
                reward += pnl * 100
            if self.position != 1:
                self.position = 1
                self.entry_price = price * (1 + self.fee_rate)

        elif action == 2:  # SHORT
            if self.position == 1:
                pnl = (price - self.entry_price) / self.entry_price - self.fee_rate
                self.capital *= (1 + pnl)
                self.trades += 1
                if pnl > 0:
                    self.wins += 1
                reward += pnl * 100
            if self.position != -1:
                self.position = -1
                self.entry_price = price * (1 - self.fee_rate)

        self.peak_capital = max(self.peak_capital, self.capital)

        if self.position != 0 and self.idx + 1 < self.n_steps:
            next_price = self.prices[self.idx + 1]
            if self.position > 0:
                reward += (next_price - price) / price * 50
            else:
                reward += (price - next_price) / price * 50

        self.idx += 1
        next_state = self._get_state()
        done = next_state is None

        return next_state, reward, done


# ============================================================
# 3. TRAINING
# ============================================================

def train_model(model_type, episodes=80):
    """Train one model - CPU optimized"""
    print(f"\n{'='*60}")
    print(f"   Training {model_type.upper()}")
    print(f"{'='*60}")

    model = TradingNN().to(device)
    target = TradingNN().to(device)
    target.load_state_dict(model.state_dict())

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    buffer = deque(maxlen=20000)

    gamma = 0.95
    epsilon = 1.0
    batch_size = 128  # Smaller for CPU

    best_return = float('-inf')
    best_state = None

    start_time = time.time()

    for ep in range(episodes):
        env = FastTradingEnv(features_array, prices_aligned)
        state = env.reset()
        ep_reward = 0

        while state is not None:
            if random.random() < epsilon:
                action = random.randint(0, 2)
            else:
                with torch.no_grad():
                    q = model(torch.FloatTensor(state))
                    action = q.argmax().item()

            next_state, reward, done = env.step(action)

            if next_state is not None:
                buffer.append((state, action, reward, next_state, done))

            ep_reward += reward
            state = next_state

        # Train - fewer updates for CPU
        if len(buffer) >= batch_size:
            for _ in range(5):  # 5 updates per episode
                batch = random.sample(buffer, batch_size)

                states = torch.FloatTensor([b[0] for b in batch])
                actions = torch.LongTensor([b[1] for b in batch])
                rewards = torch.FloatTensor([b[2] for b in batch])
                next_states = torch.FloatTensor([b[3] for b in batch])
                dones = torch.FloatTensor([b[4] for b in batch])

                current_q = model(states).gather(1, actions.unsqueeze(1))
                with torch.no_grad():
                    next_q = target(next_states).max(1)[0]
                    target_q = rewards + gamma * next_q * (1 - dones)

                loss = F.smooth_l1_loss(current_q.squeeze(), target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        epsilon = max(0.05, epsilon * 0.97)

        if ep % 5 == 0:
            target.load_state_dict(model.state_dict())

        if (ep + 1) % 20 == 0:
            total_return = (env.capital - 10000) / 100
            trades = env.trades
            win_rate = env.wins / trades * 100 if trades > 0 else 0
            elapsed = time.time() - start_time
            marker = '*' if total_return > best_return else ' '
            print(f"{marker} Ep {ep+1:3d}: Return={total_return:+7.2f}%, Trades={trades:3d}, Win={win_rate:5.1f}% ({elapsed:.0f}s)")

            if total_return > best_return and trades >= 5:
                best_return = total_return
                best_state = {k: v.clone() for k, v in model.state_dict().items()}

    if best_state:
        model.load_state_dict(best_state)

    # Final evaluation
    env = FastTradingEnv(features_array, prices_aligned)
    state = env.reset()
    actions_count = {0: 0, 1: 0, 2: 0}

    while state is not None:
        with torch.no_grad():
            q = model(torch.FloatTensor(state))
            action = q.argmax().item()
        actions_count[action] += 1
        state, _, _ = env.step(action)

    final_return = (env.capital - 10000) / 100
    win_rate = env.wins / env.trades * 100 if env.trades > 0 else 0

    print(f"\n   FINAL {model_type.upper()}: {final_return:+.2f}%, {env.trades} trades, {win_rate:.1f}% win")
    print(f"   Actions: FLAT={actions_count[0]}, LONG={actions_count[1]}, SHORT={actions_count[2]}")

    return model, {
        'return': final_return,
        'trades': env.trades,
        'win_rate': win_rate,
        'actions': actions_count
    }


# ============================================================
# 4. MAIN TRAINING
# ============================================================
print("\n" + "=" * 70)
print("   Training All Models (80 episodes each on CPU)")
print("=" * 70)

results = {}
all_models = {}

for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
    model, metrics = train_model(regime, episodes=80)
    results[regime] = metrics
    all_models[regime] = model

# ============================================================
# 5. SAVE MODELS
# ============================================================
print("\n" + "=" * 70)
print("   Saving Models")
print("=" * 70)

output_dir = '/tmp/kaggle-v7-output'
import os
os.makedirs(output_dir, exist_ok=True)

for regime, model in all_models.items():
    state_dict = model.state_dict()

    weights = {
        'input_size': 24,
        'hidden_sizes': [64, 32],
        'output_size': 3,
        'weights': {}
    }

    for name, tensor in state_dict.items():
        weights['weights'][name] = tensor.numpy().tolist()

    filename = f'{output_dir}/model_{regime}_v2.json'
    with open(filename, 'w') as f:
        json.dump(weights, f)
    print(f"Saved {filename}")

# ============================================================
# 6. SUMMARY
# ============================================================
print("\n" + "=" * 70)
print("   SUMMARY")
print("=" * 70)

buy_hold = (prices_aligned[-1] - prices_aligned[0]) / prices_aligned[0] * 100
print(f"Buy & Hold: {buy_hold:+.2f}%")
print(f"Data: {len(prices_aligned):,} samples (Nov 1 - Jan 17)\n")

for regime, metrics in results.items():
    beat = "BEAT" if metrics['return'] > buy_hold else "    "
    print(f"  {regime.upper():10s}: {metrics['return']:+7.2f}% | {metrics['trades']:3d} trades | {metrics['win_rate']:5.1f}% win | {beat}")

print("\n" + "=" * 70)
print("   DONE!")
print("=" * 70)
