# -*- coding: utf-8 -*-
"""
Dataset for .gnz unified container.
Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"
import numpy as np
import pandas as pd
try:
import torch
except ImportError:
torch = None
import typing as t
from ..exceptions import DatasetError
from ..io.gnz import GnzReader
DatasetType = torch.utils.data.Dataset if torch else object
[docs]
class GnzSparseDataset(DatasetType):
"""
Class GnzSparseDataset.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
def __init__(self, fpath: str, window_size: int, output_type: str = "sparse",
downsample_ratio: t.Union[float, t.Tuple[float, float], None] = None):
"""
Function __init__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
self.reader = GnzReader(fpath)
self.metadata = self.reader.get_metadata()
self.resolution = self.metadata["resolution"]
self.window_size = window_size
self.output_type = output_type
self.downsample_ratio = downsample_ratio
self.layout = self.metadata.get("layout", "dense")
self.block_size = self.metadata.get("block_size")
# Initialize Layouts
if self.layout == "dense":
self.matrix = self.reader.get_array("matrix", mode="r")
n_bins = self.matrix.shape[0]
self.dtype = self.matrix.dtype
elif self.layout == "tiled":
self.matrix = self.reader.get_array("matrix", mode="r")
n_bins = self.matrix.shape[0] * self.block_size
self.dtype = self.matrix.dtype
elif self.layout == "csr":
self.indptr = self.reader.get_array("indptr", mode="r")
self.indices = self.reader.get_array("indices", mode="r")
self.data = self.reader.get_array("data", mode="r")
n_bins = len(self.indptr) - 1
self.dtype = self.data.dtype
elif self.layout == "block_sparse":
self.block_index = self.reader.get_array("block_index", mode="r")
self.block_data = self.reader.get_array("block_data", mode="r")
n_bins = self.block_index.shape[0] * self.block_size
self.dtype = self.block_data.dtype
else:
raise DatasetError(f"Unknown layout: {self.layout}")
# Check for weights
self.weights = None
for key in self.reader.keys():
if key.startswith("weights_"):
self.weights = self.reader.get_array(key, mode="r")
break
# Index
step = window_size // self.resolution
starts = np.arange(0, n_bins, step)
ends = np.clip(starts + step, 0, n_bins)
self.index = pd.DataFrame({'start_bin': starts, 'end_bin': ends})
def __len__(self):
"""
Function __len__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
return len(self.index)
def _get_dense_patch_tiled(self, s, e):
"""
Function _get_dense_patch_tiled.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
B = self.block_size
b_start = s // B
b_end = (e - 1) // B + 1
patch_h = e - s
patch = np.zeros((patch_h, patch_h), dtype=self.dtype)
for br in range(b_start, b_end):
for bc in range(b_start, b_end):
gs_r, ge_r = br * B, (br + 1) * B
gs_c, ge_c = bc * B, (bc + 1) * B
is_r, ie_r = max(s, gs_r), min(e, ge_r)
is_c, ie_c = max(s, gs_c), min(e, ge_c)
if is_r < ie_r and is_c < ie_c:
if self.layout == "tiled":
# Check bounds for padded matrix
if br < self.matrix.shape[0] and bc < self.matrix.shape[1]:
block = self.matrix[br, bc]
else:
continue
else: # block_sparse
if br < self.block_index.shape[0] and bc < self.block_index.shape[1]:
idx = self.block_index[br, bc]
if idx == -1: continue
block = self.block_data[idx]
else:
continue
patch[is_r-s:ie_r-s, is_c-s:ie_c-s] = block[is_r-gs_r:ie_r-gs_r, is_c-gs_c:ie_c-gs_c]
return patch
def _get_csr_coo(self, s, e):
"""
Function _get_csr_coo.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
r_list, c_list, v_list = [], [], []
# Iterate rows in range
for r in range(s, e):
p0 = self.indptr[r]
p1 = self.indptr[r+1]
if p1 > p0:
cols = self.indices[p0:p1]
vals = self.data[p0:p1]
mask = (cols >= s) & (cols < e)
if np.any(mask):
r_list.append(np.full(mask.sum(), r - s, dtype=np.int64))
c_list.append(cols[mask] - s)
v_list.append(vals[mask])
if not r_list:
return np.array([], dtype=np.int64), np.array([], dtype=np.int64), np.array([], dtype=self.dtype)
return np.concatenate(r_list), np.concatenate(c_list), np.concatenate(v_list)
def __getitem__(self, idx):
"""
Function __getitem__.
Parameters
----------
Returns
-------
Examples
--------
Notes
-----
"""
row = self.index.iloc[idx]
s, e = int(row['start_bin']), int(row['end_bin'])
if self.output_type == "dense":
if self.layout == "dense":
patch = self.matrix[s:e, s:e]
elif self.layout in ["tiled", "block_sparse"]:
patch = self._get_dense_patch_tiled(s, e)
elif self.layout == "csr":
r, c, v = self._get_csr_coo(s, e)
h = e - s
patch = np.zeros((h, h), dtype=self.dtype)
if len(r) > 0:
patch[r, c] = v
# Apply Weights
if self.weights is not None:
w_slice = self.weights[s:e]
denom = np.outer(w_slice, w_slice)
patch = patch.astype(np.float32) / denom
np.nan_to_num(patch, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
return torch.from_numpy(patch).float().unsqueeze(0)
# Sparse Output
if self.layout == "dense":
patch = self.matrix[s:e, s:e]
r, c = np.nonzero(patch)
mask = r <= c
r, c = r[mask], c[mask]
counts = patch[r, c]
elif self.layout in ["tiled", "block_sparse"]:
patch = self._get_dense_patch_tiled(s, e)
r, c = np.nonzero(patch)
mask = r <= c
r, c = r[mask], c[mask]
counts = patch[r, c]
elif self.layout == "csr":
r, c, counts = self._get_csr_coo(s, e)
mask = r <= c
r, c, counts = r[mask], c[mask], counts[mask]
if self.weights is not None:
w1 = self.weights[s+r]
w2 = self.weights[s+c]
counts = counts.astype(np.float32) / (w1 * w2)
mask_v = np.isfinite(counts)
if not np.all(mask_v):
r, c, counts = r[mask_v], c[mask_v], counts[mask_v]
# 3. Augmentation (Random Downsampling)
target_counts = counts.copy()
if self.downsample_ratio is not None:
if isinstance(self.downsample_ratio, tuple):
alpha = np.random.uniform(*self.downsample_ratio)
else:
alpha = self.downsample_ratio
counts = np.random.binomial(counts.astype(np.int32), alpha)
mask = counts > 0
r = r[mask]
c = c[mask]
counts = counts[mask]
coords = np.stack([r, c], axis=1)
return {
"coords": torch.from_numpy(coords).long(),
"features": torch.from_numpy(counts).float().unsqueeze(1),
"target": torch.from_numpy(target_counts).float(),
"info": {"start": s * self.resolution}
}