#!/usr/bin/env python3
"""
Generate verification results using V4 Ensemble Model
Downloads recent results from football-data.co.uk and verifies predictions
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import json
import os
import requests
import io
import xgboost as xgb
from datetime import datetime, timedelta

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_DIR, '..', 'models')
DATA_DIR = os.path.join(BASE_DIR, '..', 'data')

# Football-data.co.uk URLs for current season
LEAGUE_URLS = {
    'E0': 'https://www.football-data.co.uk/mmz4281/2526/E0.csv',  # Premier League
    'E1': 'https://www.football-data.co.uk/mmz4281/2526/E1.csv',  # Championship
    'D1': 'https://www.football-data.co.uk/mmz4281/2526/D1.csv',  # Bundesliga
    'D2': 'https://www.football-data.co.uk/mmz4281/2526/D2.csv',  # 2. Bundesliga
    'I1': 'https://www.football-data.co.uk/mmz4281/2526/I1.csv',  # Serie A
    'I2': 'https://www.football-data.co.uk/mmz4281/2526/I2.csv',  # Serie B
    'SP1': 'https://www.football-data.co.uk/mmz4281/2526/SP1.csv',  # La Liga
    'SP2': 'https://www.football-data.co.uk/mmz4281/2526/SP2.csv',  # La Liga 2
    'F1': 'https://www.football-data.co.uk/mmz4281/2526/F1.csv',  # Ligue 1
    'F2': 'https://www.football-data.co.uk/mmz4281/2526/F2.csv',  # Ligue 2
    'N1': 'https://www.football-data.co.uk/mmz4281/2526/N1.csv',  # Eredivisie
    'P1': 'https://www.football-data.co.uk/mmz4281/2526/P1.csv',  # Primeira Liga
    'B1': 'https://www.football-data.co.uk/mmz4281/2526/B1.csv',  # Pro League
    'T1': 'https://www.football-data.co.uk/mmz4281/2526/T1.csv',  # Super Lig
    'G1': 'https://www.football-data.co.uk/mmz4281/2526/G1.csv',  # Super League Greece
    'SC0': 'https://www.football-data.co.uk/mmz4281/2526/SC0.csv',  # Scottish Premier
}


def fetch_recent_results(days=30):
    """Download recent match results from football-data.co.uk"""
    print(f"📥 Downloading recent results (last {days} days)...")

    cutoff_date = datetime.now() - timedelta(days=days)
    all_matches = []

    for league_code, url in LEAGUE_URLS.items():
        try:
            response = requests.get(url, timeout=15)
            if response.status_code != 200:
                print(f"  ⚠️ {league_code}: HTTP {response.status_code}")
                continue

            df = pd.read_csv(io.StringIO(response.text))

            # Parse date - football-data uses DD/MM/YYYY
            df['Date'] = pd.to_datetime(df['Date'], format='%d/%m/%Y', errors='coerce')
            df = df.dropna(subset=['Date'])

            # Filter to last N days
            df = df[df['Date'] >= cutoff_date]

            # Only completed matches (have goals)
            df = df.dropna(subset=['FTHG', 'FTAG'])

            if len(df) > 0:
                df['league_code'] = league_code
                all_matches.append(df)
                print(f"  ✓ {LEAGUE_NAMES.get(league_code, league_code)}: {len(df)} partite")

        except Exception as e:
            print(f"  ❌ {league_code}: {e}")

    if not all_matches:
        print("  ❌ Nessun dato scaricato!")
        return None

    combined = pd.concat(all_matches, ignore_index=True)
    print(f"  📊 Totale: {len(combined)} partite negli ultimi {days} giorni")
    return combined


class AdvancedBettingNet(nn.Module):
    """Same architecture as training"""
    def __init__(self, input_size, hidden_sizes=[256, 256, 128, 64], dropout=0.3):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_size, input_size // 2),
            nn.ReLU(),
            nn.Linear(input_size // 2, input_size),
            nn.Sigmoid()
        )
        layers = []
        prev_size = input_size
        for i, hidden_size in enumerate(hidden_sizes):
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.LeakyReLU(0.1))
            layers.append(nn.Dropout(dropout if i < len(hidden_sizes) - 1 else dropout * 0.5))
            prev_size = hidden_size
        self.feature_extractor = nn.Sequential(*layers)
        self.classifier = nn.Linear(prev_size, 3)

    def forward(self, x):
        attention_weights = self.attention(x)
        x = x * attention_weights
        features = self.feature_extractor(x)
        return self.classifier(features)


RESULT_MAP = {0: 'AWAY', 1: 'DRAW', 2: 'HOME'}
LEAGUE_NAMES = {
    'E0': 'Premier League',
    'E1': 'Championship',
    'D1': 'Bundesliga',
    'D2': '2. Bundesliga',
    'I1': 'Serie A',
    'I2': 'Serie B',
    'SP1': 'La Liga',
    'SP2': 'La Liga 2',
    'F1': 'Ligue 1',
    'F2': 'Ligue 2',
    'N1': 'Eredivisie',
    'P1': 'Primeira Liga',
    'B1': 'Pro League',
    'T1': 'Super Lig',
    'G1': 'Super League',
    'SC0': 'Scottish Premier',
}
CORE_LEAGUES = ['E0', 'E1', 'D1', 'D2', 'I1', 'I2', 'SP1', 'SP2', 'F1', 'F2', 'N1', 'P1', 'B1', 'T1', 'G1', 'SC0']


def load_v4_models():
    """Load V4 ensemble models"""
    print("📂 Loading V4 Models...")

    # Load metadata
    meta_path = os.path.join(MODELS_DIR, 'model_v4_meta.json')
    if not os.path.exists(meta_path):
        print("❌ V4 model metadata not found!")
        return None

    with open(meta_path) as f:
        meta = json.load(f)

    features = meta['feature_columns']
    print(f"  Features: {len(features)}")

    # Load XGBoost
    xgb_model = xgb.XGBClassifier()
    xgb_model.load_model(os.path.join(MODELS_DIR, 'xgb_model_v4.json'))
    print("  ✓ XGBoost loaded")

    # Load Neural Network
    nn_model = AdvancedBettingNet(len(features))
    nn_model.load_state_dict(torch.load(
        os.path.join(MODELS_DIR, 'nn_model_v4.pth'),
        map_location='cpu',
        weights_only=True
    ))
    nn_model.eval()
    print("  ✓ Neural Network loaded")

    # Load scaler
    scaler_data = np.load(
        os.path.join(MODELS_DIR, 'scaler_v4.npy'),
        allow_pickle=True
    ).item()
    print("  ✓ Scaler loaded")

    # Load ELO ratings
    with open(os.path.join(MODELS_DIR, 'elo_ratings.json')) as f:
        elo_ratings = json.load(f)
    print(f"  ✓ ELO ratings loaded ({len(elo_ratings)} teams)")

    return {
        'meta': meta,
        'features': features,
        'xgb': xgb_model,
        'nn': nn_model,
        'scaler': scaler_data,
        'elo_ratings': elo_ratings
    }


def generate_verification(days=30):
    """Generate verification results from REAL recent matches (downloaded from football-data.co.uk)"""

    print(f"\n🔄 Generating V4 Verification Results (ultimi {days} giorni)...")

    models = load_v4_models()
    if not models:
        return None

    features = models['features']
    xgb_model = models['xgb']
    nn_model = models['nn']
    scaler = models['scaler']

    # Download REAL recent results from football-data.co.uk
    test_df = fetch_recent_results(days=days)
    if test_df is None or len(test_df) == 0:
        print("  ❌ Nessuna partita recente disponibile!")
        return None

    print(f"  ✓ Verification matches: {len(test_df)}")

    # We need to calculate features for these matches
    # Since we don't have all the historical context, we'll use simplified features
    # In production, this would use the full feature calculation pipeline

    # Process matches from football-data.co.uk format
    # Columns: Date, HomeTeam, AwayTeam, FTHG, FTAG, etc.

    results = []
    elo_ratings = models['elo_ratings']

    for _, row in test_df.iterrows():
        # Use football-data.co.uk column names
        if pd.isna(row.get('HomeTeam')) or pd.isna(row.get('AwayTeam')) or pd.isna(row.get('Date')):
            continue

        home = str(row['HomeTeam'])
        away = str(row['AwayTeam'])

        # Get ELO ratings
        home_elo = elo_ratings.get(home, 1500)
        away_elo = elo_ratings.get(away, 1500)
        home_elo_adj = home_elo + 100  # Home advantage

        # Calculate expected result
        elo_expected = 1 / (1 + 10 ** ((away_elo - home_elo_adj) / 400))

        # Build feature vector with available data
        feature_row = {}

        # ELO features
        feature_row['elo_home'] = home_elo_adj
        feature_row['elo_away'] = away_elo
        feature_row['elo_diff'] = home_elo_adj - away_elo
        feature_row['elo_expected_home'] = elo_expected

        # Dixon-Coles features (approximate from ELO)
        feature_row['dc_home_prob'] = elo_expected * 0.9
        feature_row['dc_draw_prob'] = 0.25
        feature_row['dc_away_prob'] = (1 - elo_expected) * 0.9
        feature_row['dc_expected_home_goals'] = 1.3 + (home_elo - 1500) / 500
        feature_row['dc_expected_away_goals'] = 1.1 + (away_elo - 1500) / 500
        feature_row['dc_xg_diff'] = feature_row['dc_expected_home_goals'] - feature_row['dc_expected_away_goals']

        # Form features (from CSV if available, otherwise default)
        feature_row['home_form'] = row.get('home_form_5m', 0.5)
        feature_row['away_form'] = row.get('away_form_5m', 0.5)
        feature_row['form_diff'] = feature_row['home_form'] - feature_row['away_form']

        # Attack/Defense
        feature_row['home_attack'] = row.get('home_goals_scored_avg', 1.3)
        feature_row['home_defense'] = row.get('home_goals_conceded_avg', 1.1)
        feature_row['away_attack'] = row.get('away_goals_scored_avg', 1.1)
        feature_row['away_defense'] = row.get('away_goals_conceded_avg', 1.3)
        feature_row['attack_vs_defense_home'] = feature_row['home_attack'] - feature_row['away_defense']
        feature_row['attack_vs_defense_away'] = feature_row['away_attack'] - feature_row['home_defense']

        # Home/Away specific
        feature_row['home_home_attack'] = row.get('home_home_goals_scored_avg', 1.4)
        feature_row['home_home_defense'] = row.get('home_home_goals_conceded_avg', 1.0)
        feature_row['away_away_attack'] = row.get('away_away_goals_scored_avg', 1.0)
        feature_row['away_away_defense'] = row.get('away_away_goals_conceded_avg', 1.4)

        # H2H
        feature_row['h2h_advantage'] = row.get('h2h_home_win_rate', 0.5) - row.get('h2h_away_win_rate', 0.3)
        feature_row['h2h_matches'] = min(row.get('h2h_matches', 5), 10) / 10

        # Streaks
        feature_row['home_win_streak'] = min(row.get('home_streak_wins', 0), 5) / 5
        feature_row['home_unbeaten_streak'] = min(row.get('home_streak_unbeaten', 0), 10) / 10
        feature_row['away_win_streak'] = min(row.get('away_streak_wins', 0), 5) / 5
        feature_row['away_loss_streak'] = min(row.get('away_streak_losses', 0), 5) / 5

        # Rates
        feature_row['home_clean_sheet_rate'] = row.get('home_clean_sheet_rate', 0.3)
        feature_row['home_scoring_rate'] = row.get('home_scoring_rate', 0.75)
        feature_row['away_clean_sheet_rate'] = row.get('away_clean_sheet_rate', 0.2)
        feature_row['away_scoring_rate'] = row.get('away_scoring_rate', 0.7)

        # Experience
        feature_row['home_experience'] = 0.8
        feature_row['away_experience'] = 0.8

        # Create feature vector in correct order
        X = np.array([[feature_row.get(f, 0) for f in features]])
        X = np.nan_to_num(X, nan=0, posinf=10, neginf=-10)

        # Scale
        X_scaled = (X - scaler['mean']) / scaler['scale']

        # XGBoost prediction
        xgb_probs = xgb_model.predict_proba(X)

        # Neural Network prediction
        with torch.no_grad():
            nn_probs = torch.softmax(nn_model(torch.FloatTensor(X_scaled)), dim=1).numpy()

        # Ensemble (60% XGB, 40% NN)
        probs = 0.6 * xgb_probs + 0.4 * nn_probs
        pred = int(probs[0].argmax())
        confidence = float(probs[0].max() * 100)

        # Actual result (football-data.co.uk uses FTHG/FTAG)
        home_goals = int(row['FTHG']) if pd.notna(row.get('FTHG')) else 0
        away_goals = int(row['FTAG']) if pd.notna(row.get('FTAG')) else 0

        if home_goals > away_goals:
            actual_code = 2  # HOME
        elif home_goals < away_goals:
            actual_code = 0  # AWAY
        else:
            actual_code = 1  # DRAW

        # Format date and time (football-data.co.uk columns)
        match_date = pd.to_datetime(row['Date'])
        date_str = match_date.strftime('%Y-%m-%d')
        time_str = row.get('Time', '15:00') if pd.notna(row.get('Time')) else '15:00'

        results.append({
            'league': str(row['league_code']),
            'league_name': LEAGUE_NAMES.get(row['league_code'], str(row['league_code'])),
            'date': date_str,
            'time': str(time_str)[:5] if time_str else '15:00',
            'home_team': home,
            'away_team': away,
            'home_goals': home_goals,
            'away_goals': away_goals,
            'actual': RESULT_MAP[actual_code],
            'actual_code': actual_code,
            'predicted': RESULT_MAP[pred],
            'predicted_code': pred,
            'confidence': round(confidence, 1),
            'home_prob': round(float(probs[0][2] * 100), 1),
            'draw_prob': round(float(probs[0][1] * 100), 1),
            'away_prob': round(float(probs[0][0] * 100), 1),
            'home_elo': round(home_elo),
            'away_elo': round(away_elo),
            'correct': actual_code == pred
        })

    # Sort by date (newest first)
    results = sorted(results, key=lambda x: x['date'], reverse=True)

    # Calculate statistics
    correct = sum(1 for r in results if r['correct'])
    total = len(results)

    # By confidence level
    stats_by_confidence = {}
    for threshold in [40, 45, 50, 55, 60, 65, 70]:
        high_conf = [r for r in results if r['confidence'] >= threshold]
        if high_conf:
            high_conf_correct = sum(1 for r in high_conf if r['correct'])
            stats_by_confidence[threshold] = {
                'total': len(high_conf),
                'correct': high_conf_correct,
                'accuracy': round(100 * high_conf_correct / len(high_conf), 1)
            }

    # By league
    by_league = {}
    for league in CORE_LEAGUES:
        league_results = [r for r in results if r['league'] == league]
        if league_results:
            league_correct = sum(1 for r in league_results if r['correct'])
            by_league[league] = {
                'name': LEAGUE_NAMES[league],
                'total': len(league_results),
                'correct': league_correct,
                'accuracy': round(100 * league_correct / len(league_results), 1)
            }

    # Top verified matches (highest confidence + correct)
    top_verified = sorted(
        [r for r in results if r['correct']],
        key=lambda x: x['confidence'],
        reverse=True
    )[:20]

    # Compile output
    output = {
        'generated_at': datetime.now().isoformat(),
        'model_version': 'V4-Advanced',
        'summary': {
            'total': total,
            'correct': correct,
            'accuracy': round(100 * correct / total, 1) if total > 0 else 0
        },
        'by_confidence': stats_by_confidence,
        'by_league': by_league,
        'results': results,
        'top_verified': top_verified
    }

    # Save
    output_path = os.path.join(DATA_DIR, 'verification_results.json')
    with open(output_path, 'w') as f:
        json.dump(output, f, indent=2)

    print(f"\n{'='*60}")
    print(f"✅ VERIFICA V4 COMPLETATA")
    print(f"{'='*60}")
    print(f"  Totale partite: {total}")
    print(f"  Corrette: {correct}")
    print(f"  Accuracy: {100*correct/total:.1f}%")
    print()

    print("  Per livello di confidenza:")
    for threshold, stats in stats_by_confidence.items():
        print(f"    ≥{threshold}%: {stats['accuracy']:.1f}% ({stats['correct']}/{stats['total']})")

    print()
    for league, stats in by_league.items():
        print(f"  {stats['name']}: {stats['correct']}/{stats['total']} ({stats['accuracy']:.0f}%)")

    print(f"\n  Salvato in: {output_path}")

    return output


if __name__ == '__main__':
    generate_verification()
