"""
================================================================================
   VESUVIUS V7 - CONSERVATIVE IMPROVEMENT

   Based on V3 (0.461) with small incremental improvements:
   - Same 4x TTA (proven to work)
   - Same overlap 0.52 (proven to work)
   - Slightly lower thresholds for better recall
   - Better topology post-processing

   V3 settings: T_low=0.45, T_high=0.85, z_radius=1, xy_radius=0
   V7 settings: T_low=0.42, T_high=0.82, z_radius=2, xy_radius=1

   Target: Beat 0.461 with minimal risk
================================================================================
"""

import subprocess
import os
import gc

# Install packages
var = "/kaggle/input/vesuvius25-packages-offline-installer-v20251226/whls"
print(f"Installing from: {var}")

subprocess.run([
    "pip", "install", "--quiet",
    f"{var}/keras_nightly-3.12.0.dev2025100703-py3-none-any.whl",
    f"{var}/tifffile-2025.10.16-py3-none-any.whl",
    f"{var}/imagecodecs-2025.11.11-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
    f"{var}/medicai-0.0.3-py3-none-any.whl",
    "--no-index",
    "--find-links", var
], check=False, capture_output=True)

# Protobuf patch
try:
    from google.protobuf import message_factory as _message_factory
    if not hasattr(_message_factory.MessageFactory, "GetPrototype"):
        from google.protobuf.message_factory import GetMessageClass
        def _GetPrototype(self, descriptor):
            return GetMessageClass(descriptor)
        _message_factory.MessageFactory.GetPrototype = _GetPrototype
except:
    pass

os.environ["KERAS_BACKEND"] = "jax"

import keras
from medicai.transforms import Compose, ScaleIntensityRange
from medicai.models import TransUNet
from medicai.utils.inference import SlidingWindowInference

import numpy as np
import pandas as pd
import zipfile
import tifffile
import scipy.ndimage as ndi
from skimage.morphology import remove_small_objects

print(f"Backend: {keras.config.backend()}, Keras: {keras.version()}")

# ============================================================================
# CONFIG - V7 CONSERVATIVE (Based on V3)
# ============================================================================
root_dir = "/kaggle/input/vesuvius-challenge-surface-detection"
test_dir = f"{root_dir}/test_images"
output_dir = "/kaggle/working/submission_masks"
zip_path = "/kaggle/working/submission.zip"
os.makedirs(output_dir, exist_ok=True)

# Model weights
MODEL_PATH = "/kaggle/input/colab-a-162v4-gpu-transunet-seresnext101-x160/model.weights.h5"

# V7 SETTINGS - Conservative improvements over V3
OVERLAP = 0.52  # Same as V3 (proven)

# V7 threshold tuning (slightly more permissive than V3)
# V3: T_low=0.45, T_high=0.85
# V7: T_low=0.42, T_high=0.82 (small decrease)
T_LOW = 0.42
T_HIGH = 0.82

# V7 closing (slightly more aggressive for connectivity)
# V3: z_radius=1, xy_radius=0
# V7: z_radius=2, xy_radius=1
Z_RADIUS = 2
XY_RADIUS = 1

# Dust removal (same as V3)
DUST_MIN_SIZE = 100

# Load test data
test_df = pd.read_csv(f"{root_dir}/test.csv")
print(f"Test samples: {len(test_df)}")

print(f"\n{'='*60}")
print("V7 CONSERVATIVE - Incremental Improvement over V3")
print(f"{'='*60}")
print(f"  - Overlap: {OVERLAP} (same as V3)")
print(f"  - TTA: 4x (same as V3)")
print(f"  - Threshold: T_low={T_LOW}, T_high={T_HIGH} (V3: 0.45, 0.85)")
print(f"  - Closing: z={Z_RADIUS}, xy={XY_RADIUS} (V3: 1, 0)")

# ============================================================================
# TRANSFORMATION
# ============================================================================
def val_transformation(image):
    data = {"image": image}
    pipeline = Compose([
        ScaleIntensityRange(
            keys=["image"],
            a_min=0, a_max=255,
            b_min=0, b_max=1,
            clip=True,
        ),
    ])
    return pipeline(data)["image"]

# ============================================================================
# MODEL
# ============================================================================
print(f"\nLoading model from: {MODEL_PATH}")

model = TransUNet(
    input_shape=(160, 160, 160, 1),
    encoder_name='seresnext101',
    classifier_activation='softmax',
    num_classes=3,
)
model.load_weights(MODEL_PATH)
print(f"Model loaded! Params: {model.count_params() / 1e6:.1f}M")

# Sliding window inference (same overlap as V3)
pred_fn = SlidingWindowInference(
    model,
    roi_size=(160, 160, 160),
    num_classes=3,
    mode="gaussian",
    overlap=OVERLAP,
    sw_batch_size=1
)

# ============================================================================
# HELPERS (same as V3)
# ============================================================================
def load_volume(path):
    vol = tifffile.imread(path)
    vol = vol.astype(np.float32)
    vol = vol[None, ..., None]
    return vol

def rot90_volume(vol, k):
    if vol.ndim == 5:
        return np.rot90(vol, k=-k, axes=(2, 3))
    else:
        return np.rot90(vol, k=-k, axes=(1, 2))

def unrot90_volume(vol, k):
    return rot90_volume(vol, (4 - k) % 4)

def predict_probs_tta_rot(sample):
    """4x rotation TTA (same as V3)"""
    probs_accum = []
    for k in range(4):
        s_rot = rot90_volume(sample, k)
        out = pred_fn(s_rot)
        out = np.asarray(out)
        probs = out[0, ..., 1]
        probs = unrot90_volume(probs, k)
        probs_accum.append(probs)
    return np.mean(probs_accum, axis=0)

def build_anisotropic_struct(z_radius, xy_radius):
    z, r = z_radius, xy_radius
    if z == 0 and r == 0:
        return None
    depth = 2 * z + 1 if z > 0 else 1
    size = 2 * r + 1 if r > 0 else 1
    struct = np.zeros((depth, size, size), dtype=bool)
    cz = z if z > 0 else 0
    cy = r if r > 0 else 0
    cx = r if r > 0 else 0
    for dz in range(-z if z > 0 else 0, z + 1 if z > 0 else 1):
        for dy in range(-r if r > 0 else 0, r + 1 if r > 0 else 1):
            for dx in range(-r if r > 0 else 0, r + 1 if r > 0 else 1):
                if r == 0 or dy * dy + dx * dx <= r * r:
                    struct[cz + dz, cy + dy, cx + dx] = True
    return struct

def topo_postprocess_v7(probs):
    """
    V7: Improved post-processing
    - Slightly lower thresholds for better recall
    - More aggressive closing for better connectivity
    """
    strong = probs >= T_HIGH
    weak = probs >= T_LOW

    if not strong.any():
        return np.zeros_like(probs, dtype=np.uint8)

    # Hysteresis thresholding
    struct_hyst = ndi.generate_binary_structure(3, 3)
    mask = ndi.binary_propagation(strong, mask=weak, structure=struct_hyst)

    if not mask.any():
        return np.zeros_like(probs, dtype=np.uint8)

    # Morphological closing (V7: slightly more aggressive)
    if Z_RADIUS > 0 or XY_RADIUS > 0:
        struct_close = build_anisotropic_struct(Z_RADIUS, XY_RADIUS)
        if struct_close is not None:
            mask = ndi.binary_closing(mask, structure=struct_close)

    # Remove small objects
    if DUST_MIN_SIZE > 0:
        mask = remove_small_objects(mask.astype(bool), min_size=DUST_MIN_SIZE)

    return mask.astype(np.uint8)

def predict(sample):
    """V7 prediction pipeline"""
    probs_fg = predict_probs_tta_rot(sample)
    final = topo_postprocess_v7(probs_fg)
    return final

# ============================================================================
# INFERENCE
# ============================================================================
print(f"\n{'='*60}")
print(f"V7 INFERENCE - {len(test_df)} images")
print(f"{'='*60}")

with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as z:
    for idx, image_id in enumerate(test_df["id"]):
        tif_path = f"{test_dir}/{image_id}.tif"
        print(f"\n[{idx+1}/{len(test_df)}] {image_id}")

        volume = load_volume(tif_path)
        volume = val_transformation(volume)
        output = predict(volume)

        out_path = f"{output_dir}/{image_id}.tif"
        tifffile.imwrite(out_path, output.astype(np.uint8))
        z.write(out_path, arcname=f"{image_id}.tif")
        os.remove(out_path)

        print(f"    Shape: {output.shape}, Voxels: {output.sum():,}")

        # Cleanup
        del volume, output
        gc.collect()

print(f"\n{'='*60}")
print(f"V7 CONSERVATIVE DONE! Submission: {zip_path}")
print(f"{'='*60}")
