oineus.diff.sliced_wasserstein_distance_diag_corrected

oineus.diff.sliced_wasserstein_distance_diag_corrected(dgm1, dgm2, n_directions=100, ignore_inf_points=False)[source]

Diagonal-corrected sliced Wasserstein distance.

This variant makes the sliced distance behave like true Wasserstein at the diagonal. The 1D rank-matching used by the standard sliced distance can pair an off-diagonal point p with the diagonal projection of a different point p’; true Wasserstein never does this – such skew edges can always be straightened to p <-> diag(p) without raising the cost. The correction re-charges those matches:

  1. A point matched to a diagonal slot is charged |proj(p) - proj(diag(p))| – its distance to its own diagonal projection – not to whichever point’s diagonal stand-in the sort aligned it with.

  2. diag(p) is held constant (detached), so the gradient flows only to p, not to the unrelated point whose stand-in it happened to match.

  3. A match between two diagonal stand-ins costs zero.

Parameters:
  • dgm1 – (N, 2) tensor of persistence diagram points (birth, death)

  • dgm2 – (M, 2) tensor of persistence diagram points (birth, death)

  • n_directions – Number of random projection directions

  • ignore_inf_points – If True, only consider finite points

Returns:

Scalar tensor with the diagonal-corrected sliced Wasserstein distance