#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Procrustes Analysis
This code provides functions to perform Procrustes analysis
on two sets of points.
This version includes standard (L2), robust (L1), and IRLS methods for both
NumPy and PyTorch backends.
The tool uses NumPy for data handling, SciPy for the core Procrustes
algorithm.
Examples
--------
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "3.10.0"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import functools
import typing as t
from dataclasses import dataclass
from enum import Enum
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
from gunz_cm.exceptions import ReconstructionError
from gunz_cm.utils.logger import logger
import numpy as np
import torch
from pydantic import ConfigDict, validate_call
from scipy.linalg import orthogonal_procrustes
# =============================================================================
# DATA STRUCTURES & ERROR CLASSES
# =============================================================================
[docs]
class ProcrustesError(Exception):
"""Custom exception for errors during Procrustes analysis.
Examples
--------
"""
pass
[docs]
class Method(str, Enum):
"""
Class Method.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
VANILLA = "vanilla"
ROBUST_L1 = "robust-l1"
IRLS = "irls"
[docs]
@dataclass
class ProcrustesResult:
"""A dataclass to hold the results of a Procrustes analysis.
Examples
--------
"""
distance: t.Union[float, torch.Tensor]
scaled_input: t.Union[np.ndarray, torch.Tensor]
aligned_target: t.Union[np.ndarray, torch.Tensor]
centered_input: t.Union[np.ndarray, torch.Tensor]
centered_target: t.Union[np.ndarray, torch.Tensor]
rotation_matrix: t.Union[np.ndarray, torch.Tensor]
# =============================================================================
# CORE IMPLEMENTATION (SINGLE DISPATCH)
# =============================================================================
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
@functools.singledispatch
def procrustes_analysis(
input: t.Any,
target: t.Any,
method: Method = Method.VANILLA,
max_iter: int = 100,
tol: float = 1e-6,
) -> ProcrustesResult:
"""
Function procrustes_analysis.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
logger.info(f"Running Procrustes analysis on shapes {input.shape} and {target.shape} using {method.value} method")
"""
Performs a full Procrustes analysis on two sets of points.
This is a generic function that uses functools.singledispatch to call
the appropriate implementation based on the type of the input points.
"""
raise NotImplementedError(f"No Procrustes implementation for type {type(input)}")
@procrustes_analysis.register(np.ndarray)
def _procrustes_analysis_numpy(
input: np.ndarray,
target: np.ndarray,
method: Method = Method.VANILLA,
max_iter: int = 100,
tol: float = 1e-6,
) -> ProcrustesResult:
"""
Function _procrustes_analysis_numpy.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
logger.debug(f"Performing numpy Procrustes analysis using {method.value} method")
"""NumPy implementation for Procrustes analysis."""
if input.shape != target.shape:
raise ReconstructionError("Point sets must have the same number of points and dimensions.")
if input.ndim != 2:
raise ReconstructionError("Input point sets must be 2D arrays (num_points, n_dimensions).")
# --- Step 1: Centering and Scaling (common to all methods) ---
centroid1 = np.mean(input, axis=0)
centered_input = input - centroid1
centroid2 = np.mean(target, axis=0)
centered_target = target - centroid2
norm1 = np.linalg.norm(centered_input)
scaled_input = centered_input / (norm1 if norm1 > 1e-9 else 1.0)
norm2 = np.linalg.norm(centered_target)
scaled_target = centered_target / (norm2 if norm2 > 1e-9 else 1.0)
# --- Step 2: Find Optimal Rotation based on method ---
if method == Method.VANILLA:
rotation_matrix, _ = orthogonal_procrustes(scaled_input, scaled_target)
elif method == Method.ROBUST_L1:
rotation_matrix = np.eye(input.shape[1]) # Start with identity
for i in range(max_iter):
prev_R = rotation_matrix.copy()
residuals = scaled_input - (scaled_target @ rotation_matrix)
weights = 1.0 / (np.linalg.norm(residuals, ord=1, axis=1) + 1e-9)
W = np.diag(weights)
M = scaled_target.T @ W @ scaled_input
U, _, Vt = np.linalg.svd(M)
rotation_matrix = U @ Vt
if np.linalg.norm(rotation_matrix - prev_R) < tol:
break
elif method == Method.IRLS:
rotation_matrix = np.eye(input.shape[1]) # Start with identity
for i in range(max_iter):
prev_R = rotation_matrix.copy()
residuals = scaled_input - (scaled_target @ rotation_matrix)
weights = 1.0 / (np.linalg.norm(residuals, ord=2, axis=1) + 1e-9)
W = np.diag(weights)
M = scaled_target.T @ W @ scaled_input
U, _, Vt = np.linalg.svd(M)
rotation_matrix = U @ Vt
if np.linalg.norm(rotation_matrix - prev_R) < tol:
break
# --- Step 3: Final Alignment and Distance ---
aligned_target = scaled_target @ rotation_matrix
distance = np.linalg.norm(scaled_input - aligned_target)
return ProcrustesResult(
distance=distance,
scaled_input=scaled_input,
aligned_target=aligned_target,
centered_input=centered_input,
centered_target=centered_target,
rotation_matrix=rotation_matrix,
)
@procrustes_analysis.register(torch.Tensor)
def _procrustes_analysis_torch(
input: torch.Tensor,
target: torch.Tensor,
method: Method = Method.VANILLA,
max_iter: int = 100,
tol: float = 1e-6,
) -> ProcrustesResult:
"""
Function _procrustes_analysis_torch.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
logger.debug(f"Performing torch Procrustes analysis using {method.value} method")
"""
Differentiable PyTorch implementation for Procrustes analysis.
For robust methods (L1, IRLS), the gradient flows through the final
affine transformation but not through the iterative optimization that
finds the rotation matrix. The vanilla method is fully differentiable.
"""
if input.shape != target.shape:
raise ReconstructionError("Point sets must have the same number of points and dimensions.")
if input.ndim != 2:
raise ReconstructionError("Input point sets must be 2D tensors (num_points, n_dimensions).")
# --- Step 1: Centering and Scaling (common to all methods) ---
centroid1 = torch.mean(input, dim=0)
centered_input = input - centroid1
centroid2 = torch.mean(target, dim=0)
centered_target = target - centroid2
norm1 = torch.linalg.norm(centered_input)
scaled_input = centered_input / (norm1 if norm1 > 1e-9 else 1.0)
norm2 = torch.linalg.norm(centered_target)
scaled_target = centered_target / (norm2 if norm2 > 1e-9 else 1.0)
# --- Step 2: Find Optimal Rotation based on method ---
if method == Method.VANILLA:
M = scaled_target.T @ scaled_input
U, _, Vh = torch.linalg.svd(M)
rotation_matrix = U @ Vh
else: # Robust L1 and IRLS
rotation_matrix = torch.eye(input.shape[1], device=input.device, dtype=input.dtype)
p1_no_grad = scaled_input.detach()
p2_no_grad = scaled_target.detach()
for i in range(max_iter):
with torch.no_grad():
prev_R = rotation_matrix.clone()
residuals = p1_no_grad - (p2_no_grad @ rotation_matrix)
if method == Method.ROBUST_L1:
weights = 1.0 / (torch.linalg.norm(residuals, ord=1, dim=1) + 1e-9)
else: # IRLS
weights = 1.0 / (torch.linalg.norm(residuals, ord=2, dim=1) + 1e-9)
W = torch.diag(weights)
M = p2_no_grad.T @ W @ p1_no_grad
U, _, Vh = torch.linalg.svd(M)
rotation_matrix = U @ Vh
if torch.linalg.norm(rotation_matrix - prev_R) < tol:
break
aligned_target = scaled_target @ rotation_matrix
distance = torch.linalg.norm(scaled_input - aligned_target)
return ProcrustesResult(
distance=distance,
scaled_input=scaled_input,
aligned_target=aligned_target,
centered_input=centered_input,
centered_target=centered_target,
rotation_matrix=rotation_matrix,
)