Source code for gunz_cm.preprocs.band_matrix

# -*- coding: utf-8 -*-
"""
Module for creating a band matrix from various data structures.

This module provides a polymorphic function, `create_band_matrix`, that
filters a matrix to retain only the elements within a specified distance
from the main diagonal.


Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"

# =============================================================================
# METADATA
# =============================================================================
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.4.1"

# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import functools
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
from gunz_cm.exceptions import PreprocError
import numpy as np
import pandas as pd
from pydantic import ConfigDict, validate_call
from scipy import sparse as sp

# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from ..consts import DataFrameSpecs
from .commons import _create_diag_mask_helper


[docs] @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @functools.singledispatch def create_band_matrix( matrix: t.Union[np.ndarray, sp.coo_matrix, pd.DataFrame], max_k: t.Optional[int] = None, remove_main_diag: bool = False, *, row_ids_colname: str = DataFrameSpecs.ROW_IDS, col_ids_colname: str = DataFrameSpecs.COL_IDS, ) -> t.Union[np.ndarray, sp.coo_matrix, pd.DataFrame]: """ Creates a band matrix by keeping elements near the main diagonal. This function filters a matrix to retain only the elements where the absolute difference between the row and column index is less than or equal to `max_k`. Parameters ---------- matrix : np.ndarray, sp.coo_matrix, or pd.DataFrame The input matrix to filter. max_k : int, optional The maximum distance from the main diagonal to keep. If None, all elements are kept (no filtering by distance). Defaults to None. remove_main_diag : bool, optional If True, elements on the main diagonal (k=0) are removed. Defaults to False. row_ids_colname : str, optional Column name for row IDs (for DataFrame input). col_ids_colname : str, optional Column name for column IDs (for DataFrame input). Returns ------- np.ndarray, sp.coo_matrix, or pd.DataFrame A new matrix of the same type as the input, containing only the elements within the specified band. Examples -------- """ raise PreprocError(f"No implementation for data type: {type(matrix).__name__}")
@create_band_matrix.register(np.ndarray) def _create_band_matrix_np( matrix: np.ndarray, max_k: t.Optional[int], remove_main_diag: bool, **kwargs, ) -> np.ndarray: """NumPy array-specific implementation for creating a band matrix. Examples -------- """ if matrix.ndim != 2: raise PreprocError("Input NumPy array must be 2-dimensional.") # If no filtering is needed, return a copy. if max_k is None and not remove_main_diag: return matrix.copy() # Bolt Optimization: Avoid O(N^2) memory allocation by skipping np.indices # and boolean mask creation. Use np.triu/tril directly. # If max_k is None, we only care about removing the diagonal. if max_k is None: band_matrix = matrix.copy() if remove_main_diag: np.fill_diagonal(band_matrix, 0) return band_matrix # max_k filters |col - row| <= max_k # => row - max_k <= col <= row + max_k # => col >= row - max_k (triu with k = -max_k) # => col <= row + max_k (tril with k = +max_k) # Keep elements above/on lower boundary band_matrix = np.triu(matrix, k=-max_k) # Keep elements below/on upper boundary band_matrix = np.tril(band_matrix, k=max_k) if remove_main_diag: np.fill_diagonal(band_matrix, 0) return band_matrix @create_band_matrix.register(sp.coo_matrix) def _create_band_matrix_coo( matrix: sp.coo_matrix, max_k: t.Optional[int], remove_main_diag: bool, **kwargs, ) -> sp.coo_matrix: """COO-matrix-specific implementation for creating a band matrix. Examples -------- """ if matrix.nnz == 0: return matrix.copy() if max_k is None and not remove_main_diag: return matrix.copy() if max_k is None: # This branch is only taken if remove_main_diag is True. mask = (matrix.row != matrix.col) else: mask = _create_diag_mask_helper( matrix.row, matrix.col, abs_k=True, min_k=None, max_k=max_k, remove_main_diag=remove_main_diag ) if not np.all(mask): return sp.coo_matrix( (matrix.data[mask], (matrix.row[mask], matrix.col[mask])), shape=matrix.shape ) return matrix.copy() @create_band_matrix.register(pd.DataFrame) def _create_band_matrix_df( matrix: pd.DataFrame, max_k: t.Optional[int], remove_main_diag: bool, *, row_ids_colname: str, col_ids_colname: str, ) -> pd.DataFrame: """DataFrame-specific implementation for creating a band matrix. Examples -------- """ required_cols = [row_ids_colname, col_ids_colname] missing = [col for col in required_cols if col not in matrix.columns] if missing: raise PreprocError(f"Missing required columns: {', '.join(missing)}") if matrix.empty: return matrix.copy() if max_k is None and not remove_main_diag: return matrix.copy() if max_k is None: # This branch is only taken if remove_main_diag is True. mask = matrix[row_ids_colname] != matrix[col_ids_colname] else: mask = _create_diag_mask_helper( matrix[row_ids_colname].values, matrix[col_ids_colname].values, abs_k=True, min_k=None, max_k=max_k, remove_main_diag=remove_main_diag, ) return matrix.loc[mask].copy()