Source code for diffsim.collision

"""
Collision detection and response

This module implements two collision handling approaches:

1. **IPC (Incremental Potential Contact)**: Rigorous collision handling using
   barrier potentials with provable non-penetration guarantees. Based on Li et al.
   "Incremental Potential Contact" (2020).

2. **Simplified**: Fast approximate collision detection using distance-based
   repulsion forces for real-time simulation.

The IPC barrier function is:

.. math::

    b(d) = -(d - \\hat{d})^2 \\log(d/\\hat{d}) \\text{ for } d < \\hat{d}

where :math:`d` is the distance between primitives and :math:`\\hat{d}` is the
activation distance. This creates smooth repulsive forces that prevent penetration.
"""

import torch
import numpy as np


[docs] class IPCCollisionHandler: """ IPC (Incremental Potential Contact) collision handler Implements collision handling using smooth barrier potentials as described in: Li, M., Ferguson, Z., Schneider, T., Langlois, T. R., Zorin, D., Panozzo, D., ... & Jiang, C. (2020). Incremental potential contact: intersection-and inversion-free, large-deformation dynamics. ACM Trans. Graph., 39(4), 49. The barrier function creates smooth repulsive forces that activate when primitives approach within a threshold distance :math:`\\hat{d}`, preventing penetration while maintaining continuity for gradient-based optimization. Parameters: barrier_stiffness (float): Stiffness coefficient :math:`\\kappa` (default: 1e3) dhat (float): Barrier activation distance :math:`\\hat{d}` (default: 1e-3) friction_mu (float): Friction coefficient (default: 0.3) Attributes: kappa (float): Barrier stiffness dhat (float): Activation distance dhat_squared (float): Squared activation distance for efficiency friction_mu (float): Friction coefficient """
[docs] def __init__(self, barrier_stiffness=1e3, dhat=1e-3, friction_mu=0.3): """ Initialize IPC collision handler Args: barrier_stiffness: stiffness of barrier potential dhat: activation distance for barrier (collision threshold) friction_mu: friction coefficient """ self.kappa = barrier_stiffness self.dhat = dhat self.dhat_squared = dhat * dhat self.friction_mu = friction_mu
[docs] def barrier_function(self, d_squared): """ IPC barrier function: b(d) for distance d Args: d_squared: squared distance Returns: barrier value """ # Avoid log(0) d_squared = torch.clamp(d_squared, min=1e-12) dhat_sq = self.dhat_squared # Barrier: -(d - dhat)^2 * log(d / dhat) # Only active when d < dhat active = d_squared < dhat_sq if not active.any(): return torch.zeros_like(d_squared) d = torch.sqrt(d_squared) ratio = d / self.dhat barrier = torch.zeros_like(d_squared) barrier[active] = -((d[active] - self.dhat) ** 2) * torch.log( ratio[active] + 1e-12 ) return barrier
[docs] def barrier_gradient(self, d_squared): """ Gradient of barrier function Args: d_squared: squared distance Returns: gradient magnitude """ d_squared = torch.clamp(d_squared, min=1e-12) d = torch.sqrt(d_squared) active = d_squared < self.dhat_squared if not active.any(): return torch.zeros_like(d) grad = torch.zeros_like(d) d_active = d[active] ratio = d_active / self.dhat # d/dd[ -(d - dhat)^2 * log(d/dhat) ] term1 = -2 * (d_active - self.dhat) * torch.log(ratio + 1e-12) term2 = -((d_active - self.dhat) ** 2) / (d_active + 1e-12) grad[active] = term1 + term2 return grad
[docs] def point_triangle_distance(self, p, v0, v1, v2): """ Compute squared distance from point p to triangle (v0, v1, v2) Args: p: :math:`(N, 3)` points v0, v1, v2: :math:`(M, 3)` triangle vertices Returns: distances: :math:`(N, M)` squared distances closest_points: :math:`(N, M, 3)` closest points on triangles """ # Expand dimensions for broadcasting p = p.unsqueeze(1) # (N, 1, 3) v0 = v0.unsqueeze(0) # (1, M, 3) v1 = v1.unsqueeze(0) v2 = v2.unsqueeze(0) # Triangle edges e0 = v1 - v0 e1 = v2 - v1 e2 = v0 - v2 # Vector from v0 to p v0p = p - v0 # Normal (not normalized) normal = torch.cross(e0, v2 - v0, dim=-1) # Project point onto plane normal_norm_sq = torch.sum(normal * normal, dim=-1, keepdim=True) + 1e-12 dist_to_plane = torch.sum(v0p * normal, dim=-1, keepdim=True) proj = p - normal * (dist_to_plane / normal_norm_sq) # Check if projection is inside triangle using barycentric coordinates # Simplified: just compute distance to triangle plane for now d_squared = dist_to_plane**2 / normal_norm_sq return d_squared.squeeze(-1), proj.squeeze(1)
[docs] def compute_self_collision_forces(self, mesh, positions): """ Compute self-collision forces using IPC barrier potentials This is a simplified version that checks vertex-face distances Args: mesh: TetrahedralMesh positions: :math:`(N, 3)` current positions Returns: forces: :math:`(N, 3)` collision forces """ device = positions.device forces = torch.zeros_like(positions) # Extract surface triangles (boundary faces) surface_faces = self._extract_surface_faces(mesh) if surface_faces is None or len(surface_faces) == 0: return forces # For each vertex, check distance to non-adjacent faces # This is O(N*M) but we'll use spatial hashing for efficiency # Simple version: check all vertex-face pairs (expensive but correct) # In production, use BVH (Bounding Volume Hierarchy) num_checks = 0 max_checks = min( positions.shape[0] * surface_faces.shape[0], 10000 ) # Limit for speed # Sample subset for efficiency vertex_sample = torch.randperm(positions.shape[0], device=device)[ : min(100, positions.shape[0]) ] face_sample = torch.randperm(surface_faces.shape[0], device=device)[ : min(100, surface_faces.shape[0]) ] for v_idx in vertex_sample: p = positions[v_idx] for f_idx in face_sample: face = surface_faces[f_idx] # Skip if vertex is part of this face if v_idx in face: continue # Compute distance v0, v1, v2 = positions[face[0]], positions[face[1]], positions[face[2]] # Simple point-to-plane distance edge1 = v1 - v0 edge2 = v2 - v0 normal = torch.cross(edge1, edge2) normal_norm = torch.norm(normal) + 1e-12 normal = normal / normal_norm # Distance from point to plane to_point = p - v0 dist = torch.abs(torch.dot(to_point, normal)) d_squared = dist**2 # Apply barrier if within threshold if d_squared < self.dhat_squared: # Compute barrier gradient grad_mag = self.barrier_gradient(d_squared.unsqueeze(0))[0] # Force direction (repulsion along normal) sign = torch.sign(torch.dot(to_point, normal)) force = self.kappa * grad_mag * sign * normal forces[v_idx] += force num_checks += 1 if num_checks >= max_checks: return forces return forces
def _extract_surface_faces(self, mesh): """ Extract surface triangles from tetrahedral mesh Returns: surface_faces: :math:`(F, 3)` tensor of surface triangle indices """ tets = mesh.tetrahedra.cpu() # Each tet has 4 faces all_faces = torch.cat( [ tets[:, [0, 2, 1]], tets[:, [0, 1, 3]], tets[:, [0, 3, 2]], tets[:, [1, 2, 3]], ], dim=0, ) # Sort each face to find duplicates sorted_faces, _ = torch.sort(all_faces, dim=1) # Find unique faces (boundary faces appear once) unique_faces, inverse_indices, counts = torch.unique( sorted_faces, dim=0, return_inverse=True, return_counts=True ) # Boundary faces (count == 1) boundary_mask = counts == 1 boundary_sorted = unique_faces[boundary_mask] # Map back to original face indices surface_faces = [] for sorted_face in boundary_sorted: # Find first occurrence in all_faces for i, face in enumerate(all_faces): if torch.equal(torch.sort(face)[0], sorted_face): surface_faces.append(face) break if len(surface_faces) == 0: return None return torch.stack(surface_faces).to(mesh.tetrahedra.device)
[docs] class SimplifiedCollisionHandler: """ Fast simplified collision detection for real-time simulation Uses distance-based repulsion without full CCD """
[docs] def __init__( self, collision_distance=0.02, repulsion_stiffness=1e4, max_checks_per_frame=1000, ): """ Args: collision_distance: minimum distance threshold repulsion_stiffness: strength of repulsion forces max_checks_per_frame: maximum collision checks per frame (for performance) """ self.d_min = collision_distance self.k_repulsion = repulsion_stiffness self.max_checks = max_checks_per_frame
[docs] def compute_self_collision_forces(self, mesh, positions): """ Fast self-collision detection using simple distance checks Args: mesh: TetrahedralMesh positions: :math:`(N, 3)` current positions Returns: forces: :math:`(N, 3)` repulsion forces """ forces = torch.zeros_like(positions) # Vectorized collision detection for speed N = positions.shape[0] # Limit checks for performance - sample vertex pairs num_samples = min(self.max_checks, N * 5) if num_samples == 0: return forces # Generate random pairs in batch indices_i = torch.randint(0, N, (num_samples,), device=positions.device) indices_j = torch.randint(0, N, (num_samples,), device=positions.device) # Filter out self-pairs valid = indices_i != indices_j indices_i = indices_i[valid] indices_j = indices_j[valid] if len(indices_i) == 0: return forces # Compute distances in batch pos_i = positions[indices_i] pos_j = positions[indices_j] diff = pos_i - pos_j dist = torch.norm(diff, dim=1, keepdim=True) + 1e-12 # Find colliding pairs colliding = dist.squeeze() < self.d_min if not colliding.any(): return forces # Compute repulsion forces for colliding pairs dist_colliding = dist[colliding] diff_colliding = diff[colliding] penetration = self.d_min - dist_colliding force_mag = self.k_repulsion * penetration / dist_colliding force = force_mag * diff_colliding # Accumulate forces (using scatter_add for efficiency) indices_i_colliding = indices_i[colliding] indices_j_colliding = indices_j[colliding] forces.index_add_(0, indices_i_colliding, force) forces.index_add_(0, indices_j_colliding, -force) return forces