"""
Vesuvius Challenge - Post-processing
Topology-aware post-processing for 3D segmentation
"""

import numpy as np
from scipy import ndimage
from scipy.ndimage import binary_closing, binary_opening, binary_fill_holes
from scipy.ndimage import distance_transform_edt, label
from typing import Tuple, Optional
import warnings

warnings.filterwarnings('ignore')


def remove_small_components(mask: np.ndarray, min_size: int = 100) -> np.ndarray:
    """
    Remove connected components smaller than min_size.

    Args:
        mask: Binary 3D mask
        min_size: Minimum component size in voxels

    Returns:
        Cleaned mask
    """
    labeled, num_features = label(mask)
    if num_features == 0:
        return mask

    component_sizes = ndimage.sum(mask, labeled, range(1, num_features + 1))
    small_components = np.where(component_sizes < min_size)[0] + 1

    cleaned = mask.copy()
    for comp_id in small_components:
        cleaned[labeled == comp_id] = 0

    return cleaned


def fill_holes_3d(mask: np.ndarray, max_hole_size: int = 1000) -> np.ndarray:
    """
    Fill small holes in 3D mask.

    Args:
        mask: Binary 3D mask
        max_hole_size: Maximum hole size to fill

    Returns:
        Mask with holes filled
    """
    # Invert to find holes
    inverted = ~mask

    # Label holes
    labeled_holes, num_holes = label(inverted)

    filled = mask.copy()
    for hole_id in range(1, num_holes + 1):
        hole_mask = labeled_holes == hole_id
        hole_size = hole_mask.sum()

        # Check if hole is internal (not touching boundary)
        touches_boundary = (
            hole_mask[0, :, :].any() or hole_mask[-1, :, :].any() or
            hole_mask[:, 0, :].any() or hole_mask[:, -1, :].any() or
            hole_mask[:, :, 0].any() or hole_mask[:, :, -1].any()
        )

        if not touches_boundary and hole_size <= max_hole_size:
            filled[hole_mask] = 1

    return filled


def morphological_smoothing(mask: np.ndarray,
                            closing_iterations: int = 2,
                            opening_iterations: int = 1) -> np.ndarray:
    """
    Apply morphological operations to smooth surfaces.

    Args:
        mask: Binary 3D mask
        closing_iterations: Number of closing iterations (fills gaps)
        opening_iterations: Number of opening iterations (removes noise)

    Returns:
        Smoothed mask
    """
    # Closing first (fills small gaps)
    result = mask.copy()
    for _ in range(closing_iterations):
        result = binary_closing(result)

    # Opening (removes small protrusions)
    for _ in range(opening_iterations):
        result = binary_opening(result)

    return result.astype(mask.dtype)


def ensure_surface_continuity(mask: np.ndarray, max_gap: int = 5) -> np.ndarray:
    """
    Bridge small gaps in surfaces using distance transform.

    Args:
        mask: Binary 3D mask
        max_gap: Maximum gap size to bridge

    Returns:
        Mask with bridged gaps
    """
    # Distance from surface
    dist = distance_transform_edt(~mask)

    # Fill voxels close to surface
    bridged = mask | (dist <= max_gap / 2)

    # Smooth result
    bridged = binary_closing(bridged, iterations=1)

    return bridged.astype(mask.dtype)


def keep_largest_component(mask: np.ndarray, n_largest: int = 1) -> np.ndarray:
    """
    Keep only the N largest connected components.

    Args:
        mask: Binary 3D mask
        n_largest: Number of components to keep

    Returns:
        Mask with only largest components
    """
    labeled, num_features = label(mask)
    if num_features == 0:
        return mask

    component_sizes = ndimage.sum(mask, labeled, range(1, num_features + 1))
    largest_indices = np.argsort(component_sizes)[-n_largest:] + 1

    result = np.zeros_like(mask)
    for idx in largest_indices:
        result[labeled == idx] = 1

    return result


def surface_aware_threshold(probs: np.ndarray,
                            base_threshold: float = 0.5,
                            gradient_weight: float = 0.1) -> np.ndarray:
    """
    Threshold probabilities with surface-aware adjustment.
    Uses gradient magnitude to lower threshold near edges.

    Args:
        probs: Probability map [0, 1]
        base_threshold: Base threshold value
        gradient_weight: Weight for gradient adjustment

    Returns:
        Binary mask
    """
    # Compute gradient magnitude
    gx = np.gradient(probs, axis=0)
    gy = np.gradient(probs, axis=1)
    gz = np.gradient(probs, axis=2)
    gradient_mag = np.sqrt(gx**2 + gy**2 + gz**2)

    # Normalize gradient
    gradient_mag = gradient_mag / (gradient_mag.max() + 1e-8)

    # Adjust threshold (lower near edges)
    adjusted_threshold = base_threshold - gradient_weight * gradient_mag

    return (probs >= adjusted_threshold).astype(np.uint8)


def postprocess_vesuvius(pred: np.ndarray,
                         min_component_size: int = 100,
                         max_hole_size: int = 500,
                         closing_iterations: int = 2,
                         opening_iterations: int = 1,
                         keep_n_largest: Optional[int] = None,
                         bridge_gaps: bool = False,
                         max_gap: int = 3) -> np.ndarray:
    """
    Full post-processing pipeline for Vesuvius Challenge.

    Args:
        pred: Raw prediction (binary or probability)
        min_component_size: Remove components smaller than this
        max_hole_size: Fill holes smaller than this
        closing_iterations: Morphological closing iterations
        opening_iterations: Morphological opening iterations
        keep_n_largest: Keep only N largest components (None = keep all)
        bridge_gaps: Whether to bridge small gaps
        max_gap: Maximum gap size to bridge

    Returns:
        Post-processed binary mask
    """
    # Ensure binary
    if pred.dtype == np.float32 or pred.dtype == np.float64:
        pred = (pred > 0.5).astype(np.uint8)
    else:
        pred = pred.astype(np.uint8)

    result = pred.copy()

    # 1. Morphological smoothing
    if closing_iterations > 0 or opening_iterations > 0:
        result = morphological_smoothing(result, closing_iterations, opening_iterations)

    # 2. Remove small components
    if min_component_size > 0:
        result = remove_small_components(result, min_component_size)

    # 3. Fill holes
    if max_hole_size > 0:
        result = fill_holes_3d(result, max_hole_size)

    # 4. Bridge gaps (optional)
    if bridge_gaps:
        result = ensure_surface_continuity(result, max_gap)

    # 5. Keep largest components (optional)
    if keep_n_largest is not None:
        result = keep_largest_component(result, keep_n_largest)

    return result.astype(np.uint8)


def tta_ensemble(predictions: list, mode: str = 'mean') -> np.ndarray:
    """
    Ensemble multiple TTA predictions.

    Args:
        predictions: List of prediction arrays
        mode: 'mean', 'median', or 'vote'

    Returns:
        Ensembled prediction
    """
    stacked = np.stack(predictions, axis=0)

    if mode == 'mean':
        return np.mean(stacked, axis=0)
    elif mode == 'median':
        return np.median(stacked, axis=0)
    elif mode == 'vote':
        return (np.mean(stacked > 0.5, axis=0) >= 0.5).astype(np.float32)
    else:
        raise ValueError(f"Unknown mode: {mode}")


def apply_tta(model_fn, volume: np.ndarray, include_flips: bool = True) -> np.ndarray:
    """
    Apply Test Time Augmentation.

    Args:
        model_fn: Function that takes volume and returns prediction
        volume: Input 3D volume
        include_flips: Whether to include flip augmentations

    Returns:
        TTA-ensembled prediction
    """
    predictions = []

    # Original
    predictions.append(model_fn(volume))

    # 4 rotations around Z axis
    for k in range(1, 4):
        rotated = np.rot90(volume, k=k, axes=(1, 2))
        pred = model_fn(rotated)
        pred = np.rot90(pred, k=-k, axes=(1, 2))
        predictions.append(pred)

    if include_flips:
        # Flip X
        flipped = np.flip(volume, axis=0)
        pred = model_fn(flipped)
        pred = np.flip(pred, axis=0)
        predictions.append(pred)

        # Flip Y
        flipped = np.flip(volume, axis=1)
        pred = model_fn(flipped)
        pred = np.flip(pred, axis=1)
        predictions.append(pred)

        # Flip Z
        flipped = np.flip(volume, axis=2)
        pred = model_fn(flipped)
        pred = np.flip(pred, axis=2)
        predictions.append(pred)

    return tta_ensemble(predictions, mode='mean')


if __name__ == "__main__":
    # Test
    print("Testing postprocessing...")

    # Create noisy test prediction
    pred = np.random.rand(64, 64, 64) > 0.6

    # Add some small noise
    noise = np.random.rand(64, 64, 64) > 0.95
    pred = pred | noise

    print(f"Before: {pred.sum()} voxels")

    result = postprocess_vesuvius(
        pred,
        min_component_size=50,
        max_hole_size=100,
        closing_iterations=2,
        opening_iterations=1
    )

    print(f"After: {result.sum()} voxels")
    print("Postprocessing test complete!")
