Source code for gunz_cm.reconstructions.objectives.general_error

# -*- coding: utf-8 -*-
"""
Implements a general-purpose error function that is compatible with both
NumPy and PyTorch, applying specified term and reduction operations.


Examples
--------
"""

# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__license__ = "Clear BSD"
__version__ = "1.1.0"
__email__ = "adhisant@tnt.uni-hannover.de"

__all__ = [
    "GeneralError"
]


# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t


# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
from gunz_cm.exceptions import ReconstructionError
import numpy as np
import torch


# =============================================================================
# LOCAL IMPORTS
# =============================================================================
from .consts import AVAIL_NP_F, AVAIL_TORCH_F
from ._base_binary_op import BaseBinaryOP


# =============================================================================
# CONSTANTS
# =============================================================================
VALID_TERM: t.Final[t.List[str]] = ['abs', 'square', 'l1', 'l2']
VALID_REDUCTION: t.Final[t.List[str]] = ['mean', 'sum']


# =============================================================================
# CLASS IMPLEMENTATION
# =============================================================================

[docs] class GeneralError(BaseBinaryOP): """ A general error function for computing loss between an input and a target. This class provides a flexible way to calculate errors by combining different term functions ('l1', 'l2') and reduction methods ('mean', 'sum'). It handles both NumPy and PyTorch tensors. Parameters ---------- term : str, default='square' The term function to use. Valid options are 'abs', 'square', 'l1', 'l2'. 'abs' is mapped to 'l1', and 'square' is mapped to 'l2'. reduction : str, default='mean' The reduction method to apply to the loss. Valid options are 'mean', 'sum'. Raises ------ ValueError If an invalid `term` or `reduction` is specified. Notes ----- For PyTorch, this class leverages `F.l1_loss` and `F.mse_loss`, which are efficient as they combine the term and reduction steps. For NumPy, the operations are performed sequentially. Examples -------- """ def __init__( self, term: str = 'square', reduction: str = 'mean' ): """ Function __init__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ super().__init__() if term not in VALID_TERM: raise ReconstructionError( f"Invalid term function: '{term}'. " f"Must be one of {VALID_TERM}.") if reduction not in VALID_REDUCTION: raise ReconstructionError( f"Invalid reduction function: '{reduction}'. " f"Must be one of {VALID_REDUCTION}.") # Normalize term names for internal consistency self.term = 'l1' if term == 'abs' else 'l2' if term == 'square' else term self.reduction = reduction # Retrieve functions during initialization to fail early. self._term_f_np = AVAIL_NP_F[self.term] self._reduction_f_np = AVAIL_NP_F[self.reduction] self._term_f_torch = AVAIL_TORCH_F[self.term] def __repr__(self) -> str: """Provides an unambiguous string representation for developers. Examples -------- """ return (f"{self.__class__.__name__}(" f"term='{self.term}', " f"reduction='{self.reduction}')") def _call_numpy( self, input: np.ndarray, target: np.ndarray ) -> float: """ Computes the loss using NumPy arrays. The term function (e.g., abs, square) is applied element-wise to the difference, and the result is then aggregated using the reduction function (e.g., mean, sum). Examples -------- """ loss = self._term_f_np(target - input) return self._reduction_f_np(loss) def _call_torch( self, input: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: """ Computes the loss using PyTorch tensors. This method directly uses PyTorch's optimized loss functions (`l1_loss`, `mse_loss`) which handle the reduction internally. Examples -------- """ # PyTorch's F.l1_loss and F.mse_loss are special cases as they # combine the term and reduction steps. We pass the reduction # string directly to them. return self._term_f_torch(input, target, reduction=self.reduction)