Source code for oineus.diff.top_optimizer

from typing import List, Optional

from .. import _oineus
from .. import _OPT_CLASS_BY_FIL_TYPE
from .diff_filtration import DiffFiltration
from ._reduction_policy import default_dualize_for_filtration


def _opt_class_for_filtration(fil):
    """Pick the C++ TopologyOptimizer instantiation for this cell type."""
    if isinstance(fil, DiffFiltration):
        fil = fil.under_fil
    cls = _OPT_CLASS_BY_FIL_TYPE.get(type(fil))
    if cls is None:
        raise RuntimeError(
            f"unknown filtration type {type(fil).__name__}; expected one of "
            f"{[t.__name__ for t in _OPT_CLASS_BY_FIL_TYPE]}"
        )
    return cls, fil


def make_under_topology_optimizer(fil, *, with_crit_sets, dims_to_restore_elz=None,
                                  n_threads=1, u_strategy=None):
    """Construct the C++ TopologyOptimizer for any filtration cell type."""
    cls, under_fil = _opt_class_for_filtration(fil)
    dims = list(dims_to_restore_elz) if dims_to_restore_elz else []
    if u_strategy is None:
        u_strategy = _oineus.UStrategy.Auto
    return cls(under_fil, with_crit_sets=with_crit_sets,
               dims_to_restore_elz=dims, n_threads=n_threads,
               u_strategy=u_strategy)


[docs] class TopologyOptimizer: """Thin Python wrapper that picks the right C++ TopologyOptimizer instantiation for the filtration's cell type and forwards calls to it. The under_fil (the C++ filtration) is exposed via `self.under_fil` for callers that need it directly. The optimizer is built for one autograd backward; the reduction recipe is fixed at construction time:: with_crit_sets=False -> R only (parallel + clearing). with_crit_sets=True -> R + V + restore_ELZ in the given dims. U is recovered on demand via ensure_has_u_hom / ensure_has_u_coh, unless u_strategy=LegacyInBand, in which case U is built in-band (serial, clearing off). """ def __init__(self, fil, *, with_crit_sets: bool = True, dims_to_restore_elz: Optional[List[int]] = None, n_threads: int = 1, u_strategy=None): cls, under_fil = _opt_class_for_filtration(fil) dims = list(dims_to_restore_elz) if dims_to_restore_elz else [] if u_strategy is None: u_strategy = _oineus.UStrategy.Auto self.under_opt = cls(under_fil, with_crit_sets=with_crit_sets, dims_to_restore_elz=dims, n_threads=n_threads, u_strategy=u_strategy) self.under_fil = under_fil # diagram + simplification def compute_diagram(self, include_inf_points: bool): return self.under_opt.compute_diagram(include_inf_points) def simplify(self, epsilon: float, strategy, dim: int): return self.under_opt.simplify(epsilon, strategy, dim) def get_nth_persistence(self, dim: int, n: int): return self.under_opt.get_nth_persistence(dim=dim, n=n) def match(self, template_dgm, dim: int, wasserstein_q: float = 1.0, wasserstein_delta: float = 0.01, return_wasserstein_distance: bool = False, dualize=None): if dualize is None: dualize = default_dualize_for_filtration(self.under_fil) if dualize: self.ensure_coh_reduced() else: self.ensure_hom_reduced() return self.under_opt.match( template_dgm=template_dgm, dim=dim, q=wasserstein_q, wasserstein_delta=wasserstein_delta, return_wasserstein_distance=return_wasserstein_distance, dualize=dualize, ) # decomposition handles @property def homology_decomposition(self): return self.under_opt.homology_decomposition @property def cohomology_decomposition(self): return self.under_opt.cohomology_decomposition def homology_decomposition_ref(self): return self.under_opt.homology_decomposition_ref() def cohomology_decomposition_ref(self): return self.under_opt.cohomology_decomposition_ref() # reduction control def ensure_hom_built(self): return self.under_opt.ensure_hom_built() def ensure_coh_built(self): return self.under_opt.ensure_coh_built() @property def is_hom_built(self): return self.under_opt.is_hom_built @property def is_coh_built(self): return self.under_opt.is_coh_built def ensure_hom_reduced(self): return self.under_opt.ensure_hom_reduced() def ensure_coh_reduced(self): return self.under_opt.ensure_coh_reduced() def ensure_has_u_hom(self, dim: int, rows_fil, bounds): return self.under_opt.ensure_has_u_hom(dim, rows_fil, bounds) def ensure_has_u_coh(self, dim: int, rows_fil, bounds): return self.under_opt.ensure_has_u_coh(dim, rows_fil, bounds) def reduce_all(self): return self.under_opt.reduce_all() def update(self, new_values, n_threads: int = 1): return self.under_opt.update(new_values, n_threads) # crit-sets primitives def crit_sets_apply(self, indices, values, strategy): return self.under_opt.crit_sets_apply(indices, values, strategy) def singleton(self, index: int, value: float): return self.under_opt.singleton(index, value) def singletons(self, indices, values): return self.under_opt.singletons(indices, values) def combine_loss(self, critical_sets, strategy): return self.under_opt.combine_loss(critical_sets, strategy) def increase_death(self, negative_simplex_idx: int): return self.under_opt.increase_death(negative_simplex_idx) def decrease_death(self, negative_simplex_idx: int): return self.under_opt.decrease_death(negative_simplex_idx) def increase_birth(self, positive_simplex_idx: int): return self.under_opt.increase_birth(positive_simplex_idx) def decrease_birth(self, positive_simplex_idx: int): return self.under_opt.decrease_birth(positive_simplex_idx)