#!/usr/bin/env python3
"""
Fast Configuration Optimizer
============================
Pre-calculates features and predictions, then runs fast backtests.
"""

import pandas as pd
import numpy as np
from datetime import datetime
from pathlib import Path
from itertools import product
import json
import warnings
warnings.filterwarnings('ignore')

import joblib

# Configuration
DATA_DIR = Path('/var/www/html/pippo.cuttalo.com/data')
MODEL_DIR = Path('/var/www/html/pippo.cuttalo.com/models')
RESULTS_DIR = Path('/var/www/html/pippo.cuttalo.com/optimization_results')
RESULTS_DIR.mkdir(exist_ok=True)

# Parameter grid
PARAM_GRID = {
    'stop_loss': [0.01, 0.015, 0.02, 0.025, 0.03],
    'take_profit': [0.02, 0.03, 0.04, 0.05, 0.06],
    'position_size': [0.1, 0.2, 0.3, 0.5],
    'confidence_threshold': [0.5, 0.55, 0.6, 0.65, 0.7],
}

INITIAL_CAPITAL = 10000
FEE_RATE = 0.001
SLIPPAGE = 0.0003


def load_data():
    """Load price data."""
    print("Loading price data...")
    dfs = []
    for f in sorted(DATA_DIR.glob('prices_BTC_EUR_*.csv')):
        df = pd.read_csv(f)
        if df['timestamp'].dtype == 'int64' or str(df['timestamp'].iloc[0]).isdigit():
            ts = pd.to_numeric(df['timestamp'])
            if ts.iloc[0] > 1e12:
                df['timestamp'] = pd.to_datetime(ts, unit='ms')
            else:
                df['timestamp'] = pd.to_datetime(ts, unit='s')
        else:
            df['timestamp'] = pd.to_datetime(df['timestamp'])
        df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
        dfs.append(df)
    df = pd.concat(dfs, ignore_index=True)
    df = df.sort_values('timestamp').drop_duplicates(subset='timestamp').reset_index(drop=True)
    print(f"Loaded {len(df):,} candles")
    return df


def load_model():
    """Load trained model."""
    latest_path = MODEL_DIR / 'latest'
    model_path = latest_path.resolve()
    model = joblib.load(model_path / 'model.joblib')
    with open(model_path / 'features.json') as f:
        feature_names = json.load(f)
    print(f"Loaded model from {model_path}")
    return model, feature_names


def create_features_vectorized(df):
    """Create all features in a vectorized manner."""
    print("Creating features...")
    close = df['close']
    high = df['high']
    low = df['low']
    volume = df['volume']
    open_ = df['open']

    features = pd.DataFrame(index=df.index)

    # Lookback periods
    LOOKBACK_PERIODS = [5, 10, 20, 50, 100, 200]

    # Price returns at different timeframes
    for period in LOOKBACK_PERIODS:
        features[f'return_{period}'] = close.pct_change(period)
        features[f'volatility_{period}'] = close.pct_change().rolling(period).std()
        features[f'momentum_{period}'] = close / close.shift(period) - 1

    # Moving averages
    for period in [10, 20, 50, 100, 200]:
        sma = close.rolling(period).mean()
        features[f'price_vs_sma_{period}'] = close / sma - 1

    # MA Crossovers
    sma_10 = close.rolling(10).mean()
    sma_20 = close.rolling(20).mean()
    sma_50 = close.rolling(50).mean()
    sma_100 = close.rolling(100).mean()
    sma_200 = close.rolling(200).mean()

    features['sma_10_50_cross'] = (sma_10 > sma_50).astype(int)
    features['sma_20_100_cross'] = (sma_20 > sma_100).astype(int)
    features['sma_50_200_cross'] = (sma_50 > sma_200).astype(int)

    # RSI
    delta = close.diff()
    gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
    rs = gain / loss
    features['rsi'] = 100 - (100 / (1 + rs))
    features['rsi_oversold'] = (features['rsi'] < 30).astype(int)
    features['rsi_overbought'] = (features['rsi'] > 70).astype(int)

    # MACD
    exp1 = close.ewm(span=12, adjust=False).mean()
    exp2 = close.ewm(span=26, adjust=False).mean()
    macd = exp1 - exp2
    macd_signal = macd.ewm(span=9, adjust=False).mean()
    macd_hist = macd - macd_signal
    features['macd'] = macd
    features['macd_signal'] = macd_signal
    features['macd_hist'] = macd_hist
    features['macd_cross_up'] = ((macd > macd_signal) & (macd.shift(1) <= macd_signal.shift(1))).astype(int)
    features['macd_cross_down'] = ((macd < macd_signal) & (macd.shift(1) >= macd_signal.shift(1))).astype(int)

    # Bollinger Bands
    bb_mid = close.rolling(20).mean()
    bb_std = close.rolling(20).std()
    bb_upper = bb_mid + (bb_std * 2)
    bb_lower = bb_mid - (bb_std * 2)
    features['bb_position'] = (close - bb_lower) / (bb_upper - bb_lower + 0.0001)
    features['bb_width'] = (bb_upper - bb_lower) / bb_mid
    features['price_vs_bb_upper'] = close / bb_upper - 1
    features['price_vs_bb_lower'] = close / bb_lower - 1

    # ATR
    tr1 = high - low
    tr2 = abs(high - close.shift(1))
    tr3 = abs(low - close.shift(1))
    tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
    atr = tr.rolling(14).mean()
    features['atr'] = atr
    features['atr_pct'] = atr / close

    # Volume features
    vol_sma = volume.rolling(20).mean()
    features['volume_ratio'] = volume / vol_sma
    features['volume_change'] = volume.pct_change()

    # Price patterns
    features['higher_high'] = (high > high.shift(1)).astype(int)
    features['lower_low'] = (low < low.shift(1)).astype(int)
    features['body_size'] = abs(close - open_) / (high - low + 0.0001)
    features['upper_wick'] = (high - pd.concat([close, open_], axis=1).max(axis=1)) / (high - low + 0.0001)
    features['lower_wick'] = (pd.concat([close, open_], axis=1).min(axis=1) - low) / (high - low + 0.0001)

    # ADX
    plus_dm = high.diff()
    minus_dm = low.diff()
    plus_dm[plus_dm < 0] = 0
    minus_dm[minus_dm > 0] = 0
    tr_sum = tr.rolling(14).sum()
    plus_di = 100 * (plus_dm.rolling(14).sum() / tr_sum)
    minus_di = 100 * (abs(minus_dm.rolling(14).sum()) / tr_sum)
    dx = 100 * abs(plus_di - minus_di) / (plus_di + minus_di + 0.0001)
    features['adx'] = dx.rolling(14).mean()

    # Time features
    features['hour'] = df['timestamp'].dt.hour
    features['day_of_week'] = df['timestamp'].dt.dayofweek
    features['is_weekend'] = (features['day_of_week'] >= 5).astype(int)

    print(f"Created {len(features.columns)} features")
    return features


def get_predictions(model, features, feature_names):
    """Get model predictions for all rows."""
    print("Generating predictions...")

    # Align features with what model expects
    X = features[feature_names].values

    # Get predictions for all valid rows
    valid_mask = ~np.isnan(X).any(axis=1)
    predictions = np.zeros((len(X), 3))
    predictions[valid_mask] = model.predict(X[valid_mask])

    return predictions, valid_mask


def fast_backtest(df, predictions, valid_mask, params):
    """Run fast vectorized backtest."""
    close = df['close'].values
    n = len(close)

    capital = INITIAL_CAPITAL
    position = 0
    position_amount = 0
    entry_price = 0
    entry_idx = 0

    trades = []
    equity = np.zeros(n)

    conf_threshold = params['confidence_threshold']
    stop_loss = params['stop_loss']
    take_profit = params['take_profit']
    position_size = params['position_size']

    for i in range(500, n):  # Start after warmup
        price = close[i]

        # Calculate equity
        if position != 0:
            if position == 1:
                unrealized = position_amount * (price - entry_price)
            else:
                unrealized = position_amount * (entry_price - price)
            equity[i] = capital + position_amount * entry_price + unrealized
        else:
            equity[i] = capital

        # Check exit conditions
        if position != 0:
            if position == 1:
                pnl_pct = (price - entry_price) / entry_price
            else:
                pnl_pct = (entry_price - price) / entry_price

            exit_reason = None
            if pnl_pct <= -stop_loss:
                exit_reason = 'stop_loss'
            elif pnl_pct >= take_profit:
                exit_reason = 'take_profit'

            if exit_reason:
                # Execute exit
                if position == 1:
                    exec_price = price * (1 - SLIPPAGE)
                    gross_pnl = position_amount * (exec_price - entry_price)
                else:
                    exec_price = price * (1 + SLIPPAGE)
                    gross_pnl = position_amount * (entry_price - exec_price)

                fee = position_amount * exec_price * FEE_RATE
                net_pnl = gross_pnl - fee
                capital += position_amount * entry_price + net_pnl

                trades.append({
                    'entry_price': entry_price,
                    'exit_price': exec_price,
                    'net_pnl': net_pnl,
                    'pnl_pct': net_pnl / (position_amount * entry_price),
                    'reason': exit_reason,
                    'duration': i - entry_idx
                })

                position = 0
                position_amount = 0
                continue

        # Check entry signal
        if position == 0 and valid_mask[i]:
            proba = predictions[i]
            long_prob = proba[2]
            short_prob = proba[0]

            signal = 0
            if long_prob > conf_threshold and long_prob > short_prob:
                signal = 1
            elif short_prob > conf_threshold and short_prob > long_prob:
                signal = -1

            if signal != 0:
                trade_value = capital * position_size
                fee = trade_value * FEE_RATE

                if signal == 1:
                    exec_price = price * (1 + SLIPPAGE)
                else:
                    exec_price = price * (1 - SLIPPAGE)

                position_amount = (trade_value - fee) / exec_price
                capital -= trade_value
                position = signal
                entry_price = exec_price
                entry_idx = i

    # Close any open position
    if position != 0:
        price = close[-1]
        if position == 1:
            exec_price = price * (1 - SLIPPAGE)
            gross_pnl = position_amount * (exec_price - entry_price)
        else:
            exec_price = price * (1 + SLIPPAGE)
            gross_pnl = position_amount * (entry_price - exec_price)

        fee = position_amount * exec_price * FEE_RATE
        net_pnl = gross_pnl - fee
        capital += position_amount * entry_price + net_pnl

        trades.append({
            'entry_price': entry_price,
            'exit_price': exec_price,
            'net_pnl': net_pnl,
            'pnl_pct': net_pnl / (position_amount * entry_price),
            'reason': 'end_of_test',
            'duration': n - entry_idx
        })

    equity[-1] = capital

    return calculate_metrics(trades, equity)


def calculate_metrics(trades, equity):
    """Calculate performance metrics."""
    if not trades:
        return {
            'total_return_pct': 0, 'num_trades': 0, 'win_rate': 0,
            'sharpe_ratio': 0, 'max_drawdown': 0, 'profit_factor': 0,
            'avg_trade_pnl': 0, 'avg_win': 0, 'avg_loss': 0,
        }

    final_equity = equity[-1]
    total_return_pct = (final_equity - INITIAL_CAPITAL) / INITIAL_CAPITAL

    wins = [t for t in trades if t['net_pnl'] > 0]
    losses = [t for t in trades if t['net_pnl'] <= 0]

    win_rate = len(wins) / len(trades) if trades else 0

    # Sharpe ratio
    valid_equity = equity[equity > 0]
    if len(valid_equity) > 1:
        returns = np.diff(valid_equity) / valid_equity[:-1]
        if returns.std() > 0:
            sharpe = returns.mean() / returns.std() * np.sqrt(525600)
        else:
            sharpe = 0
    else:
        sharpe = 0

    # Max drawdown
    cummax = np.maximum.accumulate(valid_equity)
    drawdown = (valid_equity - cummax) / cummax
    max_drawdown = abs(drawdown.min()) if len(drawdown) > 0 else 0

    # Profit factor
    gross_profit = sum(t['net_pnl'] for t in wins) if wins else 0
    gross_loss = abs(sum(t['net_pnl'] for t in losses)) if losses else 1
    profit_factor = gross_profit / gross_loss if gross_loss > 0 else gross_profit

    avg_trade_pnl = np.mean([t['net_pnl'] for t in trades])
    avg_win = np.mean([t['net_pnl'] for t in wins]) if wins else 0
    avg_loss = np.mean([t['net_pnl'] for t in losses]) if losses else 0

    return {
        'total_return_pct': total_return_pct,
        'final_equity': final_equity,
        'num_trades': len(trades),
        'win_rate': win_rate,
        'sharpe_ratio': sharpe,
        'max_drawdown': max_drawdown,
        'profit_factor': profit_factor,
        'avg_trade_pnl': avg_trade_pnl,
        'avg_win': avg_win,
        'avg_loss': avg_loss,
        'wins': len(wins),
        'losses': len(losses),
    }


def main():
    print("=" * 60)
    print("FAST CONFIGURATION OPTIMIZATION")
    print("=" * 60)

    # Load data and model
    df = load_data()
    model, feature_names = load_model()

    # Use last 300k candles for faster testing (about 200 days)
    sample_size = 300000
    if len(df) > sample_size:
        df = df.iloc[-sample_size:].reset_index(drop=True)
        print(f"Using last {sample_size:,} candles")

    # Create features once
    features = create_features_vectorized(df)

    # Get predictions once
    predictions, valid_mask = get_predictions(model, features, feature_names)

    # Generate parameter combinations
    param_names = list(PARAM_GRID.keys())
    combinations = list(product(*PARAM_GRID.values()))
    print(f"\nTesting {len(combinations)} configurations...")

    results = []

    for i, combo in enumerate(combinations):
        params = dict(zip(param_names, combo))
        params['fee_rate'] = FEE_RATE

        metrics = fast_backtest(df, predictions, valid_mask, params)
        result = {**params, **metrics}
        results.append(result)

        if (i + 1) % 50 == 0 or i == 0:
            print(f"[{i+1}/{len(combinations)}] SL={params['stop_loss']:.1%}, "
                  f"TP={params['take_profit']:.1%}, Size={params['position_size']:.0%}, "
                  f"Conf={params['confidence_threshold']:.0%} -> "
                  f"Return: {metrics['total_return_pct']*100:+.1f}%, "
                  f"Trades: {metrics['num_trades']}, WR: {metrics['win_rate']*100:.0f}%")

    # Analyze results
    df_results = pd.DataFrame(results)

    # Score configurations
    df_results['score'] = (
        df_results['total_return_pct'] * 0.3 +
        df_results['sharpe_ratio'].clip(lower=-2, upper=2) * 0.25 +
        df_results['win_rate'] * 0.2 +
        df_results['profit_factor'].clip(upper=5) * 0.15 -
        df_results['max_drawdown'] * 0.1
    )

    df_results = df_results.sort_values('score', ascending=False)

    print("\n" + "=" * 60)
    print("TOP 10 CONFIGURATIONS")
    print("=" * 60)

    for rank, (_, row) in enumerate(df_results.head(10).iterrows(), 1):
        print(f"\n#{rank} Score: {row['score']:.4f}")
        print(f"  SL: {row['stop_loss']:.1%} | TP: {row['take_profit']:.1%} | "
              f"Size: {row['position_size']:.0%} | Conf: {row['confidence_threshold']:.0%}")
        print(f"  Return: {row['total_return_pct']*100:+.2f}% | Trades: {row['num_trades']:.0f} | "
              f"Win Rate: {row['win_rate']*100:.1f}%")
        print(f"  Sharpe: {row['sharpe_ratio']:.2f} | Max DD: {row['max_drawdown']*100:.1f}% | "
              f"PF: {row['profit_factor']:.2f}")

    # Save results
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    results_file = RESULTS_DIR / f'optimization_{timestamp}.csv'
    df_results.to_csv(results_file, index=False)

    # Best config
    best = df_results.iloc[0]
    best_config = {
        'stop_loss': float(best['stop_loss']),
        'take_profit': float(best['take_profit']),
        'position_size': float(best['position_size']),
        'confidence_threshold': float(best['confidence_threshold']),
        'fee_rate': FEE_RATE,
        'metrics': {
            'total_return_pct': float(best['total_return_pct']),
            'win_rate': float(best['win_rate']),
            'sharpe_ratio': float(best['sharpe_ratio']),
            'max_drawdown': float(best['max_drawdown']),
            'profit_factor': float(best['profit_factor']),
            'num_trades': int(best['num_trades']),
        }
    }

    with open(RESULTS_DIR / 'best_config.json', 'w') as f:
        json.dump(best_config, f, indent=2)

    print(f"\nResults saved to {results_file}")
    print(f"Best config saved to {RESULTS_DIR / 'best_config.json'}")

    print("\n" + "=" * 60)
    print("BEST CONFIGURATION")
    print("=" * 60)
    print(json.dumps(best_config, indent=2))

    return best_config


if __name__ == '__main__':
    main()
