"""
BestTrading V7 - Self-contained GPU Training
Downloads data from Binance, computes features, trains models
No dataset upload needed!
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import json
import requests
from collections import deque
import random

print("=" * 70)
print("   BestTrading V7 - Self-Contained Training")
print("=" * 70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# ============================================================
# 1. DOWNLOAD DATA FROM BINANCE
# ============================================================
print("\n" + "=" * 70)
print("   Downloading Historical Data from Binance")
print("=" * 70)

def fetch_binance_klines(symbol, interval, start_time, end_time):
    """Download klines from Binance"""
    all_klines = []
    current_start = start_time

    while current_start < end_time:
        url = f"https://api.binance.com/api/v3/klines?symbol={symbol}&interval={interval}&startTime={current_start}&limit=1000"
        response = requests.get(url)
        data = response.json()

        if not data or isinstance(data, dict):
            break

        all_klines.extend(data)

        last_close_time = data[-1][6]
        current_start = last_close_time + 1

        if len(data) < 1000:
            break

    return all_klines

# Download Nov 1, 2025 to Jan 17, 2026
import datetime
start_date = datetime.datetime(2025, 11, 1, 0, 0, 0)
end_date = datetime.datetime(2026, 1, 17, 23, 59, 59)
start_ms = int(start_date.timestamp() * 1000)
end_ms = int(end_date.timestamp() * 1000)

print(f"Period: {start_date.date()} to {end_date.date()}")
print("Downloading...")

klines = fetch_binance_klines("BTCEUR", "1m", start_ms, end_ms)
prices = np.array([float(k[4]) for k in klines], dtype=np.float32)  # Close prices

print(f"Downloaded {len(prices):,} price points")
print(f"Price range: {prices.min():.2f} - {prices.max():.2f}")

# ============================================================
# 2. COMPUTE FEATURES
# ============================================================
print("\n" + "=" * 70)
print("   Computing Features")
print("=" * 70)

def compute_features(prices, idx):
    """Compute exactly 24 features"""
    if idx < 60:
        return None

    price = prices[idx]
    hist = prices[max(0, idx-60):idx+1]
    features = []

    # 1. Returns (4 features)
    for period in [1, 5, 10, 20]:
        if len(hist) > period:
            ret = (price - hist[-1-period]) / hist[-1-period]
            features.append(max(-0.1, min(0.1, ret)) * 10)
        else:
            features.append(0)

    # 2. Volatility (3 features)
    if len(hist) >= 20:
        returns = [(hist[i] - hist[i-1]) / max(hist[i-1], 0.0001) for i in range(1, min(20, len(hist)))]
        rv = np.sqrt(np.mean(np.array(returns) ** 2)) * np.sqrt(252 * 24 * 60) * 100
        features.append(min(2, rv / 50))
        features.append(min(2, rv / 50))
        features.append(min(2, abs(returns[0]) * 100 / 30 if returns else 0))
    else:
        features.extend([0, 0, 0])

    # 3. Flow (3 features)
    if len(hist) >= 5:
        flow = (price - hist[-5]) / hist[-5] * 50
        flow = max(-1, min(1, flow))
    else:
        flow = 0
    features.append(flow)
    features.append(0)
    features.append(flow * 0.8)

    # 4. Spread/liquidity (2 features)
    features.append(0.2)
    features.append(0.5)

    # 5. Entropy/quality (2 features)
    features.append(0.5)
    features.append(1)

    # 6. MA crossover (1 feature)
    if len(hist) >= 10:
        ma5 = np.mean(hist[-5:])
        ma10 = np.mean(hist[-10:])
        features.append(max(-1, min(1, (ma5 - ma10) / ma10 * 100)))
    else:
        features.append(0)

    # 7. RSI (1 feature)
    if len(hist) >= 15:
        gains, losses = [], []
        for i in range(1, min(14, len(hist))):
            change = hist[-i] - hist[-i-1]
            if change > 0:
                gains.append(change)
            elif change < 0:
                losses.append(abs(change))
        avg_gain = sum(gains) / 14.0 if gains else 0.0
        avg_loss = sum(losses) / 14.0 if losses else 0.0
        if avg_loss == 0:
            rsi = 100.0 if avg_gain > 0 else 50.0
        elif avg_gain == 0:
            rsi = 0.0
        else:
            rsi = 100.0 - (100.0 / (1.0 + avg_gain / avg_loss))
        features.append((rsi - 50.0) / 50.0)
    else:
        features.append(0)

    # 8. Flow momentum (1 feature)
    if len(hist) >= 6:
        flow_now = (price - hist[-2]) / max(hist[-2], 0.0001) * 50
        flow_5ago = (hist[-5] - hist[-6]) / max(hist[-6], 0.0001) * 50
        features.append(max(-1, min(1, flow_now - flow_5ago)))
    else:
        features.append(0)

    # 9. Vol regime (1 feature)
    features.append(1)

    # 10. Position info (6 features) - placeholders
    features.extend([0, 0, 0, 0, 0, 0])

    return features

# Compute all features
print("Pre-computing features...")
all_features = []
valid_indices = []

for idx in range(60, len(prices)):
    features = compute_features(prices, idx)
    if features and len(features) == 24:
        all_features.append(features)
        valid_indices.append(idx)

    if (idx - 60) % 20000 == 0:
        print(f"  {idx - 60:,}/{len(prices) - 60:,} timesteps...")

features_array = np.array(all_features, dtype=np.float32)
prices_aligned = np.array([prices[i] for i in valid_indices], dtype=np.float32)

print(f"Features: {features_array.shape}")
print(f"Prices: {len(prices_aligned):,}")

# ============================================================
# 3. NEURAL NETWORK
# ============================================================

class TradingNN(nn.Module):
    """24 -> 64 -> 32 -> 3"""
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(24, 64),
            nn.LeakyReLU(0.01),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.01),
            nn.Linear(32, 3)
        )
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)


class FastTradingEnv:
    """Fast environment with pre-computed features"""
    def __init__(self, features, prices, fee_rate=0.0026):
        self.base_features = features
        self.prices = prices
        self.fee_rate = fee_rate
        self.n_steps = len(prices)
        self.reset()

    def reset(self):
        self.idx = 0
        self.position = 0
        self.entry_price = 0
        self.capital = 10000
        self.peak_capital = 10000
        self.trades = 0
        self.wins = 0
        return self._get_state()

    def _get_state(self):
        if self.idx >= self.n_steps:
            return None

        state = self.base_features[self.idx].copy()
        state[18] = self.position

        if self.position != 0 and self.entry_price > 0:
            price = self.prices[self.idx]
            if self.position > 0:
                upnl = (price - self.entry_price) / self.entry_price
            else:
                upnl = (self.entry_price - price) / self.entry_price
            state[19] = max(-0.1, min(0.1, upnl)) * 10
        else:
            state[19] = 0

        state[20] = 0
        dd = (self.peak_capital - self.capital) / self.peak_capital
        state[21] = min(1, dd * 10)
        state[22] = 0.5 if self.trades == 0 else self.wins / self.trades
        state[23] = (self.capital / 10000) - 1

        return state

    def step(self, action):
        if self.idx >= self.n_steps:
            return None, 0, True

        price = self.prices[self.idx]
        reward = 0

        if action == 0:  # FLAT
            if self.position != 0:
                if self.position > 0:
                    pnl = (price - self.entry_price) / self.entry_price - self.fee_rate
                else:
                    pnl = (self.entry_price - price) / self.entry_price - self.fee_rate
                self.capital *= (1 + pnl)
                self.trades += 1
                if pnl > 0:
                    self.wins += 1
                reward = pnl * 100
                self.position = 0
                self.entry_price = 0

        elif action == 1:  # LONG
            if self.position == -1:
                pnl = (self.entry_price - price) / self.entry_price - self.fee_rate
                self.capital *= (1 + pnl)
                self.trades += 1
                if pnl > 0:
                    self.wins += 1
                reward += pnl * 100
            if self.position != 1:
                self.position = 1
                self.entry_price = price * (1 + self.fee_rate)

        elif action == 2:  # SHORT
            if self.position == 1:
                pnl = (price - self.entry_price) / self.entry_price - self.fee_rate
                self.capital *= (1 + pnl)
                self.trades += 1
                if pnl > 0:
                    self.wins += 1
                reward += pnl * 100
            if self.position != -1:
                self.position = -1
                self.entry_price = price * (1 - self.fee_rate)

        self.peak_capital = max(self.peak_capital, self.capital)

        if self.position != 0 and self.idx + 1 < self.n_steps:
            next_price = self.prices[self.idx + 1]
            if self.position > 0:
                reward += (next_price - price) / price * 50
            else:
                reward += (price - next_price) / price * 50

        self.idx += 1
        next_state = self._get_state()
        done = next_state is None

        return next_state, reward, done


# ============================================================
# 4. TRAINING
# ============================================================

def train_model(model_type, episodes=150):
    """Train one model"""
    print(f"\n{'='*60}")
    print(f"   Training {model_type.upper()}")
    print(f"{'='*60}")

    model = TradingNN().to(device)
    target = TradingNN().to(device)
    target.load_state_dict(model.state_dict())

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    buffer = deque(maxlen=30000)

    gamma = 0.95
    epsilon = 1.0
    batch_size = 256

    best_return = float('-inf')
    best_state = None

    for ep in range(episodes):
        env = FastTradingEnv(features_array, prices_aligned)
        state = env.reset()
        ep_reward = 0

        while state is not None:
            if random.random() < epsilon:
                action = random.randint(0, 2)
            else:
                with torch.no_grad():
                    q = model(torch.FloatTensor(state).to(device))
                    action = q.argmax().item()

            next_state, reward, done = env.step(action)

            if next_state is not None:
                buffer.append((state, action, reward, next_state, done))

            ep_reward += reward
            state = next_state

        if len(buffer) >= batch_size:
            for _ in range(10):
                batch = random.sample(buffer, batch_size)

                # OPTIMIZED: Convert to numpy array first, then to tensor (10x faster)
                states = torch.from_numpy(np.array([b[0] for b in batch], dtype=np.float32)).to(device)
                actions = torch.from_numpy(np.array([b[1] for b in batch], dtype=np.int64)).to(device)
                rewards = torch.from_numpy(np.array([b[2] for b in batch], dtype=np.float32)).to(device)
                next_states = torch.from_numpy(np.array([b[3] for b in batch], dtype=np.float32)).to(device)
                dones = torch.from_numpy(np.array([b[4] for b in batch], dtype=np.float32)).to(device)

                current_q = model(states).gather(1, actions.unsqueeze(1))
                with torch.no_grad():
                    next_q = target(next_states).max(1)[0]
                    target_q = rewards + gamma * next_q * (1 - dones)

                loss = F.smooth_l1_loss(current_q.squeeze(), target_q)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        epsilon = max(0.05, epsilon * 0.98)

        if ep % 5 == 0:
            target.load_state_dict(model.state_dict())

        if (ep + 1) % 30 == 0:
            total_return = (env.capital - 10000) / 100
            trades = env.trades
            win_rate = env.wins / trades * 100 if trades > 0 else 0
            marker = '*' if total_return > best_return else ' '
            print(f"{marker} Ep {ep+1:3d}: Return={total_return:+7.2f}%, Trades={trades:3d}, Win={win_rate:5.1f}%")

            if total_return > best_return and trades >= 5:
                best_return = total_return
                best_state = {k: v.clone() for k, v in model.state_dict().items()}

    if best_state:
        model.load_state_dict(best_state)

    # Final evaluation
    env = FastTradingEnv(features_array, prices_aligned)
    state = env.reset()
    actions_count = {0: 0, 1: 0, 2: 0}

    while state is not None:
        with torch.no_grad():
            q = model(torch.FloatTensor(state).to(device))
            action = q.argmax().item()
        actions_count[action] += 1
        state, _, _ = env.step(action)

    final_return = (env.capital - 10000) / 100
    win_rate = env.wins / env.trades * 100 if env.trades > 0 else 0

    print(f"\n   FINAL {model_type.upper()}: {final_return:+.2f}%, {env.trades} trades, {win_rate:.1f}% win")
    print(f"   Actions: FLAT={actions_count[0]}, LONG={actions_count[1]}, SHORT={actions_count[2]}")

    return model, {
        'return': final_return,
        'trades': env.trades,
        'win_rate': win_rate,
        'actions': actions_count
    }


# ============================================================
# 5. MAIN TRAINING LOOP
# ============================================================
print("\n" + "=" * 70)
print("   Training All Models (150 episodes each)")
print("=" * 70)

results = {}
all_models = {}

for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
    model, metrics = train_model(regime, episodes=150)
    results[regime] = metrics
    all_models[regime] = model

# ============================================================
# 6. SAVE MODELS
# ============================================================
print("\n" + "=" * 70)
print("   Saving Models")
print("=" * 70)

for regime, model in all_models.items():
    state_dict = model.state_dict()

    weights = {
        'input_size': 24,
        'hidden_sizes': [64, 32],
        'output_size': 3,
        'weights': {}
    }

    for name, tensor in state_dict.items():
        weights['weights'][name] = tensor.cpu().numpy().tolist()

    filename = f'/kaggle/working/model_{regime}_v2.json'
    with open(filename, 'w') as f:
        json.dump(weights, f)
    print(f"Saved {filename}")

# ============================================================
# 7. SUMMARY
# ============================================================
print("\n" + "=" * 70)
print("   SUMMARY")
print("=" * 70)

buy_hold = (prices_aligned[-1] - prices_aligned[0]) / prices_aligned[0] * 100
print(f"Buy & Hold: {buy_hold:+.2f}%")
print(f"Data: {len(prices_aligned):,} samples ({start_date.date()} to {end_date.date()})\n")

for regime, metrics in results.items():
    beat = "BEAT" if metrics['return'] > buy_hold else "    "
    print(f"  {regime.upper():10s}: {metrics['return']:+7.2f}% | {metrics['trades']:3d} trades | {metrics['win_rate']:5.1f}% win | {beat}")

print("\n" + "=" * 70)
print("   DONE!")
print("=" * 70)
