Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 47 additions & 14 deletions src/grid/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def solve_ode_ivp(
y0: list | np.ndarray,
transform: BaseTransform = None,
method: str = "DOP853",
fallback: bool = True,
no_derivatives: bool = False,
rtol: float = 1e-8,
atol: float = 1e-6,
Expand All @@ -81,6 +82,10 @@ def solve_ode_ivp(
method : str
The method used to solve the ode by scipy.
See `scipy.integrate.solve_ivp` function for more info.
fallback : bool, optional
If True (default), upon failure of the chosen method, a sequence of
alternative solvers (RK45, DOP853, BDF, Radau, LSODA) is tried automatically.
If False, an exception is raised immediately.
no_derivatives : bool, optional
If true, when transform is used then it only returns the solution :math:`y(x)` rather
than its derivative. If false, it includes the derivatives up to :math:`P-1`.
Expand Down Expand Up @@ -144,20 +149,48 @@ def func(x, y):
)
y0 = np.hstack(([y0[0]], y_derivs))

res = solve_ivp(
func,
x_span,
y0=y0,
dense_output=True,
vectorized=True,
rtol=rtol,
atol=atol,
method=method,
)

# raise error if didn't converge
if res.status != 0:
raise ValueError(f"The ode solver didn't converge, got status: {res.status}")
if not fallback:
res = solve_ivp(
func,
x_span,
y0=y0,
dense_output=True,
vectorized=True,
rtol=rtol,
atol=atol,
method=method,
)
if res.status != 0:
raise ValueError(f"The ode solver didn't converge, got status: {res.status}")
else:
methods_chain = ["RK45", "DOP853", "BDF", "Radau", "LSODA"]
if method in methods_chain:
methods_chain.remove(method)
methods_chain.insert(0, method)

res = None
last_status = None
for idx, m in enumerate(methods_chain):
res = solve_ivp(
func,
x_span,
y0=y0,
dense_output=True,
vectorized=True,
rtol=rtol,
atol=atol,
method=m,
)
if res.status == 0:
break
last_status = res.status
if idx < len(methods_chain) - 1:
warnings.warn(
f"ODE solver method {m} failed with status {res.status}. Trying next method.",
stacklevel=2,
)
else:
raise RuntimeError(f"All ODE solver methods failed. Last status: {last_status}")

if transform is not None:
# Transform the function so that it's input is the original variable and
Expand Down
46 changes: 45 additions & 1 deletion src/grid/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ def _solve_poisson_ivp_atomgrid(
# Following takes the integral of f(x), to generate the bounds at very large r.
# The calculation of the bounds is explained below.
sph_o_l = generate_real_spherical_harmonics(0, np.array([0.1]), np.array([0.1]))

# Adaptive r_interval[0] based on density decay
r_pts = atomgrid.rgrid.points
radial_weights = atomgrid.rgrid.weights

# Use the 0th radial component (spherically averaged density) with integration weights
radial_density_abs = np.abs(radial_components[0](r_pts))
integrand = radial_density_abs * (r_pts**2) * radial_weights

Comment thread
alok-108 marked this conversation as resolved.
cumsum = np.cumsum(integrand)
if cumsum[-1] > 0:
idx = np.searchsorted(cumsum, 0.99 * cumsum[-1])
r_99 = r_pts[min(idx, len(r_pts) - 1)]
r_max = r_99 * 10
if r_max > r_interval[0]:
r_max = r_interval[0]
if transform is not None and r_max > transform.domain[1]:
r_max = transform.domain[1]
r_interval = (r_max, r_interval[1])

r_max = r_interval[0]
boundary = atomgrid.integrate(func_vals) / sph_o_l[0, 0]

Expand Down Expand Up @@ -140,6 +160,13 @@ def coeff_1(r):
else:
ivp = [0.0, 0.0]

# If the source term is effectively zero for l>0, skip integration to avoid
# numerical noise amplification in the IVP solver.
max_val = np.max(np.abs(radial_components[i_spline](atomgrid.rgrid.points)))
if l_deg > 0 and max_val < 1e-12:
splines.append(lambda r_pts: np.zeros_like(r_pts))
i_spline += 1
continue
# Solve ode
u_lm = solve_ode_ivp(
r_interval,
Expand All @@ -151,8 +178,25 @@ def coeff_1(r):
**ode_params,
)

def make_safe_spline(spline_orig, is_monopole, boundary_val, r_max_val, r_min_val):
def safe_spline(r_pts):
r_pts_clipped = np.clip(r_pts, r_min_val, r_max_val)
vals = spline_orig(r_pts_clipped)
# For r > r_max, return analytic continuation to avoid extrapolation blowup
mask = r_pts > r_max_val
if np.any(mask):
# Ensure we don't modify the array in-place if it's read-only
vals = np.array(vals)
if is_monopole:
vals[mask] = boundary_val / r_pts[mask]
else:
vals[mask] = 0.0
return vals
return safe_spline

i_spline += 1
splines.append(u_lm)
splines.append(make_safe_spline(u_lm, (l_deg == 0 and m_ord == 0), boundary, r_max, r_interval[1]))


def interpolate(points):
# Need atomgrid to center the points to the atomic grid, then convert to spherical.
Expand Down
103 changes: 103 additions & 0 deletions src/grid/tests/test_poisson_ivp_robustness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose
import warnings
from unittest.mock import patch

from grid.atomgrid import AtomGrid
from grid.onedgrid import Trapezoidal
from grid.rtransform import LinearFiniteRTransform, InverseRTransform
from grid.poisson import solve_poisson_bvp, solve_poisson_ivp, _solve_poisson_ivp_atomgrid

def gauss_density(pts, alpha=1.0):
r = np.linalg.norm(pts, axis=1)
return np.exp(-alpha * r**2)

def exp_density(pts, alpha=1.0):
r = np.linalg.norm(pts, axis=1)
return np.exp(-alpha * r)

def setup_grid():
oned = Trapezoidal(250)
btf = LinearFiniteRTransform(1e-3, 20.0)
radial = btf.transform_1d_grid(oned)
atgrid = AtomGrid(radial, center=np.array([0.0, 0.0, 0.0]), degrees=[11])
return atgrid, btf

def test_ivp_gaussian_vs_bvp():
atgrid, btf = setup_grid()
density = gauss_density(atgrid.points, alpha=1.0)

pot_ivp = solve_poisson_ivp(
atgrid,
density,
InverseRTransform(btf),
r_interval=(20.0, 1e-3)
)

pot_bvp = solve_poisson_bvp(
atgrid,
density,
InverseRTransform(btf),
include_origin=True
)

pts = atgrid.points
val_ivp = pot_ivp(pts)
val_bvp = pot_bvp(pts)

# Check that they match
assert_allclose(val_ivp, val_bvp, rtol=1e-2, atol=1e-2)

def test_ivp_fallback_warning():
atgrid, btf = setup_grid()
density = exp_density(atgrid.points, alpha=100.0)

import grid.ode
original_solve_ivp = grid.ode.solve_ivp

call_count = 0
def mock_solve_ivp(fun, t_span, y0, method, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
class DummyRes:
status = -1
message = "Failed"
return DummyRes()
return original_solve_ivp(fun, t_span, y0, method=method, **kwargs)

with patch("grid.ode.solve_ivp", side_effect=mock_solve_ivp):
with pytest.warns(UserWarning, match="ODE solver method.*failed.*Trying next method"):
pot_ivp = solve_poisson_ivp(
atgrid,
density,
InverseRTransform(btf),
r_interval=(20.0, 1e-3)
)

val = pot_ivp(atgrid.points)
assert not np.any(np.isnan(val))

def test_adaptive_r_interval():
atgrid, btf = setup_grid()

density_compact = exp_density(atgrid.points, alpha=5.0)
density_diffuse = exp_density(atgrid.points, alpha=0.1)

used_r_intervals = []

import grid.ode
original_solve_ode_ivp = grid.ode.solve_ode_ivp
def mock_solve_ode_ivp(r_interval, *args, **kwargs):
used_r_intervals.append(r_interval)
return original_solve_ode_ivp(r_interval, *args, **kwargs)

with patch("grid.poisson.solve_ode_ivp", side_effect=mock_solve_ode_ivp):
solve_poisson_ivp(atgrid, density_compact, InverseRTransform(btf), r_interval=(20.0, 1e-3))
solve_poisson_ivp(atgrid, density_diffuse, InverseRTransform(btf), r_interval=(20.0, 1e-3))

assert len(used_r_intervals) >= 2
r_max_compact = used_r_intervals[0][0]
r_max_diffuse = used_r_intervals[-1][0]
assert r_max_compact < r_max_diffuse