# -*- coding: utf-8 -*-
"""
Provides functionality to filter matrix data based on raw counts.
This module uses functools.singledispatch and Pydantic's validate_call
to provide a single, robust `filter_by_raw_counts` function that can
handle multiple data types safely.
Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
# =============================================================================
# METADATA
# =============================================================================
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.2.3"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import functools
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
from gunz_cm.exceptions import PreprocError
import numpy as np
import pandas as pd
from pandas.api.types import is_numeric_dtype
from pydantic import ConfigDict, validate_call
from scipy import sparse as sp
# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from ..consts import DataFrameSpecs
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
@functools.singledispatch
def filter_by_raw_counts(
matrix: t.Union[pd.DataFrame, sp.coo_matrix, sp.csr_matrix, np.ndarray],
min_val: t.Optional[int] = None,
max_val: t.Optional[int] = None,
*,
raw_counts_colname: str = DataFrameSpecs.RAW_COUNTS,
) -> t.Union[pd.DataFrame, sp.coo_matrix, sp.csr_matrix, np.ndarray]:
"""
Filter entries of a matrix based on raw interaction counts.
This function uses Pydantic to validate inputs and single dispatch to
route to the correct implementation based on the input data type.
Parameters
----------
matrix : pd.DataFrame, sp.coo_matrix, sp.csr_matrix, or np.ndarray
The input data. For NumPy arrays, this filters by setting values
outside the range to 0. For sparse matrices and DataFrames, it
removes the entries.
min_val : int, optional
The minimum raw count value to include (inclusive). Defaults to None.
max_val : int, optional
The maximum raw count value to include (inclusive). Defaults to None.
raw_counts_colname : str, optional
The name of the column containing raw counts. This is only used if the
input is a pandas DataFrame.
Defaults to `DataFrameSpecs.RAW_COUNTS`.
Returns
-------
pd.DataFrame, sp.coo_matrix, sp.csr_matrix, or np.ndarray
A new data object of the same type as the input, containing only the
filtered entries.
Raises
------
pydantic.ValidationError
If any argument's type is incorrect.
ValueError
If `min_val` > `max_val`, or if `raw_counts_colname` is not found.
TypeError
If the target column in a DataFrame is not numeric.
Examples
--------
"""
raise PreprocError(f"No implementation for data type: {type(matrix).__name__}")
@filter_by_raw_counts.register(pd.DataFrame)
def _filter_by_raw_counts_df(
matrix: pd.DataFrame,
min_val: t.Optional[int],
max_val: t.Optional[int],
*,
raw_counts_colname: str,
) -> pd.DataFrame:
"""DataFrame-specific implementation for filtering.
Examples
--------
"""
if raw_counts_colname not in matrix.columns:
raise PreprocError(f"Column '{raw_counts_colname}' not found in DataFrame.")
if not is_numeric_dtype(matrix[raw_counts_colname]):
raise PreprocError(f"Column '{raw_counts_colname}' must be a numeric type.")
if min_val is not None and max_val is not None and min_val > max_val:
raise PreprocError(f"min_val ({min_val}) cannot be greater than max_val ({max_val}).")
if matrix.empty:
return matrix.copy()
mask = pd.Series(True, index=matrix.index)
if min_val is not None:
mask &= matrix[raw_counts_colname] >= min_val
if max_val is not None:
mask &= matrix[raw_counts_colname] <= max_val
return matrix.loc[mask].copy()
@filter_by_raw_counts.register(sp.coo_matrix)
def _filter_by_raw_counts_coo(
matrix: sp.coo_matrix,
min_val: t.Optional[int],
max_val: t.Optional[int],
*,
raw_counts_colname: str,
) -> sp.coo_matrix:
"""COO-matrix-specific implementation for filtering.
Examples
--------
"""
if min_val is not None and max_val is not None and min_val > max_val:
raise PreprocError(f"min_val ({min_val}) cannot be greater than max_val ({max_val}).")
if matrix.nnz == 0:
return matrix
mask = np.ones_like(matrix.data, dtype=bool)
if min_val is not None:
mask &= matrix.data >= min_val
if max_val is not None:
mask &= matrix.data <= max_val
if not np.all(mask):
return sp.coo_matrix(
(matrix.data[mask], (matrix.row[mask], matrix.col[mask])),
shape=matrix.shape,
)
return matrix
@filter_by_raw_counts.register(sp.csr_matrix)
def _filter_by_raw_counts_csr(
matrix: sp.csr_matrix,
min_val: t.Optional[int],
max_val: t.Optional[int],
*,
raw_counts_colname: str,
) -> sp.csr_matrix:
"""CSR-matrix-specific implementation for filtering.
Examples
--------
"""
coo_equiv = matrix.tocoo()
filtered_coo = filter_by_raw_counts(coo_equiv, min_val, max_val)
if filtered_coo.nnz < coo_equiv.nnz:
return filtered_coo.tocsr()
return matrix
@filter_by_raw_counts.register(np.ndarray)
def _filter_by_raw_counts_np(
matrix: np.ndarray,
min_val: t.Optional[int],
max_val: t.Optional[int],
*,
raw_counts_colname: str,
) -> np.ndarray:
"""
NumPy array implementation for filtering.
This function filters the dense matrix by setting values outside the
specified range to zero.
Examples
--------
"""
if min_val is not None and max_val is not None and min_val > max_val:
raise PreprocError(f"min_val ({min_val}) cannot be greater than max_val ({max_val}).")
filtered_matrix = matrix.copy()
mask = np.full(matrix.shape, True)
if min_val is not None:
mask &= (matrix >= min_val)
if max_val is not None:
mask &= (matrix <= max_val)
filtered_matrix[~mask] = 0
return filtered_matrix