From 86e58933b4e53191fe3f3162d5105a62897a857c Mon Sep 17 00:00:00 2001 From: alok-108 Date: Thu, 28 May 2026 02:25:21 +0530 Subject: [PATCH 1/2] fix: improve Poisson solver IVP robustness (#162) - Adaptive r_interval based on density cumulative integral - Fallback ODE solver methods chain - Added test suite for IVP robustness Closes #162 --- src/grid/ode.py | 39 ++++-- src/grid/poisson.py | 60 ++++++++- src/grid/tests/test_poisson_ivp_robustness.py | 124 ++++++++++++++++++ 3 files changed, 207 insertions(+), 16 deletions(-) create mode 100644 src/grid/tests/test_poisson_ivp_robustness.py diff --git a/src/grid/ode.py b/src/grid/ode.py index a5219bd8f..3d6276ff4 100644 --- a/src/grid/ode.py +++ b/src/grid/ode.py @@ -144,20 +144,31 @@ def func(x, y): ) y0 = np.hstack(([y0[0]], y_derivs)) - res = solve_ivp( - func, - x_span, - y0=y0, - dense_output=True, - vectorized=True, - rtol=rtol, - atol=atol, - method=method, - ) - - # raise error if didn't converge - if res.status != 0: - raise ValueError(f"The ode solver didn't converge, got status: {res.status}") + methods_chain = ["RK45", "DOP853", "BDF", "Radau", "LSODA"] + if method in methods_chain: + methods_chain.remove(method) + methods_chain.insert(0, method) + + res = None + for m in methods_chain: + res = solve_ivp( + func, + x_span, + y0=y0, + dense_output=True, + vectorized=True, + rtol=rtol, + atol=atol, + method=m, + ) + if res.status == 0: + break + warnings.warn( + f"ODE solver method {m} failed with status {res.status}. Trying next method.", + stacklevel=2, + ) + else: + raise RuntimeError(f"All ODE solver methods failed. Last status: {res.status}") if transform is not None: # Transform the function so that it's input is the original variable and diff --git a/src/grid/poisson.py b/src/grid/poisson.py index 8b7871ee1..06c099d64 100644 --- a/src/grid/poisson.py +++ b/src/grid/poisson.py @@ -102,13 +102,37 @@ def _solve_poisson_ivp_atomgrid( # Following takes the integral of f(x), to generate the bounds at very large r. # The calculation of the bounds is explained below. sph_o_l = generate_real_spherical_harmonics(0, np.array([0.1]), np.array([0.1])) + + # Adaptive r_interval[0] based on density decay + r_pts = atomgrid.rgrid.points + density_abs = np.abs(func_vals) + + # func_vals is 1D array over the 3D grid (N_radial * N_angular). + # Reshape and sum over angular points to get radial density distribution. + n_radial = len(r_pts) + if len(density_abs) % n_radial == 0: + n_angular = len(density_abs) // n_radial + # In grid, points are typically ordered (r_0, ang_0), (r_0, ang_1) etc. + # So reshape to (n_radial, n_angular) and sum + density_abs = density_abs.reshape(n_radial, n_angular).sum(axis=1) + + cumsum = np.cumsum(density_abs) + if cumsum[-1] > 0: + idx = np.searchsorted(cumsum, 0.99 * cumsum[-1]) + r_99 = r_pts[min(idx, len(r_pts) - 1)] + r_max = r_99 * 10 + if r_max > r_interval[0]: + r_max = r_interval[0] + if transform is not None and r_max > transform.domain[1]: + r_max = transform.domain[1] + r_interval = (r_max, r_interval[1]) + r_max = r_interval[0] boundary = atomgrid.integrate(func_vals) / sph_o_l[0, 0] # Set up default ode parameters if it isn't set up already. if ode_params is None: ode_params = dict({}) - ode_params.setdefault("method", "DOP853") ode_params.setdefault("rtol", 1e-8) ode_params.setdefault("atol", 1e-6) @@ -140,6 +164,21 @@ def coeff_1(r): else: ivp = [0.0, 0.0] + # If the source term is effectively zero for l>0, skip integration to avoid + # numerical noise amplification in the IVP solver. + max_val = np.max(np.abs(radial_components[i_spline](atomgrid.rgrid.points))) + if l_deg > 0 and max_val < 1e-12: + def make_zero_spline(): + def zero_spline(r_pts, deriv=0): + return np.zeros_like(r_pts) + return zero_spline + splines.append(make_zero_spline()) + i_spline += 1 + continue + + if not isinstance(ode_params, dict): + ode_params = {} + ode_params.setdefault("method", "DOP853") # Solve ode u_lm = solve_ode_ivp( r_interval, @@ -151,8 +190,25 @@ def coeff_1(r): **ode_params, ) + def make_safe_spline(spline_orig, is_monopole, boundary_val, r_max_val, r_min_val): + def safe_spline(r_pts): + r_pts_clipped = np.clip(r_pts, r_min_val, r_max_val) + vals = spline_orig(r_pts_clipped) + # For r > r_max, return analytic continuation to avoid extrapolation blowup + mask = r_pts > r_max_val + if np.any(mask): + # Ensure we don't modify the array in-place if it's read-only + vals = np.array(vals) + if is_monopole: + vals[mask] = boundary_val / r_pts[mask] + else: + vals[mask] = 0.0 + return vals + return safe_spline + i_spline += 1 - splines.append(u_lm) + splines.append(make_safe_spline(u_lm, (l_deg == 0 and m_ord == 0), boundary, r_max, r_interval[1])) + def interpolate(points): # Need atomgrid to center the points to the atomic grid, then convert to spherical. diff --git a/src/grid/tests/test_poisson_ivp_robustness.py b/src/grid/tests/test_poisson_ivp_robustness.py new file mode 100644 index 000000000..238838958 --- /dev/null +++ b/src/grid/tests/test_poisson_ivp_robustness.py @@ -0,0 +1,124 @@ +import numpy as np +import pytest +from numpy.testing import assert_allclose +import warnings +from unittest.mock import patch + +from grid.atomgrid import AtomGrid +from grid.onedgrid import Trapezoidal +from grid.rtransform import LinearFiniteRTransform, InverseRTransform +from grid.poisson import solve_poisson_bvp, solve_poisson_ivp, _solve_poisson_ivp_atomgrid + +def gauss_density(pts, alpha=1.0): + r = np.linalg.norm(pts, axis=1) + return np.exp(-alpha * r**2) + +def exp_density(pts, alpha=1.0): + r = np.linalg.norm(pts, axis=1) + return np.exp(-alpha * r) + +def setup_grid(): + oned = Trapezoidal(250) + btf = LinearFiniteRTransform(1e-3, 20.0) + radial = btf.transform_1d_grid(oned) + atgrid = AtomGrid(radial, center=np.array([0.0, 0.0, 0.0]), degrees=[11]) + return atgrid, btf + +def test_ivp_gaussian_vs_bvp(): + atgrid, btf = setup_grid() + density = gauss_density(atgrid.points, alpha=1.0) + + pot_ivp = solve_poisson_ivp( + atgrid, + density, + InverseRTransform(btf), + r_interval=(20.0, 1e-3) + ) + + pot_bvp = solve_poisson_bvp( + atgrid, + density, + InverseRTransform(btf), + include_origin=True + ) + + pts = atgrid.points + val_ivp = pot_ivp(pts) + val_bvp = pot_bvp(pts) + + # Check that they match + assert_allclose(val_ivp, val_bvp, rtol=1e-2, atol=1e-2) + +def test_ivp_fallback_warning(): + atgrid, btf = setup_grid() + density = exp_density(atgrid.points, alpha=100.0) + + import grid.ode + original_solve_ivp = grid.ode.solve_ivp + + call_count = 0 + def mock_solve_ivp(fun, t_span, y0, method, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + class DummyRes: + status = -1 + message = "Failed" + return DummyRes() + return original_solve_ivp(fun, t_span, y0, method=method, **kwargs) + + with patch("grid.ode.solve_ivp", side_effect=mock_solve_ivp): + with pytest.warns(UserWarning, match="ODE solver method.*failed.*Trying next method"): + pot_ivp = solve_poisson_ivp( + atgrid, + density, + InverseRTransform(btf), + r_interval=(20.0, 1e-3) + ) + + val = pot_ivp(atgrid.points) + assert not np.any(np.isnan(val)) + +def test_adaptive_r_interval(): + atgrid, btf = setup_grid() + + density_compact = exp_density(atgrid.points, alpha=5.0) + density_diffuse = exp_density(atgrid.points, alpha=0.1) + + used_r_intervals = [] + + original_solve = _solve_poisson_ivp_atomgrid + + def mock_solve(atomgrid, func_vals, transform, r_interval, ode_params=None): + r_pts = atomgrid.rgrid.points + density_abs = np.abs(func_vals) + n_radial = len(r_pts) + if len(density_abs) % n_radial == 0: + n_angular = len(density_abs) // n_radial + density_abs = density_abs.reshape(n_radial, n_angular).sum(axis=1) + cumsum = np.cumsum(density_abs) + if cumsum[-1] > 0: + idx = np.searchsorted(cumsum, 0.99 * cumsum[-1]) + r_99 = r_pts[min(idx, len(r_pts) - 1)] + r_max = r_99 * 10 + if r_max > r_interval[0]: + r_max = r_interval[0] + if transform is not None and r_max > transform.domain[1]: + r_max = transform.domain[1] + used_r_intervals.append(r_max) + else: + used_r_intervals.append(r_interval[0]) + return original_solve(atomgrid, func_vals, transform, r_interval, ode_params) + + import grid.poisson + grid.poisson._solve_poisson_ivp_atomgrid = mock_solve + + try: + solve_poisson_ivp(atgrid, density_compact, InverseRTransform(btf), r_interval=(20.0, 1e-3)) + solve_poisson_ivp(atgrid, density_diffuse, InverseRTransform(btf), r_interval=(20.0, 1e-3)) + finally: + grid.poisson._solve_poisson_ivp_atomgrid = original_solve + + assert len(used_r_intervals) == 2 + r_max_compact, r_max_diffuse = used_r_intervals + assert r_max_compact < r_max_diffuse From 8796e47e6471d8b667051b5022c53be5632e7981 Mon Sep 17 00:00:00 2001 From: alok-108 Date: Thu, 28 May 2026 11:13:14 +0530 Subject: [PATCH 2/2] Address review comments: fix r_max calculation, clean up dead code, improve tests --- src/grid/ode.py | 52 +++++++++++++------ src/grid/poisson.py | 26 +++------- src/grid/tests/test_poisson_ivp_robustness.py | 43 ++++----------- 3 files changed, 55 insertions(+), 66 deletions(-) diff --git a/src/grid/ode.py b/src/grid/ode.py index 3d6276ff4..46761ce78 100644 --- a/src/grid/ode.py +++ b/src/grid/ode.py @@ -56,6 +56,7 @@ def solve_ode_ivp( y0: list | np.ndarray, transform: BaseTransform = None, method: str = "DOP853", + fallback: bool = True, no_derivatives: bool = False, rtol: float = 1e-8, atol: float = 1e-6, @@ -81,6 +82,10 @@ def solve_ode_ivp( method : str The method used to solve the ode by scipy. See `scipy.integrate.solve_ivp` function for more info. + fallback : bool, optional + If True (default), upon failure of the chosen method, a sequence of + alternative solvers (RK45, DOP853, BDF, Radau, LSODA) is tried automatically. + If False, an exception is raised immediately. no_derivatives : bool, optional If true, when transform is used then it only returns the solution :math:`y(x)` rather than its derivative. If false, it includes the derivatives up to :math:`P-1`. @@ -144,13 +149,7 @@ def func(x, y): ) y0 = np.hstack(([y0[0]], y_derivs)) - methods_chain = ["RK45", "DOP853", "BDF", "Radau", "LSODA"] - if method in methods_chain: - methods_chain.remove(method) - methods_chain.insert(0, method) - - res = None - for m in methods_chain: + if not fallback: res = solve_ivp( func, x_span, @@ -159,16 +158,39 @@ def func(x, y): vectorized=True, rtol=rtol, atol=atol, - method=m, - ) - if res.status == 0: - break - warnings.warn( - f"ODE solver method {m} failed with status {res.status}. Trying next method.", - stacklevel=2, + method=method, ) + if res.status != 0: + raise ValueError(f"The ode solver didn't converge, got status: {res.status}") else: - raise RuntimeError(f"All ODE solver methods failed. Last status: {res.status}") + methods_chain = ["RK45", "DOP853", "BDF", "Radau", "LSODA"] + if method in methods_chain: + methods_chain.remove(method) + methods_chain.insert(0, method) + + res = None + last_status = None + for idx, m in enumerate(methods_chain): + res = solve_ivp( + func, + x_span, + y0=y0, + dense_output=True, + vectorized=True, + rtol=rtol, + atol=atol, + method=m, + ) + if res.status == 0: + break + last_status = res.status + if idx < len(methods_chain) - 1: + warnings.warn( + f"ODE solver method {m} failed with status {res.status}. Trying next method.", + stacklevel=2, + ) + else: + raise RuntimeError(f"All ODE solver methods failed. Last status: {last_status}") if transform is not None: # Transform the function so that it's input is the original variable and diff --git a/src/grid/poisson.py b/src/grid/poisson.py index 06c099d64..84b85d9e9 100644 --- a/src/grid/poisson.py +++ b/src/grid/poisson.py @@ -105,18 +105,13 @@ def _solve_poisson_ivp_atomgrid( # Adaptive r_interval[0] based on density decay r_pts = atomgrid.rgrid.points - density_abs = np.abs(func_vals) + radial_weights = atomgrid.rgrid.weights - # func_vals is 1D array over the 3D grid (N_radial * N_angular). - # Reshape and sum over angular points to get radial density distribution. - n_radial = len(r_pts) - if len(density_abs) % n_radial == 0: - n_angular = len(density_abs) // n_radial - # In grid, points are typically ordered (r_0, ang_0), (r_0, ang_1) etc. - # So reshape to (n_radial, n_angular) and sum - density_abs = density_abs.reshape(n_radial, n_angular).sum(axis=1) + # Use the 0th radial component (spherically averaged density) with integration weights + radial_density_abs = np.abs(radial_components[0](r_pts)) + integrand = radial_density_abs * (r_pts**2) * radial_weights - cumsum = np.cumsum(density_abs) + cumsum = np.cumsum(integrand) if cumsum[-1] > 0: idx = np.searchsorted(cumsum, 0.99 * cumsum[-1]) r_99 = r_pts[min(idx, len(r_pts) - 1)] @@ -133,6 +128,7 @@ def _solve_poisson_ivp_atomgrid( # Set up default ode parameters if it isn't set up already. if ode_params is None: ode_params = dict({}) + ode_params.setdefault("method", "DOP853") ode_params.setdefault("rtol", 1e-8) ode_params.setdefault("atol", 1e-6) @@ -168,17 +164,9 @@ def coeff_1(r): # numerical noise amplification in the IVP solver. max_val = np.max(np.abs(radial_components[i_spline](atomgrid.rgrid.points))) if l_deg > 0 and max_val < 1e-12: - def make_zero_spline(): - def zero_spline(r_pts, deriv=0): - return np.zeros_like(r_pts) - return zero_spline - splines.append(make_zero_spline()) + splines.append(lambda r_pts: np.zeros_like(r_pts)) i_spline += 1 continue - - if not isinstance(ode_params, dict): - ode_params = {} - ode_params.setdefault("method", "DOP853") # Solve ode u_lm = solve_ode_ivp( r_interval, diff --git a/src/grid/tests/test_poisson_ivp_robustness.py b/src/grid/tests/test_poisson_ivp_robustness.py index 238838958..30f75c98d 100644 --- a/src/grid/tests/test_poisson_ivp_robustness.py +++ b/src/grid/tests/test_poisson_ivp_robustness.py @@ -87,38 +87,17 @@ def test_adaptive_r_interval(): used_r_intervals = [] - original_solve = _solve_poisson_ivp_atomgrid - - def mock_solve(atomgrid, func_vals, transform, r_interval, ode_params=None): - r_pts = atomgrid.rgrid.points - density_abs = np.abs(func_vals) - n_radial = len(r_pts) - if len(density_abs) % n_radial == 0: - n_angular = len(density_abs) // n_radial - density_abs = density_abs.reshape(n_radial, n_angular).sum(axis=1) - cumsum = np.cumsum(density_abs) - if cumsum[-1] > 0: - idx = np.searchsorted(cumsum, 0.99 * cumsum[-1]) - r_99 = r_pts[min(idx, len(r_pts) - 1)] - r_max = r_99 * 10 - if r_max > r_interval[0]: - r_max = r_interval[0] - if transform is not None and r_max > transform.domain[1]: - r_max = transform.domain[1] - used_r_intervals.append(r_max) - else: - used_r_intervals.append(r_interval[0]) - return original_solve(atomgrid, func_vals, transform, r_interval, ode_params) - - import grid.poisson - grid.poisson._solve_poisson_ivp_atomgrid = mock_solve - - try: + import grid.ode + original_solve_ode_ivp = grid.ode.solve_ode_ivp + def mock_solve_ode_ivp(r_interval, *args, **kwargs): + used_r_intervals.append(r_interval) + return original_solve_ode_ivp(r_interval, *args, **kwargs) + + with patch("grid.poisson.solve_ode_ivp", side_effect=mock_solve_ode_ivp): solve_poisson_ivp(atgrid, density_compact, InverseRTransform(btf), r_interval=(20.0, 1e-3)) solve_poisson_ivp(atgrid, density_diffuse, InverseRTransform(btf), r_interval=(20.0, 1e-3)) - finally: - grid.poisson._solve_poisson_ivp_atomgrid = original_solve - - assert len(used_r_intervals) == 2 - r_max_compact, r_max_diffuse = used_r_intervals + + assert len(used_r_intervals) >= 2 + r_max_compact = used_r_intervals[0][0] + r_max_diffuse = used_r_intervals[-1][0] assert r_max_compact < r_max_diffuse