#!/usr/bin/env python3
"""
GPU Batch Training - TradingAI LSTM Models
Trains all missing models on GPU and exports in TF.js format.
Run on RunPod GPU pod.
"""
import os, sys, time, json, struct, shutil
import numpy as np
import requests

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf

# Check GPU
gpus = tf.config.list_physical_devices('GPU')
print(f"GPUs available: {len(gpus)}")
for g in gpus:
    print(f"  {g}")

# ─── Config ───
BINANCE_BASE = "https://api.binance.com/api/v3"
OUTPUT_DIR = "/workspace/models"
FEATURES_COUNT = 14
LOG_RETURN_INDEX = 0

SYMBOLS = ['BTC/EUR','ETH/EUR','SOL/EUR','ADA/EUR','XRP/EUR',
           'DOT/EUR','LINK/EUR','AVAX/EUR','DOGE/EUR','POL/EUR']
TIMEFRAMES = ['5m','15m','1h','4h','6h','8h','1d','2d']

PAIR_MAP = {
    'BTC/EUR':'BTCEUR','ETH/EUR':'ETHEUR','SOL/EUR':'SOLEUR',
    'ADA/EUR':'ADAEUR','XRP/EUR':'XRPEUR','DOT/EUR':'DOTEUR',
    'LINK/EUR':'LINKEUR','AVAX/EUR':'AVAXEUR','DOGE/EUR':'DOGEEUR',
    'POL/EUR':'POLEUR',
}
TF_MAP = {'5m':'5m','15m':'15m','1h':'1h','4h':'4h','6h':'6h','8h':'8h','1d':'1d','2d':'3d'}
CANDLE_MS = {
    '5m':300000,'15m':900000,'1h':3600000,'4h':14400000,
    '6h':21600000,'8h':28800000,'1d':86400000,'2d':172800000,
}

DEFAULT_HP = {
    'sequenceLength': 30,
    'lstmUnits': 32,
    'dropoutRate': 0.15,
    'learningRate': 0.001,
    'batchSize': 64,
    'epochs': 25,
    'mcSamples': 20,
}

# ─── Binance Data ───
def fetch_candles(symbol, timeframe, months=6, max_candles=5000):
    pair = PAIR_MAP.get(symbol)
    if not pair:
        raise ValueError(f"Unknown symbol: {symbol}")
    binance_tf = TF_MAP.get(timeframe, timeframe)
    interval_ms = CANDLE_MS.get(timeframe, 3600000)
    target = min(int((months * 30 * 24 * 3600 * 1000) / interval_ms), max_candles)
    all_c = []
    end_time = int(time.time() * 1000)
    while len(all_c) < target:
        batch = min(1000, target - len(all_c))
        start_time = end_time - batch * interval_ms
        url = f"{BINANCE_BASE}/klines?symbol={pair}&interval={binance_tf}&startTime={start_time}&endTime={end_time}&limit={batch}"
        resp = requests.get(url, timeout=15)
        resp.raise_for_status()
        data = resp.json()
        if not data:
            break
        for k in data:
            all_c.append({'timestamp':int(k[0]),'open':float(k[1]),'high':float(k[2]),'low':float(k[3]),'close':float(k[4]),'volume':float(k[5])})
        end_time = int(data[0][0]) - 1
        if len(data) < batch:
            break
        time.sleep(0.1)
    all_c.sort(key=lambda c: c['timestamp'])
    seen = set()
    unique = []
    for c in all_c:
        if c['timestamp'] not in seen:
            seen.add(c['timestamp'])
            unique.append(c)
    return unique[-target:]

# ─── Features ───
def _ema(values, period):
    result = [0.0] * len(values)
    m = 2.0 / (period + 1)
    result[0] = values[0]
    for i in range(1, len(values)):
        result[i] = (values[i] - result[i-1]) * m + result[i-1]
    return result

def _rsi(closes, idx, period=14):
    if idx < period: return 50.0
    gains, losses = [], []
    for i in range(idx - period + 1, idx + 1):
        ch = closes[i] - closes[i-1]
        gains.append(max(0, ch)); losses.append(max(0, -ch))
    ag = sum(gains)/period; al = sum(losses)/period
    if al == 0: return 100.0
    return 100.0 - (100.0 / (1 + ag/al))

def _bollinger_pct_b(closes, idx, period=20):
    if idx < period - 1: return 0.5
    w = closes[idx-period+1:idx+1]
    mean = sum(w)/len(w); std = float(np.std(w))
    if std == 0: return 0.5
    upper = mean + 2*std; lower = mean - 2*std
    return (closes[idx]-lower)/(upper-lower) if (upper-lower) > 0 else 0.5

def _atr_normalized(candles, idx, period=14):
    if idx < 1: return 0.0
    trs = []
    for i in range(max(1, idx-period+1), idx+1):
        tr = max(candles[i]['high']-candles[i]['low'],
                 abs(candles[i]['high']-candles[i-1]['close']),
                 abs(candles[i]['low']-candles[i-1]['close']))
        trs.append(tr)
    atr = sum(trs)/len(trs) if trs else 0
    return atr / candles[idx]['close'] if candles[idx]['close'] > 0 else 0

def _obv_change(closes, volumes, idx, period=10):
    if idx < period: return 0.0
    obv = 0; obv_prev = 0
    for i in range(idx-period, idx+1):
        if i > 0:
            if closes[i] > closes[i-1]: obv += volumes[i]
            elif closes[i] < closes[i-1]: obv -= volumes[i]
        if i == idx-1: obv_prev = obv
    return (obv - obv_prev)/abs(obv_prev) if obv_prev != 0 else 0

def calculate_features(candles):
    n = len(candles)
    closes = [c['close'] for c in candles]
    highs = [c['high'] for c in candles]
    lows = [c['low'] for c in candles]
    volumes = [c['volume'] for c in candles]
    ema12 = _ema(closes, 12); ema26 = _ema(closes, 26)
    ema9_of_macd = _ema([ema12[i]-ema26[i] for i in range(n)], 9)
    features = []
    for i in range(n):
        close = closes[i]; prev_close = closes[i-1] if i > 0 else close
        log_return = np.log(close/prev_close) if prev_close > 0 and close > 0 else 0.0
        volatility = float(np.std([np.log(closes[j]/closes[j-1]) if closes[j-1]>0 else 0 for j in range(i-9,i+1)])) if i >= 10 else 0.0
        prev_vol = volumes[i-1] if i > 0 else volumes[i]
        vol_change = (volumes[i]-prev_vol)/prev_vol if prev_vol > 0 else 0.0
        rsi = _rsi(closes, i, 14)
        macd_line = ema12[i]-ema26[i]; macd_signal = ema9_of_macd[i]; macd_hist = macd_line-macd_signal
        bb_pct = _bollinger_pct_b(closes, i, 20)
        atr = _atr_normalized(candles, i, 14)
        obv_ch = _obv_change(closes, volumes, i, 10)
        momentum = (close-closes[i-10])/closes[i-10] if i >= 10 and closes[i-10] > 0 else 0.0
        h = highs[i]; l = lows[i]; price_pos = (close-l)/(h-l) if (h-l) > 0 else 0.5
        sma20_dist = ((close-sum(closes[i-19:i+1])/20)/(sum(closes[i-19:i+1])/20)) if i >= 19 else 0.0
        vol_sma = sum(volumes[i-19:i+1])/20 if i >= 19 else 0
        vol_ratio = volumes[i]/vol_sma if i >= 19 and vol_sma > 0 else 1.0
        features.append([log_return,volatility,vol_change,rsi,macd_line,macd_signal,macd_hist,bb_pct,atr,obv_ch,momentum,price_pos,sma20_dist,vol_ratio])
    return features

def normalize(features):
    arr = np.array(features, dtype=np.float32)
    mean = np.mean(arr, axis=0); std = np.std(arr, axis=0)
    std[std == 0] = 1.0
    return (arr - mean) / std, mean.tolist(), std.tolist()

# ─── Model ───
def create_model(hp):
    model = tf.keras.Sequential([
        tf.keras.layers.LSTM(hp['lstmUnits'], input_shape=(hp['sequenceLength'], FEATURES_COUNT),
                             return_sequences=True, dropout=hp['dropoutRate'], recurrent_dropout=hp['dropoutRate']),
        tf.keras.layers.Dropout(hp['dropoutRate']),
        tf.keras.layers.LSTM(max(8, hp['lstmUnits']//2), return_sequences=False,
                             dropout=hp['dropoutRate'], recurrent_dropout=hp['dropoutRate']),
        tf.keras.layers.Dropout(hp['dropoutRate']),
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dropout(hp['dropoutRate']),
        tf.keras.layers.Dense(1),
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=hp['learningRate']),
                  loss='mse', metrics=['mae'])
    return model

# ─── TF.js Export ───
def export_tfjs(model, scaler_mean, scaler_std, hp, output_dir):
    """Export model in TF.js format (model.json + weights.bin + metadata.json)."""
    os.makedirs(output_dir, exist_ok=True)

    # Build weight data and specs
    weight_data = b''
    weight_specs = []
    for w in model.weights:
        w_np = w.numpy().astype(np.float32)
        w_bytes = w_np.tobytes()
        weight_specs.append({
            'name': w.name,
            'shape': list(w_np.shape),
            'dtype': 'float32',
        })
        weight_data += w_bytes

    # Write weights.bin
    with open(os.path.join(output_dir, 'weights.bin'), 'wb') as f:
        f.write(weight_data)

    # Build model.json (TF.js format)
    model_config = json.loads(model.to_json())
    model_json = {
        'modelTopology': model_config,
        'weightsManifest': [{
            'paths': ['weights.bin'],
            'weights': weight_specs,
        }],
    }
    with open(os.path.join(output_dir, 'model.json'), 'w') as f:
        json.dump(model_json, f)

    # Write metadata.json
    metadata = {
        'key': os.path.basename(output_dir).replace('_', '/', 1),  # BTC_EUR_4h -> BTC/EUR_4h
        'scaler': {'mean': scaler_mean, 'std': scaler_std},
        'hyperParams': hp,
        'lastTrained': int(time.time() * 1000),
    }
    with open(os.path.join(output_dir, 'metadata.json'), 'w') as f:
        json.dump(metadata, f)

# ─── Train Single Model ───
def train_model(symbol, timeframe, hp):
    tag = f"{symbol} {timeframe}"
    dir_name = f"{symbol.replace('/','_')}_{timeframe}"
    output_dir = os.path.join(OUTPUT_DIR, dir_name)
    start = time.time()

    try:
        # 1. Fetch data
        candles = fetch_candles(symbol, timeframe, months=6, max_candles=5000)
        if len(candles) < hp['sequenceLength'] + 20:
            print(f"  ⚠ {tag}: Not enough data ({len(candles)} candles)")
            return False

        # 2. Features & normalize
        features = calculate_features(candles)
        normalized, mean, std = normalize(features)

        # 3. Prepare training data
        seq_len = hp['sequenceLength']
        X, y = [], []
        for i in range(seq_len, len(normalized) - 1):
            X.append(normalized[i-seq_len:i])
            y.append(normalized[i+1][LOG_RETURN_INDEX])

        X = np.array(X, dtype=np.float32)
        y = np.array(y, dtype=np.float32).reshape(-1, 1)

        split = int(len(X) * 0.8)
        X_train, X_val = X[:split], X[split:]
        y_train, y_val = y[:split], y[split:]

        # 4. Train
        model = create_model(hp)
        model.fit(X_train, y_train, validation_data=(X_val, y_val),
                  epochs=hp['epochs'], batch_size=hp['batchSize'], shuffle=True, verbose=0)

        # 5. Metrics
        val_pred = model.predict(X_val, verbose=0).flatten()
        pred_lr = val_pred * std[LOG_RETURN_INDEX] + mean[LOG_RETURN_INDEX]
        actual_lr = y_val.flatten() * std[LOG_RETURN_INDEX] + mean[LOG_RETURN_INDEX]
        correct = sum(1 for p, a in zip(pred_lr, actual_lr) if (p >= 0) == (a >= 0))
        da = correct / len(pred_lr)

        # 6. Export TF.js
        export_tfjs(model, mean, std, hp, output_dir)

        elapsed = time.time() - start
        print(f"  ✅ {tag:14s} {len(candles):5d} candles  DA={da*100:.1f}%  {elapsed:.1f}s")

        # Free GPU memory
        del model, X, y, X_train, X_val, y_train, y_val
        tf.keras.backend.clear_session()

        return True

    except Exception as e:
        elapsed = time.time() - start
        print(f"  ❌ {tag:14s} ERROR: {e}  ({elapsed:.1f}s)")
        return False

# ─── Main ───
def main():
    # Check which models to skip (already exist)
    skip = set()
    if len(sys.argv) > 1 and sys.argv[1] == '--skip-existing':
        existing_dir = sys.argv[2] if len(sys.argv) > 2 else ''
        if existing_dir and os.path.isdir(existing_dir):
            skip = set(os.listdir(existing_dir))
            print(f"Skipping {len(skip)} existing models from {existing_dir}")

    # Build job list
    jobs = []
    for symbol in SYMBOLS:
        for tf_name in TIMEFRAMES:
            dir_name = f"{symbol.replace('/','_')}_{tf_name}"
            if dir_name not in skip:
                jobs.append((symbol, tf_name))

    print(f"\n{'='*60}")
    print(f" TradingAI GPU Batch Training")
    print(f" Models to train: {len(jobs)}")
    print(f" Output: {OUTPUT_DIR}")
    print(f" GPU: {gpus[0].name if gpus else 'CPU'}")
    print(f"{'='*60}\n")

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    ok = 0; fail = 0
    total_start = time.time()

    for i, (symbol, tf_name) in enumerate(jobs):
        print(f"[{i+1}/{len(jobs)}] Training {symbol} {tf_name}...")
        if train_model(symbol, tf_name, DEFAULT_HP):
            ok += 1
        else:
            fail += 1

    total_time = time.time() - total_start
    print(f"\n{'='*60}")
    print(f" COMPLETE in {total_time:.0f}s ({total_time/60:.1f} min)")
    print(f" Success: {ok}/{len(jobs)}")
    print(f" Failed:  {fail}/{len(jobs)}")
    print(f" Output:  {OUTPUT_DIR}")
    print(f"{'='*60}")

    # List outputs
    if os.path.isdir(OUTPUT_DIR):
        dirs = sorted(os.listdir(OUTPUT_DIR))
        print(f"\nExported models ({len(dirs)}):")
        for d in dirs:
            files = os.listdir(os.path.join(OUTPUT_DIR, d))
            print(f"  {d}: {', '.join(files)}")

if __name__ == '__main__':
    main()
