"""
VESUVIUS V5 WINNER - Training + Topology Post-Processing
=========================================================
Target: Beat 0.575 (current #1) to win $100K

Strategy:
1. nnUNet training (SOTA for 3D medical segmentation)
2. Topology-aware post-processing
3. Connected components filtering
4. Morphological operations for gap filling

Based on analysis of top notebooks:
- jirkaborovec/surface-nnunet-training-inference-with-2xt4
- mayukh18/pytorch-cascaded-unet-inference
- jirkaborovec/surface-train-inference-3d-segm-gpu-augment
"""

import os
import json
import shutil
import subprocess
import zipfile
from pathlib import Path
from typing import Optional, Tuple, List, Union
from functools import partial
from multiprocessing import Pool

# =============================================================================
# CONFIGURATION
# =============================================================================

# Competition data
INPUT_DIR = Path("/kaggle/input/vesuvius-challenge-surface-detection")
WORKING_DIR = Path("/kaggle/temp")
OUTPUT_DIR = Path("/kaggle/working")

# PRE-PREPARED DATA (saves 1-2 hours!)
PREPARED_PREPROCESSED_PATH = Path("/kaggle/input/vesuvius-surface-nnunet-preprocessed")
PREPARED_NPZ_PATH = Path("/kaggle/input/vesuvius-surface-npz")

# nnUNet configuration
NNUNET_BASE = WORKING_DIR / "nnUNet_data"
NNUNET_RAW = NNUNET_BASE / "nnUNet_raw"
NNUNET_PREPROCESSED = NNUNET_BASE / "nnUNet_preprocessed"
NNUNET_RESULTS = OUTPUT_DIR / "nnUNet_results"

DATASET_ID = 100
DATASET_NAME = f"Dataset{DATASET_ID:03d}_VesuviusSurface"

# Training configuration - OPTIMIZED FOR WINNING
FOLD = "all"  # Train on all data
CONFIGURATION = "3d_fullres"  # Best quality
PLANNER = "nnUNetPlannerResEncM"  # ResNet encoder (better accuracy)
PLANS_NAME = "nnUNetResEncUNetMPlans"
EPOCHS = 250  # Good balance speed/quality
NUM_WORKERS = os.cpu_count() or 4

# Post-processing configuration (CRITICAL for topology score!)
POST_PROCESS_CONFIG = {
    "min_component_size": 3000,  # Remove small artifacts
    "closing_radius": 3,  # Fill small gaps
    "border_cleanup": 5,  # Clean edges
    "threshold": 0.5,  # Binary threshold
}

# =============================================================================
# INSTALLATION
# =============================================================================

def install_packages():
    """Install required packages."""
    print("Installing packages...")
    os.system("pip install nnunetv2 nibabel tifffile tqdm scipy -q")
    print("Packages installed.")

# =============================================================================
# ENVIRONMENT SETUP
# =============================================================================

def setup_environment():
    """Set up nnUNet environment variables and directories."""
    for d in [NNUNET_RAW, NNUNET_PREPROCESSED, NNUNET_RESULTS, OUTPUT_DIR]:
        d.mkdir(parents=True, exist_ok=True)

    os.environ["nnUNet_raw"] = str(NNUNET_RAW)
    os.environ["nnUNet_preprocessed"] = str(NNUNET_PREPROCESSED)
    os.environ["nnUNet_results"] = str(NNUNET_RESULTS)
    os.environ["nnUNet_compile"] = "false"
    os.environ["nnUNet_USE_BLOSC2"] = "1"

    print(f"nnUNet_raw: {NNUNET_RAW}")
    print(f"nnUNet_preprocessed: {NNUNET_PREPROCESSED}")
    print(f"nnUNet_results: {NNUNET_RESULTS}")


def link_prepared_preprocessed() -> bool:
    """
    Link pre-prepared preprocessed data if available.
    This saves 1-2 hours of preprocessing time!
    """
    if not PREPARED_PREPROCESSED_PATH.exists():
        print(f"No pre-prepared data at {PREPARED_PREPROCESSED_PATH}")
        return False

    # Find the dataset folder
    source_dir = PREPARED_PREPROCESSED_PATH
    dataset_folders = list(source_dir.glob("Dataset*"))
    if dataset_folders:
        source_dir = dataset_folders[0]
    elif not (source_dir / "dataset.json").exists():
        print(f"No dataset.json found in {source_dir}")
        return False

    target_dir = NNUNET_PREPROCESSED / DATASET_NAME

    if target_dir.exists():
        print(f"Preprocessed data already exists: {target_dir}")
        return True

    print(f"Linking preprocessed data from: {source_dir}")
    target_dir.mkdir(parents=True, exist_ok=True)

    # Copy metadata, symlink data files
    copy_patterns = ['*.json', '*.pkl', '*.txt']
    symlink_patterns = ['*.npz', '*.npy', '*.b2nd']

    copied = 0
    linked = 0

    for src_path in source_dir.rglob('*'):
        if src_path.is_dir():
            continue

        rel_path = src_path.relative_to(source_dir)
        dst_path = target_dir / rel_path
        dst_path.parent.mkdir(parents=True, exist_ok=True)

        is_data_file = any(src_path.match(pat) for pat in symlink_patterns)

        if is_data_file:
            if not dst_path.exists():
                dst_path.symlink_to(src_path.resolve())
                linked += 1
        else:
            if not dst_path.exists():
                shutil.copy2(src_path, dst_path)
                copied += 1

    print(f"Prepared: {copied} files copied, {linked} files symlinked")
    return True

# =============================================================================
# DATA PREPARATION
# =============================================================================

def create_spacing_json(output_path: Path, shape: tuple, spacing: tuple = (1.0, 1.0, 1.0)):
    """Create JSON sidecar with spacing info for TIFF files."""
    json_data = {"spacing": list(spacing)}
    with open(output_path, "w") as f:
        json.dump(json_data, f)


def create_dataset_json(output_dir: Path, num_training: int, file_ending: str = ".tif") -> dict:
    """Create dataset.json with ignore label support."""
    dataset_json = {
        "channel_names": {"0": "CT"},
        "labels": {"background": 0, "surface": 1, "ignore": 2},
        "numTraining": num_training,
        "file_ending": file_ending,
        "overwrite_image_reader_writer": "SimpleTiffIO"
    }

    json_path = output_dir / "dataset.json"
    with open(json_path, "w") as f:
        json.dump(dataset_json, f, indent=4)

    print(f"Created {json_path}")
    print(f"  - {num_training} training cases")
    print(f"  - Labels: background(0), surface(1), ignore(2)")

    return dataset_json


def prepare_single_case(src_path: Path, dest_path: Path, json_path: Path) -> bool:
    """Prepare a single TIFF file for nnUNet."""
    import tifffile
    try:
        with tifffile.TiffFile(src_path) as tif:
            shape = tif.pages[0].shape if len(tif.pages) == 1 else (len(tif.pages), *tif.pages[0].shape)

        if not dest_path.exists():
            dest_path.symlink_to(src_path.resolve())

        create_spacing_json(json_path, shape)
        return True
    except Exception as e:
        print(f"Error processing {src_path.name}: {e}")
        return False


def prepare_dataset(input_dir: Path, max_cases: Optional[int] = None):
    """Convert competition data to nnUNet format."""
    import tifffile
    from tqdm import tqdm

    dataset_dir = NNUNET_RAW / DATASET_NAME
    images_dir = dataset_dir / "imagesTr"
    labels_dir = dataset_dir / "labelsTr"

    images_dir.mkdir(parents=True, exist_ok=True)
    labels_dir.mkdir(parents=True, exist_ok=True)

    train_images_dir = input_dir / "train_images"
    train_labels_dir = input_dir / "train_labels"

    image_files = sorted(train_images_dir.glob("*.tif"))
    if max_cases:
        image_files = image_files[:max_cases]

    print(f"Found {len(image_files)} training cases")

    num_converted = 0
    for img_path in tqdm(image_files, desc="Preparing dataset"):
        case_id = img_path.stem
        label_path = train_labels_dir / img_path.name

        if not label_path.exists():
            continue

        img_ok = prepare_single_case(
            img_path,
            images_dir / f"{case_id}_0000.tif",
            images_dir / f"{case_id}_0000.json"
        )

        label_ok = prepare_single_case(
            label_path,
            labels_dir / f"{case_id}.tif",
            labels_dir / f"{case_id}.json"
        )

        if img_ok and label_ok:
            num_converted += 1

    create_dataset_json(dataset_dir, num_converted)
    print(f"Dataset prepared: {num_converted} cases at {dataset_dir}")
    return dataset_dir

# =============================================================================
# NNUNET COMMANDS
# =============================================================================

def run_command(cmd: str, name: str = "Command") -> bool:
    """Execute shell command."""
    print(f"Running: {cmd}")
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)

    if result.returncode != 0:
        print(f"{name} FAILED!")
        print(f"STDERR: {result.stderr[-2000:]}")
        return False

    print(f"{name} complete!")
    if result.stdout.strip():
        lines = result.stdout.strip().split('\n')
        print('\n'.join(lines[-20:]))

    return True


def run_preprocessing():
    """Run nnUNet preprocessing."""
    cmd = f"nnUNetv2_plan_and_preprocess -d {DATASET_ID:03d} -np {NUM_WORKERS}"
    cmd += f" -pl {PLANNER} -c {CONFIGURATION}"
    return run_command(cmd, "Preprocessing")


def run_training():
    """Run nnUNet training."""
    trainer = f"nnUNetTrainer_{EPOCHS}epochs" if EPOCHS != 1000 else "nnUNetTrainer"
    cmd = f"nnUNetv2_train {DATASET_ID:03d} {CONFIGURATION} {FOLD} -p {PLANS_NAME} -tr {trainer}"

    # Use all available GPUs
    import torch
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
    if num_gpus > 1:
        cmd += f" -num_gpus {num_gpus}"

    return run_command(cmd, f"Training ({EPOCHS} epochs)")


def run_inference(input_dir: Path, output_dir: Path):
    """Run inference with trained model."""
    output_dir.mkdir(parents=True, exist_ok=True)

    trainer = f"nnUNetTrainer_{EPOCHS}epochs" if EPOCHS != 1000 else "nnUNetTrainer"

    cmd = f"nnUNetv2_predict -d {DATASET_ID:03d} -c {CONFIGURATION} -f {FOLD}"
    cmd += f" -i {input_dir} -o {output_dir} -p {PLANS_NAME} -tr {trainer}"
    cmd += " --save_probabilities --verbose"

    return run_command(cmd, "Inference")

# =============================================================================
# POST-PROCESSING (CRITICAL FOR TOPOLOGY SCORE!)
# =============================================================================

def topology_aware_postprocess(prediction: "np.ndarray", config: dict) -> "np.ndarray":
    """
    Apply topology-aware post-processing to improve surface connectivity.

    This is CRITICAL for the topology-aware metric which penalizes:
    - Gaps in surfaces
    - Holes
    - Sheet-switches
    - Mergers

    Args:
        prediction: Binary prediction array (D, H, W)
        config: Post-processing configuration

    Returns:
        Cleaned prediction with better topology
    """
    import numpy as np
    from scipy import ndimage

    # 1. Connected Components Filtering
    # Remove small isolated components (noise)
    min_size = config.get("min_component_size", 3000)
    structure = ndimage.generate_binary_structure(3, 3)  # 26-connectivity
    labeled_array, num_components = ndimage.label(prediction, structure=structure)

    if num_components > 0:
        component_sizes = ndimage.sum(prediction, labeled_array, range(1, num_components + 1))
        components_to_keep = np.where(np.array(component_sizes) >= min_size)[0] + 1

        filtered_mask = np.zeros_like(prediction)
        for comp_id in components_to_keep:
            filtered_mask[labeled_array == comp_id] = 1

        removed = num_components - len(components_to_keep)
        print(f"  CC Filter: {num_components} -> {len(components_to_keep)} components ({removed} removed)")
        prediction = filtered_mask

    # 2. Morphological Closing (fill small gaps)
    # This helps maintain surface connectivity
    closing_radius = config.get("closing_radius", 3)
    if closing_radius > 0:
        struct_close = ndimage.generate_binary_structure(3, 1)
        for _ in range(closing_radius):
            prediction = ndimage.binary_dilation(prediction, struct_close)
        for _ in range(closing_radius):
            prediction = ndimage.binary_erosion(prediction, struct_close)
        print(f"  Morphological closing with radius {closing_radius}")

    # 3. Border Cleanup
    # Remove artifacts at volume edges
    border = config.get("border_cleanup", 5)
    if border > 0:
        prediction[:border, :, :] = 0
        prediction[-border:, :, :] = 0
        prediction[:, :border, :] = 0
        prediction[:, -border:, :] = 0
        prediction[:, :, :border] = 0
        prediction[:, :, -border:] = 0
        print(f"  Border cleanup: {border} pixels")

    return prediction.astype(np.uint8)


def process_predictions(pred_dir: Path, output_dir: Path, config: dict):
    """Convert nnUNet predictions to submission format with post-processing."""
    import numpy as np
    import tifffile
    from tqdm import tqdm

    output_dir.mkdir(parents=True, exist_ok=True)

    # Try NPZ files first (probability maps)
    npz_files = list(pred_dir.glob("*.npz"))
    tif_files = list(pred_dir.glob("*.tif"))

    source_files = npz_files if npz_files else tif_files
    print(f"Processing {len(source_files)} prediction files...")

    for src_path in tqdm(source_files, desc="Post-processing"):
        case_id = src_path.stem

        if src_path.suffix == ".npz":
            # Load probabilities and apply threshold
            data = np.load(src_path)
            probs = data['probabilities']
            # Class 1 is surface
            pred = (probs[1] > config.get("threshold", 0.5)).astype(np.uint8)
        else:
            pred = tifffile.imread(str(src_path)).astype(np.uint8)

        print(f"\n{case_id}: shape {pred.shape}, voxels before: {pred.sum():,}")

        # Apply topology-aware post-processing
        pred = topology_aware_postprocess(pred, config)

        print(f"  Voxels after: {pred.sum():,}")

        # Save
        tifffile.imwrite(output_dir / f"{case_id}.tif", pred)

    print(f"\nPredictions saved to: {output_dir}")

# =============================================================================
# TEST DATA PREPARATION
# =============================================================================

def prepare_test_data(input_dir: Path, output_dir: Path) -> Path:
    """Prepare test TIFF images for nnUNet inference."""
    import tifffile
    from tqdm import tqdm

    output_dir.mkdir(parents=True, exist_ok=True)
    test_images_dir = input_dir / "test_images"

    test_files = sorted(test_images_dir.glob("*.tif"))
    print(f"Found {len(test_files)} test cases")

    for img_path in tqdm(test_files, desc="Preparing test data"):
        case_id = img_path.stem
        prepare_single_case(
            img_path,
            output_dir / f"{case_id}_0000.tif",
            output_dir / f"{case_id}_0000.json"
        )

    return output_dir

# =============================================================================
# SUBMISSION
# =============================================================================

def generate_submission(predictions_dir: Path, output_zip: Path) -> Optional[Path]:
    """Create submission ZIP from predictions."""
    import tifffile
    from tqdm import tqdm

    tiff_files = sorted(predictions_dir.glob("*.tif"))

    if not tiff_files:
        print(f"No TIFF files found in {predictions_dir}")
        return None

    print(f"Creating submission ZIP with {len(tiff_files)} files...")

    with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for tiff_path in tqdm(tiff_files, desc="Zipping"):
            zipf.write(tiff_path, tiff_path.name)

    zip_size_mb = output_zip.stat().st_size / (1024 * 1024)
    print(f"Submission saved: {output_zip} ({zip_size_mb:.1f} MB)")

    return output_zip

# =============================================================================
# MAIN PIPELINE
# =============================================================================

def full_pipeline(
    do_preprocess: bool = True,
    do_train: bool = True,
    do_inference: bool = True,
    max_cases: Optional[int] = None
):
    """
    Run complete V5 WINNER pipeline.

    This pipeline is designed to win the $100K prize by:
    1. Training nnUNet (SOTA for 3D medical segmentation)
    2. Applying topology-aware post-processing
    3. Generating optimized submission
    """
    print("=" * 70)
    print("VESUVIUS V5 WINNER - Target: Beat 0.575 to win $100K!")
    print("=" * 70)
    print(f"Config: {CONFIGURATION}, Fold: {FOLD}, Epochs: {EPOCHS}")
    print(f"Post-processing: {POST_PROCESS_CONFIG}")
    print()

    # Install packages
    install_packages()

    # Setup environment
    print("\n[1/6] Setting up environment...")
    setup_environment()

    # Prepare raw data
    print("\n[2/6] Preparing training data...")
    raw_dataset_dir = NNUNET_RAW / DATASET_NAME
    if not raw_dataset_dir.exists():
        prepare_dataset(INPUT_DIR, max_cases=max_cases)
    else:
        print(f"Dataset already exists: {raw_dataset_dir}")

    # Preprocessing - try to use pre-prepared data first
    if do_preprocess:
        print("\n[3/6] Preprocessing...")
        # Try to link pre-prepared data (saves 1-2 hours!)
        if link_prepared_preprocessed():
            print("Using pre-prepared preprocessed data!")
        else:
            print("Running preprocessing from scratch (this may take 1-2 hours)...")
            if not run_preprocessing():
                print("Preprocessing failed!")
                return False
    else:
        print("\n[3/6] Skipping preprocessing...")
        link_prepared_preprocessed()  # Still try to link if available

    # Training
    if do_train:
        print(f"\n[4/6] Training nnUNet ({EPOCHS} epochs)...")
        if not run_training():
            print("Training failed!")
            return False
    else:
        print("\n[4/6] Skipping training...")

    # Inference
    if do_inference:
        print("\n[5/6] Running inference on test data...")

        # Prepare test data
        test_input_dir = WORKING_DIR / "test_input"
        prepare_test_data(INPUT_DIR, test_input_dir)

        # Run inference
        predictions_dir = WORKING_DIR / "predictions_raw"
        if not run_inference(test_input_dir, predictions_dir):
            print("Inference failed!")
            return False

        # Post-process predictions
        print("\n[6/6] Applying topology-aware post-processing...")
        final_predictions_dir = OUTPUT_DIR / "predictions_final"
        process_predictions(predictions_dir, final_predictions_dir, POST_PROCESS_CONFIG)

        # Generate submission
        submission_path = generate_submission(
            final_predictions_dir,
            OUTPUT_DIR / "submission.zip"
        )

        if submission_path:
            print(f"\n{'=' * 70}")
            print(f"V5 WINNER COMPLETE!")
            print(f"Submission: {submission_path}")
            print(f"{'=' * 70}")
    else:
        print("\n[5/6] Skipping inference...")
        print("\n[6/6] Skipping post-processing...")

    return True


# =============================================================================
# RUN
# =============================================================================

if __name__ == "__main__":
    # Full pipeline: preprocess + train + inference
    full_pipeline(
        do_preprocess=True,
        do_train=True,
        do_inference=True
    )
