oineus.plot_diagram_gradient¶
- oineus.plot_diagram_gradient(diagram, gradient=None, *, ax=None, dims=None, descent=False, plot_points=True, scatter_only=True, density_threshold=20000, min_persistence=None, top_k_arrows=None, arrow_overlay_threshold=1000, log_x=False, log_y=False, title=None, axis_bounds=None, inf_line_margin=0.05, diagram_color=None, grad_color=None, cmap=None, quiver_style=None, point_style=None, inf_point_style=None, diagonal_style=None, inf_line_style=None, density_style=None, dim_label_fmt='H{dim}', use_density=<object object>)[source]¶
Plot a gradient vector field on top of a persistence diagram.
For every diagram point at
(birth, death)an arrow with components(d/dbirth, d/ddeath)is drawn at that point. Useful for inspecting where an optimization step would move each persistence pair when minimizing or maximizing a topology-aware loss.- Parameters:
diagram – One of
torch.Tensorof shape(n, 2),numpy.ndarrayof shape(n, 2), nativeoineus.Diagrams, differentiableoineus.diff.PersistenceDiagrams, ordict[int, ndarray | torch.Tensor].gradient – Same shape/structure as
diagram, orNone. WhenNoneand the diagram is torch-backed, the gradient is pulled from each tensor’s.grad. For non-torch inputs it is required and must mirror the diagram’s per-dimension layout.descent (bool) – If
True, plot-grad(the descent direction). The default plots the gradient as-is (steepest increase).plot_points (bool) – If
True(default), the underlying diagram is drawn viaplot_diagrambefore the arrows are overlaid.diagram_color – Color of the diagram markers (forwarded to
plot_diagram’scolor). Single value,dict[int, color], orlist; a single color with multi-dim input emits aUserWarning.grad_color – Color of the quiver arrows. Single value (single arrow color across all dims). Defaults to
DEFAULT_DIAGRAM_GRADIENT_GRAD_COLOR.cmap – Forwarded to
plot_diagram(controls density-mode colormap).scatter_only (bool)
density_threshold (int)
min_persistence (float | None)
top_k_arrows (int | None)
arrow_overlay_threshold (int)
log_x (bool)
log_y (bool)
title (str | None)
inf_line_margin (float)
quiver_style (dict | None)
point_style (dict | None)
inf_point_style (dict | None)
diagonal_style (dict | None)
inf_line_style (dict | None)
density_style (dict | None)
dim_label_fmt (str)
Inf-death points are skipped (arrows for those rows are dropped). The
quiver_stylekwarg accepts anyAxes.quiverkeyword;colorno longer lives in the dict (usegrad_color).Arrow overlay filtering. With
scatter_only=True(the default) every finite point gets a scatter marker, so without a cap an arrow lands on every one of them and the picture turns into a hairball at >~1000 points. Two filters keep the overlay readable:min_persistence: drop arrows on points with persistencedeath - birth < min_persistence. Closest analog of “draw arrows only on the high-persistence outliers”. DefaultNone(no threshold).top_k_arrows: keep only the top-K by|grad|. When unset and the diagram has at leastarrow_overlay_thresholdfinite points (default 1000), defaults toDEFAULT_GRADIENT_TOP_K_ARROWS(200) so the overlay isn’t a hairball out of the box. Pass an explicit value (ornp.inf) to override.
Filters compose:
min_persistenceruns first,top_k_arrowssecond.The
use_densitykwarg is deprecated; passscatter_onlyinstead (semantics inverted).