"""
Module.
Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
import typing as t
import numpy as np
from torch.utils.data import Dataset
from ... import loaders as cm_loaders
from .. import preprocs as ren_preprocs
[docs]
class RENMemmapDatasetV1(Dataset):
"""
Dataset class for loading and manipulating memmap data.
Parameters
----------
hr_memmap_fpath : str
Path to the high-resolution memmap file.
win_size : int
Window size for patch extraction.
stride_size : int, optional
Stride size for patch extraction. Defaults to `win_size`.
transformations : list, optional
List of transformations to apply to the data. Defaults to None.
max_dist : int, optional
Maximum distance between patches. Defaults to `shape[0]`.
lr_memmap_fpath : str, optional
Path to the low-resolution memmap file. Defaults to None.
lr_ds_ratio : int, optional
Downscale ratio for the low-resolution data. Defaults to None.
ret_float : bool, optional
Whether to return floating-point data. Defaults to True.
Attributes
----------
_hr_f : memmap
High-resolution memmap file.
_lr_f : memmap
Low-resolution memmap file.
win_size : int
Window size for patch extraction.
transformations : list
List of transformations to apply to the data.
stride_size : int
Stride size for patch extraction.
npatches_in_row : int
Number of patches in each row.
npatches_in_col : int
Number of patches in each column.
max_diag : int
Maximum diagonal distance between patches.
i_j_pairs : ndarray
Array of patch indices.
ret_float : bool
Whether to return floating-point data.
Examples
--------
"""
def __init__(self,
hr_memmap_fpath:str,
win_size:int,
stride_size:int=None,
transformations:t.List=None,
max_dist:int=None,
lr_memmap_fpath:str=None,
lr_ds_ratio:int=None,
ret_float:bool=True,
):
"""
Initialize the RENMemmapDatasetV1 instance.
Parameters
----------
hr_memmap_fpath : str
Path to the high-resolution memmap file.
win_size : int
Window size for patch extraction.
stride_size : int, optional
Stride size for patch extraction. Defaults to `win_size`.
transformations : list, optional
List of transformations to apply to the data. Defaults to None.
max_dist : int, optional
Maximum distance between patches. Defaults to `shape[0]`.
lr_memmap_fpath : str, optional
Path to the low-resolution memmap file. Defaults to None.
lr_ds_ratio : int, optional
Downscale ratio for the low-resolution data. Defaults to None.
ret_float : bool, optional
Whether to return floating-point data. Defaults to True.
Returns
-------
None
Examples
--------
"""
assert lr_memmap_fpath is not None or lr_ds_ratio is not None, \
"either the path to the LR data or the downscale ratio has to be specified!"
self._hr_f = cm_loaders.load_memmap(hr_memmap_fpath)
if lr_memmap_fpath is None:
self._use_ds_transform = True
self._lr_f = cm_loaders.load_memmap(hr_memmap_fpath)
self._ds_ratio = lr_ds_ratio
else:
self._use_ds_transform = False
self._lr_f = cm_loaders.load_memmap(lr_memmap_fpath)
self._ds_ratio = lr_ds_ratio
hr_nrows, hr_ncols = self._hr_f.shape
lr_nrows, lr_ncols = self._lr_f.shape
assert hr_nrows == lr_nrows and hr_ncols == lr_ncols, \
"LR data does not have the same shape as the HR data"
self.win_size = win_size
self.transformations = transformations
self.stride_size = stride_size if stride_size is not None else win_size
self.npatches_in_row = int(np.floor((hr_nrows - self.win_size) / self.stride_size + 1))
self.npatches_in_col = int(np.floor((hr_ncols - self.win_size) / self.stride_size + 1))
if max_dist is None:
max_dist = self.shape[0]
self.max_diag = int(np.floor((max_dist - self.win_size)/self.stride_size))
mask = np.tri(*self.shape, dtype=bool, k=self.max_diag)
mask &= mask.T
self.i_j_pairs = np.argwhere(mask)
self.ret_float = ret_float
@property
def shape(self):
"""
Function shape.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
return (self.npatches_in_row, self.npatches_in_col)
def __len__(self) -> int:
"""
Return the total number of patches in a dataset.
Returns
-------
int:
The total number of patches
Examples
--------
"""
return len(self.i_j_pairs)
def _load_data_from_mem(self,
i:int,
j:int,
):
"""
Loads data from memory.
Parameters
----------
i : int
Row index.
j : int
Column index.
Returns
-------
dict
A dictionary containing the loaded data.
Examples
--------
"""
row_slice = slice(self.stride_size*i, self.stride_size*i+self.win_size)
col_slice = slice(self.stride_size*j, self.stride_size*j+self.win_size)
hr_patch_mat = np.array(self._hr_f[row_slice, col_slice])
hr_patch_mat = hr_patch_mat.reshape(*hr_patch_mat.shape, -1)
lr_patch_mat = np.array(self._lr_f[row_slice, col_slice])
lr_patch_mat = hr_patch_mat.reshape(*lr_patch_mat.shape, -1)
#? If On-the-fly downscaling method is used
if self._use_ds_transform:
lr_patch_mat = (
ren_preprocs.downscale_counts(
lr_patch_mat.flatten(),
self._ds_ratio
)
.reshape(*hr_patch_mat.shape)
)
if self.ret_float:
lr_patch_mat = lr_patch_mat.astype(np.float32)
hr_patch_mat = hr_patch_mat.astype(np.float32)
return {
'X': lr_patch_mat,
'Y': hr_patch_mat,
}
[docs]
def load_data(self,
idx
) -> np.ndarray:
"""
Loads a patch of size win_size with stride size stride_size using the provided index idx.
Parameters
----------
idx : int
The index of the patch to be loaded.
Returns
-------
numpy.ndarray
A numpy array containing the patch of size win_size.
Examples
--------
"""
i, j = self.i_j_pairs[idx]
return self._load_data_from_mem(i, j)
def __getitem__(self,
idx:int
) -> t.Dict:
"""
Returns the data item at the given index.
Parameters
----------
idx : int
Index of the data in the dataset.
Returns
-------
dict
(Transformed) data item.
Examples
--------
"""
assert idx < len(self)
data = self.load_data(idx)
if self.any_transformation():
data = self.transform_item(data)
return data