#!/usr/bin/env python3
"""
V9 Model Training - Complete Pipeline
=====================================
Trains the trading model on all available historical data.
"""

import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path
import json
import warnings
warnings.filterwarnings('ignore')

# ML Libraries
import lightgbm as lgb
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import joblib

# Configuration
DATA_DIR = Path('/var/www/html/pippo.cuttalo.com/data')
MODEL_DIR = Path('/var/www/html/pippo.cuttalo.com/models')
MODEL_DIR.mkdir(exist_ok=True)

# Feature engineering parameters
LOOKBACK_PERIODS = [5, 10, 20, 50, 100, 200]
RSI_PERIOD = 14
MACD_FAST = 12
MACD_SLOW = 26
MACD_SIGNAL = 9
BB_PERIOD = 20
BB_STD = 2


def load_data():
    """Load and combine all price data."""
    print("Loading price data...")

    dfs = []
    for f in sorted(DATA_DIR.glob('prices_BTC_EUR_*.csv')):
        print(f"  Loading {f.name}...")
        df = pd.read_csv(f)

        # Handle different timestamp formats
        if df['timestamp'].dtype == 'int64' or (df['timestamp'].dtype == 'object' and df['timestamp'].iloc[0].isdigit()):
            # Unix timestamp (seconds or milliseconds)
            ts = pd.to_numeric(df['timestamp'])
            if ts.iloc[0] > 1e12:  # Milliseconds
                df['timestamp'] = pd.to_datetime(ts, unit='ms')
            else:  # Seconds
                df['timestamp'] = pd.to_datetime(ts, unit='s')
        else:
            # String datetime
            df['timestamp'] = pd.to_datetime(df['timestamp'])

        # Keep only necessary columns
        df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
        dfs.append(df)

    df = pd.concat(dfs, ignore_index=True)
    df = df.sort_values('timestamp').drop_duplicates(subset='timestamp').reset_index(drop=True)

    print(f"Total candles: {len(df):,}")
    print(f"Date range: {df['timestamp'].min()} to {df['timestamp'].max()}")

    return df


def calculate_rsi(prices, period=14):
    """Calculate RSI indicator."""
    delta = prices.diff()
    gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
    rs = gain / loss
    return 100 - (100 / (1 + rs))


def calculate_macd(prices, fast=12, slow=26, signal=9):
    """Calculate MACD indicator."""
    exp1 = prices.ewm(span=fast, adjust=False).mean()
    exp2 = prices.ewm(span=slow, adjust=False).mean()
    macd = exp1 - exp2
    macd_signal = macd.ewm(span=signal, adjust=False).mean()
    macd_hist = macd - macd_signal
    return macd, macd_signal, macd_hist


def calculate_bollinger_bands(prices, period=20, std=2):
    """Calculate Bollinger Bands."""
    sma = prices.rolling(window=period).mean()
    std_dev = prices.rolling(window=period).std()
    upper = sma + (std_dev * std)
    lower = sma - (std_dev * std)
    return upper, sma, lower


def calculate_atr(high, low, close, period=14):
    """Calculate Average True Range."""
    tr1 = high - low
    tr2 = abs(high - close.shift(1))
    tr3 = abs(low - close.shift(1))
    tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
    return tr.rolling(window=period).mean()


def create_features(df):
    """Create all features for the model."""
    print("Creating features...")

    close = df['close']
    high = df['high']
    low = df['low']
    volume = df['volume']

    features = pd.DataFrame(index=df.index)

    # Price returns at different timeframes
    for period in LOOKBACK_PERIODS:
        features[f'return_{period}'] = close.pct_change(period)
        features[f'volatility_{period}'] = close.pct_change().rolling(period).std()
        features[f'momentum_{period}'] = close / close.shift(period) - 1

    # Moving averages and crossovers
    for period in [10, 20, 50, 100, 200]:
        features[f'sma_{period}'] = close.rolling(period).mean()
        features[f'ema_{period}'] = close.ewm(span=period, adjust=False).mean()
        features[f'price_vs_sma_{period}'] = close / features[f'sma_{period}'] - 1

    # MA Crossovers
    features['sma_10_50_cross'] = (features['sma_10'] > features['sma_50']).astype(int)
    features['sma_20_100_cross'] = (features['sma_20'] > features['sma_100']).astype(int)
    features['sma_50_200_cross'] = (features['sma_50'] > features['sma_200']).astype(int)

    # RSI
    features['rsi'] = calculate_rsi(close, RSI_PERIOD)
    features['rsi_oversold'] = (features['rsi'] < 30).astype(int)
    features['rsi_overbought'] = (features['rsi'] > 70).astype(int)

    # MACD
    macd, macd_signal, macd_hist = calculate_macd(close, MACD_FAST, MACD_SLOW, MACD_SIGNAL)
    features['macd'] = macd
    features['macd_signal'] = macd_signal
    features['macd_hist'] = macd_hist
    features['macd_cross_up'] = ((macd > macd_signal) & (macd.shift(1) <= macd_signal.shift(1))).astype(int)
    features['macd_cross_down'] = ((macd < macd_signal) & (macd.shift(1) >= macd_signal.shift(1))).astype(int)

    # Bollinger Bands
    bb_upper, bb_mid, bb_lower = calculate_bollinger_bands(close, BB_PERIOD, BB_STD)
    features['bb_position'] = (close - bb_lower) / (bb_upper - bb_lower)
    features['bb_width'] = (bb_upper - bb_lower) / bb_mid
    features['price_vs_bb_upper'] = close / bb_upper - 1
    features['price_vs_bb_lower'] = close / bb_lower - 1

    # ATR
    features['atr'] = calculate_atr(high, low, close, 14)
    features['atr_pct'] = features['atr'] / close

    # Volume features
    features['volume_sma_20'] = volume.rolling(20).mean()
    features['volume_ratio'] = volume / features['volume_sma_20']
    features['volume_change'] = volume.pct_change()

    # Price patterns
    features['higher_high'] = (high > high.shift(1)).astype(int)
    features['lower_low'] = (low < low.shift(1)).astype(int)
    features['body_size'] = abs(close - df['open']) / (high - low + 0.0001)
    features['upper_wick'] = (high - pd.concat([close, df['open']], axis=1).max(axis=1)) / (high - low + 0.0001)
    features['lower_wick'] = (pd.concat([close, df['open']], axis=1).min(axis=1) - low) / (high - low + 0.0001)

    # Trend strength
    features['adx'] = calculate_adx(high, low, close, 14)

    # Time features
    features['hour'] = df['timestamp'].dt.hour
    features['day_of_week'] = df['timestamp'].dt.dayofweek
    features['is_weekend'] = (features['day_of_week'] >= 5).astype(int)

    # Clean up temporary columns
    for col in ['sma_10', 'sma_20', 'sma_50', 'sma_100', 'sma_200',
                'ema_10', 'ema_20', 'ema_50', 'ema_100', 'ema_200',
                'volume_sma_20']:
        if col in features.columns:
            features.drop(col, axis=1, inplace=True)

    print(f"Created {len(features.columns)} features")
    return features


def calculate_adx(high, low, close, period=14):
    """Calculate ADX (Average Directional Index)."""
    plus_dm = high.diff()
    minus_dm = low.diff()

    plus_dm[plus_dm < 0] = 0
    minus_dm[minus_dm > 0] = 0

    tr = calculate_atr(high, low, close, 1) * period

    plus_di = 100 * (plus_dm.rolling(period).sum() / tr)
    minus_di = 100 * (abs(minus_dm.rolling(period).sum()) / tr)

    dx = 100 * abs(plus_di - minus_di) / (plus_di + minus_di + 0.0001)
    adx = dx.rolling(period).mean()

    return adx


def create_labels(df, future_periods=60, threshold=0.003):
    """
    Create labels for training.
    1 = Long opportunity (price will go up by threshold)
    -1 = Short opportunity (price will go down by threshold)
    0 = No clear signal
    """
    print(f"Creating labels (future={future_periods}m, threshold={threshold*100:.1f}%)...")

    close = df['close']

    # Calculate future returns
    future_max = close.rolling(future_periods).max().shift(-future_periods)
    future_min = close.rolling(future_periods).min().shift(-future_periods)

    future_up = (future_max / close - 1)
    future_down = (close / future_min - 1)

    labels = pd.Series(0, index=df.index)

    # Long signal: future max return > threshold and better than down risk
    labels[(future_up >= threshold) & (future_up > future_down)] = 1

    # Short signal: future max drop > threshold and better than up potential
    labels[(future_down >= threshold) & (future_down > future_up)] = -1

    print(f"Labels distribution: Long={sum(labels==1):,}, Short={sum(labels==-1):,}, Neutral={sum(labels==0):,}")

    return labels


def train_model(X, y, params=None):
    """Train LightGBM model with time series cross-validation."""
    print("\nTraining model...")

    if params is None:
        params = {
            'objective': 'multiclass',
            'num_class': 3,
            'metric': 'multi_logloss',
            'boosting_type': 'gbdt',
            'num_leaves': 63,
            'learning_rate': 0.05,
            'feature_fraction': 0.8,
            'bagging_fraction': 0.8,
            'bagging_freq': 5,
            'verbose': -1,
            'n_estimators': 500,
            'early_stopping_rounds': 50,
        }

    # Time series split
    tscv = TimeSeriesSplit(n_splits=5)

    scores = []
    models = []

    for fold, (train_idx, val_idx) in enumerate(tscv.split(X)):
        print(f"  Fold {fold + 1}/5...")

        X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
        y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]

        # Convert labels from -1,0,1 to 0,1,2
        y_train_conv = y_train + 1
        y_val_conv = y_val + 1

        train_data = lgb.Dataset(X_train, label=y_train_conv)
        val_data = lgb.Dataset(X_val, label=y_val_conv, reference=train_data)

        model = lgb.train(
            params,
            train_data,
            valid_sets=[val_data],
            callbacks=[lgb.early_stopping(50), lgb.log_evaluation(0)]
        )

        # Evaluate
        y_pred_proba = model.predict(X_val)
        y_pred = np.argmax(y_pred_proba, axis=1) - 1  # Convert back to -1,0,1

        acc = accuracy_score(y_val, y_pred)
        scores.append(acc)
        models.append(model)

        print(f"    Accuracy: {acc:.4f}")

    print(f"\nMean accuracy: {np.mean(scores):.4f} (+/- {np.std(scores):.4f})")

    # Return the best model (from last fold, which uses most data)
    return models[-1], scores


def evaluate_model(model, X_test, y_test):
    """Evaluate model performance."""
    y_pred_proba = model.predict(X_test)
    y_pred = np.argmax(y_pred_proba, axis=1) - 1

    # Overall metrics
    acc = accuracy_score(y_test, y_pred)

    # Per-class metrics
    results = {
        'accuracy': acc,
        'per_class': {}
    }

    for label, name in [(-1, 'short'), (0, 'neutral'), (1, 'long')]:
        mask = y_test == label
        if mask.sum() > 0:
            pred_mask = y_pred == label
            tp = ((y_test == label) & (y_pred == label)).sum()
            fp = ((y_test != label) & (y_pred == label)).sum()
            fn = ((y_test == label) & (y_pred != label)).sum()

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

            results['per_class'][name] = {
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'support': int(mask.sum())
            }

    return results


def get_feature_importance(model, feature_names):
    """Get feature importance from model."""
    importance = model.feature_importance(importance_type='gain')
    importance_df = pd.DataFrame({
        'feature': feature_names,
        'importance': importance
    }).sort_values('importance', ascending=False)

    return importance_df


def save_model(model, features, metrics, version='v9'):
    """Save model and metadata."""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    model_path = MODEL_DIR / f'model_{version}_{timestamp}'
    model_path.mkdir(exist_ok=True)

    # Save model
    joblib.dump(model, model_path / 'model.joblib')

    # Save feature names
    with open(model_path / 'features.json', 'w') as f:
        json.dump(list(features.columns), f)

    # Save metrics
    with open(model_path / 'metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2)

    # Create symlink to latest
    latest_path = MODEL_DIR / 'latest'
    if latest_path.exists():
        latest_path.unlink()
    latest_path.symlink_to(model_path.name)

    print(f"\nModel saved to {model_path}")
    return model_path


def main():
    print("=" * 60)
    print("V9 MODEL TRAINING")
    print("=" * 60)
    print()

    # Load data
    df = load_data()

    # Create features
    features = create_features(df)

    # Create labels with different thresholds for testing
    labels = create_labels(df, future_periods=60, threshold=0.003)

    # Remove NaN rows
    valid_idx = features.notna().all(axis=1) & labels.notna()
    features = features[valid_idx]
    labels = labels[valid_idx]
    df = df[valid_idx]

    print(f"\nValid samples: {len(features):,}")

    # Split into train/test (last 20% for testing)
    split_idx = int(len(features) * 0.8)

    X_train = features.iloc[:split_idx]
    y_train = labels.iloc[:split_idx]
    X_test = features.iloc[split_idx:]
    y_test = labels.iloc[split_idx:]

    print(f"Train: {len(X_train):,} samples")
    print(f"Test: {len(X_test):,} samples")

    # Train model
    model, cv_scores = train_model(X_train, y_train)

    # Evaluate on test set
    print("\nEvaluating on test set...")
    metrics = evaluate_model(model, X_test, y_test)

    print(f"\nTest Results:")
    print(f"  Overall Accuracy: {metrics['accuracy']:.4f}")
    for name, scores in metrics['per_class'].items():
        print(f"  {name.upper()}:")
        print(f"    Precision: {scores['precision']:.4f}")
        print(f"    Recall: {scores['recall']:.4f}")
        print(f"    F1: {scores['f1']:.4f}")
        print(f"    Support: {scores['support']:,}")

    # Feature importance
    print("\nTop 20 Features:")
    importance = get_feature_importance(model, features.columns)
    for i, row in importance.head(20).iterrows():
        print(f"  {row['feature']}: {row['importance']:.2f}")

    # Save model
    metrics['cv_scores'] = cv_scores
    metrics['cv_mean'] = float(np.mean(cv_scores))
    metrics['cv_std'] = float(np.std(cv_scores))
    metrics['train_samples'] = len(X_train)
    metrics['test_samples'] = len(X_test)
    metrics['features'] = len(features.columns)

    model_path = save_model(model, features, metrics)

    print("\n" + "=" * 60)
    print("TRAINING COMPLETE")
    print("=" * 60)

    return model, features.columns, metrics


if __name__ == '__main__':
    model, feature_names, metrics = main()
