"""
BET.CUTTALO.COM - NEURAL NETWORK V2
Rete neurale avanzata con feature engineering evoluto
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import joblib
import os
import json
from typing import Tuple, List, Dict
from datetime import datetime

MODELS_DIR = os.path.join(os.path.dirname(__file__), '..', 'models')
DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
os.makedirs(MODELS_DIR, exist_ok=True)


class MatchDataset(Dataset):
    """Dataset PyTorch per partite calcio"""

    def __init__(self, features: np.ndarray, labels: np.ndarray):
        self.features = torch.FloatTensor(features)
        self.labels = torch.LongTensor(labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


class BettingNeuralNetworkV2(nn.Module):
    """
    Rete neurale deep con architettura avanzata
    """

    def __init__(self, input_size: int, hidden_sizes: List[int] = [256, 512, 256, 128],
                 num_classes: int = 3, dropout: float = 0.4):
        super(BettingNeuralNetworkV2, self).__init__()

        self.input_size = input_size
        self.num_classes = num_classes

        # Input embedding layer
        self.input_layer = nn.Sequential(
            nn.Linear(input_size, hidden_sizes[0]),
            nn.BatchNorm1d(hidden_sizes[0]),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout)
        )

        # Residual blocks
        self.res_blocks = nn.ModuleList()
        for i in range(len(hidden_sizes) - 1):
            self.res_blocks.append(
                ResidualBlock(hidden_sizes[i], hidden_sizes[i + 1], dropout)
            )

        # Attention layer
        self.attention = nn.Sequential(
            nn.Linear(hidden_sizes[-1], hidden_sizes[-1] // 4),
            nn.Tanh(),
            nn.Linear(hidden_sizes[-1] // 4, hidden_sizes[-1]),
            nn.Sigmoid()
        )

        # Output layers
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_sizes[-1], 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(dropout / 2),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        # Input
        x = self.input_layer(x)

        # Residual blocks
        for block in self.res_blocks:
            x = block(x)

        # Attention
        attn_weights = self.attention(x)
        x = x * attn_weights

        # Output
        return self.output_layer(x)


class ResidualBlock(nn.Module):
    """Blocco residuale con skip connection"""

    def __init__(self, in_features: int, out_features: int, dropout: float):
        super(ResidualBlock, self).__init__()

        self.main = nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.BatchNorm1d(out_features),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout),
            nn.Linear(out_features, out_features),
            nn.BatchNorm1d(out_features)
        )

        # Skip connection
        self.skip = nn.Linear(in_features, out_features) if in_features != out_features else nn.Identity()
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.activation(self.main(x) + self.skip(x))


class AdvancedFeatureEngine:
    """
    Motore per feature engineering avanzato
    Calcola statistiche rolling, head-to-head, e trend
    """

    def __init__(self, n_recent: int = 5):
        self.n_recent = n_recent
        self.team_cache = {}

    def create_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Crea tutte le feature per il training"""
        print("Creando feature avanzate...")

        # Ordina per data
        df = df.sort_values('Date').copy()
        df['Date'] = pd.to_datetime(df['Date'], errors='coerce')

        features_list = []
        total = len(df)

        for idx, (i, row) in enumerate(df.iterrows()):
            if idx % 1000 == 0:
                print(f"  Progress: {idx}/{total} ({idx/total*100:.1f}%)")

            home = row['home_team']
            away = row['away_team']
            current_date = row['Date']
            league = row['league_code']

            # Storico partite
            past_matches = df[df['Date'] < current_date]
            league_matches = past_matches[past_matches['league_code'] == league]

            # Form squadre (ultime N partite)
            home_recent = self._get_team_matches(league_matches, home, self.n_recent)
            away_recent = self._get_team_matches(league_matches, away, self.n_recent)

            # Statistiche casa/trasferta specifiche
            home_home_matches = league_matches[league_matches['home_team'] == home].tail(self.n_recent)
            away_away_matches = league_matches[league_matches['away_team'] == away].tail(self.n_recent)

            # Head to head
            h2h = league_matches[
                ((league_matches['home_team'] == home) & (league_matches['away_team'] == away)) |
                ((league_matches['home_team'] == away) & (league_matches['away_team'] == home))
            ].tail(5)

            # Calcola feature
            home_form = self._calc_form(home_recent, home)
            away_form = self._calc_form(away_recent, away)
            home_home_form = self._calc_home_form(home_home_matches)
            away_away_form = self._calc_away_form(away_away_matches)
            h2h_stats = self._calc_h2h(h2h, home, away)

            # League position (stimata dalla forma stagionale)
            season_matches = league_matches[league_matches['season'] == row['season']]
            home_season = self._calc_season_stats(season_matches, home)
            away_season = self._calc_season_stats(season_matches, away)

            features = {
                'match_id': row['match_id'],
                'Date': current_date,
                'league_code': league,
                'home_team': home,
                'away_team': away,
                'season': row['season'],

                # Target
                'result': row['result'],
                'home_goals': row['home_goals'],
                'away_goals': row['away_goals'],

                # Form features (ultime 5 partite)
                'home_form_points': home_form['points'],
                'home_form_goals': home_form['goals_for'],
                'home_form_conceded': home_form['goals_against'],
                'home_form_wins': home_form['wins'],
                'home_form_clean_sheets': home_form['clean_sheets'],

                'away_form_points': away_form['points'],
                'away_form_goals': away_form['goals_for'],
                'away_form_conceded': away_form['goals_against'],
                'away_form_wins': away_form['wins'],
                'away_form_clean_sheets': away_form['clean_sheets'],

                # Home/Away specific
                'home_home_ppg': home_home_form['ppg'],
                'home_home_goals': home_home_form['goals'],
                'away_away_ppg': away_away_form['ppg'],
                'away_away_goals': away_away_form['goals'],

                # Season standing proxy
                'home_season_ppg': home_season['ppg'],
                'home_season_gd': home_season['goal_diff'],
                'away_season_ppg': away_season['ppg'],
                'away_season_gd': away_season['goal_diff'],

                # H2H
                'h2h_home_advantage': h2h_stats['home_adv'],
                'h2h_total_goals': h2h_stats['avg_goals'],

                # Derived features
                'form_diff': home_form['points'] - away_form['points'],
                'goal_diff': (home_form['goals_for'] - home_form['goals_against']) -
                            (away_form['goals_for'] - away_form['goals_against']),
                'attack_diff': home_form['goals_for'] - away_form['goals_for'],
                'defense_diff': away_form['goals_against'] - home_form['goals_against'],
            }

            # Quote se disponibili
            if pd.notna(row.get('odds_home')) and row.get('odds_home', 0) > 0:
                features['has_odds'] = 1
                features['odds_home'] = row['odds_home']
                features['odds_draw'] = row.get('odds_draw', 0)
                features['odds_away'] = row.get('odds_away', 0)
                # Implied probability
                total_odds = 1/row['odds_home'] + 1/row.get('odds_draw', 3.5) + 1/row.get('odds_away', 4)
                features['implied_home'] = (1/row['odds_home']) / total_odds if total_odds > 0 else 0.33
            else:
                features['has_odds'] = 0
                features['odds_home'] = 0
                features['odds_draw'] = 0
                features['odds_away'] = 0
                features['implied_home'] = 0

            features_list.append(features)

        result_df = pd.DataFrame(features_list)
        print(f"  Create {len(result_df)} righe con feature avanzate")
        return result_df

    def _get_team_matches(self, df: pd.DataFrame, team: str, n: int) -> pd.DataFrame:
        """Ottieni ultime n partite di una squadra"""
        return df[(df['home_team'] == team) | (df['away_team'] == team)].tail(n)

    def _calc_form(self, matches: pd.DataFrame, team: str) -> Dict:
        """Calcola forma squadra"""
        if matches.empty:
            return {'points': 0, 'goals_for': 0, 'goals_against': 0, 'wins': 0, 'clean_sheets': 0}

        points, goals_for, goals_against, wins, clean_sheets = 0, 0, 0, 0, 0

        for _, m in matches.iterrows():
            is_home = m['home_team'] == team
            if is_home:
                gf, ga = m['home_goals'], m['away_goals']
                if m['result'] == 2:
                    points += 3
                    wins += 1
                elif m['result'] == 1:
                    points += 1
            else:
                gf, ga = m['away_goals'], m['home_goals']
                if m['result'] == 0:
                    points += 3
                    wins += 1
                elif m['result'] == 1:
                    points += 1

            goals_for += gf
            goals_against += ga
            if ga == 0:
                clean_sheets += 1

        n = len(matches)
        return {
            'points': points / n,
            'goals_for': goals_for / n,
            'goals_against': goals_against / n,
            'wins': wins / n,
            'clean_sheets': clean_sheets / n
        }

    def _calc_home_form(self, matches: pd.DataFrame) -> Dict:
        """Forma in casa"""
        if matches.empty:
            return {'ppg': 1.0, 'goals': 1.0}

        points = 0
        goals = 0
        for _, m in matches.iterrows():
            if m['result'] == 2:
                points += 3
            elif m['result'] == 1:
                points += 1
            goals += m['home_goals']

        n = len(matches)
        return {'ppg': points / n, 'goals': goals / n}

    def _calc_away_form(self, matches: pd.DataFrame) -> Dict:
        """Forma in trasferta"""
        if matches.empty:
            return {'ppg': 1.0, 'goals': 1.0}

        points = 0
        goals = 0
        for _, m in matches.iterrows():
            if m['result'] == 0:
                points += 3
            elif m['result'] == 1:
                points += 1
            goals += m['away_goals']

        n = len(matches)
        return {'ppg': points / n, 'goals': goals / n}

    def _calc_h2h(self, matches: pd.DataFrame, home: str, away: str) -> Dict:
        """Statistiche head-to-head"""
        if matches.empty:
            return {'home_adv': 0, 'avg_goals': 2.5}

        home_wins = 0
        total_goals = 0

        for _, m in matches.iterrows():
            total_goals += m['home_goals'] + m['away_goals']
            if m['home_team'] == home and m['result'] == 2:
                home_wins += 1
            elif m['away_team'] == home and m['result'] == 0:
                home_wins += 1

        n = len(matches)
        return {
            'home_adv': (home_wins / n) - 0.5,  # 0 = neutral, positive = home advantage
            'avg_goals': total_goals / n
        }

    def _calc_season_stats(self, matches: pd.DataFrame, team: str) -> Dict:
        """Stats stagionali"""
        team_matches = matches[(matches['home_team'] == team) | (matches['away_team'] == team)]

        if team_matches.empty:
            return {'ppg': 1.0, 'goal_diff': 0}

        points = 0
        goal_diff = 0

        for _, m in team_matches.iterrows():
            is_home = m['home_team'] == team
            if is_home:
                gf, ga = m['home_goals'], m['away_goals']
                if m['result'] == 2:
                    points += 3
                elif m['result'] == 1:
                    points += 1
            else:
                gf, ga = m['away_goals'], m['home_goals']
                if m['result'] == 0:
                    points += 3
                elif m['result'] == 1:
                    points += 1
            goal_diff += gf - ga

        n = len(team_matches)
        return {'ppg': points / n, 'goal_diff': goal_diff / n}


class BettingPredictorV2:
    """
    Predittore avanzato con feature engineering
    """

    FEATURE_COLUMNS = [
        'home_form_points', 'home_form_goals', 'home_form_conceded',
        'home_form_wins', 'home_form_clean_sheets',
        'away_form_points', 'away_form_goals', 'away_form_conceded',
        'away_form_wins', 'away_form_clean_sheets',
        'home_home_ppg', 'home_home_goals',
        'away_away_ppg', 'away_away_goals',
        'home_season_ppg', 'home_season_gd',
        'away_season_ppg', 'away_season_gd',
        'h2h_home_advantage', 'h2h_total_goals',
        'form_diff', 'goal_diff', 'attack_diff', 'defense_diff',
        'has_odds', 'odds_home', 'odds_draw', 'odds_away', 'implied_home'
    ]

    def __init__(self, model_name: str = 'betting_model_v2'):
        self.model_name = model_name
        self.model = None
        self.scaler = StandardScaler()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.training_history = []
        self.feature_engine = AdvancedFeatureEngine()

    def prepare_data(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        """Prepara dati per training"""
        # Seleziona feature esistenti
        available_cols = [c for c in self.FEATURE_COLUMNS if c in df.columns]

        # Fill missing
        for col in self.FEATURE_COLUMNS:
            if col not in df.columns:
                df[col] = 0

        X = df[self.FEATURE_COLUMNS].fillna(0).values
        y = df['result'].values

        # Rimuovi NaN
        mask = ~np.isnan(X).any(axis=1) & ~np.isnan(y)
        X = X[mask]
        y = y[mask].astype(int)

        # Scala feature
        X = self.scaler.fit_transform(X)

        return X, y

    def train(self, df: pd.DataFrame = None, epochs: int = 150, batch_size: int = 64,
              learning_rate: float = 0.001, validation_split: float = 0.2):
        """
        Addestra la rete neurale
        """
        # Carica dati se non forniti
        if df is None:
            data_path = os.path.join(DATA_DIR, 'historical_matches.csv')
            df = pd.read_csv(data_path)

        print(f"Dati grezzi: {len(df)} partite")

        # Feature engineering
        features_path = os.path.join(DATA_DIR, 'training_features.csv')
        if os.path.exists(features_path):
            print("Caricando feature pre-calcolate...")
            features_df = pd.read_csv(features_path)
        else:
            features_df = self.feature_engine.create_features(df)
            features_df.to_csv(features_path, index=False)

        # Rimuovi prime partite (non hanno storico sufficiente)
        features_df = features_df[features_df['home_form_points'] > 0]
        print(f"Dati con feature: {len(features_df)} partite")

        print(f"\nPreparando training data...")
        X, y = self.prepare_data(features_df)
        print(f"Shape: X={X.shape}, y={y.shape}")

        # Split stratificato
        X_train, X_val, y_train, y_val = train_test_split(
            X, y, test_size=validation_split, random_state=42, stratify=y
        )

        # Dataset e DataLoader
        train_dataset = MatchDataset(X_train, y_train)
        val_dataset = MatchDataset(X_val, y_val)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)

        # Inizializza modello
        input_size = X.shape[1]
        self.model = BettingNeuralNetworkV2(input_size).to(self.device)

        # Class weights per bilanciare
        class_counts = np.bincount(y)
        class_weights = 1.0 / class_counts
        class_weights = class_weights / class_weights.sum()
        weights_tensor = torch.FloatTensor(class_weights).to(self.device)

        # Loss, optimizer, scheduler
        criterion = nn.CrossEntropyLoss(weight=weights_tensor)
        optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=0.01)
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=learning_rate * 10,
            epochs=epochs, steps_per_epoch=len(train_loader)
        )

        # Training
        best_val_acc = 0
        self.training_history = []

        print(f"\n{'='*60}")
        print(f"TRAINING su {self.device}")
        print(f"Train: {len(X_train)}, Validation: {len(X_val)}")
        print(f"Input features: {input_size}")
        print(f"{'='*60}")

        for epoch in range(epochs):
            # Training phase
            self.model.train()
            train_loss = 0
            train_correct = 0
            train_total = 0

            for features, labels in train_loader:
                features, labels = features.to(self.device), labels.to(self.device)

                optimizer.zero_grad()
                outputs = self.model(features)
                loss = criterion(outputs, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()

                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()

            train_acc = 100 * train_correct / train_total
            train_loss /= len(train_loader)

            # Validation phase
            self.model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0
            class_correct = [0, 0, 0]
            class_total = [0, 0, 0]

            with torch.no_grad():
                for features, labels in val_loader:
                    features, labels = features.to(self.device), labels.to(self.device)
                    outputs = self.model(features)
                    loss = criterion(outputs, labels)

                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

                    for i in range(3):
                        mask = labels == i
                        class_total[i] += mask.sum().item()
                        class_correct[i] += ((predicted == labels) & mask).sum().item()

            val_acc = 100 * val_correct / val_total
            val_loss /= len(val_loader)

            # Log
            self.training_history.append({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'train_acc': train_acc,
                'val_loss': val_loss,
                'val_acc': val_acc
            })

            if (epoch + 1) % 10 == 0 or epoch == 0:
                print(f"Epoch {epoch+1:3d}/{epochs} | "
                      f"Train: {train_acc:.1f}% | Val: {val_acc:.1f}% | "
                      f"Loss: {val_loss:.4f}")

            # Save best
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                self.save_model()

        print(f"\n{'='*60}")
        print(f"TRAINING COMPLETATO!")
        print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
        print(f"{'='*60}")

        # Accuracy per classe
        print("\nAccuracy per classe:")
        for i, label in enumerate(['Away Win', 'Draw', 'Home Win']):
            acc = 100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0
            print(f"  {label}: {acc:.1f}% ({class_correct[i]}/{class_total[i]})")

        return self.training_history

    def predict(self, match_data: Dict) -> Dict:
        """Predici singola partita"""
        if self.model is None:
            self.load_model()

        self.model.eval()

        # Prepara feature
        features = []
        for col in self.FEATURE_COLUMNS:
            features.append(match_data.get(col, 0))

        features = np.array(features).reshape(1, -1)
        features = self.scaler.transform(features)
        features = torch.FloatTensor(features).to(self.device)

        # Predizione
        with torch.no_grad():
            logits = self.model(features)
            probs = torch.softmax(logits, dim=1).cpu().numpy()[0]

        max_prob = max(probs)
        confidence = (max_prob - 0.33) / 0.67 * 100

        return {
            'home_win_prob': float(probs[2]) * 100,
            'draw_prob': float(probs[1]) * 100,
            'away_win_prob': float(probs[0]) * 100,
            'prediction': ['Away Win', 'Draw', 'Home Win'][np.argmax(probs)],
            'prediction_code': int(np.argmax(probs)),
            'confidence': float(max(0, min(100, confidence))),
            'raw_probs': probs.tolist()
        }

    def save_model(self):
        """Salva modello"""
        model_path = os.path.join(MODELS_DIR, f'{self.model_name}.pth')
        scaler_path = os.path.join(MODELS_DIR, f'{self.model_name}_scaler.pkl')
        history_path = os.path.join(MODELS_DIR, f'{self.model_name}_history.json')
        meta_path = os.path.join(MODELS_DIR, f'{self.model_name}_meta.json')

        torch.save({
            'model_state_dict': self.model.state_dict(),
            'input_size': self.model.input_size,
            'timestamp': datetime.now().isoformat()
        }, model_path)

        joblib.dump(self.scaler, scaler_path)

        with open(history_path, 'w') as f:
            json.dump(self.training_history, f)

        meta = {
            'model_name': self.model_name,
            'input_size': self.model.input_size,
            'feature_columns': self.FEATURE_COLUMNS,
            'trained_at': datetime.now().isoformat(),
            'best_val_acc': max(h['val_acc'] for h in self.training_history) if self.training_history else 0
        }
        with open(meta_path, 'w') as f:
            json.dump(meta, f, indent=2)

        print(f"Modello salvato: {model_path}")

    def load_model(self):
        """Carica modello"""
        model_path = os.path.join(MODELS_DIR, f'{self.model_name}.pth')
        scaler_path = os.path.join(MODELS_DIR, f'{self.model_name}_scaler.pkl')

        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Modello non trovato: {model_path}")

        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
        self.model = BettingNeuralNetworkV2(checkpoint['input_size']).to(self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()

        self.scaler = joblib.load(scaler_path)
        print(f"Modello caricato (trained: {checkpoint.get('timestamp', 'unknown')})")

    def get_model_stats(self) -> Dict:
        """Ritorna statistiche modello"""
        meta_path = os.path.join(MODELS_DIR, f'{self.model_name}_meta.json')
        history_path = os.path.join(MODELS_DIR, f'{self.model_name}_history.json')

        if os.path.exists(meta_path):
            with open(meta_path) as f:
                meta = json.load(f)

            if os.path.exists(history_path):
                with open(history_path) as f:
                    history = json.load(f)
                meta['history'] = history

            return meta

        return {'status': 'not_trained'}


if __name__ == "__main__":
    print("="*70)
    print("BET.CUTTALO.COM - TRAINING RETE NEURALE V2")
    print("="*70)

    predictor = BettingPredictorV2()
    predictor.train(epochs=150)
