Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions src/grid/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,31 @@ def func(x, y):
)
y0 = np.hstack(([y0[0]], y_derivs))

res = solve_ivp(
func,
x_span,
y0=y0,
dense_output=True,
vectorized=True,
rtol=rtol,
atol=atol,
method=method,
)

# raise error if didn't converge
if res.status != 0:
raise ValueError(f"The ode solver didn't converge, got status: {res.status}")
methods_chain = ["RK45", "DOP853", "BDF", "Radau", "LSODA"]
if method in methods_chain:
methods_chain.remove(method)
methods_chain.insert(0, method)

res = None
for m in methods_chain:
res = solve_ivp(
func,
x_span,
y0=y0,
dense_output=True,
vectorized=True,
rtol=rtol,
atol=atol,
method=m,
)
if res.status == 0:
break
warnings.warn(
f"ODE solver method {m} failed with status {res.status}. Trying next method.",
stacklevel=2,
)
else:
raise RuntimeError(f"All ODE solver methods failed. Last status: {res.status}")
Comment thread
alok-108 marked this conversation as resolved.
Outdated
Comment thread
alok-108 marked this conversation as resolved.
Outdated

if transform is not None:
# Transform the function so that it's input is the original variable and
Expand Down
60 changes: 58 additions & 2 deletions src/grid/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,37 @@ def _solve_poisson_ivp_atomgrid(
# Following takes the integral of f(x), to generate the bounds at very large r.
# The calculation of the bounds is explained below.
sph_o_l = generate_real_spherical_harmonics(0, np.array([0.1]), np.array([0.1]))

# Adaptive r_interval[0] based on density decay
r_pts = atomgrid.rgrid.points
density_abs = np.abs(func_vals)

# func_vals is 1D array over the 3D grid (N_radial * N_angular).
# Reshape and sum over angular points to get radial density distribution.
n_radial = len(r_pts)
if len(density_abs) % n_radial == 0:
n_angular = len(density_abs) // n_radial
# In grid, points are typically ordered (r_0, ang_0), (r_0, ang_1) etc.
# So reshape to (n_radial, n_angular) and sum
density_abs = density_abs.reshape(n_radial, n_angular).sum(axis=1)

Comment thread
alok-108 marked this conversation as resolved.
cumsum = np.cumsum(density_abs)
if cumsum[-1] > 0:
idx = np.searchsorted(cumsum, 0.99 * cumsum[-1])
r_99 = r_pts[min(idx, len(r_pts) - 1)]
r_max = r_99 * 10
if r_max > r_interval[0]:
r_max = r_interval[0]
if transform is not None and r_max > transform.domain[1]:
r_max = transform.domain[1]
r_interval = (r_max, r_interval[1])
Comment thread
alok-108 marked this conversation as resolved.
Outdated

r_max = r_interval[0]
boundary = atomgrid.integrate(func_vals) / sph_o_l[0, 0]

# Set up default ode parameters if it isn't set up already.
if ode_params is None:
ode_params = dict({})
ode_params.setdefault("method", "DOP853")
ode_params.setdefault("rtol", 1e-8)
ode_params.setdefault("atol", 1e-6)

Expand Down Expand Up @@ -140,6 +164,21 @@ def coeff_1(r):
else:
ivp = [0.0, 0.0]

# If the source term is effectively zero for l>0, skip integration to avoid
# numerical noise amplification in the IVP solver.
max_val = np.max(np.abs(radial_components[i_spline](atomgrid.rgrid.points)))
if l_deg > 0 and max_val < 1e-12:
def make_zero_spline():
def zero_spline(r_pts, deriv=0):
return np.zeros_like(r_pts)
return zero_spline
splines.append(make_zero_spline())
Comment thread
alok-108 marked this conversation as resolved.
Outdated
i_spline += 1
continue

if not isinstance(ode_params, dict):
ode_params = {}
ode_params.setdefault("method", "DOP853")
Comment thread
alok-108 marked this conversation as resolved.
Outdated
# Solve ode
u_lm = solve_ode_ivp(
r_interval,
Expand All @@ -151,8 +190,25 @@ def coeff_1(r):
**ode_params,
)

def make_safe_spline(spline_orig, is_monopole, boundary_val, r_max_val, r_min_val):
def safe_spline(r_pts):
r_pts_clipped = np.clip(r_pts, r_min_val, r_max_val)
vals = spline_orig(r_pts_clipped)
# For r > r_max, return analytic continuation to avoid extrapolation blowup
mask = r_pts > r_max_val
if np.any(mask):
# Ensure we don't modify the array in-place if it's read-only
vals = np.array(vals)
if is_monopole:
vals[mask] = boundary_val / r_pts[mask]
else:
vals[mask] = 0.0
return vals
return safe_spline

i_spline += 1
splines.append(u_lm)
splines.append(make_safe_spline(u_lm, (l_deg == 0 and m_ord == 0), boundary, r_max, r_interval[1]))


def interpolate(points):
# Need atomgrid to center the points to the atomic grid, then convert to spherical.
Expand Down
124 changes: 124 additions & 0 deletions src/grid/tests/test_poisson_ivp_robustness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose
import warnings
from unittest.mock import patch

from grid.atomgrid import AtomGrid
from grid.onedgrid import Trapezoidal
from grid.rtransform import LinearFiniteRTransform, InverseRTransform
from grid.poisson import solve_poisson_bvp, solve_poisson_ivp, _solve_poisson_ivp_atomgrid

def gauss_density(pts, alpha=1.0):
r = np.linalg.norm(pts, axis=1)
return np.exp(-alpha * r**2)

def exp_density(pts, alpha=1.0):
r = np.linalg.norm(pts, axis=1)
return np.exp(-alpha * r)

def setup_grid():
oned = Trapezoidal(250)
btf = LinearFiniteRTransform(1e-3, 20.0)
radial = btf.transform_1d_grid(oned)
atgrid = AtomGrid(radial, center=np.array([0.0, 0.0, 0.0]), degrees=[11])
return atgrid, btf

def test_ivp_gaussian_vs_bvp():
atgrid, btf = setup_grid()
density = gauss_density(atgrid.points, alpha=1.0)

pot_ivp = solve_poisson_ivp(
atgrid,
density,
InverseRTransform(btf),
r_interval=(20.0, 1e-3)
)

pot_bvp = solve_poisson_bvp(
atgrid,
density,
InverseRTransform(btf),
include_origin=True
)

pts = atgrid.points
val_ivp = pot_ivp(pts)
val_bvp = pot_bvp(pts)

# Check that they match
assert_allclose(val_ivp, val_bvp, rtol=1e-2, atol=1e-2)

def test_ivp_fallback_warning():
atgrid, btf = setup_grid()
density = exp_density(atgrid.points, alpha=100.0)

import grid.ode
original_solve_ivp = grid.ode.solve_ivp

call_count = 0
def mock_solve_ivp(fun, t_span, y0, method, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
class DummyRes:
status = -1
message = "Failed"
return DummyRes()
return original_solve_ivp(fun, t_span, y0, method=method, **kwargs)

with patch("grid.ode.solve_ivp", side_effect=mock_solve_ivp):
with pytest.warns(UserWarning, match="ODE solver method.*failed.*Trying next method"):
pot_ivp = solve_poisson_ivp(
atgrid,
density,
InverseRTransform(btf),
r_interval=(20.0, 1e-3)
)

val = pot_ivp(atgrid.points)
assert not np.any(np.isnan(val))

def test_adaptive_r_interval():
atgrid, btf = setup_grid()

density_compact = exp_density(atgrid.points, alpha=5.0)
density_diffuse = exp_density(atgrid.points, alpha=0.1)

used_r_intervals = []

original_solve = _solve_poisson_ivp_atomgrid

def mock_solve(atomgrid, func_vals, transform, r_interval, ode_params=None):
r_pts = atomgrid.rgrid.points
density_abs = np.abs(func_vals)
n_radial = len(r_pts)
if len(density_abs) % n_radial == 0:
n_angular = len(density_abs) // n_radial
density_abs = density_abs.reshape(n_radial, n_angular).sum(axis=1)
cumsum = np.cumsum(density_abs)
if cumsum[-1] > 0:
idx = np.searchsorted(cumsum, 0.99 * cumsum[-1])
r_99 = r_pts[min(idx, len(r_pts) - 1)]
r_max = r_99 * 10
if r_max > r_interval[0]:
r_max = r_interval[0]
if transform is not None and r_max > transform.domain[1]:
r_max = transform.domain[1]
used_r_intervals.append(r_max)
else:
used_r_intervals.append(r_interval[0])
return original_solve(atomgrid, func_vals, transform, r_interval, ode_params)

import grid.poisson
grid.poisson._solve_poisson_ivp_atomgrid = mock_solve

try:
solve_poisson_ivp(atgrid, density_compact, InverseRTransform(btf), r_interval=(20.0, 1e-3))
solve_poisson_ivp(atgrid, density_diffuse, InverseRTransform(btf), r_interval=(20.0, 1e-3))
finally:
grid.poisson._solve_poisson_ivp_atomgrid = original_solve

assert len(used_r_intervals) == 2
r_max_compact, r_max_diffuse = used_r_intervals
assert r_max_compact < r_max_diffuse
Comment thread
alok-108 marked this conversation as resolved.
Outdated