# -*- coding: utf-8 -*-
"""
Implements a general-purpose error function that is compatible with both
NumPy and PyTorch, applying specified term and reduction operations.
Examples
--------
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__license__ = "Clear BSD"
__version__ = "1.1.0"
__email__ = "adhisant@tnt.uni-hannover.de"
__all__ = [
"GeneralError"
]
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
from gunz_cm.exceptions import ReconstructionError
import numpy as np
import torch
# =============================================================================
# LOCAL IMPORTS
# =============================================================================
from .consts import AVAIL_NP_F, AVAIL_TORCH_F
from ._base_binary_op import BaseBinaryOP
# =============================================================================
# CONSTANTS
# =============================================================================
VALID_TERM: t.Final[t.List[str]] = ['abs', 'square', 'l1', 'l2']
VALID_REDUCTION: t.Final[t.List[str]] = ['mean', 'sum']
# =============================================================================
# CLASS IMPLEMENTATION
# =============================================================================
[docs]
class GeneralError(BaseBinaryOP):
"""
A general error function for computing loss between an input and a target.
This class provides a flexible way to calculate errors by combining different
term functions ('l1', 'l2') and reduction methods ('mean', 'sum'). It
handles both NumPy and PyTorch tensors.
Parameters
----------
term : str, default='square'
The term function to use. Valid options are 'abs', 'square', 'l1',
'l2'. 'abs' is mapped to 'l1', and 'square' is mapped to 'l2'.
reduction : str, default='mean'
The reduction method to apply to the loss. Valid options are 'mean',
'sum'.
Raises
------
ValueError
If an invalid `term` or `reduction` is specified.
Notes
-----
For PyTorch, this class leverages `F.l1_loss` and `F.mse_loss`, which
are efficient as they combine the term and reduction steps. For NumPy,
the operations are performed sequentially.
Examples
--------
"""
def __init__(
self,
term: str = 'square',
reduction: str = 'mean'
):
"""
Function __init__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
super().__init__()
if term not in VALID_TERM:
raise ReconstructionError(
f"Invalid term function: '{term}'. "
f"Must be one of {VALID_TERM}.")
if reduction not in VALID_REDUCTION:
raise ReconstructionError(
f"Invalid reduction function: '{reduction}'. "
f"Must be one of {VALID_REDUCTION}.")
# Normalize term names for internal consistency
self.term = 'l1' if term == 'abs' else 'l2' if term == 'square' else term
self.reduction = reduction
# Retrieve functions during initialization to fail early.
self._term_f_np = AVAIL_NP_F[self.term]
self._reduction_f_np = AVAIL_NP_F[self.reduction]
self._term_f_torch = AVAIL_TORCH_F[self.term]
def __repr__(self) -> str:
"""Provides an unambiguous string representation for developers.
Examples
--------
"""
return (f"{self.__class__.__name__}("
f"term='{self.term}', "
f"reduction='{self.reduction}')")
def _call_numpy(
self,
input: np.ndarray,
target: np.ndarray
) -> float:
"""
Computes the loss using NumPy arrays.
The term function (e.g., abs, square) is applied element-wise to the
difference, and the result is then aggregated using the reduction
function (e.g., mean, sum).
Examples
--------
"""
loss = self._term_f_np(target - input)
return self._reduction_f_np(loss)
def _call_torch(
self,
input: torch.Tensor,
target: torch.Tensor
) -> torch.Tensor:
"""
Computes the loss using PyTorch tensors.
This method directly uses PyTorch's optimized loss functions
(`l1_loss`, `mse_loss`) which handle the reduction internally.
Examples
--------
"""
# PyTorch's F.l1_loss and F.mse_loss are special cases as they
# combine the term and reduction steps. We pass the reduction
# string directly to them.
return self._term_f_torch(input, target, reduction=self.reduction)