#!/usr/bin/env python3
"""
BetPredictAI - Robust V4 Model Training
Addestramento robusto con holdout test per verifica reale

Features:
1. ELO Rating System
2. XGBoost Classifier
3. Neural Network con attention
4. Ensemble (60% XGB + 40% NN)

Lascia 400 partite recenti per test/verifica
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import xgboost as xgb
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from collections import defaultdict
import json
import os
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, '..', 'data')
MODELS_DIR = os.path.join(BASE_DIR, '..', 'models')

os.makedirs(MODELS_DIR, exist_ok=True)

# Core leagues
CORE_LEAGUES = ['E0', 'D1', 'I1', 'SP1', 'F1']


class ELOSystem:
    """FIFA-style ELO rating system"""
    def __init__(self, k_factor=32, home_advantage=100):
        self.k_factor = k_factor
        self.home_advantage = home_advantage
        self.ratings = defaultdict(lambda: 1500)

    def expected_score(self, rating_a, rating_b):
        return 1 / (1 + 10 ** ((rating_b - rating_a) / 400))

    def update(self, home_team, away_team, home_goals, away_goals):
        home_rating = self.ratings[home_team] + self.home_advantage
        away_rating = self.ratings[away_team]

        expected_home = self.expected_score(home_rating, away_rating)
        expected_away = 1 - expected_home

        if home_goals > away_goals:
            actual_home, actual_away = 1, 0
        elif home_goals < away_goals:
            actual_home, actual_away = 0, 1
        else:
            actual_home, actual_away = 0.5, 0.5

        goal_diff = abs(home_goals - away_goals)
        k_multiplier = 1 + np.log10(goal_diff + 1) * 0.5
        k = self.k_factor * k_multiplier

        self.ratings[home_team] += k * (actual_home - expected_home)
        self.ratings[away_team] += k * (actual_away - expected_away)

        return home_rating, away_rating, expected_home


class AdvancedBettingNet(nn.Module):
    """Neural network with attention mechanism"""
    def __init__(self, input_size, hidden_sizes=[256, 256, 128, 64], dropout=0.3):
        super().__init__()

        # Attention layer
        self.attention = nn.Sequential(
            nn.Linear(input_size, input_size // 2),
            nn.ReLU(),
            nn.Linear(input_size // 2, input_size),
            nn.Sigmoid()
        )

        # Main network
        layers = []
        prev_size = input_size
        for i, hidden_size in enumerate(hidden_sizes):
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.LeakyReLU(0.1))
            layers.append(nn.Dropout(dropout if i < len(hidden_sizes) - 1 else dropout * 0.5))
            prev_size = hidden_size

        self.feature_extractor = nn.Sequential(*layers)
        self.classifier = nn.Linear(prev_size, 3)

    def forward(self, x):
        attention_weights = self.attention(x)
        x = x * attention_weights
        features = self.feature_extractor(x)
        return self.classifier(features)


def prepare_features(df, elo_system):
    """Prepare features for training/prediction"""
    features = []

    for idx, row in df.iterrows():
        home = str(row['home_team'])
        away = str(row['away_team'])

        # ELO features
        home_elo = elo_system.ratings[home]
        away_elo = elo_system.ratings[away]
        home_elo_adj = home_elo + elo_system.home_advantage
        elo_diff = home_elo_adj - away_elo
        elo_expected = elo_system.expected_score(home_elo_adj, away_elo)

        # Form features (from CSV)
        home_form = row.get('home_form_ppg', 1.3)
        away_form = row.get('away_form_ppg', 1.3)
        home_gpg = row.get('home_form_gpg', 1.3)
        away_gpg = row.get('away_form_gpg', 1.1)
        home_gapg = row.get('home_form_gapg', 1.1)
        away_gapg = row.get('away_form_gapg', 1.3)

        # Win rates
        home_win_rate = row.get('home_form_win_rate', 0.4)
        away_win_rate = row.get('away_form_win_rate', 0.35)

        # Home/Away specific
        home_home_ppg = row.get('home_home_ppg', 1.5)
        away_away_ppg = row.get('away_away_ppg', 1.0)

        # H2H
        h2h_home_wins = row.get('h2h_home_wins', 0.33)
        h2h_matches = row.get('h2h_matches', 0)

        # Implied probabilities from odds
        impl_home = row.get('impl_home', 0.45)
        impl_draw = row.get('impl_draw', 0.27)
        impl_away = row.get('impl_away', 0.28)

        # Derived features
        ppg_diff = home_form - away_form
        gpg_diff = home_gpg - away_gpg
        defense_diff = away_gapg - home_gapg  # Higher is better for home
        win_rate_diff = home_win_rate - away_win_rate

        # Attack vs Defense
        attack_vs_defense_home = home_gpg - away_gapg
        attack_vs_defense_away = away_gpg - home_gapg

        feature_row = {
            'elo_home': home_elo_adj,
            'elo_away': away_elo,
            'elo_diff': elo_diff,
            'elo_expected_home': elo_expected,

            'home_form': home_form,
            'away_form': away_form,
            'form_diff': ppg_diff,

            'home_gpg': home_gpg,
            'away_gpg': away_gpg,
            'gpg_diff': gpg_diff,

            'home_gapg': home_gapg,
            'away_gapg': away_gapg,
            'defense_diff': defense_diff,

            'home_win_rate': home_win_rate,
            'away_win_rate': away_win_rate,
            'win_rate_diff': win_rate_diff,

            'home_home_ppg': home_home_ppg,
            'away_away_ppg': away_away_ppg,
            'home_away_diff': home_home_ppg - away_away_ppg,

            'h2h_home_wins': h2h_home_wins,
            'h2h_matches': min(h2h_matches / 10, 1),

            'impl_home': impl_home,
            'impl_draw': impl_draw,
            'impl_away': impl_away,

            'attack_vs_defense_home': attack_vs_defense_home,
            'attack_vs_defense_away': attack_vs_defense_away,

            # Composite features
            'strength_composite': (elo_expected * 0.4 + impl_home * 0.4 + home_win_rate * 0.2),
            'form_momentum': (home_form - 1.3) - (away_form - 1.3),
        }

        features.append(feature_row)

    return pd.DataFrame(features)


def train_model():
    """Main training function"""
    print("=" * 70)
    print("BetPredictAI - V4 ROBUST TRAINING")
    print("=" * 70)

    # Load data
    print("\n Loading data...")
    df = pd.read_csv(os.path.join(DATA_DIR, 'advanced_historical_matches.csv'), low_memory=False)
    df['date'] = pd.to_datetime(df['date'], format='mixed', dayfirst=True)
    df = df.sort_values('date').reset_index(drop=True)

    # Filter to core leagues
    df = df[df['league_code'].isin(CORE_LEAGUES)].copy()
    print(f"  Core leagues: {len(df)} matches")

    # Clean data
    df = df.dropna(subset=['home_team', 'away_team', 'home_goals', 'away_goals'])
    df['home_goals'] = df['home_goals'].astype(int)
    df['away_goals'] = df['away_goals'].astype(int)

    # Create target
    df['result_code'] = df.apply(
        lambda r: 2 if r['home_goals'] > r['away_goals'] else (1 if r['home_goals'] == r['away_goals'] else 0),
        axis=1
    )

    print(f"  Clean data: {len(df)} matches")
    print(f"  Date range: {df['date'].min().strftime('%Y-%m-%d')} to {df['date'].max().strftime('%Y-%m-%d')}")

    # Split: leave last 400 matches for testing
    TEST_SIZE = 400
    train_df = df.iloc[:-TEST_SIZE].copy()
    test_df = df.iloc[-TEST_SIZE:].copy()

    print(f"\n  Training: {len(train_df)} matches")
    print(f"  Testing:  {len(test_df)} matches (for verification)")

    # Initialize ELO and train on training data
    print("\n Building ELO ratings...")
    elo = ELOSystem(k_factor=32, home_advantage=100)

    for _, row in train_df.iterrows():
        elo.update(
            str(row['home_team']),
            str(row['away_team']),
            int(row['home_goals']),
            int(row['away_goals'])
        )

    print(f"  Teams rated: {len(elo.ratings)}")

    # Prepare features
    print("\n Preparing features...")
    X_train_df = prepare_features(train_df, elo)
    X_test_df = prepare_features(test_df, elo)

    feature_cols = list(X_train_df.columns)
    print(f"  Features: {len(feature_cols)}")

    # Convert to numpy
    X_train = X_train_df.values.astype(np.float32)
    X_test = X_test_df.values.astype(np.float32)
    y_train = train_df['result_code'].values
    y_test = test_df['result_code'].values

    # Handle NaN/Inf
    X_train = np.nan_to_num(X_train, nan=0, posinf=10, neginf=-10)
    X_test = np.nan_to_num(X_test, nan=0, posinf=10, neginf=-10)

    # Scale
    print("\n Scaling features...")
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # ===== TRAIN XGBOOST =====
    print("\n Training XGBoost...")
    xgb_model = xgb.XGBClassifier(
        n_estimators=300,
        max_depth=6,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        min_child_weight=3,
        reg_alpha=0.1,
        reg_lambda=1.0,
        random_state=42,
        use_label_encoder=False,
        eval_metric='mlogloss'
    )

    xgb_model.fit(X_train, y_train, verbose=False)
    xgb_train_acc = (xgb_model.predict(X_train) == y_train).mean()
    xgb_test_acc = (xgb_model.predict(X_test) == y_test).mean()
    print(f"  XGBoost Train Accuracy: {xgb_train_acc*100:.1f}%")
    print(f"  XGBoost Test Accuracy:  {xgb_test_acc*100:.1f}%")

    # ===== TRAIN NEURAL NETWORK =====
    print("\n Training Neural Network...")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    X_train_tensor = torch.FloatTensor(X_train_scaled)
    y_train_tensor = torch.LongTensor(y_train)
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    nn_model = AdvancedBettingNet(len(feature_cols)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(nn_model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5)

    best_loss = float('inf')
    patience_counter = 0

    for epoch in range(100):
        nn_model.train()
        total_loss = 0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = nn_model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        scheduler.step(avg_loss)

        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
            best_state = nn_model.state_dict().copy()
        else:
            patience_counter += 1
            if patience_counter >= 20:
                break

        if (epoch + 1) % 20 == 0:
            print(f"    Epoch {epoch+1}: Loss = {avg_loss:.4f}")

    nn_model.load_state_dict(best_state)
    nn_model.eval()

    with torch.no_grad():
        X_test_tensor = torch.FloatTensor(X_test_scaled).to(device)
        nn_probs = torch.softmax(nn_model(X_test_tensor), dim=1).cpu().numpy()
        nn_preds = nn_probs.argmax(axis=1)
        nn_test_acc = (nn_preds == y_test).mean()

    print(f"  Neural Network Test Accuracy: {nn_test_acc*100:.1f}%")

    # ===== ENSEMBLE =====
    print("\n Creating Ensemble (60% XGB + 40% NN)...")
    xgb_probs = xgb_model.predict_proba(X_test)
    ensemble_probs = 0.6 * xgb_probs + 0.4 * nn_probs
    ensemble_preds = ensemble_probs.argmax(axis=1)
    ensemble_confidence = ensemble_probs.max(axis=1) * 100

    ensemble_acc = (ensemble_preds == y_test).mean()
    print(f"  Ensemble Test Accuracy: {ensemble_acc*100:.1f}%")

    # Accuracy by confidence
    print("\n Accuracy by Confidence Level:")
    for threshold in [40, 45, 50, 55, 60]:
        mask = ensemble_confidence >= threshold
        if mask.sum() > 0:
            high_conf_acc = (ensemble_preds[mask] == y_test[mask]).mean()
            print(f"    >= {threshold}%: {high_conf_acc*100:.1f}% ({mask.sum()} matches)")

    # ===== SAVE MODELS =====
    print("\n Saving models...")

    # Save XGBoost
    xgb_model.save_model(os.path.join(MODELS_DIR, 'xgb_model_v4.json'))

    # Save Neural Network
    torch.save(nn_model.state_dict(), os.path.join(MODELS_DIR, 'nn_model_v4.pth'))

    # Save scaler
    np.save(os.path.join(MODELS_DIR, 'scaler_v4.npy'), {
        'mean': scaler.mean_,
        'scale': scaler.scale_
    })

    # Save ELO ratings
    with open(os.path.join(MODELS_DIR, 'elo_ratings.json'), 'w') as f:
        json.dump(dict(elo.ratings), f, indent=2)

    # Save metadata
    meta = {
        'trained_at': datetime.now().isoformat(),
        'model_version': 'V4-Robust',
        'training_matches': len(train_df),
        'test_matches': len(test_df),
        'feature_columns': feature_cols,
        'accuracy': {
            'xgb': float(xgb_test_acc),
            'nn': float(nn_test_acc),
            'ensemble': float(ensemble_acc)
        }
    }

    with open(os.path.join(MODELS_DIR, 'model_v4_meta.json'), 'w') as f:
        json.dump(meta, f, indent=2)

    # ===== GENERATE VERIFICATION =====
    print("\n Generating verification results...")

    RESULT_MAP = {0: 'AWAY', 1: 'DRAW', 2: 'HOME'}
    LEAGUE_NAMES = {
        'E0': 'Premier League',
        'D1': 'Bundesliga',
        'I1': 'Serie A',
        'SP1': 'La Liga',
        'F1': 'Ligue 1'
    }

    verification_results = []

    for i, (_, row) in enumerate(test_df.iterrows()):
        pred = int(ensemble_preds[i])
        actual = int(y_test[i])
        conf = float(ensemble_confidence[i])

        verification_results.append({
            'league': row['league_code'],
            'league_name': LEAGUE_NAMES.get(row['league_code'], row['league_code']),
            'date': row['date'].strftime('%Y-%m-%d'),
            'time': str(row.get('Time', '15:00'))[:5] if pd.notna(row.get('Time')) else '15:00',
            'home_team': str(row['home_team']),
            'away_team': str(row['away_team']),
            'home_goals': int(row['home_goals']),
            'away_goals': int(row['away_goals']),
            'actual': RESULT_MAP[actual],
            'actual_code': actual,
            'predicted': RESULT_MAP[pred],
            'predicted_code': pred,
            'confidence': round(conf, 1),
            'home_prob': round(float(ensemble_probs[i][2] * 100), 1),
            'draw_prob': round(float(ensemble_probs[i][1] * 100), 1),
            'away_prob': round(float(ensemble_probs[i][0] * 100), 1),
            'correct': actual == pred
        })

    # Sort by date (newest first)
    verification_results = sorted(verification_results, key=lambda x: x['date'], reverse=True)

    # Calculate statistics
    correct = sum(1 for r in verification_results if r['correct'])
    total = len(verification_results)

    by_confidence = {}
    for threshold in [40, 45, 50, 55, 60, 65, 70]:
        high_conf = [r for r in verification_results if r['confidence'] >= threshold]
        if high_conf:
            high_conf_correct = sum(1 for r in high_conf if r['correct'])
            by_confidence[threshold] = {
                'total': len(high_conf),
                'correct': high_conf_correct,
                'accuracy': round(100 * high_conf_correct / len(high_conf), 1)
            }

    by_league = {}
    for league in CORE_LEAGUES:
        league_results = [r for r in verification_results if r['league'] == league]
        if league_results:
            league_correct = sum(1 for r in league_results if r['correct'])
            by_league[league] = {
                'name': LEAGUE_NAMES[league],
                'total': len(league_results),
                'correct': league_correct,
                'accuracy': round(100 * league_correct / len(league_results), 1)
            }

    top_verified = sorted(
        [r for r in verification_results if r['correct']],
        key=lambda x: x['confidence'],
        reverse=True
    )[:20]

    verification_output = {
        'generated_at': datetime.now().isoformat(),
        'model_version': 'V4-Robust',
        'training_info': {
            'total_matches': len(df),
            'training_matches': len(train_df),
            'test_matches': len(test_df)
        },
        'summary': {
            'total': total,
            'correct': correct,
            'accuracy': round(100 * correct / total, 1)
        },
        'by_confidence': by_confidence,
        'by_league': by_league,
        'results': verification_results,
        'top_verified': top_verified
    }

    with open(os.path.join(DATA_DIR, 'verification_results.json'), 'w') as f:
        json.dump(verification_output, f, indent=2)

    # Print final summary
    print("\n" + "=" * 70)
    print("TRAINING COMPLETE")
    print("=" * 70)
    print(f"  Total matches trained: {len(train_df)}")
    print(f"  Test matches verified: {len(test_df)}")
    print(f"\n  Overall Accuracy: {100*correct/total:.1f}%")
    print(f"\n  By Confidence:")
    for thresh, stats in by_confidence.items():
        print(f"    >= {thresh}%: {stats['accuracy']:.1f}% ({stats['correct']}/{stats['total']})")

    print(f"\n  By League:")
    for league, stats in by_league.items():
        print(f"    {stats['name']}: {stats['accuracy']:.1f}% ({stats['correct']}/{stats['total']})")

    print("\n  Models saved to:", MODELS_DIR)
    print("  Verification saved to:", os.path.join(DATA_DIR, 'verification_results.json'))


if __name__ == '__main__':
    train_model()
