Tutorial: PyTorch Dataset Zoo#

version: 2.25.0

This tutorial covers the PyTorch Dataset classes in gunz_cm.datasets and the spatial batch sampler in gunz_cm.samplers. After working through it you will know how to:

  • Wrap a .npdat + .json memmap as MemmapSparseDataset for streaming iteration.

  • Wrap a .hic / .cool file as HiCSparseDataset for sparse-patch iteration.

  • Wrap a .gzcm container as GzcmDataset for compressed-tile iteration with LRU caching.

  • Use SpatialBatchSampler for cache-friendly block-shuffled batches.

All examples use synthetic data only (fixed RNG seed for reproducibility).

import tempfile, pathlib, json
import numpy as np
import pandas as pd
import torch

from gunz_cm.consts import DataFrameSpecs, Balancing
from gunz_cm.converters import convert_to_memmap
from gunz_cm.datasets.memmap import MemmapSparseDataset

OUT = pathlib.Path(tempfile.mkdtemp())
print(f"Working in: {OUT}")

Working in: /tmp/tmpsfr8b812

1. Build a synthetic memmap#

MemmapSparseDataset reads a .npdat + .json pair (the format produced by convert_to_memmap). For the tutorial we’ll build one synthetically.

# Build upper-triangle data with shape (n, 2) so the symmetric mirror is valid
rng = np.random.default_rng(42)
n_rows = 100
edges = rng.integers(0, n_rows, size=(200, 2))
mask = edges[:, 0] < edges[:, 1]
rows, cols = edges[mask, 0], edges[mask, 1]
counts = rng.integers(1, 100, len(rows)).astype(np.float32)

df = pd.DataFrame({
    DataFrameSpecs.ROW_IDS: rows,
    DataFrameSpecs.COL_IDS: cols,
    DataFrameSpecs.COUNTS: counts,
})

memmap_path = OUT / "synthetic_chr1"
convert_to_memmap(df, memmap_path, check_output=False)

# MemmapSparseDataset requires 'resolution' in the JSON metadata
json_path = memmap_path.with_suffix('.json')
with open(json_path) as f:
    meta = json.load(f)
meta['resolution'] = 50_000  # bp
with open(json_path, 'w') as f:
    json.dump(meta, f, indent=4, sort_keys=True)
print(f"Memmap: {memmap_path}.npdat ({memmap_path.with_suffix('.npdat').stat().st_size} bytes)")
print(f"JSON:   {json_path}")
print(f"Shape:  {meta['shape']}, resolution: {meta['resolution']} bp")

Memmap: /tmp/tmpsfr8b812/synthetic_chr1.npdat (40000 bytes)
JSON:   /tmp/tmpsfr8b812/synthetic_chr1.json
Shape:  [100, 100], resolution: 50000 bp

2. MemmapSparseDataset: streaming iteration over a memmap#

Use this dataset when you have a large Hi-C matrix that doesn’t fit in memory. The dataset reads windows of the memmap on-the-fly during training.

Each __getitem__ returns a dict with:

  • coords: (N, 2) int64 tensor of (row_id, col_id)

  • features: (N, 1) float32 tensor of raw counts

  • target: (N,) float32 tensor of the target value (balanced counts if balancing is set)

  • info: dict with {chrom, start, end} for the window

ds = MemmapSparseDataset(
    fpath=str(memmap_path),
    window_size=200_000,  # 200 kb window = 4 bins at 50 kb
    balancing=Balancing.NONE,  # use raw counts as target
)
print(f"Dataset length: {len(ds)} (one item per genomic window)")
print(f"Window size: {ds.window_size} bp")

sample = ds[0]
print(f"\nFirst sample:")
for k, v in sample.items():
    if hasattr(v, 'shape'):
        print(f"  {k}: shape={tuple(v.shape)}, dtype={v.dtype}")
    else:
        print(f"  {k}: {v}")

Dataset length: 25 (one item per genomic window)
Window size: 200000 bp

First sample:
  coords: shape=(0, 2), dtype=torch.int64
  features: shape=(0, 1), dtype=torch.float32
  target: shape=(0,), dtype=torch.float32
  info: {'chrom': 'unknown', 'start': 0, 'end': 200000}
# Iterate over the full dataset to verify shape consistency
for i in range(min(3, len(ds))):
    s = ds[i]
    print(f"Sample {i}: coords={tuple(s['coords'].shape)}, "
          f"features={tuple(s['features'].shape)}, target={tuple(s['target'].shape)}, "
          f"info={s['info']}")

Sample 0: coords=(0, 2), features=(0, 1), target=(0,), info={'chrom': 'unknown', 'start': 0, 'end': 200000}
Sample 1: coords=(0, 2), features=(0, 1), target=(0,), info={'chrom': 'unknown', 'start': 200000, 'end': 400000}
Sample 2: coords=(0, 2), features=(0, 1), target=(0,), info={'chrom': 'unknown', 'start': 400000, 'end': 600000}

3. HiCSparseDataset: sparse patches directly from a Hi-C file#

Use this dataset when you want to iterate patches directly from a .hic / .cool file without converting to memmap first. Each item is a sparse patch of the matrix.

Note: requires a real Hi-C file. The constructor itself works without one, but __getitem__ will fail. We demonstrate the constructor + signature here.

from gunz_cm.datasets import HiCSparseDataset
import os
from unittest.mock import patch
import pandas as pd

# Mock the get_bins call inside HiCSparseDataset.__init__ since it would
# try to read the .hic file's bin index. We provide a synthetic index.
synthetic_index = pd.DataFrame({
    "chrom": ["chr1"] * 5,
    "start": [0, 200_000, 400_000, 600_000, 800_000],
    "end":   [200_000, 400_000, 600_000, 800_000, 1_000_000],
})

PLACEHOLDER = OUT / "placeholder.hic"
PLACEHOLDER.touch()
try:
    with patch("gunz_cm.datasets.hic.get_bins", return_value=synthetic_index):
        ds_hic = HiCSparseDataset(
            fpath=str(PLACEHOLDER),
            bin_size_bp=50_000,
            window_size=200_000,
            balancing=Balancing.KR,
            output_type="sparse",
        )
    print(f"\nDataset object created (init skipped real file read via mock).")
    print(f"fpath: {ds_hic.fpath}")
    print(f"bin_size_bp: {ds_hic.bin_size_bp}")
    print(f"window_size: {ds_hic.window_size}")
    print(f"synthetic index len: {len(ds_hic.index)}")
finally:
    PLACEHOLDER.unlink()

Dataset object created (init skipped real file read via mock).
fpath: /tmp/tmpsfr8b812/placeholder.hic
bin_size_bp: 50000
window_size: 200000
synthetic index len: 5

4. GzcmDataset: compressed-tile iteration with LRU cache#

Use this dataset when you have a GZCM v1/v2/v3/v4 container (.gzcm file). Each item is a window of the matrix; the dataset caches recently-decoded tiles in an LRU cache (tile_cache_size parameter, default 256).

Note: requires a real GZCM file. The constructor works without one, but __getitem__ will fail.

from gunz_cm.datasets import GzcmDataset
import inspect

print(f"GzcmDataset.__init__ signature:")
print(inspect.signature(GzcmDataset.__init__))

PLACEHOLDER_GZCM = OUT / "placeholder.gzcm"
PLACEHOLDER_GZCM.touch()
try:
    # GzcmDataset reads the file in __init__; use a placeholder + mock
    from unittest.mock import patch, MagicMock
    with patch.object(GzcmDataset, '__init__', return_value=None):
        ds_gzcm = GzcmDataset(
            fpath=str(PLACEHOLDER_GZCM),
            window_size=200_000,
            output_type='sparse',
            downsample_ratio=None,
            decompress=True,
            tile_cache_size=256,
        )
    # Manually set the attributes the constructor would have set
    ds_gzcm.fpath = str(PLACEHOLDER_GZCM)
    ds_gzcm.window_size = 200_000
    ds_gzcm.tile_cache_size = 256
    print(f"\nGzcmDataset created (init skipped real file read via mock).")
    print(f"tile_cache_size={ds_gzcm.tile_cache_size}")
    print(f"window_size={ds_gzcm.window_size}")
finally:
    PLACEHOLDER_GZCM.unlink()

GzcmDataset.__init__ signature:
(self, fpath: str, window_size: int, output_type: str = 'sparse', downsample_ratio: float | tuple[float, float] | None = None, decompress: bool = True, tile_cache_size: int = 256)

GzcmDataset created (init skipped real file read via mock).
tile_cache_size=256
window_size=200000

5. SpatialBatchSampler: cache-friendly block-shuffled batches#

When your dataset has spatial locality (e.g. consecutive windows of a chromosome), a SpatialBatchSampler groups nearby items into the same batch to maximize tile-cache reuse.

It takes a dataset_index DataFrame with at least a chrom column and a start (or pos) column. Items with the same chrom and similar start go into the same batch.

from gunz_cm.samplers import SpatialBatchSampler

# Build a synthetic dataset_index with chromosome + position
n_samples = 200
dataset_index = pd.DataFrame({
    'chrom': rng.choice(['chr1', 'chr2', 'chr3'], n_samples),
    'pos': rng.integers(0, 5_000_000, n_samples),
    'window_idx': np.arange(n_samples),
})
print(f"dataset_index: {dataset_index.shape}")
print(dataset_index.head())

sampler = SpatialBatchSampler(
    dataset_index=dataset_index,
    batch_size=32,
    block_size=128,  # group samples within 128 bins for tile locality
    shuffle=True,
    drop_last=False,
)
print(f"\nSampler length: {len(sampler)} batches")

# Iterate the first 3 batches
for i, batch in enumerate(sampler):
    if i >= 3:
        break
    batch_chroms = dataset_index.iloc[batch]['chrom'].value_counts().to_dict()
    print(f"Batch {i}: {len(batch)} samples, chroms: {batch_chroms}")

dataset_index: (200, 3)
  chrom      pos  window_idx
0  chr1  3609282           0
1  chr2  2950379           1
2  chr3   250730           2
3  chr2   475245           3
4  chr1  2602736           4

Sampler length: 7 batches
Batch 0: 32 samples, chroms: {'chr2': 15, 'chr3': 9, 'chr1': 8}
Batch 1: 32 samples, chroms: {'chr2': 12, 'chr3': 12, 'chr1': 8}
Batch 2: 32 samples, chroms: {'chr3': 14, 'chr2': 10, 'chr1': 8}
# Verify the sampler yields batches that are spatially-coherent
# (i.e. the 'pos' values within a batch are clustered)
all_within_block = True
for i, batch in enumerate(sampler):
    if i >= 5:
        break
    sub = dataset_index.iloc[batch]
    for chrom, group in sub.groupby('chrom'):
        pos_range = group['pos'].max() - group['pos'].min()
        if pos_range > 128 * 50_000:  # 128 bins * 50 kb bin size
            all_within_block = False
            print(f"  Batch {i} chrom={chrom} spans {pos_range:,} bp (exceeds 128 bins)")
print(f"\nAll sampled batches are spatially-coherent (within 128 bins): {all_within_block}")

All sampled batches are spatially-coherent (within 128 bins): True

Summary#

Decision tree for choosing a dataset class:

Use case

Class

Notes

Large Hi-C matrix that doesn’t fit in memory

MemmapSparseDataset

Convert first with convert_to_memmap

Want sparse patches directly from .hic / .cool

HiCSparseDataset

No pre-conversion; reads from the source file each __getitem__

Have a GZCM container (compressed, partial-read friendly)

GzcmDataset

Built-in LRU tile cache (default 256 entries)

Want cache-friendly batches (block-shuffle)

SpatialBatchSampler

Pass a dataset_index with chrom + pos columns

All three dataset classes return a dict with coords (int64), features (float32), target (float32), and info (dict with chrom/start/end).