Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/pyrecest/filters/euclidean_boxed_particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import copy
from collections.abc import Callable
from numbers import Integral
from typing import Union

import numpy as _np
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(
gaussian_sigma_scale: float = 3.0,
max_sampling_iterations: Union[int, int32, int64] = 100,
):
super().__init__(int(n_particles), int(dim))
super().__init__(n_particles, dim)
self.resampling_criterion = resampling_criterion
self.default_boxed_generation_method = self._validate_generation_method(
boxed_generation_method
Expand Down Expand Up @@ -490,10 +491,13 @@ def _coerce_box_result(result):

@staticmethod
def _validate_positive_int(value, name: str):
value = int(value)
if value <= 0:
raise ValueError(f"{name} must be positive")
return value
if (
isinstance(value, bool)
or not isinstance(value, Integral)
or int(value) <= 0
):
raise ValueError(f"{name} must be a positive integer")
return int(value)

@staticmethod
def _validate_positive_float(value, name: str):
Expand Down
39 changes: 39 additions & 0 deletions tests/filters/test_euclidean_boxed_particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,45 @@ def test_filter_state_remains_point_particles(self):
self.assertIsInstance(pf.filter_state, LinearDiracDistribution)
self.assertIs(BoxedParticleFilter, EuclideanBoxedParticleFilter)

def test_rejects_bool_and_nonintegral_particle_counts(self):
invalid_arguments = (
(True, 1),
(1.5, 1),
(2, True),
(2, 1.5),
)

for n_particles, dim in invalid_arguments:
with self.subTest(n_particles=n_particles, dim=dim):
with self.assertRaisesRegex(ValueError, "positive integer"):
EuclideanBoxedParticleFilter(n_particles, dim)

def test_sampling_controls_reject_bool_and_nonintegral_values(self):
invalid_values = (
("batch_size", True),
("batch_size", 1.5),
("max_sampling_iterations", True),
("max_sampling_iterations", 1.5),
("max_tries_per_particle", True),
("max_tries_per_particle", 1.5),
)

for name, value in invalid_values:
with self.subTest(name=name, value=value):
with self.assertRaisesRegex(
ValueError,
f"{name} must be a positive integer",
):
EuclideanBoxedParticleFilter._validate_positive_int(value, name)

self.assertEqual(
EuclideanBoxedParticleFilter._validate_positive_int(
np.int64(3),
"batch_size",
),
3,
)

def test_uniform_generation_places_point_particles_in_box(self):
random.seed(1)
pf = EuclideanBoxedParticleFilter(50, 2)
Expand Down
Loading