Source code for gunz_cm.reconstructions.objectives.procrustes

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Procrustes Analysis

This code provides functions to perform Procrustes analysis
on two sets of points. 

This version includes standard (L2), robust (L1), and IRLS methods for both
NumPy and PyTorch backends.

The tool uses NumPy for data handling, SciPy for the core Procrustes
algorithm.


Examples
--------
"""

# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "3.10.0"

# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import functools
import typing as t
from dataclasses import dataclass
from enum import Enum

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
from gunz_cm.exceptions import ReconstructionError
from gunz_cm.utils.logger import logger
import numpy as np
import torch
from pydantic import ConfigDict, validate_call
from scipy.linalg import orthogonal_procrustes

# =============================================================================
# DATA STRUCTURES & ERROR CLASSES
# =============================================================================

[docs] class ProcrustesError(Exception): """Custom exception for errors during Procrustes analysis. Examples -------- """ pass
[docs] class Method(str, Enum): """ Class Method. Parameters ---------- Returns ------- Examples -------- Notes ----- """ VANILLA = "vanilla" ROBUST_L1 = "robust-l1" IRLS = "irls"
[docs] @dataclass class ProcrustesResult: """A dataclass to hold the results of a Procrustes analysis. Examples -------- """ distance: t.Union[float, torch.Tensor] scaled_input: t.Union[np.ndarray, torch.Tensor] aligned_target: t.Union[np.ndarray, torch.Tensor] centered_input: t.Union[np.ndarray, torch.Tensor] centered_target: t.Union[np.ndarray, torch.Tensor] rotation_matrix: t.Union[np.ndarray, torch.Tensor]
# ============================================================================= # CORE IMPLEMENTATION (SINGLE DISPATCH) # =============================================================================
[docs] @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @functools.singledispatch def procrustes_analysis( input: t.Any, target: t.Any, method: Method = Method.VANILLA, max_iter: int = 100, tol: float = 1e-6, ) -> ProcrustesResult: """ Function procrustes_analysis. Parameters ---------- Returns ------- Examples -------- Notes ----- """ logger.info(f"Running Procrustes analysis on shapes {input.shape} and {target.shape} using {method.value} method") """ Performs a full Procrustes analysis on two sets of points. This is a generic function that uses functools.singledispatch to call the appropriate implementation based on the type of the input points. """ raise NotImplementedError(f"No Procrustes implementation for type {type(input)}")
@procrustes_analysis.register(np.ndarray) def _procrustes_analysis_numpy( input: np.ndarray, target: np.ndarray, method: Method = Method.VANILLA, max_iter: int = 100, tol: float = 1e-6, ) -> ProcrustesResult: """ Function _procrustes_analysis_numpy. Parameters ---------- Returns ------- Examples -------- Notes ----- """ logger.debug(f"Performing numpy Procrustes analysis using {method.value} method") """NumPy implementation for Procrustes analysis.""" if input.shape != target.shape: raise ReconstructionError("Point sets must have the same number of points and dimensions.") if input.ndim != 2: raise ReconstructionError("Input point sets must be 2D arrays (num_points, n_dimensions).") # --- Step 1: Centering and Scaling (common to all methods) --- centroid1 = np.mean(input, axis=0) centered_input = input - centroid1 centroid2 = np.mean(target, axis=0) centered_target = target - centroid2 norm1 = np.linalg.norm(centered_input) scaled_input = centered_input / (norm1 if norm1 > 1e-9 else 1.0) norm2 = np.linalg.norm(centered_target) scaled_target = centered_target / (norm2 if norm2 > 1e-9 else 1.0) # --- Step 2: Find Optimal Rotation based on method --- if method == Method.VANILLA: rotation_matrix, _ = orthogonal_procrustes(scaled_input, scaled_target) elif method == Method.ROBUST_L1: rotation_matrix = np.eye(input.shape[1]) # Start with identity for i in range(max_iter): prev_R = rotation_matrix.copy() residuals = scaled_input - (scaled_target @ rotation_matrix) weights = 1.0 / (np.linalg.norm(residuals, ord=1, axis=1) + 1e-9) W = np.diag(weights) M = scaled_target.T @ W @ scaled_input U, _, Vt = np.linalg.svd(M) rotation_matrix = U @ Vt if np.linalg.norm(rotation_matrix - prev_R) < tol: break elif method == Method.IRLS: rotation_matrix = np.eye(input.shape[1]) # Start with identity for i in range(max_iter): prev_R = rotation_matrix.copy() residuals = scaled_input - (scaled_target @ rotation_matrix) weights = 1.0 / (np.linalg.norm(residuals, ord=2, axis=1) + 1e-9) W = np.diag(weights) M = scaled_target.T @ W @ scaled_input U, _, Vt = np.linalg.svd(M) rotation_matrix = U @ Vt if np.linalg.norm(rotation_matrix - prev_R) < tol: break # --- Step 3: Final Alignment and Distance --- aligned_target = scaled_target @ rotation_matrix distance = np.linalg.norm(scaled_input - aligned_target) return ProcrustesResult( distance=distance, scaled_input=scaled_input, aligned_target=aligned_target, centered_input=centered_input, centered_target=centered_target, rotation_matrix=rotation_matrix, ) @procrustes_analysis.register(torch.Tensor) def _procrustes_analysis_torch( input: torch.Tensor, target: torch.Tensor, method: Method = Method.VANILLA, max_iter: int = 100, tol: float = 1e-6, ) -> ProcrustesResult: """ Function _procrustes_analysis_torch. Parameters ---------- Returns ------- Examples -------- Notes ----- """ logger.debug(f"Performing torch Procrustes analysis using {method.value} method") """ Differentiable PyTorch implementation for Procrustes analysis. For robust methods (L1, IRLS), the gradient flows through the final affine transformation but not through the iterative optimization that finds the rotation matrix. The vanilla method is fully differentiable. """ if input.shape != target.shape: raise ReconstructionError("Point sets must have the same number of points and dimensions.") if input.ndim != 2: raise ReconstructionError("Input point sets must be 2D tensors (num_points, n_dimensions).") # --- Step 1: Centering and Scaling (common to all methods) --- centroid1 = torch.mean(input, dim=0) centered_input = input - centroid1 centroid2 = torch.mean(target, dim=0) centered_target = target - centroid2 norm1 = torch.linalg.norm(centered_input) scaled_input = centered_input / (norm1 if norm1 > 1e-9 else 1.0) norm2 = torch.linalg.norm(centered_target) scaled_target = centered_target / (norm2 if norm2 > 1e-9 else 1.0) # --- Step 2: Find Optimal Rotation based on method --- if method == Method.VANILLA: M = scaled_target.T @ scaled_input U, _, Vh = torch.linalg.svd(M) rotation_matrix = U @ Vh else: # Robust L1 and IRLS rotation_matrix = torch.eye(input.shape[1], device=input.device, dtype=input.dtype) p1_no_grad = scaled_input.detach() p2_no_grad = scaled_target.detach() for i in range(max_iter): with torch.no_grad(): prev_R = rotation_matrix.clone() residuals = p1_no_grad - (p2_no_grad @ rotation_matrix) if method == Method.ROBUST_L1: weights = 1.0 / (torch.linalg.norm(residuals, ord=1, dim=1) + 1e-9) else: # IRLS weights = 1.0 / (torch.linalg.norm(residuals, ord=2, dim=1) + 1e-9) W = torch.diag(weights) M = p2_no_grad.T @ W @ p1_no_grad U, _, Vh = torch.linalg.svd(M) rotation_matrix = U @ Vh if torch.linalg.norm(rotation_matrix - prev_R) < tol: break aligned_target = scaled_target @ rotation_matrix distance = torch.linalg.norm(scaled_input - aligned_target) return ProcrustesResult( distance=distance, scaled_input=scaled_input, aligned_target=aligned_target, centered_input=centered_input, centered_target=centered_target, rotation_matrix=rotation_matrix, )