#!/usr/bin/env python3
"""
Generate verification results using the current model (V3)
Dynamically creates verification data from the most recent matches
"""

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import json
import os

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_DIR, '..', 'models')
DATA_DIR = os.path.join(BASE_DIR, '..', 'data')

# Model definition (same as in training)
class BettingNetV3(nn.Module):
    def __init__(self, input_size, hidden_sizes=[512, 256, 256, 128, 64], dropout=0.3):
        super().__init__()
        layers = []
        prev_size = input_size
        for i, hidden_size in enumerate(hidden_sizes):
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.LeakyReLU(0.1))
            layers.append(nn.Dropout(dropout if i < len(hidden_sizes) - 1 else dropout * 0.5))
            prev_size = hidden_size
        self.feature_extractor = nn.Sequential(*layers)
        self.classifier = nn.Linear(prev_size, 3)

    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features)


RESULT_MAP = {0: 'AWAY', 1: 'DRAW', 2: 'HOME'}
LEAGUE_NAMES = {
    'E0': 'Premier League',
    'D1': 'Bundesliga',
    'I1': 'Serie A',
    'SP1': 'La Liga',
    'F1': 'Ligue 1'
}
CORE_LEAGUES = ['E0', 'D1', 'I1', 'SP1', 'F1']


def generate_verification(matches_per_league=20):
    """Generate verification results from recent matches"""

    print("🔄 Generating verification results with V3 model...")

    # Load meta
    meta_path = os.path.join(MODELS_DIR, 'betting_model_v3_meta.json')
    if not os.path.exists(meta_path):
        print("❌ Model V3 meta not found!")
        return None

    with open(meta_path) as f:
        meta = json.load(f)

    features = meta['feature_columns']
    print(f"  Features: {len(features)}")

    # Load model
    model = BettingNetV3(len(features))
    model.load_state_dict(torch.load(
        os.path.join(MODELS_DIR, 'betting_model_v3.pth'),
        map_location='cpu',
        weights_only=True
    ))
    model.eval()

    # Load scaler
    scaler_data = np.load(
        os.path.join(MODELS_DIR, 'scaler_v3.npy'),
        allow_pickle=True
    ).item()

    # Load advanced data
    df = pd.read_csv(
        os.path.join(DATA_DIR, 'advanced_historical_matches.csv'),
        low_memory=False
    )
    df = df.sort_values('date').reset_index(drop=True)

    # Get recent matches from core leagues
    df_core = df[df['league_code'].isin(CORE_LEAGUES)].copy()

    # Take last N matches per league
    verification_matches = []
    for league in CORE_LEAGUES:
        league_df = df_core[df_core['league_code'] == league].tail(matches_per_league)
        verification_matches.append(league_df)

    test_df = pd.concat(verification_matches).copy()
    print(f"  Verification matches: {len(test_df)}")

    # Prepare features
    for col in features:
        if col not in test_df.columns:
            test_df[col] = 0
        test_df[col] = test_df[col].fillna(0)

    X = test_df[features].values

    # Scale
    X_scaled = (X - scaler_data['mean']) / scaler_data['scale']

    # Predict
    with torch.no_grad():
        outputs = model(torch.FloatTensor(X_scaled))
        probs = torch.softmax(outputs, dim=1).numpy()
        preds = outputs.argmax(dim=1).numpy()

    # Create verification results
    results = []
    for i, (_, row) in enumerate(test_df.iterrows()):
        # Skip rows with missing essential data
        if pd.isna(row['home_team']) or pd.isna(row['away_team']) or pd.isna(row['date']):
            continue

        actual_code = int(row['result_code']) if pd.notna(row['result_code']) else 1
        pred_code = int(preds[i])

        # Handle potential NaN in probabilities
        home_prob = float(probs[i][2] * 100) if not np.isnan(probs[i][2]) else 33.3
        draw_prob = float(probs[i][1] * 100) if not np.isnan(probs[i][1]) else 33.3
        away_prob = float(probs[i][0] * 100) if not np.isnan(probs[i][0]) else 33.3
        confidence = float(probs[i].max() * 100) if not np.isnan(probs[i].max()) else 33.3

        results.append({
            'league': str(row['league_code']),
            'league_name': LEAGUE_NAMES.get(row['league_code'], str(row['league_code'])),
            'date': str(row['date'])[:10],
            'home_team': str(row['home_team']),
            'away_team': str(row['away_team']),
            'home_goals': int(row['home_goals']) if pd.notna(row['home_goals']) else 0,
            'away_goals': int(row['away_goals']) if pd.notna(row['away_goals']) else 0,
            'actual': RESULT_MAP[actual_code],
            'actual_code': actual_code,
            'predicted': RESULT_MAP[pred_code],
            'predicted_code': pred_code,
            'confidence': confidence,
            'home_prob': home_prob,
            'draw_prob': draw_prob,
            'away_prob': away_prob,
            'correct': actual_code == pred_code
        })

    # Save
    output_path = os.path.join(DATA_DIR, 'verification_results.json')
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)

    # Stats
    correct = sum(1 for r in results if r['correct'])
    total = len(results)

    print(f"\n{'='*50}")
    print(f"✅ VERIFICA V3 COMPLETATA")
    print(f"{'='*50}")
    print(f"  Totale partite: {total}")
    print(f"  Corrette: {correct}")
    print(f"  Accuracy: {100*correct/total:.1f}%")
    print()

    # Per league
    for league in CORE_LEAGUES:
        league_results = [r for r in results if r['league'] == league]
        if league_results:
            league_correct = sum(1 for r in league_results if r['correct'])
            print(f"  {LEAGUE_NAMES[league]}: {league_correct}/{len(league_results)} ({100*league_correct/len(league_results):.0f}%)")

    # High confidence
    high_conf = [r for r in results if r['confidence'] > 50]
    if high_conf:
        high_conf_correct = sum(1 for r in high_conf if r['correct'])
        print(f"\n  High confidence (>50%): {high_conf_correct}/{len(high_conf)} ({100*high_conf_correct/len(high_conf):.0f}%)")

    print(f"\n  Salvato in: {output_path}")

    return results


if __name__ == '__main__':
    generate_verification(matches_per_league=20)
