# -*- coding: utf-8 -*-
"""
Module for creating a band matrix from various data structures.
This module provides a polymorphic function, `create_band_matrix`, that
filters a matrix to retain only the elements within a specified distance
from the main diagonal.
Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
# =============================================================================
# METADATA
# =============================================================================
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.4.1"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import functools
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
from gunz_cm.exceptions import PreprocError
import numpy as np
import pandas as pd
from pydantic import ConfigDict, validate_call
from scipy import sparse as sp
# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from ..consts import DataFrameSpecs
from .commons import _create_diag_mask_helper
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
@functools.singledispatch
def create_band_matrix(
matrix: t.Union[np.ndarray, sp.coo_matrix, pd.DataFrame],
max_k: t.Optional[int] = None,
remove_main_diag: bool = False,
*,
row_ids_colname: str = DataFrameSpecs.ROW_IDS,
col_ids_colname: str = DataFrameSpecs.COL_IDS,
) -> t.Union[np.ndarray, sp.coo_matrix, pd.DataFrame]:
"""
Creates a band matrix by keeping elements near the main diagonal.
This function filters a matrix to retain only the elements where the
absolute difference between the row and column index is less than or
equal to `max_k`.
Parameters
----------
matrix : np.ndarray, sp.coo_matrix, or pd.DataFrame
The input matrix to filter.
max_k : int, optional
The maximum distance from the main diagonal to keep. If None, all
elements are kept (no filtering by distance). Defaults to None.
remove_main_diag : bool, optional
If True, elements on the main diagonal (k=0) are removed.
Defaults to False.
row_ids_colname : str, optional
Column name for row IDs (for DataFrame input).
col_ids_colname : str, optional
Column name for column IDs (for DataFrame input).
Returns
-------
np.ndarray, sp.coo_matrix, or pd.DataFrame
A new matrix of the same type as the input, containing only the
elements within the specified band.
Examples
--------
"""
raise PreprocError(f"No implementation for data type: {type(matrix).__name__}")
@create_band_matrix.register(np.ndarray)
def _create_band_matrix_np(
matrix: np.ndarray,
max_k: t.Optional[int],
remove_main_diag: bool,
**kwargs,
) -> np.ndarray:
"""NumPy array-specific implementation for creating a band matrix.
Examples
--------
"""
if matrix.ndim != 2:
raise PreprocError("Input NumPy array must be 2-dimensional.")
# If no filtering is needed, return a copy.
if max_k is None and not remove_main_diag:
return matrix.copy()
# Bolt Optimization: Avoid O(N^2) memory allocation by skipping np.indices
# and boolean mask creation. Use np.triu/tril directly.
# If max_k is None, we only care about removing the diagonal.
if max_k is None:
band_matrix = matrix.copy()
if remove_main_diag:
np.fill_diagonal(band_matrix, 0)
return band_matrix
# max_k filters |col - row| <= max_k
# => row - max_k <= col <= row + max_k
# => col >= row - max_k (triu with k = -max_k)
# => col <= row + max_k (tril with k = +max_k)
# Keep elements above/on lower boundary
band_matrix = np.triu(matrix, k=-max_k)
# Keep elements below/on upper boundary
band_matrix = np.tril(band_matrix, k=max_k)
if remove_main_diag:
np.fill_diagonal(band_matrix, 0)
return band_matrix
@create_band_matrix.register(sp.coo_matrix)
def _create_band_matrix_coo(
matrix: sp.coo_matrix,
max_k: t.Optional[int],
remove_main_diag: bool,
**kwargs,
) -> sp.coo_matrix:
"""COO-matrix-specific implementation for creating a band matrix.
Examples
--------
"""
if matrix.nnz == 0:
return matrix.copy()
if max_k is None and not remove_main_diag:
return matrix.copy()
if max_k is None:
# This branch is only taken if remove_main_diag is True.
mask = (matrix.row != matrix.col)
else:
mask = _create_diag_mask_helper(
matrix.row,
matrix.col,
abs_k=True,
min_k=None,
max_k=max_k,
remove_main_diag=remove_main_diag
)
if not np.all(mask):
return sp.coo_matrix(
(matrix.data[mask], (matrix.row[mask], matrix.col[mask])),
shape=matrix.shape
)
return matrix.copy()
@create_band_matrix.register(pd.DataFrame)
def _create_band_matrix_df(
matrix: pd.DataFrame,
max_k: t.Optional[int],
remove_main_diag: bool,
*,
row_ids_colname: str,
col_ids_colname: str,
) -> pd.DataFrame:
"""DataFrame-specific implementation for creating a band matrix.
Examples
--------
"""
required_cols = [row_ids_colname, col_ids_colname]
missing = [col for col in required_cols if col not in matrix.columns]
if missing:
raise PreprocError(f"Missing required columns: {', '.join(missing)}")
if matrix.empty:
return matrix.copy()
if max_k is None and not remove_main_diag:
return matrix.copy()
if max_k is None:
# This branch is only taken if remove_main_diag is True.
mask = matrix[row_ids_colname] != matrix[col_ids_colname]
else:
mask = _create_diag_mask_helper(
matrix[row_ids_colname].values,
matrix[col_ids_colname].values,
abs_k=True,
min_k=None,
max_k=max_k,
remove_main_diag=remove_main_diag,
)
return matrix.loc[mask].copy()