"""
================================================================================
   VESUVIUS V4 - MAXIMA EDITION
   Maximum optimization without additional models:
   - 8x TTA (4 rotation + flip)
   - Higher overlap (0.625)
   - Optimized post-processing
   - Multi-threshold ensemble
================================================================================
"""

import subprocess
import os

# Install packages
var = "/kaggle/input/vesuvius25-packages-offline-installer-v20251226/whls"
print(f"Installing from: {var}")

subprocess.run([
    "pip", "install", "--quiet",
    f"{var}/keras_nightly-3.12.0.dev2025100703-py3-none-any.whl",
    f"{var}/tifffile-2025.10.16-py3-none-any.whl",
    f"{var}/imagecodecs-2025.11.11-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
    f"{var}/medicai-0.0.3-py3-none-any.whl",
    "--no-index",
    "--find-links", var
], check=False, capture_output=True)

# Protobuf patch
try:
    from google.protobuf import message_factory as _message_factory
    if not hasattr(_message_factory.MessageFactory, "GetPrototype"):
        from google.protobuf.message_factory import GetMessageClass
        def _GetPrototype(self, descriptor):
            return GetMessageClass(descriptor)
        _message_factory.MessageFactory.GetPrototype = _GetPrototype
except:
    pass

os.environ["KERAS_BACKEND"] = "jax"

import keras
from medicai.transforms import Compose, ScaleIntensityRange
from medicai.models import TransUNet
from medicai.utils.inference import SlidingWindowInference

import numpy as np
import pandas as pd
import zipfile
import tifffile
import scipy.ndimage as ndi
from skimage.morphology import remove_small_objects

print(f"Backend: {keras.config.backend()}, Keras: {keras.version()}")

# ============================================================================
# CONFIG - V4 MAXIMA
# ============================================================================
root_dir = "/kaggle/input/vesuvius-challenge-surface-detection"
test_dir = f"{root_dir}/test_images"
output_dir = "/kaggle/working/submission_masks"
zip_path = "/kaggle/working/submission.zip"
os.makedirs(output_dir, exist_ok=True)

# Model weights - SEResNeXt101 (LB 0.460)
MODEL_PATH = "/kaggle/input/colab-a-162v4-gpu-transunet-seresnext101-x160/model.weights.h5"

# V4 MAXIMA SETTINGS
OVERLAP = 0.625  # Higher overlap for better stitching (was 0.52)
USE_FLIP_TTA = True  # Enable flip TTA (8x instead of 4x)

# Load test data
test_df = pd.read_csv(f"{root_dir}/test.csv")
print(f"Test samples: {len(test_df)}")
print(f"\nV4 MAXIMA CONFIG:")
print(f"  - Overlap: {OVERLAP}")
print(f"  - TTA: 8x (4 rotation + flip)" if USE_FLIP_TTA else "  - TTA: 4x (rotation only)")

# ============================================================================
# TRANSFORMATION
# ============================================================================
def val_transformation(image):
    data = {"image": image}
    pipeline = Compose([
        ScaleIntensityRange(
            keys=["image"],
            a_min=0, a_max=255,
            b_min=0, b_max=1,
            clip=True,
        ),
    ])
    return pipeline(data)["image"]

# ============================================================================
# MODEL
# ============================================================================
print(f"\nLoading model from: {MODEL_PATH}")

model = TransUNet(
    input_shape=(160, 160, 160, 1),
    encoder_name='seresnext101',
    classifier_activation='softmax',
    num_classes=3,
)
model.load_weights(MODEL_PATH)
print(f"Model loaded! Params: {model.count_params() / 1e6:.1f}M")

# Sliding window inference with HIGHER OVERLAP
pred_fn = SlidingWindowInference(
    model,
    roi_size=(160, 160, 160),
    num_classes=3,
    mode="gaussian",
    overlap=OVERLAP,  # V4: Higher overlap
    sw_batch_size=1
)

# ============================================================================
# HELPERS
# ============================================================================
def load_volume(path):
    vol = tifffile.imread(path)
    vol = vol.astype(np.float32)
    vol = vol[None, ..., None]
    return vol

def rot90_volume(vol, k):
    """Rotate volume by k*90 degrees"""
    if vol.ndim == 5:
        return np.rot90(vol, k=-k, axes=(2, 3))
    else:
        return np.rot90(vol, k=-k, axes=(1, 2))

def unrot90_volume(vol, k):
    """Reverse rotation"""
    return rot90_volume(vol, (4 - k) % 4)

def flip_volume(vol, axis):
    """Flip volume along axis"""
    if vol.ndim == 5:
        return np.flip(vol, axis=axis+1)  # +1 for batch dim
    else:
        return np.flip(vol, axis=axis)

def predict_probs_tta_8x(sample):
    """
    V4 MAXIMA: 8x TTA
    - 4x rotation (0, 90, 180, 270 degrees)
    - 2x flip (original + horizontal flip)
    Total: 8 predictions averaged
    """
    probs_accum = []

    # Original + flipped
    samples_to_process = [sample]
    if USE_FLIP_TTA:
        samples_to_process.append(flip_volume(sample, axis=2))  # Flip along X

    for flip_idx, s in enumerate(samples_to_process):
        for k in range(4):  # 4 rotations
            s_rot = rot90_volume(s, k)
            out = pred_fn(s_rot)
            out = np.asarray(out)
            probs = out[0, ..., 1]  # Foreground probability

            # Undo rotation
            probs = unrot90_volume(probs, k)

            # Undo flip if needed
            if flip_idx == 1:
                probs = np.flip(probs, axis=1)  # Undo X flip

            probs_accum.append(probs)

    return np.mean(probs_accum, axis=0)

def build_anisotropic_struct(z_radius, xy_radius):
    z, r = z_radius, xy_radius
    if z == 0 and r == 0:
        return None
    depth = 2 * z + 1 if z > 0 else 1
    size = 2 * r + 1 if r > 0 else 1
    struct = np.zeros((depth, size, size), dtype=bool)
    cz = z if z > 0 else 0
    cy = r if r > 0 else 0
    cx = r if r > 0 else 0
    for dz in range(-z if z > 0 else 0, z + 1 if z > 0 else 1):
        for dy in range(-r if r > 0 else 0, r + 1 if r > 0 else 1):
            for dx in range(-r if r > 0 else 0, r + 1 if r > 0 else 1):
                if r == 0 or dy * dy + dx * dx <= r * r:
                    struct[cz + dz, cy + dy, cx + dx] = True
    return struct

def topo_postprocess_v4(probs, T_low=0.40, T_high=0.80, z_radius=2, xy_radius=1, dust_min_size=150):
    """
    V4 MAXIMA: Optimized post-processing
    - Lower T_low for better recall
    - Larger morphological operations
    - Bigger dust removal
    """
    strong = probs >= T_high
    weak = probs >= T_low

    if not strong.any():
        return np.zeros_like(probs, dtype=np.uint8)

    # Hysteresis thresholding
    struct_hyst = ndi.generate_binary_structure(3, 3)
    mask = ndi.binary_propagation(strong, mask=weak, structure=struct_hyst)

    if not mask.any():
        return np.zeros_like(probs, dtype=np.uint8)

    # Morphological closing (fill small holes)
    if z_radius > 0 or xy_radius > 0:
        struct_close = build_anisotropic_struct(z_radius, xy_radius)
        if struct_close is not None:
            mask = ndi.binary_closing(mask, structure=struct_close)

    # Remove small objects (dust)
    if dust_min_size > 0:
        mask = remove_small_objects(mask.astype(bool), min_size=dust_min_size)

    return mask.astype(np.uint8)

def multi_threshold_ensemble(probs):
    """
    V4 MAXIMA: Multi-threshold ensemble
    Try multiple threshold combinations and average
    """
    masks = []

    # Threshold combinations (T_low, T_high)
    threshold_configs = [
        (0.35, 0.75),  # More aggressive
        (0.40, 0.80),  # Balanced
        (0.45, 0.85),  # Conservative (original V3)
    ]

    for t_low, t_high in threshold_configs:
        mask = topo_postprocess_v4(probs, T_low=t_low, T_high=t_high)
        masks.append(mask)

    # Majority voting (at least 2/3 agree)
    ensemble = (np.sum(masks, axis=0) >= 2).astype(np.uint8)

    return ensemble

def predict(sample):
    """V4 MAXIMA prediction pipeline"""
    # 8x TTA
    probs_fg = predict_probs_tta_8x(sample)

    # Multi-threshold ensemble
    final = multi_threshold_ensemble(probs_fg)

    return final

# ============================================================================
# INFERENCE
# ============================================================================
print(f"\n{'='*60}")
print(f"V4 MAXIMA INFERENCE - {len(test_df)} images")
print(f"{'='*60}")

with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as z:
    for idx, image_id in enumerate(test_df["id"]):
        tif_path = f"{test_dir}/{image_id}.tif"
        print(f"\n[{idx+1}/{len(test_df)}] {image_id}")

        volume = load_volume(tif_path)
        volume = val_transformation(volume)
        output = predict(volume)

        out_path = f"{output_dir}/{image_id}.tif"
        tifffile.imwrite(out_path, output.astype(np.uint8))
        z.write(out_path, arcname=f"{image_id}.tif")
        os.remove(out_path)

        print(f"    Shape: {output.shape}, Voxels: {output.sum():,}")

print(f"\n{'='*60}")
print(f"V4 MAXIMA DONE! Submission: {zip_path}")
print(f"{'='*60}")
