Source code for gunz_cm.resolution_enhancements.datasets.memmap_v2

"""
Module.

Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
import typing as t
import numpy as np

from .memmap_v1 import MemmapDatasetV1

[docs] class MemmapDatasetV2(MemmapDatasetV1): """PyTorch Dataset for loading specific patches from a memory-mapped file based on a provided max_distance attribute constraining the diagonal of the patches. Inherits from MemmapDatasetV1 and extends its functionality as follows: - Adds a `max_dist` attribute to calculate the maximum diagonal length `max_diag` - Modifies `__len__` method to return the total number of patches within the `max_diag` - Modifies `load_data` method to load a patch of size `win_size` with stride size `stride_size` using the provided index `idx` only if the index corresponds to a valid pair within the `max_diag` See MemmapDatasetV1 docstring for shared attributes and methods. Examples -------- """ def __init__(self, memmap_fpath:str, win_size:int, stride_size:int=None, transformations:t.Optional[t.Callable]=None, max_dist:int=None, ): """Initialize Memmap Dataset Version 2 Parameters ---------- memmap_fpath : str The path to the memmap file to be loaded. win_size : int The size of the patches to be created. stride_size : int, optional The stride size or the step size from one patch to another patch. If not provided, it is set to win_size. Defaults to None. transformations : callable, optional Transformations to apply to the data. max_dist : int, optional The maximum distance from the main diagonal to consider for patch creation. If not provided, it is set to the number of rows in the memmap file. Defaults to None. Examples -------- """ super().__init__( memmap_fpath, win_size, stride_size=stride_size, transformations=transformations, ) if max_dist is None: max_dist = self.shape[0] self.max_diag = int(np.floor((max_dist - self.win_size)/self.stride_size)) mask = np.tri(*self.shape, dtype=bool, k=self.max_diag) mask &= mask.T self.i_j_pairs = np.argwhere(mask) def __len__(self) -> int: """ Return the total number of patches in a dataset within the `max_diag`. Returns ------- int The total number of patches within the `max_diag` Examples -------- """ return len(self.i_j_pairs)
[docs] def load_data(self, idx ) -> np.ndarray: """ Loads a patch of size `win_size` with stride size `stride_size` using the provided index `idx` only if the index corresponds to a valid pair within the `max_diag`. Parameters ---------- idx : int The index of the patch to be loaded. Returns ------- np.ndarray A numpy array containing the patch of size `win_size` at the indicated valid location. Examples -------- """ i, j = self.i_j_pairs[idx] return self._load_data_from_mem(i, j)