Skip to content
53 changes: 40 additions & 13 deletions src/squidpy/gr/_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def test(
# much faster than applymap (tested on 1M interactions)
interactions_ = np.vectorize(lambda g: gene_mapper[g])(interactions.values)

rng = np.random.default_rng(seed)
n_jobs = _get_n_cores(kwargs.pop("n_jobs", None))
start = logg.info(
f"Running `{n_perms}` permutations on `{len(interactions)}` interactions "
Expand All @@ -420,7 +421,7 @@ def test(
clusters_,
threshold=threshold,
n_perms=n_perms,
seed=seed,
rng=rng,
n_jobs=n_jobs,
numba_parallel=numba_parallel,
**kwargs,
Expand Down Expand Up @@ -642,7 +643,17 @@ def ligrec(
copy: bool = False,
key_added: str | None = None,
gene_symbols: str | None = None,
**kwargs: Any,
n_perms: int = 1000,
seed: int | None = None,
clusters: Cluster_t | None = None,
alpha: float = 0.05,
numba_parallel: bool | None = None,
n_jobs: int | None = None,
backend: str = "loky",
show_progress_bar: bool = True,
interactions_params: Mapping[str, Any] = MappingProxyType({}),
transmitter_params: Mapping[str, Any] = MappingProxyType({"categories": "ligand"}),
receiver_params: Mapping[str, Any] = MappingProxyType({"categories": "receptor"}),
) -> Mapping[str, pd.DataFrame] | None:
"""
%(PT_test.full_desc)s
Expand All @@ -664,15 +675,28 @@ def ligrec(
with _genesymbols(adata, key=gene_symbols, use_raw=use_raw, make_unique=False):
return ( # type: ignore[no-any-return]
PermutationTest(adata, use_raw=use_raw)
.prepare(interactions, complex_policy=complex_policy, **kwargs)
.prepare(
interactions,
complex_policy=complex_policy,
interactions_params=interactions_params,
transmitter_params=transmitter_params,
receiver_params=receiver_params,
)
.test(
cluster_key=cluster_key,
clusters=clusters,
n_perms=n_perms,
threshold=threshold,
seed=seed,
corr_method=corr_method,
corr_axis=corr_axis,
alpha=alpha,
copy=copy,
key_added=key_added,
**kwargs,
numba_parallel=numba_parallel,
n_jobs=n_jobs,
backend=backend,
show_progress_bar=show_progress_bar,
)
)

Expand All @@ -682,9 +706,9 @@ def _analysis(
data: pd.DataFrame,
interactions: NDArrayA,
interaction_clusters: NDArrayA,
rng: np.random.Generator,
threshold: float = 0.1,
n_perms: int = 1000,
seed: int | None = None,
n_jobs: int = 1,
numba_parallel: bool | None = None,
**kwargs: Any,
Expand All @@ -705,7 +729,8 @@ def _analysis(
threshold
Percentage threshold for removing lowly expressed genes in clusters.
%(n_perms)s
%(seed)s
rng
NumPy :class:`numpy.random.Generator` for reproducibility.
n_jobs
Number of parallel jobs to launch.
numba_parallel
Expand Down Expand Up @@ -750,6 +775,7 @@ def extractor(res: Sequence[TempResult]) -> TempResult:

# (n_cells, n_genes)
data = np.array(data[data.columns.difference(["clusters"])].values, dtype=np.float64, order="C")
root_seed = rng.integers(np.iinfo(np.int64).max)
# all 3 should be C contiguous
return parallelize( # type: ignore[no-any-return]
_analysis_helper,
Expand All @@ -765,7 +791,7 @@ def extractor(res: Sequence[TempResult]) -> TempResult:
interactions,
interaction_clusters=interaction_clusters,
clustering=clustering,
seed=seed,
root_seed=root_seed,
numba_parallel=numba_parallel,
)

Expand All @@ -778,7 +804,7 @@ def _analysis_helper(
interactions: NDArrayA,
interaction_clusters: NDArrayA,
clustering: NDArrayA,
seed: int | None = None,
root_seed: int,
numba_parallel: bool | None = None,
queue: SigQueue | None = None,
) -> TempResult:
Expand All @@ -788,7 +814,7 @@ def _analysis_helper(
Parameters
----------
perms
Permutation indices. Only used to set the ``seed``.
Permutation indices. Only used to differentiate workers/permutations.
data
Array of shape `(n_cells, n_genes)`.
mean
Expand All @@ -802,8 +828,9 @@ def _analysis_helper(
Array of shape `(n_interaction_clusters, 2)`.
clustering
Array of shape `(n_cells,)` containing the original clustering.
seed
Random seed for :class:`numpy.random.RandomState`.
root_seed
Integer seed derived from the root generator. Each worker creates
an independent stream via ``default_rng([perms[0], root_seed])``.
numba_parallel
Whether to use :func:`numba.prange` or not. If `None`, it's determined automatically.
queue
Expand All @@ -818,7 +845,7 @@ def _analysis_helper(
- `'pvalues'` - array of shape `(n_interactions, n_interaction_clusters)` containing `np.sum(T0 > T)`
where `T0` is the test statistic under null hypothesis and `T` is the true test statistic.
"""
rs = np.random.RandomState(None if seed is None else perms[0] + seed)
rng = np.random.default_rng([perms[0], root_seed])

clustering = clustering.copy()
n_cls = mean.shape[1]
Expand Down Expand Up @@ -847,7 +874,7 @@ def _analysis_helper(
test = _test

for _ in perms:
rs.shuffle(clustering)
rng.shuffle(clustering)
error = test(interactions, interaction_clusters, data, clustering, mean, mask, res=res)
if error:
raise ValueError("In the execution of the numba function, an unhandled case was encountered. ")
Expand Down
Binary file modified tests/_data/ligrec_no_numba.pickle
Binary file not shown.
Binary file added tests/_data/ligrec_pvalues_reference.pickle
Binary file not shown.
Binary file modified tests/_images/Ligrec_pvalue_threshold.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Ligrec_remove_nonsig_interactions.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ def ligrec_no_numba() -> Mapping[str, pd.DataFrame]:
return {"means": data[0], "pvalues": data[1], "metadata": data[2]}


@pytest.fixture(scope="session")
def ligrec_pvalues_reference() -> Mapping[str, pd.DataFrame]:
with open("tests/_data/ligrec_pvalues_reference.pickle", "rb") as fin:
return pickle.load(fin)
Comment on lines +268 to +269
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we store this as a zarr file or hdf5 or even just an anndata?



@pytest.fixture(scope="session")
def ligrec_result() -> Mapping[str, pd.DataFrame]:
adata = _adata.copy()
Expand Down
17 changes: 17 additions & 0 deletions tests/graph/test_ligrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,23 @@ def test_reproducibility_numba_off(
np.testing.assert_allclose(r["pvalues"], ligrec_no_numba["pvalues"])
np.testing.assert_array_equal(np.where(np.isnan(r["pvalues"])), np.where(np.isnan(ligrec_no_numba["pvalues"])))

def test_pvalues_reference(
self, adata: AnnData, interactions: Interactions_t, ligrec_pvalues_reference: Mapping[str, pd.DataFrame]
):
r = ligrec(
adata, _CK, interactions=interactions, n_perms=25, copy=True, show_progress_bar=False, seed=42, n_jobs=1
)
np.testing.assert_array_equal(r["means"].index, ligrec_pvalues_reference["means"].index)
np.testing.assert_array_equal(r["means"].columns, ligrec_pvalues_reference["means"].columns)
np.testing.assert_array_equal(r["pvalues"].index, ligrec_pvalues_reference["pvalues"].index)
np.testing.assert_array_equal(r["pvalues"].columns, ligrec_pvalues_reference["pvalues"].columns)

np.testing.assert_allclose(r["means"], ligrec_pvalues_reference["means"])
np.testing.assert_allclose(r["pvalues"], ligrec_pvalues_reference["pvalues"])
np.testing.assert_array_equal(
np.where(np.isnan(r["pvalues"])), np.where(np.isnan(ligrec_pvalues_reference["pvalues"]))
)

def test_logging(self, adata: AnnData, interactions: Interactions_t, capsys):
s.logfile = sys.stderr
s.verbosity = 4
Expand Down