"""
================================================================================
   🏆 VESUVIUS CHALLENGE - ENSEMBLE V1 🏆
   Combining TransUNet + Cascaded UNet for maximum performance
   Target: Beat 0.575 (current #1)
================================================================================
"""

# ============================================================================
# STRATEGY: Use pre-trained models (minimal GPU usage)
# - Model 1: TransUNet SEResNeXt101 (LB ~0.46)
# - Model 2: TransUNet SEResNeXt50 (LB ~0.50)
# - Ensemble + TTA + Optimized post-processing
# ============================================================================

import os
os.environ["KERAS_BACKEND"] = "jax"

# Install dependencies
print("[1/7] Installing dependencies...")
import subprocess
subprocess.run([
    "pip", "install", "--quiet",
    "/kaggle/input/vesuvius25-packages-offline-installer-v20251226/whls/keras_nightly-3.12.0.dev2025100703-py3-none-any.whl",
    "/kaggle/input/vesuvius25-packages-offline-installer-v20251226/whls/tifffile-2025.10.16-py3-none-any.whl",
    "/kaggle/input/vesuvius25-packages-offline-installer-v20251226/whls/imagecodecs-2025.11.11-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
    "/kaggle/input/vesuvius25-packages-offline-installer-v20251226/whls/medicai-0.0.3-py3-none-any.whl",
    "--no-index",
    "--find-links", "/kaggle/input/vesuvius25-packages-offline-installer-v20251226/whls"
], check=True, capture_output=True)

import warnings
warnings.filterwarnings('ignore')

import keras
from medicai.transforms import Compose, ScaleIntensityRange
from medicai.models import TransUNet
from medicai.utils.inference import SlidingWindowInference

import numpy as np
import pandas as pd
import zipfile
import tifffile
import scipy.ndimage as ndi
from skimage.morphology import remove_small_objects

print(f"  Backend: {keras.config.backend()}, Keras: {keras.version()}")

# ============================================================================
# CONFIG
# ============================================================================
print("\n[2/7] Configuration...")

ROOT_DIR = "/kaggle/input/vesuvius-challenge-surface-detection"
TEST_DIR = f"{ROOT_DIR}/test_images"
OUTPUT_DIR = "/kaggle/working/submission_masks"
ZIP_PATH = "/kaggle/working/submission.zip"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Models to ensemble (add more as available)
MODELS_CONFIG = [
    {
        "name": "TransUNet_SEResNeXt101",
        "weights": "/kaggle/input/colab-a-162v4-gpu-transunet-seresnext101-x160/model.weights.h5",
        "encoder": "seresnext101",
        "input_shape": (160, 160, 160, 1),
        "weight": 0.6  # Higher weight for better model
    },
    {
        "name": "TransUNet_SEResNeXt50",
        "weights": "/kaggle/input/train-vesuvius-surface-3d-detection-on-tpu/model.weights.h5",
        "encoder": "seresnext50",
        "input_shape": (160, 160, 160, 1),
        "weight": 0.4
    },
]

# Post-processing params (optimized)
POST_CONFIG = {
    "T_low": 0.40,           # Lower threshold for hysteresis
    "T_high": 0.80,          # Higher threshold for hysteresis
    "z_radius": 2,           # Z-axis closing radius
    "xy_radius": 1,          # XY closing radius
    "dust_min_size": 150,    # Remove small components
}

# TTA config
TTA_ROTATIONS = True  # 4x rotation TTA
TTA_FLIPS = True      # 3x flip TTA (increases to 8x with rotations)

# ============================================================================
# LOAD TEST DATA
# ============================================================================
print("\n[3/7] Loading test data...")
test_df = pd.read_csv(f"{ROOT_DIR}/test.csv")
print(f"  Test samples: {len(test_df)}")

# ============================================================================
# TRANSFORMATION
# ============================================================================
def val_transformation(image):
    data = {"image": image}
    pipeline = Compose([
        ScaleIntensityRange(
            keys=["image"],
            a_min=0, a_max=255,
            b_min=0, b_max=1,
            clip=True,
        ),
    ])
    return pipeline(data)["image"]

# ============================================================================
# LOAD MODELS
# ============================================================================
print("\n[4/7] Loading models...")

models = []
model_weights = []

for cfg in MODELS_CONFIG:
    try:
        if os.path.exists(cfg["weights"]):
            model = TransUNet(
                input_shape=cfg["input_shape"],
                encoder_name=cfg["encoder"],
                classifier_activation='softmax',
                num_classes=3,
            )
            model.load_weights(cfg["weights"])
            models.append(model)
            model_weights.append(cfg["weight"])
            print(f"  ✓ Loaded {cfg['name']} (weight={cfg['weight']})")
        else:
            print(f"  ✗ Weights not found: {cfg['name']}")
    except Exception as e:
        print(f"  ✗ Error loading {cfg['name']}: {e}")

# Normalize weights
total_weight = sum(model_weights)
model_weights = [w / total_weight for w in model_weights]
print(f"  Total models: {len(models)}")

# Create sliding window inference for each model
swi_list = []
for model in models:
    swi = SlidingWindowInference(
        model,
        num_classes=3,
        roi_size=(160, 160, 160),
        sw_batch_size=1,
        mode='gaussian',
        overlap=0.55,  # Slightly higher overlap for better predictions
    )
    swi_list.append(swi)

# ============================================================================
# TTA HELPERS
# ============================================================================
def rot90_volume(vol, k):
    """Rotate volume k times 90° clockwise in HW plane."""
    if vol.ndim == 5:
        return np.rot90(vol, k=-k, axes=(2, 3))
    else:
        return np.rot90(vol, k=-k, axes=(1, 2))

def unrot90_volume(vol, k):
    return rot90_volume(vol, (4 - k) % 4)

def flip_volume(vol, axis):
    """Flip volume along specified axis."""
    if vol.ndim == 5:
        return np.flip(vol, axis=axis+1)  # +1 for batch dim
    else:
        return np.flip(vol, axis=axis)

# ============================================================================
# POST-PROCESSING
# ============================================================================
def build_anisotropic_struct(z_radius: int, xy_radius: int):
    """Build anisotropic structuring element for 3D closing."""
    z, r = z_radius, xy_radius

    if z == 0 and r == 0:
        return None

    if z == 0 and r > 0:
        size = 2 * r + 1
        struct = np.zeros((1, size, size), dtype=bool)
        cy, cx = r, r
        for dy in range(-r, r + 1):
            for dx in range(-r, r + 1):
                if dy * dy + dx * dx <= r * r:
                    struct[0, cy + dy, cx + dx] = True
        return struct

    if z > 0 and r == 0:
        struct = np.zeros((2 * z + 1, 1, 1), dtype=bool)
        struct[:, 0, 0] = True
        return struct

    depth = 2 * z + 1
    size = 2 * r + 1
    struct = np.zeros((depth, size, size), dtype=bool)
    cz, cy, cx = z, r, r
    for dz in range(-z, z + 1):
        for dy in range(-r, r + 1):
            for dx in range(-r, r + 1):
                if dy * dy + dx * dx <= r * r:
                    struct[cz + dz, cy + dy, cx + dx] = True
    return struct

def topo_postprocess(probs, cfg=POST_CONFIG):
    """Topology-aware post-processing."""
    T_low = cfg["T_low"]
    T_high = cfg["T_high"]
    z_radius = cfg["z_radius"]
    xy_radius = cfg["xy_radius"]
    dust_min_size = cfg["dust_min_size"]

    # Step 1: 3D Hysteresis thresholding
    strong = probs >= T_high
    weak = probs >= T_low

    if not strong.any():
        return np.zeros_like(probs, dtype=np.uint8)

    struct_hyst = ndi.generate_binary_structure(3, 3)
    mask = ndi.binary_propagation(strong, mask=weak, structure=struct_hyst)

    if not mask.any():
        return np.zeros_like(probs, dtype=np.uint8)

    # Step 2: 3D Anisotropic closing
    if z_radius > 0 or xy_radius > 0:
        struct_close = build_anisotropic_struct(z_radius, xy_radius)
        if struct_close is not None:
            mask = ndi.binary_closing(mask, structure=struct_close)

    # Step 3: Hole filling (layer by layer)
    for i in range(mask.shape[0]):
        mask[i] = ndi.binary_fill_holes(mask[i])

    # Step 4: Dust removal
    if dust_min_size > 0:
        mask = remove_small_objects(mask.astype(bool), min_size=dust_min_size)

    return mask.astype(np.uint8)

# ============================================================================
# PREDICTION WITH ENSEMBLE + TTA
# ============================================================================
def predict_ensemble_tta(sample):
    """
    Predict with ensemble of models + TTA.
    sample: (1, D, H, W, 1)
    """
    all_probs = []

    # For each model
    for swi, weight in zip(swi_list, model_weights):
        model_probs = []

        # Rotation TTA
        rotations = range(4) if TTA_ROTATIONS else [0]
        for k in rotations:
            s_rot = rot90_volume(sample, k)

            # Predict
            out = swi(s_rot)
            out = np.asarray(out)
            probs = out[0, ..., 1]  # Foreground class

            # Un-rotate
            probs = unrot90_volume(probs, k)
            model_probs.append(probs)

        # Flip TTA (if enabled)
        if TTA_FLIPS:
            for axis in [0, 1, 2]:  # D, H, W
                s_flip = flip_volume(sample, axis)
                out = swi(s_flip)
                out = np.asarray(out)
                probs = out[0, ..., 1]
                probs = flip_volume(probs, axis)
                model_probs.append(probs)

        # Average TTA predictions for this model
        model_avg = np.mean(model_probs, axis=0)
        all_probs.append(model_avg * weight)

    # Weighted ensemble
    final_probs = np.sum(all_probs, axis=0)
    return final_probs

def load_volume(path):
    vol = tifffile.imread(path)
    vol = vol.astype(np.float32)
    vol = vol[None, ..., None]
    return vol

def predict(sample):
    """Full prediction pipeline."""
    # Get ensemble + TTA probabilities
    probs = predict_ensemble_tta(sample)

    # Apply post-processing
    output = topo_postprocess(probs, POST_CONFIG)

    return output

# ============================================================================
# MAIN INFERENCE
# ============================================================================
print("\n[5/7] Running inference...")

with zipfile.ZipFile(ZIP_PATH, "w", compression=zipfile.ZIP_DEFLATED) as z:
    for idx, image_id in enumerate(test_df["id"]):
        tif_path = f"{TEST_DIR}/{image_id}.tif"

        print(f"  Processing {idx+1}/{len(test_df)}: {image_id}")

        # Load and transform
        volume = load_volume(tif_path)
        volume = val_transformation(volume)

        # Predict
        output = predict(volume)

        # Save to zip
        out_path = f"{OUTPUT_DIR}/{image_id}.tif"
        tifffile.imwrite(out_path, output.astype(np.uint8))
        z.write(out_path, arcname=f"{image_id}.tif")
        os.remove(out_path)

        # Stats
        print(f"    Shape: {output.shape}, Voxels: {output.sum():,}")

print(f"\n[6/7] Submission saved: {ZIP_PATH}")

# ============================================================================
# SUMMARY
# ============================================================================
print("\n" + "="*60)
print("   🏆 VESUVIUS ENSEMBLE V1 COMPLETE 🏆")
print("="*60)
print(f"   Models: {len(models)}")
print(f"   TTA: Rotations={TTA_ROTATIONS}, Flips={TTA_FLIPS}")
print(f"   Post-processing: T_low={POST_CONFIG['T_low']}, T_high={POST_CONFIG['T_high']}")
print("="*60)
