"""
================================================================================
   VESUVIUS V3 - WORKING BASELINE
   Using SEResNeXt101 model (LB ~0.460)
================================================================================
"""

import subprocess
import os

# Install packages
var = "/kaggle/input/vesuvius25-packages-offline-installer-v20251226/whls"
print(f"Installing from: {var}")

subprocess.run([
    "pip", "install", "--quiet",
    f"{var}/keras_nightly-3.12.0.dev2025100703-py3-none-any.whl",
    f"{var}/tifffile-2025.10.16-py3-none-any.whl",
    f"{var}/imagecodecs-2025.11.11-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
    f"{var}/medicai-0.0.3-py3-none-any.whl",
    "--no-index",
    "--find-links", var
], check=False, capture_output=True)

# Protobuf patch
try:
    from google.protobuf import message_factory as _message_factory
    if not hasattr(_message_factory.MessageFactory, "GetPrototype"):
        from google.protobuf.message_factory import GetMessageClass
        def _GetPrototype(self, descriptor):
            return GetMessageClass(descriptor)
        _message_factory.MessageFactory.GetPrototype = _GetPrototype
except:
    pass

os.environ["KERAS_BACKEND"] = "jax"

import keras
from medicai.transforms import Compose, ScaleIntensityRange
from medicai.models import TransUNet
from medicai.utils.inference import SlidingWindowInference

import numpy as np
import pandas as pd
import zipfile
import tifffile
import scipy.ndimage as ndi
from skimage.morphology import remove_small_objects

print(f"Backend: {keras.config.backend()}, Keras: {keras.version()}")

# ============================================================================
# CONFIG
# ============================================================================
root_dir = "/kaggle/input/vesuvius-challenge-surface-detection"
test_dir = f"{root_dir}/test_images"
output_dir = "/kaggle/working/submission_masks"
zip_path = "/kaggle/working/submission.zip"
os.makedirs(output_dir, exist_ok=True)

# Model weights - SEResNeXt101 (LB 0.460)
MODEL_PATH = "/kaggle/input/colab-a-162v4-gpu-transunet-seresnext101-x160/model.weights.h5"

# Load test data
test_df = pd.read_csv(f"{root_dir}/test.csv")
print(f"Test samples: {len(test_df)}")

# Debug: show what's available
print("\nAvailable inputs:")
for item in os.listdir("/kaggle/input"):
    print(f"  - {item}")
    subpath = f"/kaggle/input/{item}"
    if os.path.isdir(subpath):
        for f in os.listdir(subpath)[:3]:
            print(f"      {f}")

# ============================================================================
# TRANSFORMATION
# ============================================================================
def val_transformation(image):
    data = {"image": image}
    pipeline = Compose([
        ScaleIntensityRange(
            keys=["image"],
            a_min=0, a_max=255,
            b_min=0, b_max=1,
            clip=True,
        ),
    ])
    return pipeline(data)["image"]

# ============================================================================
# MODEL
# ============================================================================
print(f"\nLoading model from: {MODEL_PATH}")
print(f"File exists: {os.path.exists(MODEL_PATH)}")

model = TransUNet(
    input_shape=(160, 160, 160, 1),
    encoder_name='seresnext101',
    classifier_activation='softmax',
    num_classes=3,
)
model.load_weights(MODEL_PATH)
print(f"Model loaded! Params: {model.count_params() / 1e6:.1f}M")

# Sliding window inference
pred_fn = SlidingWindowInference(
    model,
    roi_size=(160, 160, 160),
    num_classes=3,
    mode="gaussian",
    overlap=0.52,
    sw_batch_size=1
)

# ============================================================================
# HELPERS
# ============================================================================
def load_volume(path):
    vol = tifffile.imread(path)
    vol = vol.astype(np.float32)
    vol = vol[None, ..., None]
    return vol

def rot90_volume(vol, k):
    if vol.ndim == 5:
        return np.rot90(vol, k=-k, axes=(2, 3))
    else:
        return np.rot90(vol, k=-k, axes=(1, 2))

def unrot90_volume(vol, k):
    return rot90_volume(vol, (4 - k) % 4)

def predict_probs_tta_rot(sample):
    """4x rotation TTA"""
    probs_accum = []
    for k in range(4):
        s_rot = rot90_volume(sample, k)
        out = pred_fn(s_rot)
        out = np.asarray(out)
        probs = out[0, ..., 1]
        probs = unrot90_volume(probs, k)
        probs_accum.append(probs)
    return np.mean(probs_accum, axis=0)

def build_anisotropic_struct(z_radius, xy_radius):
    z, r = z_radius, xy_radius
    if z == 0 and r == 0:
        return None
    depth = 2 * z + 1 if z > 0 else 1
    size = 2 * r + 1 if r > 0 else 1
    struct = np.zeros((depth, size, size), dtype=bool)
    cz = z if z > 0 else 0
    cy = r if r > 0 else 0
    cx = r if r > 0 else 0
    for dz in range(-z if z > 0 else 0, z + 1 if z > 0 else 1):
        for dy in range(-r if r > 0 else 0, r + 1 if r > 0 else 1):
            for dx in range(-r if r > 0 else 0, r + 1 if r > 0 else 1):
                if r == 0 or dy * dy + dx * dx <= r * r:
                    struct[cz + dz, cy + dy, cx + dx] = True
    return struct

def topo_postprocess(probs, T_low=0.45, T_high=0.85, z_radius=1, xy_radius=0, dust_min_size=100):
    strong = probs >= T_high
    weak = probs >= T_low

    if not strong.any():
        return np.zeros_like(probs, dtype=np.uint8)

    struct_hyst = ndi.generate_binary_structure(3, 3)
    mask = ndi.binary_propagation(strong, mask=weak, structure=struct_hyst)

    if not mask.any():
        return np.zeros_like(probs, dtype=np.uint8)

    if z_radius > 0 or xy_radius > 0:
        struct_close = build_anisotropic_struct(z_radius, xy_radius)
        if struct_close is not None:
            mask = ndi.binary_closing(mask, structure=struct_close)

    if dust_min_size > 0:
        mask = remove_small_objects(mask.astype(bool), min_size=dust_min_size)

    return mask.astype(np.uint8)

def predict(sample):
    probs_fg = predict_probs_tta_rot(sample)
    final = topo_postprocess(probs_fg)
    return final

# ============================================================================
# INFERENCE
# ============================================================================
print(f"\n{'='*60}")
print(f"INFERENCE - {len(test_df)} images")
print(f"{'='*60}")

with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as z:
    for idx, image_id in enumerate(test_df["id"]):
        tif_path = f"{test_dir}/{image_id}.tif"
        print(f"\n[{idx+1}/{len(test_df)}] {image_id}")

        volume = load_volume(tif_path)
        volume = val_transformation(volume)
        output = predict(volume)

        out_path = f"{output_dir}/{image_id}.tif"
        tifffile.imwrite(out_path, output.astype(np.uint8))
        z.write(out_path, arcname=f"{image_id}.tif")
        os.remove(out_path)

        print(f"    Shape: {output.shape}, Voxels: {output.sum():,}")

print(f"\n{'='*60}")
print(f"DONE! Submission: {zip_path}")
print(f"{'='*60}")
