Source code for gunz_cm.resolution_enhancements.datasets.ren_memmap_v1

"""
Module.

Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
import typing as t
import numpy as np
from torch.utils.data import Dataset
from ... import loaders as cm_loaders
from .. import preprocs as ren_preprocs

[docs] class RENMemmapDatasetV1(Dataset): """ Dataset class for loading and manipulating memmap data. Parameters ---------- hr_memmap_fpath : str Path to the high-resolution memmap file. win_size : int Window size for patch extraction. stride_size : int, optional Stride size for patch extraction. Defaults to `win_size`. transformations : list, optional List of transformations to apply to the data. Defaults to None. max_dist : int, optional Maximum distance between patches. Defaults to `shape[0]`. lr_memmap_fpath : str, optional Path to the low-resolution memmap file. Defaults to None. lr_ds_ratio : int, optional Downscale ratio for the low-resolution data. Defaults to None. ret_float : bool, optional Whether to return floating-point data. Defaults to True. Attributes ---------- _hr_f : memmap High-resolution memmap file. _lr_f : memmap Low-resolution memmap file. win_size : int Window size for patch extraction. transformations : list List of transformations to apply to the data. stride_size : int Stride size for patch extraction. npatches_in_row : int Number of patches in each row. npatches_in_col : int Number of patches in each column. max_diag : int Maximum diagonal distance between patches. i_j_pairs : ndarray Array of patch indices. ret_float : bool Whether to return floating-point data. Examples -------- """ def __init__(self, hr_memmap_fpath:str, win_size:int, stride_size:int=None, transformations:t.List=None, max_dist:int=None, lr_memmap_fpath:str=None, lr_ds_ratio:int=None, ret_float:bool=True, ): """ Initialize the RENMemmapDatasetV1 instance. Parameters ---------- hr_memmap_fpath : str Path to the high-resolution memmap file. win_size : int Window size for patch extraction. stride_size : int, optional Stride size for patch extraction. Defaults to `win_size`. transformations : list, optional List of transformations to apply to the data. Defaults to None. max_dist : int, optional Maximum distance between patches. Defaults to `shape[0]`. lr_memmap_fpath : str, optional Path to the low-resolution memmap file. Defaults to None. lr_ds_ratio : int, optional Downscale ratio for the low-resolution data. Defaults to None. ret_float : bool, optional Whether to return floating-point data. Defaults to True. Returns ------- None Examples -------- """ assert lr_memmap_fpath is not None or lr_ds_ratio is not None, \ "either the path to the LR data or the downscale ratio has to be specified!" self._hr_f = cm_loaders.load_memmap(hr_memmap_fpath) if lr_memmap_fpath is None: self._use_ds_transform = True self._lr_f = cm_loaders.load_memmap(hr_memmap_fpath) self._ds_ratio = lr_ds_ratio else: self._use_ds_transform = False self._lr_f = cm_loaders.load_memmap(lr_memmap_fpath) self._ds_ratio = lr_ds_ratio hr_nrows, hr_ncols = self._hr_f.shape lr_nrows, lr_ncols = self._lr_f.shape assert hr_nrows == lr_nrows and hr_ncols == lr_ncols, \ "LR data does not have the same shape as the HR data" self.win_size = win_size self.transformations = transformations self.stride_size = stride_size if stride_size is not None else win_size self.npatches_in_row = int(np.floor((hr_nrows - self.win_size) / self.stride_size + 1)) self.npatches_in_col = int(np.floor((hr_ncols - self.win_size) / self.stride_size + 1)) if max_dist is None: max_dist = self.shape[0] self.max_diag = int(np.floor((max_dist - self.win_size)/self.stride_size)) mask = np.tri(*self.shape, dtype=bool, k=self.max_diag) mask &= mask.T self.i_j_pairs = np.argwhere(mask) self.ret_float = ret_float @property def shape(self): """ Function shape. Parameters ---------- Returns ------- Examples -------- Notes ----- """ return (self.npatches_in_row, self.npatches_in_col) def __len__(self) -> int: """ Return the total number of patches in a dataset. Returns ------- int: The total number of patches Examples -------- """ return len(self.i_j_pairs) def _load_data_from_mem(self, i:int, j:int, ): """ Loads data from memory. Parameters ---------- i : int Row index. j : int Column index. Returns ------- dict A dictionary containing the loaded data. Examples -------- """ row_slice = slice(self.stride_size*i, self.stride_size*i+self.win_size) col_slice = slice(self.stride_size*j, self.stride_size*j+self.win_size) hr_patch_mat = np.array(self._hr_f[row_slice, col_slice]) hr_patch_mat = hr_patch_mat.reshape(*hr_patch_mat.shape, -1) lr_patch_mat = np.array(self._lr_f[row_slice, col_slice]) lr_patch_mat = hr_patch_mat.reshape(*lr_patch_mat.shape, -1) #? If On-the-fly downscaling method is used if self._use_ds_transform: lr_patch_mat = ( ren_preprocs.downscale_counts( lr_patch_mat.flatten(), self._ds_ratio ) .reshape(*hr_patch_mat.shape) ) if self.ret_float: lr_patch_mat = lr_patch_mat.astype(np.float32) hr_patch_mat = hr_patch_mat.astype(np.float32) return { 'X': lr_patch_mat, 'Y': hr_patch_mat, }
[docs] def load_data(self, idx ) -> np.ndarray: """ Loads a patch of size win_size with stride size stride_size using the provided index idx. Parameters ---------- idx : int The index of the patch to be loaded. Returns ------- numpy.ndarray A numpy array containing the patch of size win_size. Examples -------- """ i, j = self.i_j_pairs[idx] return self._load_data_from_mem(i, j)
[docs] def any_transformation(self) -> bool: """ Checks if there are any transformations present. Returns ------- bool True if there are transformations, False otherwise. Examples -------- """ return self.transformations is not None
[docs] def transform_item(self, data:t.Dict ) -> t.Dict: """ Applies transformations to the data item. Parameters ---------- data : dict Input data. Returns ------- dict Transformed data. Examples -------- """ return self.transformations(data)
def __getitem__(self, idx:int ) -> t.Dict: """ Returns the data item at the given index. Parameters ---------- idx : int Index of the data in the dataset. Returns ------- dict (Transformed) data item. Examples -------- """ assert idx < len(self) data = self.load_data(idx) if self.any_transformation(): data = self.transform_item(data) return data
[docs] def set_transformations(self, transformations ): """ Sets the transformations. Parameters ---------- transformations The transformations to be set. Examples -------- """ self.transformations = transformations