# -*- coding: utf-8 -*-
"""
Module.
Examples
--------
"""
__version__ = "1.0.0"
__author__ = "Yeremia Gunawan Adhisantoso"
__license__ = "Clear BSD"
__email__ = "adhisant@tnt.uni-hannover.de"
import typing as t
from pydantic import validate_call, ConfigDict
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from ... import consts as cm_consts
from ...preprocs.masks import expand_with_nans
from ...preprocs.rc_filters import filter_empty_rowcols
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def downsample_points(
points: np.ndarray,
ds_ratio: int,
def_coor: float = np.nan,
) -> np.ndarray:
"""
Downsamples the given points by a specified ratio.
Notes
-----
- The function ensures that the `ds_ratio` is greater than 1.
- Points with all NaN values are ignored during downsampling.
- The resulting array is filled with `def_coor` for indices without valid points.
Parameters
----------
points : np.ndarray
The array of points to be downsampled.
ds_ratio : int
The downsampling ratio. Must be greater than 1.
def_coor : float, optional
The default coordinate value for indices without valid points, by default np.nan.
Returns
-------
np.ndarray
The downsampled points array.
Examples
--------
Authors
-------
- Yeremia G. Adhisantoso (adhisant@tnt.uni-hannover.de)
- Qwen2.5 Coder 32B - 6.5bpw
Examples
--------
"""
if not (isinstance(ds_ratio, int) and ds_ratio > 1):
raise ValueError("Invalid ds_ratio!")
num_points = points.shape[0]
valid_points_mask = ~np.isnan(points).all(axis=1)
valid_points_ids = np.arange(num_points)[valid_points_mask]
points_df = pd.DataFrame({
'ids': valid_points_ids,
'x': points[valid_points_mask, 0],
'y': points[valid_points_mask, 1],
'z': points[valid_points_mask, 2],
})
points_df['lr_ids'] = points_df['ids'] // ds_ratio
lr_points_df = (
points_df[['lr_ids', 'x', 'y', 'z']]
.groupby('lr_ids')
.mean()
)
num_lr_points = int(np.ceil(num_points/ds_ratio))
valid_lr_ids = np.array(lr_points_df.index)
lr_points = np.full((num_lr_points, 3), def_coor)
lr_points[valid_lr_ids] = lr_points_df.to_numpy()
return lr_points
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def filter_points(
points: np.ndarray,
ret_mask: bool = False,
) -> t.Union[np.ndarray, t.Tuple[np.ndarray, np.ndarray]]:
"""
Filters out points with any NaN values.
Notes
-----
- If `ret_mask` is True, the function returns both the filtered points and the mask used for filtering.
- If `ret_mask` is False, only the filtered points are returned.
Parameters
----------
points : np.ndarray
The array of points to be filtered.
ret_mask : bool, optional
Whether to return the mask used for filtering, by default False.
Returns
-------
np.ndarray or Tuple[np.ndarray, np.ndarray]
The filtered points, and optionally the mask used for filtering.
Examples
--------
Authors
-------
- Yeremia G. Adhisantoso (adhisant@tnt.uni-hannover.de)
- Qwen2.5 Coder 32B - 6.5bpw
Examples
--------
"""
mask = ~np.isnan(points).any(axis=1)
new_points = points[mask, :]
if ret_mask:
return new_points, mask
else:
return points
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def filter_valid_points(
points: np.ndarray,
cm_df: pd.DataFrame,
ds_ratio: int = 1,
) -> np.ndarray:
"""
Filters valid points based on the provided DataFrame and downsampling ratio.
This function is used when the coordinates of the points covers also the empty regions.
Notes
-----
- The function ensures that the `ds_ratio` is a positive integer.
- It extracts unique row and column IDs from the DataFrame and filters the points accordingly.
- If `ds_ratio` is greater than 1, it performs downsampling by averaging points within the same low-resolution ID.
Parameters
----------
points : np.ndarray
The array of points to be filtered.
cm_df : pd.DataFrame
The DataFrame containing row and column IDs.
ds_ratio : int, optional
The downsampling ratio, by default 1. Must be a positive integer.
Returns
-------
np.ndarray
The filtered and optionally downsampled points.
Examples
--------
Authors
-------
- Yeremia G. Adhisantoso (adhisant@tnt.uni-hannover.de)
- Qwen2.5 72B - 4.25bpw
Examples
--------
"""
if not (isinstance(ds_ratio, int) and ds_ratio >= 1):
raise ValueError(f"Invalid ds_ratio: {ds_ratio}")
row_ids = cm_df[cm_consts.DataFrameSpecs.ROW_IDS].to_numpy()
col_ids = cm_df[cm_consts.DataFrameSpecs.COL_IDS].to_numpy()
# Use centralized filter logic to get unique valid ids
out = filter_empty_rowcols((row_ids, col_ids), is_triu_sym=True, ret_unique_ids=True)
unique_ids = out[2]
max_id = unique_ids.max() if len(unique_ids) > 0 else 0
num_points = points.shape[0]
if max_id >= num_points:
raise ValueError(f"Max ID {max_id} is out of bounds for points array of shape {points.shape}")
new_points = points[unique_ids, :]
if ds_ratio > 1:
new_points = (
pd.DataFrame(data={
'lr_ids': unique_ids // ds_ratio,
'x': new_points[:, 0],
'y': new_points[:, 1],
'z': new_points[:, 2]
})
.groupby(['lr_ids'])
.mean()
.reset_index()
[['x', 'y', 'z']]
.to_numpy()
)
return new_points
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def mask_points(
points: np.ndarray,
cm_df: pd.DataFrame,
ds_ratio: int = 1,
) -> np.ndarray:
"""
Masks points based on the provided DataFrame and downsampling ratio.
Notes
-----
This function processes the input points and masks them based on the unique
row and column IDs from the DataFrame. If the downsampling ratio is greater
than 1, it further processes the points to downsample them.
Parameters
----------
points : np.ndarray
The array of points to be masked.
cm_df : pd.DataFrame
The DataFrame containing row and column IDs.
ds_ratio : int, optional
The downsampling ratio, by default 1. Must be an integer greater than or equal to 1.
Returns
-------
np.ndarray
The masked points array.
Examples
--------
Authors
-------
- Yeremia G. Adhisantoso (adhisant@tnt.uni-hannover.de)
- Qwen2.5 72B - 4.25bpw
Examples
--------
"""
if not (isinstance(ds_ratio, int) and ds_ratio >= 1):
raise ValueError(f"Invalid ds_ratio: {ds_ratio}")
row_ids = cm_df[cm_consts.DataFrameSpecs.ROW_IDS].to_numpy()
col_ids = cm_df[cm_consts.DataFrameSpecs.COL_IDS].to_numpy()
# Use centralized filter logic
out = filter_empty_rowcols((row_ids, col_ids), is_triu_sym=True, ret_unique_ids=True)
unique_ids = out[2]
max_id = unique_ids.max() if len(unique_ids) > 0 else 0
num_points, ndim = points.shape
if max_id >= num_points:
raise ValueError(f"Max ID {max_id} is out of bounds for points array of shape {points.shape}")
if ds_ratio == 1:
new_points = points.copy()
# Using expand_with_nans idea but here it fills np.inf on non-unique ids
mask = np.ones(num_points, dtype=bool)
mask[unique_ids] = False
new_points[mask, :] = np.inf
elif ds_ratio > 1:
new_points = points[unique_ids, :]
new_num_points = int(np.ceil(num_points / ds_ratio))
lr_points_df = (
pd.DataFrame(data={
'lr_ids': unique_ids // ds_ratio,
'x': new_points[:, 0],
'y': new_points[:, 1],
'z': new_points[:, 2]
})
.groupby(['lr_ids'])
.mean()
.reset_index()
)
lr_ids = lr_points_df['lr_ids'].to_numpy()
lr_points = lr_points_df[['x', 'y', 'z']].to_numpy()
new_points = np.full((new_num_points, 3), np.inf)
new_points[lr_ids, :] = lr_points
return new_points
[docs]
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def plot_points(
points: np.ndarray,
cm_df: t.Optional[pd.DataFrame] = None,
colorscale: str = 'Viridis',
trace_size: int = 5,
fig_width: int = 1000,
fig_height: int = 1000,
fig_title: str = '3D Reconstruction',
) -> None:
"""
Plots 3D points using Plotly.
Notes
-----
This function plots the points in the array `P` in a 3D scatter plot. If `cm_df` is provided, it first extracts
the relevant points using the `extract_points` function. The plot is displayed with the specified colorscale,
trace size, and dimensions.
Parameters
----------
points : np.ndarray
The array of points to plot.
cm_df : t.Optional[pd.DataFrame], optional
The DataFrame containing row and column IDs, by default None.
colorscale : str, optional
The colorscale to use for the plot, by default 'Viridis'.
trace_size : int, optional
The size of the markers and lines, by default 5.
fig_width : int, optional
The width of the plot, by default 1000.
fig_height : int, optional
The height of the plot, by default 1000.
Returns
-------
None
Examples
--------
Authors
-------
- Yeremia G. Adhisantoso (adhisant@tnt.uni-hannover.de)
- Qwen2.5 72B - 4.25bpw
Examples
--------
"""
if cm_df is not None:
points = filter_valid_points(
points,
cm_df
)
fig = go.Figure()
fig.add_trace(
go.Scatter3d(
x=points[:,0],
y=points[:,1],
z=points[:,2],
mode='lines+markers',
line=dict(
color=np.arange(len(points)),
colorscale=colorscale,
width=trace_size,
),
marker=dict(
color=np.arange(len(points)),
colorscale=colorscale,
size=trace_size,
# opacity=0.8
)
)
)
fig.update_layout(
width=fig_width,
height=fig_height,
title=fig_title,
)
fig.show()