Source code for gunz_cm.datasets.hic

# -*- coding: utf-8 -*-
"""
PyTorch Dataset implementation for Fully Sparse Hi-C data loading.
Supports on-the-fly binomial downsampling and genomic window indexing.


Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"

import numpy as np
import pandas as pd
try:
    import torch
except ImportError:
    torch = None


import typing as t
from pydantic import validate_call

from ..loaders import load_cm_data, get_bins, DataStructure, Balancing
from ..utils import intervals

DatasetType = torch.utils.data.Dataset if torch else object

[docs] class HiCSparseDataset(DatasetType): """ A PyTorch Dataset for on-the-fly loading of Hi-C patches from sparse files. Attributes ---------- fpath : str Path to the .hic or .mcool file. resolution : int The resolution to load. window_size : int Size of the genomic window (patch) in BP. index : pd.DataFrame The binnified and filtered index of training windows. downsample_ratio : float or tuple, optional Ratio for binomial subsampling. If tuple (min, max), a random ratio is sampled per item. Examples -------- """ def __init__( self, fpath: str, resolution: int, window_size: int, blacklist: t.Optional[pd.DataFrame] = None, downsample_ratio: t.Union[float, t.Tuple[float, float], None] = None, balancing: t.Optional[Balancing] = Balancing.NONE, output_type: str = "sparse", **kwargs ): """ Function __init__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ self.fpath = fpath self.resolution = resolution self.window_size = window_size self.downsample_ratio = downsample_ratio self.balancing = balancing self.output_type = output_type self.kwargs = kwargs # 1. Generate Index # We use window_size as the binning step for the training windows self.index = get_bins(fpath, window_size) # 2. Filter Index if blacklist is not None: self.index = intervals.subtract(self.index, blacklist) def __len__(self) -> int: """ Function __len__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ return len(self.index) def __getitem__(self, idx: int) -> t.Dict[str, t.Any]: """ Function __getitem__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ row = self.index.iloc[idx] chrom, start, end = row['chrom'], row['start'], row['end'] # 1. Fetch Sparse Pixels # Using load_cm_data which dispatches to the correct loader # We request RCV format for direct access to coordinate arrays data = load_cm_data( self.fpath, resolution=self.resolution, region1=f"{chrom}:{start}-{end}", balancing=self.balancing, output_format=DataStructure.RCV, **self.kwargs ) # Unpack RCV (row_ids, col_ids, counts) # Note: result might be a tuple if multiple balancings were asked, # but here we assume single for simplicity in basic dataset. r_ids, c_ids, counts = data # 2. Normalize Coordinates to Local Patch (0 to N-1) region_start_bin = start // self.resolution local_r = r_ids - region_start_bin local_c = c_ids - region_start_bin # 3. On-the-Fly Augmentation: Random Downsampling target_counts = counts.copy() input_counts = counts if self.downsample_ratio is not None: if isinstance(self.downsample_ratio, tuple): alpha = np.random.uniform(*self.downsample_ratio) else: alpha = self.downsample_ratio # Binomial Subsampling input_counts = np.random.binomial(counts.astype(np.int32), alpha) # Filter zeros to maintain sparsity mask = input_counts > 0 local_r = local_r[mask] local_c = local_c[mask] input_counts = input_counts[mask] #? We keep target_counts aligned with original coords or the new ones? #? Usually we want the model to predict the full GT from sparse input. if self.output_type == "dense": n_bins = self.window_size // self.resolution patch = torch.zeros((1, n_bins, n_bins), dtype=torch.float32) if len(local_r) > 0: v = torch.from_numpy(input_counts).float() patch[0, local_r, local_c] = v # Symmetrize (assuming input is Upper Triangular) nondiag = local_r != local_c patch[0, local_c[nondiag], local_r[nondiag]] = v[nondiag] return patch # 4. Prepare Sparse Tensor components # Coordinates: (N, 2) coords = np.stack([local_r, local_c], axis=1) return { "coords": torch.from_numpy(coords).long(), "features": torch.from_numpy(input_counts).float().unsqueeze(1), "target": torch.from_numpy(target_counts).float(), # Placeholder "info": {"chrom": chrom, "start": start, "end": end} }
[docs] @validate_call(config=dict(arbitrary_types_allowed=True)) def sparse_collate_fn(batch: t.List[t.Dict[str, t.Any]]) -> t.Dict[str, t.Any]: """ Collate function for Sparse Tensors (MinkowskiEngine style). Prepends batch index to coordinates. Examples -------- """ batch_coords = [] batch_feats = [] for i, item in enumerate(batch): coords = item["coords"] # Add batch index as first column batch_idx = torch.full((coords.shape[0], 1), i, dtype=torch.long) batch_coords.append(torch.cat([batch_idx, coords], dim=1)) batch_feats.append(item["features"]) return { "coords": torch.cat(batch_coords, dim=0), "features": torch.cat(batch_feats, dim=0), "infos": [item["info"] for item in batch] }