# -*- coding: utf-8 -*-
"""
Samplers for spatial data locality optimization.
Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
import numpy as np
import pandas as pd
from torch.utils.data import Sampler
import typing as t
[docs]
class SpatialBatchSampler(Sampler[t.List[int]]):
"""
A BatchSampler that yields mini-batches ordered by spatial proximity
to maximize cache hits in compressed files.
To maintain randomness for SGD, it performs 'Block Shuffling':
1. Sorts the dataset spatially.
2. Groups indices into 'mega-blocks' (e.g., 50-100 samples).
3. Shuffles the order of mega-blocks.
4. Yields sequential mini-batches from within each mega-block.
Examples
--------
"""
def __init__(
self,
dataset_index: pd.DataFrame,
batch_size: int,
block_size: int = 128,
shuffle: bool = True,
drop_last: bool = False,
):
"""
Args:
dataset_index: DataFrame with 'chrom', 'start', 'end' columns.
batch_size: Mini-batch size.
block_size: Number of samples to keep contiguous (should be >> batch_size).
Larger = better cache locality, worse randomness.
shuffle: Whether to shuffle the blocks.
drop_last: Whether to drop the last incomplete batch.
Examples
--------
"""
self.index = dataset_index
self.batch_size = batch_size
self.block_size = block_size
self.shuffle = shuffle
self.drop_last = drop_last
# Pre-calculate spatial sort order
# Sort by chrom, then start
if 'chrom' in self.index.columns and 'start' in self.index.columns:
# Handle chromosome sorting naturally (1, 2, 10...) if possible,
# but string sort is fine for clustering.
self.sorted_indices = np.lexsort((self.index['start'], self.index['chrom']))
elif 'start_bin' in self.index.columns: # MemmapDataset style
self.sorted_indices = np.argsort(self.index['start_bin'].values)
else:
# Fallback to linear if no spatial cols
self.sorted_indices = np.arange(len(self.index))
def __len__(self) -> int:
"""
Function __len__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
if self.drop_last:
return len(self.index) // self.batch_size
else:
return (len(self.index) + self.batch_size - 1) // self.batch_size
def __iter__(self) -> t.Iterator[t.List[int]]:
"""
Function __iter__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
indices = self.sorted_indices.copy()
if self.shuffle:
# Block Shuffling
# 1. Reshape into blocks (padding last block if needed for reshaping logic,
# but array_split is safer)
n_blocks = (len(indices) + self.block_size - 1) // self.block_size
blocks = np.array_split(indices, n_blocks)
# 2. Shuffle blocks
perm = np.random.permutation(n_blocks)
# 3. Flatten
shuffled_indices = []
for p in perm:
shuffled_indices.extend(blocks[p])
indices = shuffled_indices
# Batchify
batch = []
for idx in indices:
batch.append(int(idx))
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch