Source code for gunz_cm.compressions.cmc_zstd_decoder
"""
CMC Transforms + Zstd Decoder for GZCM v3 compression.
Combines CMC's domain-specific transforms (diagonal transform, binarization)
with Zstd entropy coding for faster decode than pure CMC.
Examples
--------
"""
__author__ = "Yeremia Gunawan Adhisantoso"
__email__ = "adhisant@tnt.uni-hannover.de"
__license__ = "Clear BSD"
import sys
import pathlib
import zlib
import numpy as np
try:
import zstandard as zstd
HAS_ZSTD = True
except ImportError:
HAS_ZSTD = False
_FILE_PATH = pathlib.Path(__file__).resolve()
_WS_ROOT = _FILE_PATH.parent
for _ in range(4):
_WS_ROOT = _WS_ROOT.parent
_CMC_PATH = _WS_ROOT / "3d_recon" / "thirdparty" / "cmc"
if not _CMC_PATH.exists():
raise FileNotFoundError(f"CMC not found at {_CMC_PATH}")
if str(_CMC_PATH) not in sys.path:
sys.path.insert(0, str(_CMC_PATH))
import cmc.transform # noqa: E402
debinarize_rc_bin_split_v2 = cmc.transform.debinarize_rc_bin_split_v2
reverse_diag_transform_mode0 = cmc.transform.reverse_diag_transform_mode0
[docs]class CmcZstdDecoder:
"""CMC Transforms + Zstd decoder for contact matrix tiles.
Uses Zstd decompression then reverses CMC's domain-specific transforms.
Parameters
----------
tile_size : int, default=256
Tile size for block processing.
resolution : int, default=50000
Hi-C resolution in bp.
dtype : np.dtype, default=np.uint32
Data type for decoded tiles.
Examples
--------
"""
def __init__(
self,
tile_size: int = 256,
resolution: int = 50000,
dtype: np.dtype = np.uint32,
):
"""
Examples
--------
"""
self.tile_size = tile_size
self.resolution = resolution
self.dtype = np.dtype(dtype)
[docs] def decode_tile(self, payload: bytes) -> np.ndarray:
"""Decode a single compressed tile.
Parameters
----------
payload : bytes
Compressed bitstream (shape info + encoded data).
Returns
-------
np.ndarray
Decoded contact matrix tile.
Examples
--------
"""
shape = np.frombuffer(payload[:8], dtype=np.int32)
encoded_data = payload[8:]
if HAS_ZSTD:
ctx = zstd.ZstdDecompressor()
data = ctx.decompress(encoded_data)
else:
data = zlib.decompress(encoded_data)
bin_mat = np.frombuffer(data, dtype=np.bool_).reshape(shape)
debinarized = debinarize_rc_bin_split_v2(bin_mat, axis=0)
return reverse_diag_transform_mode0(debinarized)
[docs] def decode_tiles(self, payloads: list[bytes]) -> np.ndarray:
"""Decode multiple tiles into a 4D array.
Parameters
----------
payloads : list[bytes]
List of encoded bitstreams.
Returns
-------
np.ndarray
4D array of decoded tiles (n_tile_rows, n_tile_cols, tile_size, tile_size).
Examples
--------
"""
n_tiles = len(payloads)
decoded = [self.decode_tile(p) for p in payloads]
tile_shape = decoded[0].shape
tile_rows = int(np.sqrt(n_tiles))
tile_cols = n_tiles // tile_rows if tile_rows > 0 else 1
result = np.empty((tile_rows, tile_cols, *tile_shape), dtype=self.dtype)
idx = 0
for i in range(tile_rows):
for j in range(tile_cols):
result[i, j] = decoded[idx]
idx += 1
return result