Source code for oineus.diff.sliced_wasserstein

import operator

import torch
import numpy as np

from .wasserstein_utils import _project_to_diagonal, _split_finite_essential, _match_essential_1d


def _sample_unit_directions(n_directions, device, dtype):
    """Sample n_directions unit vectors on the unit circle (actually half-circle is enough)."""
    # Sample angles in [0, pi) for half-circle
    angles = torch.rand(n_directions, device=device, dtype=dtype) * np.pi
    # Convert to unit vectors
    return torch.stack([torch.cos(angles), torch.sin(angles)], dim=1)


def _validate_n_directions(n_directions):
    try:
        n = operator.index(n_directions)
    except TypeError as exc:
        raise TypeError("n_directions must be an integer") from exc
    if n <= 0:
        raise ValueError("n_directions must be positive")
    return n


def _compute_slice_cost_standard(fin1, fin2, u):
    """
    Compute sliced Wasserstein cost for a single direction (standard version).

    Args:
        fin1: (n1, 2) finite points from dgm1
        fin2: (n2, 2) finite points from dgm2
        u: (2,) unit direction vector

    Returns:
        Scalar cost for this slice
    """
    n1, n2 = len(fin1), len(fin2)

    if n1 == 0 and n2 == 0:
        return torch.tensor(0.0, dtype=u.dtype, device=u.device)

    # Project finite points onto direction u
    proj1 = (fin1 @ u) if n1 > 0 else torch.tensor([], dtype=u.dtype, device=u.device)
    proj2 = (fin2 @ u) if n2 > 0 else torch.tensor([], dtype=u.dtype, device=u.device)

    # Add diagonal projections of opposite diagram
    if n2 > 0:
        diag_proj2 = _project_to_diagonal(fin2)
        proj1_diag = diag_proj2 @ u
    else:
        proj1_diag = torch.tensor([], dtype=u.dtype, device=u.device)

    if n1 > 0:
        diag_proj1 = _project_to_diagonal(fin1)
        proj2_diag = diag_proj1 @ u
    else:
        proj2_diag = torch.tensor([], dtype=u.dtype, device=u.device)

    # Augmented 1D measures
    L1 = torch.cat([proj1, proj1_diag]) if n1 > 0 or n2 > 0 else torch.tensor([], dtype=u.dtype, device=u.device)
    L2 = torch.cat([proj2, proj2_diag]) if n2 > 0 or n1 > 0 else torch.tensor([], dtype=u.dtype, device=u.device)

    if len(L1) == 0:
        return torch.tensor(0.0, dtype=u.dtype, device=u.device)

    # Sort
    L1_sorted = torch.sort(L1)[0]
    L2_sorted = torch.sort(L2)[0]

    # Match by rank and compute cost
    return torch.sum(torch.abs(L1_sorted - L2_sorted))


def _compute_slice_cost_corrected(fin1, fin2, u):
    """
    Compute diagonal-corrected sliced Wasserstein cost for a single direction.

    Args:
        fin1: (n1, 2) finite points from dgm1
        fin2: (n2, 2) finite points from dgm2
        u: (2,) unit direction vector

    Returns:
        Scalar cost for this slice
    """
    n1, n2 = len(fin1), len(fin2)

    # Handle empty diagram cases specially to avoid indexing issues
    if n1 == 0 and n2 == 0:
        return torch.tensor(0.0, dtype=u.dtype, device=u.device)

    if n1 == 0:
        # Cost: distance from each point in fin2 to its own diagonal projection
        diag_proj2 = _project_to_diagonal(fin2).detach()
        proj2 = fin2 @ u
        proj2_diag = diag_proj2 @ u
        return torch.sum(torch.abs(proj2 - proj2_diag))

    if n2 == 0:
        # Cost: distance from each point in fin1 to its own diagonal projection
        diag_proj1 = _project_to_diagonal(fin1).detach()
        proj1 = fin1 @ u
        proj1_diag = diag_proj1 @ u
        return torch.sum(torch.abs(proj1 - proj1_diag))

    # Project finite points onto direction u
    proj1 = (fin1 @ u) if n1 > 0 else torch.tensor([], dtype=u.dtype, device=u.device)
    proj2 = (fin2 @ u) if n2 > 0 else torch.tensor([], dtype=u.dtype, device=u.device)

    # Diagonal projections - DETACHED so they don't contribute to gradients
    if n1 > 0:
        diag_proj1 = _project_to_diagonal(fin1).detach()
        proj1_self_diag = diag_proj1 @ u
        proj2_diag_from_1 = diag_proj1 @ u
    else:
        proj1_self_diag = torch.tensor([], dtype=u.dtype, device=u.device)
        proj2_diag_from_1 = torch.tensor([], dtype=u.dtype, device=u.device)

    if n2 > 0:
        diag_proj2 = _project_to_diagonal(fin2).detach()
        proj2_self_diag = diag_proj2 @ u
        proj1_diag_from_2 = diag_proj2 @ u
    else:
        proj2_self_diag = torch.tensor([], dtype=u.dtype, device=u.device)
        proj1_diag_from_2 = torch.tensor([], dtype=u.dtype, device=u.device)

    # Build augmented lists
    L1 = torch.cat([proj1, proj1_diag_from_2]) if n1 > 0 or n2 > 0 else torch.tensor([], dtype=u.dtype, device=u.device)
    L2 = torch.cat([proj2, proj2_diag_from_1]) if n2 > 0 or n1 > 0 else torch.tensor([], dtype=u.dtype, device=u.device)

    if len(L1) == 0:
        return torch.tensor(0.0, dtype=u.dtype, device=u.device)

    # Sort and track indices
    L1_sorted, L1_indices = torch.sort(L1)
    L2_sorted, L2_indices = torch.sort(L2)

    # Vectorized cost computation using torch.where (vmap-compatible)
    # Determine which points are diagonal projections
    is_diag1 = (L1_indices >= n1)  # L1 points that are diagonal projections from dgm2
    is_diag2 = (L2_indices >= n2)  # L2 points that are diagonal projections from dgm1

    # Case 1: both diagonal -> cost = 0
    # Case 2: L1 is diagonal, L2 is real -> cost = |L2_real - L2_self_diag|
    # Case 3: L1 is real, L2 is diagonal -> cost = |L1_real - L1_self_diag|
    # Case 4: both real -> cost = |L1 - L2|

    # Default: standard matching cost (Case 4)
    costs = torch.abs(L1_sorted - L2_sorted)

    # Case 3: L1 is real (from dgm1), L2 is diagonal (from dgm1)
    # Cost should be |L1_real - L1_self_diag|
    # Use torch.where to avoid boolean indexing (vmap-compatible)
    real_idx1 = torch.clamp(L1_indices, 0, n1 - 1)  # Clamp for safe indexing
    cost_case3 = torch.abs(proj1[real_idx1] - proj1_self_diag[real_idx1])
    costs = torch.where(~is_diag1 & is_diag2, cost_case3, costs)

    # Case 2: L1 is diagonal (from dgm2), L2 is real (from dgm2)
    # Cost should be |L2_real - L2_self_diag|
    real_idx2 = torch.clamp(L2_indices, 0, n2 - 1)  # Clamp for safe indexing
    cost_case2 = torch.abs(proj2[real_idx2] - proj2_self_diag[real_idx2])
    costs = torch.where(is_diag1 & ~is_diag2, cost_case2, costs)

    # Case 1: both diagonal -> cost = 0
    costs = torch.where(is_diag1 & is_diag2, torch.zeros_like(costs), costs)

    return torch.sum(costs)


[docs] def sliced_wasserstein_distance(dgm1, dgm2, n_directions=100, ignore_inf_points=False): """ Sliced Wasserstein distance between two persistence diagrams. This is the standard sliced Wasserstein where diagonal projections participate in gradients. When a point p from dgm1 is matched to diag_proj(q) from dgm2, both p and q receive gradients. Args: dgm1: (N, 2) tensor of persistence diagram points (birth, death) dgm2: (M, 2) tensor of persistence diagram points (birth, death) n_directions: Number of random projection directions ignore_inf_points: If True, only consider finite points Returns: Scalar tensor with the sliced Wasserstein distance """ if len(dgm1) == 0 and len(dgm2) == 0: return torch.tensor(0.0, dtype=dgm1.dtype, device=dgm1.device) # Split into finite and essential fin1, ess1 = _split_finite_essential(dgm1) fin2, ess2 = _split_finite_essential(dgm2) total_cost = torch.tensor(0.0, dtype=dgm1.dtype, device=dgm1.device) # Handle essential points if requested if not ignore_inf_points: ess_names = ["(finite, +inf)", "(finite, -inf)", "(+inf, finite)", "(-inf, finite)"] for coords1, coords2, name in zip(ess1, ess2, ess_names): if len(coords1) != len(coords2): raise ValueError( f"Essential point cardinalities must match. " f"Got {len(coords1)} and {len(coords2)} points with {name}." ) if len(coords1) > 0: total_cost = total_cost + _match_essential_1d(coords1, coords2) # Handle finite points if len(fin1) == 0 and len(fin2) == 0: return total_cost # Sample random directions n_directions = _validate_n_directions(n_directions) directions = _sample_unit_directions(n_directions, dgm1.device, dgm1.dtype) # Vectorized computation over directions using vmap slice_costs = torch.vmap(lambda u: _compute_slice_cost_standard(fin1, fin2, u))(directions) total_cost = total_cost + slice_costs.mean() return total_cost
[docs] def sliced_wasserstein_distance_diag_corrected(dgm1, dgm2, n_directions=100, ignore_inf_points=False): """ Diagonal-corrected sliced Wasserstein distance. This variant makes the sliced distance behave like true Wasserstein at the diagonal. The 1D rank-matching used by the standard sliced distance can pair an off-diagonal point p with the diagonal projection of a *different* point p'; true Wasserstein never does this -- such skew edges can always be straightened to ``p <-> diag(p)`` without raising the cost. The correction re-charges those matches: 1. A point matched to a diagonal slot is charged ``|proj(p) - proj(diag(p))|`` -- its distance to *its own* diagonal projection -- not to whichever point's diagonal stand-in the sort aligned it with. 2. ``diag(p)`` is held constant (detached), so the gradient flows only to p, not to the unrelated point whose stand-in it happened to match. 3. A match between two diagonal stand-ins costs zero. Args: dgm1: (N, 2) tensor of persistence diagram points (birth, death) dgm2: (M, 2) tensor of persistence diagram points (birth, death) n_directions: Number of random projection directions ignore_inf_points: If True, only consider finite points Returns: Scalar tensor with the diagonal-corrected sliced Wasserstein distance """ if len(dgm1) == 0 and len(dgm2) == 0: return torch.tensor(0.0, dtype=dgm1.dtype, device=dgm1.device) # Split into finite and essential fin1, ess1 = _split_finite_essential(dgm1) fin2, ess2 = _split_finite_essential(dgm2) total_cost = torch.tensor(0.0, dtype=dgm1.dtype, device=dgm1.device) # Handle essential points if requested if not ignore_inf_points: ess_names = ["(finite, +inf)", "(finite, -inf)", "(+inf, finite)", "(-inf, finite)"] for coords1, coords2, name in zip(ess1, ess2, ess_names): if len(coords1) != len(coords2): raise ValueError( f"Essential point cardinalities must match. " f"Got {len(coords1)} and {len(coords2)} points with {name}." ) if len(coords1) > 0: total_cost = total_cost + _match_essential_1d(coords1, coords2) # Handle finite points if len(fin1) == 0 and len(fin2) == 0: return total_cost # Sample random directions n_directions = _validate_n_directions(n_directions) directions = _sample_unit_directions(n_directions, dgm1.device, dgm1.dtype) # Vectorized computation over directions using vmap slice_costs = torch.vmap(lambda u: _compute_slice_cost_corrected(fin1, fin2, u))(directions) total_cost = total_cost + slice_costs.mean() return total_cost