"""
BET.CUTTALO.COM - NEURAL NETWORK
Rete neurale per predizione risultati calcistici
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import joblib
import os
from typing import Tuple, List, Dict
from datetime import datetime

MODELS_DIR = os.path.join(os.path.dirname(__file__), '..', 'models')
os.makedirs(MODELS_DIR, exist_ok=True)


class MatchDataset(Dataset):
    """Dataset PyTorch per partite calcio"""

    def __init__(self, features: np.ndarray, labels: np.ndarray):
        self.features = torch.FloatTensor(features)
        self.labels = torch.LongTensor(labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


class BettingNeuralNetwork(nn.Module):
    """
    Rete neurale deep per predizione risultati
    Architettura: Multi-layer con Batch Norm, Dropout, Skip connections
    """

    def __init__(self, input_size: int, hidden_sizes: List[int] = [128, 256, 128, 64],
                 num_classes: int = 3, dropout: float = 0.3):
        super(BettingNeuralNetwork, self).__init__()

        self.input_size = input_size
        self.num_classes = num_classes

        # Input layer
        self.input_layer = nn.Sequential(
            nn.Linear(input_size, hidden_sizes[0]),
            nn.BatchNorm1d(hidden_sizes[0]),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Hidden layers con residual connections
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_sizes) - 1):
            layer = nn.Sequential(
                nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]),
                nn.BatchNorm1d(hidden_sizes[i + 1]),
                nn.ReLU(),
                nn.Dropout(dropout)
            )
            self.hidden_layers.append(layer)

        # Output layer
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_sizes[-1], num_classes),
            nn.Softmax(dim=1)
        )

        # Skip connection per migliorare gradient flow
        self.skip_connection = nn.Linear(hidden_sizes[0], hidden_sizes[-1])

    def forward(self, x):
        # Input
        x = self.input_layer(x)
        skip = self.skip_connection(x)

        # Hidden layers
        for layer in self.hidden_layers:
            x = layer(x)

        # Add skip connection
        x = x + skip

        # Output
        return self.output_layer(x)


class BettingPredictor:
    """
    Classe principale per training e inferenza
    """

    FEATURE_COLUMNS = [
        'home_position', 'home_points', 'home_won', 'home_draw', 'home_lost',
        'home_goals_for', 'home_goals_against', 'home_goal_diff',
        'away_position', 'away_points', 'away_won', 'away_draw', 'away_lost',
        'away_goals_for', 'away_goals_against', 'away_goal_diff',
        'position_diff', 'points_diff'
    ]

    def __init__(self, model_name: str = 'betting_model'):
        self.model_name = model_name
        self.model = None
        self.scaler = StandardScaler()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.training_history = []

    def prepare_data(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        """Prepara dati per training"""
        # Seleziona feature
        X = df[self.FEATURE_COLUMNS].values
        y = df['result'].values

        # Scala feature
        X = self.scaler.fit_transform(X)

        return X, y

    def train(self, df: pd.DataFrame, epochs: int = 100, batch_size: int = 32,
              learning_rate: float = 0.001, validation_split: float = 0.2):
        """
        Addestra la rete neurale
        """
        print(f"Preparando dati ({len(df)} samples)...")
        X, y = self.prepare_data(df)

        # Split train/validation
        X_train, X_val, y_train, y_val = train_test_split(
            X, y, test_size=validation_split, random_state=42, stratify=y
        )

        # Crea dataset e dataloader
        train_dataset = MatchDataset(X_train, y_train)
        val_dataset = MatchDataset(X_val, y_val)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)

        # Inizializza modello
        input_size = X.shape[1]
        self.model = BettingNeuralNetwork(input_size).to(self.device)

        # Loss e optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5)

        # Training loop
        best_val_acc = 0
        self.training_history = []

        print(f"\nTraining su {self.device}...")
        print(f"Train: {len(X_train)}, Validation: {len(X_val)}")
        print("-" * 50)

        for epoch in range(epochs):
            # Training
            self.model.train()
            train_loss = 0
            train_correct = 0
            train_total = 0

            for features, labels in train_loader:
                features, labels = features.to(self.device), labels.to(self.device)

                optimizer.zero_grad()
                outputs = self.model(features)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()

            train_acc = 100 * train_correct / train_total
            train_loss /= len(train_loader)

            # Validation
            self.model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for features, labels in val_loader:
                    features, labels = features.to(self.device), labels.to(self.device)
                    outputs = self.model(features)
                    loss = criterion(outputs, labels)

                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()

            val_acc = 100 * val_correct / val_total
            val_loss /= len(val_loader)

            # Learning rate scheduling
            scheduler.step(val_loss)

            # Salva history
            self.training_history.append({
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'train_acc': train_acc,
                'val_loss': val_loss,
                'val_acc': val_acc
            })

            # Log ogni 10 epoche
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs} | "
                      f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | "
                      f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}%")

            # Salva best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                self.save_model()

        print("-" * 50)
        print(f"Training completato! Best validation accuracy: {best_val_acc:.2f}%")

        return self.training_history

    def predict(self, match_data: Dict) -> Dict:
        """
        Predici risultato singola partita
        Ritorna probabilita per ogni outcome
        """
        if self.model is None:
            self.load_model()

        self.model.eval()

        # Prepara feature
        features = [match_data.get(col, 0) for col in self.FEATURE_COLUMNS]
        features = np.array(features).reshape(1, -1)
        features = self.scaler.transform(features)
        features = torch.FloatTensor(features).to(self.device)

        # Predizione
        with torch.no_grad():
            probs = self.model(features).cpu().numpy()[0]

        # Calcola confidence
        max_prob = max(probs)
        confidence = (max_prob - 0.33) / 0.67 * 100  # Normalizza 33%-100% -> 0%-100%

        return {
            'home_win_prob': float(probs[2]) * 100,
            'draw_prob': float(probs[1]) * 100,
            'away_win_prob': float(probs[0]) * 100,
            'prediction': ['Away Win', 'Draw', 'Home Win'][np.argmax(probs)],
            'confidence': float(max(0, min(100, confidence))),
            'raw_probs': probs.tolist()
        }

    def predict_batch(self, matches: List[Dict]) -> List[Dict]:
        """Predici batch di partite"""
        predictions = []
        for match in matches:
            pred = self.predict(match)
            pred['match'] = match
            predictions.append(pred)
        return predictions

    def save_model(self):
        """Salva modello e scaler"""
        model_path = os.path.join(MODELS_DIR, f'{self.model_name}.pth')
        scaler_path = os.path.join(MODELS_DIR, f'{self.model_name}_scaler.pkl')
        history_path = os.path.join(MODELS_DIR, f'{self.model_name}_history.pkl')

        torch.save({
            'model_state_dict': self.model.state_dict(),
            'input_size': self.model.input_size,
            'timestamp': datetime.now().isoformat()
        }, model_path)

        joblib.dump(self.scaler, scaler_path)
        joblib.dump(self.training_history, history_path)

        print(f"Modello salvato in {model_path}")

    def load_model(self):
        """Carica modello e scaler"""
        model_path = os.path.join(MODELS_DIR, f'{self.model_name}.pth')
        scaler_path = os.path.join(MODELS_DIR, f'{self.model_name}_scaler.pkl')

        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Modello non trovato: {model_path}")

        checkpoint = torch.load(model_path, map_location=self.device)
        self.model = BettingNeuralNetwork(checkpoint['input_size']).to(self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()

        self.scaler = joblib.load(scaler_path)

        print(f"Modello caricato (trained: {checkpoint.get('timestamp', 'unknown')})")

    def get_model_stats(self) -> Dict:
        """Ritorna statistiche modello"""
        history_path = os.path.join(MODELS_DIR, f'{self.model_name}_history.pkl')

        if os.path.exists(history_path):
            history = joblib.load(history_path)
            if history:
                last = history[-1]
                best_acc = max(h['val_acc'] for h in history)
                return {
                    'epochs_trained': len(history),
                    'final_train_acc': last['train_acc'],
                    'final_val_acc': last['val_acc'],
                    'best_val_acc': best_acc,
                    'last_train_loss': last['train_loss'],
                    'history': history
                }

        return {'status': 'no_training_data'}


if __name__ == "__main__":
    # Test training
    import os
    import sys
    sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
    from data_collector import FootballDataCollector

    # Raccogli dati
    collector = FootballDataCollector()

    data_path = os.path.join(os.path.dirname(__file__), '..', 'data', 'historical_matches.csv')

    if os.path.exists(data_path):
        df = pd.read_csv(data_path)
    else:
        df = collector.collect_historical_data(seasons=[2023, 2024])

    # Addestra modello
    predictor = BettingPredictor()
    predictor.train(df, epochs=100)

    # Test predizione
    test_match = {
        'home_position': 3,
        'home_points': 45,
        'home_won': 14,
        'home_draw': 3,
        'home_lost': 5,
        'home_goals_for': 42,
        'home_goals_against': 20,
        'home_goal_diff': 22,
        'away_position': 8,
        'away_points': 32,
        'away_won': 9,
        'away_draw': 5,
        'away_lost': 8,
        'away_goals_for': 28,
        'away_goals_against': 25,
        'away_goal_diff': 3,
        'position_diff': -5,
        'points_diff': 13
    }

    print("\nTest predizione:")
    pred = predictor.predict(test_match)
    print(f"Predizione: {pred['prediction']}")
    print(f"Confidence: {pred['confidence']:.1f}%")
    print(f"Home Win: {pred['home_win_prob']:.1f}%")
    print(f"Draw: {pred['draw_prob']:.1f}%")
    print(f"Away Win: {pred['away_win_prob']:.1f}%")
