# Training Notebooks

- [Vesuvius Surface 3D Detection in Keras-JAX](https://www.kaggle.com/code/ipythonx/vesuvius-surface-3d-detection-in-jax)
- [Vesuvius Surface 3D Detection in PyTorch](https://www.kaggle.com/code/ipythonx/vesuvius-surface-3d-detection-in-pytorch)
- [Vesuvius Surface 3D Detection in PyTorch Lightning](https://www.kaggle.com/code/ipythonx/train-vesuvius-surface-3d-detection-in-lightning)
- [[WIP] Vesuvius Surface 2.5D Detection](https://www.kaggle.com/code/ipythonx/wip-vesuvius-surface-2-5d-detection)

**Note**
1. The inference code below is adapted from the **Keras-JAX** version. The PyTorch and Lightning implementations follow the same workflow. Training was performed on a single Tesla T4 (16 GB VRAM) with extended epochs.
2. Both the training and inference pipelines are implemented using [`medicai`](https://github.com/innat/medic-ai), a **Keras 3** based multi-backend medical ML library designed for 2D and 3D classification and segmentation tasks. However, please note, `medicai` project is still new and actively evolving.

# Inference

In [None]:
var="/kaggle/input/vsdetection-packages-offline-installer-only/whls"
!pip install \
    "$var"/keras_nightly-3.12.0.dev2025100703-py3-none-any.whl \
    "$var"/tifffile-2025.12.12-py3-none-any.whl \
    "$var"/imagecodecs-2025.11.11-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl \
    "$var"/medicai-0.0.3-py3-none-any.whl \
    --no-index \
    --find-links "$var"

In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax"

import keras
from medicai.transforms import (
    Compose,
    ScaleIntensityRange,
)
from medicai.models import SegFormer, TransUNet
from medicai.utils.inference import SlidingWindowInference

import numpy as np
import pandas as pd
import zipfile
import tifffile
from matplotlib import pyplot as plt

keras.config.backend(), keras.version()

**Dataset**

In [None]:
root_dir = "/kaggle/input/vesuvius-challenge-surface-detection"
test_dir = f"{root_dir}/test_images"
output_dir = "/kaggle/working/submission_masks"
zip_path = "/kaggle/working/submission.zip"
os.makedirs(output_dir, exist_ok=True)

In [None]:
test_df = pd.read_csv(f"{root_dir}/test.csv")
test_df.head()

**Transformation**

In [None]:
def val_transformation(image):
    data = {"image": image}
    pipeline = Compose([
        ScaleIntensityRange(
            keys=["image"],
            a_min = 0,
            a_max = 255,
            b_min = 0,
            b_max = 1,
            clip = True,
        ),
    ])
    result = pipeline(data)
    return result["image"]

**Model**

In [None]:
kaggle_model_path = "/kaggle/input/vsd-model/keras/"

In [None]:
def get_model():
    ## LB: 0.486
    # model = SegFormer(
    #     input_shape=(128, 128, 128, 1),
    #     encoder_name='mit_b2',
    #     classifier_activation='softmax',
    #     num_classes=2,
    # )
    # model.load_weights(
    #     "/kaggle/input/vsd-model/keras/segformer.mit.b2/2/segformer.mit.b2.weights.h5"
    # )

    ## LB: 0.5 
    # model = TransUNet(
    #     input_shape=(128, 128, 128, 1),
    #     encoder_name='seresnext50',
    #     classifier_activation='softmax',
    #     num_classes=2,
    # )
    # model.load_weights(
    #     f"{kaggle_model_path}/transunet/2/transunet.seresnext50.128px.weights.h5"
    # )

    ## LB: 505
    model = TransUNet(
        input_shape=(160, 160, 160, 1),
        encoder_name='seresnext50',
        classifier_activation='softmax',
        num_classes=3,
    )
    model.load_weights(
        f"{kaggle_model_path}/transunet/2/transunet.seresnext50.160px.weights.h5"
    )
    return model

In [None]:
model = get_model()
model.count_params() / 1e6

In [None]:
model.instance_describe()

**Sliding Window Inference**

In [None]:
swi = SlidingWindowInference(
    model,
    num_classes=3,
    roi_size=(160, 160, 160),
    sw_batch_size=1,
    mode='gaussian',
    overlap=0.5,
)

In [None]:
def load_volume(path):
    vol = tifffile.imread(path)
    vol = vol.astype(np.float32)
    vol = vol[None, ..., None]
    return vol

def predict(sample):
    mask = swi(sample)
    output = mask.argmax(-1).astype(np.uint8).squeeze()
    return output

**Prediction and Zip Submission**

In [None]:
with zipfile.ZipFile(
    zip_path, "w", compression=zipfile.ZIP_DEFLATED
) as z:
    for image_id in test_df["id"]:
        tif_path = f"{test_dir}/{image_id}.tif"
        
        volume = load_volume(tif_path)
        volume = val_transformation(volume)
        output = predict(volume) 
        
        out_path = f"{output_dir}/{image_id}.tif"
        tifffile.imwrite(out_path, output.astype(np.uint8))

        z.write(out_path, arcname=f"{image_id}.tif")
        os.remove(out_path)

print("Submission ZIP:", zip_path)