#!/usr/bin/env python3
"""
================================================================================
   VESUVIUS V8 - nnUNet TRAINING ON RUNPOD

   Strategy: Use nnUNet (SOTA for 3D medical segmentation)
   - Pre-trained with ResEncM planner (ResNet encoder)
   - Full dataset training (fold=all)
   - 250 epochs optimal for competition

   GPU: A40 46GB VRAM - Can handle 3d_fullres easily
   Expected training time: 4-8 hours

   Target: Beat 0.575 (current #1)
================================================================================
"""

import os
import subprocess
import sys
from pathlib import Path
import json

# =============================================================================
# CONFIGURATION
# =============================================================================
WORKSPACE = Path("/workspace")
DATA_DIR = WORKSPACE / "nnunet_data"
RESULTS_DIR = WORKSPACE / "nnunet_results"

# nnUNet environment
os.environ["nnUNet_raw"] = str(DATA_DIR / "nnUNet_raw")
os.environ["nnUNet_preprocessed"] = str(DATA_DIR / "nnUNet_preprocessed")
os.environ["nnUNet_results"] = str(RESULTS_DIR)
os.environ["nnUNet_compile"] = "false"  # Disable torch.compile for stability
os.environ["nnUNet_USE_BLOSC2"] = "1"   # Better compression

# Dataset config
DATASET_ID = 100
DATASET_NAME = f"Dataset{DATASET_ID:03d}_VesuviusSurface"

# Training config
CONFIGURATION = "3d_fullres"  # Best quality for 3D segmentation
PLANNER = "nnUNetPlannerResEncM"  # ResNet encoder (1-2% better)
PLANS_NAME = "nnUNetResEncUNetMPlans"
FOLD = "all"  # Train on all data (no validation split)
EPOCHS = 250  # Good balance between quality and time

# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
def run_cmd(cmd, desc=""):
    """Run shell command with output."""
    print(f"\n{'='*60}")
    print(f"[RUNNING] {desc or cmd}")
    print(f"{'='*60}")
    result = subprocess.run(cmd, shell=True, capture_output=False, text=True)
    if result.returncode != 0:
        print(f"[ERROR] Command failed with code {result.returncode}")
        return False
    return True

def check_gpu():
    """Check GPU availability."""
    print("\n[GPU CHECK]")
    os.system("nvidia-smi")

    import torch
    if torch.cuda.is_available():
        print(f"\nPyTorch CUDA: {torch.version.cuda}")
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        return True
    else:
        print("ERROR: No GPU available!")
        return False

def setup_directories():
    """Create necessary directories."""
    RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    print(f"\nDirectories:")
    print(f"  Raw: {os.environ['nnUNet_raw']}")
    print(f"  Preprocessed: {os.environ['nnUNet_preprocessed']}")
    print(f"  Results: {os.environ['nnUNet_results']}")

def check_data():
    """Check if preprocessed data exists."""
    preprocessed_path = Path(os.environ["nnUNet_preprocessed"]) / DATASET_NAME

    if not preprocessed_path.exists():
        print(f"\n[ERROR] Preprocessed data not found: {preprocessed_path}")
        print("\nAvailable folders:")
        for p in Path(os.environ["nnUNet_preprocessed"]).glob("*"):
            print(f"  {p}")
        return False

    # Check for required files
    required = ["dataset.json", "nnUNetPlans.json"]
    for f in required:
        if not (preprocessed_path / f).exists():
            # Check alternative names
            alt_path = preprocessed_path / f"{PLANS_NAME}.json"
            if f == "nnUNetPlans.json" and alt_path.exists():
                continue
            print(f"[WARNING] Missing: {f}")

    print(f"\n[OK] Preprocessed data found: {preprocessed_path}")
    return True

def get_trainer_name():
    """Get trainer class name based on epochs."""
    if EPOCHS is None or EPOCHS == 1000:
        return "nnUNetTrainer"
    elif EPOCHS == 1:
        return "nnUNetTrainer_1epoch"
    else:
        return f"nnUNetTrainer_{EPOCHS}epochs"

def train():
    """Run nnUNet training."""
    trainer = get_trainer_name()

    # Build command
    cmd = f"nnUNetv2_train {DATASET_ID:03d} {CONFIGURATION} {FOLD}"
    cmd += f" -p {PLANS_NAME} -tr {trainer}"

    print(f"\n{'='*60}")
    print("TRAINING CONFIGURATION")
    print(f"{'='*60}")
    print(f"  Dataset: {DATASET_NAME}")
    print(f"  Configuration: {CONFIGURATION}")
    print(f"  Planner: {PLANNER}")
    print(f"  Trainer: {trainer}")
    print(f"  Fold: {FOLD}")
    print(f"  Epochs: {EPOCHS}")
    print(f"\nCommand: {cmd}")
    print(f"{'='*60}\n")

    # Run training
    result = subprocess.run(cmd, shell=True)

    if result.returncode != 0:
        print(f"\n[ERROR] Training failed with code {result.returncode}")
        return False

    print(f"\n[SUCCESS] Training complete!")
    return True

def find_best_checkpoint():
    """Find the best checkpoint after training."""
    trainer = get_trainer_name()
    model_dir = RESULTS_DIR / DATASET_NAME / f"{trainer}__{PLANS_NAME}__{CONFIGURATION}" / f"fold_{FOLD}"

    checkpoints = list(model_dir.glob("checkpoint_*.pth"))

    if not checkpoints:
        print(f"[ERROR] No checkpoints found in {model_dir}")
        return None

    print(f"\nCheckpoints found:")
    for cp in sorted(checkpoints):
        size = cp.stat().st_size / 1e6
        print(f"  {cp.name}: {size:.1f} MB")

    best = model_dir / "checkpoint_best.pth"
    if best.exists():
        return best

    return sorted(checkpoints)[-1]

# =============================================================================
# MAIN
# =============================================================================
if __name__ == "__main__":
    print("="*60)
    print("VESUVIUS V8 - nnUNet TRAINING")
    print("="*60)

    # Check GPU
    if not check_gpu():
        sys.exit(1)

    # Setup
    setup_directories()

    # Check data
    if not check_data():
        print("\n[INFO] Waiting for data download to complete...")
        print("Run this script again after download finishes.")
        sys.exit(1)

    # Train
    if train():
        checkpoint = find_best_checkpoint()
        if checkpoint:
            print(f"\n{'='*60}")
            print("TRAINING COMPLETE!")
            print(f"{'='*60}")
            print(f"Best checkpoint: {checkpoint}")
            print(f"\nNext steps:")
            print(f"1. Download checkpoint to local machine")
            print(f"2. Upload to Kaggle as dataset")
            print(f"3. Run inference kernel")
    else:
        print("\n[FAILED] Training did not complete successfully")
        sys.exit(1)
