Source code for gunz_cm.compressions.zstd_decoder

"""
Zstd decoder wrapper for GZCM v3 compression.

Uses zlib as fallback since zstandard may not be installed.

Examples
--------
"""

__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"

import zlib
import numpy as np

try:
    import zstandard as zstd

    HAS_ZSTD = True
except ImportError:
    HAS_ZSTD = False


[docs]class ZstdDecoder: """Zstd decoder for contact matrix tiles. Parameters ---------- tile_size : int, default=256 Tile size for block processing. resolution : int, default=50000 Hi-C resolution in bp. dtype : np.dtype, default=np.uint32 Data type for decoded tiles. use_zstd : bool, default=True Use zstd if available, otherwise zlib fallback. Examples -------- """ def __init__( self, tile_size: int = 256, resolution: int = 50000, dtype: np.dtype = np.uint32, use_zstd: bool = True, ): """ Examples -------- """ self.tile_size = tile_size self.resolution = resolution self.dtype = np.dtype(dtype) self.use_zstd = use_zstd and HAS_ZSTD
[docs] def decode_tile(self, payload: bytes) -> np.ndarray: """Decode a single compressed tile. Parameters ---------- payload : bytes Compressed bitstream. Returns ------- np.ndarray Decoded contact matrix tile. Examples -------- """ if self.use_zstd: ctx = zstd.ZstdDecompressor() data = ctx.decompress(payload) else: data = zlib.decompress(payload) tile_size = self.tile_size return np.frombuffer(data, dtype=self.dtype).reshape(tile_size, tile_size)
[docs] def decode_tiles(self, payloads: list[bytes]) -> np.ndarray: """Decode multiple tiles into a 4D array. Parameters ---------- payloads : list[bytes] List of encoded bitstreams. Returns ------- np.ndarray 4D array of decoded tiles (n_tile_rows, n_tile_cols, tile_size, tile_size). Examples -------- """ n_tiles = len(payloads) decoded = [self.decode_tile(p) for p in payloads] tile_shape = decoded[0].shape tile_rows = int(np.sqrt(n_tiles)) tile_cols = n_tiles // tile_rows if tile_rows > 0 else 1 result = np.empty((tile_rows, tile_cols, *tile_shape), dtype=self.dtype) idx = 0 for i in range(tile_rows): for j in range(tile_cols): result[i, j] = decoded[idx] idx += 1 return result