Source code for gunz_cm.reconstructions.preprocs.points

# -*- coding: utf-8 -*-
"""
Module.

Examples
--------
"""
__version__ = "1.0.0"
__author__ = "Yeremia Gunawan Adhisantoso"
__license__ = "Clear BSD"
__email__ = "adhisant@tnt.uni-hannover.de"

import typing as t
from pydantic import validate_call, ConfigDict
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from ... import consts as cm_consts
from ...preprocs.masks import expand_with_nans
from ...preprocs.rc_filters import filter_empty_rowcols

[docs] @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def downsample_points( points: np.ndarray, ds_ratio: int, def_coor: float = np.nan, ) -> np.ndarray: """ Downsamples the given points by a specified ratio. Notes ----- - The function ensures that the `ds_ratio` is greater than 1. - Points with all NaN values are ignored during downsampling. - The resulting array is filled with `def_coor` for indices without valid points. Parameters ---------- points : np.ndarray The array of points to be downsampled. ds_ratio : int The downsampling ratio. Must be greater than 1. def_coor : float, optional The default coordinate value for indices without valid points, by default np.nan. Returns ------- np.ndarray The downsampled points array. Examples -------- Authors ------- - Yeremia G. Adhisantoso (adhisant@tnt.uni-hannover.de) - Qwen2.5 Coder 32B - 6.5bpw Examples -------- """ if not (isinstance(ds_ratio, int) and ds_ratio > 1): raise ValueError("Invalid ds_ratio!") num_points = points.shape[0] valid_points_mask = ~np.isnan(points).all(axis=1) valid_points_ids = np.arange(num_points)[valid_points_mask] points_df = pd.DataFrame({ 'ids': valid_points_ids, 'x': points[valid_points_mask, 0], 'y': points[valid_points_mask, 1], 'z': points[valid_points_mask, 2], }) points_df['lr_ids'] = points_df['ids'] // ds_ratio lr_points_df = ( points_df[['lr_ids', 'x', 'y', 'z']] .groupby('lr_ids') .mean() ) num_lr_points = int(np.ceil(num_points/ds_ratio)) valid_lr_ids = np.array(lr_points_df.index) lr_points = np.full((num_lr_points, 3), def_coor) lr_points[valid_lr_ids] = lr_points_df.to_numpy() return lr_points
[docs] @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def filter_points( points: np.ndarray, ret_mask: bool = False, ) -> t.Union[np.ndarray, t.Tuple[np.ndarray, np.ndarray]]: """ Filters out points with any NaN values. Notes ----- - If `ret_mask` is True, the function returns both the filtered points and the mask used for filtering. - If `ret_mask` is False, only the filtered points are returned. Parameters ---------- points : np.ndarray The array of points to be filtered. ret_mask : bool, optional Whether to return the mask used for filtering, by default False. Returns ------- np.ndarray or Tuple[np.ndarray, np.ndarray] The filtered points, and optionally the mask used for filtering. Examples -------- Authors ------- - Yeremia G. Adhisantoso (adhisant@tnt.uni-hannover.de) - Qwen2.5 Coder 32B - 6.5bpw Examples -------- """ mask = ~np.isnan(points).any(axis=1) new_points = points[mask, :] if ret_mask: return new_points, mask else: return points
[docs] @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def filter_valid_points( points: np.ndarray, cm_df: pd.DataFrame, ds_ratio: int = 1, ) -> np.ndarray: """ Filters valid points based on the provided DataFrame and downsampling ratio. This function is used when the coordinates of the points covers also the empty regions. Notes ----- - The function ensures that the `ds_ratio` is a positive integer. - It extracts unique row and column IDs from the DataFrame and filters the points accordingly. - If `ds_ratio` is greater than 1, it performs downsampling by averaging points within the same low-resolution ID. Parameters ---------- points : np.ndarray The array of points to be filtered. cm_df : pd.DataFrame The DataFrame containing row and column IDs. ds_ratio : int, optional The downsampling ratio, by default 1. Must be a positive integer. Returns ------- np.ndarray The filtered and optionally downsampled points. Examples -------- Authors ------- - Yeremia G. Adhisantoso (adhisant@tnt.uni-hannover.de) - Qwen2.5 72B - 4.25bpw Examples -------- """ if not (isinstance(ds_ratio, int) and ds_ratio >= 1): raise ValueError(f"Invalid ds_ratio: {ds_ratio}") row_ids = cm_df[cm_consts.DataFrameSpecs.ROW_IDS].to_numpy() col_ids = cm_df[cm_consts.DataFrameSpecs.COL_IDS].to_numpy() # Use centralized filter logic to get unique valid ids out = filter_empty_rowcols((row_ids, col_ids), is_triu_sym=True, ret_unique_ids=True) unique_ids = out[2] max_id = unique_ids.max() if len(unique_ids) > 0 else 0 num_points = points.shape[0] if max_id >= num_points: raise ValueError(f"Max ID {max_id} is out of bounds for points array of shape {points.shape}") new_points = points[unique_ids, :] if ds_ratio > 1: new_points = ( pd.DataFrame(data={ 'lr_ids': unique_ids // ds_ratio, 'x': new_points[:, 0], 'y': new_points[:, 1], 'z': new_points[:, 2] }) .groupby(['lr_ids']) .mean() .reset_index() [['x', 'y', 'z']] .to_numpy() ) return new_points
[docs] @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def mask_points( points: np.ndarray, cm_df: pd.DataFrame, ds_ratio: int = 1, ) -> np.ndarray: """ Masks points based on the provided DataFrame and downsampling ratio. Notes ----- This function processes the input points and masks them based on the unique row and column IDs from the DataFrame. If the downsampling ratio is greater than 1, it further processes the points to downsample them. Parameters ---------- points : np.ndarray The array of points to be masked. cm_df : pd.DataFrame The DataFrame containing row and column IDs. ds_ratio : int, optional The downsampling ratio, by default 1. Must be an integer greater than or equal to 1. Returns ------- np.ndarray The masked points array. Examples -------- Authors ------- - Yeremia G. Adhisantoso (adhisant@tnt.uni-hannover.de) - Qwen2.5 72B - 4.25bpw Examples -------- """ if not (isinstance(ds_ratio, int) and ds_ratio >= 1): raise ValueError(f"Invalid ds_ratio: {ds_ratio}") row_ids = cm_df[cm_consts.DataFrameSpecs.ROW_IDS].to_numpy() col_ids = cm_df[cm_consts.DataFrameSpecs.COL_IDS].to_numpy() # Use centralized filter logic out = filter_empty_rowcols((row_ids, col_ids), is_triu_sym=True, ret_unique_ids=True) unique_ids = out[2] max_id = unique_ids.max() if len(unique_ids) > 0 else 0 num_points, ndim = points.shape if max_id >= num_points: raise ValueError(f"Max ID {max_id} is out of bounds for points array of shape {points.shape}") if ds_ratio == 1: new_points = points.copy() # Using expand_with_nans idea but here it fills np.inf on non-unique ids mask = np.ones(num_points, dtype=bool) mask[unique_ids] = False new_points[mask, :] = np.inf elif ds_ratio > 1: new_points = points[unique_ids, :] new_num_points = int(np.ceil(num_points / ds_ratio)) lr_points_df = ( pd.DataFrame(data={ 'lr_ids': unique_ids // ds_ratio, 'x': new_points[:, 0], 'y': new_points[:, 1], 'z': new_points[:, 2] }) .groupby(['lr_ids']) .mean() .reset_index() ) lr_ids = lr_points_df['lr_ids'].to_numpy() lr_points = lr_points_df[['x', 'y', 'z']].to_numpy() new_points = np.full((new_num_points, 3), np.inf) new_points[lr_ids, :] = lr_points return new_points
[docs] @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def plot_points( points: np.ndarray, cm_df: t.Optional[pd.DataFrame] = None, colorscale: str = 'Viridis', trace_size: int = 5, fig_width: int = 1000, fig_height: int = 1000, fig_title: str = '3D Reconstruction', ) -> None: """ Plots 3D points using Plotly. Notes ----- This function plots the points in the array `P` in a 3D scatter plot. If `cm_df` is provided, it first extracts the relevant points using the `extract_points` function. The plot is displayed with the specified colorscale, trace size, and dimensions. Parameters ---------- points : np.ndarray The array of points to plot. cm_df : t.Optional[pd.DataFrame], optional The DataFrame containing row and column IDs, by default None. colorscale : str, optional The colorscale to use for the plot, by default 'Viridis'. trace_size : int, optional The size of the markers and lines, by default 5. fig_width : int, optional The width of the plot, by default 1000. fig_height : int, optional The height of the plot, by default 1000. Returns ------- None Examples -------- Authors ------- - Yeremia G. Adhisantoso (adhisant@tnt.uni-hannover.de) - Qwen2.5 72B - 4.25bpw Examples -------- """ if cm_df is not None: points = filter_valid_points( points, cm_df ) fig = go.Figure() fig.add_trace( go.Scatter3d( x=points[:,0], y=points[:,1], z=points[:,2], mode='lines+markers', line=dict( color=np.arange(len(points)), colorscale=colorscale, width=trace_size, ), marker=dict( color=np.arange(len(points)), colorscale=colorscale, size=trace_size, # opacity=0.8 ) ) ) fig.update_layout( width=fig_width, height=fig_height, title=fig_title, ) fig.show()