"""
VESUVIUS V5-QUICK - Pretrained Model + Aggressive Post-Processing
===================================================================
Quick submission while V5-WINNER trains.

Uses best pretrained model + topology-aware post-processing.
Target: Competitive score quickly (>0.56)
"""

import os
import sys
import gc
import zipfile
import numpy as np
from pathlib import Path

# =============================================================================
# CONFIGURATION
# =============================================================================

class CFG:
    # Paths
    INPUT_DIR = Path("/kaggle/input/vesuvius-challenge-surface-detection")
    TEST_DIR = INPUT_DIR / "test_images"
    OUTPUT_DIR = Path("/kaggle/working")

    # Model settings - use best available pretrained
    MODEL_PATH = Path("/kaggle/input/vesuvius-surface-detection-3d-checkpoints/pytorch/default/0")
    MODEL_SIZE = (160, 160, 160)

    # Post-processing (AGGRESSIVE for topology score!)
    THRESHOLD = 0.6  # Higher threshold = cleaner predictions
    MIN_COMPONENT_SIZE = 5000  # Remove small artifacts
    CLOSING_RADIUS = 4  # Fill gaps
    BORDER_CLEANUP = 8  # Clean edges

    # TTA
    USE_TTA = True
    TTA_ROTATIONS = [0, 1, 2, 3]  # 90-degree rotations
    TTA_FLIPS = [False, True]  # Flip augmentation

    DEVICE = "cuda"

CFG.OUTPUT_DIR.mkdir(exist_ok=True)

# =============================================================================
# INSTALLATION
# =============================================================================

print("Installing packages...")
os.system("pip install monai tifffile scipy scikit-image -q")

import torch
import torch.nn.functional as F
import tifffile
from tqdm import tqdm
from scipy import ndimage

# =============================================================================
# MODEL
# =============================================================================

def load_model():
    """Load the best pretrained model."""
    from monai.networks.nets import SegResNet

    model = SegResNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        init_filters=16,
        dropout_prob=0.2
    )

    # Find best checkpoint
    ckpt_files = list(CFG.MODEL_PATH.glob("*.ckpt"))
    if not ckpt_files:
        ckpt_files = list(CFG.MODEL_PATH.glob("*.pth"))

    if ckpt_files:
        # Sort by val_dice in filename if available
        best_ckpt = sorted(ckpt_files, key=lambda x: str(x))[-1]
        print(f"Loading checkpoint: {best_ckpt}")

        checkpoint = torch.load(best_ckpt, map_location=CFG.DEVICE, weights_only=False)

        # Handle different checkpoint formats
        if 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
            # Remove prefix if present
            state_dict = {k.replace('net_module.', ''): v for k, v in state_dict.items()}
            model.load_state_dict(state_dict, strict=False)
        elif 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        else:
            model.load_state_dict(checkpoint, strict=False)

    model = model.to(CFG.DEVICE)
    model.eval()
    return model

# =============================================================================
# PREPROCESSING
# =============================================================================

def normalize_volume(volume):
    """Normalize volume to zero mean and unit std."""
    volume = volume.astype(np.float32)
    mean = volume.mean()
    std = volume.std() + 1e-6
    return (volume - mean) / std


def resize_volume(volume, target_size):
    """Resize volume using trilinear interpolation."""
    vol_tensor = torch.from_numpy(volume).float().unsqueeze(0).unsqueeze(0)
    resized = F.interpolate(vol_tensor, size=target_size, mode='trilinear', align_corners=False)
    return resized.squeeze(0).squeeze(0).numpy()

# =============================================================================
# TTA (Test Time Augmentation)
# =============================================================================

def apply_tta_transform(volume, rotation, flip):
    """Apply TTA transformation."""
    # Rotation (90-degree steps around z-axis)
    if rotation > 0:
        volume = np.rot90(volume, k=rotation, axes=(1, 2))
    # Flip
    if flip:
        volume = np.flip(volume, axis=2)
    return np.ascontiguousarray(volume)


def reverse_tta_transform(volume, rotation, flip):
    """Reverse TTA transformation."""
    # Reverse flip first
    if flip:
        volume = np.flip(volume, axis=2)
    # Reverse rotation
    if rotation > 0:
        volume = np.rot90(volume, k=4-rotation, axes=(1, 2))
    return np.ascontiguousarray(volume)

# =============================================================================
# INFERENCE
# =============================================================================

@torch.no_grad()
def predict_volume(model, volume):
    """Predict with TTA ensemble."""
    original_shape = volume.shape

    # Normalize and resize to model size
    volume_norm = normalize_volume(volume)
    volume_resized = resize_volume(volume_norm, CFG.MODEL_SIZE)

    all_predictions = []

    # TTA loop
    tta_configs = []
    if CFG.USE_TTA:
        for rot in CFG.TTA_ROTATIONS:
            for flip in CFG.TTA_FLIPS:
                tta_configs.append((rot, flip))
    else:
        tta_configs = [(0, False)]

    for rotation, flip in tta_configs:
        # Apply transform
        vol_aug = apply_tta_transform(volume_resized, rotation, flip)

        # To tensor
        vol_tensor = torch.from_numpy(vol_aug).float().unsqueeze(0).unsqueeze(0).to(CFG.DEVICE)

        # Predict
        logits = model(vol_tensor)
        probs = torch.softmax(logits, dim=1)

        # Get surface probability (class 1)
        pred = probs[0, 1].cpu().numpy()

        # Reverse transform
        pred = reverse_tta_transform(pred, rotation, flip)

        all_predictions.append(pred)

        # Free memory
        del vol_tensor, logits, probs

    # Ensemble average
    ensemble_pred = np.mean(all_predictions, axis=0)

    # Resize back to original shape
    pred_full = resize_volume(ensemble_pred, original_shape)

    return pred_full

# =============================================================================
# POST-PROCESSING (CRITICAL!)
# =============================================================================

def topology_aware_postprocess(prediction):
    """
    Aggressive topology-aware post-processing.

    The metric rewards:
    - Voxel accuracy
    - Surface connectivity (no gaps, holes, mergers)
    """
    print(f"  Input voxels: {(prediction > 0).sum():,}")

    # 1. Thresholding
    binary = (prediction > CFG.THRESHOLD).astype(np.uint8)
    print(f"  After threshold ({CFG.THRESHOLD}): {binary.sum():,}")

    # 2. Morphological Closing (fill gaps)
    if CFG.CLOSING_RADIUS > 0:
        struct = ndimage.generate_binary_structure(3, 1)
        # Dilate
        for _ in range(CFG.CLOSING_RADIUS):
            binary = ndimage.binary_dilation(binary, struct).astype(np.uint8)
        # Erode
        for _ in range(CFG.CLOSING_RADIUS):
            binary = ndimage.binary_erosion(binary, struct).astype(np.uint8)
        print(f"  After closing (r={CFG.CLOSING_RADIUS}): {binary.sum():,}")

    # 3. Connected Components Filtering
    structure = ndimage.generate_binary_structure(3, 3)  # 26-connectivity
    labeled, num_components = ndimage.label(binary, structure=structure)

    if num_components > 0:
        sizes = ndimage.sum(binary, labeled, range(1, num_components + 1))
        keep_labels = np.where(np.array(sizes) >= CFG.MIN_COMPONENT_SIZE)[0] + 1

        filtered = np.zeros_like(binary)
        for label_id in keep_labels:
            filtered[labeled == label_id] = 1

        removed = num_components - len(keep_labels)
        print(f"  CC Filter: {num_components} -> {len(keep_labels)} ({removed} removed)")
        binary = filtered

    print(f"  After CC filter (min={CFG.MIN_COMPONENT_SIZE}): {binary.sum():,}")

    # 4. Border Cleanup
    if CFG.BORDER_CLEANUP > 0:
        b = CFG.BORDER_CLEANUP
        binary[:b, :, :] = 0
        binary[-b:, :, :] = 0
        binary[:, :b, :] = 0
        binary[:, -b:, :] = 0
        binary[:, :, :b] = 0
        binary[:, :, -b:] = 0
        print(f"  After border cleanup ({b}px): {binary.sum():,}")

    return binary.astype(np.uint8)

# =============================================================================
# MAIN
# =============================================================================

def main():
    print("=" * 70)
    print("VESUVIUS V5-QUICK - Pretrained + Aggressive Post-Processing")
    print("=" * 70)
    print(f"Threshold: {CFG.THRESHOLD}")
    print(f"Min Component: {CFG.MIN_COMPONENT_SIZE}")
    print(f"Closing Radius: {CFG.CLOSING_RADIUS}")
    print(f"TTA: {len(CFG.TTA_ROTATIONS) * len(CFG.TTA_FLIPS)}x" if CFG.USE_TTA else "Disabled")
    print()

    # Load model
    print("Loading model...")
    model = load_model()

    # Get test files
    test_files = sorted(CFG.TEST_DIR.glob("*.tif"))
    print(f"Found {len(test_files)} test files")

    # Process each volume
    predictions_dir = CFG.OUTPUT_DIR / "predictions"
    predictions_dir.mkdir(exist_ok=True)

    for test_path in tqdm(test_files, desc="Processing"):
        print(f"\n{test_path.name}:")

        # Load volume
        volume = tifffile.imread(test_path)
        print(f"  Shape: {volume.shape}")

        # Predict with TTA
        pred_probs = predict_volume(model, volume)

        # Post-process
        pred_binary = topology_aware_postprocess(pred_probs)

        # Save
        output_path = predictions_dir / test_path.name
        tifffile.imwrite(output_path, pred_binary)

        # Free memory
        del volume, pred_probs, pred_binary
        gc.collect()
        torch.cuda.empty_cache()

    # Create submission
    print("\nCreating submission...")
    submission_path = CFG.OUTPUT_DIR / "submission.zip"

    with zipfile.ZipFile(submission_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for pred_file in predictions_dir.glob("*.tif"):
            zipf.write(pred_file, pred_file.name)

    print(f"\nDONE! Submission: {submission_path}")
    print(f"Size: {submission_path.stat().st_size / 1024 / 1024:.1f} MB")

if __name__ == "__main__":
    main()
