# -*- coding: utf-8 -*-
"""
Implements Multi-Dimensional Scaling (MDS) and its weighted variant as loss
functions compatible with both NumPy and PyTorch.
Examples
--------
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__license__ = "Clear BSD"
__version__ = "1.2.0"
__email__ = "adhisant@tnt.uni-hannover.de"
__all__ = [
"MultiDimensionalScaling",
"WeightedMultiDimensionalScaling"
]
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
from gunz_cm.exceptions import ReconstructionError
import numpy as np
import torch
# =============================================================================
# LOCAL IMPORTS
# =============================================================================
from .consts import AVAIL_NP_F, AVAIL_TORCH_F
from ._base_binary_op import BaseBinaryOP
# =============================================================================
# CONSTANTS
# =============================================================================
VALID_TERM: t.Final[t.List[str]] = ['abs', 'square']
VALID_REDUCTION: t.Final[t.List[str]] = ['mean', 'sum']
# =============================================================================
# LOSS FUNCTION IMPLEMENTATIONS
# =============================================================================
[docs]
class MultiDimensionalScaling(BaseBinaryOP):
"""
Computes the Multi-Dimensional Scaling (MDS) loss between two tensors.
This loss is calculated as the ratio of the error between input and target
to the magnitude of the target, aggregated by a reduction function.
It supports both NumPy arrays and PyTorch tensors.
Parameters
----------
term : str, default='abs'
The element-wise function to apply to the error and the target.
Must be one of {'abs', 'square'}.
reduction : str, default='mean'
The method for reducing the element-wise losses to a single value.
Must be one of {'mean', 'sum'}.
eps : float, default=1e-8
A small epsilon value added to the denominator for numerical stability
to prevent division by zero.
Raises
------
ValueError
If an invalid `term` or `reduction` is specified.
Notes
-----
The formula is: `Loss = reduction(term(target - input) / term(target))`
Examples
--------
"""
def __init__(
self,
term: str = 'abs',
reduction: str = 'mean',
eps: float = 1e-8
):
"""
Function __init__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
super().__init__()
if term not in VALID_TERM:
raise ReconstructionError(
f"Invalid term function: '{term}'. "
f"Must be one of {VALID_TERM}.")
if reduction not in VALID_REDUCTION:
raise ReconstructionError(
f"Invalid reduction function: '{reduction}'. "
f"Must be one of {VALID_REDUCTION}.")
self.term = term
self.reduction = reduction
self.eps = eps
# Retrieve functions during initialization to fail early.
self._term_f_np = AVAIL_NP_F[self.term]
self._reduction_f_np = AVAIL_NP_F[self.reduction]
self._term_f_torch = AVAIL_TORCH_F[self.term]
self._reduction_f_torch = AVAIL_TORCH_F[self.reduction]
def __repr__(self) -> str:
"""Provides an unambiguous string representation for developers.
Examples
--------
"""
return (f"{self.__class__.__name__}("
f"term='{self.term}', "
f"reduction='{self.reduction}', "
f"eps={self.eps})")
def _call_numpy(
self,
input: np.ndarray,
target: np.ndarray
) -> float:
"""Computes the MDS loss for NumPy inputs.
Examples
--------
"""
numerator = self._term_f_np(target - input)
denominator = self._term_f_np(target) + self.eps
loss = numerator / denominator
return self._reduction_f_np(loss)
def _call_torch(
self,
input: torch.Tensor,
target: torch.Tensor
) -> torch.Tensor:
"""Computes the MDS loss for PyTorch inputs.
Examples
--------
"""
numerator = self._term_f_torch(target - input)
denominator = self._term_f_torch(target) + self.eps
loss = numerator / denominator
return self._reduction_f_torch(loss)
[docs]
class WeightedMultiDimensionalScaling(BaseBinaryOP):
"""
Computes a weighted Multi-Dimensional Scaling (WMDS) loss.
This loss function applies a weight to the term-wise error, where the
weight is derived from the target tensor.
Parameters
----------
term : str, default='square'
The element-wise function to apply to the error.
Must be one of {'abs', 'square'}.
reduction : str, default='mean'
The reduction method. Must be one of {'mean', 'sum'}.
weight_exp : float, default=1.0
The exponent applied to the target to calculate the weights.
eps : float, default=1e-8
A small epsilon value for numerical stability, used in the denominator
of the weighted mean calculation.
Raises
------
ValueError
If an invalid `term` or `reduction` is specified.
Notes
-----
The formula is:
weights = target ** weight_exp
weighted_loss = weights * term(target - input)
For 'mean' reduction, the result is `sum(weighted_loss) / sum(weights)`.
For 'sum' reduction, the result is `sum(weighted_loss)`.
Examples
--------
"""
def __init__(
self,
term: str = 'square',
reduction: str = 'mean',
weight_exp: float = 1.0,
eps: float = 1e-8
):
"""
Function __init__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
super().__init__()
if term not in VALID_TERM:
raise ReconstructionError(
f"Invalid term function: '{term}'. "
f"Must be one of {VALID_TERM}.")
if reduction not in VALID_REDUCTION:
raise ReconstructionError(
f"Invalid reduction function: '{reduction}'. "
f"Must be one of {VALID_REDUCTION}.")
self.term = term
self.reduction = reduction
self.weight_exp = weight_exp
self.eps = eps
# Retrieve functions during initialization.
self._term_f_np = AVAIL_NP_F[self.term]
self._power_f_np = AVAIL_NP_F['power']
self._sum_f_np = AVAIL_NP_F['sum']
self._term_f_torch = AVAIL_TORCH_F[self.term]
self._power_f_torch = AVAIL_TORCH_F['power']
self._sum_f_torch = AVAIL_TORCH_F['sum']
def __repr__(self) -> str:
"""Provides an unambiguous string representation for developers.
Examples
--------
"""
return (f"{self.__class__.__name__}("
f"term='{self.term}', "
f"reduction='{self.reduction}', "
f"weight_exp={self.weight_exp}, "
f"eps={self.eps})")
def _call_numpy(
self,
input: np.ndarray,
target: np.ndarray
) -> float:
"""Computes the WMDS loss for NumPy inputs.
Examples
--------
"""
weights = self._power_f_np(target, self.weight_exp)
term_loss = self._term_f_np(target - input)
weighted_loss = weights * term_loss
if self.reduction == 'mean':
total_weight = self._sum_f_np(weights)
return self._sum_f_np(weighted_loss) / (total_weight + self.eps)
# Case: self.reduction == 'sum'
return self._sum_f_np(weighted_loss)
def _call_torch(
self,
input: torch.Tensor,
target: torch.Tensor
) -> torch.Tensor:
"""Computes the WMDS loss for PyTorch inputs.
Examples
--------
"""
weights = self._power_f_torch(target, self.weight_exp)
term_loss = self._term_f_torch(target - input)
weighted_loss = weights * term_loss
if self.reduction == 'mean':
total_weight = self._sum_f_torch(weights)
return self._sum_f_torch(weighted_loss) / (total_weight + self.eps)
# Case: self.reduction == 'sum'
return self._sum_f_torch(weighted_loss)