Source code for gunz_cm.preprocs.count_filters

# -*- coding: utf-8 -*-
"""
Provides functionality to filter matrix data based on raw counts.

This module uses functools.singledispatch and Pydantic's validate_call
to provide a single, robust `filter_by_raw_counts` function that can
handle multiple data types safely.


Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"

# =============================================================================
# METADATA
# =============================================================================
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.2.3"


# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import functools
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
from gunz_cm.exceptions import PreprocError
import numpy as np
import pandas as pd
from pandas.api.types import is_numeric_dtype
from pydantic import ConfigDict, validate_call
from scipy import sparse as sp

# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from ..consts import DataFrameSpecs


[docs] @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @functools.singledispatch def filter_by_raw_counts( matrix: t.Union[pd.DataFrame, sp.coo_matrix, sp.csr_matrix, np.ndarray], min_val: t.Optional[int] = None, max_val: t.Optional[int] = None, *, raw_counts_colname: str = DataFrameSpecs.RAW_COUNTS, ) -> t.Union[pd.DataFrame, sp.coo_matrix, sp.csr_matrix, np.ndarray]: """ Filter entries of a matrix based on raw interaction counts. This function uses Pydantic to validate inputs and single dispatch to route to the correct implementation based on the input data type. Parameters ---------- matrix : pd.DataFrame, sp.coo_matrix, sp.csr_matrix, or np.ndarray The input data. For NumPy arrays, this filters by setting values outside the range to 0. For sparse matrices and DataFrames, it removes the entries. min_val : int, optional The minimum raw count value to include (inclusive). Defaults to None. max_val : int, optional The maximum raw count value to include (inclusive). Defaults to None. raw_counts_colname : str, optional The name of the column containing raw counts. This is only used if the input is a pandas DataFrame. Defaults to `DataFrameSpecs.RAW_COUNTS`. Returns ------- pd.DataFrame, sp.coo_matrix, sp.csr_matrix, or np.ndarray A new data object of the same type as the input, containing only the filtered entries. Raises ------ pydantic.ValidationError If any argument's type is incorrect. ValueError If `min_val` > `max_val`, or if `raw_counts_colname` is not found. TypeError If the target column in a DataFrame is not numeric. Examples -------- """ raise PreprocError(f"No implementation for data type: {type(matrix).__name__}")
@filter_by_raw_counts.register(pd.DataFrame) def _filter_by_raw_counts_df( matrix: pd.DataFrame, min_val: t.Optional[int], max_val: t.Optional[int], *, raw_counts_colname: str, ) -> pd.DataFrame: """DataFrame-specific implementation for filtering. Examples -------- """ if raw_counts_colname not in matrix.columns: raise PreprocError(f"Column '{raw_counts_colname}' not found in DataFrame.") if not is_numeric_dtype(matrix[raw_counts_colname]): raise PreprocError(f"Column '{raw_counts_colname}' must be a numeric type.") if min_val is not None and max_val is not None and min_val > max_val: raise PreprocError(f"min_val ({min_val}) cannot be greater than max_val ({max_val}).") if matrix.empty: return matrix.copy() mask = pd.Series(True, index=matrix.index) if min_val is not None: mask &= matrix[raw_counts_colname] >= min_val if max_val is not None: mask &= matrix[raw_counts_colname] <= max_val return matrix.loc[mask].copy() @filter_by_raw_counts.register(sp.coo_matrix) def _filter_by_raw_counts_coo( matrix: sp.coo_matrix, min_val: t.Optional[int], max_val: t.Optional[int], *, raw_counts_colname: str, ) -> sp.coo_matrix: """COO-matrix-specific implementation for filtering. Examples -------- """ if min_val is not None and max_val is not None and min_val > max_val: raise PreprocError(f"min_val ({min_val}) cannot be greater than max_val ({max_val}).") if matrix.nnz == 0: return matrix mask = np.ones_like(matrix.data, dtype=bool) if min_val is not None: mask &= matrix.data >= min_val if max_val is not None: mask &= matrix.data <= max_val if not np.all(mask): return sp.coo_matrix( (matrix.data[mask], (matrix.row[mask], matrix.col[mask])), shape=matrix.shape, ) return matrix @filter_by_raw_counts.register(sp.csr_matrix) def _filter_by_raw_counts_csr( matrix: sp.csr_matrix, min_val: t.Optional[int], max_val: t.Optional[int], *, raw_counts_colname: str, ) -> sp.csr_matrix: """CSR-matrix-specific implementation for filtering. Examples -------- """ coo_equiv = matrix.tocoo() filtered_coo = filter_by_raw_counts(coo_equiv, min_val, max_val) if filtered_coo.nnz < coo_equiv.nnz: return filtered_coo.tocsr() return matrix @filter_by_raw_counts.register(np.ndarray) def _filter_by_raw_counts_np( matrix: np.ndarray, min_val: t.Optional[int], max_val: t.Optional[int], *, raw_counts_colname: str, ) -> np.ndarray: """ NumPy array implementation for filtering. This function filters the dense matrix by setting values outside the specified range to zero. Examples -------- """ if min_val is not None and max_val is not None and min_val > max_val: raise PreprocError(f"min_val ({min_val}) cannot be greater than max_val ({max_val}).") filtered_matrix = matrix.copy() mask = np.full(matrix.shape, True) if min_val is not None: mask &= (matrix >= min_val) if max_val is not None: mask &= (matrix <= max_val) filtered_matrix[~mask] = 0 return filtered_matrix