diff --git a/dexpv2/_tests/test_segmentation.py b/dexpv2/_tests/test_segmentation.py index 71ef18e..c904116 100644 --- a/dexpv2/_tests/test_segmentation.py +++ b/dexpv2/_tests/test_segmentation.py @@ -1,8 +1,9 @@ import logging from skimage.data import cells3d +import pytest -from dexpv2.segmentation import detect_foreground +from dexpv2.segmentation import detect_foreground, reconstruction_by_dilation from dexpv2.utils import to_cpu LOG = logging.getLogger(__name__) @@ -32,3 +33,35 @@ def test_foreground_detection(interactive_test: bool) -> None: viewer.add_labels(to_cpu(foreground)) napari.run() + + +def test_foreground_detection_with_float16() -> None: + # Test with float16 data + nuclei = xp.asarray(cells3d()[:, 1]) + nuclei = nuclei / nuclei.max() + nuclei = nuclei.astype(xp.float16) + mask = xp.copy(nuclei) + + # Ensure we are using cupy backend + import numpy as np + + if isinstance(nuclei, np.ndarray): + pytest.skip("Skipping test as cupy is not available.") + + foreground_cp = reconstruction_by_dilation(nuclei, mask, iterations=10) + foreground_cp = to_cpu(foreground_cp) + + nuclei_f32 = nuclei.astype(xp.float32) + mask_f32 = mask.astype(xp.float32) + foreground_f32 = reconstruction_by_dilation(nuclei_f32, mask_f32, iterations=10) + + # Convert to numpy for comparison + # Obs. skimage operations won't work with np.float16 so we need to convert + # to float32 and hope that the conversion doesn't change the result too much + nuclei_np = to_cpu(nuclei_f32) + mask_np = to_cpu(mask_f32) + foreground_np = reconstruction_by_dilation(nuclei_np, mask_np, iterations=10) + + # Check that the output is a binary mask + assert np.allclose(foreground_cp, foreground_np) + assert np.allclose(foreground_cp, foreground_f32) diff --git a/dexpv2/segmentation.py b/dexpv2/segmentation.py index 90effa6..bf00078 100644 --- a/dexpv2/segmentation.py +++ b/dexpv2/segmentation.py @@ -1,4 +1,5 @@ import logging +from typing import Tuple, List import numpy as np from numpy.typing import ArrayLike @@ -10,6 +11,94 @@ LOG = logging.getLogger(__name__) LOG.setLevel(logging.INFO) +try: + import cupy as xp + + LOG.info("cupy found.") +except (ModuleNotFoundError, ImportError): + import numpy as xp + + LOG.info("cupy not found using numpy.") + + +def discretize_multiple_f16_to_u16( + f16_arrays: List[ArrayLike], +) -> Tuple[List[ArrayLike], ArrayLike]: + """ + Discretizes multiple arrays (e.g., CuPy or NumPy) of float16 values to + uint16, preserving order using a global mapping across all arrays. + + Parameters + ---------- + f16_arrays : List[ArrayLike] + List of input arrays to be discretized. Each array must have dtype + float16. + + Returns + ------- + Tuple[List[ArrayLike], ArrayLike]: A tuple containing: + - u16_list (List[ArrayLike]): A list of discretized arrays, + each with dtype uint16, + corresponding to the input + arrays. + - u16_to_f16_lut (ArrayLike): A single lookup table (array of + float16) for all arrays, where the + index is the uint16 value and the + value is the corresponding original + float16 value. + + Raises + ------ + TypeError: If any input array's dtype is not float16. + ValueError: If the list of arrays is empty, or if the total number + of unique values across all arrays exceeds the capacity + of uint16 (65536). + """ + if not f16_arrays: + raise ValueError("Input list of arrays cannot be empty.") + + # Validate input types and collect original shapes and sizes + original_shapes = [] + for i, arr in enumerate(f16_arrays): + if arr.dtype != xp.float16: + raise TypeError( + f"Array at index {i} must be an 'xp.ndarray' with dtype xp.float16. " + f"Got type {type(arr)} with dtype {getattr(arr, 'dtype', 'N/A')}." + ) + if arr.size == 0: + raise ValueError( + f"Array at index {i} is empty. Cannot discretize empty arrays." + ) + original_shapes.append(arr.shape) + + # Collect unique values and their indices + uniques, inverses = [], [] + for arr in f16_arrays: + unq, inv = np.unique(arr, return_inverse=True) + uniques.append(unq) + inverses.append(inv.astype(np.uint16)) + + # Concatenate all unique values and sort them + u16_to_f16_lut = np.sort(np.concatenate(uniques)) + if len(u16_to_f16_lut) > np.iinfo(np.uint16).max: + raise ValueError( + "The total number of unique values across all arrays exceeds " + "the capacity of uint16 (65536)." + ) + + # Fix inverses to preserve order + for unq_k, inv_k in zip(uniques, inverses): + new_idx_k = np.searchsorted(u16_to_f16_lut, unq_k) + inv_k[:] = new_idx_k[inv_k] + + # Reshape inverses to match original shapes + u16_list: list[ArrayLike] = [] + for shape, inv in zip(original_shapes, inverses): + inv_reshaped = inv.reshape(shape) + u16_list.append(inv_reshaped) + + return u16_list, u16_to_f16_lut + def reconstruction_by_dilation( seed: ArrayLike, mask: ArrayLike, iterations: int @@ -34,19 +123,35 @@ def reconstruction_by_dilation( ------- Image reconstructed by dilation. """ - ndi = import_module("scipy", "ndimage") + ndi = import_module("scipy", "ndimage", seed) + + import numpy as np + + cupy_used = np != xp and not isinstance(seed, np.ndarray) + + lut = None + # quick-fix for the issue https://github.com/cupy/cupy/issues/9122 + if cupy_used and seed.dtype == xp.float16: + (seed, mask), lut = discretize_multiple_f16_to_u16([seed, mask]) - seed = np.minimum(seed, mask, out=seed) # just making sure + seed = np.minimum(seed, mask, out=seed) for _ in range(iterations): seed = ndi.grey_dilation(seed, size=3, output=seed, mode="constant") seed = np.minimum(seed, mask, out=seed) + if lut is not None: + # convert back to float16 + seed = xp.take(lut, seed) + return seed def fancy_otsu_threshold( - image: ArrayLike, remove_hist_mode: bool = False, min_foreground: float = 0.0 + image: ArrayLike, + remove_hist_mode: bool = False, + min_foreground: float = 0.0, + max_foreground: float = None, ) -> float: """ Compute Otsu threshold with some additional features. @@ -61,6 +166,8 @@ def fancy_otsu_threshold( Removes histogram mode before computing otsu threshold, useful when background regions are being detected. min_foreground : float, optional Minimum threshold value, by default 0.0 + max_foreground: float, optional + Maximum threshold value, by default max value of image Returns ------- @@ -85,6 +192,13 @@ def fancy_otsu_threshold( LOG.info(f"Histogram with {nbins}") hist, bin_centers = exposure.histogram(image, nbins) + print(len(bin_centers)) + # clip bins and histogram beyond max_foreground value + if max_foreground is not None: + below_threshold_mask = bin_centers < np.sqrt(max_foreground) + bin_centers = bin_centers[below_threshold_mask] + print(bin_centers) + hist = hist[below_threshold_mask] # histogram disconsidering pixels we are sure are background if remove_hist_mode: @@ -98,7 +212,6 @@ def fancy_otsu_threshold( threshold = max(threshold, min_foreground) LOG.info(f"Threshold after minimum filtering {threshold}") - return threshold @@ -148,6 +261,7 @@ def detect_foreground( sigma: float = 15.0, remove_hist_mode: bool = False, min_foreground: float = 0.0, + max_foreground: float = None, ) -> ArrayLike: """ Detect foreground using morphological reconstruction by dilation and thresholding. @@ -183,6 +297,7 @@ def detect_foreground( small_foreground, remove_hist_mode=remove_hist_mode, min_foreground=min_foreground, + max_foreground=max_foreground, ) mask = foreground > threshold