Source code for gunz_cm.datasets.gnz

# -*- coding: utf-8 -*-
"""
Dataset for .gnz unified container.


Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
__version__ = "1.0.0"

import numpy as np
import pandas as pd
try:
    import torch
except ImportError:
    torch = None


import typing as t

from ..exceptions import DatasetError
from ..io.gnz import GnzReader

DatasetType = torch.utils.data.Dataset if torch else object

[docs] class GnzSparseDataset(DatasetType): """ Class GnzSparseDataset. Parameters ---------- Returns ------- Examples -------- Notes ----- """ def __init__(self, fpath: str, window_size: int, output_type: str = "sparse", downsample_ratio: t.Union[float, t.Tuple[float, float], None] = None): """ Function __init__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ self.reader = GnzReader(fpath) self.metadata = self.reader.get_metadata() self.resolution = self.metadata["resolution"] self.window_size = window_size self.output_type = output_type self.downsample_ratio = downsample_ratio self.layout = self.metadata.get("layout", "dense") self.block_size = self.metadata.get("block_size") # Initialize Layouts if self.layout == "dense": self.matrix = self.reader.get_array("matrix", mode="r") n_bins = self.matrix.shape[0] self.dtype = self.matrix.dtype elif self.layout == "tiled": self.matrix = self.reader.get_array("matrix", mode="r") n_bins = self.matrix.shape[0] * self.block_size self.dtype = self.matrix.dtype elif self.layout == "csr": self.indptr = self.reader.get_array("indptr", mode="r") self.indices = self.reader.get_array("indices", mode="r") self.data = self.reader.get_array("data", mode="r") n_bins = len(self.indptr) - 1 self.dtype = self.data.dtype elif self.layout == "block_sparse": self.block_index = self.reader.get_array("block_index", mode="r") self.block_data = self.reader.get_array("block_data", mode="r") n_bins = self.block_index.shape[0] * self.block_size self.dtype = self.block_data.dtype else: raise DatasetError(f"Unknown layout: {self.layout}") # Check for weights self.weights = None for key in self.reader.keys(): if key.startswith("weights_"): self.weights = self.reader.get_array(key, mode="r") break # Index step = window_size // self.resolution starts = np.arange(0, n_bins, step) ends = np.clip(starts + step, 0, n_bins) self.index = pd.DataFrame({'start_bin': starts, 'end_bin': ends}) def __len__(self): """ Function __len__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ return len(self.index) def _get_dense_patch_tiled(self, s, e): """ Function _get_dense_patch_tiled. Parameters ---------- Returns ------- Examples -------- Notes ----- """ B = self.block_size b_start = s // B b_end = (e - 1) // B + 1 patch_h = e - s patch = np.zeros((patch_h, patch_h), dtype=self.dtype) for br in range(b_start, b_end): for bc in range(b_start, b_end): gs_r, ge_r = br * B, (br + 1) * B gs_c, ge_c = bc * B, (bc + 1) * B is_r, ie_r = max(s, gs_r), min(e, ge_r) is_c, ie_c = max(s, gs_c), min(e, ge_c) if is_r < ie_r and is_c < ie_c: if self.layout == "tiled": # Check bounds for padded matrix if br < self.matrix.shape[0] and bc < self.matrix.shape[1]: block = self.matrix[br, bc] else: continue else: # block_sparse if br < self.block_index.shape[0] and bc < self.block_index.shape[1]: idx = self.block_index[br, bc] if idx == -1: continue block = self.block_data[idx] else: continue patch[is_r-s:ie_r-s, is_c-s:ie_c-s] = block[is_r-gs_r:ie_r-gs_r, is_c-gs_c:ie_c-gs_c] return patch def _get_csr_coo(self, s, e): """ Function _get_csr_coo. Parameters ---------- Returns ------- Examples -------- Notes ----- """ r_list, c_list, v_list = [], [], [] # Iterate rows in range for r in range(s, e): p0 = self.indptr[r] p1 = self.indptr[r+1] if p1 > p0: cols = self.indices[p0:p1] vals = self.data[p0:p1] mask = (cols >= s) & (cols < e) if np.any(mask): r_list.append(np.full(mask.sum(), r - s, dtype=np.int64)) c_list.append(cols[mask] - s) v_list.append(vals[mask]) if not r_list: return np.array([], dtype=np.int64), np.array([], dtype=np.int64), np.array([], dtype=self.dtype) return np.concatenate(r_list), np.concatenate(c_list), np.concatenate(v_list) def __getitem__(self, idx): """ Function __getitem__. Parameters ---------- Returns ------- Examples -------- Notes ----- """ row = self.index.iloc[idx] s, e = int(row['start_bin']), int(row['end_bin']) if self.output_type == "dense": if self.layout == "dense": patch = self.matrix[s:e, s:e] elif self.layout in ["tiled", "block_sparse"]: patch = self._get_dense_patch_tiled(s, e) elif self.layout == "csr": r, c, v = self._get_csr_coo(s, e) h = e - s patch = np.zeros((h, h), dtype=self.dtype) if len(r) > 0: patch[r, c] = v # Apply Weights if self.weights is not None: w_slice = self.weights[s:e] denom = np.outer(w_slice, w_slice) patch = patch.astype(np.float32) / denom np.nan_to_num(patch, copy=False, nan=0.0, posinf=0.0, neginf=0.0) return torch.from_numpy(patch).float().unsqueeze(0) # Sparse Output if self.layout == "dense": patch = self.matrix[s:e, s:e] r, c = np.nonzero(patch) mask = r <= c r, c = r[mask], c[mask] counts = patch[r, c] elif self.layout in ["tiled", "block_sparse"]: patch = self._get_dense_patch_tiled(s, e) r, c = np.nonzero(patch) mask = r <= c r, c = r[mask], c[mask] counts = patch[r, c] elif self.layout == "csr": r, c, counts = self._get_csr_coo(s, e) mask = r <= c r, c, counts = r[mask], c[mask], counts[mask] if self.weights is not None: w1 = self.weights[s+r] w2 = self.weights[s+c] counts = counts.astype(np.float32) / (w1 * w2) mask_v = np.isfinite(counts) if not np.all(mask_v): r, c, counts = r[mask_v], c[mask_v], counts[mask_v] # 3. Augmentation (Random Downsampling) target_counts = counts.copy() if self.downsample_ratio is not None: if isinstance(self.downsample_ratio, tuple): alpha = np.random.uniform(*self.downsample_ratio) else: alpha = self.downsample_ratio counts = np.random.binomial(counts.astype(np.int32), alpha) mask = counts > 0 r = r[mask] c = c[mask] counts = counts[mask] coords = np.stack([r, c], axis=1) return { "coords": torch.from_numpy(coords).long(), "features": torch.from_numpy(counts).float().unsqueeze(1), "target": torch.from_numpy(target_counts).float(), "info": {"start": s * self.resolution} }