"""
BestTrading V9 - FEE-AWARE Training
=====================================
Key improvements:
1. Fee-aware reward: penalizes EVERY trade by 0.4% round-trip
2. Minimum hold time: prevents overtrading
3. Sharpe-ratio inspired reward scaling
4. Higher penalty for losses than reward for wins
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import json
import requests
from collections import deque
import random
import datetime

print("=" * 70)
print("   BestTrading V9 - FEE-AWARE Training")
print("   Designed to be PROFITABLE AFTER FEES")
print("=" * 70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# ============================================================
# CONFIGURATION - FEE AWARE
# ============================================================
CONFIG = {
    'FEE_RATE': 0.002,           # 0.2% per side (Kraken taker)
    'FEE_ROUND_TRIP': 0.004,     # 0.4% total cost per trade
    'MIN_HOLD_BARS': 5,          # Minimum 5 minutes hold time
    'MIN_PROFIT_TARGET': 0.005,  # 0.5% minimum profit target
    'LOSS_PENALTY_MULT': 1.5,    # Penalize losses 1.5x more than wins
    'TRADE_PENALTY': -0.5,       # Penalty for each trade (discourage overtrading)
    'EPISODES': 200,             # More episodes for better convergence
    'BATCH_SIZE': 512,           # Larger batch for stability
    'BUFFER_SIZE': 50000,        # Larger replay buffer
}

print("\nFee-Aware Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

# ============================================================
# 1. DOWNLOAD DATA FROM BINANCE
# ============================================================
print("\n" + "=" * 70)
print("   Downloading Historical Data from Binance")
print("=" * 70)

def fetch_binance_klines(symbol, interval, start_time, end_time):
    """Download klines from Binance"""
    all_klines = []
    current_start = start_time

    while current_start < end_time:
        url = f"https://api.binance.com/api/v3/klines?symbol={symbol}&interval={interval}&startTime={current_start}&limit=1000"
        response = requests.get(url)
        data = response.json()

        if not data or isinstance(data, dict):
            break

        all_klines.extend(data)
        last_close_time = data[-1][6]
        current_start = last_close_time + 1

        if len(data) < 1000:
            break

    return all_klines

# Download 3 months of data for robust training
start_date = datetime.datetime(2025, 10, 15, 0, 0, 0)
end_date = datetime.datetime(2026, 1, 20, 0, 0, 0)
start_ms = int(start_date.timestamp() * 1000)
end_ms = int(end_date.timestamp() * 1000)

print(f"Period: {start_date.date()} to {end_date.date()}")
print("Downloading BTCEUR 1m data...")

klines = fetch_binance_klines("BTCEUR", "1m", start_ms, end_ms)
prices = np.array([float(k[4]) for k in klines], dtype=np.float32)
volumes = np.array([float(k[5]) for k in klines], dtype=np.float32)

print(f"Downloaded {len(prices):,} price points")
print(f"Price range: EUR {prices.min():.2f} - EUR {prices.max():.2f}")

# ============================================================
# 2. COMPUTE FEATURES
# ============================================================
print("\n" + "=" * 70)
print("   Computing Features")
print("=" * 70)

def compute_features(prices, idx):
    """Compute exactly 24 features"""
    if idx < 60:
        return None

    price = prices[idx]
    hist = prices[max(0, idx-60):idx+1]
    features = []

    # 1. Returns (4 features)
    for period in [1, 5, 10, 20]:
        if len(hist) > period:
            ret = (price - hist[-1-period]) / hist[-1-period]
            features.append(max(-0.1, min(0.1, ret)) * 10)
        else:
            features.append(0)

    # 2. Volatility (3 features)
    if len(hist) >= 20:
        returns = [(hist[i] - hist[i-1]) / max(hist[i-1], 0.0001) for i in range(1, min(20, len(hist)))]
        rv = np.sqrt(np.mean(np.array(returns) ** 2)) * np.sqrt(252 * 24 * 60) * 100
        features.append(min(2, rv / 50))
        features.append(min(2, rv / 50))
        features.append(min(2, abs(returns[0]) * 100 / 30 if returns else 0))
    else:
        features.extend([0, 0, 0])

    # 3. Flow (3 features)
    if len(hist) >= 5:
        flow = (price - hist[-5]) / hist[-5] * 50
        flow = max(-1, min(1, flow))
    else:
        flow = 0
    features.append(flow)
    features.append(0)
    features.append(flow * 0.8)

    # 4. Spread/liquidity (2 features)
    features.append(0.2)
    features.append(0.5)

    # 5. Entropy/quality (2 features)
    features.append(0.5)
    features.append(1)

    # 6. MA crossover (1 feature)
    if len(hist) >= 10:
        ma5 = np.mean(hist[-5:])
        ma10 = np.mean(hist[-10:])
        features.append(max(-1, min(1, (ma5 - ma10) / ma10 * 100)))
    else:
        features.append(0)

    # 7. RSI (1 feature)
    if len(hist) >= 15:
        gains, losses = [], []
        for i in range(1, min(14, len(hist))):
            change = hist[-i] - hist[-i-1]
            if change > 0:
                gains.append(change)
            elif change < 0:
                losses.append(abs(change))
        avg_gain = sum(gains) / 14.0 if gains else 0.0
        avg_loss = sum(losses) / 14.0 if losses else 0.0
        if avg_loss == 0:
            rsi = 100.0 if avg_gain > 0 else 50.0
        elif avg_gain == 0:
            rsi = 0.0
        else:
            rsi = 100.0 - (100.0 / (1.0 + avg_gain / avg_loss))
        features.append((rsi - 50.0) / 50.0)
    else:
        features.append(0)

    # 8. Flow momentum (1 feature)
    if len(hist) >= 6:
        flow_now = (price - hist[-2]) / max(hist[-2], 0.0001) * 50
        flow_5ago = (hist[-5] - hist[-6]) / max(hist[-6], 0.0001) * 50
        features.append(max(-1, min(1, flow_now - flow_5ago)))
    else:
        features.append(0)

    # 9. Vol regime (1 feature)
    features.append(1)

    # 10. Position info (6 features) - placeholders, filled by env
    features.extend([0, 0, 0, 0, 0, 0])

    return features

# Compute all features
print("Pre-computing features...")
all_features = []
valid_indices = []

for idx in range(60, len(prices)):
    features = compute_features(prices, idx)
    if features and len(features) == 24:
        all_features.append(features)
        valid_indices.append(idx)

    if (idx - 60) % 20000 == 0:
        print(f"  {idx - 60:,}/{len(prices) - 60:,} timesteps...")

features_array = np.array(all_features, dtype=np.float32)
prices_aligned = np.array([prices[i] for i in valid_indices], dtype=np.float32)

print(f"Features: {features_array.shape}")
print(f"Prices: {len(prices_aligned):,}")

# ============================================================
# 3. NEURAL NETWORK
# ============================================================

class TradingNN(nn.Module):
    """24 -> 128 -> 64 -> 32 -> 3 (deeper network)"""
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(24, 128),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.01),
            nn.Linear(32, 3)
        )
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)


class FeeAwareTradingEnv:
    """
    FEE-AWARE Environment
    - Penalizes EVERY trade by fee cost
    - Enforces minimum hold time
    - Asymmetric reward (losses hurt more)
    """
    def __init__(self, features, prices):
        self.base_features = features
        self.prices = prices
        self.n_steps = len(prices)
        self.reset()

    def reset(self):
        self.idx = 0
        self.position = 0  # -1, 0, 1
        self.entry_price = 0
        self.entry_idx = 0  # Track when position was opened
        self.capital = 10000
        self.peak_capital = 10000
        self.trades = 0
        self.wins = 0
        self.total_fees = 0
        self.gross_pnl = 0
        self.returns_history = []  # For Sharpe calculation
        return self._get_state()

    def _get_state(self):
        if self.idx >= self.n_steps:
            return None

        state = self.base_features[self.idx].copy()
        state[18] = self.position

        if self.position != 0 and self.entry_price > 0:
            price = self.prices[self.idx]
            if self.position > 0:
                upnl = (price - self.entry_price) / self.entry_price
            else:
                upnl = (self.entry_price - price) / self.entry_price
            state[19] = max(-0.1, min(0.1, upnl)) * 10
        else:
            state[19] = 0

        # Time in position (normalized)
        if self.position != 0:
            bars_in_position = self.idx - self.entry_idx
            state[20] = min(1, bars_in_position / 60)  # Normalize to ~1 hour
        else:
            state[20] = 0

        dd = (self.peak_capital - self.capital) / self.peak_capital
        state[21] = min(1, dd * 10)
        state[22] = 0.5 if self.trades == 0 else self.wins / max(self.trades, 1)
        state[23] = (self.capital / 10000) - 1

        return state

    def step(self, action):
        if self.idx >= self.n_steps:
            return None, 0, True

        price = self.prices[self.idx]
        reward = 0
        bars_held = self.idx - self.entry_idx if self.position != 0 else 0

        # ===== FEE-AWARE REWARD LOGIC =====

        if action == 0:  # FLAT - close position
            if self.position != 0:
                # Calculate PnL
                if self.position > 0:
                    gross_pnl = (price - self.entry_price) / self.entry_price
                else:
                    gross_pnl = (self.entry_price - price) / self.entry_price

                # Subtract FULL round-trip fee
                net_pnl = gross_pnl - CONFIG['FEE_ROUND_TRIP']

                # Apply to capital
                self.capital *= (1 + net_pnl)
                self.gross_pnl += gross_pnl
                self.total_fees += CONFIG['FEE_ROUND_TRIP'] * self.capital

                self.trades += 1
                if net_pnl > 0:
                    self.wins += 1
                    reward = net_pnl * 100  # Positive reward for profitable trade
                else:
                    # ASYMMETRIC: Penalize losses more heavily
                    reward = net_pnl * 100 * CONFIG['LOSS_PENALTY_MULT']

                # Bonus for holding longer (reduces overtrading)
                if bars_held >= CONFIG['MIN_HOLD_BARS']:
                    reward += 0.1  # Small bonus for patient trades

                self.returns_history.append(net_pnl)
                self.position = 0
                self.entry_price = 0

        elif action == 1:  # LONG
            if self.position == -1:
                # Close short first
                gross_pnl = (self.entry_price - price) / self.entry_price
                net_pnl = gross_pnl - CONFIG['FEE_ROUND_TRIP']
                self.capital *= (1 + net_pnl)
                self.gross_pnl += gross_pnl
                self.total_fees += CONFIG['FEE_ROUND_TRIP'] * self.capital
                self.trades += 1
                if net_pnl > 0:
                    self.wins += 1
                    reward += net_pnl * 100
                else:
                    reward += net_pnl * 100 * CONFIG['LOSS_PENALTY_MULT']
                self.returns_history.append(net_pnl)

            if self.position != 1:
                # Penalty for opening new position (fee cost)
                reward += CONFIG['TRADE_PENALTY']
                self.position = 1
                self.entry_price = price
                self.entry_idx = self.idx

        elif action == 2:  # SHORT
            if self.position == 1:
                # Close long first
                gross_pnl = (price - self.entry_price) / self.entry_price
                net_pnl = gross_pnl - CONFIG['FEE_ROUND_TRIP']
                self.capital *= (1 + net_pnl)
                self.gross_pnl += gross_pnl
                self.total_fees += CONFIG['FEE_ROUND_TRIP'] * self.capital
                self.trades += 1
                if net_pnl > 0:
                    self.wins += 1
                    reward += net_pnl * 100
                else:
                    reward += net_pnl * 100 * CONFIG['LOSS_PENALTY_MULT']
                self.returns_history.append(net_pnl)

            if self.position != -1:
                # Penalty for opening new position
                reward += CONFIG['TRADE_PENALTY']
                self.position = -1
                self.entry_price = price
                self.entry_idx = self.idx

        self.peak_capital = max(self.peak_capital, self.capital)

        # Small reward for holding profitable position (encourages patience)
        if self.position != 0 and self.idx + 1 < self.n_steps:
            next_price = self.prices[self.idx + 1]
            if self.position > 0:
                hold_reward = (next_price - price) / price * 30  # Reduced from 50
            else:
                hold_reward = (price - next_price) / price * 30
            reward += hold_reward

        self.idx += 1
        next_state = self._get_state()
        done = next_state is None

        return next_state, reward, done

    def get_sharpe(self):
        """Calculate Sharpe ratio of returns"""
        if len(self.returns_history) < 2:
            return 0
        returns = np.array(self.returns_history)
        if returns.std() == 0:
            return 0
        return returns.mean() / returns.std() * np.sqrt(252)


# ============================================================
# 4. TRAINING
# ============================================================

def train_model(model_type, episodes=200):
    """Train one model with fee-aware rewards"""
    print(f"\n{'='*60}")
    print(f"   Training {model_type.upper()} (FEE-AWARE)")
    print(f"{'='*60}")

    model = TradingNN().to(device)
    target = TradingNN().to(device)
    target.load_state_dict(model.state_dict())

    optimizer = optim.Adam(model.parameters(), lr=0.0005)  # Lower LR for stability
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
    buffer = deque(maxlen=CONFIG['BUFFER_SIZE'])

    gamma = 0.95
    epsilon = 1.0
    batch_size = CONFIG['BATCH_SIZE']

    best_sharpe = float('-inf')
    best_return = float('-inf')
    best_state = None

    for ep in range(episodes):
        env = FeeAwareTradingEnv(features_array, prices_aligned)
        state = env.reset()
        ep_reward = 0

        while state is not None:
            if random.random() < epsilon:
                action = random.randint(0, 2)
            else:
                with torch.no_grad():
                    q = model(torch.FloatTensor(state).to(device))
                    action = q.argmax().item()

            next_state, reward, done = env.step(action)

            if next_state is not None:
                buffer.append((state, action, reward, next_state, done))

            ep_reward += reward
            state = next_state

        # Training
        if len(buffer) >= batch_size:
            for _ in range(15):  # More updates per episode
                batch = random.sample(buffer, batch_size)

                states = torch.from_numpy(np.array([b[0] for b in batch], dtype=np.float32)).to(device)
                actions = torch.from_numpy(np.array([b[1] for b in batch], dtype=np.int64)).to(device)
                rewards = torch.from_numpy(np.array([b[2] for b in batch], dtype=np.float32)).to(device)
                next_states = torch.from_numpy(np.array([b[3] for b in batch], dtype=np.float32)).to(device)
                dones = torch.from_numpy(np.array([b[4] for b in batch], dtype=np.float32)).to(device)

                # Double DQN
                current_q = model(states).gather(1, actions.unsqueeze(1))
                with torch.no_grad():
                    # Use online network to select action
                    next_actions = model(next_states).argmax(1, keepdim=True)
                    # Use target network to evaluate
                    next_q = target(next_states).gather(1, next_actions).squeeze()
                    target_q = rewards + gamma * next_q * (1 - dones)

                loss = F.smooth_l1_loss(current_q.squeeze(), target_q)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

        epsilon = max(0.02, epsilon * 0.985)
        scheduler.step()

        if ep % 5 == 0:
            target.load_state_dict(model.state_dict())

        if (ep + 1) % 25 == 0:
            net_return = (env.capital - 10000) / 100
            gross_return = env.gross_pnl * 100
            fees_pct = (env.total_fees / 10000) * 100
            trades = env.trades
            win_rate = env.wins / trades * 100 if trades > 0 else 0
            sharpe = env.get_sharpe()

            # Track best by Sharpe ratio (more robust than raw return)
            is_best = sharpe > best_sharpe and trades >= 10 and net_return > 0
            marker = '***' if is_best else '   '

            print(f"{marker} Ep {ep+1:3d}: Net={net_return:+6.2f}% Gross={gross_return:+6.2f}% Fees={fees_pct:.2f}% | {trades:3d} trades | Win={win_rate:5.1f}% | Sharpe={sharpe:.2f}")

            if is_best:
                best_sharpe = sharpe
                best_return = net_return
                best_state = {k: v.clone() for k, v in model.state_dict().items()}

    if best_state:
        model.load_state_dict(best_state)
        print(f"\n   Loaded best model: Sharpe={best_sharpe:.2f}, Return={best_return:.2f}%")

    # Final evaluation
    env = FeeAwareTradingEnv(features_array, prices_aligned)
    state = env.reset()
    actions_count = {0: 0, 1: 0, 2: 0}

    while state is not None:
        with torch.no_grad():
            q = model(torch.FloatTensor(state).to(device))
            action = q.argmax().item()
        actions_count[action] += 1
        state, _, _ = env.step(action)

    net_return = (env.capital - 10000) / 100
    gross_return = env.gross_pnl * 100
    fees_paid = env.total_fees
    win_rate = env.wins / env.trades * 100 if env.trades > 0 else 0
    sharpe = env.get_sharpe()

    print(f"\n   FINAL {model_type.upper()}:")
    print(f"   Net Return:   {net_return:+.2f}%")
    print(f"   Gross Return: {gross_return:+.2f}%")
    print(f"   Fees Paid:    EUR {fees_paid:.2f}")
    print(f"   Trades:       {env.trades}")
    print(f"   Win Rate:     {win_rate:.1f}%")
    print(f"   Sharpe Ratio: {sharpe:.2f}")
    print(f"   Actions: FLAT={actions_count[0]}, LONG={actions_count[1]}, SHORT={actions_count[2]}")

    return model, {
        'net_return': net_return,
        'gross_return': gross_return,
        'fees': fees_paid,
        'trades': env.trades,
        'win_rate': win_rate,
        'sharpe': sharpe,
        'actions': actions_count
    }


# ============================================================
# 5. MAIN TRAINING LOOP
# ============================================================
print("\n" + "=" * 70)
print(f"   Training All Models ({CONFIG['EPISODES']} episodes each)")
print("=" * 70)

results = {}
all_models = {}

for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
    model, metrics = train_model(regime, episodes=CONFIG['EPISODES'])
    results[regime] = metrics
    all_models[regime] = model

# ============================================================
# 6. SAVE MODELS
# ============================================================
print("\n" + "=" * 70)
print("   Saving Models")
print("=" * 70)

import os
os.makedirs('/kaggle/working', exist_ok=True)

for regime, model in all_models.items():
    state_dict = model.state_dict()

    # Convert to format expected by orchestrator
    weights = {
        'input_size': 24,
        'hidden_sizes': [128, 64, 32],  # Updated architecture
        'output_size': 3,
        'weights': {}
    }

    for name, tensor in state_dict.items():
        weights['weights'][name] = tensor.cpu().numpy().tolist()

    filename = f'/kaggle/working/model_{regime}_v9.json'
    with open(filename, 'w') as f:
        json.dump(weights, f)
    print(f"Saved {filename}")

# ============================================================
# 7. AUTO-UPLOAD TO PRODUCTION
# ============================================================
print("\n" + "=" * 70)
print("   Uploading to Production")
print("=" * 70)

UPLOAD_URL = "https://cuttalo.com/upload-model.php"

for regime in ['bullish', 'bearish', 'ranging', 'volatile', 'scalper']:
    filename = f'/kaggle/working/model_{regime}_v9.json'
    with open(filename, 'r') as f:
        data = json.load(f)

    try:
        response = requests.post(
            f"{UPLOAD_URL}?regime={regime}",
            json=data,
            headers={'Content-Type': 'application/json'},
            timeout=30
        )
        if response.status_code == 200:
            print(f"  {regime}: UPLOADED")
        else:
            print(f"  {regime}: FAILED ({response.status_code})")
    except Exception as e:
        print(f"  {regime}: ERROR - {e}")

# ============================================================
# 8. SUMMARY
# ============================================================
print("\n" + "=" * 70)
print("   SUMMARY - FEE-AWARE MODELS")
print("=" * 70)

buy_hold = (prices_aligned[-1] - prices_aligned[0]) / prices_aligned[0] * 100
print(f"Buy & Hold: {buy_hold:+.2f}%")
print(f"Data: {len(prices_aligned):,} samples ({start_date.date()} to {end_date.date()})")
print(f"Fee Rate: {CONFIG['FEE_ROUND_TRIP']*100:.2f}% round-trip\n")

print(f"{'Model':<12} {'Net Return':>12} {'Gross':>10} {'Trades':>8} {'Win Rate':>10} {'Sharpe':>8} {'Status':>10}")
print("-" * 75)

for regime, m in results.items():
    beat = "PROFIT" if m['net_return'] > 0 else "LOSS"
    status_color = beat
    print(f"{regime.upper():<12} {m['net_return']:>+11.2f}% {m['gross_return']:>+9.2f}% {m['trades']:>8d} {m['win_rate']:>9.1f}% {m['sharpe']:>8.2f} {status_color:>10}")

avg_return = np.mean([m['net_return'] for m in results.values()])
avg_sharpe = np.mean([m['sharpe'] for m in results.values()])
print("-" * 75)
print(f"{'AVERAGE':<12} {avg_return:>+11.2f}% {'':>10} {'':>8} {'':>10} {avg_sharpe:>8.2f}")

print("\n" + "=" * 70)
print("   TRAINING COMPLETE!")
print("   Models uploaded to: cuttalo.com/upload-model.php")
print("   Restart orchestrator to load new models")
print("=" * 70)
