Source code for gunz_cm.reconstructions.objectives.mds

# -*- coding: utf-8 -*-
"""
Implements Multi-Dimensional Scaling (MDS) and its weighted variant as loss
functions compatible with both NumPy and PyTorch.


Examples
--------
"""

# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__license__ = "Clear BSD"
__version__ = "1.2.0"
__email__ = "adhisant@tnt.uni-hannover.de"

__all__ = [
    "MultiDimensionalScaling",
    "WeightedMultiDimensionalScaling"
]


# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t


# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
from gunz_cm.exceptions import ReconstructionError
import numpy as np
import torch


# =============================================================================
# LOCAL IMPORTS
# =============================================================================
from .consts import AVAIL_NP_F, AVAIL_TORCH_F
from ._base_binary_op import BaseBinaryOP


# =============================================================================
# CONSTANTS
# =============================================================================
VALID_TERM: t.Final[t.List[str]] = ['abs', 'square']
VALID_REDUCTION: t.Final[t.List[str]] = ['mean', 'sum']


# =============================================================================
# LOSS FUNCTION IMPLEMENTATIONS
# =============================================================================

[docs] class MultiDimensionalScaling(BaseBinaryOP): """ Computes the Multi-Dimensional Scaling (MDS) loss between two tensors. This loss is calculated as the ratio of the error between input and target to the magnitude of the target, aggregated by a reduction function. It supports both NumPy arrays and PyTorch tensors. Parameters ---------- term : str, default='abs' The element-wise function to apply to the error and the target. Must be one of {'abs', 'square'}. reduction : str, default='mean' The method for reducing the element-wise losses to a single value. Must be one of {'mean', 'sum'}. eps : float, default=1e-8 A small epsilon value added to the denominator for numerical stability to prevent division by zero. Raises ------ ValueError If an invalid `term` or `reduction` is specified. Notes ----- The formula is: `Loss = reduction(term(target - input) / term(target))` Examples -------- """ def __init__( self, term: str = 'abs', reduction: str = 'mean', eps: float = 1e-8 ): """ Function __init__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ super().__init__() if term not in VALID_TERM: raise ReconstructionError( f"Invalid term function: '{term}'. " f"Must be one of {VALID_TERM}.") if reduction not in VALID_REDUCTION: raise ReconstructionError( f"Invalid reduction function: '{reduction}'. " f"Must be one of {VALID_REDUCTION}.") self.term = term self.reduction = reduction self.eps = eps # Retrieve functions during initialization to fail early. self._term_f_np = AVAIL_NP_F[self.term] self._reduction_f_np = AVAIL_NP_F[self.reduction] self._term_f_torch = AVAIL_TORCH_F[self.term] self._reduction_f_torch = AVAIL_TORCH_F[self.reduction] def __repr__(self) -> str: """Provides an unambiguous string representation for developers. Examples -------- """ return (f"{self.__class__.__name__}(" f"term='{self.term}', " f"reduction='{self.reduction}', " f"eps={self.eps})") def _call_numpy( self, input: np.ndarray, target: np.ndarray ) -> float: """Computes the MDS loss for NumPy inputs. Examples -------- """ numerator = self._term_f_np(target - input) denominator = self._term_f_np(target) + self.eps loss = numerator / denominator return self._reduction_f_np(loss) def _call_torch( self, input: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: """Computes the MDS loss for PyTorch inputs. Examples -------- """ numerator = self._term_f_torch(target - input) denominator = self._term_f_torch(target) + self.eps loss = numerator / denominator return self._reduction_f_torch(loss)
[docs] class WeightedMultiDimensionalScaling(BaseBinaryOP): """ Computes a weighted Multi-Dimensional Scaling (WMDS) loss. This loss function applies a weight to the term-wise error, where the weight is derived from the target tensor. Parameters ---------- term : str, default='square' The element-wise function to apply to the error. Must be one of {'abs', 'square'}. reduction : str, default='mean' The reduction method. Must be one of {'mean', 'sum'}. weight_exp : float, default=1.0 The exponent applied to the target to calculate the weights. eps : float, default=1e-8 A small epsilon value for numerical stability, used in the denominator of the weighted mean calculation. Raises ------ ValueError If an invalid `term` or `reduction` is specified. Notes ----- The formula is: weights = target ** weight_exp weighted_loss = weights * term(target - input) For 'mean' reduction, the result is `sum(weighted_loss) / sum(weights)`. For 'sum' reduction, the result is `sum(weighted_loss)`. Examples -------- """ def __init__( self, term: str = 'square', reduction: str = 'mean', weight_exp: float = 1.0, eps: float = 1e-8 ): """ Function __init__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ super().__init__() if term not in VALID_TERM: raise ReconstructionError( f"Invalid term function: '{term}'. " f"Must be one of {VALID_TERM}.") if reduction not in VALID_REDUCTION: raise ReconstructionError( f"Invalid reduction function: '{reduction}'. " f"Must be one of {VALID_REDUCTION}.") self.term = term self.reduction = reduction self.weight_exp = weight_exp self.eps = eps # Retrieve functions during initialization. self._term_f_np = AVAIL_NP_F[self.term] self._power_f_np = AVAIL_NP_F['power'] self._sum_f_np = AVAIL_NP_F['sum'] self._term_f_torch = AVAIL_TORCH_F[self.term] self._power_f_torch = AVAIL_TORCH_F['power'] self._sum_f_torch = AVAIL_TORCH_F['sum'] def __repr__(self) -> str: """Provides an unambiguous string representation for developers. Examples -------- """ return (f"{self.__class__.__name__}(" f"term='{self.term}', " f"reduction='{self.reduction}', " f"weight_exp={self.weight_exp}, " f"eps={self.eps})") def _call_numpy( self, input: np.ndarray, target: np.ndarray ) -> float: """Computes the WMDS loss for NumPy inputs. Examples -------- """ weights = self._power_f_np(target, self.weight_exp) term_loss = self._term_f_np(target - input) weighted_loss = weights * term_loss if self.reduction == 'mean': total_weight = self._sum_f_np(weights) return self._sum_f_np(weighted_loss) / (total_weight + self.eps) # Case: self.reduction == 'sum' return self._sum_f_np(weighted_loss) def _call_torch( self, input: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: """Computes the WMDS loss for PyTorch inputs. Examples -------- """ weights = self._power_f_torch(target, self.weight_exp) term_loss = self._term_f_torch(target - input) weighted_loss = weights * term_loss if self.reduction == 'mean': total_weight = self._sum_f_torch(weights) return self._sum_f_torch(weighted_loss) / (total_weight + self.eps) # Case: self.reduction == 'sum' return self._sum_f_torch(weighted_loss)