#!/usr/bin/env python3
"""
Prepara i dati per nnUNet dalla competizione Vesuvius Surface Detection.
Usa il dataset preprocessato se disponibile, altrimenti prepara da zero.
"""

import os
import json
import shutil
from pathlib import Path
import tifffile
from tqdm import tqdm

# Configurazione
DATASET_ID = 100
DATASET_NAME = f"Dataset{DATASET_ID:03d}_VesuviusSurface"

# Paths
WORKSPACE = Path("/workspace")
DATA_DIR = WORKSPACE / "data"
PREPROCESSED_DIR = WORKSPACE / "nnunet-preprocessed"

NNUNET_RAW = Path(os.environ.get("nnUNet_raw", "/workspace/nnUNet_raw"))
NNUNET_PREPROCESSED = Path(os.environ.get("nnUNet_preprocessed", "/workspace/nnUNet_preprocessed"))

def check_preprocessed():
    """Verifica se esiste dataset preprocessato."""
    if PREPROCESSED_DIR.exists():
        # Cerca la cartella Dataset
        dataset_dirs = list(PREPROCESSED_DIR.glob("Dataset*"))
        if dataset_dirs:
            print(f"✅ Dataset preprocessato trovato: {dataset_dirs[0]}")
            return dataset_dirs[0]
    return None

def link_preprocessed(source_dir):
    """Collega il dataset preprocessato."""
    target_dir = NNUNET_PREPROCESSED / DATASET_NAME

    if target_dir.exists():
        print(f"Dataset già linkato: {target_dir}")
        return True

    print(f"Linking preprocessato da {source_dir} a {target_dir}")
    target_dir.mkdir(parents=True, exist_ok=True)

    # Copia/link tutti i file
    for src_file in source_dir.rglob("*"):
        if src_file.is_file():
            rel_path = src_file.relative_to(source_dir)
            dst_file = target_dir / rel_path
            dst_file.parent.mkdir(parents=True, exist_ok=True)

            # Symlink per file grandi, copia per piccoli
            if src_file.suffix in ['.npz', '.npy', '.b2nd']:
                if not dst_file.exists():
                    dst_file.symlink_to(src_file.resolve())
            else:
                if not dst_file.exists():
                    shutil.copy2(src_file, dst_file)

    print("✅ Dataset preprocessato linkato!")
    return True

def prepare_raw_dataset():
    """Prepara il dataset raw per nnUNet."""
    dataset_dir = NNUNET_RAW / DATASET_NAME
    images_tr = dataset_dir / "imagesTr"
    labels_tr = dataset_dir / "labelsTr"

    images_tr.mkdir(parents=True, exist_ok=True)
    labels_tr.mkdir(parents=True, exist_ok=True)

    # Source
    train_images = DATA_DIR / "train_images"
    train_labels = DATA_DIR / "train_labels"

    if not train_images.exists():
        print(f"❌ Training images non trovate: {train_images}")
        return False

    image_files = sorted(train_images.glob("*.tif"))
    print(f"Trovati {len(image_files)} training cases")

    for img_path in tqdm(image_files, desc="Preparando raw data"):
        case_id = img_path.stem
        label_path = train_labels / img_path.name

        if not label_path.exists():
            continue

        # Symlink image
        img_dst = images_tr / f"{case_id}_0000.tif"
        if not img_dst.exists():
            img_dst.symlink_to(img_path.resolve())

        # Symlink label
        lbl_dst = labels_tr / f"{case_id}.tif"
        if not lbl_dst.exists():
            lbl_dst.symlink_to(label_path.resolve())

    # Crea dataset.json
    num_cases = len(list(images_tr.glob("*.tif")))
    dataset_json = {
        "channel_names": {"0": "CT"},
        "labels": {"background": 0, "surface": 1, "ignore": 2},
        "numTraining": num_cases,
        "file_ending": ".tif",
        "overwrite_image_reader_writer": "SimpleTiffIO"
    }

    with open(dataset_dir / "dataset.json", "w") as f:
        json.dump(dataset_json, f, indent=4)

    print(f"✅ Raw dataset preparato: {num_cases} cases")
    return True

def prepare_test_data():
    """Prepara i dati di test."""
    test_input = WORKSPACE / "test_input"
    test_input.mkdir(parents=True, exist_ok=True)

    test_images = DATA_DIR / "test_images"
    if not test_images.exists():
        print(f"❌ Test images non trovate: {test_images}")
        return None

    for img_path in tqdm(sorted(test_images.glob("*.tif")), desc="Preparando test data"):
        case_id = img_path.stem
        dst = test_input / f"{case_id}_0000.tif"
        if not dst.exists():
            dst.symlink_to(img_path.resolve())

    print(f"✅ Test data preparato: {test_input}")
    return test_input

def main():
    print("=" * 60)
    print("PREPARAZIONE DATI nnUNet")
    print("=" * 60)

    # 1. Verifica se esiste preprocessato
    preprocessed = check_preprocessed()

    if preprocessed:
        # Usa preprocessato
        link_preprocessed(preprocessed)
    else:
        # Prepara da zero
        print("⚠️ Dataset preprocessato non trovato, preparo da zero...")
        prepare_raw_dataset()
        print("\n⚠️ Devi eseguire preprocessing:")
        print("nnUNetv2_plan_and_preprocess -d 100 -pl nnUNetPlannerResEncM -c 3d_fullres -np 4")

    # 2. Prepara sempre raw (serve per inference)
    prepare_raw_dataset()

    # 3. Prepara test data
    prepare_test_data()

    print("\n" + "=" * 60)
    print("✅ PREPARAZIONE COMPLETATA!")
    print("=" * 60)
    print("\nProssimo passo - Training:")
    print("nnUNetv2_train 100 3d_fullres all -p nnUNetResEncUNetMPlans -tr nnUNetTrainer_250epochs")

if __name__ == "__main__":
    main()
