"""Matplotlib implementations of the oineus plotting helpers.
The four user-facing entry points are ``plot_diagram``, ``plot_matching``,
``plot_diagram_gradient`` and ``plot_chain``. Backend-agnostic helpers live
in ``_common``; default style dicts in ``_styles``.
"""
from __future__ import annotations
import typing
import warnings
import numpy as np
try:
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection, PolyCollection
_HAS_MATPLOTLIB = True
except Exception:
_HAS_MATPLOTLIB = False
try:
import mpl_scatter_density # noqa: F401
_HAS_MPL_SCATTER_DENSITY = True
except Exception:
_HAS_MPL_SCATTER_DENSITY = False
from . import _common
import warnings
from . import _styles
from ._common import (
_array_diagram,
_build_diagram_arrays,
_coerce_chain,
_coerce_diagram_with_grad,
_compute_plot_limits,
_id_to_grid,
_point_coords_for_edge,
_resolve_color,
_resolve_style,
_shift_for_log,
_split_near_diagonal,
_to_dim_diagrams,
)
from ._styles import (
DEFAULT_CHAIN_COLOR,
DEFAULT_DENSITY_STYLE,
DEFAULT_DENSITY_THRESHOLD,
DEFAULT_DIAGRAM_A_COLOR,
DEFAULT_DIAGRAM_B_COLOR,
DEFAULT_DIAGRAM_GRADIENT_DIAGRAM_COLOR,
DEFAULT_DIAGRAM_GRADIENT_GRAD_COLOR,
DEFAULT_GRADIENT_TOP_K_ARROWS,
DEFAULT_MATCHING_EDGE_COLOR,
DEFAULT_MATCHING_EDGE_QUANTILE,
DEFAULT_POINT_CLOUD_COLOR,
default_chain_edge_style,
default_chain_tetrahedron_style,
default_chain_triangle_style,
default_chain_vertex_style,
default_density_style,
default_diagonal_style,
default_diagonal_projection_a_style,
default_diagonal_projection_b_style,
default_diagram_a_point_style,
default_diagram_b_point_style,
default_diagram_gradient_style,
default_grid_style,
default_inf_line_style,
default_inf_point_style,
default_longest_edge_style,
default_matching_edge_style,
default_point_cloud_style,
default_point_style,
)
def _resolve_scatter_only(scatter_only, use_density):
"""Backward-compat shim for the old ``use_density`` kwarg.
Returns the resolved ``scatter_only`` flag and emits a DeprecationWarning
if ``use_density`` is explicitly passed (i.e. not the sentinel).
"""
if use_density is _UNSET:
return scatter_only
warnings.warn(
"the `use_density` kwarg is deprecated; pass `scatter_only=...` "
"instead. `use_density=False` -> `scatter_only=True`, "
"`use_density=True` -> `scatter_only=False`.",
DeprecationWarning,
stacklevel=3,
)
return not bool(use_density)
_UNSET = object()
def _is_solo_color(color):
"""True iff `color` is a non-None, single-color spec (not a dict, list,
tuple-of-rgba is OK as solo)."""
if color is None:
return False
if isinstance(color, dict):
return False
# str (named or hex), 1-elt sequence, tuple of (r,g,b)/(r,g,b,a):
if isinstance(color, str):
return True
if isinstance(color, (list, tuple)):
# treat list as cycle, tuple of 3/4 floats as a single rgba
if isinstance(color, tuple) and len(color) in (3, 4) and all(
isinstance(v, (int, float)) for v in color
):
return True
return False
return True
def _warn_if_solo_color_multi_dim(color, dims_sorted, where):
if _is_solo_color(color) and len(dims_sorted) > 1:
warnings.warn(
f"{where}: a single color was passed but the diagram has "
f"{len(dims_sorted)} dimensions ({list(dims_sorted)}); this "
"collapses the per-dim cycle to one color. Pass color={dim: c, ...} "
"or color=[c0, c1, ...] to color dimensions individually.",
UserWarning,
stacklevel=3,
)
# ---------------------------------------------------------------------------
# mpl_scatter_density bridge
# ---------------------------------------------------------------------------
def _require_scatter_density():
if not _HAS_MPL_SCATTER_DENSITY:
raise ImportError(
"Density rendering requires the mpl_scatter_density package. "
"Install it via `pip install mpl-scatter-density`, or disable "
"density rendering with use_density=False / "
"density_threshold=<larger value>."
)
def _add_density_artist(ax, x, y, *, color=None, style=None, norm=None):
"""Attach a ScatterDensityArtist to a regular matplotlib Axes.
Returns the artist. Uses ``ScatterDensityArtist`` directly so the caller
is not forced to construct the Axes with ``projection='scatter_density'``.
``color`` selects the monochromatic-fade-to-transparent rendering (used
when overlaying multiple diagrams in the same plot, e.g. matching). When
``color`` is ``None`` the artist uses ``style['cmap']``.
"""
_require_scatter_density()
from mpl_scatter_density import ScatterDensityArtist
import matplotlib.colors as mcolors
if norm is None:
norm = mcolors.PowerNorm(gamma=0.5)
kwargs = dict(style if style is not None else DEFAULT_DENSITY_STYLE)
if color is not None:
kwargs["color"] = color
kwargs.pop("cmap", None)
artist = ScatterDensityArtist(ax, x, y, norm=norm, **kwargs)
ax.add_artist(artist)
return artist
# ---------------------------------------------------------------------------
# plot_diagram
# ---------------------------------------------------------------------------
[docs]
def plot_diagram(
diagrams,
ax=None,
*,
color=None,
cmap=None,
log_x: bool = False,
log_y: bool = False,
title: typing.Optional[str] = None,
suptitle: typing.Optional[str] = None,
axis_bounds: typing.Optional[typing.Mapping[str, float]] = None,
dims: typing.Optional[typing.Iterable[int]] = None,
max_dimension: typing.Optional[int] = None,
scatter_only: bool = True,
density_threshold: int = DEFAULT_DENSITY_THRESHOLD,
near_diagonal_fraction: float = 0.03,
density_style: typing.Optional[dict] = None,
inf_line_margin: float = 0.05,
point_style: typing.Optional[dict] = None,
inf_point_style: typing.Optional[dict] = None,
diagonal_style: typing.Optional[dict] = None,
inf_line_style: typing.Optional[dict] = None,
grid_style: typing.Union[dict, bool, None] = None,
dim_label_fmt: str = "H{dim}",
use_density=_UNSET,
):
"""Plot one or more persistence diagrams.
Default rendering is **pure scatter**: every finite point gets its own
marker, so the visual density of overlapping markers faithfully tracks
the actual point density. Pass ``scatter_only=False`` to opt into the
hybrid density-bulk + outlier-scatter mode (suitable for very large
diagrams where the near-diagonal noise band saturates the scatter).
In that hybrid mode the bulk near the diagonal is rendered as a 2D
density (via mpl_scatter_density) while points further than
``near_diagonal_fraction`` of the axis range from the diagonal are
still drawn as crisp scatter, so high-persistence (topologically
meaningful) features are never aggregated.
Color of the diagram points is taken from the top-level ``color``
argument: pass a single color (``"red"``, ``"#012345"``, ``(r, g, b[, a])``)
to override every dim, a ``dict`` mapping dim -> color, or a ``list``
cycled through dims in sort order. ``None`` (default) leaves the choice
to matplotlib's per-call cycle. Passing a single color when the input
has more than one dim emits a ``UserWarning``.
When ``color`` is set and density mode is active, the density artist
uses a single-hue fade-to-transparent ramp keyed off ``color`` rather
than the cmap-based default; this is what makes 2-diagram overlays of
the form ``plot_diagram(dgm_a, color="C0"); plot_diagram(dgm_b, color="C1")``
on the same axes read as two distinct distributions.
``cmap`` is forwarded to the density-aggregation path when ``color`` is
None; pass ``None`` to keep the default (``viridis``).
Style-dict kwargs (``point_style``, ``inf_point_style``, ``diagonal_style``,
``inf_line_style``, ``density_style``) default to copies of the
module-level ``DEFAULT_*_STYLE`` dicts and no longer carry a ``color`` /
``c`` key. ``inf_point_style`` controls the markers placed on the
horizontal inf-line (death = +-inf); the default uses an upward triangle
so essentials read distinctly from finite points.
The ``use_density`` kwarg is deprecated; pass ``scatter_only`` instead.
``use_density=False`` corresponds to ``scatter_only=True`` and vice
versa.
"""
if not _HAS_MATPLOTLIB:
raise ImportError("matplotlib is required for plot_diagram.")
scatter_only = _resolve_scatter_only(scatter_only, use_density)
point_style = _resolve_style(point_style, default_point_style)
inf_point_style = _resolve_style(inf_point_style, default_inf_point_style)
diagonal_style = _resolve_style(diagonal_style, default_diagonal_style)
inf_line_style = _resolve_style(inf_line_style, default_inf_line_style)
if grid_style is False:
resolved_grid_style = None
else:
resolved_grid_style = _resolve_style(grid_style, default_grid_style)
bounds = {} if axis_bounds is None else dict(axis_bounds)
dgms = _to_dim_diagrams(diagrams, dims=dims, max_dimension=max_dimension)
dims_sorted = sorted(dgms.keys())
_warn_if_solo_color_multi_dim(color, dims_sorted, "plot_diagram")
finite_by_dim = {}
pos_inf_birth_by_dim = {}
neg_inf_birth_by_dim = {}
all_finite_births = []
all_finite_deaths = []
all_births_for_limits = []
for dim in dims_sorted:
arr = dgms[dim]
births = arr[:, 0] if arr.shape[0] else np.empty((0,), dtype=float)
deaths = arr[:, 1] if arr.shape[0] else np.empty((0,), dtype=float)
finite_mask = np.isfinite(births) & np.isfinite(deaths)
pos_inf_mask = np.isfinite(births) & (np.isposinf(deaths) | np.isnan(deaths))
neg_inf_mask = np.isfinite(births) & np.isneginf(deaths)
finite_births = births[finite_mask]
finite_deaths = deaths[finite_mask]
pos_inf_births = births[pos_inf_mask]
neg_inf_births = births[neg_inf_mask]
finite_by_dim[dim] = (finite_births, finite_deaths)
pos_inf_birth_by_dim[dim] = pos_inf_births
neg_inf_birth_by_dim[dim] = neg_inf_births
if finite_births.size:
all_finite_births.append(finite_births)
all_finite_deaths.append(finite_deaths)
all_births_for_limits.append(finite_births)
if pos_inf_births.size:
all_births_for_limits.append(pos_inf_births)
if neg_inf_births.size:
all_births_for_limits.append(neg_inf_births)
all_finite_births = (
np.concatenate(all_finite_births) if all_finite_births else np.empty((0,), dtype=float)
)
all_finite_deaths = (
np.concatenate(all_finite_deaths) if all_finite_deaths else np.empty((0,), dtype=float)
)
all_births_for_limits = (
np.concatenate(all_births_for_limits) if all_births_for_limits else np.empty((0,), dtype=float)
)
any_pos_inf = any(pos_inf_birth_by_dim[d].size > 0 for d in dims_sorted)
any_neg_inf = any(neg_inf_birth_by_dim[d].size > 0 for d in dims_sorted)
if all_finite_deaths.size:
y_min = float(np.min(all_finite_deaths))
y_max = float(np.max(all_finite_deaths))
y_span = y_max - y_min
elif all_births_for_limits.size:
y_min = float(np.min(all_births_for_limits))
y_max = float(np.max(all_births_for_limits))
y_span = y_max - y_min
else:
y_min = -1.0
y_max = 1.0
y_span = 2.0
if y_span <= 0.0:
y_span = max(abs(y_max), abs(y_min), 1.0)
if "ymax" in bounds:
inf_y_pos = 0.9 * float(bounds["ymax"])
else:
inf_y_pos = y_max + inf_line_margin * y_span
if "ymin" in bounds:
inf_y_neg = 0.9 * float(bounds["ymin"])
else:
inf_y_neg = y_min - inf_line_margin * y_span
if all_births_for_limits.size:
x_span = float(np.ptp(all_births_for_limits))
else:
x_span = 1.0
if x_span <= 0.0:
x_span = 1.0
near_thr = near_diagonal_fraction * max(x_span, y_span)
near_x_parts = []
near_y_parts = []
near_mask_by_dim = {}
for dim in dims_sorted:
births, deaths = finite_by_dim[dim]
if births.size == 0:
near_mask = np.zeros((0,), dtype=bool)
else:
near_mask = np.abs(deaths - births) <= near_thr
if np.any(near_mask):
near_x_parts.append(births[near_mask])
near_y_parts.append(deaths[near_mask])
near_mask_by_dim[dim] = near_mask
near_x = np.concatenate(near_x_parts) if near_x_parts else np.empty((0,), dtype=float)
near_y = np.concatenate(near_y_parts) if near_y_parts else np.empty((0,), dtype=float)
use_density_plot = (
not scatter_only
and all_finite_births.size >= density_threshold
and near_x.size > 0
)
if use_density_plot:
_require_scatter_density()
if ax is None:
_, ax = plt.subplots()
if resolved_grid_style is not None:
ax.grid(True, **resolved_grid_style)
y_values_for_shift = all_finite_deaths
if any_pos_inf:
y_values_for_shift = np.concatenate([y_values_for_shift, np.asarray([inf_y_pos])])
if any_neg_inf:
y_values_for_shift = np.concatenate([y_values_for_shift, np.asarray([inf_y_neg])])
x_shift = _shift_for_log(all_births_for_limits, log_x)
y_shift = _shift_for_log(y_values_for_shift, log_y)
if use_density_plot:
density_style_resolved = _resolve_style(density_style, default_density_style)
if cmap is not None:
density_style_resolved["cmap"] = cmap
# When the caller passes a solo color (single hue per density layer),
# use the color-keyed fade-to-transparent rendering so two
# plot_diagram calls with distinct colors stack readably on the same
# axes. cmap-keyed rendering only kicks in when no color is given.
density_color = color if _is_solo_color(color) else None
_add_density_artist(
ax,
near_x + x_shift,
near_y + y_shift,
color=density_color,
style=density_style_resolved,
)
# When a per-dim color override is supplied, it wins over point_style's "c".
base_scatter_kwargs = dict(point_style)
base_color = base_scatter_kwargs.pop("c", None)
base_inf_kwargs = dict(inf_point_style)
base_inf_color = base_inf_kwargs.pop("c", None)
for dim_idx, dim in enumerate(dims_sorted):
dim_color = _resolve_color(color, dim, dim_idx)
effective_c = dim_color if dim_color is not None else base_color
effective_inf_c = dim_color if dim_color is not None else base_inf_color
# If neither finite-point nor inf-point styles pin a color, share
# the matplotlib cycle index across all of this dim's scatter calls
# so the +inf, -inf, and finite markers match instead of stepping
# through three consecutive cycle entries.
if effective_c is None and effective_inf_c is None:
shared = f"C{dim_idx % 10}"
effective_c = shared
effective_inf_c = shared
scatter_kwargs = dict(base_scatter_kwargs)
if effective_c is not None:
scatter_kwargs["c"] = effective_c
inf_scatter_kwargs = dict(base_inf_kwargs)
if effective_inf_c is not None:
inf_scatter_kwargs["c"] = effective_inf_c
label = dim_label_fmt.format(dim=dim)
births, deaths = finite_by_dim[dim]
if births.size:
mask = ~near_mask_by_dim[dim] if use_density_plot else np.ones_like(births, dtype=bool)
if np.any(mask):
ax.scatter(
births[mask] + x_shift,
deaths[mask] + y_shift,
label=label,
**scatter_kwargs,
)
label_for_inf = label if births.size == 0 else None
pos_inf_births = pos_inf_birth_by_dim[dim]
if pos_inf_births.size:
ax.scatter(
pos_inf_births + x_shift,
np.full_like(pos_inf_births, inf_y_pos + y_shift),
label=label_for_inf,
**inf_scatter_kwargs,
)
label_for_inf = None
neg_inf_kwargs = dict(inf_scatter_kwargs)
# Down-pointing triangle for -inf if the default upward-triangle is
# in use; a user-supplied marker is left untouched.
if neg_inf_kwargs.get("marker") == "^":
neg_inf_kwargs["marker"] = "v"
neg_inf_births = neg_inf_birth_by_dim[dim]
if neg_inf_births.size:
ax.scatter(
neg_inf_births + x_shift,
np.full_like(neg_inf_births, inf_y_neg + y_shift),
label=label_for_inf,
**neg_inf_kwargs,
)
if any_pos_inf:
ax.axhline(inf_y_pos + y_shift, **inf_line_style)
if any_neg_inf:
ax.axhline(inf_y_neg + y_shift, **inf_line_style)
if all_births_for_limits.size or all_finite_deaths.size or any_pos_inf or any_neg_inf:
if all_births_for_limits.size:
x_vals = all_births_for_limits + x_shift
else:
x_vals = np.asarray([0.0 + x_shift, 1.0 + x_shift])
if all_finite_deaths.size:
y_vals = all_finite_deaths + y_shift
else:
y_vals = np.asarray([0.0 + y_shift, 1.0 + y_shift])
lo = min(float(np.min(x_vals)), float(np.min(y_vals)))
hi = max(float(np.max(x_vals)), float(np.max(y_vals)))
if any_pos_inf:
hi = max(hi, float(inf_y_pos + y_shift))
if any_neg_inf:
lo = min(lo, float(inf_y_neg + y_shift))
if hi <= lo:
hi = lo + 1.0
ax.plot([lo, hi], [lo, hi], **diagonal_style)
if log_x:
ax.set_xscale("log")
if log_y:
ax.set_yscale("log")
x_left = None if "xmin" not in bounds else float(bounds["xmin"]) + x_shift
x_right = None if "xmax" not in bounds else float(bounds["xmax"]) + x_shift
y_bottom = None if "ymin" not in bounds else float(bounds["ymin"]) + y_shift
y_top = None if "ymax" not in bounds else float(bounds["ymax"]) + y_shift
if x_left is not None or x_right is not None:
ax.set_xlim(left=x_left, right=x_right)
if y_bottom is not None or y_top is not None:
ax.set_ylim(bottom=y_bottom, top=y_top)
ax.set_xlabel("birth" if x_shift == 0 else f"birth (shifted by +{x_shift:.3g})")
ax.set_ylabel("death" if y_shift == 0 else f"death (shifted by +{y_shift:.3g})")
if title is not None:
ax.set_title(title)
if suptitle is not None:
ax.figure.suptitle(suptitle)
if len(dims_sorted) > 1:
handles, labels = ax.get_legend_handles_labels()
if labels:
uniq = dict(zip(labels, handles))
ax.legend(uniq.values(), uniq.keys(), title="dimension")
return ax
# ---------------------------------------------------------------------------
# plot_diagram_gradient
# ---------------------------------------------------------------------------
[docs]
def plot_diagram_gradient(
diagram,
gradient=None,
*,
ax=None,
dims: typing.Optional[typing.Iterable[int]] = None,
descent: bool = False,
plot_points: bool = True,
scatter_only: bool = True,
density_threshold: int = DEFAULT_DENSITY_THRESHOLD,
min_persistence: typing.Optional[float] = None,
top_k_arrows: typing.Optional[int] = None,
arrow_overlay_threshold: int = 1000,
log_x: bool = False,
log_y: bool = False,
title: typing.Optional[str] = None,
axis_bounds: typing.Optional[typing.Mapping[str, float]] = None,
inf_line_margin: float = 0.05,
diagram_color=None,
grad_color=None,
cmap=None,
quiver_style: typing.Optional[dict] = None,
point_style: typing.Optional[dict] = None,
inf_point_style: typing.Optional[dict] = None,
diagonal_style: typing.Optional[dict] = None,
inf_line_style: typing.Optional[dict] = None,
density_style: typing.Optional[dict] = None,
dim_label_fmt: str = "H{dim}",
use_density=_UNSET,
):
"""Plot a gradient vector field on top of a persistence diagram.
For every diagram point at ``(birth, death)`` an arrow with components
``(d/dbirth, d/ddeath)`` is drawn at that point. Useful for inspecting
where an optimization step would move each persistence pair when
minimizing or maximizing a topology-aware loss.
Args:
diagram: One of ``torch.Tensor`` of shape ``(n, 2)``,
``numpy.ndarray`` of shape ``(n, 2)``, native
``oineus.Diagrams``, differentiable
``oineus.diff.PersistenceDiagrams``, or
``dict[int, ndarray | torch.Tensor]``.
gradient: Same shape/structure as ``diagram``, or ``None``. When
``None`` and the diagram is torch-backed, the gradient is pulled
from each tensor's ``.grad``. For non-torch inputs it is
required and must mirror the diagram's per-dimension layout.
descent: If ``True``, plot ``-grad`` (the descent direction). The
default plots the gradient as-is (steepest *increase*).
plot_points: If ``True`` (default), the underlying diagram is drawn
via ``plot_diagram`` before the arrows are overlaid.
diagram_color: Color of the diagram markers (forwarded to
``plot_diagram``'s ``color``). Single value, ``dict[int, color]``,
or ``list``; a single color with multi-dim input emits a
``UserWarning``.
grad_color: Color of the quiver arrows. Single value (single arrow
color across all dims). Defaults to ``DEFAULT_DIAGRAM_GRADIENT_GRAD_COLOR``.
cmap: Forwarded to ``plot_diagram`` (controls density-mode colormap).
Inf-death points are skipped (arrows for those rows are dropped). The
``quiver_style`` kwarg accepts any ``Axes.quiver`` keyword; ``color``
no longer lives in the dict (use ``grad_color``).
Arrow overlay filtering. With ``scatter_only=True`` (the default) every
finite point gets a scatter marker, so without a cap an arrow lands on
every one of them and the picture turns into a hairball at >~1000
points. Two filters keep the overlay readable:
- ``min_persistence``: drop arrows on points with persistence
``death - birth < min_persistence``. Closest analog of "draw arrows
only on the high-persistence outliers". Default ``None`` (no
threshold).
- ``top_k_arrows``: keep only the top-K by ``|grad|``. When unset and
the diagram has at least ``arrow_overlay_threshold`` finite points
(default 1000), defaults to ``DEFAULT_GRADIENT_TOP_K_ARROWS`` (200)
so the overlay isn't a hairball out of the box. Pass an explicit
value (or ``np.inf``) to override.
Filters compose: ``min_persistence`` runs first, ``top_k_arrows``
second.
The ``use_density`` kwarg is deprecated; pass ``scatter_only`` instead
(semantics inverted).
"""
if not _HAS_MATPLOTLIB:
raise ImportError("matplotlib is required for plot_diagram_gradient.")
# Legacy use_density=False meant "scatter mode, draw an arrow on every
# finite point" -- no overlay cap. The new scatter_only=True is the
# cap's primary trigger, so the shim has to also disable the cap when
# the caller explicitly opted into the old behaviour. Honour caller-set
# top_k_arrows / min_persistence.
if use_density is False and top_k_arrows is None and min_persistence is None:
top_k_arrows = np.inf
scatter_only = _resolve_scatter_only(scatter_only, use_density)
quiver_style = _resolve_style(quiver_style, default_diagram_gradient_style)
quiver_style.setdefault(
"color",
grad_color if grad_color is not None else DEFAULT_DIAGRAM_GRADIENT_GRAD_COLOR,
)
if grad_color is not None:
quiver_style["color"] = grad_color
points_by_dim, grad_by_dim = _coerce_diagram_with_grad(diagram, gradient, dims)
if not points_by_dim:
raise ValueError("No diagram points to plot.")
sign = -1.0 if descent else 1.0
finite_by_dim: typing.Dict[int, typing.Tuple[np.ndarray, np.ndarray]] = {}
for dim, pts in points_by_dim.items():
g = grad_by_dim[dim]
if pts.shape[0] != g.shape[0]:
raise ValueError(
f"Diagram and gradient row counts disagree for dim {dim}: "
f"{pts.shape[0]} vs {g.shape[0]}."
)
finite_mask = np.isfinite(pts).all(axis=1) & np.isfinite(g).all(axis=1)
finite_by_dim[dim] = (pts[finite_mask], sign * g[finite_mask])
total_finite = sum(pts.shape[0] for pts, _ in finite_by_dim.values())
if plot_points:
finite_dgms = {dim: pts for dim, (pts, _) in finite_by_dim.items()}
ax = plot_diagram(
finite_dgms,
ax=ax,
color=diagram_color,
cmap=cmap,
log_x=log_x,
log_y=log_y,
title=title,
axis_bounds=axis_bounds,
inf_line_margin=inf_line_margin,
point_style=point_style,
inf_point_style=inf_point_style,
diagonal_style=diagonal_style,
inf_line_style=inf_line_style,
density_style=density_style,
scatter_only=scatter_only,
density_threshold=density_threshold,
dim_label_fmt=dim_label_fmt,
)
elif ax is None:
_, ax = plt.subplots()
# Filter 1: persistence threshold (drops noise).
if min_persistence is not None and min_persistence > 0.0:
new_finite = {}
for dim, (pts, g) in finite_by_dim.items():
if pts.shape[0] == 0:
new_finite[dim] = (pts, g)
continue
mask = (pts[:, 1] - pts[:, 0]) >= min_persistence
new_finite[dim] = (pts[mask], g[mask])
finite_by_dim = new_finite
total_finite = sum(pts.shape[0] for pts, _ in finite_by_dim.values())
# Filter 2: cap by gradient magnitude. Default kicks in once the
# remaining count is large enough for the overlay to read as a hairball.
effective_top_k = top_k_arrows
if effective_top_k is None and total_finite >= arrow_overlay_threshold:
effective_top_k = DEFAULT_GRADIENT_TOP_K_ARROWS
warnings.warn(
f"plot_diagram_gradient: capping arrows to top-{effective_top_k} "
f"by |grad| (out of {total_finite} finite pairs). Pass "
"top_k_arrows=N or min_persistence=p to override.",
UserWarning,
stacklevel=2,
)
if effective_top_k is not None and np.isfinite(effective_top_k):
dim_arrs, local_arrs, mag_arrs = [], [], []
for dim, (pts, g) in finite_by_dim.items():
n = pts.shape[0]
if n == 0:
continue
dim_arrs.append(np.full(n, dim, dtype=np.int64))
local_arrs.append(np.arange(n, dtype=np.int64))
mag_arrs.append(np.hypot(g[:, 0], g[:, 1]))
if dim_arrs:
dim_concat = np.concatenate(dim_arrs)
local_concat = np.concatenate(local_arrs)
mag_concat = np.concatenate(mag_arrs)
if mag_concat.size > effective_top_k:
keep = np.argpartition(mag_concat, -int(effective_top_k))[-int(effective_top_k):]
keep_dim = dim_concat[keep]
keep_local = local_concat[keep]
new_finite = {}
for dim, (pts, g) in finite_by_dim.items():
mask = np.zeros(pts.shape[0], dtype=bool)
mask[keep_local[keep_dim == dim]] = True
new_finite[dim] = (pts[mask], g[mask])
finite_by_dim = new_finite
finite_births_parts = [pts[:, 0] for pts, _ in finite_by_dim.values() if pts.size]
finite_deaths_parts = [pts[:, 1] for pts, _ in finite_by_dim.values() if pts.size]
all_births = np.concatenate(finite_births_parts) if finite_births_parts else np.empty((0,), dtype=float)
all_deaths = np.concatenate(finite_deaths_parts) if finite_deaths_parts else np.empty((0,), dtype=float)
x_shift = _shift_for_log(all_births, log_x)
y_shift = _shift_for_log(all_deaths, log_y)
for dim in sorted(finite_by_dim.keys()):
pts, grads = finite_by_dim[dim]
if pts.shape[0] == 0:
continue
ax.quiver(
pts[:, 0] + x_shift,
pts[:, 1] + y_shift,
grads[:, 0],
grads[:, 1],
**quiver_style,
)
return ax
# ---------------------------------------------------------------------------
# plot_matching
# ---------------------------------------------------------------------------
[docs]
def plot_matching(
dgm_a,
dgm_b,
matching,
ax=None,
*,
plot_finite_to_finite: typing.Optional[bool] = None,
plot_a_to_diagonal: typing.Optional[bool] = None,
plot_b_to_diagonal: typing.Optional[bool] = None,
plot_essential: typing.Optional[bool] = None,
highlight_longest: typing.Optional[bool] = None,
plot_points: bool = True,
plot_diagonal_projections: bool = False,
plot_diagonal: bool = True,
scatter_only: bool = True,
density_threshold: int = DEFAULT_DENSITY_THRESHOLD,
near_diagonal_fraction: float = 0.03,
edge_quantile: float = DEFAULT_MATCHING_EDGE_QUANTILE,
min_persistence: typing.Optional[float] = None,
top_k_pairs: typing.Optional[int] = None,
pair_filter: str = "either",
pair_overlay_threshold: int = 1000,
color_dgm_a=None,
color_dgm_b=None,
match_color=None,
cmap_a=None,
cmap_b=None,
density_style: typing.Optional[dict] = None,
dgm_a_point_style: typing.Optional[dict] = None,
dgm_b_point_style: typing.Optional[dict] = None,
ordinary_edge_style: typing.Optional[dict] = None,
longest_edge_style: typing.Optional[dict] = None,
diagonal_style: typing.Optional[dict] = None,
diagonal_projection_a_style: typing.Optional[dict] = None,
diagonal_projection_b_style: typing.Optional[dict] = None,
inf_line_style: typing.Optional[dict] = None,
dgm_a_label: str = "Diagram A",
dgm_b_label: str = "Diagram B",
title: typing.Optional[str] = None,
axis_bounds: typing.Optional[typing.Mapping[str, float]] = None,
inf_line_margin: float = 0.05,
use_density=_UNSET,
):
"""Plot a matching between two persistence diagrams.
Color of the per-side scatter is controlled by ``color_dgm_a`` and
``color_dgm_b`` (single colors; ``None`` uses ``DEFAULT_DIAGRAM_A_COLOR``
/ ``DEFAULT_DIAGRAM_B_COLOR``). The matching edges use ``match_color``
(single color; ``None`` uses ``DEFAULT_MATCHING_EDGE_COLOR``). Per-side
density colormaps are ``cmap_a`` and ``cmap_b``; ``None`` keeps the
monochromatic-fade-from-color default driven by the side's color.
Dispatches on ``matching`` type: for Wasserstein (``DiagramMatching``)
all edge categories are drawn by default; for ``BottleneckMatching`` only
finite-to-finite edges are drawn and the longest edge(s) are overlaid in
the highlight style.
``dgm_a`` and ``dgm_b`` must be 2D numpy arrays (one homology dimension).
Edge filtering. Three knobs, applied in order:
- ``min_persistence``: drop edges where neither (or both, if
``pair_filter='both'``) endpoint has persistence
``death - birth >= min_persistence``. Default ``None`` (no filter).
- ``top_k_pairs``: keep only the K edges with the largest endpoint
persistence. When unset and the total ordinary-edge count is at
least ``pair_overlay_threshold`` (default 1000), defaults to 200
and emits a ``UserWarning`` listing the cap.
- ``edge_quantile``: legacy length-based filter, only active when
``scatter_only=False`` (i.e. density mode). Kept for back-compat;
``min_persistence`` / ``top_k_pairs`` are the recommended controls.
With ``scatter_only=False`` (opt-in density mode) the diagram is
rendered as a density background (near-diagonal points only) plus
crisp scatter for high-persistence outliers; the ``edge_quantile``
filter then runs in addition.
The ``use_density`` kwarg is deprecated; pass ``scatter_only`` instead
(semantics inverted).
"""
if not _HAS_MATPLOTLIB:
raise ImportError("matplotlib is required for plot_matching.")
scatter_only = _resolve_scatter_only(scatter_only, use_density)
if pair_filter not in ("either", "both"):
raise ValueError("pair_filter must be 'either' or 'both'.")
# Avoid circular import
from ..matching import (
BottleneckMatching,
DiagramMatching,
point_to_diagonal,
)
if not isinstance(matching, DiagramMatching):
raise TypeError("matching must be a DiagramMatching or BottleneckMatching instance.")
is_bottleneck = isinstance(matching, BottleneckMatching)
# Type-aware category-flag defaults
if plot_finite_to_finite is None:
plot_finite_to_finite = True
if plot_a_to_diagonal is None:
plot_a_to_diagonal = not is_bottleneck
if plot_b_to_diagonal is None:
plot_b_to_diagonal = not is_bottleneck
if plot_essential is None:
plot_essential = False
if highlight_longest is None:
highlight_longest = is_bottleneck
# Resolve style dicts
dgm_a_point_style = _resolve_style(dgm_a_point_style, default_diagram_a_point_style)
dgm_b_point_style = _resolve_style(dgm_b_point_style, default_diagram_b_point_style)
ordinary_edge_style = _resolve_style(ordinary_edge_style, default_matching_edge_style)
longest_edge_style = _resolve_style(longest_edge_style, default_longest_edge_style)
diagonal_style = _resolve_style(diagonal_style, default_diagonal_style)
diagonal_projection_a_style = _resolve_style(
diagonal_projection_a_style, default_diagonal_projection_a_style)
diagonal_projection_b_style = _resolve_style(
diagonal_projection_b_style, default_diagonal_projection_b_style)
# Diagonal-projection markers inherit the per-side color unless the
# caller pinned one in the style dict.
if "c" not in diagonal_projection_a_style:
diagonal_projection_a_style["c"] = (
color_dgm_a if color_dgm_a is not None else DEFAULT_DIAGRAM_A_COLOR
)
if "c" not in diagonal_projection_b_style:
diagonal_projection_b_style["c"] = (
color_dgm_b if color_dgm_b is not None else DEFAULT_DIAGRAM_B_COLOR
)
inf_line_style = _resolve_style(inf_line_style, default_inf_line_style)
# Inject the top-level color arguments. The style dicts no longer
# carry color keys, so a `None` user input falls back to the module-
# level defaults defined in _styles.py.
color_a = color_dgm_a if color_dgm_a is not None else DEFAULT_DIAGRAM_A_COLOR
color_b = color_dgm_b if color_dgm_b is not None else DEFAULT_DIAGRAM_B_COLOR
dgm_a_point_style.setdefault("c", color_a)
dgm_b_point_style.setdefault("c", color_b)
ordinary_edge_style.setdefault(
"color",
match_color if match_color is not None else DEFAULT_MATCHING_EDGE_COLOR,
)
bounds = {} if axis_bounds is None else dict(axis_bounds)
dgm_a = _build_diagram_arrays(dgm_a)
dgm_b = _build_diagram_arrays(dgm_b)
# Collect finite coordinates for layout
def _finite_parts(dgm):
if dgm.shape[0] == 0:
return np.empty((0,), dtype=float), np.empty((0,), dtype=float)
finite = np.isfinite(dgm[:, 0]) & np.isfinite(dgm[:, 1])
return dgm[finite, 0], dgm[finite, 1]
a_fin_b, a_fin_d = _finite_parts(dgm_a)
b_fin_b, b_fin_d = _finite_parts(dgm_b)
# Essential births / deaths for layout (we want the plot to include them).
a_pos_inf_b = dgm_a[np.isfinite(dgm_a[:, 0]) & np.isposinf(dgm_a[:, 1]), 0] if dgm_a.size else np.empty((0,))
b_pos_inf_b = dgm_b[np.isfinite(dgm_b[:, 0]) & np.isposinf(dgm_b[:, 1]), 0] if dgm_b.size else np.empty((0,))
a_neg_inf_b = dgm_a[np.isfinite(dgm_a[:, 0]) & np.isneginf(dgm_a[:, 1]), 0] if dgm_a.size else np.empty((0,))
b_neg_inf_b = dgm_b[np.isfinite(dgm_b[:, 0]) & np.isneginf(dgm_b[:, 1]), 0] if dgm_b.size else np.empty((0,))
a_pos_inf_d = dgm_a[np.isposinf(dgm_a[:, 0]) & np.isfinite(dgm_a[:, 1]), 1] if dgm_a.size else np.empty((0,))
b_pos_inf_d = dgm_b[np.isposinf(dgm_b[:, 0]) & np.isfinite(dgm_b[:, 1]), 1] if dgm_b.size else np.empty((0,))
a_neg_inf_d = dgm_a[np.isneginf(dgm_a[:, 0]) & np.isfinite(dgm_a[:, 1]), 1] if dgm_a.size else np.empty((0,))
b_neg_inf_d = dgm_b[np.isneginf(dgm_b[:, 0]) & np.isfinite(dgm_b[:, 1]), 1] if dgm_b.size else np.empty((0,))
(x_min, x_max, y_min, y_max, x_span, y_span,
inf_x_pos, inf_x_neg, inf_y_pos, inf_y_neg) = _compute_plot_limits(
np.concatenate([a_fin_b, b_fin_b]) if a_fin_b.size or b_fin_b.size else np.empty((0,)),
np.concatenate([a_fin_d, b_fin_d]) if a_fin_d.size or b_fin_d.size else np.empty((0,)),
extra_xs=[a_pos_inf_b, b_pos_inf_b, a_neg_inf_b, b_neg_inf_b],
extra_ys=[a_pos_inf_d, b_pos_inf_d, a_neg_inf_d, b_neg_inf_d],
inf_line_margin=inf_line_margin,
)
any_pos_inf_d = (a_pos_inf_b.size + b_pos_inf_b.size) > 0
any_neg_inf_d = (a_neg_inf_b.size + b_neg_inf_b.size) > 0
any_pos_inf_b = (a_pos_inf_d.size + b_pos_inf_d.size) > 0
any_neg_inf_b = (a_neg_inf_d.size + b_neg_inf_d.size) > 0
if ax is None:
_, ax = plt.subplots()
# Decide whether to switch to density mode for the bulk near the diagonal.
n_finite_total = a_fin_b.size + b_fin_b.size
use_density_plot = (
not scatter_only
and plot_points
and n_finite_total >= density_threshold
)
if use_density_plot:
_require_scatter_density()
near_thr = near_diagonal_fraction * max(x_span, y_span)
# Diagonal
if plot_diagonal:
lo = min(x_min, y_min)
hi = max(x_max, y_max)
if any_pos_inf_d:
hi = max(hi, inf_y_pos)
if any_neg_inf_d:
lo = min(lo, inf_y_neg)
if any_pos_inf_b:
hi = max(hi, inf_x_pos)
if any_neg_inf_b:
lo = min(lo, inf_x_neg)
if hi <= lo:
hi = lo + 1.0
ax.plot([lo, hi], [lo, hi], **diagonal_style)
# Inf lines
if any_pos_inf_d:
ax.axhline(inf_y_pos, **inf_line_style)
if any_neg_inf_d:
ax.axhline(inf_y_neg, **inf_line_style)
if any_pos_inf_b:
ax.axvline(inf_x_pos, **inf_line_style)
if any_neg_inf_b:
ax.axvline(inf_x_neg, **inf_line_style)
# Diagram points: split near-diagonal bulk (density when enabled) from
# outliers (always scatter so high-persistence features stay crisp).
def _draw_diagram_points(dgm, point_style, label, *, side_cmap):
if dgm.shape[0] == 0:
return
coords = np.array([
_point_coords_for_edge(dgm, i, inf_x_pos, inf_x_neg, inf_y_pos, inf_y_neg)
for i in range(dgm.shape[0])
])
if not use_density_plot:
ax.scatter(coords[:, 0], coords[:, 1], label=label, **point_style)
return
finite_mask = np.isfinite(coords[:, 0]) & np.isfinite(coords[:, 1])
finite = coords[finite_mask]
non_finite = coords[~finite_mask]
near_b, near_d, far_b, far_d = _split_near_diagonal(
finite[:, 0], finite[:, 1], near_thr)
if near_b.size:
density_kwargs = {}
if side_cmap is not None:
# Explicit cmap wins over the monochromatic-fade default.
resolved = _resolve_style(density_style, default_density_style)
resolved["cmap"] = side_cmap
density_kwargs["style"] = resolved
else:
density_kwargs["color"] = point_style.get("c")
density_kwargs["style"] = _resolve_style(
density_style, default_density_style,
)
_add_density_artist(ax, near_b, near_d, **density_kwargs)
scatter_label = label if (far_b.size or non_finite.size) else None
if far_b.size:
ax.scatter(far_b, far_d, label=scatter_label, **point_style)
scatter_label = None
if non_finite.size:
ax.scatter(non_finite[:, 0], non_finite[:, 1], label=scatter_label, **point_style)
if plot_points:
_draw_diagram_points(dgm_a, dgm_a_point_style, dgm_a_label, side_cmap=cmap_a)
_draw_diagram_points(dgm_b, dgm_b_point_style, dgm_b_label, side_cmap=cmap_b)
# Gather all edges to draw, grouped by category.
ordinary_segments: list = []
if plot_finite_to_finite:
for ia, ib in matching.finite_to_finite:
pa = _point_coords_for_edge(dgm_a, ia, inf_x_pos, inf_x_neg, inf_y_pos, inf_y_neg)
pb = _point_coords_for_edge(dgm_b, ib, inf_x_pos, inf_x_neg, inf_y_pos, inf_y_neg)
ordinary_segments.append((pa, pb))
diag_proj_a_coords = []
diag_proj_b_coords = []
if plot_a_to_diagonal and len(matching.a_to_diagonal) > 0:
projs = point_to_diagonal(dgm_a, indices=matching.a_to_diagonal)
for local_i, ia in enumerate(matching.a_to_diagonal):
pa = _point_coords_for_edge(dgm_a, int(ia), inf_x_pos, inf_x_neg, inf_y_pos, inf_y_neg)
pproj = (float(projs[local_i, 0]), float(projs[local_i, 1]))
ordinary_segments.append((pa, pproj))
diag_proj_a_coords.append(pproj)
if plot_b_to_diagonal and len(matching.b_to_diagonal) > 0:
projs = point_to_diagonal(dgm_b, indices=matching.b_to_diagonal)
for local_i, ib in enumerate(matching.b_to_diagonal):
pb = _point_coords_for_edge(dgm_b, int(ib), inf_x_pos, inf_x_neg, inf_y_pos, inf_y_neg)
pproj = (float(projs[local_i, 0]), float(projs[local_i, 1]))
ordinary_segments.append((pb, pproj))
diag_proj_b_coords.append(pproj)
if plot_essential:
for _kind, pairs in matching.essential.items():
for ia, ib in pairs:
pa = _point_coords_for_edge(dgm_a, int(ia), inf_x_pos, inf_x_neg, inf_y_pos, inf_y_neg)
pb = _point_coords_for_edge(dgm_b, int(ib), inf_x_pos, inf_x_neg, inf_y_pos, inf_y_neg)
ordinary_segments.append((pa, pb))
# Persistence-aware filters (run before the legacy edge-length filter).
if ordinary_segments and (min_persistence is not None or top_k_pairs is not None
or len(ordinary_segments) >= pair_overlay_threshold):
seg_arr = np.array(ordinary_segments, dtype=float)
# Persistence per endpoint = abs(death - birth). Inf-clamped points
# appear at finite inf_*_pos coords, so the value here is only an
# approximation for essentials -- but they have y near the inf-line
# and x at the birth, so abs is still large and they pass any
# reasonable threshold.
pers_a = np.abs(seg_arr[:, 0, 1] - seg_arr[:, 0, 0])
pers_b = np.abs(seg_arr[:, 1, 1] - seg_arr[:, 1, 0])
if pair_filter == "either":
pers_pair = np.maximum(pers_a, pers_b)
pers_for_filter = pers_pair
else: # "both"
pers_pair = np.maximum(pers_a, pers_b) # ranking proxy
pers_for_filter = np.minimum(pers_a, pers_b)
keep_mask = np.ones(len(ordinary_segments), dtype=bool)
if min_persistence is not None and min_persistence > 0.0:
keep_mask &= pers_for_filter >= min_persistence
kept_count = int(keep_mask.sum())
effective_top_k = top_k_pairs
if effective_top_k is None and kept_count >= pair_overlay_threshold:
effective_top_k = 200
warnings.warn(
f"plot_matching: capping ordinary edges to top-{effective_top_k} "
f"by endpoint persistence (out of {kept_count}). Pass "
"top_k_pairs=N or min_persistence=p to override.",
UserWarning,
stacklevel=2,
)
if effective_top_k is not None and np.isfinite(effective_top_k) \
and kept_count > effective_top_k:
ranks = np.where(keep_mask, pers_pair, -np.inf)
keep_idx = np.argpartition(ranks, -int(effective_top_k))[-int(effective_top_k):]
new_mask = np.zeros_like(keep_mask)
new_mask[keep_idx] = True
keep_mask = new_mask
ordinary_segments = [
s for s, k in zip(ordinary_segments, keep_mask) if k
]
if ordinary_segments and use_density_plot and 0.0 < edge_quantile < 1.0:
# Legacy length-based filter (only when density mode active).
# Most matchings of large diagrams are dominated by short noise-to-noise
# pairs near the diagonal that pile up into a featureless gray hairball.
seg_arr = np.array(ordinary_segments, dtype=float)
lengths = np.hypot(
seg_arr[:, 1, 0] - seg_arr[:, 0, 0],
seg_arr[:, 1, 1] - seg_arr[:, 0, 1],
)
threshold = float(np.quantile(lengths, edge_quantile))
keep = lengths >= threshold
ordinary_segments = [s for s, k in zip(ordinary_segments, keep) if k]
if ordinary_segments:
ax.add_collection(LineCollection(ordinary_segments, **ordinary_edge_style))
# Diagonal projection markers (after edges so they sit on top).
if plot_diagonal_projections:
if diag_proj_a_coords:
arr = np.array(diag_proj_a_coords)
ax.scatter(arr[:, 0], arr[:, 1], **diagonal_projection_a_style)
if diag_proj_b_coords:
arr = np.array(diag_proj_b_coords)
ax.scatter(arr[:, 0], arr[:, 1], **diagonal_projection_b_style)
# Highlight longest edges for bottleneck.
if highlight_longest and is_bottleneck:
longest_segments = []
for e in matching.longest.finite:
longest_segments.append((e.point_a, e.point_b))
for _kind, edges in matching.longest.essential.items():
for e in edges:
pa = _point_coords_for_edge(dgm_a, e.idx_a, inf_x_pos, inf_x_neg, inf_y_pos, inf_y_neg)
pb = _point_coords_for_edge(dgm_b, e.idx_b, inf_x_pos, inf_x_neg, inf_y_pos, inf_y_neg)
longest_segments.append((pa, pb))
if longest_segments:
ax.add_collection(LineCollection(longest_segments, **longest_edge_style))
# Axis limits
x_left = None if "xmin" not in bounds else float(bounds["xmin"])
x_right = None if "xmax" not in bounds else float(bounds["xmax"])
y_bottom = None if "ymin" not in bounds else float(bounds["ymin"])
y_top = None if "ymax" not in bounds else float(bounds["ymax"])
if x_left is not None or x_right is not None:
ax.set_xlim(left=x_left, right=x_right)
if y_bottom is not None or y_top is not None:
ax.set_ylim(bottom=y_bottom, top=y_top)
ax.set_xlabel("birth")
ax.set_ylabel("death")
if title is not None:
ax.set_title(title)
# Legend: only the diagram-point labels are registered, so this is safe.
handles, labels = ax.get_legend_handles_labels()
if labels:
ax.legend()
return ax
# ---------------------------------------------------------------------------
# plot_chain
# ---------------------------------------------------------------------------
def _cube_filtration_types():
"""Return the tuple of CubeFiltration_*D types or () if unavailable."""
try:
from .. import _oineus
return (
_oineus._CubeFiltration_1D,
_oineus._CubeFiltration_2D,
_oineus._CubeFiltration_3D,
)
except (ImportError, AttributeError):
return ()
def _is_cubical_filtration(filtration):
cube_types = _cube_filtration_types()
return bool(cube_types) and isinstance(filtration, cube_types)
def _resolve_source_kind(filtration, override):
if override is not None:
if override not in ("points", "field"):
raise ValueError(
f"source_kind must be 'points' or 'field', got {override!r}."
)
return override
# Prefer the FiltrationKind tag set by the constructor; fall back
# to type-based detection for hand-built filtrations whose kind
# was left at User.
kind = getattr(filtration, "kind", None)
if kind is not None:
try:
from .. import _oineus
FK = _oineus.FiltrationKind
if kind in (FK.Cubical, FK.Freudenthal):
return "field"
if kind in (FK.Vr, FK.Alpha, FK.WeakAlpha, FK.CechDelaunay):
return "points"
except (ImportError, AttributeError):
pass
if _is_cubical_filtration(filtration):
return "field"
return "points"
def _square_corners_cyclic(corners):
"""Reorder 4 axis-aligned-square corners into BL, BR, TR, TL cyclic order
for polygon rendering. Each input corner is a length-2 sequence (i, j)."""
arr = np.asarray(corners, dtype=float)
i_min, j_min = float(arr[:, 0].min()), float(arr[:, 1].min())
i_max, j_max = float(arr[:, 0].max()), float(arr[:, 1].max())
return [(i_min, j_min), (i_min, j_max), (i_max, j_max), (i_max, j_min)]
def _cube_3d_face_polys(corners):
"""Return the 6 axis-aligned square faces of a 3-cube as lists of 4
(i, j, k) corners each. ``corners`` is the 8-corner list."""
arr = np.asarray(corners, dtype=float)
mins = arr.min(axis=0)
maxs = arr.max(axis=0)
i0, j0, k0 = mins
i1, j1, k1 = maxs
return [
[(i0, j0, k0), (i1, j0, k0), (i1, j1, k0), (i0, j1, k0)], # k=k0
[(i0, j0, k1), (i1, j0, k1), (i1, j1, k1), (i0, j1, k1)], # k=k1
[(i0, j0, k0), (i1, j0, k0), (i1, j0, k1), (i0, j0, k1)], # j=j0
[(i0, j1, k0), (i1, j1, k0), (i1, j1, k1), (i0, j1, k1)], # j=j1
[(i0, j0, k0), (i0, j1, k0), (i0, j1, k1), (i0, j0, k1)], # i=i0
[(i1, j0, k0), (i1, j1, k0), (i1, j1, k1), (i1, j0, k1)], # i=i1
]
def _tet_face_polys(verts):
"""Return the 4 triangular faces of a tetrahedron given its 4 vertex
coordinates (each a length-3 array)."""
v = list(verts)
return [
[v[0], v[1], v[2]],
[v[0], v[1], v[3]],
[v[0], v[2], v[3]],
[v[1], v[2], v[3]],
]
[docs]
def plot_chain(
source,
filtration,
chain,
*,
ax=None,
source_kind: typing.Optional[str] = None,
dualize=False,
chain_color=None,
point_cloud_color=None,
edge_style: typing.Optional[dict] = None,
triangle_style: typing.Optional[dict] = None,
tetrahedron_style: typing.Optional[dict] = None,
vertex_style: typing.Optional[dict] = None,
point_style: typing.Optional[dict] = None,
title: typing.Optional[str] = None,
plot_source: bool = True,
field_cmap: str = "viridis",
):
"""Render a chain of cells over its underlying source.
``source`` is one of:
- a 2D point cloud as ``(N, 2)`` ndarray (simplicial filtration),
- a 3D point cloud as ``(N, 3)`` ndarray (simplicial filtration),
- a 2D scalar field as ``(H, W)`` ndarray (cubical or Freudenthal),
- a 3D scalar field as ``(D, H, W)`` ndarray (cubical or Freudenthal).
The dispatch is driven by ``source_kind`` (``"points"`` or ``"field"``)
and, for kind ``"points"``, by ``source.shape[1]`` (2 vs 3). When
``source_kind`` is ``None`` we route ``CubeFiltration_*D`` to ``"field"``
and everything else to ``"points"``.
The chain may be a list, ndarray, range, or scipy.sparse column / row
slice (e.g. ``dcmp.v_as_csc()[:, j]``). Each entry is normally a
*filtration sorted-id*; cells of dim 0/1/2/3 render as vertices /
edges / triangles-or-squares / tetrahedra-or-voxels respectively.
The matrices ``v_data``, ``r_data``, ``u_data_t`` of a
``Decomposition(fil, dualize=True)`` are indexed in *cohomology matrix
space* rather than filtration space (matrix index = ``n - 1 -
filtration_id``). Pass ``dualize=True`` (or ``dualize=dcmp`` and we
will read ``dcmp.dualize`` for you) so plot_chain translates the
chain back to filtration ids before looking up cells. The default
``False`` matches homology decompositions and indices that have
already been translated (e.g. those returned by
``TopologyOptimizer.increase_birth``).
Colors:
``chain_color`` -- single color or ``dict[int, color]`` (keys are
cell dimensions: 0=vertex, 1=edge, 2=triangle, 3=tetrahedron);
applied to whatever the matching style dict didn't already
override. ``None`` (default) uses ``DEFAULT_CHAIN_COLOR``
uniformly.
``point_cloud_color`` -- single color for the underlying
point-cloud / field background. ``None`` (default) uses
``DEFAULT_POINT_CLOUD_COLOR``.
"""
if not _HAS_MATPLOTLIB:
raise ImportError("matplotlib is required for plot_chain.")
kind = _resolve_source_kind(filtration, source_kind)
chain_ids = _coerce_chain(chain)
if hasattr(dualize, "dualize"):
# Decomposition (or duck-type with a .dualize attr) -- read the flag.
dualize_flag = bool(dualize.dualize)
else:
dualize_flag = bool(dualize)
if dualize_flag and chain_ids.size:
chain_ids = (filtration.size() - 1) - chain_ids
vertex_style = _resolve_style(vertex_style, default_chain_vertex_style)
edge_style = _resolve_style(edge_style, default_chain_edge_style)
triangle_style = _resolve_style(triangle_style, default_chain_triangle_style)
tetrahedron_style = _resolve_style(tetrahedron_style, default_chain_tetrahedron_style)
point_style = _resolve_style(point_style, default_point_cloud_style)
# Inject chain_color into each stratum style. The style dicts no
# longer carry a default color, so a None user input falls back to
# the module-level constant defined in _styles.py.
def _color_for_dim(d):
if isinstance(chain_color, dict):
return chain_color.get(d)
return chain_color # solo or None
_vertex_c = _color_for_dim(0) or DEFAULT_CHAIN_COLOR
_edge_c = _color_for_dim(1) or DEFAULT_CHAIN_COLOR
_triangle_c = _color_for_dim(2) or DEFAULT_CHAIN_COLOR
_tet_c = _color_for_dim(3) or DEFAULT_CHAIN_COLOR
vertex_style.setdefault("c", _vertex_c)
edge_style.setdefault("color", _edge_c)
triangle_style.setdefault("facecolor", _triangle_c)
triangle_style.setdefault("edgecolor", _triangle_c)
tetrahedron_style.setdefault("facecolor", _tet_c)
tetrahedron_style.setdefault("edgecolor", _tet_c)
point_style.setdefault(
"c",
point_cloud_color if point_cloud_color is not None else DEFAULT_POINT_CLOUD_COLOR,
)
if kind == "points":
points = np.asarray(source, dtype=float)
if points.ndim != 2 or points.shape[1] not in (2, 3):
raise ValueError(
f"For source_kind='points', expected an (N, 2) or (N, 3) "
f"array; got shape {points.shape}."
)
if points.shape[1] == 2:
return _render_chain_points_2d(
ax, points, filtration, chain_ids,
vertex_style=vertex_style,
edge_style=edge_style,
triangle_style=triangle_style,
point_style=point_style,
plot_source=plot_source,
title=title,
)
return _render_chain_points_3d(
ax, points, filtration, chain_ids,
vertex_style=vertex_style,
edge_style=edge_style,
triangle_style=triangle_style,
tetrahedron_style=tetrahedron_style,
point_style=point_style,
plot_source=plot_source,
title=title,
)
# kind == "field"
field = np.asarray(source)
if field.ndim == 2:
return _render_chain_field_2d(
ax, field, filtration, chain_ids,
vertex_style=vertex_style,
edge_style=edge_style,
triangle_style=triangle_style,
field_cmap=field_cmap,
plot_source=plot_source,
title=title,
)
if field.ndim == 3:
return _render_chain_field_3d(
ax, field, filtration, chain_ids,
vertex_style=vertex_style,
edge_style=edge_style,
triangle_style=triangle_style,
tetrahedron_style=tetrahedron_style,
point_style=point_style,
plot_source=plot_source,
title=title,
)
raise ValueError(
f"For source_kind='field', expected a 2D (H, W) or 3D (D, H, W) "
f"array; got shape {field.shape}."
)
# ---------------------------------------------------------------------------
# plot_chain renderers
# ---------------------------------------------------------------------------
def _render_chain_points_2d(
ax, points, filtration, chain_ids,
*, vertex_style, edge_style, triangle_style,
point_style, plot_source, title,
):
if ax is None:
_, ax = plt.subplots()
if plot_source:
ax.scatter(points[:, 0], points[:, 1], **point_style)
vertex_coords = []
edge_segments = []
triangle_polys = []
skipped_high_dim = 0
for cell_id in chain_ids:
cell = filtration[int(cell_id)]
verts = list(cell.vertices)
if len(verts) == 1:
vertex_coords.append(points[verts[0]])
elif len(verts) == 2:
edge_segments.append((points[verts[0]], points[verts[1]]))
elif len(verts) == 3:
triangle_polys.append([points[v] for v in verts])
else:
skipped_high_dim += 1
if skipped_high_dim:
warnings.warn(
f"plot_chain: skipped {skipped_high_dim} cells of dim >= 3 "
f"(2D point-cloud rendering only supports vertices, edges, "
f"triangles).",
stacklevel=3,
)
if triangle_polys:
ax.add_collection(PolyCollection(triangle_polys, **triangle_style))
if edge_segments:
ax.add_collection(LineCollection(edge_segments, **edge_style))
if vertex_coords:
arr = np.asarray(vertex_coords)
ax.scatter(arr[:, 0], arr[:, 1], **vertex_style)
if title is not None:
ax.set_title(title)
ax.set_aspect("equal", adjustable="datalim")
return ax
def _ensure_3d_axes(ax):
if ax is None:
fig = plt.figure()
return fig.add_subplot(111, projection="3d")
if getattr(ax, "name", "") != "3d":
raise ValueError(
"3D plot_chain rendering requires a 3D Axes; pass ax=None to "
"auto-create one or build with projection='3d'."
)
return ax
def _render_chain_points_3d(
ax, points, filtration, chain_ids,
*, vertex_style, edge_style, triangle_style, tetrahedron_style,
point_style, plot_source, title,
):
from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection
ax = _ensure_3d_axes(ax)
if plot_source:
ax.scatter(points[:, 0], points[:, 1], points[:, 2], **point_style)
vertex_coords = []
edge_segments = []
triangle_polys = []
tetrahedron_face_polys = []
skipped_high_dim = 0
for cell_id in chain_ids:
cell = filtration[int(cell_id)]
verts = list(cell.vertices)
if len(verts) == 1:
vertex_coords.append(points[verts[0]])
elif len(verts) == 2:
edge_segments.append((points[verts[0]], points[verts[1]]))
elif len(verts) == 3:
triangle_polys.append([points[v] for v in verts])
elif len(verts) == 4:
tetrahedron_face_polys.extend(_tet_face_polys([points[v] for v in verts]))
else:
skipped_high_dim += 1
if skipped_high_dim:
warnings.warn(
f"plot_chain: skipped {skipped_high_dim} cells of dim >= 4 "
f"(3D point-cloud rendering only supports cells up to "
f"tetrahedra).",
stacklevel=3,
)
if tetrahedron_face_polys:
ax.add_collection3d(Poly3DCollection(tetrahedron_face_polys, **tetrahedron_style))
if triangle_polys:
ax.add_collection3d(Poly3DCollection(triangle_polys, **triangle_style))
if edge_segments:
ax.add_collection3d(Line3DCollection(edge_segments, **edge_style))
if vertex_coords:
arr = np.asarray(vertex_coords)
ax.scatter(arr[:, 0], arr[:, 1], arr[:, 2], **vertex_style)
if title is not None:
ax.set_title(title)
return ax
def _render_chain_field_2d(
ax, field, filtration, chain_ids,
*, vertex_style, edge_style, triangle_style,
field_cmap, plot_source, title,
):
if ax is None:
_, ax = plt.subplots()
H, W = field.shape
if plot_source:
# origin='lower' so that array index i maps to plot y = i, j to x = j.
ax.imshow(field, origin="lower", cmap=field_cmap,
extent=(-0.5, W - 0.5, -0.5, H - 0.5), zorder=0)
cubical = _is_cubical_filtration(filtration)
vertex_xy = []
edge_segments = []
polys = []
skipped_high_dim = 0
if cubical:
for cell_id in chain_ids:
cell = filtration[int(cell_id)]
corners = [tuple(c) for c in cell.vertices] # list of (i, j)
if len(corners) == 1:
i, j = corners[0]
vertex_xy.append((j, i))
elif len(corners) == 2:
(i0, j0), (i1, j1) = corners
edge_segments.append(((j0, i0), (j1, i1)))
elif len(corners) == 4:
rect = _square_corners_cyclic(corners) # [(i, j), ...]
polys.append([(j, i) for (i, j) in rect])
else:
skipped_high_dim += 1
else:
# Simplicial Freudenthal: vertex IDs ravel C-order over (H, W).
for cell_id in chain_ids:
cell = filtration[int(cell_id)]
verts = list(cell.vertices)
if len(verts) == 1:
i, j = _id_to_grid(verts[0], (H, W))
vertex_xy.append((j, i))
elif len(verts) == 2:
(i0, j0) = _id_to_grid(verts[0], (H, W))
(i1, j1) = _id_to_grid(verts[1], (H, W))
edge_segments.append(((j0, i0), (j1, i1)))
elif len(verts) == 3:
tri = [_id_to_grid(v, (H, W)) for v in verts]
polys.append([(j, i) for (i, j) in tri])
else:
skipped_high_dim += 1
if skipped_high_dim:
warnings.warn(
f"plot_chain: skipped {skipped_high_dim} cells of dim >= 3 "
f"(2D field rendering only supports vertices, edges, faces).",
stacklevel=3,
)
if polys:
ax.add_collection(PolyCollection(polys, **triangle_style))
if edge_segments:
ax.add_collection(LineCollection(edge_segments, **edge_style))
if vertex_xy:
arr = np.asarray(vertex_xy, dtype=float)
ax.scatter(arr[:, 0], arr[:, 1], **vertex_style)
if title is not None:
ax.set_title(title)
ax.set_aspect("equal", adjustable="box")
return ax
def _render_chain_field_3d(
ax, field, filtration, chain_ids,
*, vertex_style, edge_style, triangle_style, tetrahedron_style,
point_style, plot_source, title,
):
from mpl_toolkits.mplot3d.art3d import Line3DCollection, Poly3DCollection
ax = _ensure_3d_axes(ax)
D, H, W = field.shape
if plot_source:
# Render the grid as a sparse scatter with color = field value.
# Drawing every grid point is overkill at large sizes, but keeps the
# demo straightforward; users can disable with plot_source=False.
ii, jj, kk = np.mgrid[0:D, 0:H, 0:W]
ax.scatter(
kk.ravel(), jj.ravel(), ii.ravel(),
c=field.ravel(),
cmap="viridis",
s=point_style.get("s", 12.0),
alpha=point_style.get("alpha", 0.4),
)
cubical = _is_cubical_filtration(filtration)
vertex_xyz = []
edge_segments = []
triangle_polys = []
cube_face_polys = []
skipped_high_dim = 0
def _ijk_to_xyz(ijk):
# Plot convention: x=k, y=j, z=i so the third array dim runs along x.
i, j, k = ijk
return (k, j, i)
if cubical:
for cell_id in chain_ids:
cell = filtration[int(cell_id)]
corners = [tuple(c) for c in cell.vertices]
if len(corners) == 1:
vertex_xyz.append(_ijk_to_xyz(corners[0]))
elif len(corners) == 2:
edge_segments.append((_ijk_to_xyz(corners[0]), _ijk_to_xyz(corners[1])))
elif len(corners) == 4:
# 2-cube (square face). Reorder to cyclic by bounding-box.
arr = np.asarray(corners, dtype=float)
axis = int(np.argmin(arr.ptp(axis=0))) # the constant axis
others = [d for d in range(3) if d != axis]
a0, a1 = arr[:, others[0]].min(), arr[:, others[0]].max()
b0, b1 = arr[:, others[1]].min(), arr[:, others[1]].max()
c = arr[0, axis]
base = [None] * 3
ordered = []
for (a, b) in [(a0, b0), (a1, b0), (a1, b1), (a0, b1)]:
base[axis] = c
base[others[0]] = a
base[others[1]] = b
ordered.append(_ijk_to_xyz(tuple(base)))
triangle_polys.append(ordered)
elif len(corners) == 8:
for face in _cube_3d_face_polys(corners):
cube_face_polys.append([_ijk_to_xyz(c) for c in face])
else:
skipped_high_dim += 1
else:
# Simplicial Freudenthal in 3D.
for cell_id in chain_ids:
cell = filtration[int(cell_id)]
verts = list(cell.vertices)
coords = [_ijk_to_xyz(_id_to_grid(v, (D, H, W))) for v in verts]
if len(coords) == 1:
vertex_xyz.append(coords[0])
elif len(coords) == 2:
edge_segments.append((coords[0], coords[1]))
elif len(coords) == 3:
triangle_polys.append(coords)
elif len(coords) == 4:
cube_face_polys.extend(_tet_face_polys(coords))
else:
skipped_high_dim += 1
if skipped_high_dim:
warnings.warn(
f"plot_chain: skipped {skipped_high_dim} cells of dim >= 4 "
f"(3D field rendering only supports cells up to 3-cubes / tets).",
stacklevel=3,
)
if cube_face_polys:
ax.add_collection3d(Poly3DCollection(cube_face_polys, **tetrahedron_style))
if triangle_polys:
ax.add_collection3d(Poly3DCollection(triangle_polys, **triangle_style))
if edge_segments:
ax.add_collection3d(Line3DCollection(edge_segments, **edge_style))
if vertex_xyz:
arr = np.asarray(vertex_xyz, dtype=float)
ax.scatter(arr[:, 0], arr[:, 1], arr[:, 2], **vertex_style)
if title is not None:
ax.set_title(title)
return ax