#!/usr/bin/env python3
"""
Full Backtest - Test best configuration on all data
====================================================
"""

import pandas as pd
import numpy as np
from datetime import datetime
from pathlib import Path
import json
import joblib

DATA_DIR = Path('/var/www/html/pippo.cuttalo.com/data')
MODEL_DIR = Path('/var/www/html/pippo.cuttalo.com/models')
RESULTS_DIR = Path('/var/www/html/pippo.cuttalo.com/optimization_results')

# Best configurations to test
CONFIGS = [
    {
        'name': 'V9 Balanced',
        'stop_loss': 0.02,
        'take_profit': 0.04,
        'position_size': 0.50,
        'confidence_threshold': 0.60,
    },
    {
        'name': 'V9 Conservative',
        'stop_loss': 0.015,
        'take_profit': 0.03,
        'position_size': 0.20,
        'confidence_threshold': 0.65,
    },
    {
        'name': 'V9 Aggressive',
        'stop_loss': 0.025,
        'take_profit': 0.05,
        'position_size': 0.50,
        'confidence_threshold': 0.55,
    },
]

INITIAL_CAPITAL = 10000
FEE_RATE = 0.001
SLIPPAGE = 0.0003


def load_data():
    dfs = []
    for f in sorted(DATA_DIR.glob('prices_BTC_EUR_*.csv')):
        df = pd.read_csv(f)
        if df['timestamp'].dtype == 'int64' or str(df['timestamp'].iloc[0]).isdigit():
            ts = pd.to_numeric(df['timestamp'])
            if ts.iloc[0] > 1e12:
                df['timestamp'] = pd.to_datetime(ts, unit='ms')
            else:
                df['timestamp'] = pd.to_datetime(ts, unit='s')
        else:
            df['timestamp'] = pd.to_datetime(df['timestamp'])
        df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
        dfs.append(df)
    df = pd.concat(dfs, ignore_index=True)
    return df.sort_values('timestamp').drop_duplicates(subset='timestamp').reset_index(drop=True)


def load_model():
    latest_path = MODEL_DIR / 'latest'
    model_path = latest_path.resolve()
    model = joblib.load(model_path / 'model.joblib')
    with open(model_path / 'features.json') as f:
        feature_names = json.load(f)
    return model, feature_names


def create_features(df):
    close = df['close']
    high = df['high']
    low = df['low']
    volume = df['volume']
    open_ = df['open']

    features = pd.DataFrame(index=df.index)

    LOOKBACK_PERIODS = [5, 10, 20, 50, 100, 200]

    for period in LOOKBACK_PERIODS:
        features[f'return_{period}'] = close.pct_change(period)
        features[f'volatility_{period}'] = close.pct_change().rolling(period).std()
        features[f'momentum_{period}'] = close / close.shift(period) - 1

    for period in [10, 20, 50, 100, 200]:
        sma = close.rolling(period).mean()
        features[f'price_vs_sma_{period}'] = close / sma - 1

    sma_10 = close.rolling(10).mean()
    sma_20 = close.rolling(20).mean()
    sma_50 = close.rolling(50).mean()
    sma_100 = close.rolling(100).mean()
    sma_200 = close.rolling(200).mean()

    features['sma_10_50_cross'] = (sma_10 > sma_50).astype(int)
    features['sma_20_100_cross'] = (sma_20 > sma_100).astype(int)
    features['sma_50_200_cross'] = (sma_50 > sma_200).astype(int)

    delta = close.diff()
    gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
    rs = gain / loss
    features['rsi'] = 100 - (100 / (1 + rs))
    features['rsi_oversold'] = (features['rsi'] < 30).astype(int)
    features['rsi_overbought'] = (features['rsi'] > 70).astype(int)

    exp1 = close.ewm(span=12, adjust=False).mean()
    exp2 = close.ewm(span=26, adjust=False).mean()
    macd = exp1 - exp2
    macd_signal = macd.ewm(span=9, adjust=False).mean()
    features['macd'] = macd
    features['macd_signal'] = macd_signal
    features['macd_hist'] = macd - macd_signal
    features['macd_cross_up'] = ((macd > macd_signal) & (macd.shift(1) <= macd_signal.shift(1))).astype(int)
    features['macd_cross_down'] = ((macd < macd_signal) & (macd.shift(1) >= macd_signal.shift(1))).astype(int)

    bb_mid = close.rolling(20).mean()
    bb_std = close.rolling(20).std()
    bb_upper = bb_mid + (bb_std * 2)
    bb_lower = bb_mid - (bb_std * 2)
    features['bb_position'] = (close - bb_lower) / (bb_upper - bb_lower + 0.0001)
    features['bb_width'] = (bb_upper - bb_lower) / bb_mid
    features['price_vs_bb_upper'] = close / bb_upper - 1
    features['price_vs_bb_lower'] = close / bb_lower - 1

    tr1 = high - low
    tr2 = abs(high - close.shift(1))
    tr3 = abs(low - close.shift(1))
    tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
    atr = tr.rolling(14).mean()
    features['atr'] = atr
    features['atr_pct'] = atr / close

    vol_sma = volume.rolling(20).mean()
    features['volume_ratio'] = volume / vol_sma
    features['volume_change'] = volume.pct_change()

    features['higher_high'] = (high > high.shift(1)).astype(int)
    features['lower_low'] = (low < low.shift(1)).astype(int)
    features['body_size'] = abs(close - open_) / (high - low + 0.0001)
    features['upper_wick'] = (high - pd.concat([close, open_], axis=1).max(axis=1)) / (high - low + 0.0001)
    features['lower_wick'] = (pd.concat([close, open_], axis=1).min(axis=1) - low) / (high - low + 0.0001)

    plus_dm = high.diff()
    minus_dm = low.diff()
    plus_dm[plus_dm < 0] = 0
    minus_dm[minus_dm > 0] = 0
    tr_sum = tr.rolling(14).sum()
    plus_di = 100 * (plus_dm.rolling(14).sum() / tr_sum)
    minus_di = 100 * (abs(minus_dm.rolling(14).sum()) / tr_sum)
    dx = 100 * abs(plus_di - minus_di) / (plus_di + minus_di + 0.0001)
    features['adx'] = dx.rolling(14).mean()

    features['hour'] = df['timestamp'].dt.hour
    features['day_of_week'] = df['timestamp'].dt.dayofweek
    features['is_weekend'] = (features['day_of_week'] >= 5).astype(int)

    return features


def backtest(df, predictions, valid_mask, config):
    close = df['close'].values
    n = len(close)

    capital = INITIAL_CAPITAL
    position = 0
    position_amount = 0
    entry_price = 0

    trades = []
    equity = np.zeros(n)
    monthly_returns = {}

    for i in range(500, n):
        price = close[i]
        month = df['timestamp'].iloc[i].strftime('%Y-%m')

        if position != 0:
            if position == 1:
                unrealized = position_amount * (price - entry_price)
            else:
                unrealized = position_amount * (entry_price - price)
            equity[i] = capital + position_amount * entry_price + unrealized
        else:
            equity[i] = capital

        # Track monthly equity
        if month not in monthly_returns:
            monthly_returns[month] = {'start': equity[i], 'end': equity[i]}
        monthly_returns[month]['end'] = equity[i]

        if position != 0:
            if position == 1:
                pnl_pct = (price - entry_price) / entry_price
            else:
                pnl_pct = (entry_price - price) / entry_price

            exit_reason = None
            if pnl_pct <= -config['stop_loss']:
                exit_reason = 'stop_loss'
            elif pnl_pct >= config['take_profit']:
                exit_reason = 'take_profit'

            if exit_reason:
                if position == 1:
                    exec_price = price * (1 - SLIPPAGE)
                    gross_pnl = position_amount * (exec_price - entry_price)
                else:
                    exec_price = price * (1 + SLIPPAGE)
                    gross_pnl = position_amount * (entry_price - exec_price)

                fee = position_amount * exec_price * FEE_RATE
                net_pnl = gross_pnl - fee
                capital += position_amount * entry_price + net_pnl

                trades.append({
                    'net_pnl': net_pnl,
                    'pnl_pct': net_pnl / (position_amount * entry_price),
                    'reason': exit_reason,
                })

                position = 0
                position_amount = 0
                continue

        if position == 0 and valid_mask[i]:
            proba = predictions[i]
            long_prob = proba[2]
            short_prob = proba[0]

            signal = 0
            if long_prob > config['confidence_threshold'] and long_prob > short_prob:
                signal = 1
            elif short_prob > config['confidence_threshold'] and short_prob > long_prob:
                signal = -1

            if signal != 0:
                trade_value = capital * config['position_size']
                fee = trade_value * FEE_RATE

                if signal == 1:
                    exec_price = price * (1 + SLIPPAGE)
                else:
                    exec_price = price * (1 - SLIPPAGE)

                position_amount = (trade_value - fee) / exec_price
                capital -= trade_value
                position = signal
                entry_price = exec_price

    if position != 0:
        price = close[-1]
        if position == 1:
            exec_price = price * (1 - SLIPPAGE)
            gross_pnl = position_amount * (exec_price - entry_price)
        else:
            exec_price = price * (1 + SLIPPAGE)
            gross_pnl = position_amount * (entry_price - exec_price)

        fee = position_amount * exec_price * FEE_RATE
        net_pnl = gross_pnl - fee
        capital += position_amount * entry_price + net_pnl

        trades.append({
            'net_pnl': net_pnl,
            'pnl_pct': net_pnl / (position_amount * entry_price),
            'reason': 'end_of_test',
        })

    equity[-1] = capital

    return calculate_metrics(trades, equity, monthly_returns)


def calculate_metrics(trades, equity, monthly_returns):
    if not trades:
        return None

    final_equity = equity[-1]
    total_return_pct = (final_equity - INITIAL_CAPITAL) / INITIAL_CAPITAL

    wins = [t for t in trades if t['net_pnl'] > 0]
    losses = [t for t in trades if t['net_pnl'] <= 0]

    win_rate = len(wins) / len(trades) if trades else 0

    valid_equity = equity[equity > 0]
    if len(valid_equity) > 1:
        returns = np.diff(valid_equity) / valid_equity[:-1]
        if returns.std() > 0:
            sharpe = returns.mean() / returns.std() * np.sqrt(525600)
        else:
            sharpe = 0
    else:
        sharpe = 0

    cummax = np.maximum.accumulate(valid_equity)
    drawdown = (valid_equity - cummax) / cummax
    max_drawdown = abs(drawdown.min()) if len(drawdown) > 0 else 0

    gross_profit = sum(t['net_pnl'] for t in wins) if wins else 0
    gross_loss = abs(sum(t['net_pnl'] for t in losses)) if losses else 1
    profit_factor = gross_profit / gross_loss if gross_loss > 0 else gross_profit

    # Monthly returns
    monthly_pnl = {}
    for month, data in monthly_returns.items():
        monthly_pnl[month] = (data['end'] - data['start']) / data['start'] if data['start'] > 0 else 0

    return {
        'total_return': final_equity - INITIAL_CAPITAL,
        'total_return_pct': total_return_pct,
        'final_equity': final_equity,
        'num_trades': len(trades),
        'win_rate': win_rate,
        'sharpe_ratio': sharpe,
        'max_drawdown': max_drawdown,
        'profit_factor': profit_factor,
        'wins': len(wins),
        'losses': len(losses),
        'avg_win': np.mean([t['net_pnl'] for t in wins]) if wins else 0,
        'avg_loss': np.mean([t['net_pnl'] for t in losses]) if losses else 0,
        'monthly_returns': monthly_pnl,
    }


def main():
    print("=" * 70)
    print("FULL BACKTEST - ALL DATA")
    print("=" * 70)

    print("\nLoading data...")
    df = load_data()
    print(f"Loaded {len(df):,} candles")
    print(f"Period: {df['timestamp'].min()} to {df['timestamp'].max()}")

    print("\nLoading model...")
    model, feature_names = load_model()

    print("\nCreating features...")
    features = create_features(df)

    print("Generating predictions...")
    X = features[feature_names].values
    valid_mask = ~np.isnan(X).any(axis=1)
    predictions = np.zeros((len(X), 3))
    predictions[valid_mask] = model.predict(X[valid_mask])

    print("\n" + "=" * 70)
    print("BACKTEST RESULTS")
    print("=" * 70)

    results = []

    for config in CONFIGS:
        print(f"\n{config['name']}")
        print("-" * 50)
        print(f"SL: {config['stop_loss']:.1%} | TP: {config['take_profit']:.1%} | "
              f"Size: {config['position_size']:.0%} | Conf: {config['confidence_threshold']:.0%}")

        metrics = backtest(df, predictions, valid_mask, config)

        if metrics:
            print(f"\nPerformance:")
            print(f"  Total Return: €{metrics['total_return']:,.2f} ({metrics['total_return_pct']*100:+.2f}%)")
            print(f"  Final Equity: €{metrics['final_equity']:,.2f}")
            print(f"  Trades: {metrics['num_trades']} (Wins: {metrics['wins']}, Losses: {metrics['losses']})")
            print(f"  Win Rate: {metrics['win_rate']*100:.1f}%")
            print(f"  Avg Win: €{metrics['avg_win']:.2f} | Avg Loss: €{metrics['avg_loss']:.2f}")
            print(f"  Sharpe Ratio: {metrics['sharpe_ratio']:.2f}")
            print(f"  Max Drawdown: {metrics['max_drawdown']*100:.1f}%")
            print(f"  Profit Factor: {metrics['profit_factor']:.2f}")

            # Monthly breakdown
            print(f"\nMonthly Returns (last 6):")
            for month in list(sorted(metrics['monthly_returns'].keys()))[-6:]:
                ret = metrics['monthly_returns'][month]
                print(f"    {month}: {ret*100:+.2f}%")

            results.append({
                'name': config['name'],
                'config': config,
                'metrics': metrics
            })

    # Save results
    with open(RESULTS_DIR / 'full_backtest_results.json', 'w') as f:
        # Convert numpy types
        def convert(obj):
            if isinstance(obj, np.floating):
                return float(obj)
            if isinstance(obj, np.integer):
                return int(obj)
            if isinstance(obj, dict):
                return {k: convert(v) for k, v in obj.items()}
            if isinstance(obj, list):
                return [convert(i) for i in obj]
            return obj

        json.dump(convert(results), f, indent=2)

    print("\n" + "=" * 70)
    print("BACKTEST COMPLETE")
    print(f"Results saved to {RESULTS_DIR / 'full_backtest_results.json'}")
    print("=" * 70)


if __name__ == '__main__':
    main()
