#!/usr/bin/env python3
"""
================================================================================
   VESUVIUS V8 - nnUNet INFERENCE ON RUNPOD

   After training completes, run this to generate predictions.
   Then create submission.zip and download to upload to Kaggle.
================================================================================
"""

import os
import subprocess
import numpy as np
import zipfile
from pathlib import Path

# =============================================================================
# CONFIGURATION
# =============================================================================
WORKSPACE = Path("/workspace")
RESULTS_DIR = WORKSPACE / "nnunet_results"
TEST_DIR = WORKSPACE / "test_images"
OUTPUT_DIR = WORKSPACE / "predictions"

# nnUNet environment
os.environ["nnUNet_raw"] = str(WORKSPACE / "nnunet_data" / "nnUNet_raw")
os.environ["nnUNet_preprocessed"] = str(WORKSPACE / "nnunet_data" / "nnUNet_preprocessed")
os.environ["nnUNet_results"] = str(RESULTS_DIR)

# Dataset config
DATASET_ID = 100
CONFIGURATION = "3d_fullres"
PLANS_NAME = "nnUNetResEncUNetMPlans"
FOLD = "all"
EPOCHS = 250

def get_trainer_name():
    if EPOCHS is None or EPOCHS == 1000:
        return "nnUNetTrainer"
    elif EPOCHS == 1:
        return "nnUNetTrainer_1epoch"
    else:
        return f"nnUNetTrainer_{EPOCHS}epochs"

def prepare_test_input():
    """Prepare test images for nnUNet inference."""
    test_input = WORKSPACE / "test_input"
    test_input.mkdir(parents=True, exist_ok=True)

    if not TEST_DIR.exists():
        print(f"[ERROR] Test images not found: {TEST_DIR}")
        print("Download competition data first:")
        print("  kaggle competitions download vesuvius-challenge-surface-detection -p /workspace/competition --unzip")
        return None

    import tifffile
    import json

    for img_path in sorted(TEST_DIR.glob("*.tif")):
        case_id = img_path.stem
        dst = test_input / f"{case_id}_0000.tif"
        if not dst.exists():
            dst.symlink_to(img_path.resolve())

        # Create spacing JSON
        with tifffile.TiffFile(img_path) as tif:
            shape = tif.pages[0].shape if len(tif.pages) == 1 else (len(tif.pages), *tif.pages[0].shape)
        with open(test_input / f"{case_id}_0000.json", "w") as f:
            json.dump({"spacing": [1.0, 1.0, 1.0]}, f)

    print(f"Prepared {len(list(test_input.glob('*.tif')))} test images")
    return test_input

def run_inference():
    """Run nnUNet inference."""
    test_input = prepare_test_input()
    if not test_input:
        return False

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    trainer = get_trainer_name()
    cmd = f"nnUNetv2_predict -d {DATASET_ID:03d} -c {CONFIGURATION} -f {FOLD}"
    cmd += f" -i {test_input} -o {OUTPUT_DIR} -p {PLANS_NAME} -tr {trainer}"
    cmd += " --save_probabilities"

    print(f"\n{'='*60}")
    print("RUNNING INFERENCE")
    print(f"{'='*60}")
    print(f"Command: {cmd}")

    result = subprocess.run(cmd, shell=True)
    return result.returncode == 0

def create_submission():
    """Create submission.zip from predictions."""
    import tifffile

    tiff_output = WORKSPACE / "submission_tiff"
    tiff_output.mkdir(parents=True, exist_ok=True)

    # Convert NPZ to TIFF
    for npz_path in sorted(OUTPUT_DIR.glob("*.npz")):
        case_id = npz_path.stem
        data = np.load(npz_path)
        probs = data['probabilities']
        pred = np.argmax(probs, axis=0).astype(np.uint8)
        tifffile.imwrite(tiff_output / f"{case_id}.tif", pred)
        print(f"Converted: {case_id}")

    # Create ZIP
    zip_path = WORKSPACE / "submission.zip"
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
        for tif in sorted(tiff_output.glob("*.tif")):
            zf.write(tif, tif.name)
            print(f"Added to ZIP: {tif.name}")

    print(f"\n{'='*60}")
    print(f"SUBMISSION READY: {zip_path}")
    print(f"Size: {zip_path.stat().st_size / (1024*1024):.1f} MB")
    print(f"{'='*60}")
    print(f"\nDownload to local: scp -P PORT root@IP:{zip_path} ./")

if __name__ == "__main__":
    print("="*60)
    print("VESUVIUS V8 - nnUNet INFERENCE")
    print("="*60)

    if run_inference():
        create_submission()
    else:
        print("[FAILED] Inference did not complete")
