Source code for gunz_cm.samplers.spatial

# -*- coding: utf-8 -*-
"""
Samplers for spatial data locality optimization.


Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"

import numpy as np
import pandas as pd
from torch.utils.data import Sampler
import typing as t

[docs] class SpatialBatchSampler(Sampler[t.List[int]]): """ A BatchSampler that yields mini-batches ordered by spatial proximity to maximize cache hits in compressed files. To maintain randomness for SGD, it performs 'Block Shuffling': 1. Sorts the dataset spatially. 2. Groups indices into 'mega-blocks' (e.g., 50-100 samples). 3. Shuffles the order of mega-blocks. 4. Yields sequential mini-batches from within each mega-block. Examples -------- """ def __init__( self, dataset_index: pd.DataFrame, batch_size: int, block_size: int = 128, shuffle: bool = True, drop_last: bool = False, ): """ Args: dataset_index: DataFrame with 'chrom', 'start', 'end' columns. batch_size: Mini-batch size. block_size: Number of samples to keep contiguous (should be >> batch_size). Larger = better cache locality, worse randomness. shuffle: Whether to shuffle the blocks. drop_last: Whether to drop the last incomplete batch. Examples -------- """ self.index = dataset_index self.batch_size = batch_size self.block_size = block_size self.shuffle = shuffle self.drop_last = drop_last # Pre-calculate spatial sort order # Sort by chrom, then start if 'chrom' in self.index.columns and 'start' in self.index.columns: # Handle chromosome sorting naturally (1, 2, 10...) if possible, # but string sort is fine for clustering. self.sorted_indices = np.lexsort((self.index['start'], self.index['chrom'])) elif 'start_bin' in self.index.columns: # MemmapDataset style self.sorted_indices = np.argsort(self.index['start_bin'].values) else: # Fallback to linear if no spatial cols self.sorted_indices = np.arange(len(self.index)) def __len__(self) -> int: """ Function __len__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ if self.drop_last: return len(self.index) // self.batch_size else: return (len(self.index) + self.batch_size - 1) // self.batch_size def __iter__(self) -> t.Iterator[t.List[int]]: """ Function __iter__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ indices = self.sorted_indices.copy() if self.shuffle: # Block Shuffling # 1. Reshape into blocks (padding last block if needed for reshaping logic, # but array_split is safer) n_blocks = (len(indices) + self.block_size - 1) // self.block_size blocks = np.array_split(indices, n_blocks) # 2. Shuffle blocks perm = np.random.permutation(n_blocks) # 3. Flatten shuffled_indices = [] for p in perm: shuffled_indices.extend(blocks[p]) indices = shuffled_indices # Batchify batch = [] for idx in indices: batch.append(int(idx)) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch