# -*- coding: utf-8 -*-
"""
Centralized manager for genomic and structural masking logic.
"""
#? Metadata
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.2.0"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import os
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import numpy as np
from pydantic import validate_call, ConfigDict
from loguru import logger
# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from .. import loaders as cm_loaders
from ..consts import DataFrameSpecs
from .rc_filters import filter_empty_rowcols
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def get_genomic_mask(
#? --- Target ---
resolution: int,
region: str,
hic_path: t.Union[str, os.PathLike],
#? --- Config ---
balancing: str = "KR",
root: t.Optional[t.Union[str, os.PathLike]] = None
) -> np.ndarray:
"""
Identifies valid (aligned) bins from Hi-C data by inspecting non-zero contacts.
Parameters
----------
resolution : int
Genomic resolution in bp.
region : str
Chromosome/region identifier.
hic_path : str | os.PathLike
Path to the .hic file.
balancing : str
Normalization scheme (e.g., 'KR').
root : str | os.PathLike | None
Project root directory.
Returns
-------
np.ndarray
Boolean mask of valid bins.
"""
if root:
# If path is relative, join with root
if not os.path.isabs(hic_path):
hic_path = os.path.join(root, hic_path)
try:
# We load the dataframe to see which bins have ANY non-zero contacts
df = cm_loaders.load_cm_data(
hic_path,
resolution,
region,
balancing=[balancing],
output_format="df",
)
# Use filter_empty_rowcols to get the aligned/valid regions
row_ids = df[DataFrameSpecs.ROW_IDS].to_numpy()
col_ids = df[DataFrameSpecs.COL_IDS].to_numpy()
# The filter_empty_rowcols with ret_unique_ids=True returns
# (new_row_ids, new_col_ids, unique_ids) or similar if ret_mapping is False.
# Let's just use it properly.
out = filter_empty_rowcols(
(row_ids, col_ids),
is_triu_sym=True,
ret_unique_ids=True
)
# out is [new_row_ids, new_col_ids, unique_ids] since ret_mapping=False
valid_bins = out[2]
if len(valid_bins) == 0:
logger.warning(f"No valid bins found for {region} at {resolution}bp")
return np.array([], dtype=bool)
# Create a boolean mask spanning the entire genomic length of the region
# (up to the last observed bin)
full_len = int(np.max(valid_bins) + 1)
mask = np.zeros(full_len, dtype=bool)
mask[valid_bins] = True
if logger.level("DEBUG"):
logger.opt(lazy=True).debug(f"Loaded genomic mask: {np.sum(mask)} valid bins out of {full_len}")
return mask
except Exception as e:
logger.error(f"Failed to compute genomic mask: {e}")
return np.array([], dtype=bool)
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def get_optimization_mask(
#? --- Coordinates ---
points: np.ndarray,
#? --- Threshold ---
threshold: float = 1e-5
) -> np.ndarray:
"""
Identifies points that have moved from the origin (stagnant noise filter).
"""
return np.linalg.norm(points, axis=1) > threshold
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def get_unified_mask(
points: np.ndarray,
resolution: int,
region: str,
hic_path: t.Union[str, os.PathLike],
balancing: str = "KR",
root: t.Optional[t.Union[str, os.PathLike]] = None
) -> np.ndarray:
"""
Combines Genomic (Hi-C) and Optimization (Movement) masks.
"""
m_gen = get_genomic_mask(resolution, region, hic_path, balancing, root)
# Match lengths
n_pts = len(points)
n_gen = len(m_gen)
final_m_gen = np.zeros(n_pts, dtype=bool)
limit = min(n_pts, n_gen)
final_m_gen[:limit] = m_gen[:limit]
m_opt = get_optimization_mask(points)
return final_m_gen & m_opt
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def intersect_masks(
masks: list[np.ndarray]
) -> np.ndarray:
"""Computes bitwise-AND across multiple masks."""
if not masks:
return np.array([], dtype=bool)
min_len = min(len(m) for m in masks)
intersect = np.ones(min_len, dtype=bool)
for m in masks:
intersect &= m[:min_len]
return intersect
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def expand_with_nans(
points_filtered: np.ndarray,
mask: np.ndarray,
full_length: t.Optional[int] = None
) -> np.ndarray:
"""
Expands a filtered point cloud back to genomic length, inserting NaNs
where the mask is False.
"""
if full_length is None:
full_length = len(mask)
full_points = np.full((full_length, 3), np.nan)
# Ensure mask and output length are compatible
m = np.zeros(full_length, dtype=bool)
limit = min(full_length, len(mask))
m[:limit] = mask[:limit]
# Only assign if counts match
n_valid_in_mask = np.sum(m)
n_points = len(points_filtered)
take = min(n_valid_in_mask, n_points)
# Filter mask to only take up to 'take' trues if sizes mismatch
if n_valid_in_mask > n_points:
# Find indices of first 'take' True values
true_indices = np.where(m)[0][:take]
m_limited = np.zeros_like(m)
m_limited[true_indices] = True
full_points[m_limited] = points_filtered[:take]
else:
full_points[m] = points_filtered[:n_valid_in_mask]
return full_points