Source code for gunz_cm.datasets.memmap

# -*- coding: utf-8 -*-
"""
PyTorch Dataset implementation for Memory-Mapped Hi-C data loading.
Offers extreme throughput by bypassing decompression.


Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"

import numpy as np
import pandas as pd
try:
    import torch
except ImportError:
    torch = None


import typing as t

from ..loaders import load_memmap
from ..loaders import hic_loader
from ..matrix import ContactMatrix
from ..exceptions import DatasetError
from ..consts import Balancing

DatasetType = torch.utils.data.Dataset if torch else object

[docs] class MemmapSparseDataset(DatasetType): """ A PyTorch Dataset for ultra-fast loading of Hi-C patches from uncompressed memory-mapped files (.npdat). Attributes ---------- fpath : str Path to the .npdat file (or base path). window_size : int Size of the genomic window (patch) in BP. downsample_ratio : float or tuple, optional Ratio for binomial subsampling. Examples -------- """ def __init__( self, fpath: str, window_size: int, downsample_ratio: t.Union[float, t.Tuple[float, float], None] = None, balancing: t.Optional[Balancing] = None, hic_path: t.Optional[str] = None, output_type: str = "sparse", ): """ Function __init__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ self.fpath = fpath self.window_size = window_size self.downsample_ratio = downsample_ratio self.balancing = balancing self.output_type = output_type self.weights = None # 1. Load Memmap (Lazy) # Returns a ContactMatrix wrapper around np.memmap self.matrix_obj: ContactMatrix = load_memmap(fpath, mode='r') self.data: np.memmap = self.matrix_obj.data self.metadata = self.matrix_obj.metadata self.resolution = self.metadata.get("resolution") if not self.resolution: raise DatasetError("Memmap metadata missing 'resolution'") self.chrom = self.metadata.get("chromosome1", "unknown") # 2. Load Weights if On-the-Fly Balancing requested if self.balancing and self.balancing != Balancing.NONE: if not hic_path: # Try to infer from metadata if possible, otherwise raise raise DatasetError("hic_path is required for on-the-fly balancing") # Load weights using hictk (via hic_loader helper if accessible or direct) # We use hic_loader's cached accessor try: # Ensure hictk is available # Use the cached helper from hic_loader to avoid reopening # Note: accessing private member _get_hic_file is risky but efficient. # Alternatively use public load_hic but extracting weights is not exposed publicly. # We will use hictkpy directly but ideally we should expose weight fetching in loaders. # For now, open file (cached by OS/hictk internally hopefully) # Or re-implement simple caching here? # hic_loader._get_hic_file is cached. f = hic_loader._get_hic_file(hic_path, self.resolution) # Fetch weights # Need to handle chromosome offset? # weights() returns GLOBAL weights. # We need weights corresponding to the memmap indices. # Memmap usually corresponds to a specific region (chromosome). # We need the offset of that chromosome in the global bin table. bins = f.bins() # Assuming memmap starts at chromosome start 0 offset = bins.get_id(self.chrom, 0) # Fetch all weights global_weights = f.weights(self.balancing.value, divisive=True) # Slice weights for this chromosome # Length matches memmap shape[0] n_bins = self.data.shape[0] # Check bounds if offset + n_bins > len(global_weights): raise DatasetError(f"Weights vector shorter than expected. Offset={offset}, N={n_bins}, W={len(global_weights)}") self.weights = global_weights[offset : offset + n_bins] except Exception as e: raise DatasetError(f"Failed to load normalization weights from {hic_path}: {e}") from e # 3. Generate Index # Assuming square matrix for a single chromosome or region n_bins = self.data.shape[0] step = window_size // self.resolution starts = np.arange(0, n_bins, step) ends = starts + step # Clip last window ends = np.clip(ends, 0, n_bins) self.index = pd.DataFrame({ 'start_bin': starts, 'end_bin': ends }) # Remove partial windows if necessary? # For now keep them, just slice safely. def __len__(self) -> int: """ Function __len__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ return len(self.index) def __getitem__(self, idx: int) -> t.Dict[str, t.Any]: """ Function __getitem__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ row = self.index.iloc[idx] start_bin, end_bin = int(row['start_bin']), int(row['end_bin']) # 1. Slice Dense Patch (Zero-Copy View) # We access the diagonal block (intra-chromosomal window) patch = self.data[start_bin:end_bin, start_bin:end_bin] # Dense Output (Fast Path) if self.output_type == "dense": if self.weights is not None: # Apply weights via broadcasting (vectorized) w_slice = self.weights[start_bin:end_bin] denom = np.outer(w_slice, w_slice) # Compute balanced patch_bal = patch.astype(np.float32) / denom # Replace NaN/Inf with 0 np.nan_to_num(patch_bal, copy=False, nan=0.0, posinf=0.0, neginf=0.0) out_tensor = torch.from_numpy(patch_bal) else: out_tensor = torch.from_numpy(patch).float() return out_tensor.unsqueeze(0) # (1, H, W) # 2. Sparsify on-the-fly # np.nonzero is fast for small patches in cache # Returns tuple of arrays (row_indices, col_indices) r_ids, c_ids = np.nonzero(patch) # Enforce Upper Triangular to match standard HiCSparseDataset output mask_ut = r_ids <= c_ids r_ids = r_ids[mask_ut] c_ids = c_ids[mask_ut] counts = patch[r_ids, c_ids] # Apply On-the-Fly Normalization if self.weights is not None: # Shift indices to chromosome-level for weight lookup w1 = self.weights[start_bin + r_ids] w2 = self.weights[start_bin + c_ids] # Compute balanced counts (Raw / Norm1 / Norm2) # Ensure float32 output counts = counts.astype(np.float32) / (w1 * w2) # Filter invalid values (NaN/Inf) resulting from bad weights mask_valid = np.isfinite(counts) if not np.all(mask_valid): r_ids = r_ids[mask_valid] c_ids = c_ids[mask_valid] counts = counts[mask_valid] # Coordinates are already local to the patch (0..window_size) # 3. Augmentation (Same as HiCSparseDataset) target_counts = counts.copy() input_counts = counts if self.downsample_ratio is not None: if isinstance(self.downsample_ratio, tuple): alpha = np.random.uniform(*self.downsample_ratio) else: alpha = self.downsample_ratio input_counts = np.random.binomial(counts.astype(np.int32), alpha) mask = input_counts > 0 r_ids = r_ids[mask] c_ids = c_ids[mask] input_counts = input_counts[mask] # 4. Prepare Tensors coords = np.stack([r_ids, c_ids], axis=1) # Info for debugging/mapping info_start = start_bin * self.resolution info_end = end_bin * self.resolution return { "coords": torch.from_numpy(coords).long(), "features": torch.from_numpy(input_counts).float().unsqueeze(1), "target": torch.from_numpy(target_counts).float(), "info": {"chrom": self.chrom, "start": info_start, "end": info_end} }