Source code for oineus.diff.weak_alpha

"""Differentiable weak-alpha filtration.

Same combinatorics as the alpha complex (built via diode/CGAL), but each
simplex is assigned the squared length of its longest edge (vertices get 0).
The squared-distance convention matches cech_delaunay_filtration so the two
are directly comparable, and the longest-edge rule mirrors Vietoris-Rips
restricted to the alpha-complex simplices.
"""

import time
import numpy as np
import torch

from .. import _delaunay_combinatorics, _oineus
from .diff_filtration import DiffFiltration


[docs] def weak_alpha_filtration(points, packed: bool = False, print_time: bool = False): """Build a differentiable weak-alpha filtration from a point cloud. Args: points: ``(n, d)`` torch.Tensor with ``d in {2, 3}``. Differentiable. packed: Use the compact bit-packed cell encoding for the Delaunay combinatorics when the vertex ids fit a 64/128-bit word. The values (and gradients) are recomputed here regardless of encoding. print_time: If True, print per-stage timings. Returns: DiffFiltration whose values are squared longest-edge lengths. """ if print_time: start = time.time() points_np = points.detach().cpu().numpy() alpha_fil = _delaunay_combinatorics(points_np, packed=packed) if print_time: elapsed = time.time() - start print(f"alpha_fil construction elapsed: {elapsed:.3f}") n0 = alpha_fil.size_in_dimension(0) values_in_dim = [ torch.zeros(n0, requires_grad=True, device=points.device, dtype=points.dtype) ] for dim in range(1, alpha_fil.max_dim + 1): if print_time: start_dim = time.time() if dim == 1: edges = torch.LongTensor(alpha_fil.get_edges().astype(np.uint64)) values = torch.sum((points[edges[:, 0]] - points[edges[:, 1]]) ** 2, dim=1) elif dim == 2: tri = torch.LongTensor(alpha_fil.get_triangles().astype(np.uint64)) p0 = points[tri[:, 0]] p1 = points[tri[:, 1]] p2 = points[tri[:, 2]] d01 = torch.sum((p0 - p1) ** 2, dim=1) d02 = torch.sum((p0 - p2) ** 2, dim=1) d12 = torch.sum((p1 - p2) ** 2, dim=1) values = torch.amax(torch.stack([d01, d02, d12], dim=0), dim=0) elif dim == 3: tet = torch.LongTensor(alpha_fil.get_tetrahedra().astype(np.uint64)) p0 = points[tet[:, 0]] p1 = points[tet[:, 1]] p2 = points[tet[:, 2]] p3 = points[tet[:, 3]] d01 = torch.sum((p0 - p1) ** 2, dim=1) d02 = torch.sum((p0 - p2) ** 2, dim=1) d03 = torch.sum((p0 - p3) ** 2, dim=1) d12 = torch.sum((p1 - p2) ** 2, dim=1) d13 = torch.sum((p1 - p3) ** 2, dim=1) d23 = torch.sum((p2 - p3) ** 2, dim=1) values = torch.amax(torch.stack([d01, d02, d03, d12, d13, d23], dim=0), dim=0) else: raise RuntimeError(f"weak_alpha_filtration: dim={dim} not supported") if print_time: elapsed = time.time() - start_dim print(f"dim {dim} weak-alpha values elapsed: {elapsed:.3f}") values_in_dim.append(values) if print_time: start = time.time() cd_vals = torch.cat(values_in_dim) cd_vals_list = [float(x) for x in cd_vals.clone().detach().cpu()] alpha_fil.set_values(cd_vals_list) if print_time: elapsed = time.time() - start print(f"set values elapsed: {elapsed:.3f}") if print_time: start = time.time() sorted_vals = torch.cat([torch.sort(vals)[0] for vals in values_in_dim]) if print_time: elapsed = time.time() - start print(f"sort values elapsed: {elapsed:.3f}") alpha_fil.kind = _oineus.FiltrationKind.WeakAlpha return DiffFiltration(alpha_fil, sorted_vals)