"""
================================================================================
   VESUVIUS V2 - BASELINE
   Based on: LB 0.537 notebook approach
================================================================================
"""

# Install packages - using available dataset
var="/kaggle/input/vesuvius25-packages-offline-installer-v20251226/whls"
import subprocess
import os

# Check what's available
print("Checking available packages...")
if os.path.exists(var):
    print(f"Found package dir: {var}")
    print(os.listdir(var))
else:
    print(f"Package dir not found: {var}")
    # Try alternative
    var = "/kaggle/input/vesuvius25-packages-offline-installer-v20251226"
    if os.path.exists(var):
        print(f"Found alternative: {var}")
        print(os.listdir(var))

# Install
subprocess.run([
    "pip", "install", "--quiet",
    f"{var}/keras_nightly-3.12.0.dev2025100703-py3-none-any.whl",
    f"{var}/tifffile-2025.10.16-py3-none-any.whl",
    f"{var}/imagecodecs-2025.11.11-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
    f"{var}/medicai-0.0.3-py3-none-any.whl",
    "--no-index",
    "--find-links", var
], check=False, capture_output=True)

# Protobuf compatibility patch
try:
    from google.protobuf import message_factory as _message_factory
    if not hasattr(_message_factory.MessageFactory, "GetPrototype"):
        from google.protobuf.message_factory import GetMessageClass
        def _GetPrototype(self, descriptor):
            return GetMessageClass(descriptor)
        _message_factory.MessageFactory.GetPrototype = _GetPrototype
        print("Patched protobuf")
except Exception as e:
    print(f"Protobuf patch: {e}")

os.environ["KERAS_BACKEND"] = "jax"

import keras
from medicai.transforms import Compose, ScaleIntensityRange
from medicai.models import TransUNet
from medicai.utils.inference import SlidingWindowInference

import numpy as np
import pandas as pd
import zipfile
import tifffile
import scipy.ndimage as ndi
from skimage.morphology import remove_small_objects

print(f"Backend: {keras.config.backend()}, Keras: {keras.version()}")

# ============================================================================
# CONFIG
# ============================================================================
root_dir = "/kaggle/input/vesuvius-challenge-surface-detection"
test_dir = f"{root_dir}/test_images"
output_dir = "/kaggle/working/submission_masks"
zip_path = "/kaggle/working/submission.zip"
os.makedirs(output_dir, exist_ok=True)

# Load test data
test_df = pd.read_csv(f"{root_dir}/test.csv")
print(f"Test samples: {len(test_df)}")
print(f"Test IDs: {list(test_df['id'])}")

# List test images directly
print(f"\nTest images in folder:")
if os.path.exists(test_dir):
    test_files = os.listdir(test_dir)
    print(f"  Found {len(test_files)} files: {test_files[:5]}...")
else:
    print(f"  Test dir not found!")

# ============================================================================
# TRANSFORMATION
# ============================================================================
def val_transformation(image):
    data = {"image": image}
    pipeline = Compose([
        ScaleIntensityRange(
            keys=["image"],
            a_min=0, a_max=255,
            b_min=0, b_max=1,
            clip=True,
        ),
    ])
    return pipeline(data)["image"]

# ============================================================================
# MODEL - TransUNet SEResNeXt50
# ============================================================================
print("\nLoading model...")

# Try multiple paths for model weights
model_paths = [
    "/kaggle/input/train-vesuvius-surface-3d-detection-on-tpu/model.weights.h5",
    "/kaggle/input/colab-a-162v4-gpu-transunet-seresnext101-x160/model.weights.h5",
]

model_path = None
encoder_name = 'seresnext50'

for mp in model_paths:
    if os.path.exists(mp):
        model_path = mp
        if 'seresnext101' in mp:
            encoder_name = 'seresnext101'
        print(f"  Found weights: {mp}")
        break

if model_path is None:
    print("ERROR: No model weights found!")
    print("Available inputs:")
    for item in os.listdir("/kaggle/input"):
        print(f"  - {item}")
    raise FileNotFoundError("No model weights!")

model = TransUNet(
    input_shape=(160, 160, 160, 1),
    encoder_name=encoder_name,
    classifier_activation='softmax',
    num_classes=3,
)
model.load_weights(model_path)
print(f"Model params: {model.count_params() / 1e6:.1f}M")

# Sliding window inference
pred_fn = SlidingWindowInference(
    model,
    roi_size=(160, 160, 160),
    num_classes=3,
    mode="gaussian",
    overlap=0.50,
    sw_batch_size=1
)

# ============================================================================
# HELPERS
# ============================================================================
def load_volume(path):
    vol = tifffile.imread(path)
    vol = vol.astype(np.float32)
    vol = vol[None, ..., None]
    return vol

def rot90_volume(vol, k):
    if vol.ndim == 5:
        return np.rot90(vol, k=-k, axes=(2, 3))
    else:
        return np.rot90(vol, k=-k, axes=(1, 2))

def unrot90_volume(vol, k):
    return rot90_volume(vol, (4 - k) % 4)

def predict_probs_tta_rot(sample):
    """4x rotation TTA"""
    probs_accum = []
    for k in range(4):
        s_rot = rot90_volume(sample, k)
        out = pred_fn(s_rot)
        out = np.asarray(out)
        probs = out[0, ..., 1]
        probs = unrot90_volume(probs, k)
        probs_accum.append(probs)
    return np.mean(probs_accum, axis=0)

def build_anisotropic_struct(z_radius, xy_radius):
    z, r = z_radius, xy_radius
    if z == 0 and r == 0:
        return None
    if z == 0 and r > 0:
        size = 2 * r + 1
        struct = np.zeros((1, size, size), dtype=bool)
        cy, cx = r, r
        for dy in range(-r, r + 1):
            for dx in range(-r, r + 1):
                if dy * dy + dx * dx <= r * r:
                    struct[0, cy + dy, cx + dx] = True
        return struct
    if z > 0 and r == 0:
        struct = np.zeros((2 * z + 1, 1, 1), dtype=bool)
        struct[:, 0, 0] = True
        return struct
    depth = 2 * z + 1
    size = 2 * r + 1
    struct = np.zeros((depth, size, size), dtype=bool)
    cz, cy, cx = z, r, r
    for dz in range(-z, z + 1):
        for dy in range(-r, r + 1):
            for dx in range(-r, r + 1):
                if dy * dy + dx * dx <= r * r:
                    struct[cz + dz, cy + dy, cx + dx] = True
    return struct

def topo_postprocess(probs, T_low=0.45, T_high=0.85, z_radius=1, xy_radius=0, dust_min_size=100):
    strong = probs >= T_high
    weak = probs >= T_low

    if not strong.any():
        return np.zeros_like(probs, dtype=np.uint8)

    struct_hyst = ndi.generate_binary_structure(3, 3)
    mask = ndi.binary_propagation(strong, mask=weak, structure=struct_hyst)

    if not mask.any():
        return np.zeros_like(probs, dtype=np.uint8)

    if z_radius > 0 or xy_radius > 0:
        struct_close = build_anisotropic_struct(z_radius, xy_radius)
        if struct_close is not None:
            mask = ndi.binary_closing(mask, structure=struct_close)

    if dust_min_size > 0:
        mask = remove_small_objects(mask.astype(bool), min_size=dust_min_size)

    return mask.astype(np.uint8)

def predict(sample):
    probs_fg = predict_probs_tta_rot(sample)
    final = topo_postprocess(probs_fg)
    return final

# ============================================================================
# INFERENCE
# ============================================================================
print(f"\n{'='*60}")
print(f"Running inference on {len(test_df)} images...")
print(f"{'='*60}")

with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as z:
    for idx, image_id in enumerate(test_df["id"]):
        tif_path = f"{test_dir}/{image_id}.tif"
        print(f"\n[{idx+1}/{len(test_df)}] Processing {image_id}...")
        print(f"    Path: {tif_path}")
        print(f"    Exists: {os.path.exists(tif_path)}")

        if not os.path.exists(tif_path):
            print(f"    ERROR: File not found!")
            continue

        volume = load_volume(tif_path)
        print(f"    Volume shape: {volume.shape}")

        volume = val_transformation(volume)
        output = predict(volume)

        out_path = f"{output_dir}/{image_id}.tif"
        tifffile.imwrite(out_path, output.astype(np.uint8))
        z.write(out_path, arcname=f"{image_id}.tif")
        os.remove(out_path)

        print(f"    Output shape: {output.shape}, Voxels: {output.sum():,}")

print(f"\n{'='*60}")
print(f"Submission saved: {zip_path}")
print(f"{'='*60}")
