"""Differentiable persistence diagrams.
Two `gradient_method` options:
- "dgm-loss": gradient flows only through the (birth_simplex,
death_simplex) pair defining each diagram point. The
forward path reduces one decomposition (hom or coh,
chosen by `dualize`) with the cheapest recipe -
parallel + clearing, R only - and the backward is a
scatter into the fil_values gradient.
- "crit-sets": gradient propagates through the full critical set of
each moved pair, with conflicts resolved by the
selected `conflict_strategy`. The forward reduces one
side with parallel + clearing + V + restore_ELZ in
`dims_to_backprop` so the backward can recover U on
demand without re-reducing. The other decomposition
is reduced lazily in backward only if
`determine_needed_matrices` says we need it.
Phase 1 of the refactor only supports `include_inf_points=False`.
Phase 2 will expose a split index-diagram from C++ for the
inf-points case so it can be handled without per-entry validity
masks.
"""
from .. import _oineus
import torch
import numpy as np
from .diff_filtration import DiffFiltration
from ._reduction_policy import default_dualize_for_filtration
from .top_optimizer import TopologyOptimizer
_STRATEGY_MAP = {
"avg": _oineus.ConflictStrategy.Avg,
"max": _oineus.ConflictStrategy.Max,
"sum": _oineus.ConflictStrategy.Sum,
"fca": _oineus.ConflictStrategy.FixCritAvg,
}
_U_STRATEGY_MAP = {
"auto": _oineus.UStrategy.Auto,
"row_partial": _oineus.UStrategy.RowPartial,
"legacy_in_band": _oineus.UStrategy.LegacyInBand,
}
def _resolve_strategy(strategy):
if isinstance(strategy, _oineus.ConflictStrategy):
return strategy
try:
return _STRATEGY_MAP[strategy.lower()]
except (AttributeError, KeyError):
raise ValueError(
f"unknown conflict_strategy {strategy!r}; expected one of "
f"{sorted(_STRATEGY_MAP)} or an _oineus.ConflictStrategy"
)
def _resolve_u_strategy(u_strategy):
if u_strategy is None:
return _oineus.UStrategy.Auto
if isinstance(u_strategy, _oineus.UStrategy):
return u_strategy
try:
return _U_STRATEGY_MAP[u_strategy.lower()]
except (AttributeError, KeyError):
raise ValueError(
f"unknown u_strategy {u_strategy!r}; expected one of "
f"{sorted(_U_STRATEGY_MAP)} or an _oineus.UStrategy"
)
def determine_needed_matrices(dgm_grad, negate: bool):
"""Return `(v_hom, u_hom, v_coh, u_coh)`: which of the four
matrices the crit-sets backward needs given a diagram gradient.
Single torch op over `dgm_grad`; cheap and GPU-friendly.
A positive sign on birth/death pushes the value down (we minimize
the loss), a negative sign pushes it up. Mapping the sign to the
move direction and the move direction to the matrix (paper:
decrease_birth -> U_coh, increase_birth -> V_coh,
increase_death -> U_hom, decrease_death -> V_hom) gives the
layout below. `negate=True` flips value-direction vs
filtration-direction, so the matrix assignment swaps accordingly.
"""
if dgm_grad.numel() == 0:
return False, False, False, False
mn, mx = torch.aminmax(dgm_grad, dim=0)
flags = torch.cat([mx > 0, mn < 0]).tolist()
if negate:
v_coh, u_hom, u_coh, v_hom = flags
else:
u_coh, v_hom, v_coh, u_hom = flags
return v_hom, u_hom, v_coh, u_coh
def _select_u_moves(idx_t, cur_t, tgt_t, move_t, *, side, negate):
"""Pick the rows + bounds the U-side walker on `side` will read.
Filters moves to those in the U-needing direction:
hom side, non-negate: increase_death (tgt > cur).
coh side, non-negate: decrease_birth (tgt < cur).
negate flips both.
Returns (rows_fil_idx_list, bounds_list) ready for
ensure_has_u_hom / ensure_has_u_coh.
"""
if side == "hom":
u_dir_mask = (tgt_t < cur_t) if negate else (tgt_t > cur_t)
else:
u_dir_mask = (tgt_t > cur_t) if negate else (tgt_t < cur_t)
sel = move_t & u_dir_mask
if not bool(sel.any().item()):
return [], []
rows = idx_t[sel].detach().cpu().numpy().astype(np.uintp).tolist()
bounds = tgt_t[sel].detach().cpu().to(torch.float64).numpy().tolist()
return rows, bounds
def _forward_reduce(top_opt, dualize):
"""Reduce the chosen side with the recipe baked into the
optimizer (recipe was decided at construction time)."""
if dualize:
top_opt.ensure_coh_reduced()
return top_opt.cohomology_decomposition_ref()
top_opt.ensure_hom_reduced()
return top_opt.homology_decomposition_ref()
def _backward_dgm_loss(ctx, grad_output):
index_dgm, fil_values = ctx.saved_tensors
grad_vals = torch.zeros_like(fil_values)
grad_vals.scatter_add_(0, index_dgm[:, 0], grad_output[:, 0])
grad_vals.scatter_add_(0, index_dgm[:, 1], grad_output[:, 1])
return (grad_vals,) + (None,) * 7
def _backward_crit_sets(ctx, grad_output):
index_dgm, fil_values = ctx.saved_tensors
top_opt = ctx.top_opt
negate = ctx.negate
b_idx = index_dgm[:, 0]
d_idx = index_dgm[:, 1]
b_cur = fil_values[b_idx]
d_cur = fil_values[d_idx]
b_tgt = b_cur - ctx.step_size * grad_output[:, 0]
d_tgt = d_cur - ctx.step_size * grad_output[:, 1]
b_move = b_tgt != b_cur
d_move = d_tgt != d_cur
v_hom, u_hom, v_coh, u_coh = determine_needed_matrices(grad_output, negate)
if v_hom or u_hom:
top_opt.ensure_hom_reduced()
if v_coh or u_coh:
top_opt.ensure_coh_reduced()
if u_hom:
rows, bounds = _select_u_moves(d_idx, d_cur, d_tgt, d_move,
side="hom", negate=negate)
top_opt.ensure_has_u_hom(ctx.dim, rows, bounds)
if u_coh:
rows, bounds = _select_u_moves(b_idx, b_cur, b_tgt, b_move,
side="coh", negate=negate)
top_opt.ensure_has_u_coh(ctx.dim, rows, bounds)
flat_idx = torch.cat([b_idx[b_move], d_idx[d_move]])
flat_tgt = torch.cat([b_tgt[b_move], d_tgt[d_move]])
grad_vals = torch.zeros_like(fil_values)
if flat_idx.numel() == 0:
return (grad_vals,) + (None,) * 7
flat_idx_np = flat_idx.detach().cpu().numpy().astype(np.uintp).tolist()
flat_tgt_np = flat_tgt.detach().cpu().to(torch.float64).numpy().tolist()
# crit_sets_apply handles the dispatch reduction (ensure_hom_reduced)
# internally and raises if the optimizer is dgm-loss only.
indvals = top_opt.crit_sets_apply(flat_idx_np, flat_tgt_np, ctx.strategy)
out_idx_np = np.asarray(indvals.indices_array(), copy=True)
out_tgt_np = np.asarray(indvals.values_array(), copy=True)
if out_idx_np.size == 0:
return (grad_vals,) + (None,) * 7
idx_t = torch.from_numpy(out_idx_np.astype(np.int64)).to(device=fil_values.device)
tgt_t = torch.from_numpy(out_tgt_np).to(dtype=fil_values.dtype,
device=fil_values.device)
if ctx.strategy == _oineus.ConflictStrategy.Sum:
contrib = fil_values[idx_t] - tgt_t
grad_vals.scatter_add_(0, idx_t, contrib)
else:
grad_vals[idx_t] = fil_values[idx_t] - tgt_t
return (grad_vals,) + (None,) * 7
class _PDHelper(torch.autograd.Function):
"""One autograd Function per (dim, gradient_method) pair. Forward
subscripts fil.values at birth/death indices; backward dispatches
to dgm-loss (scatter) or crit-sets."""
@staticmethod
def forward(ctx, fil_values, top_opt, nondiff_dgms, dim,
gradient_method, step_size, strategy, negate):
index_dgm = nondiff_dgms.index_diagram_in_dimension(
dim, as_numpy=True).astype(np.int64)
index_dgm = torch.from_numpy(index_dgm).to(fil_values.device)
if index_dgm.numel() == 0:
diagram = torch.zeros((0, 2), dtype=fil_values.dtype,
device=fil_values.device)
else:
diagram = fil_values[index_dgm]
ctx.save_for_backward(index_dgm, fil_values)
ctx.top_opt = top_opt
ctx.dim = dim
ctx.gradient_method = gradient_method
ctx.step_size = step_size
ctx.strategy = strategy
ctx.negate = negate
return diagram
@staticmethod
def backward(ctx, grad_output):
if ctx.gradient_method == "dgm-loss":
return _backward_dgm_loss(ctx, grad_output)
if ctx.gradient_method == "crit-sets":
return _backward_crit_sets(ctx, grad_output)
raise RuntimeError(f"Unknown gradient method: {ctx.gradient_method}")
[docs]
class PersistenceDiagrams:
"""Container for differentiable persistence diagrams in all dimensions.
Usage:
dgms = persistence_diagram(fil)
dgm1 = dgms[1] # H1 diagram as tensor (N, 2)
loss = (dgm1[:, 1] - dgm1[:, 0]).pow(2).sum()
loss.backward()
"""
def __init__(self, fil: DiffFiltration, *, dualize, include_inf_points,
gradient_method, step_size, conflict_strategy,
n_threads, u_strategy, dims_to_backprop):
if not isinstance(fil.values, torch.Tensor):
raise TypeError("fil.values must be a torch.Tensor for "
"differentiable diagrams")
if include_inf_points:
raise NotImplementedError(
"include_inf_points=True is deferred to Phase 2 of the "
"differentiable-diagram refactor (will need a split "
"index-diagram return type from C++ to avoid per-entry "
"validity masks). For now, request finite points only.")
if dualize is None:
dualize = default_dualize_for_filtration(fil.under_fil)
if dims_to_backprop is None:
# Cover all simplex dims so partial-U is admissible
# everywhere. For H_k pairs the birth simplex has dim k
# and the death simplex has dim k+1, so we need
# range(max_dim + 1).
dims_to_backprop = list(range(fil.max_dim + 1))
n_threads = max(1, int(n_threads) if n_threads is not None else 1)
strategy = _resolve_strategy(conflict_strategy)
u_strategy_enum = _resolve_u_strategy(u_strategy)
negate = bool(fil.negate)
with_crit_sets = gradient_method == "crit-sets"
top_opt = TopologyOptimizer(
fil,
with_crit_sets=with_crit_sets,
dims_to_restore_elz=dims_to_backprop,
n_threads=n_threads,
u_strategy=u_strategy_enum,
)
decmp = _forward_reduce(top_opt, dualize)
nondiff_dgms = decmp.diagram(fil.under_fil,
include_inf_points=False)
self._fil = fil
self._top_opt = top_opt
self._dualize = dualize
self._gradient_method = gradient_method
self._diagrams = {
dim: _PDHelper.apply(
fil.values, top_opt, nondiff_dgms, dim,
gradient_method, step_size, strategy, negate,
)
for dim in range(fil.max_dim)
}
def __getitem__(self, dim: int) -> torch.Tensor:
if dim not in self._diagrams:
raise KeyError(
f"No diagram for dimension {dim}. "
f"Available: {list(self._diagrams.keys())}")
return self._diagrams[dim]
def __contains__(self, dim: int) -> bool:
return dim in self._diagrams
def __len__(self) -> int:
return len(self._diagrams)
def __iter__(self):
return iter(self._diagrams)
def keys(self):
return self._diagrams.keys()
def values(self):
return self._diagrams.values()
def items(self):
return self._diagrams.items()
def in_dimension(self, dim: int) -> torch.Tensor:
return self[dim]
@property
def max_dim(self) -> int:
return max(self._diagrams.keys())
[docs]
def persistence_diagram(
fil: DiffFiltration,
dualize=None,
include_inf_points: bool = False,
gradient_method: str = "dgm-loss",
step_size: float = 1.0,
conflict_strategy="avg",
n_threads=None,
u_strategy=None,
dims_to_backprop=None,
) -> PersistenceDiagrams:
"""Compute differentiable persistence diagrams from a DiffFiltration.
Args:
fil: DiffFiltration with differentiable `values` tensor.
dualize: cohomology if True, homology if False. None (default)
uses the FiltrationKind reduction policy. Currently this picks
cohomology for VR and homology otherwise.
include_inf_points: Phase 1 only supports False. Setting True
raises NotImplementedError.
gradient_method: "dgm-loss" or "crit-sets".
step_size: scales grad_output to a target diagram
(target = current - step_size * grad_output) for crit-sets.
Ignored for dgm-loss.
conflict_strategy: "avg", "max", "sum", or "fca", or any
_oineus.ConflictStrategy. Used only for crit-sets.
n_threads: parallelism for the forward reduction and the
partial-U pass in backward.
u_strategy: "auto" (default), "row_partial", or
"legacy_in_band", or any _oineus.UStrategy. Used only for
crit-sets.
dims_to_backprop: list of geometric dims to restore ELZ in
during the forward reduction. None defaults to all dims
of the filtration. Used only for crit-sets.
Returns:
PersistenceDiagrams: dict-like, dim -> Tensor (N, 2). Gradients
flow back to fil.values.
"""
return PersistenceDiagrams(
fil,
dualize=dualize,
include_inf_points=include_inf_points,
gradient_method=gradient_method,
step_size=step_size,
conflict_strategy=conflict_strategy,
n_threads=n_threads,
u_strategy=u_strategy,
dims_to_backprop=dims_to_backprop,
)