Source code for gunz_cm.preprocs.log_scaler

"""
Module.

Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
import functools
import typing as t
from pydantic import validate_call, ConfigDict
from gunz_cm.exceptions import PreprocError
import numpy as np
from scipy.sparse import coo_matrix, csr_matrix, issparse
from ..utils.matrix import _non_diagonal_mask


[docs] @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @functools.singledispatch def log_scale_matrix( matrix: t.Union[np.ndarray, coo_matrix, csr_matrix], exclude_diagonal: bool = False, inplace: bool = False ) -> t.Union[np.ndarray, coo_matrix, csr_matrix]: """ Optimized log(1+v) scaling with in-place operation support. Notes ----- This function applies a log(1+v) transformation to the input matrix. It supports both dense and sparse matrices. If `exclude_diagonal` is True, the diagonal elements are set to zero for dense matrices or removed for sparse matrices. The `inplace` parameter allows modifying the matrix in-place to save memory. Parameters ---------- matrix : Union[np.ndarray, coo_matrix, csr_matrix] Input matrix for log scaling. exclude_diagonal : bool, optional Zero diagonal (dense) or remove entries (sparse), default False. inplace : bool, optional Modify matrix in-place instead of creating new, default False. Returns ------- Union[np.ndarray, coo_matrix, csr_matrix] Log-scaled matrix (original if inplace=True). Examples -------- Examples -------- """ raise PreprocError(f"No implementation for data type: {type(matrix).__name__}")
@log_scale_matrix.register(np.ndarray) def _( matrix: np.ndarray, exclude_diagonal: bool, inplace: bool, **kwargs, ) -> np.ndarray: """ Vectorized dense matrix processing with memory reuse. Notes ----- This function handles the log scaling of dense matrices. If `inplace` is True, it modifies the matrix in-place. If `exclude_diag` is True, it sets the diagonal elements to zero. Parameters ---------- matrix : np.ndarray Input dense matrix. exclude_diag : bool Zero diagonal if True. inplace : bool Modify matrix in-place if True. Returns ------- np.ndarray Log-scaled dense matrix. Examples -------- Examples -------- """ _validate_inputs(matrix, exclude_diagonal, inplace) if inplace: # Use pre-allocated memory with in-place operations np.log1p(matrix, out=matrix) if exclude_diagonal: np.fill_diagonal(matrix, 0.0) return matrix # Create optimized output buffer scaled = np.empty_like(matrix, dtype=np.float64) np.log1p(matrix, out=scaled) if exclude_diagonal: np.fill_diagonal(scaled, 0.0) return scaled @log_scale_matrix.register(coo_matrix) @log_scale_matrix.register(csr_matrix) def _( matrix: t.Union[coo_matrix, csr_matrix], exclude_diagonal: bool, inplace: bool, **kwargs, ) -> t.Union[coo_matrix, csr_matrix]: """ Batched sparse processing with zero-copy optimizations. Notes ----- This function handles the log scaling of sparse matrices. If `exclude_diag` is True, it removes diagonal entries. If `inplace` is True, it modifies the matrix in-place. Parameters ---------- matrix : Union[coo_matrix, csr_matrix] Input sparse matrix. exclude_diag : bool Remove diagonal entries if True. inplace : bool Modify matrix in-place if True. Returns ------- Union[coo_matrix, csr_matrix] Log-scaled sparse matrix. Examples -------- Examples -------- """ _validate_inputs(matrix, exclude_diagonal, inplace) if exclude_diagonal: return _process_sparse_exclude_diag(matrix) if inplace: # Direct memory modification for non-diagonal case np.log1p(matrix.data, out=matrix.data) return matrix # Create new matrix with pre-allocated data buffer data = np.empty_like(matrix.data) np.log1p(matrix.data, out=data) if isinstance(matrix, coo_matrix): return coo_matrix((data, (matrix.row, matrix.col)), shape=matrix.shape) else: return matrix.__class__((data, matrix.indices, matrix.indptr), shape=matrix.shape) def _process_sparse_exclude_diag( matrix: t.Union[coo_matrix, csr_matrix] ) -> t.Union[coo_matrix, csr_matrix]: """ Optimized sparse diagonal exclusion handler. Notes ----- This function processes sparse matrices to exclude diagonal entries during log scaling. Parameters ---------- matrix : Union[coo_matrix, csr_matrix] Input sparse matrix. Returns ------- Union[coo_matrix, csr_matrix] Sparse matrix with diagonal entries excluded. Examples -------- Examples -------- """ mask = _non_diagonal_mask(matrix) data = np.empty_like(matrix.data[mask]) np.log1p(matrix.data[mask], out=data) return _rebuild_sparse(matrix, data, mask) def _rebuild_sparse( matrix: t.Union[coo_matrix, csr_matrix], data: np.ndarray, mask: np.ndarray ) -> t.Union[coo_matrix, csr_matrix]: """ Efficient sparse matrix reconstruction. Notes ----- This function reconstructs a sparse matrix from processed data and a mask. Parameters ---------- matrix : Union[coo_matrix, csr_matrix] Original sparse matrix. data : np.ndarray Processed data for non-diagonal elements. mask : np.ndarray Boolean mask indicating non-diagonal elements. Returns ------- Union[coo_matrix, csr_matrix] Reconstructed sparse matrix. Examples -------- Examples -------- """ if isinstance(matrix, coo_matrix): return coo_matrix( (data, (matrix.row[mask], matrix.col[mask])), shape=matrix.shape, copy=False ) # Optimized CSR reconstruction using batched operations new_indices = matrix.indices[mask] # Bolt Optimization: Avoid expanding row indices with np.repeat which is O(NNZ) memory. # Instead, use scipy's optimized C implementation to count elements per row. row_sums = csr_matrix( (mask.astype(np.int32), matrix.indices, matrix.indptr), shape=matrix.shape ).sum(axis=1) # Convert result to flat array and compute cumsum for indptr new_row_counts = np.asarray(row_sums).flatten() new_indptr = np.concatenate([ [0], np.cumsum(new_row_counts) ]) return csr_matrix( (data, new_indices, new_indptr), shape=matrix.shape, copy=False ) def _validate_inputs( matrix: t.Union[np.ndarray, coo_matrix, csr_matrix], exclude_diag: bool, inplace: bool ) -> None: """ Vectorized input validation with early exit. Notes ----- This function validates the input matrix for log scaling. It checks for empty matrices, value ranges, square matrices for diagonal operations, and writeable buffers for in-place operations. Parameters ---------- matrix : Union[np.ndarray, coo_matrix, csr_matrix] Input matrix. exclude_diag : bool Exclude diagonal if True. inplace : bool Modify matrix in-place if True. Returns ------- None Examples -------- Examples -------- """ # Matrix emptiness check if (issparse(matrix) and matrix.nnz == 0) or \ (not issparse(matrix) and matrix.size == 0): raise PreprocError("Cannot process empty matrix") # Value range validation data = matrix.data if issparse(matrix) else matrix if np.any(data < -1): raise PreprocError("Matrix contains values < -1 which cannot be log-scaled") # Square matrix validation for diagonal operations if exclude_diag and matrix.shape[0] != matrix.shape[1]: raise PreprocError(f"Matrix must be square for diagonal exclusion. Got shape {matrix.shape}") # In-place compatibility checks if inplace: if issparse(matrix): if exclude_diag: raise PreprocError("Cannot modify sparse matrix in-place when excluding diagonal") if not matrix.data.flags.writeable: raise PreprocError("Sparse matrix data buffer is not writeable") else: if not matrix.flags.writeable: raise PreprocError("Dense matrix buffer is not writeable") if not np.issubdtype(matrix.dtype, np.floating): raise PreprocError("In-place operations require floating-point dtype")