oineus.plot_diagram_gradient

oineus.plot_diagram_gradient(diagram, gradient=None, *, ax=None, dims=None, descent=False, plot_points=True, scatter_only=True, density_threshold=20000, min_persistence=None, top_k_arrows=None, arrow_overlay_threshold=1000, log_x=False, log_y=False, title=None, axis_bounds=None, inf_line_margin=0.05, diagram_color=None, grad_color=None, cmap=None, quiver_style=None, point_style=None, inf_point_style=None, diagonal_style=None, inf_line_style=None, density_style=None, dim_label_fmt='H{dim}', use_density=<object object>)[source]

Plot a gradient vector field on top of a persistence diagram.

For every diagram point at (birth, death) an arrow with components (d/dbirth, d/ddeath) is drawn at that point. Useful for inspecting where an optimization step would move each persistence pair when minimizing or maximizing a topology-aware loss.

Parameters:
  • diagram – One of torch.Tensor of shape (n, 2), numpy.ndarray of shape (n, 2), native oineus.Diagrams, differentiable oineus.diff.PersistenceDiagrams, or dict[int, ndarray | torch.Tensor].

  • gradient – Same shape/structure as diagram, or None. When None and the diagram is torch-backed, the gradient is pulled from each tensor’s .grad. For non-torch inputs it is required and must mirror the diagram’s per-dimension layout.

  • descent (bool) – If True, plot -grad (the descent direction). The default plots the gradient as-is (steepest increase).

  • plot_points (bool) – If True (default), the underlying diagram is drawn via plot_diagram before the arrows are overlaid.

  • diagram_color – Color of the diagram markers (forwarded to plot_diagram’s color). Single value, dict[int, color], or list; a single color with multi-dim input emits a UserWarning.

  • grad_color – Color of the quiver arrows. Single value (single arrow color across all dims). Defaults to DEFAULT_DIAGRAM_GRADIENT_GRAD_COLOR.

  • cmap – Forwarded to plot_diagram (controls density-mode colormap).

  • dims (Iterable[int] | None)

  • scatter_only (bool)

  • density_threshold (int)

  • min_persistence (float | None)

  • top_k_arrows (int | None)

  • arrow_overlay_threshold (int)

  • log_x (bool)

  • log_y (bool)

  • title (str | None)

  • axis_bounds (Mapping[str, float] | None)

  • inf_line_margin (float)

  • quiver_style (dict | None)

  • point_style (dict | None)

  • inf_point_style (dict | None)

  • diagonal_style (dict | None)

  • inf_line_style (dict | None)

  • density_style (dict | None)

  • dim_label_fmt (str)

Inf-death points are skipped (arrows for those rows are dropped). The quiver_style kwarg accepts any Axes.quiver keyword; color no longer lives in the dict (use grad_color).

Arrow overlay filtering. With scatter_only=True (the default) every finite point gets a scatter marker, so without a cap an arrow lands on every one of them and the picture turns into a hairball at >~1000 points. Two filters keep the overlay readable:

  • min_persistence: drop arrows on points with persistence death - birth < min_persistence. Closest analog of “draw arrows only on the high-persistence outliers”. Default None (no threshold).

  • top_k_arrows: keep only the top-K by |grad|. When unset and the diagram has at least arrow_overlay_threshold finite points (default 1000), defaults to DEFAULT_GRADIENT_TOP_K_ARROWS (200) so the overlay isn’t a hairball out of the box. Pass an explicit value (or np.inf) to override.

Filters compose: min_persistence runs first, top_k_arrows second.

The use_density kwarg is deprecated; pass scatter_only instead (semantics inverted).