"""Differentiable Wasserstein cost for persistence diagrams.
Single function that calls the new detailed Hera matching once on the full
diagrams (essentials included), then walks the bucketed result and rebuilds
the cost in torch so gradients flow through every matched pair (finite-to-
finite, finite-to-diagonal, and essential-to-essential).
"""
import numpy as np
import torch
from .. import _oineus
from .._dtype import as_real_numpy
# Mapping from essential family attribute name to the index of the finite
# coordinate axis (0 = birth, 1 = death). Used to compute the per-pair
# 1D distance: the infinite axis matches itself with cost 0.
_ESSENTIAL_FINITE_AXIS = (
("inf_death", 0),
("neg_inf_death", 0),
("inf_birth", 1),
("neg_inf_birth", 1),
)
[docs]
def wasserstein_cost(
dgm_a: torch.Tensor,
dgm_b: torch.Tensor,
wasserstein_q: float = 1.0,
wasserstein_delta: float = 0.05,
ignore_inf_points: bool = True,
internal_p: float = float("inf"),
) -> torch.Tensor:
"""Differentiable Wasserstein cost between two persistence diagrams.
Returns ``cost = sum_pair dist(p_a, p_b) ** wasserstein_q`` so that
``Wasserstein_q distance == cost ** (1 / wasserstein_q)``.
Args:
dgm_a: ``(N, 2)`` tensor of (birth, death) points.
dgm_b: ``(M, 2)`` tensor of (birth, death) points.
wasserstein_q: Wasserstein power (default 1.0 → W_1).
wasserstein_delta: Hera relative-error parameter (must be > 0).
ignore_inf_points: If True, drop essential (±inf) points before
matching. If False, every essential family must have equal
cardinalities on both sides; the matching pairs them by
sorted-rank of the finite coordinate.
internal_p: Ground metric in the (birth, death) plane. ``inf``
selects L_∞.
Returns:
Scalar tensor with the cost. Gradients flow through every matched
finite point on both sides (and through the finite coord of every
matched essential). Diagonal projections are detached.
"""
device = dgm_a.device
dtype = dgm_a.dtype
# Convert to numpy and call the bucketed Hera matching once.
dgm_a_np = as_real_numpy(dgm_a.detach().cpu().numpy())
dgm_b_np = as_real_numpy(dgm_b.detach().cpu().numpy())
internal_p_hera = -1.0 if np.isinf(internal_p) else internal_p
matching = _oineus.wasserstein_matching_detailed(
dgm_a_np, dgm_b_np,
q=wasserstein_q,
wasserstein_delta=wasserstein_delta,
internal_p=internal_p_hera,
ignore_inf_points=ignore_inf_points,
)
total = torch.zeros((), dtype=dtype, device=device)
def _pair_dist(pts_a, pts_b):
diff = pts_a - pts_b
if np.isinf(internal_p):
return torch.max(torch.abs(diff), dim=1)[0]
return torch.sum(torch.abs(diff) ** internal_p, dim=1) ** (1.0 / internal_p)
# 1. finite-to-finite
ftf = matching.finite_to_finite
if ftf.shape[0] > 0:
ia = torch.from_numpy(ftf[:, 0]).to(device).long()
ib = torch.from_numpy(ftf[:, 1]).to(device).long()
total = total + torch.sum(_pair_dist(dgm_a[ia], dgm_b[ib]) ** wasserstein_q)
# 2. finite-to-diagonal (a side)
if matching.a_to_diagonal.shape[0] > 0:
ia = torch.from_numpy(matching.a_to_diagonal).to(device).long()
pts = dgm_a[ia]
mid = ((pts[:, 0] + pts[:, 1]) / 2.0).detach()
diag_proj = torch.stack([mid, mid], dim=1)
total = total + torch.sum(_pair_dist(pts, diag_proj) ** wasserstein_q)
# 3. finite-to-diagonal (b side)
if matching.b_to_diagonal.shape[0] > 0:
ib = torch.from_numpy(matching.b_to_diagonal).to(device).long()
pts = dgm_b[ib]
mid = ((pts[:, 0] + pts[:, 1]) / 2.0).detach()
diag_proj = torch.stack([mid, mid], dim=1)
total = total + torch.sum(_pair_dist(pts, diag_proj) ** wasserstein_q)
# 4. essentials, per family — the shared infinite coord contributes 0
# to the ground metric for any internal_p, so cost is just
# |finite_a - finite_b| ** wasserstein_q on the finite axis.
for name, axis in _ESSENTIAL_FINITE_AXIS:
pairs = getattr(matching.essential, name)
if pairs.shape[0] == 0:
continue
ia = torch.from_numpy(pairs[:, 0]).to(device).long()
ib = torch.from_numpy(pairs[:, 1]).to(device).long()
d = torch.abs(dgm_a[ia, axis] - dgm_b[ib, axis])
total = total + torch.sum(d ** wasserstein_q)
return total