Source code for gunz_cm.resolution_enhancements.datasets.memmap_v1

"""
Module.

Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
import typing as t
import numpy as np
from torch.utils.data import Dataset

from ...loaders import load_memmap

[docs] class MemmapDatasetV1(Dataset): """ A PyTorch Dataset class for loading patches from a memory-mapped file. Args: memmap_fpath (str): The path to the memory-mapped file. win_size (int): The size of the window/patch to extract from the memory-mapped file. stride_size (int, optional): The stride size to use when extracting patches. Defaults to win_size. transformations (callable, optional): A function/transform to apply to the data. Defaults to None. Attributes: _f (numpy.memmap): The memory-mapped file. win_size (int): The size of the window/patch to extract from the memory-mapped file. stride_size (int): The stride size to use when extracting patches. transformations (callable): A function/transform to apply to the data. npatches_in_row (int): The number of patches that can be extracted from a single row. npatches_in_col (int): The number of patches that can be extracted from a single column. Examples -------- """ #TODO: Check for possible bug which the patch is smaller than win_size at the border. def __init__(self, memmap_fpath:str, win_size:int, stride_size:int=None, transformations:t.Optional[t.Callable]=None ): """ Initialize the MemmapDataset class. Parameters ---------- memmap_fpath : str Path to the memory-mapped file. win_size : int Size of the window. stride_size : int, optional Stride size. If None, stride size is set equal to win_size. transformations : callable, optional Transformations to apply to the data. Examples -------- """ self._f = load_memmap(memmap_fpath) self.win_size = win_size self.transformations = transformations nrows, ncols = self._f.shape if stride_size is None: self.stride_size = win_size else: self.stride_size = stride_size if not (isinstance(self.stride_size, int) and self.stride_size > 0): #? Mandatory check to prevent downstream injection raise ValueError("stride_size must be a positive integer.") self.npatches_in_row = int(np.floor((nrows - self.win_size) / self.stride_size + 1 )) self.npatches_in_col = int(np.floor((ncols - self.win_size) / self.stride_size + 1 )) def __len__(self) -> int: """ Get the total number of patches in the dataset. Returns ------- int Total number of patches. Examples -------- """ return self.npatches_in_row * self.npatches_in_col @property def shape(self): """ Returns the shape of the dataset as a tuple. Returns ------- tuple The shape of the dataset as a tuple (npatches_in_row, npatches_in_col). Examples -------- """ return (self.npatches_in_row, self.npatches_in_col)
[docs] def any_transformation(self): """ Check if any transformations are defined. Returns ------- bool True if transformations are defined, False otherwise. Examples -------- """ return self.transformations is not None
def _load_data_from_mem(self, i:int, j:int, ): """ Private method to load data from memory based on the given indices. Parameters ---------- i : int Row index. j : int Column index. Returns ------- dict A dictionary containing the loaded patch data under the key 'X'. Examples -------- """ row_slice = slice(self.stride_size*i, self.stride_size*i+self.win_size) col_slice = slice(self.stride_size*j, self.stride_size*j+self.win_size) patch_mat = self._f[row_slice, col_slice] data = dict() data['X'] = patch_mat return data
[docs] def load_data(self, idx:int ): """ Load a patch from the memory-mapped file based on the given index. Parameters ---------- idx : int Index of the patch. Returns ------- dict A dictionary containing the loaded patch data under the key 'X'. Examples -------- """ if idx >= len(self): #? Mandatory check to prevent downstream injection raise ValueError("idx is greater than the length!") i, j = np.divmod(idx, self.npatches_in_col) return self._load_data_from_mem(i, j)
[docs] def transform_item(self, data:t.Dict ) -> t.Dict: """ Apply defined transformations to the data item. Parameters ---------- data : dict Input data item. Returns ------- dict Transformed data item. Examples -------- """ return self.transformations(data)
def __getitem__(self, idx:int, ): """ Get the data item at the given index. Parameters ---------- idx : int Index of the data item. Returns ------- dict The data item. Examples -------- """ if idx >= len(self): #? Mandatory check to prevent downstream injection raise ValueError("idx is greater than the length!") data = dict() data['input'] = self.load_data(idx) if self.any_transformation(): data = self.transform_item(data) return data