diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index fe5140171b..4172e8c1c4 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -1,7 +1,7 @@ from __future__ import annotations -from functools import partial, singledispatch -from typing import TYPE_CHECKING, Literal, TypedDict, get_args +from functools import singledispatch +from typing import TYPE_CHECKING, Literal, TypedDict import numpy as np import pandas as pd @@ -71,17 +71,32 @@ def __init__( indicator_matrix: CSRBase | sparse.coo_array data: ArrayT - def count_nonzero(self) -> NDArray[np.integer]: + def count_nonzero(self, *, keep_sparse: bool = True) -> NDArray[np.integer]: """Count the number of observations in each group. + Parameters + ---------- + keep_sparse + If True and the input data is a sparse matrix, return a sparse matrix + of the same type for the aggregated counts. If False, always return + a dense :class:`numpy.ndarray`. + Returns ------- Array of counts. """ - return self._sum(data=(self.data != 0).astype("uint8")) + data = self.data + if isinstance(data, CSBase): + data = type(data)( + (np.ones(data.nnz, dtype="uint8"), data.indices, data.indptr), + shape=data.shape, + ) + else: + data = (data != 0).astype("uint8") + return self._sum(data=data, keep_sparse=keep_sparse) - def _sum(self, data: ArrayT): + def _sum(self, data: ArrayT, *, keep_sparse: bool): if isinstance(data, np.ndarray): res = self.indicator_matrix @ data if isinstance(res, CSBase): @@ -92,17 +107,26 @@ def _sum(self, data: ArrayT): (agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)( self.indicator_matrix, data, out ) + if keep_sparse and isinstance(data, CSBase): + return type(data)(out) # convert to sparse type of input return out - def sum(self) -> np.ndarray: + def sum(self, *, keep_sparse: bool = True) -> np.ndarray: """Compute the sum per feature per group of observations. + Parameters + ---------- + keep_sparse + If True and the input data is a sparse matrix, return a sparse matrix + of the same type for the aggregated sums. If False, always return + a dense :class:`numpy.ndarray`. + Returns ------- Array of sum. """ - return self._sum(self.data) + return self._sum(self.data, keep_sparse=keep_sparse) def mean(self) -> Array: """Compute the mean per feature per group of observations. @@ -112,7 +136,7 @@ def mean(self) -> Array: Array of mean. """ - return self.sum() / np.bincount(self.groupby.codes)[:, None] + return self.sum(keep_sparse=False) / np.bincount(self.groupby.codes)[:, None] def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: """Compute the count, as well as mean and variance per feature, per group of observations. @@ -139,7 +163,10 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: if isinstance(self.data, np.ndarray): mean_ = self.mean() # sparse matrices do not support ** for elementwise power. - mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None] + mean_sq = ( + self._sum(_power(self.data, 2), keep_sparse=False) + / group_counts[:, None] + ) sq_mean = mean_**2 var_ = mean_sq - sq_mean else: @@ -205,6 +232,7 @@ def aggregate( # noqa: PLR0912 axis: Literal["obs", 0, "var", 1] | None = None, mask: NDArray[np.bool] | str | None = None, dof: int = 1, + keep_sparse: bool = True, layer: str | None = None, obsm: str | None = None, varm: str | None = None, @@ -235,6 +263,10 @@ def aggregate( # noqa: PLR0912 Boolean mask (or key to column containing mask) to apply along the axis. dof Degrees of freedom for variance. Defaults to 1. + keep_sparse + If True and the input data is a sparse matrix, preserve sparse outputs + for metrics that support it (for example, ``sum`` and ``count_nonzero``). + If False, force dense :class:`numpy.ndarray` outputs. Defaults to True. layer If not None, key for aggregation data. obsm @@ -324,6 +356,7 @@ def aggregate( # noqa: PLR0912 func=func, mask=mask, dof=dof, + keep_sparse=keep_sparse, ) # Define new var dataframe @@ -354,6 +387,7 @@ def _aggregate( *, mask: NDArray[np.bool] | None = None, dof: int = 1, + keep_sparse: bool = True, ) -> dict[AggType, np.ndarray | DaskArray]: msg = f"Data type {type(data)} not supported for aggregation" raise NotImplementedError(msg) @@ -370,9 +404,14 @@ def aggregate_dask_mean_var( *, mask: NDArray[np.bool] | None = None, dof: int = 1, + keep_sparse: bool = False, ) -> MeanVarDict: - mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"] - sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"] + mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof, keep_sparse=False)[ + "mean" + ] + sq_mean = aggregate_dask( + fau_power(data, 2), by, "mean", mask=mask, dof=dof, keep_sparse=False + )["mean"] # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. if isinstance(data._meta, CSRBase): sq_mean = sq_mean.compute() @@ -391,7 +430,10 @@ def aggregate_dask( *, mask: NDArray[np.bool] | None = None, dof: int = 1, + keep_sparse: bool = True, ) -> dict[AggType, DaskArray]: + import dask + if not isinstance(data._meta, CSBase | np.ndarray): msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported." raise ValueError(msg) @@ -399,25 +441,7 @@ def aggregate_dask( (0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0) ) if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]: - msg = "Feature axis must be unchunked" - raise ValueError(msg) - - def aggregate_chunk_sum_or_count_nonzero( - chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None - ): - # only subset the mask and by if we need to i.e., - # there is chunking along the same axis as by and mask - if chunked_axis == 0: - # See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html - # for what is contained in `block_info`. - subset = slice(*block_info[0]["array-location"][0]) - by_subsetted = by[subset] - mask_subsetted = mask[subset] if mask is not None else mask - else: - by_subsetted = by - mask_subsetted = mask - res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func] - return res[None, :] if unchunked_axis == 1 else res + data = data.rechunk({unchunked_axis: -1}) funcs = set([func] if isinstance(func, str) else func) if "median" in funcs: @@ -425,35 +449,72 @@ def aggregate_chunk_sum_or_count_nonzero( raise NotImplementedError(msg) has_mean, has_var = (v in funcs for v in ["mean", "var"]) funcs_no_var_or_mean = funcs - {"var", "mean"} - # aggregate each row chunk or column chunk individually, - # producing a #chunks × #categories × #features or a #categories × #chunks array, - # then aggregate the per-chunk results. - chunks = ( - ((1,) * data.blocks.size, (len(by.categories),), data.shape[1]) - if unchunked_axis == 1 - else (len(by.categories), data.chunks[1]) - ) + + @dask.delayed + def aggregate_chunk(block, block_idx): + subset = slice(block_idx[0], block_idx[1]) + by_subsetted = ( + pd.Categorical.from_codes(by.codes[subset], categories=by.categories) + if chunked_axis == 0 + else by + ) + mask_subsetted = ( + mask[subset] if (mask is not None and chunked_axis == 0) else mask + ) + return { + f: _aggregate( + block, + by_subsetted, + f, + mask=mask_subsetted, + dof=dof, + keep_sparse=keep_sparse, + )[f] + for f in funcs_no_var_or_mean + } + + @dask.delayed + def combine_aggs(a, b): + if chunked_axis == 0: + return {f: a[f] + b[f] for f in funcs_no_var_or_mean} + else: + return { + f: sparse.hstack([a[f], b[f]]) + if isinstance(a[f], CSBase) + else np.concatenate([a[f], b[f]], axis=1) + for f in funcs_no_var_or_mean + } + + offset = 0 + delayed_chunks = [] + blocks = data.to_delayed().ravel() + for i, block in enumerate(blocks): + block_idx = (offset, offset + data.chunks[chunked_axis][i]) + delayed_chunks.append(aggregate_chunk(block, block_idx)) + offset += data.chunks[chunked_axis][i] + + while len(delayed_chunks) > 1: + delayed_chunks = [ + combine_aggs(delayed_chunks[i], delayed_chunks[i + 1]) + if i + 1 < len(delayed_chunks) + else delayed_chunks[i] + for i in range(0, len(delayed_chunks), 2) + ] + aggregated = { - f: data.map_blocks( - partial(aggregate_chunk_sum_or_count_nonzero, func=func), - new_axis=(1,) if unchunked_axis == 1 else None, - chunks=chunks, - meta=np.array( - [], - dtype=np.float64 - if func not in get_args(ConstantDtypeAgg) - else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original - ), + f: dask.array.from_delayed( + dask.delayed(lambda r, f=f: r[f])(delayed_chunks[0]), + shape=(len(by.categories), data.shape[1]), + dtype=np.float64, ) for f in funcs_no_var_or_mean } - # If we have row chunking, we need to handle the extra axis by summing over all category × feature matrices. - # Otherwise, dask internally concatenates the #categories × #chunks arrays i.e., the column chunks are concatenated together to get a #categories × #features matrix. - if unchunked_axis == 1: - for k, v in aggregated.items(): - aggregated[k] = v.sum(axis=chunked_axis) + if has_var: - aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof) + # mean/var must be dense regardless of `keep_sparse` + aggredated_mean_var = aggregate_dask_mean_var( + data, by, mask=mask, dof=dof, keep_sparse=False + ) aggregated["var"] = aggredated_mean_var["var"] if has_mean: aggregated["mean"] = aggredated_mean_var["mean"] @@ -461,16 +522,23 @@ def aggregate_chunk_sum_or_count_nonzero( # i.e., we can't just call map blocks over the mean function. elif has_mean: group_counts = np.bincount(by.codes) + # compute sum then divide; force sum to be dense here for mean aggregated["mean"] = ( - aggregate_dask(data, by, "sum", mask=mask, dof=dof)["sum"] + aggregate_dask(data, by, "sum", mask=mask, dof=dof, keep_sparse=False)[ + "sum" + ] / group_counts[:, None] ) return aggregated @_aggregate.register(pd.DataFrame) -def aggregate_df(data, by, func, *, mask=None, dof=1) -> dict[AggType, np.ndarray]: - return _aggregate(data.values, by, func, mask=mask, dof=dof) +def aggregate_df( + data, by, func, *, mask=None, dof=1, keep_sparse=False +) -> dict[AggType, np.ndarray]: + return _aggregate( + data.values, by, func, mask=mask, dof=dof, keep_sparse=keep_sparse + ) @_aggregate.register(np.ndarray) @@ -482,6 +550,7 @@ def aggregate_array( *, mask: NDArray[np.bool] | None = None, dof: int = 1, + keep_sparse: bool = True, ) -> dict[AggType, np.ndarray]: groupby = Aggregate(groupby=by, data=data, mask=mask) result = {} @@ -492,22 +561,19 @@ def aggregate_array( raise ValueError(msg) if "sum" in funcs: # sum is calculated separately from the rest - agg = groupby.sum() - result["sum"] = agg + result["sum"] = groupby.sum(keep_sparse=keep_sparse) # here and below for count, if var is present, these can be calculate alongside var if "mean" in funcs and "var" not in funcs: - agg = groupby.mean() - result["mean"] = agg + result["mean"] = groupby.mean() if "count_nonzero" in funcs: - result["count_nonzero"] = groupby.count_nonzero() + result["count_nonzero"] = groupby.count_nonzero(keep_sparse=keep_sparse) if "var" in funcs: mean_, var_ = groupby.mean_var(dof) result["var"] = var_ if "mean" in funcs: result["mean"] = mean_ if "median" in funcs: - agg = groupby.median() - result["median"] = agg + result["median"] = groupby.median() return result diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index 979cbdc536..71d004eed2 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -9,13 +9,14 @@ from scipy import sparse import scanpy as sc -from scanpy._compat import DaskArray +from scanpy._compat import CSBase, DaskArray from scanpy._utils import _resolve_axis, get_literal_vals from scanpy.get._aggregated import AggType from testing.scanpy._helpers import assert_equal from testing.scanpy._helpers.data import pbmc3k_processed from testing.scanpy._pytest.marks import needs from testing.scanpy._pytest.params import ARRAY_TYPES as ARRAY_TYPES_ALL +from testing.scanpy._pytest.params import ARRAY_TYPES_DASK if TYPE_CHECKING: from collections.abc import Callable @@ -201,8 +202,10 @@ def test_aggregate_bad_dask_array( ) -> None: adata = pbmc3k_processed().raw.to_adata() adata.X = func(adata.X) - with pytest.raises(ValueError, match=error_msg): - sc.get.aggregate(adata, ["louvain"], "sum") + # The implementation now rechunks the array to make the feature axis unchunked + # instead of raising; ensure aggregation completes and returns a dask layer. + result = sc.get.aggregate(adata, ["louvain"], "sum") + assert isinstance(result.layers["sum"], DaskArray) @pytest.mark.parametrize("axis_name", ["obs", "var"]) @@ -414,8 +417,9 @@ def test_combine_categories( @pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES) +@pytest.mark.parametrize("keep_sparse", [True, False]) def test_aggregate_arraytype( - array_type, metric: AggType, request: pytest.FixtureRequest + array_type, metric: AggType, *, keep_sparse: bool, request: pytest.FixtureRequest ) -> None: adata = pbmc3k_processed().raw.to_adata() adata = adata[ @@ -423,11 +427,29 @@ def test_aggregate_arraytype( ].copy() adata.X = array_type(adata.X) xfail_dask_median(adata, metric, request) - aggregate = sc.get.aggregate(adata, ["louvain"], metric) - assert isinstance( - aggregate.layers[metric], - DaskArray if isinstance(adata.X, DaskArray) else np.ndarray, - ) + aggregate = sc.get.aggregate(adata, ["louvain"], metric, keep_sparse=keep_sparse) + + # Resolve dask if present for type assertions + layer = aggregate.layers[metric] + + if isinstance(adata.X, DaskArray): + assert isinstance(layer, DaskArray) + layer = layer.compute() + adata.X = adata.X.compute() + + # Determine expected sparsity concisely + if metric in {"mean", "var"}: + expected_sparse = False + elif metric in {"count_nonzero", "sum"}: + expected_sparse = isinstance(adata.X, CSBase) and keep_sparse + print(keep_sparse, expected_sparse, isinstance(adata.X, CSBase)) + else: + expected_sparse = False + + if expected_sparse: + assert isinstance(layer, CSBase) + else: + assert isinstance(layer, np.ndarray) def test_aggregate_obsm_varm() -> None: @@ -544,3 +566,42 @@ def test_nan() -> None: "s2_control_C", ] assert adata_agg.obs["n_obs_aggregated"].tolist() == [1, 2, 1] + + +def _to_dense_array(x): + """Normalize various array-like objects to a dense numpy array for comparison. + + Handles `DaskArray` by computing, sparse matrices by converting to dense, + and ensures a numpy ndarray is returned. + """ + if isinstance(x, DaskArray): + x = x.compute() + if isinstance(x, CSBase): + x = x.toarray() + return np.asarray(x) + + +@needs.dask +@pytest.mark.parametrize("array_type", ARRAY_TYPES_DASK) +def test_aggregate_dask_vs_regular( + array_type, metric: AggType, request: pytest.FixtureRequest +): + adata = pbmc3k_processed().raw.to_adata() + adata = adata[ + adata.obs["louvain"].isin(adata.obs["louvain"].cat.categories[:5]), :1_000 + ].copy() + + # expected result + expected = sc.get.aggregate(adata, ["louvain"], metric) + + # create dask array + adata.X = array_type(adata.X) + xfail_dask_median(adata, metric, request) + + # dask result + dask_res = sc.get.aggregate(adata, ["louvain"], metric) + + # check results + a = _to_dense_array(expected.layers[metric]) + b = _to_dense_array(dask_res.layers[metric]) + np.testing.assert_allclose(a, b, atol=1e-6)