diff --git a/simpeg/optimization.py b/simpeg/optimization.py index 343ff132ab..020e9a0155 100644 --- a/simpeg/optimization.py +++ b/simpeg/optimization.py @@ -1529,7 +1529,7 @@ def __init__( cg_rtol: float = None, cg_atol: float = None, step_active_set: bool = True, - active_set_grad_scale: float = 1e-2, + active_set_grad_scale: float | np.ndarray = 1e-2, **kwargs, ): if (val := kwargs.pop("tolCG", None)) is not None: @@ -1619,9 +1619,16 @@ def active_set_grad_scale(self) -> float: @active_set_grad_scale.setter def active_set_grad_scale(self, value: float): - self._active_set_grad_scale = validate_float( - "active_set_grad_scale", value, min_val=0, inclusive_min=True - ) + try: + value = validate_float("active_set_grad_scale", value) + except TypeError: + value = validate_ndarray_with_shape( + "active_set_grad_scale", value, shape=("*",) + ) + if np.any(value < 0.0): + raise ValueError("active_set_grad_scale must be >= 0.0") + + self._active_set_grad_scale = value @timeIt def findSearchDirection(self):