# -*- coding: utf-8 -*-
"""
Module for calculating metrics related to genomic interaction data and
3D chromatin reconstructions.
Examples
--------
"""
# =============================================================================
# METADATA
# =============================================================================
__author__ = "Yeremia Gunawan Adhisantoso"
__maintainer__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
# =============================================================================
# STANDARD LIBRARY IMPORTS
# =============================================================================
import pathlib
import typing as t
# =============================================================================
# THIRD-PARTY IMPORTS
# =============================================================================
import numpy as np
from pydantic import ConfigDict, validate_call
from scipy.sparse import coo_matrix
from scipy.spatial.distance import pdist, squareform
# =============================================================================
# LOCAL APPLICATION IMPORTS
# =============================================================================
from gunz_cm import consts, loaders, preprocs
from gunz_cm.reconstructions.preprocs.points import (
downsample_points,
filter_points,
)
# =============================================================================
# METRICS FUNCTIONS
# =============================================================================
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def measure_nways_interaction_found(
#? --- Input Data ---
points: np.ndarray,
fpath_fmt: str,
target_nways_vals: t.List[int],
#? --- Data Specifiers ---
region: str,
resolution: int,
#? --- Analysis Parameters ---
ds_ratio: int = 1,
num_thres_vals: int = 100,
log_thres: bool = True,
) -> t.Dict[int, np.ndarray]:
"""
Measures the percentage of n-way interactions found within distance thresholds.
This function analyzes a 3D point cloud (a chromatin reconstruction) to
determine how well it captures known n-way genomic interactions. It
calculates the percentage of true interactions that are found within
progressively larger distance thresholds in the 3D model.
Parameters
----------
points : np.ndarray
The (N, 3) array of 3D coordinates representing the reconstruction.
fpath_fmt : str
A format string for the path to the n-way interaction data files.
Must contain placeholders for {region}, {resolution}, and {nway}.
target_nways_vals : list[int]
A list of n-way interaction orders to analyze (e.g., [3, 4, 5]).
region : str
The genomic region or chromosome being analyzed (e.g., 'chr1').
resolution : int
The resolution of the data in base pairs.
ds_ratio : int, optional
The downsampling ratio to apply to the points. Defaults to 1.
num_thres_vals : int, optional
The number of distance thresholds to evaluate. Defaults to 100.
log_thres : bool, optional
If True, use logarithmically spaced distance thresholds. Defaults to True.
Returns
-------
dict[int, np.ndarray]
A dictionary where keys are the n-way values and values are arrays
containing the percentage of interactions found at each threshold.
Examples
--------
"""
if ds_ratio > 1:
points = downsample_points(points, ds_ratio)
#? Filter out invalid points (e.g., unmapped regions) from the reconstruction
#? to avoid errors in distance calculation.
proj_points, points_mask = filter_points(
points, ret_mask=True
)
#? Calculate the Euclidean Distance Matrix for the valid points, which will be
#? used to check which interactions are close in 3D space.
proj_lr_edm = squareform(pdist(proj_points))
unique_dist_vals = np.unique(proj_lr_edm)
#? Generate distance thresholds for the analysis.
if log_thres:
#? Use log-spaced bins for thresholds, as distances in 3D structures
#? often follow a log-normal distribution, and this provides better
#? resolution at shorter distances.
bins = np.histogram(np.log2(unique_dist_vals[1:]), bins=num_thres_vals)[1]
thres_vals = 2**bins
else:
_, thres_vals = np.histogram(unique_dist_vals, bins=num_thres_vals)
results = {}
for nway_val in target_nways_vals:
#? Construct the path to the specific n-way interaction file
sprite_fpath = pathlib.Path(
fpath_fmt.format(
region=region,
resolution=resolution * ds_ratio,
nway=nway_val,
)
)
#? Load the interaction data as a sparse COO tuple (rows, cols, vals)
rows, cols, _ = loaders.load_cm_data(
fpath=sprite_fpath,
region1=region,
resolution=resolution * ds_ratio,
output_format=consts.DataStructure.COO,
)
#? Ensure the data is in upper-triangular format for efficiency
#? and to avoid double-counting.
assert (rows < cols).all(), "N-way interaction data must be upper-triangular"
#? Create a SciPy COO matrix object for efficient sparse operations.
shape = (points.shape[0], points.shape[0])
nway_coo_matrix = coo_matrix((np.ones_like(rows), (rows, cols)), shape=shape)
nway_coo_matrix = preprocs.create_triu_matrix(
nway_coo_matrix, remove_main_diag=True
)
#? Create a dense boolean matrix of interactions and filter it to match the points
nway_mat = nway_coo_matrix.toarray() > 0
proj_nway_mat = nway_mat[points_mask, :][:, points_mask]
num_nways_found = proj_nway_mat.sum()
if num_nways_found == 0:
results[nway_val] = np.zeros_like(thres_vals, dtype=float)
continue
perc_nways_found = np.zeros_like(thres_vals, dtype=float)
#? Calculate the cumulative percentage of interactions found at each threshold
for i, thres_val in enumerate(thres_vals):
#? Count interactions that are both TRUE in the n-way matrix AND
#? have a distance less than the current threshold in the 3D model.
count_at_thres = np.logical_and(proj_lr_edm < thres_val, proj_nway_mat).sum()
perc_nways_found[i] = count_at_thres / num_nways_found
results[nway_val] = perc_nways_found
return results