Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion dexpv2/_tests/test_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging

from skimage.data import cells3d
import pytest

from dexpv2.segmentation import detect_foreground
from dexpv2.segmentation import detect_foreground, reconstruction_by_dilation
from dexpv2.utils import to_cpu

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -32,3 +33,35 @@ def test_foreground_detection(interactive_test: bool) -> None:
viewer.add_labels(to_cpu(foreground))

napari.run()


def test_foreground_detection_with_float16() -> None:
# Test with float16 data
nuclei = xp.asarray(cells3d()[:, 1])
nuclei = nuclei / nuclei.max()
nuclei = nuclei.astype(xp.float16)
mask = xp.copy(nuclei)

# Ensure we are using cupy backend
import numpy as np

if isinstance(nuclei, np.ndarray):
pytest.skip("Skipping test as cupy is not available.")

foreground_cp = reconstruction_by_dilation(nuclei, mask, iterations=10)
foreground_cp = to_cpu(foreground_cp)

nuclei_f32 = nuclei.astype(xp.float32)
mask_f32 = mask.astype(xp.float32)
foreground_f32 = reconstruction_by_dilation(nuclei_f32, mask_f32, iterations=10)

# Convert to numpy for comparison
# Obs. skimage operations won't work with np.float16 so we need to convert
# to float32 and hope that the conversion doesn't change the result too much
nuclei_np = to_cpu(nuclei_f32)
mask_np = to_cpu(mask_f32)
foreground_np = reconstruction_by_dilation(nuclei_np, mask_np, iterations=10)

# Check that the output is a binary mask
assert np.allclose(foreground_cp, foreground_np)
assert np.allclose(foreground_cp, foreground_f32)
123 changes: 119 additions & 4 deletions dexpv2/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Tuple, List

import numpy as np
from numpy.typing import ArrayLike
Expand All @@ -10,6 +11,94 @@
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)

try:
import cupy as xp

LOG.info("cupy found.")
except (ModuleNotFoundError, ImportError):
import numpy as xp

LOG.info("cupy not found using numpy.")


def discretize_multiple_f16_to_u16(
f16_arrays: List[ArrayLike],
) -> Tuple[List[ArrayLike], ArrayLike]:
"""
Discretizes multiple arrays (e.g., CuPy or NumPy) of float16 values to
uint16, preserving order using a global mapping across all arrays.

Parameters
----------
f16_arrays : List[ArrayLike]
List of input arrays to be discretized. Each array must have dtype
float16.

Returns
-------
Tuple[List[ArrayLike], ArrayLike]: A tuple containing:
- u16_list (List[ArrayLike]): A list of discretized arrays,
each with dtype uint16,
corresponding to the input
arrays.
- u16_to_f16_lut (ArrayLike): A single lookup table (array of
float16) for all arrays, where the
index is the uint16 value and the
value is the corresponding original
float16 value.

Raises
------
TypeError: If any input array's dtype is not float16.
ValueError: If the list of arrays is empty, or if the total number
of unique values across all arrays exceeds the capacity
of uint16 (65536).
"""
if not f16_arrays:
raise ValueError("Input list of arrays cannot be empty.")

# Validate input types and collect original shapes and sizes
original_shapes = []
for i, arr in enumerate(f16_arrays):
if arr.dtype != xp.float16:
raise TypeError(
f"Array at index {i} must be an 'xp.ndarray' with dtype xp.float16. "
f"Got type {type(arr)} with dtype {getattr(arr, 'dtype', 'N/A')}."
)
if arr.size == 0:
raise ValueError(
f"Array at index {i} is empty. Cannot discretize empty arrays."
)
original_shapes.append(arr.shape)

# Collect unique values and their indices
uniques, inverses = [], []
for arr in f16_arrays:
unq, inv = np.unique(arr, return_inverse=True)
uniques.append(unq)
inverses.append(inv.astype(np.uint16))

# Concatenate all unique values and sort them
u16_to_f16_lut = np.sort(np.concatenate(uniques))
if len(u16_to_f16_lut) > np.iinfo(np.uint16).max:
raise ValueError(
"The total number of unique values across all arrays exceeds "
"the capacity of uint16 (65536)."
)

# Fix inverses to preserve order
for unq_k, inv_k in zip(uniques, inverses):
new_idx_k = np.searchsorted(u16_to_f16_lut, unq_k)
inv_k[:] = new_idx_k[inv_k]

# Reshape inverses to match original shapes
u16_list: list[ArrayLike] = []
for shape, inv in zip(original_shapes, inverses):
inv_reshaped = inv.reshape(shape)
u16_list.append(inv_reshaped)

return u16_list, u16_to_f16_lut


def reconstruction_by_dilation(
seed: ArrayLike, mask: ArrayLike, iterations: int
Expand All @@ -34,19 +123,35 @@ def reconstruction_by_dilation(
-------
Image reconstructed by dilation.
"""
ndi = import_module("scipy", "ndimage")
ndi = import_module("scipy", "ndimage", seed)

import numpy as np

cupy_used = np != xp and not isinstance(seed, np.ndarray)

lut = None
# quick-fix for the issue https://github.com/cupy/cupy/issues/9122
if cupy_used and seed.dtype == xp.float16:
(seed, mask), lut = discretize_multiple_f16_to_u16([seed, mask])

seed = np.minimum(seed, mask, out=seed) # just making sure
seed = np.minimum(seed, mask, out=seed)

for _ in range(iterations):
seed = ndi.grey_dilation(seed, size=3, output=seed, mode="constant")
seed = np.minimum(seed, mask, out=seed)

if lut is not None:
# convert back to float16
seed = xp.take(lut, seed)

return seed


def fancy_otsu_threshold(
image: ArrayLike, remove_hist_mode: bool = False, min_foreground: float = 0.0
image: ArrayLike,
remove_hist_mode: bool = False,
min_foreground: float = 0.0,
max_foreground: float = None,
) -> float:
"""
Compute Otsu threshold with some additional features.
Expand All @@ -61,6 +166,8 @@ def fancy_otsu_threshold(
Removes histogram mode before computing otsu threshold, useful when background regions are being detected.
min_foreground : float, optional
Minimum threshold value, by default 0.0
max_foreground: float, optional
Maximum threshold value, by default max value of image

Returns
-------
Expand All @@ -85,6 +192,13 @@ def fancy_otsu_threshold(
LOG.info(f"Histogram with {nbins}")

hist, bin_centers = exposure.histogram(image, nbins)
print(len(bin_centers))
# clip bins and histogram beyond max_foreground value
if max_foreground is not None:
below_threshold_mask = bin_centers < np.sqrt(max_foreground)
bin_centers = bin_centers[below_threshold_mask]
print(bin_centers)
hist = hist[below_threshold_mask]

# histogram disconsidering pixels we are sure are background
if remove_hist_mode:
Expand All @@ -98,7 +212,6 @@ def fancy_otsu_threshold(

threshold = max(threshold, min_foreground)
LOG.info(f"Threshold after minimum filtering {threshold}")

return threshold


Expand Down Expand Up @@ -148,6 +261,7 @@ def detect_foreground(
sigma: float = 15.0,
remove_hist_mode: bool = False,
min_foreground: float = 0.0,
max_foreground: float = None,
) -> ArrayLike:
"""
Detect foreground using morphological reconstruction by dilation and thresholding.
Expand Down Expand Up @@ -183,6 +297,7 @@ def detect_foreground(
small_foreground,
remove_hist_mode=remove_hist_mode,
min_foreground=min_foreground,
max_foreground=max_foreground,
)

mask = foreground > threshold
Expand Down