Source code for oineus.diff.cech_delaunay

import time
import numpy as np
import torch

from .. import _delaunay_combinatorics, _oineus
from .diff_filtration import DiffFiltration

def triangle_meb(p0, p1, p2, eps=1e-12):
    """
    Compute minimum enclosing ball center and radius squared for triangles.

    Args:
        p0, p1, p2: Tensor of shape (n, d) for n triangles in d dimensions
        eps: Small value for numerical stability

    Returns:
        centers: Tensor of shape (n, d) - MEB centers
        radii_sq: Tensor of shape (n,) - MEB radii squared
    """
    a = p1 - p0
    b = p2 - p0
    c = p2 - p1

    a_sq = torch.sum(a ** 2, dim=1)
    b_sq = torch.sum(b ** 2, dim=1)
    c_sq = torch.sum(c ** 2, dim=1)

    d = p0.shape[1]
    if d == 2:
        cross = a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0]
        area_2_sq = cross ** 2
    else:
        cross = torch.cross(a, b, dim=1)
        area_2_sq = torch.sum(cross ** 2, dim=1)

    circum_radii_sq = (a_sq * b_sq * c_sq + eps) / (4 * area_2_sq + eps)

    if d == 3:
        cross_ab = torch.cross(a, b, dim=1)
        cross_ab_sq = torch.sum(cross_ab ** 2, dim=1, keepdim=True)
        a_dot_a = a_sq.unsqueeze(1)
        b_dot_b = b_sq.unsqueeze(1)
        b_cross_axb = torch.cross(b, cross_ab, dim=1)
        axb_cross_a = torch.cross(cross_ab, a, dim=1)
        circum_centers = p0 + (a_dot_a * b_cross_axb + b_dot_b * axb_cross_a) / (2 * cross_ab_sq + eps)
    else:
        a_dot_a = a_sq.unsqueeze(1)
        b_dot_b = b_sq.unsqueeze(1)
        D = 2 * (a[:, 0:1] * b[:, 1:2] - a[:, 1:2] * b[:, 0:1])
        ux = (b[:, 1:2] * a_dot_a - a[:, 1:2] * b_dot_b) / (D + eps)
        uy = (a[:, 0:1] * b_dot_b - b[:, 0:1] * a_dot_a) / (D + eps)
        circum_centers = p0 + torch.cat([ux, uy], dim=1)

    abc_sq = torch.stack((a_sq, b_sq, c_sq), dim=0)
    s_abc_sq, sort_idx = torch.sort(abc_sq, dim=0)
    obtuse_mask = s_abc_sq[2, :] > s_abc_sq[0, :] + s_abc_sq[1, :]

    longest_edge_idx = sort_idx[2, :]

    midpoint_a = (p0 + p1) / 2
    midpoint_b = (p0 + p2) / 2
    midpoint_c = (p1 + p2) / 2

    centers = circum_centers.clone()
    radii_sq = circum_radii_sq.clone()

    if obtuse_mask.any():
        obtuse_longest = longest_edge_idx[obtuse_mask]
        mask_a = obtuse_longest == 0
        mask_b = obtuse_longest == 1
        mask_c = obtuse_longest == 2
        obtuse_indices = torch.where(obtuse_mask)[0]

        if mask_a.any():
            centers[obtuse_indices[mask_a]] = midpoint_a[obtuse_mask][mask_a]
        if mask_b.any():
            centers[obtuse_indices[mask_b]] = midpoint_b[obtuse_mask][mask_b]
        if mask_c.any():
            centers[obtuse_indices[mask_c]] = midpoint_c[obtuse_mask][mask_c]

        radii_sq[obtuse_mask] = s_abc_sq[2, obtuse_mask] / 4

    return centers, radii_sq


def tetrahedron_meb(p0, p1, p2, p3, eps=1e-12, return_centers=False):
    """
    Compute minimum enclosing ball center and radius squared for tetrahedra.

    The MEB of a tetrahedron is one of:
    1. The circumsphere (all 4 vertices on boundary)
    2. A face's MEB (if opposite vertex is inside that MEB)
    3. An edge's MEB (if other two vertices are inside that MEB)

    Args:
        p0, p1, p2, p3: Tensor of shape (n, 3) for n tetrahedra
        eps: Small value for numerical stability

    Returns:
        centers: Tensor of shape (n, 3) - MEB centers
        radii_sq: Tensor of shape (n,) - MEB radii squared
    """
    n = p0.shape[0]
    device = p0.device
    dtype = p0.dtype

    # Compute circumsphere of tetrahedron
    a = p1 - p0
    b = p2 - p0
    c = p3 - p0

    a_sq = torch.sum(a ** 2, dim=1, keepdim=True)
    b_sq = torch.sum(b ** 2, dim=1, keepdim=True)
    c_sq = torch.sum(c ** 2, dim=1, keepdim=True)

    cross_bc = torch.cross(b, c, dim=1)
    cross_ca = torch.cross(c, a, dim=1)
    cross_ab = torch.cross(a, b, dim=1)

    volume_6 = torch.sum(a * cross_bc, dim=1)

    numerator_vec = a_sq * cross_bc + b_sq * cross_ca + c_sq * cross_ab
    circum_disp = numerator_vec / (2 * volume_6.unsqueeze(1) + eps)
    circum_center = p0 + circum_disp
    circum_radii_sq = torch.sum(circum_disp ** 2, dim=1)

    # Compute MEB for each of the 4 faces
    face_centers_0, face_radii_sq_0 = triangle_meb(p1, p2, p3, eps)  # opposite to p0
    face_centers_1, face_radii_sq_1 = triangle_meb(p0, p2, p3, eps)  # opposite to p1
    face_centers_2, face_radii_sq_2 = triangle_meb(p0, p1, p3, eps)  # opposite to p2
    face_centers_3, face_radii_sq_3 = triangle_meb(p0, p1, p2, eps)  # opposite to p3

    # Check if opposite vertex is contained in face's MEB
    dist_sq_0 = torch.sum((p0 - face_centers_0) ** 2, dim=1)
    dist_sq_1 = torch.sum((p1 - face_centers_1) ** 2, dim=1)
    dist_sq_2 = torch.sum((p2 - face_centers_2) ** 2, dim=1)
    dist_sq_3 = torch.sum((p3 - face_centers_3) ** 2, dim=1)

    contains_0 = dist_sq_0 <= face_radii_sq_0 + eps
    contains_1 = dist_sq_1 <= face_radii_sq_1 + eps
    contains_2 = dist_sq_2 <= face_radii_sq_2 + eps
    contains_3 = dist_sq_3 <= face_radii_sq_3 + eps

    # Build candidate radii and centers
    inf_val = torch.tensor(float('inf'), dtype=dtype, device=device)

    # Stack all candidates: circumsphere, 4 faces, 6 edges = 11 candidates
    all_radii_sq = torch.stack([
        circum_radii_sq,
        torch.where(contains_0, face_radii_sq_0, inf_val),
        torch.where(contains_1, face_radii_sq_1, inf_val),
        torch.where(contains_2, face_radii_sq_2, inf_val),
        torch.where(contains_3, face_radii_sq_3, inf_val),
    ], dim=0)

    # Find minimum radius for each tetrahedron
    min_radii_sq, min_idx = torch.min(all_radii_sq, dim=0)

    centers = None

    # Gather corresponding centers
    if return_centers:
        all_centers = torch.stack([
            circum_center,
            face_centers_0,
            face_centers_1,
            face_centers_2,
            face_centers_3,
        ], dim=0)
        centers = all_centers[min_idx, torch.arange(n)]

    return centers, min_radii_sq


[docs] def cech_delaunay_filtration(points, eps: float = 0.0, packed: bool = False, print_time: bool = False): """Build a differentiable Cech-Delaunay filtration from a point cloud. The combinatorics of the alpha complex are computed via diode (CGAL); the filtration values are recomputed differentiably as squared minimum enclosing ball radii of each simplex, so gradients flow back to ``points``. Args: points: ``(n, d)`` torch.Tensor with ``d in {2, 3}``. Differentiable. eps: Small value for numerical stability in the MEB computation. packed: Use the compact bit-packed cell encoding for the Delaunay combinatorics when the vertex ids fit a 64/128-bit word. The values (and gradients) are recomputed here regardless of encoding. print_time: If True, print per-stage timings. Returns: DiffFiltration whose values are squared MEB radii. """ if print_time: start = time.time() points_np = points.detach().cpu().numpy() alpha_fil = _delaunay_combinatorics(points_np, packed=packed) if print_time: elapsed = time.time() - start print(f"alpha_fil construction elapsed: {elapsed:.3f}") if print_time: start = time.time() values_in_dim = [torch.zeros(alpha_fil.size_in_dimension(0), requires_grad=True, device=points.device)] if print_time: elapsed = time.time() - start print(f"initialize dim-0 values elapsed: {elapsed:.3f}") for dim in range(1, alpha_fil.max_dim + 1): if print_time: start_dim = time.time() if dim == 1: if print_time: start = time.time() edges = torch.LongTensor(alpha_fil.get_edges().astype(np.uint64)) if print_time: elapsed = time.time() - start print(f"dim 1 get edges elapsed: {elapsed:.3f}") if print_time: start = time.time() sqdists = torch.sum((points[edges[:, 0]] - points[edges[:, 1]]) ** 2, axis=1) radii_sq = 0.25 * sqdists if print_time: elapsed = time.time() - start print(f"dim 1 compute edge radii elapsed: {elapsed:.3f}") assert edges.shape[0] == radii_sq.shape[0] elif dim == 2: if print_time: start = time.time() triangles = torch.LongTensor(alpha_fil.get_triangles().astype(np.uint64)) if print_time: elapsed = time.time() - start print(f"dim 2 get triangles elapsed: {elapsed:.3f}") if print_time: start = time.time() p0 = points[triangles[:, 0]] p1 = points[triangles[:, 1]] p2 = points[triangles[:, 2]] if print_time: elapsed = time.time() - start print(f"dim 2 gather triangle points elapsed: {elapsed:.3f}") # ignore centers if print_time: start = time.time() _, radii_sq = triangle_meb(p0, p1, p2, eps) if print_time: elapsed = time.time() - start print(f"dim 2 triangle_meb elapsed: {elapsed:.3f}") assert triangles.shape[0] == radii_sq.shape[0] elif dim == 3: if print_time: start = time.time() tetra = torch.LongTensor(alpha_fil.get_tetrahedra().astype(np.uint64)) if print_time: elapsed = time.time() - start print(f"dim 3 get tetrahedra elapsed: {elapsed:.3f}") if print_time: start = time.time() p0 = points[tetra[:, 0]] p1 = points[tetra[:, 1]] p2 = points[tetra[:, 2]] p3 = points[tetra[:, 3]] if print_time: elapsed = time.time() - start print(f"dim 3 gather tetra points elapsed: {elapsed:.3f}") if print_time: start = time.time() _, radii_sq = tetrahedron_meb(p0, p1, p2, p3, eps) if print_time: elapsed = time.time() - start print(f"dim 3 tetrahedron_meb elapsed: {elapsed:.3f}") assert tetra.shape[0] == radii_sq.shape[0] else: raise RuntimeError("Dimension not supported") if print_time: elapsed = time.time() - start_dim print(f"dim {dim} total elapsed: {elapsed:.3f}") if print_time: start = time.time() values_in_dim.append(radii_sq) if print_time: elapsed = time.time() - start print(f"append dim {dim} values elapsed: {elapsed:.3f}") # this will sort the simplices in the filtration correctly, cd_vals_np is not monotonic if print_time: start = time.time() cd_vals = torch.cat(values_in_dim) if print_time: elapsed = time.time() - start print(f"concatenate values elapsed: {elapsed:.3f}") if print_time: start = time.time() cd_vals_list = [ float(x) for x in cd_vals.clone().detach().cpu() ] if print_time: elapsed = time.time() - start print(f"convert values to list elapsed: {elapsed:.3f}") # set non-differentiable internal Oineus values: if print_time: start = time.time() alpha_fil.set_values(cd_vals_list) if print_time: elapsed = time.time() - start print(f"set alpha_fil values elapsed: {elapsed:.3f}") # sort values in each dimension independently, in a differentiable way: if print_time: start = time.time() sorted_vals = torch.cat([ torch.sort(vals)[0] for vals in values_in_dim]) if print_time: elapsed = time.time() - start print(f"sort values by dimension elapsed: {elapsed:.3f}") if print_time: start = time.time() alpha_fil.kind = _oineus.FiltrationKind.CechDelaunay result = DiffFiltration(alpha_fil, sorted_vals) if print_time: elapsed = time.time() - start print(f"build DiffFiltration elapsed: {elapsed:.3f}") return result