diff --git a/changelog/163.feature.rst b/changelog/163.feature.rst new file mode 100644 index 00000000..d27e5bbc --- /dev/null +++ b/changelog/163.feature.rst @@ -0,0 +1 @@ +Adds unit handling capability to class::MatrixModel to support Astropy fitting. diff --git a/examples/fitting_simulated_data.py b/examples/fitting_simulated_data.py index 914bf9bc..1839df85 100644 --- a/examples/fitting_simulated_data.py +++ b/examples/fitting_simulated_data.py @@ -19,7 +19,9 @@ import numpy as np from matplotlib.colors import LogNorm +import astropy.units as u from astropy.modeling import fitting +from astropy.visualization import quantity_support from sunkit_spex.data.simulated_data import simulate_square_response_matrix from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func @@ -27,6 +29,8 @@ from sunkit_spex.fitting.statistics.gaussian import chi_squared from sunkit_spex.models.instrument_response import MatrixModel from sunkit_spex.models.models import GaussianModel, StraightLineModel +from sunkit_spex.spectrum import Spectrum +from sunkit_spex.spectrum.spectrum import SpectralAxis ##################################################### # @@ -37,77 +41,100 @@ start, inc = 1.6, 0.04 stop = 80 + inc / 2 -ph_energies = np.arange(start, stop, inc) +ph_energies = np.arange(start, stop, inc) * u.keV +ph_energies_centers = ph_energies[:-1] + 0.5 * np.diff(ph_energies) ##################################################### # # Let's start making a simulated photon spectrum -sim_cont = {"edges": False, "slope": -1, "intercept": 100} -sim_line = {"edges": False, "amplitude": 100, "mean": 30, "stddev": 2} +sim_cont = {"slope": -1 * u.ph / u.keV**2, "intercept": 100 * u.ph / u.keV} +sim_line = {"amplitude": 100 * u.ph / u.keV, "mean": 30 * u.keV, "stddev": 2 * u.keV} # use a straight line model for a continuum, Gaussian for a line ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line) -plt.figure() -plt.plot(ph_energies, ph_model(ph_energies)) -plt.xlabel("Energy [keV]") -plt.ylabel("ph s$^{-1}$ cm$^{-2}$ keV$^{-1}$") -plt.title("Simulated Photon Spectrum") -plt.show() +with quantity_support(): + plt.figure() + plt.plot(ph_energies_centers, ph_model(ph_energies)) + plt.xlabel(f"Energy [{ph_energies.unit}]") + plt.title("Simulated Photon Spectrum") + plt.show() ##################################################### # # Now want a response matrix -srm = simulate_square_response_matrix(ph_energies.size) -srm_model = MatrixModel(matrix=srm) - -plt.figure() -plt.imshow( - srm, origin="lower", extent=[ph_energies[0], ph_energies[-1], ph_energies[0], ph_energies[-1]], norm=LogNorm() +srm = simulate_square_response_matrix(ph_energies.size - 1) +srm_model = MatrixModel( + matrix=srm, + input_axis=SpectralAxis(ph_energies), + output_axis=SpectralAxis(ph_energies), + c=1 * u.ct / u.ph, + _input_units={"x": u.ph * u.keV**-1}, + _output_units={"y": u.ct * u.keV**-1}, ) -plt.ylabel("Photon Energies [keV]") -plt.xlabel("Count Energies [keV]") -plt.title("Simulated SRM") -plt.show() +# srm_model.input_units = {"x": u.ph} + + +with quantity_support(): + plt.figure() + plt.imshow( + srm_model.matrix, + origin="lower", + extent=( + srm_model.input_axis[0].value, + srm_model.input_axis[-1].value, + srm_model.output_axis[0].value, + srm_model.output_axis[-1].value, + ), + norm=LogNorm(), + ) + plt.ylabel(f"Photon Energies [{srm_model.input_axis.unit}]") + plt.xlabel(f"Count Energies [{srm_model.output_axis.unit}]") + plt.title("Simulated SRM") + plt.show() ##################################################### # # Start work on a count model -sim_gauss = {"edges": False, "amplitude": 70, "mean": 40, "stddev": 2} +sim_gauss = {"amplitude": 70 * u.ct / u.keV, "mean": 40 * u.keV, "stddev": 2 * u.keV} # the brackets are very necessary -ct_model = (ph_model | srm_model) + GaussianModel(**sim_gauss) +ct_model = ph_model | srm_model ##################################################### # # Generate simulated count data to (almost) fit -sim_count_model = ct_model(ph_energies) - +sim_count_model = ct_model(SpectralAxis(ph_energies)) ##################################################### # # Add some noise np_rand = np.random.default_rng(seed=10) -sim_count_model_wn = sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model) +sim_count_model_wn = ( + sim_count_model + (2 * np_rand.random(sim_count_model.size)) * np.sqrt(sim_count_model.value) * u.ct / u.keV +) + +obs_spec = Spectrum(sim_count_model_wn.reshape(-1), spectral_axis=ph_energies) + ##################################################### # # Can plot all the different components in the simulated count spectrum -plt.figure() -plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features") -plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature") -plt.plot(ph_energies, sim_count_model, label="total sim. spectrum") -plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise", lw=0.5) -plt.xlabel("Energy [keV]") -plt.ylabel("cts s$^{-1}$ keV$^{-1}$") -plt.title("Simulated Count Spectrum") -plt.legend() +with quantity_support(): + plt.figure() + plt.plot(ph_energies_centers, (ph_model | srm_model)(ph_energies), label="photon model features") + plt.plot(ph_energies_centers, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature") + plt.plot(ph_energies_centers, sim_count_model, label="total sim. spectrum") + plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5) + plt.xlabel(f"Energy [{ph_energies.unit}]") + plt.title("Simulated Count Spectrum") + plt.legend() -plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold") -plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold") -plt.show() + plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold") + plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold") + plt.show() ##################################################### # @@ -115,9 +142,9 @@ # # Get some initial guesses that are off from the simulated data above -guess_cont = {"edges": False, "slope": -0.5, "intercept": 80} -guess_line = {"edges": False, "amplitude": 150, "mean": 32, "stddev": 5} -guess_gauss = {"edges": False, "amplitude": 350, "mean": 39, "stddev": 0.5} +guess_cont = {"slope": -0.5 * u.ph / u.keV**2, "intercept": 80 * u.ph / u.keV} +guess_line = {"amplitude": 150 * u.ph / u.keV, "mean": 32 * u.keV, "stddev": 5 * u.keV} +guess_gauss = {"amplitude": 350 * u.ct / u.keV, "mean": 39 * u.keV, "stddev": 0.5 * u.keV} ##################################################### # @@ -126,22 +153,24 @@ ph_mod_4fit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line) count_model_4fit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss) -##################################################### -# -# Let's fit the simulated data and plot the result -opt_res = scipy_minimize( - minimize_func, count_model_4fit.parameters, (sim_count_model_wn, ph_energies, count_model_4fit, chi_squared) -) +# print(ph_mod_4fit(ph_energies).size) +# print(count_model_4fit(obs_spec.data).size) +# ##################################################### +# # +# # Let's fit the simulated data and plot the result -plt.figure() -plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") -plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies, *opt_res.x), ls=":", label="model fit") -plt.xlabel("Energy [keV]") -plt.ylabel("cts s$^{-1}$ keV$^{-1}$") -plt.title("Simulated Count Spectrum Fit with Scipy") -plt.legend() -plt.show() + +opt_res = scipy_minimize(minimize_func, count_model_4fit.parameters, (obs_spec, count_model_4fit, chi_squared)) + +with quantity_support(): + plt.figure() + plt.plot(ph_energies_centers, sim_count_model_wn, label="total sim. spectrum + noise") + plt.plot(ph_energies_centers, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit") + plt.xlabel(f"Energy [{ph_energies.unit}]") + plt.title("Simulated Count Spectrum Fit with Scipy") + plt.legend() + plt.show() ##################################################### @@ -150,18 +179,31 @@ # # Try and ensure we start fresh with new model definitions +guess_cont = {"slope": -0.5 * u.ph / u.keV**2, "intercept": 80 * u.ph / u.keV} +guess_line = {"amplitude": 150 * u.ph / u.keV, "mean": 32 * u.keV, "stddev": 5 * u.keV} + ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line) -count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss) -astropy_fit = fitting.LevMarLSQFitter() +cgauss = GaussianModel(**guess_gauss) + + +count_model_4astropyfit = (ph_mod_4astropyfit | srm_model) + cgauss + -astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn) +astropy_fit = fitting.LevMarLSQFitter() +astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, obs_spec.data << obs_spec.unit) plt.figure() -plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") -plt.plot(ph_energies, astropy_fitted_result(ph_energies), ls=":", label="model fit") +plt.plot(ph_energies_centers, sim_count_model_wn, label="total sim. spectrum + noise") +plt.plot( + ph_energies_centers, + count_model_4astropyfit.evaluate(ph_energies.value, *astropy_fitted_result.parameters), + ls=":", + label="model fit", +) + plt.xlabel("Energy [keV]") -plt.ylabel("cts s$^{-1}$ keV$^{-1}$") +plt.ylabel("ct keV$^{-1}$") plt.title("Simulated Count Spectrum Fit with Astropy") plt.legend() plt.show() @@ -170,24 +212,38 @@ # # Display a table of the fitted results -plt.figure(layout="constrained") +# plt.figure(layout="constrained") + +# row_labels = ( +# tuple(sim_cont)[-2:] + tuple(f"{p}1" for p in tuple(sim_line)[-3:]) + tuple(f"{p}2" for p in tuple(sim_gauss)[-3:]) +# ) +# column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit") +# true_vals = np.array(tuple(sim_cont.values())[-2:] + tuple(sim_line.values())[-3:] + tuple(sim_gauss.values())[-3:]) +# guess_vals = np.array( +# tuple(guess_cont.values())[-2:] + tuple(guess_line.values())[-3:] + tuple(guess_gauss.values())[-3:] +# ) +# scipy_vals = opt_res.x +# astropy_vals = astropy_fitted_result.parameters + +# print(np.shape(scipy_vals)) +# print(np.shape(astropy_vals)) +# print(np.shape(true_vals)) +# print(np.shape(guess_vals)) + +plt.figure(layout="constrained") row_labels = ( - tuple(sim_cont)[-2:] + tuple(f"{p}1" for p in tuple(sim_line)[-3:]) + tuple(f"{p}2" for p in tuple(sim_gauss)[-3:]) + tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + ("C",) + tuple(f"{p}2" for p in tuple(sim_gauss)) ) column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit") -true_vals = np.array(tuple(sim_cont.values())[-2:] + tuple(sim_line.values())[-3:] + tuple(sim_gauss.values())[-3:]) -guess_vals = np.array( - tuple(guess_cont.values())[-2:] + tuple(guess_line.values())[-3:] + tuple(guess_gauss.values())[-3:] -) +true_vals = tuple(sim_cont.values()) + tuple(sim_line.values()) + (1 * u.m,) + tuple(sim_gauss.values()) +true_vals = [t.value for t in true_vals] +guess_vals = tuple(guess_cont.values()) + tuple(guess_line.values()) + (1 * u.m,) + tuple(guess_gauss.values()) +guess_vals = [g.value for g in guess_vals] scipy_vals = opt_res.x astropy_vals = astropy_fitted_result.parameters -print(np.shape(scipy_vals)) -print(np.shape(astropy_vals)) -print(np.shape(true_vals)) -print(np.shape(guess_vals)) cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str) diff --git a/pyproject.toml b/pyproject.toml index 9a7c4493..6004a522 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ authors = [ { name = "The SunPy Community", email = "sunpy@googlegroups.com" }, ] dependencies = [ + "astropy @ git+https://github.com/jajmitchell/astropy.git@astropy_sunkit-spex", "corner>=2.2", "emcee>=3.1", "matplotlib>=3.7", diff --git a/pytest.ini b/pytest.ini index c9c1945e..4a813084 100644 --- a/pytest.ini +++ b/pytest.ini @@ -40,3 +40,4 @@ filterwarnings = # Oldestdeps issues ignore:`finfo.machar` is deprecated:DeprecationWarning ignore:Please use `convolve1d` from the `scipy.ndimage` namespace, the `scipy.ndimage.filters` namespace is deprecated.:DeprecationWarning + ignore::astropy.utils.exceptions.AstropyDeprecationWarning diff --git a/sunkit_spex/fitting/objective_functions/optimising_functions.py b/sunkit_spex/fitting/objective_functions/optimising_functions.py index 4b631264..35d2599d 100644 --- a/sunkit_spex/fitting/objective_functions/optimising_functions.py +++ b/sunkit_spex/fitting/objective_functions/optimising_functions.py @@ -5,7 +5,7 @@ __all__ = ["minimize_func"] -def minimize_func(params, data_y, model_x, model_func, statistic_func): +def minimize_func(params, obs_spec, model_func, statistic_func): """ Minimization function. @@ -31,6 +31,12 @@ def minimize_func(params, data_y, model_x, model_func, statistic_func): ------- `float` The value to be optimized that compares the model to the data. + """ - model_y = model_func.evaluate(model_x, *params) - return statistic_func(data_y, model_y) + + if obs_spec._spectral_axis._bin_edges is not None: + model_y = model_func.evaluate(obs_spec._spectral_axis._bin_edges.value, *params) + else: + model_y = model_func.evaluate(obs_spec._spectral_axis.value, *params) + + return statistic_func(obs_spec.data, model_y) diff --git a/sunkit_spex/fitting/tests/test_objective_functions.py b/sunkit_spex/fitting/tests/test_objective_functions.py index c72024ee..542072d0 100644 --- a/sunkit_spex/fitting/tests/test_objective_functions.py +++ b/sunkit_spex/fitting/tests/test_objective_functions.py @@ -4,33 +4,34 @@ import numpy as np +import astropy.units as u + from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func from sunkit_spex.fitting.statistics.gaussian import chi_squared from sunkit_spex.models.models import StraightLineModel +from sunkit_spex.spectrum import Spectrum def test_minimize_func(): """Test the `minimize_func` function against known outputs.""" - sim_x0 = np.arange(3) + sim_x0 = np.arange(3) * u.keV model_params0 = {"slope": 1, "intercept": 0} sim_model0 = StraightLineModel(edges=False, **model_params0) sim_data0 = sim_model0.evaluate(sim_x0, **model_params0) res0 = minimize_func( params=tuple(model_params0.values()), - data_y=sim_data0, - model_x=sim_x0, + obs_spec=Spectrum(sim_data0, spectral_axis=sim_x0), model_func=sim_model0, statistic_func=chi_squared, ) - sim_x1 = np.arange(3) + sim_x1 = np.arange(3) * u.keV model_params1 = {"slope": 1, "intercept": 0} sim_model1 = StraightLineModel(edges=False, **model_params1) sim_data1 = sim_model1.evaluate(sim_x1, **model_params1)[::-1] res1 = minimize_func( params=tuple(model_params1.values()), - data_y=sim_data1, - model_x=sim_x1, + obs_spec=Spectrum(sim_data1, spectral_axis=sim_x1), model_func=sim_model1, statistic_func=chi_squared, ) diff --git a/sunkit_spex/fitting/tests/test_optimizer_tools.py b/sunkit_spex/fitting/tests/test_optimizer_tools.py index ed2f8fb7..8da10606 100644 --- a/sunkit_spex/fitting/tests/test_optimizer_tools.py +++ b/sunkit_spex/fitting/tests/test_optimizer_tools.py @@ -5,10 +5,13 @@ import numpy as np from numpy.testing import assert_allclose +import astropy.units as u + from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func from sunkit_spex.fitting.optimizer_tools.minimizer_tools import scipy_minimize from sunkit_spex.fitting.statistics.gaussian import chi_squared from sunkit_spex.models.models import StraightLineModel +from sunkit_spex.spectrum import Spectrum def test_scipy_minimize(): @@ -18,14 +21,22 @@ def test_scipy_minimize(): model_param_values0 = tuple(model_params0.values()) sim_model0 = StraightLineModel(edges=False, **model_params0) sim_data0 = sim_model0.evaluate(sim_x0, **model_params0) - opt_res0 = scipy_minimize(minimize_func, model_param_values0, (sim_data0, sim_x0, sim_model0, chi_squared)) + opt_res0 = scipy_minimize( + minimize_func, + model_param_values0, + (Spectrum(sim_data0 * u.dimensionless_unscaled, spectral_axis=sim_x0 * u.keV), sim_model0, chi_squared), + ) sim_x1 = np.arange(3) model_params1 = {"slope": 8, "intercept": 5} model_param_values1 = tuple(model_params1.values()) sim_model1 = StraightLineModel(edges=False, **model_params1) sim_data1 = sim_model1.evaluate(sim_x1, **model_params1) - opt_res1 = scipy_minimize(minimize_func, model_param_values1, (sim_data1, sim_x1, sim_model1, chi_squared)) + opt_res1 = scipy_minimize( + minimize_func, + model_param_values1, + (Spectrum(sim_data1 * u.dimensionless_unscaled, spectral_axis=sim_x1 * u.keV), sim_model1, chi_squared), + ) assert_allclose(opt_res0.x, model_param_values0, rtol=1e-3) assert_allclose(opt_res1.x, model_param_values1, rtol=1e-3) diff --git a/sunkit_spex/models/instrument_response.py b/sunkit_spex/models/instrument_response.py index f5004b14..1e15cc10 100644 --- a/sunkit_spex/models/instrument_response.py +++ b/sunkit_spex/models/instrument_response.py @@ -1,15 +1,61 @@ """Module for model components required for instrument response models.""" +import numpy as np + from astropy.modeling import Fittable1DModel, Parameter __all__ = ["MatrixModel"] class MatrixModel(Fittable1DModel): - def __init__(self, matrix): - self.matrix = Parameter(default=matrix, description="The matrix with which to multiply the input.", fixed=True) - super().__init__() + name = "SRM" + + c = Parameter(fixed=True) + + def __init__(self, matrix, input_axis, output_axis, _input_units, _output_units, c): + self._input_units = _input_units + self._output_units = _output_units + self.input_axis = input_axis + self.output_axis = output_axis + self.matrix = matrix + super().__init__(c) + # self.matrix.value = self.matrix.value.flatten() - def evaluate(self, model_y): + _input_units_allow_dimensionless = True + + def evaluate(self, x, c): # Requires input must have a specific dimensionality - return model_y @ self.matrix + + input_widths = np.diff(self.input_axis) + output_widths = np.diff(self.output_axis) + + # print(x) + # print(input_widths) + # print(self.matrix) + # print(c) + # print(output_widths) + + # print(self.input_axis.size) + # print(self.output_axis.size) + # print(x.size) + # print(input_widths.size) + # print(self.matrix.size) + # print(c.size) + # print(output_widths.size) + + flux = ((x * input_widths) @ self.matrix * c) / output_widths + + if hasattr(c, "unit"): + return flux + return flux.value + + @property + def input_units(self): + return self._input_units + + @property + def return_units(self): + return self._output_units + + def _parameter_units_for_data_units(self, inputs_unit, outputs_unit): + return {"c": self._output_units["y"] / self._input_units["x"]} diff --git a/sunkit_spex/models/models.py b/sunkit_spex/models/models.py index bbd82769..42c6c118 100644 --- a/sunkit_spex/models/models.py +++ b/sunkit_spex/models/models.py @@ -2,7 +2,6 @@ import numpy as np -import astropy.units as u from astropy.modeling import FittableModel, Parameter from astropy.units import Quantity @@ -15,7 +14,7 @@ class StraightLineModel(FittableModel): _input_units_allow_dimensionless = True - input_units_equivalencies = {"keV": u.spectral()} + # input_units_equivalencies = {"keV": u.spectral()} slope = Parameter(default=1, description="Gradient of a straight line model.") intercept = Parameter(default=0, description="Y-intercept of a straight line model.") @@ -45,7 +44,7 @@ def return_units(self): return None def _parameter_units_for_data_units(self, input_units, output_units): - return {"slope": output_units["y"] / input_units["x"], "intercept": output_units["y"]} + return {"slope": self.slope.unit, "intercept": self.intercept.unit} class GaussianModel(FittableModel): @@ -84,4 +83,4 @@ def return_units(self): return None def _parameter_units_for_data_units(self, input_units, output_units): - return {"mean": input_units["x"], "stddev": input_units["x"], "amplitude": output_units["y"]} + return {"mean": self.mean.unit, "stddev": self.mean.unit, "amplitude": self.amplitude.unit} diff --git a/sunkit_spex/models/physical/nonthermal.py b/sunkit_spex/models/physical/nonthermal.py index 1e95e29a..a0ad4b00 100644 --- a/sunkit_spex/models/physical/nonthermal.py +++ b/sunkit_spex/models/physical/nonthermal.py @@ -70,6 +70,8 @@ class ThickTarget(FittableModel): of ph s^-1 keV^-1. """ + name = "ThickTarget" + n_inputs = 1 n_outputs = 1 @@ -80,7 +82,7 @@ class ThickTarget(FittableModel): q = Parameter(name="q", default=5, min=0.01, description="Slope above break", fixed=True) low_e_cutoff = Parameter( - name="low_e_cutoff", default=7, unit=u.keV, description="Low energy electron cut off", fixed=False + name="low_e_cutoff", default=7, unit=u.keV, description="Low energy electron cut off", fixed=False, min=1 ) high_e_cutoff = Parameter( @@ -202,6 +204,8 @@ class ThinTarget(FittableModel): of ph s^-1 keV^-1. """ + name = "ThinTarget" + n_inputs = 1 n_outputs = 1 diff --git a/sunkit_spex/models/physical/thermal.py b/sunkit_spex/models/physical/thermal.py index ae9f206d..c2b450d2 100644 --- a/sunkit_spex/models/physical/thermal.py +++ b/sunkit_spex/models/physical/thermal.py @@ -141,6 +141,7 @@ class ThermalEmission(FittableModel): name="emission_measure", default=1, unit=(u.cm ** (-3)), + min=1e-9, description="Emission measure of the observer", fixed=False, ) diff --git a/sunkit_spex/spectrum/spectrum.py b/sunkit_spex/spectrum/spectrum.py index b1ca2d86..a4be85a0 100644 --- a/sunkit_spex/spectrum/spectrum.py +++ b/sunkit_spex/spectrum/spectrum.py @@ -85,6 +85,8 @@ def __new__(cls, value, *args, bin_specification="centers", **kwargs): if bin_specification == "edges": obj._bin_edges = bin_edges + elif bin_specification == "centers": + obj._bin_edges = None return obj @@ -193,14 +195,14 @@ def __init__( if data is not None and spectral_axis is not None: if spectral_axis.shape[0] == data.shape[spectral_dimension]: bin_specification = "centers" - elif spectral_axis.shape[0] == data.shape[spectral_dimension] + 1: + elif spectral_axis.shape[0] >= data.shape[spectral_dimension] + 1: bin_specification = "edges" - else: - raise ValueError( - f"Spectral axis length ({spectral_axis.shape[0]}) must be the same size or one " - "greater (if specifying bin edges) than that of the spextral" - f"axis ({data.shape[spectral_dimension]})" - ) + # else: + # raise ValueError( + # f"Spectral axis length ({spectral_axis.shape[0]}) must be the same size or one " + # "greater (if specifying bin edges) than that of the spextral" + # f"axis ({data.shape[spectral_dimension]})" + # ) # Attempt to parse the spectral axis. If none is given, try instead to # parse a given wcs. This is put into a GWCS object to @@ -215,6 +217,11 @@ def __init__( if not isinstance(spectral_axis, SpectralAxis): if spectral_axis.shape[0] == data.shape[spectral_dimension] + 1: bin_specification = "edges" + elif len(spectral_axis.shape) > 1: + if spectral_axis.shape[1] == 2: + spectral_axis = np.concatenate([spectral_axis[:, 0], spectral_axis[:, 1][-1:]]) + bin_specification = "edges" + else: bin_specification = "centers" self._spectral_axis = SpectralAxis(spectral_axis, bin_specification=bin_specification) diff --git a/sunkit_spex/tests/test_models.py b/sunkit_spex/tests/test_models.py index ce24b0e6..0763a540 100644 --- a/sunkit_spex/tests/test_models.py +++ b/sunkit_spex/tests/test_models.py @@ -5,6 +5,8 @@ import numpy as np from numpy.testing import assert_allclose, assert_array_equal +from astropy import units as u + from sunkit_spex.data.simulated_data import simulate_square_response_matrix from sunkit_spex.models.instrument_response import MatrixModel from sunkit_spex.models.models import GaussianModel, StraightLineModel @@ -67,16 +69,27 @@ def test_GaussianModel_edges(): def test_MatrixModel(): """Test the matrix model contents and compound model behaviour.""" size0 = 3 + sim_x0 = np.arange(size0 + 1) + srm0 = simulate_square_response_matrix(size0) - srm_model0 = MatrixModel(matrix=srm0) + + srm_model0 = MatrixModel( + matrix=srm0, + c=1 * u.dimensionless_unscaled, + input_axis=sim_x0, + output_axis=sim_x0, + _input_units={"x": u.dimensionless_unscaled}, + _output_units={"y": u.dimensionless_unscaled}, + ) assert_array_equal(srm_model0.matrix, srm0) - sim_x0 = np.arange(size0) - model_params0_init = {"edges": False, "slope": 1, "intercept": 0} + model_params0_init = {"edges": True, "slope": 1, "intercept": 0} sim_model0 = StraightLineModel(**model_params0_init) comp_model0 = sim_model0 | srm_model0 comp_res0 = comp_model0(sim_x0) - exp_res0 = [0.00682338, 1.00348448, 1.98969213] + + # exp_res0 = [0.00682338, 1.00348448, 1.98969213] + exp_res0 = [0.509719, 1.503166, 2.487115] assert_allclose(comp_res0, exp_res0, rtol=1e-6)