# -*- coding: utf-8 -*-
"""
PyTorch Dataset implementation for Fully Sparse Hi-C data loading.
Supports on-the-fly binomial downsampling and genomic window indexing.
Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
import numpy as np
import pandas as pd
try:
import torch
except ImportError:
torch = None
import typing as t
from pydantic import validate_call
from ..loaders import load_cm_data, get_bins, DataStructure, Balancing
from ..utils import intervals
DatasetType = torch.utils.data.Dataset if torch else object
[docs]
class HiCSparseDataset(DatasetType):
"""
A PyTorch Dataset for on-the-fly loading of Hi-C patches from sparse files.
Attributes
----------
fpath : str
Path to the .hic or .mcool file.
resolution : int
The resolution to load.
window_size : int
Size of the genomic window (patch) in BP.
index : pd.DataFrame
The binnified and filtered index of training windows.
downsample_ratio : float or tuple, optional
Ratio for binomial subsampling. If tuple (min, max), a random ratio
is sampled per item.
Examples
--------
"""
def __init__(
self,
fpath: str,
resolution: int,
window_size: int,
blacklist: t.Optional[pd.DataFrame] = None,
downsample_ratio: t.Union[float, t.Tuple[float, float], None] = None,
balancing: t.Optional[Balancing] = Balancing.NONE,
output_type: str = "sparse",
**kwargs
):
"""
Function __init__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
self.fpath = fpath
self.resolution = resolution
self.window_size = window_size
self.downsample_ratio = downsample_ratio
self.balancing = balancing
self.output_type = output_type
self.kwargs = kwargs
# 1. Generate Index
# We use window_size as the binning step for the training windows
self.index = get_bins(fpath, window_size)
# 2. Filter Index
if blacklist is not None:
self.index = intervals.subtract(self.index, blacklist)
def __len__(self) -> int:
"""
Function __len__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
return len(self.index)
def __getitem__(self, idx: int) -> t.Dict[str, t.Any]:
"""
Function __getitem__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
row = self.index.iloc[idx]
chrom, start, end = row['chrom'], row['start'], row['end']
# 1. Fetch Sparse Pixels
# Using load_cm_data which dispatches to the correct loader
# We request RCV format for direct access to coordinate arrays
data = load_cm_data(
self.fpath,
resolution=self.resolution,
region1=f"{chrom}:{start}-{end}",
balancing=self.balancing,
output_format=DataStructure.RCV,
**self.kwargs
)
# Unpack RCV (row_ids, col_ids, counts)
# Note: result might be a tuple if multiple balancings were asked,
# but here we assume single for simplicity in basic dataset.
r_ids, c_ids, counts = data
# 2. Normalize Coordinates to Local Patch (0 to N-1)
region_start_bin = start // self.resolution
local_r = r_ids - region_start_bin
local_c = c_ids - region_start_bin
# 3. On-the-Fly Augmentation: Random Downsampling
target_counts = counts.copy()
input_counts = counts
if self.downsample_ratio is not None:
if isinstance(self.downsample_ratio, tuple):
alpha = np.random.uniform(*self.downsample_ratio)
else:
alpha = self.downsample_ratio
# Binomial Subsampling
input_counts = np.random.binomial(counts.astype(np.int32), alpha)
# Filter zeros to maintain sparsity
mask = input_counts > 0
local_r = local_r[mask]
local_c = local_c[mask]
input_counts = input_counts[mask]
#? We keep target_counts aligned with original coords or the new ones?
#? Usually we want the model to predict the full GT from sparse input.
if self.output_type == "dense":
n_bins = self.window_size // self.resolution
patch = torch.zeros((1, n_bins, n_bins), dtype=torch.float32)
if len(local_r) > 0:
v = torch.from_numpy(input_counts).float()
patch[0, local_r, local_c] = v
# Symmetrize (assuming input is Upper Triangular)
nondiag = local_r != local_c
patch[0, local_c[nondiag], local_r[nondiag]] = v[nondiag]
return patch
# 4. Prepare Sparse Tensor components
# Coordinates: (N, 2)
coords = np.stack([local_r, local_c], axis=1)
return {
"coords": torch.from_numpy(coords).long(),
"features": torch.from_numpy(input_counts).float().unsqueeze(1),
"target": torch.from_numpy(target_counts).float(), # Placeholder
"info": {"chrom": chrom, "start": start, "end": end}
}
[docs]
@validate_call(config=dict(arbitrary_types_allowed=True))
def sparse_collate_fn(batch: t.List[t.Dict[str, t.Any]]) -> t.Dict[str, t.Any]:
"""
Collate function for Sparse Tensors (MinkowskiEngine style).
Prepends batch index to coordinates.
Examples
--------
"""
batch_coords = []
batch_feats = []
for i, item in enumerate(batch):
coords = item["coords"]
# Add batch index as first column
batch_idx = torch.full((coords.shape[0], 1), i, dtype=torch.long)
batch_coords.append(torch.cat([batch_idx, coords], dim=1))
batch_feats.append(item["features"])
return {
"coords": torch.cat(batch_coords, dim=0),
"features": torch.cat(batch_feats, dim=0),
"infos": [item["info"] for item in batch]
}