"""
Vesuvius Challenge - Metrics
Surface Dice, TopoScore, VOI implementations
"""

import numpy as np
from scipy import ndimage
from typing import Tuple, Dict

def surface_dice(pred: np.ndarray, target: np.ndarray, tolerance: float = 1.0) -> float:
    """
    Surface Dice coefficient.
    Measures overlap between predicted and ground truth surfaces.

    Args:
        pred: Binary prediction mask (H, W, D)
        target: Binary ground truth mask (H, W, D)
        tolerance: Distance tolerance in voxels

    Returns:
        Surface Dice score [0, 1]
    """
    # Get surfaces (boundary voxels)
    pred_surface = get_surface(pred)
    target_surface = get_surface(target)

    if pred_surface.sum() == 0 and target_surface.sum() == 0:
        return 1.0
    if pred_surface.sum() == 0 or target_surface.sum() == 0:
        return 0.0

    # Distance transform
    pred_dist = ndimage.distance_transform_edt(~pred_surface)
    target_dist = ndimage.distance_transform_edt(~target_surface)

    # Count surface voxels within tolerance
    pred_to_target = (pred_dist[target_surface] <= tolerance).sum()
    target_to_pred = (target_dist[pred_surface] <= tolerance).sum()

    # Surface Dice
    dice = (pred_to_target + target_to_pred) / (pred_surface.sum() + target_surface.sum())
    return float(dice)


def get_surface(mask: np.ndarray) -> np.ndarray:
    """Extract surface voxels from binary mask."""
    eroded = ndimage.binary_erosion(mask)
    surface = mask & ~eroded
    return surface


def dice_coefficient(pred: np.ndarray, target: np.ndarray) -> float:
    """Standard Dice coefficient."""
    if pred.sum() == 0 and target.sum() == 0:
        return 1.0
    intersection = (pred & target).sum()
    return 2 * intersection / (pred.sum() + target.sum())


def topo_score(pred: np.ndarray, target: np.ndarray) -> float:
    """
    Topological score.
    Measures preservation of connected components and holes.

    Args:
        pred: Binary prediction mask
        target: Binary ground truth mask

    Returns:
        TopoScore [0, 1]
    """
    # Count connected components
    pred_labels, pred_num = ndimage.label(pred)
    target_labels, target_num = ndimage.label(target)

    # Penalize difference in component count
    component_diff = abs(pred_num - target_num)
    component_score = 1.0 / (1.0 + component_diff)

    # Count holes (connected components in inverse)
    pred_holes, pred_hole_num = ndimage.label(~pred)
    target_holes, target_hole_num = ndimage.label(~target)

    hole_diff = abs(pred_hole_num - target_hole_num)
    hole_score = 1.0 / (1.0 + hole_diff)

    # Euler characteristic
    pred_euler = euler_number_3d(pred)
    target_euler = euler_number_3d(target)
    euler_diff = abs(pred_euler - target_euler)
    euler_score = 1.0 / (1.0 + euler_diff * 0.1)

    return (component_score + hole_score + euler_score) / 3


def euler_number_3d(mask: np.ndarray) -> int:
    """
    Compute Euler number for 3D binary mask.
    Euler = V - E + F (vertices - edges + faces)
    """
    from scipy.ndimage import convolve

    # 2x2x2 kernel for counting configurations
    kernel = np.ones((2, 2, 2))
    counts = convolve(mask.astype(int), kernel, mode='constant', cval=0)

    # Simplified Euler approximation
    n1 = (counts == 1).sum()  # corners
    n2 = (counts == 2).sum()  # edges
    n3 = (counts == 3).sum()  # faces
    n4 = (counts == 4).sum()  # inner corners

    euler = n1 - n2 + n3 - n4
    return int(euler // 8)


def variation_of_information(pred: np.ndarray, target: np.ndarray) -> float:
    """
    Variation of Information (VOI).
    Information-theoretic measure of segmentation quality.
    Lower is better, normalized to [0, 1] where 0 is perfect.

    Args:
        pred: Labeled prediction
        target: Labeled ground truth

    Returns:
        VOI score [0, 1] (lower is better)
    """
    # Flatten
    pred_flat = pred.flatten()
    target_flat = target.flatten()
    n = len(pred_flat)

    # Compute joint histogram
    pred_labels = np.unique(pred_flat)
    target_labels = np.unique(target_flat)

    # Entropy of prediction
    h_pred = 0
    for label in pred_labels:
        p = (pred_flat == label).sum() / n
        if p > 0:
            h_pred -= p * np.log2(p)

    # Entropy of target
    h_target = 0
    for label in target_labels:
        p = (target_flat == label).sum() / n
        if p > 0:
            h_target -= p * np.log2(p)

    # Mutual information
    mi = 0
    for p_label in pred_labels:
        for t_label in target_labels:
            joint = ((pred_flat == p_label) & (target_flat == t_label)).sum() / n
            if joint > 0:
                p_pred = (pred_flat == p_label).sum() / n
                p_target = (target_flat == t_label).sum() / n
                mi += joint * np.log2(joint / (p_pred * p_target))

    # VOI = H(pred) + H(target) - 2*MI
    voi = h_pred + h_target - 2 * mi

    # Normalize to [0, 1]
    max_voi = h_pred + h_target
    if max_voi > 0:
        voi_normalized = voi / max_voi
    else:
        voi_normalized = 0

    return float(voi_normalized)


def vesuvius_score(pred: np.ndarray, target: np.ndarray,
                   weights: Tuple[float, float, float] = (0.5, 0.3, 0.2)) -> Dict[str, float]:
    """
    Combined Vesuvius Challenge score.

    Args:
        pred: Prediction mask
        target: Ground truth mask
        weights: (surface_dice_weight, topo_weight, voi_weight)

    Returns:
        Dict with individual scores and combined score
    """
    sd = surface_dice(pred, target)
    ts = topo_score(pred, target)
    voi = variation_of_information(pred.astype(int), target.astype(int))

    # VOI is lower-is-better, convert to higher-is-better
    voi_score = 1 - voi

    combined = weights[0] * sd + weights[1] * ts + weights[2] * voi_score

    return {
        'surface_dice': sd,
        'topo_score': ts,
        'voi': voi,
        'voi_score': voi_score,
        'combined': combined
    }


def rmsle(pred: np.ndarray, target: np.ndarray) -> float:
    """Root Mean Squared Logarithmic Error."""
    pred = np.clip(pred, 0, None)
    target = np.clip(target, 0, None)
    return np.sqrt(np.mean((np.log1p(pred) - np.log1p(target)) ** 2))


if __name__ == "__main__":
    # Test
    pred = np.random.rand(64, 64, 64) > 0.5
    target = np.random.rand(64, 64, 64) > 0.5

    scores = vesuvius_score(pred, target)
    print("Test scores:")
    for k, v in scores.items():
        print(f"  {k}: {v:.4f}")
