#!/usr/bin/env python3
"""
BetPredictAI - Advanced Model V4
================================

State-of-the-art football prediction using:
1. ELO Rating System (FIFA-style)
2. Dixon-Coles Goal Distribution Correction
3. XGBoost + Neural Network Ensemble
4. Time-Weighted Feature Engineering
5. Poisson-Based Expected Goals

Research Sources:
- ETH Zurich: "From Ratings to Results: Advanced Predictive Modeling"
- Imperial College: "Using ML to predict professional football matches"
- Frontiers in Sports: "EPV vs xG to predict match outcomes"

Target: 65-75% accuracy on high-confidence predictions
"""

import os
import sys
import json
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# ML Libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from scipy.optimize import minimize
from scipy.stats import poisson
import xgboost as xgb

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, '..', 'data')
MODELS_DIR = os.path.join(BASE_DIR, '..', 'models')

os.makedirs(MODELS_DIR, exist_ok=True)

print("="*70)
print("🧠 BetPredictAI - ADVANCED MODEL V4")
print("="*70)
print("  📚 Methods: ELO + Dixon-Coles + XGBoost + Neural Network Ensemble")
print("  🎯 Target: 65-75% accuracy on high-confidence predictions")
print("="*70)


# =============================================================================
# 1. ELO RATING SYSTEM (FIFA-style)
# =============================================================================

class ELOSystem:
    """
    ELO Rating System for Football Teams
    Based on FIFA World Rankings methodology (adopted 2018)
    """

    def __init__(self, k_factor=32, home_advantage=100, initial_rating=1500):
        self.k_factor = k_factor
        self.home_advantage = home_advantage
        self.initial_rating = initial_rating
        self.ratings = defaultdict(lambda: initial_rating)
        self.rating_history = defaultdict(list)

    def expected_result(self, rating_a, rating_b):
        """Calculate expected result using logistic curve"""
        return 1 / (1 + 10 ** ((rating_b - rating_a) / 400))

    def update_ratings(self, home_team, away_team, home_goals, away_goals,
                       importance=1.0, date=None):
        """Update ELO ratings after a match"""

        # Get current ratings with home advantage
        home_rating = self.ratings[home_team] + self.home_advantage
        away_rating = self.ratings[away_team]

        # Calculate expected results
        home_expected = self.expected_result(home_rating, away_rating)
        away_expected = 1 - home_expected

        # Actual results (1 = win, 0.5 = draw, 0 = loss)
        if home_goals > away_goals:
            home_actual, away_actual = 1, 0
        elif home_goals < away_goals:
            home_actual, away_actual = 0, 1
        else:
            home_actual, away_actual = 0.5, 0.5

        # Goal difference multiplier (FIFA style)
        goal_diff = abs(home_goals - away_goals)
        if goal_diff <= 1:
            G = 1
        elif goal_diff == 2:
            G = 1.5
        else:
            G = (11 + goal_diff) / 8

        # Calculate rating changes
        k = self.k_factor * importance * G
        home_change = k * (home_actual - home_expected)
        away_change = k * (away_actual - away_expected)

        # Update ratings
        self.ratings[home_team] += home_change
        self.ratings[away_team] += away_change

        # Store history
        if date:
            self.rating_history[home_team].append((date, self.ratings[home_team]))
            self.rating_history[away_team].append((date, self.ratings[away_team]))

        return home_change, away_change

    def get_rating(self, team):
        return self.ratings[team]

    def get_ratings_for_match(self, home_team, away_team):
        """Get ratings for prediction, including home advantage"""
        home_rating = self.ratings[home_team] + self.home_advantage
        away_rating = self.ratings[away_team]
        return home_rating, away_rating, self.expected_result(home_rating, away_rating)


# =============================================================================
# 2. DIXON-COLES MODEL
# =============================================================================

class DixonColesModel:
    """
    Dixon-Coles Model for Football Score Prediction
    Corrects Poisson model's underestimation of draws and low scores
    """

    def __init__(self, decay_rate=0.0065):
        self.decay_rate = decay_rate  # Time decay factor
        self.teams = {}
        self.attack = {}
        self.defense = {}
        self.home_adv = 0.25
        self.rho = -0.13  # Low-score correction

    def time_weight(self, days_ago):
        """Apply exponential decay to older matches"""
        return np.exp(-self.decay_rate * days_ago)

    def tau(self, x, y, lambda_x, mu_y, rho):
        """Dixon-Coles correction factor for low scores"""
        if x == 0 and y == 0:
            return 1 - lambda_x * mu_y * rho
        elif x == 0 and y == 1:
            return 1 + lambda_x * rho
        elif x == 1 and y == 0:
            return 1 + mu_y * rho
        elif x == 1 and y == 1:
            return 1 - rho
        else:
            return 1

    def fit(self, matches_df, reference_date=None):
        """Fit the model to historical data"""
        if reference_date is None:
            reference_date = datetime.now()

        # Get unique teams
        all_teams = set(matches_df['home_team'].unique()) | set(matches_df['away_team'].unique())
        self.teams = {team: i for i, team in enumerate(all_teams)}
        n_teams = len(all_teams)

        # Initialize parameters
        # attack[i], defense[i] for each team, home_adv, rho
        initial_params = np.zeros(2 * n_teams + 2)
        initial_params[-2] = 0.25  # home advantage
        initial_params[-1] = -0.13  # rho

        # Prepare data with time weights
        home_teams = []
        away_teams = []
        home_goals = []
        away_goals = []
        weights = []

        for _, row in matches_df.iterrows():
            if pd.isna(row['home_team']) or pd.isna(row['away_team']):
                continue

            home_teams.append(self.teams.get(row['home_team'], 0))
            away_teams.append(self.teams.get(row['away_team'], 0))
            home_goals.append(int(row['home_goals']) if pd.notna(row['home_goals']) else 0)
            away_goals.append(int(row['away_goals']) if pd.notna(row['away_goals']) else 0)

            # Calculate time weight
            try:
                match_date = pd.to_datetime(row['date'])
                days_ago = (reference_date - match_date).days
                weights.append(self.time_weight(max(0, days_ago)))
            except:
                weights.append(0.5)

        self.match_data = {
            'home_teams': np.array(home_teams),
            'away_teams': np.array(away_teams),
            'home_goals': np.array(home_goals),
            'away_goals': np.array(away_goals),
            'weights': np.array(weights),
            'n_teams': n_teams
        }

        # Optimize
        result = minimize(
            self._neg_log_likelihood,
            initial_params,
            method='L-BFGS-B',
            options={'maxiter': 100, 'disp': False}
        )

        # Extract parameters
        params = result.x
        team_list = list(self.teams.keys())
        self.attack = {team_list[i]: params[i] for i in range(n_teams)}
        self.defense = {team_list[i]: params[n_teams + i] for i in range(n_teams)}
        self.home_adv = params[-2]
        self.rho = params[-1]

        print(f"  Dixon-Coles fitted: home_adv={self.home_adv:.3f}, rho={self.rho:.3f}")

    def _neg_log_likelihood(self, params):
        """Negative log-likelihood for optimization"""
        n_teams = self.match_data['n_teams']
        attack = params[:n_teams]
        defense = params[n_teams:2*n_teams]
        home_adv = params[-2]
        rho = params[-1]

        home_teams = self.match_data['home_teams']
        away_teams = self.match_data['away_teams']
        home_goals = self.match_data['home_goals']
        away_goals = self.match_data['away_goals']
        weights = self.match_data['weights']

        # Calculate expected goals
        lambda_home = np.exp(attack[home_teams] + defense[away_teams] + home_adv)
        mu_away = np.exp(attack[away_teams] + defense[home_teams])

        # Clip for numerical stability
        lambda_home = np.clip(lambda_home, 0.01, 10)
        mu_away = np.clip(mu_away, 0.01, 10)

        # Log-likelihood
        log_lik = (
            weights * (
                poisson.logpmf(home_goals, lambda_home) +
                poisson.logpmf(away_goals, mu_away) +
                np.log(np.vectorize(lambda h, a, lh, ma: self.tau(h, a, lh, ma, rho))(
                    home_goals, away_goals, lambda_home, mu_away
                ))
            )
        )

        return -np.sum(log_lik)

    def predict_probs(self, home_team, away_team, max_goals=6):
        """Predict match outcome probabilities"""
        attack_home = self.attack.get(home_team, 0)
        defense_home = self.defense.get(home_team, 0)
        attack_away = self.attack.get(away_team, 0)
        defense_away = self.defense.get(away_team, 0)

        lambda_home = np.exp(attack_home + defense_away + self.home_adv)
        mu_away = np.exp(attack_away + defense_home)

        # Clip for stability
        lambda_home = np.clip(lambda_home, 0.1, 5)
        mu_away = np.clip(mu_away, 0.1, 5)

        # Calculate probability matrix
        home_probs = poisson.pmf(range(max_goals), lambda_home)
        away_probs = poisson.pmf(range(max_goals), mu_away)

        prob_matrix = np.outer(home_probs, away_probs)

        # Apply Dixon-Coles correction
        for i in range(min(2, max_goals)):
            for j in range(min(2, max_goals)):
                prob_matrix[i, j] *= self.tau(i, j, lambda_home, mu_away, self.rho)

        # Normalize
        prob_matrix /= prob_matrix.sum()

        # Calculate outcome probabilities
        home_win = np.tril(prob_matrix, -1).sum()
        draw = np.trace(prob_matrix)
        away_win = np.triu(prob_matrix, 1).sum()

        return {
            'home_win': home_win,
            'draw': draw,
            'away_win': away_win,
            'expected_home_goals': lambda_home,
            'expected_away_goals': mu_away
        }


# =============================================================================
# 3. ADVANCED NEURAL NETWORK
# =============================================================================

class AdvancedBettingNet(nn.Module):
    """
    Advanced Neural Network with:
    - Residual connections
    - Batch normalization
    - Attention mechanism for feature importance
    """

    def __init__(self, input_size, hidden_sizes=[256, 256, 128, 64], dropout=0.3):
        super().__init__()

        # Feature attention layer
        self.attention = nn.Sequential(
            nn.Linear(input_size, input_size // 2),
            nn.ReLU(),
            nn.Linear(input_size // 2, input_size),
            nn.Sigmoid()
        )

        # Main network with residual connections
        layers = []
        prev_size = input_size

        for i, hidden_size in enumerate(hidden_sizes):
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.LeakyReLU(0.1))
            layers.append(nn.Dropout(dropout if i < len(hidden_sizes) - 1 else dropout * 0.5))
            prev_size = hidden_size

        self.feature_extractor = nn.Sequential(*layers)
        self.classifier = nn.Linear(prev_size, 3)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Apply attention
        attention_weights = self.attention(x)
        x = x * attention_weights

        # Feature extraction
        features = self.feature_extractor(x)

        # Classification
        return self.classifier(features)


# =============================================================================
# 4. ADVANCED FEATURE ENGINEERING
# =============================================================================

def calculate_advanced_features(df, elo_system, dixon_coles):
    """
    Calculate advanced features including:
    - ELO ratings
    - Dixon-Coles predictions
    - Time-weighted form
    - Expected goals differentials
    """

    print("\n📊 Calculating Advanced Features...")

    features_list = []

    # Sort by date
    df = df.sort_values('date').reset_index(drop=True)

    # Team statistics trackers
    team_stats = defaultdict(lambda: {
        'goals_scored': [],
        'goals_conceded': [],
        'results': [],  # 1=win, 0.5=draw, 0=loss
        'home_goals_scored': [],
        'home_goals_conceded': [],
        'away_goals_scored': [],
        'away_goals_conceded': [],
        'clean_sheets': 0,
        'failed_to_score': 0,
        'matches_played': 0
    })

    h2h_stats = defaultdict(lambda: {'home_wins': 0, 'draws': 0, 'away_wins': 0, 'total': 0})

    for idx, row in df.iterrows():
        if idx % 5000 == 0:
            print(f"  Processing match {idx}/{len(df)}...")

        home = row['home_team']
        away = row['away_team']

        if pd.isna(home) or pd.isna(away):
            continue

        home_goals = int(row['home_goals']) if pd.notna(row['home_goals']) else 0
        away_goals = int(row['away_goals']) if pd.notna(row['away_goals']) else 0

        # Get current ELO ratings
        home_elo, away_elo, elo_expected = elo_system.get_ratings_for_match(home, away)

        # Get Dixon-Coles predictions
        try:
            dc_probs = dixon_coles.predict_probs(home, away)
        except:
            dc_probs = {'home_win': 0.4, 'draw': 0.3, 'away_win': 0.3,
                       'expected_home_goals': 1.3, 'expected_away_goals': 1.0}

        # Calculate form features (last 5 matches)
        home_stats = team_stats[home]
        away_stats = team_stats[away]

        # Time-weighted form (exponential decay)
        def weighted_avg(values, decay=0.7):
            if not values:
                return 0
            weights = [decay ** i for i in range(len(values))]
            return sum(v * w for v, w in zip(reversed(values[-10:]), weights)) / sum(weights)

        # Form metrics
        home_form = weighted_avg(home_stats['results'][-5:]) if home_stats['results'] else 0.5
        away_form = weighted_avg(away_stats['results'][-5:]) if away_stats['results'] else 0.5

        home_attack = weighted_avg(home_stats['goals_scored'][-5:]) if home_stats['goals_scored'] else 1.2
        home_defense = weighted_avg(home_stats['goals_conceded'][-5:]) if home_stats['goals_conceded'] else 1.2
        away_attack = weighted_avg(away_stats['goals_scored'][-5:]) if away_stats['goals_scored'] else 1.0
        away_defense = weighted_avg(away_stats['goals_conceded'][-5:]) if away_stats['goals_conceded'] else 1.0

        # Home/Away specific form
        home_home_attack = weighted_avg(home_stats['home_goals_scored'][-3:]) if home_stats['home_goals_scored'] else 1.3
        home_home_defense = weighted_avg(home_stats['home_goals_conceded'][-3:]) if home_stats['home_goals_conceded'] else 1.0
        away_away_attack = weighted_avg(away_stats['away_goals_scored'][-3:]) if away_stats['away_goals_scored'] else 1.0
        away_away_defense = weighted_avg(away_stats['away_goals_conceded'][-3:]) if away_stats['away_goals_conceded'] else 1.3

        # Head-to-head
        h2h_key = tuple(sorted([home, away]))
        h2h = h2h_stats[h2h_key]
        h2h_home_advantage = (h2h['home_wins'] - h2h['away_wins']) / max(h2h['total'], 1)

        # Streaks
        def calculate_streak(results, result_type):
            if not results:
                return 0
            streak = 0
            for r in reversed(results):
                if (result_type == 'win' and r == 1) or \
                   (result_type == 'unbeaten' and r >= 0.5) or \
                   (result_type == 'loss' and r == 0):
                    streak += 1
                else:
                    break
            return streak

        home_win_streak = calculate_streak(home_stats['results'], 'win')
        home_unbeaten_streak = calculate_streak(home_stats['results'], 'unbeaten')
        away_win_streak = calculate_streak(away_stats['results'], 'win')
        away_loss_streak = calculate_streak(away_stats['results'], 'loss')

        # Clean sheet and scoring rates
        home_clean_sheet_rate = home_stats['clean_sheets'] / max(home_stats['matches_played'], 1)
        home_scoring_rate = 1 - (home_stats['failed_to_score'] / max(home_stats['matches_played'], 1))
        away_clean_sheet_rate = away_stats['clean_sheets'] / max(away_stats['matches_played'], 1)
        away_scoring_rate = 1 - (away_stats['failed_to_score'] / max(away_stats['matches_played'], 1))

        # Expected goals differential
        xg_diff = dc_probs['expected_home_goals'] - dc_probs['expected_away_goals']

        # Build feature vector
        feature_row = {
            # ELO features
            'elo_home': home_elo,
            'elo_away': away_elo,
            'elo_diff': home_elo - away_elo,
            'elo_expected_home': elo_expected,

            # Dixon-Coles features
            'dc_home_prob': dc_probs['home_win'],
            'dc_draw_prob': dc_probs['draw'],
            'dc_away_prob': dc_probs['away_win'],
            'dc_expected_home_goals': dc_probs['expected_home_goals'],
            'dc_expected_away_goals': dc_probs['expected_away_goals'],
            'dc_xg_diff': xg_diff,

            # Form features
            'home_form': home_form,
            'away_form': away_form,
            'form_diff': home_form - away_form,

            # Attack/Defense
            'home_attack': home_attack,
            'home_defense': home_defense,
            'away_attack': away_attack,
            'away_defense': away_defense,
            'attack_vs_defense_home': home_attack - away_defense,
            'attack_vs_defense_away': away_attack - home_defense,

            # Home/Away specific
            'home_home_attack': home_home_attack,
            'home_home_defense': home_home_defense,
            'away_away_attack': away_away_attack,
            'away_away_defense': away_away_defense,

            # H2H
            'h2h_advantage': h2h_home_advantage,
            'h2h_matches': min(h2h['total'], 10) / 10,

            # Streaks
            'home_win_streak': min(home_win_streak, 5) / 5,
            'home_unbeaten_streak': min(home_unbeaten_streak, 10) / 10,
            'away_win_streak': min(away_win_streak, 5) / 5,
            'away_loss_streak': min(away_loss_streak, 5) / 5,

            # Rates
            'home_clean_sheet_rate': home_clean_sheet_rate,
            'home_scoring_rate': home_scoring_rate,
            'away_clean_sheet_rate': away_clean_sheet_rate,
            'away_scoring_rate': away_scoring_rate,

            # Matches played (experience)
            'home_experience': min(home_stats['matches_played'], 50) / 50,
            'away_experience': min(away_stats['matches_played'], 50) / 50,

            # Target
            'result_code': row['result_code'] if 'result_code' in row else None,

            # Metadata
            'date': row['date'],
            'home_team': home,
            'away_team': away,
            'league_code': row.get('league_code', 'UNK')
        }

        features_list.append(feature_row)

        # Update ELO
        elo_system.update_ratings(home, away, home_goals, away_goals, date=row['date'])

        # Update team stats
        home_result = 1 if home_goals > away_goals else (0.5 if home_goals == away_goals else 0)
        away_result = 1 - home_result if home_result != 0.5 else 0.5

        team_stats[home]['goals_scored'].append(home_goals)
        team_stats[home]['goals_conceded'].append(away_goals)
        team_stats[home]['results'].append(home_result)
        team_stats[home]['home_goals_scored'].append(home_goals)
        team_stats[home]['home_goals_conceded'].append(away_goals)
        team_stats[home]['matches_played'] += 1
        if away_goals == 0:
            team_stats[home]['clean_sheets'] += 1
        if home_goals == 0:
            team_stats[home]['failed_to_score'] += 1

        team_stats[away]['goals_scored'].append(away_goals)
        team_stats[away]['goals_conceded'].append(home_goals)
        team_stats[away]['results'].append(away_result)
        team_stats[away]['away_goals_scored'].append(away_goals)
        team_stats[away]['away_goals_conceded'].append(home_goals)
        team_stats[away]['matches_played'] += 1
        if home_goals == 0:
            team_stats[away]['clean_sheets'] += 1
        if away_goals == 0:
            team_stats[away]['failed_to_score'] += 1

        # Update H2H
        h2h_stats[h2h_key]['total'] += 1
        if home_goals > away_goals:
            h2h_stats[h2h_key]['home_wins'] += 1
        elif away_goals > home_goals:
            h2h_stats[h2h_key]['away_wins'] += 1
        else:
            h2h_stats[h2h_key]['draws'] += 1

    features_df = pd.DataFrame(features_list)
    print(f"  ✓ Generated {len(features_df)} feature rows with {len(features_df.columns)} columns")

    return features_df, elo_system


# =============================================================================
# 5. ENSEMBLE TRAINING
# =============================================================================

def train_ensemble(X_train, y_train, X_val, y_val, feature_names):
    """Train XGBoost + Neural Network ensemble"""

    print("\n🎯 Training Ensemble Model...")

    # =========================================================================
    # XGBoost Model
    # =========================================================================
    print("\n  📈 Training XGBoost...")

    xgb_model = xgb.XGBClassifier(
        n_estimators=500,
        max_depth=6,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        min_child_weight=3,
        reg_alpha=0.1,
        reg_lambda=1.0,
        objective='multi:softprob',
        num_class=3,
        eval_metric='mlogloss',
        early_stopping_rounds=50,
        random_state=42
    )

    xgb_model.fit(
        X_train, y_train,
        eval_set=[(X_val, y_val)],
        verbose=False
    )

    xgb_val_probs = xgb_model.predict_proba(X_val)
    xgb_val_pred = xgb_model.predict(X_val)
    xgb_accuracy = accuracy_score(y_val, xgb_val_pred)
    print(f"    XGBoost Validation Accuracy: {xgb_accuracy*100:.2f}%")

    # Feature importance
    importance = xgb_model.feature_importances_
    top_features = sorted(zip(feature_names, importance), key=lambda x: x[1], reverse=True)[:10]
    print("    Top Features:")
    for feat, imp in top_features:
        print(f"      - {feat}: {imp:.4f}")

    # =========================================================================
    # Neural Network Model
    # =========================================================================
    print("\n  🧠 Training Neural Network...")

    # Scale data
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)

    # Create dataloaders
    train_dataset = TensorDataset(
        torch.FloatTensor(X_train_scaled),
        torch.LongTensor(y_train)
    )
    val_dataset = TensorDataset(
        torch.FloatTensor(X_val_scaled),
        torch.LongTensor(y_val)
    )

    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=256)

    # Initialize model
    nn_model = AdvancedBettingNet(X_train.shape[1])

    # Class weights for imbalance
    class_counts = np.bincount(y_train)
    class_weights = torch.FloatTensor(len(y_train) / (3 * class_counts))

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.AdamW(nn_model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=10, factor=0.5)

    best_val_acc = 0
    patience_counter = 0

    for epoch in range(200):
        # Training
        nn_model.train()
        train_loss = 0
        for batch_x, batch_y in train_loader:
            optimizer.zero_grad()
            outputs = nn_model(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(nn_model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()

        # Validation
        nn_model.eval()
        val_preds = []
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                outputs = nn_model(batch_x)
                val_preds.extend(outputs.argmax(dim=1).numpy())

        val_acc = accuracy_score(y_val, val_preds)
        scheduler.step(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = nn_model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= 30:
            break

        if (epoch + 1) % 20 == 0:
            print(f"    Epoch {epoch+1}: Val Acc = {val_acc*100:.2f}%")

    nn_model.load_state_dict(best_state)
    print(f"    Neural Network Best Validation Accuracy: {best_val_acc*100:.2f}%")

    # =========================================================================
    # Ensemble Predictions
    # =========================================================================
    print("\n  🔗 Creating Ensemble...")

    nn_model.eval()
    with torch.no_grad():
        nn_val_probs = torch.softmax(nn_model(torch.FloatTensor(X_val_scaled)), dim=1).numpy()

    # Weighted ensemble (XGBoost 60%, NN 40%)
    ensemble_probs = 0.6 * xgb_val_probs + 0.4 * nn_val_probs
    ensemble_preds = ensemble_probs.argmax(axis=1)
    ensemble_accuracy = accuracy_score(y_val, ensemble_preds)

    print(f"\n  📊 ENSEMBLE VALIDATION ACCURACY: {ensemble_accuracy*100:.2f}%")

    # High confidence analysis
    max_probs = ensemble_probs.max(axis=1)
    for threshold in [0.45, 0.50, 0.55, 0.60]:
        mask = max_probs >= threshold
        if mask.sum() > 0:
            high_conf_acc = accuracy_score(y_val[mask], ensemble_preds[mask])
            print(f"    Confidence >= {threshold*100:.0f}%: {high_conf_acc*100:.1f}% ({mask.sum()} matches)")

    return xgb_model, nn_model, scaler, {
        'xgb_accuracy': xgb_accuracy,
        'nn_accuracy': best_val_acc,
        'ensemble_accuracy': ensemble_accuracy
    }


# =============================================================================
# MAIN TRAINING PIPELINE
# =============================================================================

def main():
    print("\n📂 Loading Data...")

    # Load historical data
    data_path = os.path.join(DATA_DIR, 'advanced_historical_matches.csv')
    if not os.path.exists(data_path):
        print("❌ Data file not found! Run advanced_data_collector.py first.")
        return

    df = pd.read_csv(data_path, low_memory=False)
    print(f"  Loaded {len(df)} matches")

    # Filter and clean
    df = df.dropna(subset=['home_team', 'away_team', 'home_goals', 'away_goals', 'date'])
    df['date'] = pd.to_datetime(df['date'])
    df = df.sort_values('date').reset_index(drop=True)

    # Ensure result_code exists
    if 'result_code' not in df.columns:
        df['result_code'] = df.apply(
            lambda x: 2 if x['home_goals'] > x['away_goals'] else (1 if x['home_goals'] == x['away_goals'] else 0),
            axis=1
        )

    print(f"  After cleaning: {len(df)} matches")
    print(f"  Date range: {df['date'].min()} to {df['date'].max()}")

    # Initialize models
    print("\n🔧 Initializing ELO and Dixon-Coles...")
    elo_system = ELOSystem(k_factor=32, home_advantage=100)
    dixon_coles = DixonColesModel(decay_rate=0.005)

    # Fit Dixon-Coles on training data (first 80%)
    train_cutoff = int(len(df) * 0.8)
    train_df = df.iloc[:train_cutoff].copy()
    test_df = df.iloc[train_cutoff:].copy()

    print(f"  Training set: {len(train_df)} matches")
    print(f"  Test set: {len(test_df)} matches")

    dixon_coles.fit(train_df)

    # Calculate advanced features
    features_df, elo_system = calculate_advanced_features(df, elo_system, dixon_coles)

    # Split features
    train_features = features_df.iloc[:train_cutoff].copy()
    test_features = features_df.iloc[train_cutoff:].copy()

    # Remove rows with missing target
    train_features = train_features.dropna(subset=['result_code'])
    test_features = test_features.dropna(subset=['result_code'])

    # Feature columns
    feature_cols = [c for c in train_features.columns if c not in
                   ['result_code', 'date', 'home_team', 'away_team', 'league_code']]

    print(f"\n  Features: {len(feature_cols)}")

    # Prepare data
    X_train = train_features[feature_cols].values
    y_train = train_features['result_code'].astype(int).values
    X_test = test_features[feature_cols].values
    y_test = test_features['result_code'].astype(int).values

    # Handle any remaining NaN/Inf
    X_train = np.nan_to_num(X_train, nan=0, posinf=10, neginf=-10)
    X_test = np.nan_to_num(X_test, nan=0, posinf=10, neginf=-10)

    # Split training into train/val
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
    )

    print(f"  Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")

    # Train ensemble
    xgb_model, nn_model, scaler, metrics = train_ensemble(X_train, y_train, X_val, y_val, feature_cols)

    # =========================================================================
    # FINAL TEST EVALUATION
    # =========================================================================
    print("\n" + "="*70)
    print("📊 FINAL TEST SET EVALUATION")
    print("="*70)

    # XGBoost predictions
    xgb_test_probs = xgb_model.predict_proba(X_test)

    # Neural Network predictions
    X_test_scaled = scaler.transform(X_test)
    nn_model.eval()
    with torch.no_grad():
        nn_test_probs = torch.softmax(nn_model(torch.FloatTensor(X_test_scaled)), dim=1).numpy()

    # Ensemble
    ensemble_probs = 0.6 * xgb_test_probs + 0.4 * nn_test_probs
    ensemble_preds = ensemble_probs.argmax(axis=1)

    test_accuracy = accuracy_score(y_test, ensemble_preds)
    print(f"\n  🎯 OVERALL TEST ACCURACY: {test_accuracy*100:.2f}%")

    # Confusion matrix
    print("\n  Confusion Matrix:")
    cm = confusion_matrix(y_test, ensemble_preds)
    labels = ['AWAY', 'DRAW', 'HOME']
    for i, label in enumerate(labels):
        print(f"    {label}: {cm[i]}")

    # Classification report
    print("\n  Classification Report:")
    print(classification_report(y_test, ensemble_preds, target_names=labels))

    # High confidence analysis
    print("\n  📈 HIGH CONFIDENCE PREDICTIONS:")
    max_probs = ensemble_probs.max(axis=1)

    for threshold in [0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70]:
        mask = max_probs >= threshold
        if mask.sum() > 10:
            high_conf_acc = accuracy_score(y_test[mask], ensemble_preds[mask])
            pct = mask.sum() / len(y_test) * 100
            print(f"    Confidence >= {threshold*100:.0f}%: {high_conf_acc*100:.1f}% accuracy ({mask.sum()} matches, {pct:.1f}%)")

    # =========================================================================
    # SAVE MODELS
    # =========================================================================
    print("\n💾 Saving Models...")

    # Save XGBoost
    xgb_model.save_model(os.path.join(MODELS_DIR, 'xgb_model_v4.json'))

    # Save Neural Network
    torch.save(nn_model.state_dict(), os.path.join(MODELS_DIR, 'nn_model_v4.pth'))

    # Save scaler
    np.save(os.path.join(MODELS_DIR, 'scaler_v4.npy'), {
        'mean': scaler.mean_,
        'scale': scaler.scale_
    })

    # Save metadata
    meta = {
        'version': 'V4-Advanced',
        'trained_at': datetime.now().isoformat(),
        'feature_columns': feature_cols,
        'n_features': len(feature_cols),
        'training_matches': len(X_train) + len(X_val),
        'test_matches': len(X_test),
        'metrics': {
            'test_accuracy': float(test_accuracy),
            'xgb_accuracy': float(metrics['xgb_accuracy']),
            'nn_accuracy': float(metrics['nn_accuracy']),
            'ensemble_accuracy': float(metrics['ensemble_accuracy'])
        },
        'ensemble_weights': {'xgb': 0.6, 'nn': 0.4}
    }

    with open(os.path.join(MODELS_DIR, 'model_v4_meta.json'), 'w') as f:
        json.dump(meta, f, indent=2)

    # Save ELO ratings
    elo_ratings = {team: rating for team, rating in elo_system.ratings.items()}
    with open(os.path.join(MODELS_DIR, 'elo_ratings.json'), 'w') as f:
        json.dump(elo_ratings, f, indent=2)

    # Save Dixon-Coles parameters
    dc_params = {
        'attack': dixon_coles.attack,
        'defense': dixon_coles.defense,
        'home_adv': float(dixon_coles.home_adv),
        'rho': float(dixon_coles.rho)
    }
    with open(os.path.join(MODELS_DIR, 'dixon_coles_params.json'), 'w') as f:
        json.dump(dc_params, f, indent=2)

    print("\n" + "="*70)
    print("✅ TRAINING COMPLETE!")
    print("="*70)
    print(f"  Model: V4-Advanced (ELO + Dixon-Coles + XGBoost + NN Ensemble)")
    print(f"  Test Accuracy: {test_accuracy*100:.2f}%")
    print(f"  High Confidence (>50%): Check results above")
    print("="*70)


if __name__ == '__main__':
    main()
