# -*- coding: utf-8 -*-
"""
PyTorch Dataset implementation for Memory-Mapped Hi-C data loading.
Offers extreme throughput by bypassing decompression.
Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
import numpy as np
import pandas as pd
try:
import torch
except ImportError:
torch = None
import typing as t
from ..loaders import load_memmap
from ..loaders import hic_loader
from ..matrix import ContactMatrix
from ..exceptions import DatasetError
from ..consts import Balancing
DatasetType = torch.utils.data.Dataset if torch else object
[docs]
class MemmapSparseDataset(DatasetType):
"""
A PyTorch Dataset for ultra-fast loading of Hi-C patches from uncompressed
memory-mapped files (.npdat).
Attributes
----------
fpath : str
Path to the .npdat file (or base path).
window_size : int
Size of the genomic window (patch) in BP.
downsample_ratio : float or tuple, optional
Ratio for binomial subsampling.
Examples
--------
"""
def __init__(
self,
fpath: str,
window_size: int,
downsample_ratio: t.Union[float, t.Tuple[float, float], None] = None,
balancing: t.Optional[Balancing] = None,
hic_path: t.Optional[str] = None,
output_type: str = "sparse",
):
"""
Function __init__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
self.fpath = fpath
self.window_size = window_size
self.downsample_ratio = downsample_ratio
self.balancing = balancing
self.output_type = output_type
self.weights = None
# 1. Load Memmap (Lazy)
# Returns a ContactMatrix wrapper around np.memmap
self.matrix_obj: ContactMatrix = load_memmap(fpath, mode='r')
self.data: np.memmap = self.matrix_obj.data
self.metadata = self.matrix_obj.metadata
self.resolution = self.metadata.get("resolution")
if not self.resolution:
raise DatasetError("Memmap metadata missing 'resolution'")
self.chrom = self.metadata.get("chromosome1", "unknown")
# 2. Load Weights if On-the-Fly Balancing requested
if self.balancing and self.balancing != Balancing.NONE:
if not hic_path:
# Try to infer from metadata if possible, otherwise raise
raise DatasetError("hic_path is required for on-the-fly balancing")
# Load weights using hictk (via hic_loader helper if accessible or direct)
# We use hic_loader's cached accessor
try:
# Ensure hictk is available
# Use the cached helper from hic_loader to avoid reopening
# Note: accessing private member _get_hic_file is risky but efficient.
# Alternatively use public load_hic but extracting weights is not exposed publicly.
# We will use hictkpy directly but ideally we should expose weight fetching in loaders.
# For now, open file (cached by OS/hictk internally hopefully)
# Or re-implement simple caching here?
# hic_loader._get_hic_file is cached.
f = hic_loader._get_hic_file(hic_path, self.resolution)
# Fetch weights
# Need to handle chromosome offset?
# weights() returns GLOBAL weights.
# We need weights corresponding to the memmap indices.
# Memmap usually corresponds to a specific region (chromosome).
# We need the offset of that chromosome in the global bin table.
bins = f.bins()
# Assuming memmap starts at chromosome start 0
offset = bins.get_id(self.chrom, 0)
# Fetch all weights
global_weights = f.weights(self.balancing.value, divisive=True)
# Slice weights for this chromosome
# Length matches memmap shape[0]
n_bins = self.data.shape[0]
# Check bounds
if offset + n_bins > len(global_weights):
raise DatasetError(f"Weights vector shorter than expected. Offset={offset}, N={n_bins}, W={len(global_weights)}")
self.weights = global_weights[offset : offset + n_bins]
except Exception as e:
raise DatasetError(f"Failed to load normalization weights from {hic_path}: {e}") from e
# 3. Generate Index
# Assuming square matrix for a single chromosome or region
n_bins = self.data.shape[0]
step = window_size // self.resolution
starts = np.arange(0, n_bins, step)
ends = starts + step
# Clip last window
ends = np.clip(ends, 0, n_bins)
self.index = pd.DataFrame({
'start_bin': starts,
'end_bin': ends
})
# Remove partial windows if necessary?
# For now keep them, just slice safely.
def __len__(self) -> int:
"""
Function __len__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
return len(self.index)
def __getitem__(self, idx: int) -> t.Dict[str, t.Any]:
"""
Function __getitem__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
row = self.index.iloc[idx]
start_bin, end_bin = int(row['start_bin']), int(row['end_bin'])
# 1. Slice Dense Patch (Zero-Copy View)
# We access the diagonal block (intra-chromosomal window)
patch = self.data[start_bin:end_bin, start_bin:end_bin]
# Dense Output (Fast Path)
if self.output_type == "dense":
if self.weights is not None:
# Apply weights via broadcasting (vectorized)
w_slice = self.weights[start_bin:end_bin]
denom = np.outer(w_slice, w_slice)
# Compute balanced
patch_bal = patch.astype(np.float32) / denom
# Replace NaN/Inf with 0
np.nan_to_num(patch_bal, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
out_tensor = torch.from_numpy(patch_bal)
else:
out_tensor = torch.from_numpy(patch).float()
return out_tensor.unsqueeze(0) # (1, H, W)
# 2. Sparsify on-the-fly
# np.nonzero is fast for small patches in cache
# Returns tuple of arrays (row_indices, col_indices)
r_ids, c_ids = np.nonzero(patch)
# Enforce Upper Triangular to match standard HiCSparseDataset output
mask_ut = r_ids <= c_ids
r_ids = r_ids[mask_ut]
c_ids = c_ids[mask_ut]
counts = patch[r_ids, c_ids]
# Apply On-the-Fly Normalization
if self.weights is not None:
# Shift indices to chromosome-level for weight lookup
w1 = self.weights[start_bin + r_ids]
w2 = self.weights[start_bin + c_ids]
# Compute balanced counts (Raw / Norm1 / Norm2)
# Ensure float32 output
counts = counts.astype(np.float32) / (w1 * w2)
# Filter invalid values (NaN/Inf) resulting from bad weights
mask_valid = np.isfinite(counts)
if not np.all(mask_valid):
r_ids = r_ids[mask_valid]
c_ids = c_ids[mask_valid]
counts = counts[mask_valid]
# Coordinates are already local to the patch (0..window_size)
# 3. Augmentation (Same as HiCSparseDataset)
target_counts = counts.copy()
input_counts = counts
if self.downsample_ratio is not None:
if isinstance(self.downsample_ratio, tuple):
alpha = np.random.uniform(*self.downsample_ratio)
else:
alpha = self.downsample_ratio
input_counts = np.random.binomial(counts.astype(np.int32), alpha)
mask = input_counts > 0
r_ids = r_ids[mask]
c_ids = c_ids[mask]
input_counts = input_counts[mask]
# 4. Prepare Tensors
coords = np.stack([r_ids, c_ids], axis=1)
# Info for debugging/mapping
info_start = start_bin * self.resolution
info_end = end_bin * self.resolution
return {
"coords": torch.from_numpy(coords).long(),
"features": torch.from_numpy(input_counts).float().unsqueeze(1),
"target": torch.from_numpy(target_counts).float(),
"info": {"chrom": self.chrom, "start": info_start, "end": info_end}
}