"""
Module.
Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
import typing as t
import numpy as np
from torch.utils.data import Dataset
from ...loaders import load_memmap
[docs]
class MemmapDatasetV1(Dataset):
"""
A PyTorch Dataset class for loading patches from a memory-mapped file.
Args:
memmap_fpath (str): The path to the memory-mapped file.
win_size (int): The size of the window/patch to extract from the memory-mapped file.
stride_size (int, optional): The stride size to use when extracting patches. Defaults to win_size.
transformations (callable, optional): A function/transform to apply to the data. Defaults to None.
Attributes:
_f (numpy.memmap): The memory-mapped file.
win_size (int): The size of the window/patch to extract from the memory-mapped file.
stride_size (int): The stride size to use when extracting patches.
transformations (callable): A function/transform to apply to the data.
npatches_in_row (int): The number of patches that can be extracted from a single row.
npatches_in_col (int): The number of patches that can be extracted from a single column.
Examples
--------
"""
#TODO: Check for possible bug which the patch is smaller than win_size at the border.
def __init__(self,
memmap_fpath:str,
win_size:int,
stride_size:int=None,
transformations:t.Optional[t.Callable]=None
):
"""
Initialize the MemmapDataset class.
Parameters
----------
memmap_fpath : str
Path to the memory-mapped file.
win_size : int
Size of the window.
stride_size : int, optional
Stride size. If None, stride size is set equal to win_size.
transformations : callable, optional
Transformations to apply to the data.
Examples
--------
"""
self._f = load_memmap(memmap_fpath)
self.win_size = win_size
self.transformations = transformations
nrows, ncols = self._f.shape
if stride_size is None:
self.stride_size = win_size
else:
self.stride_size = stride_size
if not (isinstance(self.stride_size, int) and self.stride_size > 0):
#? Mandatory check to prevent downstream injection
raise ValueError("stride_size must be a positive integer.")
self.npatches_in_row = int(np.floor((nrows - self.win_size) / self.stride_size + 1 ))
self.npatches_in_col = int(np.floor((ncols - self.win_size) / self.stride_size + 1 ))
def __len__(self) -> int:
"""
Get the total number of patches in the dataset.
Returns
-------
int
Total number of patches.
Examples
--------
"""
return self.npatches_in_row * self.npatches_in_col
@property
def shape(self):
"""
Returns the shape of the dataset as a tuple.
Returns
-------
tuple
The shape of the dataset as a tuple (npatches_in_row, npatches_in_col).
Examples
--------
"""
return (self.npatches_in_row, self.npatches_in_col)
def _load_data_from_mem(self,
i:int,
j:int,
):
"""
Private method to load data from memory based on the given indices.
Parameters
----------
i : int
Row index.
j : int
Column index.
Returns
-------
dict
A dictionary containing the loaded patch data under the key 'X'.
Examples
--------
"""
row_slice = slice(self.stride_size*i, self.stride_size*i+self.win_size)
col_slice = slice(self.stride_size*j, self.stride_size*j+self.win_size)
patch_mat = self._f[row_slice, col_slice]
data = dict()
data['X'] = patch_mat
return data
[docs]
def load_data(self,
idx:int
):
"""
Load a patch from the memory-mapped file based on the given index.
Parameters
----------
idx : int
Index of the patch.
Returns
-------
dict
A dictionary containing the loaded patch data under the key 'X'.
Examples
--------
"""
if idx >= len(self):
#? Mandatory check to prevent downstream injection
raise ValueError("idx is greater than the length!")
i, j = np.divmod(idx, self.npatches_in_col)
return self._load_data_from_mem(i, j)
def __getitem__(self,
idx:int,
):
"""
Get the data item at the given index.
Parameters
----------
idx : int
Index of the data item.
Returns
-------
dict
The data item.
Examples
--------
"""
if idx >= len(self):
#? Mandatory check to prevent downstream injection
raise ValueError("idx is greater than the length!")
data = dict()
data['input'] = self.load_data(idx)
if self.any_transformation():
data = self.transform_item(data)
return data