#!/usr/bin/env python3
"""
BetPredictAI - Advanced Model Training V3

Uses 67 ML features from advanced data collection:
- Rolling team statistics (form, goals, streaks)
- Home/Away specific performance
- Head-to-head records
- Betting odds implied probabilities
- Form differentials

Target: Improve accuracy beyond 62%
"""

import os
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from datetime import datetime

# Configuration
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, '..', 'data')
MODELS_DIR = os.path.join(BASE_DIR, '..', 'models')

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ML Features to use (67 features)
ML_FEATURES = [
    # Home team form (14 features)
    'home_form_ppg', 'home_form_gpg', 'home_form_gapg', 'home_form_win_rate',
    'home_home_ppg', 'home_home_gpg', 'home_away_ppg', 'home_away_gpg',
    'home_form_wins', 'home_form_draws', 'home_form_losses',
    'home_streak', 'home_clean_sheets', 'home_failed_to_score',

    # Away team form (14 features)
    'away_form_ppg', 'away_form_gpg', 'away_form_gapg', 'away_form_win_rate',
    'away_home_ppg', 'away_home_gpg', 'away_away_ppg', 'away_away_gpg',
    'away_form_wins', 'away_form_draws', 'away_form_losses',
    'away_streak', 'away_clean_sheets', 'away_failed_to_score',

    # Head-to-head (7 features)
    'h2h_matches', 'h2h_home_wins', 'h2h_draws', 'h2h_away_wins',
    'h2h_home_goals', 'h2h_away_goals', 'h2h_total_goals',

    # Betting odds implied probabilities (3 features)
    'impl_home', 'impl_draw', 'impl_away',

    # Form differentials (5 features)
    'ppg_diff', 'attack_diff', 'defense_diff', 'win_rate_diff', 'home_home_vs_away_away',

    # Streak differential (1 feature)
    'streak_diff',
]

# Additional derived features we'll create
DERIVED_FEATURES = [
    'form_momentum_home', 'form_momentum_away',
    'goal_expectancy', 'defensive_strength_diff',
    'h2h_dominance', 'odds_value_home', 'odds_value_away',
]


class AdvancedBettingModel(nn.Module):
    """
    Advanced neural network with:
    - Deeper architecture
    - Residual connections
    - Dropout for regularization
    - Batch normalization
    """

    def __init__(self, input_size, hidden_sizes=[512, 256, 256, 128, 64], dropout=0.3):
        super().__init__()

        layers = []
        prev_size = input_size

        for i, hidden_size in enumerate(hidden_sizes):
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.LeakyReLU(0.1))
            layers.append(nn.Dropout(dropout if i < len(hidden_sizes) - 1 else dropout * 0.5))
            prev_size = hidden_size

        self.feature_extractor = nn.Sequential(*layers)
        self.classifier = nn.Linear(prev_size, 3)  # 3 classes: Away, Draw, Home

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features)


def load_and_prepare_data():
    """Load advanced data and prepare features"""
    print("\n📊 Loading advanced dataset...")

    df = pd.read_csv(os.path.join(DATA_DIR, 'advanced_historical_matches.csv'))
    print(f"  Loaded {len(df)} matches")

    # Check available features
    available_features = [f for f in ML_FEATURES if f in df.columns]
    missing_features = [f for f in ML_FEATURES if f not in df.columns]

    print(f"  Available features: {len(available_features)}/{len(ML_FEATURES)}")
    if missing_features:
        print(f"  Missing features: {missing_features[:5]}...")

    # Create derived features
    print("\n🔧 Creating derived features...")

    # Form momentum (recent form trend)
    if 'home_form_wins' in df.columns and 'home_form_losses' in df.columns:
        df['form_momentum_home'] = (df['home_form_wins'] - df['home_form_losses']) / 5
        df['form_momentum_away'] = (df['away_form_wins'] - df['away_form_losses']) / 5
        available_features.extend(['form_momentum_home', 'form_momentum_away'])

    # Goal expectancy
    if 'home_form_gpg' in df.columns and 'away_form_gapg' in df.columns:
        df['goal_expectancy'] = (df['home_form_gpg'] + df['away_form_gpg']) / 2
        available_features.append('goal_expectancy')

    # Defensive strength differential
    if 'home_form_gapg' in df.columns and 'away_form_gapg' in df.columns:
        df['defensive_strength_diff'] = df['away_form_gapg'] - df['home_form_gapg']
        available_features.append('defensive_strength_diff')

    # H2H dominance score
    if 'h2h_home_wins' in df.columns and 'h2h_away_wins' in df.columns:
        df['h2h_dominance'] = df['h2h_home_wins'] - df['h2h_away_wins']
        available_features.append('h2h_dominance')

    # Odds value (difference between model prob and implied prob)
    if 'impl_home' in df.columns and 'home_form_win_rate' in df.columns:
        df['odds_value_home'] = df['home_form_win_rate'] - df['impl_home'].fillna(0.33)
        df['odds_value_away'] = df['away_form_win_rate'] - df['impl_away'].fillna(0.33)
        available_features.extend(['odds_value_home', 'odds_value_away'])

    print(f"  Total features after derivation: {len(available_features)}")

    # Filter to matches with complete data
    df_clean = df.dropna(subset=['result_code'] + available_features[:20]).copy()  # Key features
    print(f"  Matches with complete data: {len(df_clean)}")

    # Fill remaining NaNs with median
    for col in available_features:
        if col in df_clean.columns:
            df_clean.loc[:, col] = df_clean[col].fillna(df_clean[col].median())

    return df_clean, available_features


def create_datasets(df, features, test_size=0.15, val_size=0.15):
    """Create train/val/test datasets"""
    print("\n📦 Creating datasets...")

    # Sort by date to ensure temporal split
    df = df.sort_values('date').reset_index(drop=True)

    # Extract features and target
    X = df[features].values
    y = df['result_code'].values.astype(int)  # 0=Away, 1=Draw, 2=Home

    # Scale features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # Temporal split (last 15% for test, previous 15% for validation)
    n = len(X)
    test_idx = int(n * (1 - test_size))
    val_idx = int(n * (1 - test_size - val_size))

    X_train, y_train = X_scaled[:val_idx], y[:val_idx]
    X_val, y_val = X_scaled[val_idx:test_idx], y[val_idx:test_idx]
    X_test, y_test = X_scaled[test_idx:], y[test_idx:]

    print(f"  Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}")

    # Class distribution
    for name, labels in [('Train', y_train), ('Val', y_val), ('Test', y_test)]:
        dist = np.bincount(labels, minlength=3)
        print(f"  {name} distribution: Away={dist[0]}, Draw={dist[1]}, Home={dist[2]}")

    # Convert to tensors
    train_dataset = TensorDataset(
        torch.FloatTensor(X_train),
        torch.LongTensor(y_train)
    )
    val_dataset = TensorDataset(
        torch.FloatTensor(X_val),
        torch.LongTensor(y_val)
    )
    test_dataset = TensorDataset(
        torch.FloatTensor(X_test),
        torch.LongTensor(y_test)
    )

    return train_dataset, val_dataset, test_dataset, scaler, df.iloc[test_idx:]


def train_model(model, train_loader, val_loader, epochs=150, patience=20):
    """Train model with early stopping and learning rate scheduling"""
    print("\n🚀 Training model...")

    # Class weights for imbalanced data
    criterion = nn.CrossEntropyLoss()

    # Optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=10
    )

    best_val_acc = 0
    best_model_state = None
    patience_counter = 0
    history = []

    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += y_batch.size(0)
            train_correct += predicted.eq(y_batch).sum().item()

        train_acc = 100 * train_correct / train_total

        # Validation phase
        model.eval()
        val_correct = 0
        val_total = 0
        class_correct = [0, 0, 0]
        class_total = [0, 0, 0]

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                outputs = model(X_batch)
                _, predicted = outputs.max(1)

                val_total += y_batch.size(0)
                val_correct += predicted.eq(y_batch).sum().item()

                for i in range(3):
                    mask = y_batch == i
                    class_total[i] += mask.sum().item()
                    class_correct[i] += (predicted[mask] == i).sum().item()

        val_acc = 100 * val_correct / val_total
        class_acc = [100 * c / t if t > 0 else 0 for c, t in zip(class_correct, class_total)]

        # Update scheduler
        scheduler.step(val_acc)

        # Save history
        history.append({
            'epoch': epoch + 1,
            'train_acc': train_acc,
            'val_acc': val_acc,
            'class_acc': class_acc
        })

        # Early stopping check
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
            patience_counter = 0

            print(f"  Epoch {epoch+1:3d}: Train={train_acc:.1f}% | Val={val_acc:.1f}% ⭐ | "
                  f"Away={class_acc[0]:.0f}% Draw={class_acc[1]:.0f}% Home={class_acc[2]:.0f}%")
        else:
            patience_counter += 1
            if (epoch + 1) % 10 == 0:
                print(f"  Epoch {epoch+1:3d}: Train={train_acc:.1f}% | Val={val_acc:.1f}%")

        if patience_counter >= patience:
            print(f"\n  Early stopping at epoch {epoch+1}")
            break

    # Load best model
    model.load_state_dict(best_model_state)

    return model, best_val_acc, history


def evaluate_model(model, test_loader, test_df):
    """Evaluate model on test set"""
    print("\n📈 Evaluating on test set...")

    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            outputs = model(X_batch)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(y_batch.numpy())
            all_probs.extend(probs.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    # Overall accuracy
    accuracy = 100 * (all_preds == all_labels).sum() / len(all_labels)

    # Per-class accuracy
    class_names = ['Away', 'Draw', 'Home']
    class_acc = []
    for i in range(3):
        mask = all_labels == i
        if mask.sum() > 0:
            acc = 100 * (all_preds[mask] == i).sum() / mask.sum()
            class_acc.append(acc)
            print(f"  {class_names[i]}: {acc:.1f}% ({mask.sum()} matches)")
        else:
            class_acc.append(0)

    print(f"\n  ✅ Overall Test Accuracy: {accuracy:.2f}%")

    # Confidence analysis
    confidences = all_probs.max(axis=1)
    high_conf_mask = confidences > 0.5
    if high_conf_mask.sum() > 0:
        high_conf_acc = 100 * (all_preds[high_conf_mask] == all_labels[high_conf_mask]).sum() / high_conf_mask.sum()
        print(f"  High confidence (>50%) accuracy: {high_conf_acc:.1f}% ({high_conf_mask.sum()} matches)")

    return accuracy, class_acc, all_preds, all_labels, all_probs


def save_model(model, scaler, features, val_acc, class_acc, history):
    """Save model and metadata"""
    print("\n💾 Saving model...")

    os.makedirs(MODELS_DIR, exist_ok=True)

    # Save model
    model_path = os.path.join(MODELS_DIR, 'betting_model_v3.pth')
    torch.save(model.state_dict(), model_path)
    print(f"  Model saved to {model_path}")

    # Save scaler
    scaler_path = os.path.join(MODELS_DIR, 'scaler_v3.npy')
    np.save(scaler_path, {
        'mean': scaler.mean_,
        'scale': scaler.scale_
    })

    # Save metadata
    meta = {
        'trained_at': datetime.now().isoformat(),
        'best_val_acc': val_acc,
        'class_accuracy': class_acc,
        'feature_columns': features,
        'model_version': 'v3',
        'architecture': '512-256-256-128-64',
        'total_features': len(features),
        'history': history[-20:]  # Last 20 epochs
    }

    meta_path = os.path.join(MODELS_DIR, 'betting_model_v3_meta.json')
    with open(meta_path, 'w') as f:
        json.dump(meta, f, indent=2)
    print(f"  Metadata saved to {meta_path}")

    return model_path


def main():
    print("="*60)
    print("🧠 BetPredictAI - Advanced Model Training V3")
    print("="*60)

    # Load data
    df, features = load_and_prepare_data()

    # Create datasets
    train_ds, val_ds, test_ds, scaler, test_df = create_datasets(df, features)

    # Create data loaders
    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=128)
    test_loader = DataLoader(test_ds, batch_size=128)

    # Create model
    model = AdvancedBettingModel(
        input_size=len(features),
        hidden_sizes=[512, 256, 256, 128, 64],
        dropout=0.3
    ).to(device)

    print(f"\n🏗️ Model architecture:")
    print(f"  Input: {len(features)} features")
    print(f"  Hidden: 512 → 256 → 256 → 128 → 64")
    print(f"  Output: 3 classes (Away/Draw/Home)")
    total_params = sum(p.numel() for p in model.parameters())
    print(f"  Total parameters: {total_params:,}")

    # Train
    model, val_acc, history = train_model(model, train_loader, val_loader, epochs=150, patience=25)

    # Evaluate
    test_acc, class_acc, preds, labels, probs = evaluate_model(model, test_loader, test_df)

    # Save
    save_model(model, scaler, features, val_acc, class_acc, history)

    print("\n" + "="*60)
    print("✅ TRAINING COMPLETE")
    print("="*60)
    print(f"  Best Validation Accuracy: {val_acc:.2f}%")
    print(f"  Test Accuracy: {test_acc:.2f}%")
    print(f"  Features used: {len(features)}")
    print("="*60)

    return model, test_acc


if __name__ == '__main__':
    main()
