"""
VESUVIUS V5-QUICK v2 - TransUNet + Aggressive Topology Post-Processing
=======================================================================
Based on V4 MAXIMA + aggressive topology-aware post-processing.

NO MONAI needed - uses same TransUNet approach as V3/V4.
Target: Beat 0.575 with post-processing optimizations.
"""

import os
import gc
import sys
import zipfile
import numpy as np
from pathlib import Path

print("=" * 70)
print("VESUVIUS V5-QUICK v2 - TransUNet + Topology Post-Processing")
print("=" * 70)

# =============================================================================
# INSTALL REQUIRED PACKAGES (offline from wheels dataset)
# =============================================================================

print("Installing imagecodecs...")
os.system("pip install --no-deps /kaggle/input/wheels-for-vesuvius/imagecodecs-2025.11.11-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl -q 2>/dev/null || pip install imagecodecs -q 2>/dev/null || true")

# =============================================================================
# IMPORTS (no MONAI!)
# =============================================================================

import torch
import torch.nn.functional as F

# These are pre-installed on Kaggle
import tifffile
from tqdm import tqdm
from scipy import ndimage

# =============================================================================
# CONFIGURATION
# =============================================================================

class CFG:
    # Paths
    INPUT_DIR = Path("/kaggle/input/vesuvius-challenge-surface-detection")
    TEST_DIR = INPUT_DIR / "test_images"
    OUTPUT_DIR = Path("/kaggle/working")

    # Model - use pretrained TransUNet
    MODEL_DATASET = Path("/kaggle/input/hideyukizushi/colab-a-162v5-gpu-transunet-seresnext101-x160")
    MODEL_BACKUP = Path("/kaggle/input/vesuvius-transunet-160")

    # Inference settings
    PATCH_SIZE = (160, 160, 160)
    OVERLAP = 0.625  # High overlap for better stitching

    # Post-processing (AGGRESSIVE for topology!)
    THRESHOLD = 0.5  # Optimal threshold
    MIN_COMPONENT_SIZE = 4000  # Remove small artifacts
    CLOSING_RADIUS = 3  # Fill gaps for surface connectivity
    BORDER_CLEANUP = 5  # Clean edges

    # TTA
    TTA_ROTATIONS = 4  # 0, 90, 180, 270 degrees

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

CFG.OUTPUT_DIR.mkdir(exist_ok=True)

print(f"Device: {CFG.DEVICE}")
print(f"Threshold: {CFG.THRESHOLD}")
print(f"Min Component: {CFG.MIN_COMPONENT_SIZE}")
print(f"Closing Radius: {CFG.CLOSING_RADIUS}")
print(f"TTA Rotations: {CFG.TTA_ROTATIONS}x")
print()

# =============================================================================
# TRANSUNET MODEL ARCHITECTURE
# =============================================================================

# Import segmentation_models_pytorch_3d if available, else define minimal model
try:
    # Try importing from Kaggle's pre-installed packages
    sys.path.append("/kaggle/input/smp3d")
    import segmentation_models_pytorch_3d as smp3d
    HAS_SMP3D = True
    print("Using segmentation_models_pytorch_3d")
except:
    HAS_SMP3D = False
    print("smp3d not available, using fallback model")


class DoubleConv3D(torch.nn.Module):
    """Double 3D convolution block."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv3d(in_ch, out_ch, 3, padding=1, bias=False),
            torch.nn.BatchNorm3d(out_ch),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv3d(out_ch, out_ch, 3, padding=1, bias=False),
            torch.nn.BatchNorm3d(out_ch),
            torch.nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class SimpleUNet3D(torch.nn.Module):
    """Simple 3D UNet - fallback model."""
    def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128, 256]):
        super().__init__()
        self.downs = torch.nn.ModuleList()
        self.ups = torch.nn.ModuleList()
        self.pool = torch.nn.MaxPool3d(2, 2)

        # Encoder
        in_ch = in_channels
        for feature in features:
            self.downs.append(DoubleConv3D(in_ch, feature))
            in_ch = feature

        # Bottleneck
        self.bottleneck = DoubleConv3D(features[-1], features[-1] * 2)

        # Decoder
        for feature in reversed(features):
            self.ups.append(torch.nn.ConvTranspose3d(feature * 2, feature, 2, 2))
            self.ups.append(DoubleConv3D(feature * 2, feature))

        self.final = torch.nn.Conv3d(features[0], out_channels, 1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip = skip_connections[idx // 2]
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:])
            x = torch.cat([skip, x], dim=1)
            x = self.ups[idx + 1](x)

        return torch.sigmoid(self.final(x))


def load_model():
    """Load best available model."""
    # Try to find pretrained checkpoint
    model_dirs = [
        CFG.MODEL_DATASET,
        CFG.MODEL_BACKUP,
        Path("/kaggle/input/vesuvius-surface-detection-3d-checkpoints"),
    ]

    checkpoint = None
    for model_dir in model_dirs:
        if not model_dir.exists():
            continue
        ckpt_files = list(model_dir.rglob("*.pth")) + list(model_dir.rglob("*.ckpt"))
        if ckpt_files:
            ckpt_path = sorted(ckpt_files)[-1]
            print(f"Found checkpoint: {ckpt_path}")
            try:
                checkpoint = torch.load(ckpt_path, map_location=CFG.DEVICE, weights_only=False)
                print(f"Loaded successfully!")
                break
            except Exception as e:
                print(f"Failed to load: {e}")

    # Create model
    if HAS_SMP3D:
        model = smp3d.create_model(
            'transunet',
            encoder_name='tu-seresnext101_32x4d',
            encoder_weights=None,
            in_channels=1,
            classes=1,
            activation=None,
        )
    else:
        model = SimpleUNet3D(in_channels=1, out_channels=1)

    # Load weights if available
    if checkpoint is not None:
        try:
            if isinstance(checkpoint, dict):
                if 'state_dict' in checkpoint:
                    state_dict = checkpoint['state_dict']
                elif 'model_state_dict' in checkpoint:
                    state_dict = checkpoint['model_state_dict']
                else:
                    state_dict = checkpoint
            else:
                state_dict = checkpoint

            # Clean state dict keys
            cleaned = {}
            for k, v in state_dict.items():
                k = k.replace('model.', '').replace('net.', '')
                cleaned[k] = v

            model.load_state_dict(cleaned, strict=False)
            print("Weights loaded!")
        except Exception as e:
            print(f"Could not load weights: {e}")
            print("Using random initialization")

    model = model.to(CFG.DEVICE)
    model.eval()
    return model

# =============================================================================
# INFERENCE
# =============================================================================

def normalize_volume(volume):
    """Normalize to zero mean, unit std."""
    volume = volume.astype(np.float32)
    mean = volume.mean()
    std = volume.std() + 1e-6
    return (volume - mean) / std


def sliding_window_inference(model, volume, patch_size, overlap, device):
    """Sliding window inference with Gaussian weighting."""
    D, H, W = volume.shape
    pD, pH, pW = patch_size

    # Calculate stride
    stride = [int(p * (1 - overlap)) for p in patch_size]

    # Create Gaussian weight
    sigma = 0.125
    coords = [np.linspace(-1, 1, s) for s in patch_size]
    grid = np.meshgrid(*coords, indexing='ij')
    gaussian = np.exp(-(grid[0]**2 + grid[1]**2 + grid[2]**2) / (2 * sigma**2))
    gaussian = gaussian / gaussian.max()
    gaussian_tensor = torch.from_numpy(gaussian).float().to(device)

    # Initialize output
    output = torch.zeros((D, H, W), device=device)
    weight = torch.zeros((D, H, W), device=device)

    # Normalize volume
    volume_norm = normalize_volume(volume)
    volume_tensor = torch.from_numpy(volume_norm).float().to(device)

    # Pad if needed
    pad_d = max(0, pD - D)
    pad_h = max(0, pH - H)
    pad_w = max(0, pW - W)

    if pad_d > 0 or pad_h > 0 or pad_w > 0:
        volume_tensor = F.pad(volume_tensor, (0, pad_w, 0, pad_h, 0, pad_d))
        output = F.pad(output, (0, pad_w, 0, pad_h, 0, pad_d))
        weight = F.pad(weight, (0, pad_w, 0, pad_h, 0, pad_d))

    pD_new, pH_new, pW_new = volume_tensor.shape

    # Generate patch positions
    positions = []
    for d in range(0, max(1, pD_new - pD + 1), stride[0]):
        for h in range(0, max(1, pH_new - pH + 1), stride[1]):
            for w in range(0, max(1, pW_new - pW + 1), stride[2]):
                positions.append((min(d, pD_new - pD), min(h, pH_new - pH), min(w, pW_new - pW)))

    # Remove duplicates
    positions = list(set(positions))

    with torch.no_grad(), torch.amp.autocast('cuda'):
        for d, h, w in tqdm(positions, desc="  Patches", leave=False):
            patch = volume_tensor[d:d+pD, h:h+pH, w:w+pW]
            patch = patch.unsqueeze(0).unsqueeze(0)

            # Predict
            pred = model(patch)
            if pred.shape[1] > 1:
                pred = torch.softmax(pred, dim=1)[:, 1]
            else:
                pred = torch.sigmoid(pred)
            pred = pred.squeeze()

            # Apply Gaussian weighting
            output[d:d+pD, h:h+pH, w:w+pW] += pred * gaussian_tensor
            weight[d:d+pD, h:h+pH, w:w+pW] += gaussian_tensor

    # Normalize by weights
    output = output / (weight + 1e-8)

    # Remove padding
    output = output[:D, :H, :W]

    return output.cpu().numpy()


def predict_with_tta(model, volume, patch_size, overlap, device, num_rotations=4):
    """Predict with TTA (rotations)."""
    predictions = []

    for rot in range(num_rotations):
        print(f"  TTA rotation {rot+1}/{num_rotations}")

        # Rotate volume
        vol_rot = np.rot90(volume, k=rot, axes=(1, 2))
        vol_rot = np.ascontiguousarray(vol_rot)

        # Predict
        pred = sliding_window_inference(model, vol_rot, patch_size, overlap, device)

        # Reverse rotation
        pred = np.rot90(pred, k=4-rot, axes=(1, 2))
        pred = np.ascontiguousarray(pred)

        predictions.append(pred)

        # Free memory
        del vol_rot
        gc.collect()
        torch.cuda.empty_cache()

    # Average predictions
    return np.mean(predictions, axis=0)

# =============================================================================
# POST-PROCESSING (CRITICAL FOR TOPOLOGY SCORE!)
# =============================================================================

def topology_postprocess(prediction):
    """
    Aggressive topology-aware post-processing.

    The metric rewards:
    - Voxel accuracy
    - Surface connectivity (no gaps, holes, mergers)
    """
    print(f"  Raw prediction range: [{prediction.min():.3f}, {prediction.max():.3f}]")

    # 1. Thresholding
    binary = (prediction > CFG.THRESHOLD).astype(np.uint8)
    print(f"  After threshold ({CFG.THRESHOLD}): {binary.sum():,} voxels")

    # 2. Morphological Closing (fill gaps for better surface connectivity)
    if CFG.CLOSING_RADIUS > 0:
        struct = ndimage.generate_binary_structure(3, 1)
        # Dilate then erode
        for _ in range(CFG.CLOSING_RADIUS):
            binary = ndimage.binary_dilation(binary, struct).astype(np.uint8)
        for _ in range(CFG.CLOSING_RADIUS):
            binary = ndimage.binary_erosion(binary, struct).astype(np.uint8)
        print(f"  After closing (r={CFG.CLOSING_RADIUS}): {binary.sum():,} voxels")

    # 3. Connected Components Filtering (remove noise)
    structure = ndimage.generate_binary_structure(3, 3)  # 26-connectivity
    labeled, num_components = ndimage.label(binary, structure=structure)

    if num_components > 0:
        sizes = ndimage.sum(binary, labeled, range(1, num_components + 1))
        keep_labels = np.where(np.array(sizes) >= CFG.MIN_COMPONENT_SIZE)[0] + 1

        filtered = np.zeros_like(binary)
        for label_id in keep_labels:
            filtered[labeled == label_id] = 1

        removed = num_components - len(keep_labels)
        print(f"  CC Filter: {num_components} -> {len(keep_labels)} ({removed} removed)")
        binary = filtered

    print(f"  After CC filter (min={CFG.MIN_COMPONENT_SIZE}): {binary.sum():,} voxels")

    # 4. Border Cleanup (remove edge artifacts)
    if CFG.BORDER_CLEANUP > 0:
        b = CFG.BORDER_CLEANUP
        binary[:b, :, :] = 0
        binary[-b:, :, :] = 0
        binary[:, :b, :] = 0
        binary[:, -b:, :] = 0
        binary[:, :, :b] = 0
        binary[:, :, -b:] = 0
        print(f"  After border cleanup ({b}px): {binary.sum():,} voxels")

    return binary.astype(np.uint8)

# =============================================================================
# MAIN
# =============================================================================

def main():
    # Load model
    print("\nLoading model...")
    model = load_model()

    # Get test files
    test_files = sorted(CFG.TEST_DIR.glob("*.tif"))
    print(f"Found {len(test_files)} test files\n")

    # Process each volume
    predictions_dir = CFG.OUTPUT_DIR / "predictions"
    predictions_dir.mkdir(exist_ok=True)

    for test_path in test_files:
        print(f"{'='*60}")
        print(f"Processing: {test_path.name}")

        # Load volume
        volume = tifffile.imread(test_path)
        print(f"  Shape: {volume.shape}")

        # Predict with TTA
        prediction = predict_with_tta(
            model, volume,
            CFG.PATCH_SIZE, CFG.OVERLAP, CFG.DEVICE,
            CFG.TTA_ROTATIONS
        )

        # Post-process
        print("  Post-processing...")
        binary = topology_postprocess(prediction)

        # Save
        output_path = predictions_dir / test_path.name
        tifffile.imwrite(output_path, binary)
        print(f"  Saved: {output_path} ({binary.sum():,} voxels)")

        # Free memory
        del volume, prediction, binary
        gc.collect()
        torch.cuda.empty_cache()

    # Create submission
    print(f"\n{'='*60}")
    print("Creating submission...")
    submission_path = CFG.OUTPUT_DIR / "submission.zip"

    with zipfile.ZipFile(submission_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for pred_file in predictions_dir.glob("*.tif"):
            zipf.write(pred_file, pred_file.name)

    print(f"\nV5-QUICK COMPLETE!")
    print(f"Submission: {submission_path}")
    print(f"Size: {submission_path.stat().st_size / 1024 / 1024:.1f} MB")

if __name__ == "__main__":
    main()
