#!/usr/bin/env python3
"""
Template Training per RunPod
Questo file va caricato sul pod ed eseguito
"""

import os
import sys
import json
import torch
from pathlib import Path
from datetime import datetime

# ============================================================
# CONFIGURAZIONE - MODIFICA QUESTI PARAMETRI
# ============================================================

CONFIG = {
    # Modello
    "model_name": "meta-llama/Llama-2-7b-hf",  # o il tuo modello
    "use_lora": True,  # LoRA per risparmiare VRAM
    "load_in_4bit": True,  # Quantizzazione 4-bit

    # Dataset
    "dataset_name": "your-dataset",  # Nome dataset HuggingFace o path locale
    "dataset_path": None,  # Path locale se non usi HF

    # Training
    "epochs": 3,
    "batch_size": 4,
    "gradient_accumulation": 4,
    "learning_rate": 2e-4,
    "max_seq_length": 512,

    # LoRA config
    "lora_r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "lora_target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"],

    # Output
    "output_dir": "/workspace/output",
    "save_steps": 100,
    "logging_steps": 10,
}

# ============================================================
# SETUP E UTILITIES
# ============================================================

def setup_environment():
    """Configura ambiente e verifica GPU"""
    print("="*60)
    print("SETUP AMBIENTE")
    print("="*60)

    # Verifica CUDA
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"✓ GPU: {gpu_name}")
        print(f"✓ VRAM: {gpu_memory:.1f} GB")
    else:
        print("✗ CUDA non disponibile!")
        sys.exit(1)

    # Crea directory output
    os.makedirs(CONFIG["output_dir"], exist_ok=True)
    print(f"✓ Output dir: {CONFIG['output_dir']}")

    # HuggingFace token
    hf_token = os.getenv("HF_TOKEN")
    if hf_token:
        print("✓ HuggingFace token trovato")
    else:
        print("! HF_TOKEN non impostato (potrebbe servire per alcuni modelli)")

    return True


def install_dependencies():
    """Installa dipendenze necessarie"""
    print("\nInstallazione dipendenze...")

    packages = [
        "transformers>=4.36.0",
        "datasets>=2.15.0",
        "accelerate>=0.25.0",
        "peft>=0.7.0",
        "bitsandbytes>=0.41.0",
        "trl>=0.7.0",
        "wandb",  # opzionale per logging
    ]

    for pkg in packages:
        os.system(f"pip install -q {pkg}")

    print("✓ Dipendenze installate")


# ============================================================
# TRAINING
# ============================================================

def load_model():
    """Carica modello con ottimizzazioni"""
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

    print("\n" + "="*60)
    print("CARICAMENTO MODELLO")
    print("="*60)

    # Configurazione quantizzazione
    bnb_config = None
    if CONFIG["load_in_4bit"]:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
        )

    # Carica tokenizer
    print(f"Caricamento tokenizer: {CONFIG['model_name']}")
    tokenizer = AutoTokenizer.from_pretrained(
        CONFIG["model_name"],
        trust_remote_code=True,
    )
    tokenizer.pad_token = tokenizer.eos_token

    # Carica modello
    print(f"Caricamento modello: {CONFIG['model_name']}")
    model = AutoModelForCausalLM.from_pretrained(
        CONFIG["model_name"],
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )

    # Applica LoRA
    if CONFIG["use_lora"]:
        print("Applicazione LoRA...")
        model = prepare_model_for_kbit_training(model)

        lora_config = LoraConfig(
            r=CONFIG["lora_r"],
            lora_alpha=CONFIG["lora_alpha"],
            lora_dropout=CONFIG["lora_dropout"],
            target_modules=CONFIG["lora_target_modules"],
            bias="none",
            task_type="CAUSAL_LM",
        )

        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()

    print("✓ Modello caricato")
    return model, tokenizer


def load_dataset(tokenizer):
    """Carica e prepara dataset"""
    from datasets import load_dataset

    print("\n" + "="*60)
    print("CARICAMENTO DATASET")
    print("="*60)

    if CONFIG["dataset_path"]:
        # Dataset locale
        dataset = load_dataset("json", data_files=CONFIG["dataset_path"])
    else:
        # Dataset da HuggingFace
        dataset = load_dataset(CONFIG["dataset_name"])

    print(f"✓ Dataset caricato: {len(dataset['train'])} esempi")

    return dataset


def train(model, tokenizer, dataset):
    """Esegue training"""
    from transformers import TrainingArguments
    from trl import SFTTrainer

    print("\n" + "="*60)
    print("TRAINING")
    print("="*60)

    training_args = TrainingArguments(
        output_dir=CONFIG["output_dir"],
        num_train_epochs=CONFIG["epochs"],
        per_device_train_batch_size=CONFIG["batch_size"],
        gradient_accumulation_steps=CONFIG["gradient_accumulation"],
        learning_rate=CONFIG["learning_rate"],
        fp16=True,
        logging_steps=CONFIG["logging_steps"],
        save_steps=CONFIG["save_steps"],
        save_total_limit=3,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        report_to="none",  # o "wandb" se vuoi logging
    )

    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset["train"],
        tokenizer=tokenizer,
        args=training_args,
        max_seq_length=CONFIG["max_seq_length"],
        dataset_text_field="text",  # campo del dataset con il testo
    )

    print("Avvio training...")
    start_time = datetime.now()

    trainer.train()

    elapsed = datetime.now() - start_time
    print(f"\n✓ Training completato in {elapsed}")

    # Salva modello finale
    final_path = f"{CONFIG['output_dir']}/final"
    trainer.save_model(final_path)
    tokenizer.save_pretrained(final_path)
    print(f"✓ Modello salvato in: {final_path}")

    return trainer


# ============================================================
# MAIN
# ============================================================

def main():
    print("\n" + "="*60)
    print("TRAINING SCRIPT - RunPod")
    print(f"Avviato: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("="*60)

    # Setup
    setup_environment()
    install_dependencies()

    # Carica modello e dataset
    model, tokenizer = load_model()
    dataset = load_dataset(tokenizer)

    # Training
    trainer = train(model, tokenizer, dataset)

    print("\n" + "="*60)
    print("COMPLETATO!")
    print("="*60)
    print(f"Modello salvato in: {CONFIG['output_dir']}/final")
    print("\nPer scaricare il modello:")
    print("  1. Comprimi: tar -czf model.tar.gz /workspace/output/final")
    print("  2. Usa runpodctl per scaricare")


if __name__ == "__main__":
    main()
