Source code for gunz_cm.metrics.reconstruction.n_way_interactions

# -*- coding: utf-8 -*-
"""
Module for calculating metrics related to genomic interaction data and
3D chromatin reconstructions.


Examples
--------
"""

# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__maintainer__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"


# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import pathlib
import typing as t

# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import numpy as np
from pydantic import ConfigDict, validate_call
from scipy.sparse import coo_matrix
from scipy.spatial.distance import pdist, squareform

# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from gunz_cm import consts, loaders, preprocs
from gunz_cm.reconstructions.preprocs.points import (
    downsample_points,
    filter_points,
)

# =============================================================================
# METRICS FUNCTIONS
# =============================================================================

[docs] @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def measure_nways_interaction_found( #? --- Input Data --- points: np.ndarray, fpath_fmt: str, target_nways_vals: t.List[int], #? --- Data Specifiers --- region: str, resolution: int, #? --- Analysis Parameters --- ds_ratio: int = 1, num_thres_vals: int = 100, log_thres: bool = True, ) -> t.Dict[int, np.ndarray]: """ Measures the percentage of n-way interactions found within distance thresholds. This function analyzes a 3D point cloud (a chromatin reconstruction) to determine how well it captures known n-way genomic interactions. It calculates the percentage of true interactions that are found within progressively larger distance thresholds in the 3D model. Parameters ---------- points : np.ndarray The (N, 3) array of 3D coordinates representing the reconstruction. fpath_fmt : str A format string for the path to the n-way interaction data files. Must contain placeholders for {region}, {resolution}, and {nway}. target_nways_vals : list[int] A list of n-way interaction orders to analyze (e.g., [3, 4, 5]). region : str The genomic region or chromosome being analyzed (e.g., 'chr1'). resolution : int The resolution of the data in base pairs. ds_ratio : int, optional The downsampling ratio to apply to the points. Defaults to 1. num_thres_vals : int, optional The number of distance thresholds to evaluate. Defaults to 100. log_thres : bool, optional If True, use logarithmically spaced distance thresholds. Defaults to True. Returns ------- dict[int, np.ndarray] A dictionary where keys are the n-way values and values are arrays containing the percentage of interactions found at each threshold. Examples -------- """ if ds_ratio > 1: points = downsample_points(points, ds_ratio) #? Filter out invalid points (e.g., unmapped regions) from the reconstruction #? to avoid errors in distance calculation. proj_points, points_mask = filter_points( points, ret_mask=True ) #? Calculate the Euclidean Distance Matrix for the valid points, which will be #? used to check which interactions are close in 3D space. proj_lr_edm = squareform(pdist(proj_points)) unique_dist_vals = np.unique(proj_lr_edm) #? Generate distance thresholds for the analysis. if log_thres: #? Use log-spaced bins for thresholds, as distances in 3D structures #? often follow a log-normal distribution, and this provides better #? resolution at shorter distances. bins = np.histogram(np.log2(unique_dist_vals[1:]), bins=num_thres_vals)[1] thres_vals = 2**bins else: _, thres_vals = np.histogram(unique_dist_vals, bins=num_thres_vals) results = {} for nway_val in target_nways_vals: #? Construct the path to the specific n-way interaction file sprite_fpath = pathlib.Path( fpath_fmt.format( region=region, resolution=resolution * ds_ratio, nway=nway_val, ) ) #? Load the interaction data as a sparse COO tuple (rows, cols, vals) rows, cols, _ = loaders.load_cm_data( fpath=sprite_fpath, region1=region, resolution=resolution * ds_ratio, output_format=consts.DataStructure.COO, ) #? Ensure the data is in upper-triangular format for efficiency #? and to avoid double-counting. assert (rows < cols).all(), "N-way interaction data must be upper-triangular" #? Create a SciPy COO matrix object for efficient sparse operations. shape = (points.shape[0], points.shape[0]) nway_coo_matrix = coo_matrix((np.ones_like(rows), (rows, cols)), shape=shape) nway_coo_matrix = preprocs.create_triu_matrix( nway_coo_matrix, remove_main_diag=True ) #? Create a dense boolean matrix of interactions and filter it to match the points nway_mat = nway_coo_matrix.toarray() > 0 proj_nway_mat = nway_mat[points_mask, :][:, points_mask] num_nways_found = proj_nway_mat.sum() if num_nways_found == 0: results[nway_val] = np.zeros_like(thres_vals, dtype=float) continue perc_nways_found = np.zeros_like(thres_vals, dtype=float) #? Calculate the cumulative percentage of interactions found at each threshold for i, thres_val in enumerate(thres_vals): #? Count interactions that are both TRUE in the n-way matrix AND #? have a distance less than the current threshold in the 3D model. count_at_thres = np.logical_and(proj_lr_edm < thres_val, proj_nway_mat).sum() perc_nways_found[i] = count_at_thres / num_nways_found results[nway_val] = perc_nways_found return results