"""
================================================================================
   VESUVIUS V6 - WINNER EDITION

   Nobel Prize-level approach combining:
   1. TransUNet neural network (baseline prediction)
   2. 3D Sobel gradient refinement (ThaumatoAnakalyptor inspired)
   3. Directional filtering toward scroll center
   4. Advanced topology-aware post-processing
   5. Hysteresis thresholding + morphological optimization

   Target: Beat 0.575 to win $100K!
================================================================================
"""

import subprocess
import os
import gc

# Install packages
var = "/kaggle/input/vesuvius25-packages-offline-installer-v20251226/whls"
print(f"Installing from: {var}")

subprocess.run([
    "pip", "install", "--quiet",
    f"{var}/keras_nightly-3.12.0.dev2025100703-py3-none-any.whl",
    f"{var}/tifffile-2025.10.16-py3-none-any.whl",
    f"{var}/imagecodecs-2025.11.11-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl",
    f"{var}/medicai-0.0.3-py3-none-any.whl",
    "--no-index",
    "--find-links", var
], check=False, capture_output=True)

# Protobuf patch
try:
    from google.protobuf import message_factory as _message_factory
    if not hasattr(_message_factory.MessageFactory, "GetPrototype"):
        from google.protobuf.message_factory import GetMessageClass
        def _GetPrototype(self, descriptor):
            return GetMessageClass(descriptor)
        _message_factory.MessageFactory.GetPrototype = _GetPrototype
except:
    pass

os.environ["KERAS_BACKEND"] = "jax"

import keras
from medicai.transforms import Compose, ScaleIntensityRange
from medicai.models import TransUNet
from medicai.utils.inference import SlidingWindowInference

import numpy as np
import pandas as pd
import zipfile
import tifffile
import scipy.ndimage as ndi
from skimage.morphology import remove_small_objects, binary_closing, ball

print(f"Backend: {keras.config.backend()}, Keras: {keras.version()}")

# ============================================================================
# CONFIG - V6 WINNER
# ============================================================================
root_dir = "/kaggle/input/vesuvius-challenge-surface-detection"
test_dir = f"{root_dir}/test_images"
output_dir = "/kaggle/working/submission_masks"
zip_path = "/kaggle/working/submission.zip"
os.makedirs(output_dir, exist_ok=True)

# Model weights - SEResNeXt101 (best pretrained)
MODEL_PATH = "/kaggle/input/colab-a-162v4-gpu-transunet-seresnext101-x160/model.weights.h5"

# V6 WINNER SETTINGS
OVERLAP = 0.65  # Maximum overlap for best stitching
USE_FLIP_TTA = True  # 8x TTA
USE_SOBEL_REFINEMENT = True  # NEW: 3D Sobel refinement
USE_DIRECTIONAL_FILTER = True  # NEW: Filter toward scroll center

# Load test data
test_df = pd.read_csv(f"{root_dir}/test.csv")
print(f"Test samples: {len(test_df)}")

print(f"\n{'='*60}")
print("V6 WINNER - Nobel Prize Edition")
print(f"{'='*60}")
print(f"  - Overlap: {OVERLAP}")
print(f"  - TTA: 8x (4 rotation + flip)")
print(f"  - 3D Sobel Refinement: {USE_SOBEL_REFINEMENT}")
print(f"  - Directional Filtering: {USE_DIRECTIONAL_FILTER}")

# ============================================================================
# TRANSFORMATION
# ============================================================================
def val_transformation(image):
    data = {"image": image}
    pipeline = Compose([
        ScaleIntensityRange(
            keys=["image"],
            a_min=0, a_max=255,
            b_min=0, b_max=1,
            clip=True,
        ),
    ])
    return pipeline(data)["image"]

# ============================================================================
# MODEL
# ============================================================================
print(f"\nLoading model from: {MODEL_PATH}")

model = TransUNet(
    input_shape=(160, 160, 160, 1),
    encoder_name='seresnext101',
    classifier_activation='softmax',
    num_classes=3,
)
model.load_weights(MODEL_PATH)
print(f"Model loaded! Params: {model.count_params() / 1e6:.1f}M")

# Sliding window inference
pred_fn = SlidingWindowInference(
    model,
    roi_size=(160, 160, 160),
    num_classes=3,
    mode="gaussian",
    overlap=OVERLAP,
    sw_batch_size=1
)

# ============================================================================
# 3D SOBEL GRADIENT (ThaumatoAnakalyptor inspired)
# ============================================================================
def compute_3d_sobel(volume):
    """
    Compute 3D Sobel gradient magnitude.
    This detects edges/surfaces in the CT volume.
    Key insight from ThaumatoAnakalyptor: surfaces show bell curve in intensity.
    """
    # Compute gradients along each axis
    gx = ndi.sobel(volume, axis=0, mode='reflect')
    gy = ndi.sobel(volume, axis=1, mode='reflect')
    gz = ndi.sobel(volume, axis=2, mode='reflect')

    # Magnitude
    magnitude = np.sqrt(gx**2 + gy**2 + gz**2)

    # Normalize
    magnitude = magnitude / (magnitude.max() + 1e-8)

    return magnitude, (gx, gy, gz)


def compute_scroll_center_direction(volume_shape):
    """
    Compute direction vectors pointing toward scroll center.
    Assumes scroll is roughly centered in the volume.
    """
    D, H, W = volume_shape
    center = np.array([D/2, H/2, W/2])

    # Create coordinate grids
    z, y, x = np.mgrid[0:D, 0:H, 0:W]

    # Direction toward center
    dz = center[0] - z
    dy = center[1] - y
    dx = center[2] - x

    # Normalize
    magnitude = np.sqrt(dz**2 + dy**2 + dx**2) + 1e-8
    dz = dz / magnitude
    dy = dy / magnitude
    dx = dx / magnitude

    return dz, dy, dx


def directional_filter(gradients, center_direction, threshold=0.3):
    """
    Filter surfaces based on gradient direction.
    Keep only surfaces facing toward the scroll center.
    This removes back-facing surfaces and noise.
    """
    gx, gy, gz = gradients
    cx, cy, cz = center_direction

    # Compute dot product (alignment with center direction)
    dot = gx * cx + gy * cy + gz * cz

    # Surfaces facing toward center have positive dot product
    mask = dot > threshold

    return mask


def sobel_surface_detection(volume, gradient_threshold=0.1, directional_threshold=0.2):
    """
    ThaumatoAnakalyptor-inspired surface detection using 3D Sobel.

    1. Compute 3D gradients
    2. Threshold on magnitude
    3. Filter by direction toward center
    """
    # Compute 3D Sobel
    magnitude, gradients = compute_3d_sobel(volume)

    # Threshold on gradient magnitude
    surface_mask = magnitude > gradient_threshold

    if USE_DIRECTIONAL_FILTER:
        # Get center direction
        center_dir = compute_scroll_center_direction(volume.shape)

        # Directional filtering
        dir_mask = directional_filter(gradients, center_dir, directional_threshold)

        # Combine masks
        surface_mask = surface_mask & dir_mask

    return surface_mask.astype(np.float32), magnitude

# ============================================================================
# ROTATION/FLIP HELPERS
# ============================================================================
def rot90_volume(vol, k):
    """Rotate volume by k*90 degrees"""
    if vol.ndim == 5:
        return np.rot90(vol, k=-k, axes=(2, 3))
    else:
        return np.rot90(vol, k=-k, axes=(1, 2))

def unrot90_volume(vol, k):
    return rot90_volume(vol, (4 - k) % 4)

def flip_volume(vol, axis):
    if vol.ndim == 5:
        return np.flip(vol, axis=axis+1)
    else:
        return np.flip(vol, axis=axis)

# ============================================================================
# TTA PREDICTION
# ============================================================================
def predict_probs_tta_8x(sample):
    """V6: 8x TTA (4 rotation x 2 flip)"""
    probs_accum = []

    samples_to_process = [sample]
    if USE_FLIP_TTA:
        samples_to_process.append(flip_volume(sample, axis=2))

    for flip_idx, s in enumerate(samples_to_process):
        for k in range(4):
            s_rot = rot90_volume(s, k)
            out = pred_fn(s_rot)
            out = np.asarray(out)
            probs = out[0, ..., 1]

            probs = unrot90_volume(probs, k)
            if flip_idx == 1:
                probs = np.flip(probs, axis=1)

            probs_accum.append(probs)

    return np.mean(probs_accum, axis=0)

# ============================================================================
# ADVANCED TOPOLOGY-AWARE POST-PROCESSING
# ============================================================================
def build_anisotropic_struct(z_radius, xy_radius):
    z, r = z_radius, xy_radius
    if z == 0 and r == 0:
        return None
    depth = 2 * z + 1 if z > 0 else 1
    size = 2 * r + 1 if r > 0 else 1
    struct = np.zeros((depth, size, size), dtype=bool)
    cz = z if z > 0 else 0
    cy = r if r > 0 else 0
    cx = r if r > 0 else 0
    for dz in range(-z if z > 0 else 0, z + 1 if z > 0 else 1):
        for dy in range(-r if r > 0 else 0, r + 1 if r > 0 else 1):
            for dx in range(-r if r > 0 else 0, r + 1 if r > 0 else 1):
                if r == 0 or dy * dy + dx * dx <= r * r:
                    struct[cz + dz, cy + dy, cx + dx] = True
    return struct


def topology_postprocess_v6(probs, sobel_mask=None,
                            T_low=0.35, T_high=0.75,
                            z_radius=3, xy_radius=2,
                            dust_min_size=200,
                            fill_holes=True):
    """
    V6 WINNER: Advanced topology-aware post-processing.

    Optimized for Surface Dice + TopoScore + VOI metric.

    1. Hysteresis thresholding (better boundary detection)
    2. Sobel refinement (optional)
    3. Morphological closing (fill gaps for connectivity)
    4. Hole filling (remove internal holes)
    5. Small object removal (noise)
    """
    # 1. Hysteresis thresholding
    strong = probs >= T_high
    weak = probs >= T_low

    if not strong.any():
        return np.zeros_like(probs, dtype=np.uint8)

    struct_hyst = ndi.generate_binary_structure(3, 3)
    mask = ndi.binary_propagation(strong, mask=weak, structure=struct_hyst)

    if not mask.any():
        return np.zeros_like(probs, dtype=np.uint8)

    # 2. Sobel refinement (combine with neural network prediction)
    if sobel_mask is not None and USE_SOBEL_REFINEMENT:
        # Boost regions where both NN and Sobel agree
        combined = mask.astype(np.float32) + 0.3 * sobel_mask
        mask = combined > 0.8

    # 3. Morphological closing (CRITICAL for topology!)
    # Fill small gaps to ensure surface connectivity
    if z_radius > 0 or xy_radius > 0:
        struct_close = build_anisotropic_struct(z_radius, xy_radius)
        if struct_close is not None:
            mask = ndi.binary_closing(mask, structure=struct_close)

    # 4. Fill internal holes (improves TopoScore)
    if fill_holes:
        mask = ndi.binary_fill_holes(mask)

    # 5. Remove small objects (reduces noise, improves precision)
    if dust_min_size > 0:
        mask = remove_small_objects(mask.astype(bool), min_size=dust_min_size)

    # 6. Border cleanup
    border = 5
    mask[:border, :, :] = 0
    mask[-border:, :, :] = 0
    mask[:, :border, :] = 0
    mask[:, -border:, :] = 0
    mask[:, :, :border] = 0
    mask[:, :, -border:] = 0

    return mask.astype(np.uint8)


def multi_threshold_ensemble_v6(probs, sobel_mask=None):
    """
    V6: Multi-threshold ensemble with Sobel refinement.
    """
    masks = []

    # Multiple threshold configurations
    threshold_configs = [
        (0.30, 0.70),  # Aggressive (high recall)
        (0.35, 0.75),  # Balanced
        (0.40, 0.80),  # Conservative (high precision)
        (0.45, 0.85),  # Very conservative
    ]

    for t_low, t_high in threshold_configs:
        mask = topology_postprocess_v6(
            probs, sobel_mask,
            T_low=t_low, T_high=t_high,
            fill_holes=True
        )
        masks.append(mask)

    # Majority voting (at least 2/4 agree)
    ensemble = (np.sum(masks, axis=0) >= 2).astype(np.uint8)

    return ensemble

# ============================================================================
# MAIN PREDICTION
# ============================================================================
def load_volume(path):
    vol = tifffile.imread(path)
    vol = vol.astype(np.float32)
    return vol

def predict_v6(volume):
    """
    V6 WINNER prediction pipeline:
    1. Neural network with TTA
    2. 3D Sobel surface detection
    3. Combine predictions
    4. Advanced post-processing
    """
    # Prepare for neural network
    vol_nn = volume[None, ..., None]  # Add batch and channel dims
    vol_nn = val_transformation(vol_nn)

    # Neural network prediction with TTA
    probs_nn = predict_probs_tta_8x(vol_nn)

    # 3D Sobel surface detection
    sobel_mask = None
    if USE_SOBEL_REFINEMENT:
        sobel_mask, _ = sobel_surface_detection(
            volume,
            gradient_threshold=0.1,
            directional_threshold=0.2
        )

    # Multi-threshold ensemble with Sobel refinement
    final = multi_threshold_ensemble_v6(probs_nn, sobel_mask)

    return final, probs_nn

# ============================================================================
# INFERENCE
# ============================================================================
print(f"\n{'='*60}")
print(f"V6 WINNER INFERENCE - {len(test_df)} images")
print(f"{'='*60}")

with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as z:
    for idx, image_id in enumerate(test_df["id"]):
        tif_path = f"{test_dir}/{image_id}.tif"
        print(f"\n[{idx+1}/{len(test_df)}] {image_id}")

        # Load volume
        volume = load_volume(tif_path)
        print(f"    Volume shape: {volume.shape}")

        # V6 prediction
        output, probs = predict_v6(volume)

        # Stats
        voxels = output.sum()
        print(f"    Output shape: {output.shape}")
        print(f"    Voxels detected: {voxels:,}")
        print(f"    Probability range: [{probs.min():.3f}, {probs.max():.3f}]")

        # Save
        out_path = f"{output_dir}/{image_id}.tif"
        tifffile.imwrite(out_path, output.astype(np.uint8))
        z.write(out_path, arcname=f"{image_id}.tif")
        os.remove(out_path)

        # Cleanup
        del volume, output, probs
        gc.collect()

print(f"\n{'='*60}")
print(f"V6 WINNER COMPLETE!")
print(f"Submission: {zip_path}")
print(f"{'='*60}")
