diff --git a/sunkit_spex/fitting/fitter.py b/sunkit_spex/fitting/fitter.py new file mode 100644 index 00000000..8edabd8e --- /dev/null +++ b/sunkit_spex/fitting/fitter.py @@ -0,0 +1,211 @@ +""" +This module contains functions to carry out astropy fitting with spectral models +""" + +import astropy.units as u +from astropy.modeling import fitting +from astropy.modeling import models +from matplotlib import pyplot as plt + +import numpy as np + +from sunkit_spex.models.physical.thermal import ThermalEmission +from sunkit_spex.models.physical.nonthermal import ThickTarget +from sunkit_spex.models.physical.albedo import Albedo +from sunkit_spex.models.scaling import InverseSquareFluxScaling +from sunkit_spex.models.instrument_response import MatrixModel +# from sunkit_spex.visualisation.plotter import plot + +__all__ = ["fitter"] + + +class Fitter: + + def __init__( + self, + model, + spectrum_object, + fitting_method = fitting.TRFLSQFitter(calc_uncertainties=True), + fit_range=None): + + self._model = model + self._spectrum_object = spectrum_object + self._fitting_method = fitting_method + self._fit_range = fit_range + self._fitted_model = None + # self._PIPELINE_COMPONENTS = {'SRM', 'Albedo', 'InverseSquareFluxScaling'} + + @property + def model(self): + return self._model + + + + def _set_abledo_angle(self): + + if 'Albedo' in self.model.submodel_names: + + # print(len(self._spectrum_object.meta['ph_axis'])) + + replacement_albedo = Albedo(energy_edges=self._spectrum_object.meta['ph_axis'], + theta=self._spectrum_object.meta['angle']) + replacement_albedo.theta.fixed = True + + self._model = self._model.replace_submodel('Albedo',replacement_albedo) + + def _set_observer_distance(self): + + match = np.where(np.array(self._model.submodel_names)=='InverseSquareFluxScaling')[0] + + if np.shape(match) != 0: + param_names = [f'observer_distance_{str(ind)}' for ind in match] + + for param_name in param_names: + setattr(self._model, param_name, self._spectrum_object.meta['distance']) + getattr(self._model, param_name).fixed = True + + def _set_srm(self): + + if 'SRM' in self.model.submodel_names: + + self._model = self._model.replace_submodel('SRM',MatrixModel(matrix= np.array(self._spectrum_object.meta['srm']), + spectrum_object=self._spectrum_object, + model_spec_units=u.ph * u.keV**-1 * u.s**-1 * u.cm**-2)) + @property + def fitting_method(self): + return self._fitting_method + + @fitting_method.setter + def fitting_method(self, value): + self._fitting_method = value + + @property + def fitted_model(self): + """Return the fitted model. None until do_fit() has been called.""" + if self._fitted_model is None: + raise RuntimeError("No fitted model available — call do_fit() first.") + return self._fitted_model + + @property + def fit_range(self): + return self._fit_range + + + @fit_range.setter + def fit_range(self, value): + """ + value : tuple + (emin, emax) in same units as spectral_axis + """ + + if value is None: + self._fit_range = None + return + + emin, emax = value + edges = self._spectrum_object.spectral_axis.bin_edges + + # Determine bins fully inside range + lower = edges[:-1] + upper = edges[1:] + + indices = np.where((lower >= emin) & (upper <= emax))[0] + + self._fit_range = value + self._fit_mask = indices + + def _apply_fit_range(self): + + if self._fit_range is None: + return + + mask = self._fit_mask + + self._spectrum_object = self._spectrum_object[mask[0]:mask[-1]+1] + + # print(self._spectrum_object.spectral_axis.bin_edges.shape) + + self._spectrum_object.spectral_axis._bin_edges = np.array(self._spectrum_object.spectral_axis.bin_edges[mask[0]:mask[-1]+2]) + + + # print(self._spectrum_object.spectral_axis.bin_edges.shape) + + if 'srm' in self._spectrum_object.meta: + self._spectrum_object.meta['srm'] = \ + self._spectrum_object.meta['srm'][:,mask[0]:mask[-1]+1] + + + def _fit_prep(self): + + self._apply_fit_range() + + self._set_abledo_angle() + self._set_observer_distance() + self._set_srm() + + + + def do_fit(self): + + + self._fit_prep() + + + w = np.array(1/self._spectrum_object.uncertainty.array) << self._spectrum_object.uncertainty.unit + data = np.array(self._spectrum_object.data) << self._spectrum_object.unit + + + # Store on the instance; access via the fitted_model property + self._fitted_model = self._fitting_method( + model=self._model, + x=self._spectrum_object.meta['ph_axis'], + y=data, + weights=w, + estimate_jacobian=True) + + # return fitted_model + + + # def plot_fit_results(self,save=True): + + + # if save: + # plot(self._spectrum_object.spectral_axis._bin_edges*u.keV, + # self._spectrum_object.meta['ph_axis'], + # self._spectrum_object.data << self._spectrum_object.unit, + # self._spectrum_object.uncertainty.array << self._spectrum_object.unit, + # self.fitted_model, + # f'{self._spectrum_object.meta['time_range'][0]}_{self._spectrum_object.meta['time_range'][1]}_sunkit_spex_fit.png', + # f'{self._spectrum_object.meta['time_range'][0]} - {self._spectrum_object.meta['time_range'][1]}', + # self.fitting_method.fit_info['param_cov'], + # self._spectrum_object) + # else: + # plot(self._spectrum_object.spectral_axis._bin_edges*u.keV, + # self._spectrum_object.meta['ph_axis'], + # self._spectrum_object.data << self._spectrum_object.unit, + # self._spectrum_object.uncertainty.array << self._spectrum_object.unit, + # self.fitted_model, + # False, + # f'{self._spectrum_object.meta['time_range'][0]} - {self._spectrum_object.meta['time_range'][1]}', + # self.fitting_method.fit_info['param_cov'], + # self._spectrum_object) + + + + + # 'here we perform the fitting' + + # def plot_fit_results(self): + # 'here we plot the fitting results' + + # def chi_squared(self): + # 'here we calculate the chi^2' + + # def get_fit_results(self): + # 'here we return fit results and uncertainties' + + # def get_fit_components(self): + # 'here we return the fitted components' + + # def run_mcmc(self): + # 'run_mcmc' \ No newline at end of file diff --git a/sunkit_spex/models/instrument_response.py b/sunkit_spex/models/instrument_response.py index f5004b14..7c73934d 100644 --- a/sunkit_spex/models/instrument_response.py +++ b/sunkit_spex/models/instrument_response.py @@ -1,15 +1,156 @@ """Module for model components required for instrument response models.""" +import numpy as np + from astropy.modeling import Fittable1DModel, Parameter +import astropy.units as u __all__ = ["MatrixModel"] class MatrixModel(Fittable1DModel): - def __init__(self, matrix): - self.matrix = Parameter(default=matrix, description="The matrix with which to multiply the input.", fixed=True) - super().__init__() + name = "SRM" + conversion_factor = Parameter(fixed=True) + _input_units_allow_dimensionless = True + + def __init__(self, matrix=None, + model_spec_units=u.dimensionless_unscaled, + data_spec_units=u.dimensionless_unscaled, + conversion_factor=1*u.dimensionless_unscaled, + spectrum_object= None, + spectral_model=True): + + self.spectral_model = spectral_model + + if not self.spectral_model: + + self.model_spec_units = model_spec_units + self.data_spec_units = data_spec_units + self.matrix = matrix + conversion_factor = 1 << (data_spec_units / model_spec_units) + else: + # self.matrix = matrix + self.spectrum_object = spectrum_object + self.model_spec_units = model_spec_units + + if spectrum_object: + self.data_spec_units = spectrum_object.unit + conversion_factor = 1* u.ct / u.ph + # conversion_factor = 1 << (self.data_spec_units / self.model_spec_units) + # print(conversion_factor) + else: + self.data_spec_units = data_spec_units + # conversion_factor = 1 << (data_spec_units / model_spec_units) + conversion_factor = 1* u.ct / u.ph + + super().__init__(conversion_factor=conversion_factor) + + def evaluate(self, x, conversion_factor): + + # matrix = self.matrix + + if self.spectral_model: + + matrix = self.spectrum_object.meta['srm'] + input_axis = np.array(self.spectrum_object.spectral_axis.bin_edges) + input_widths = np.diff(input_axis) + output_widths = np.diff(self.spectrum_object.meta['ph_axis']) + + # print('IR SRM = ',self.spectrum_object.meta['srm'].shape) + # print('IR SRM = ',self.spectrum_object.spectral_axis.bin_edges.shape) + + geo_area = self.spectrum_object.meta['geo_area'] + exposure_time = self.spectrum_object.meta['exposure_time'] + norm = input_widths * exposure_time * geo_area + + # print('input_widths = ',input_widths) + # print('exposure_time = ',exposure_time) + # print('geo_area = ',geo_area) + + # print(x.unit) + # print(conversion_factor.unit) + # print(norm.unit) + + # flux = (x @ matrix) * conversion_factor * norm + flux = (((x*output_widths*exposure_time)@ (matrix*geo_area*u.cm**2)) * conversion_factor ) + + else: + flux = x @ matrix * conversion_factor * (geo_area*u.cm**2) + + # print('HHEERRREEEE') + + if hasattr(conversion_factor,"unit"): + return flux + else: + return flux.value + + def set_spectrum_object(self, new_spectrum_object): + self.spectrum_object = new_spectrum_object + + # @property + # def model_spec_units(self): + # return self._model_spec_units + + # @model_spec_units.setter + # def model_spec_units(self, new_unit): + # self._model_spec_units = new_unit + + # if hasattr(self,"data_spec_units"): + + # if self.data_spec_units != u.dimensionless_unscaled: + + # new_param_unit = self.data_spec_units / new_unit + + # self.conversion_factor = self.conversion_factor.value * new_param_unit + + # else: + + # self.conversion_factor = self.conversion_factor * u.dimensionless_unscaled + + + # @property + # def data_spec_units(self): + # return self._data_spec_units + + # @data_spec_units.setter + # def data_spec_units(self, new_unit): + # self._data_spec_units = new_unit + + # if hasattr(self,"model_spec_units"): + + + # if self.data_spec_units != u.dimensionless_unscaled: + + # new_param_unit = new_unit / self.model_spec_units + + # self.conversion_factor = self.conversion_factor.value * new_param_unit + + # else: + + # self.conversion_factor = self.conversion_factor * u.dimensionless_unscaled + + @property + def input_units(self): + # return {"x": self.model_spec_units }SS + return {"x": u.ph * u.keV**-1 * u.s**-1 * u.cm**-2 } + + @property + def return_units(self): + # return {"y": self.data_spec_units} + return {"y": u.ct} + + def _parameter_units_for_data_units(self, inputs_unit, outputs_unit): + return {"conversion_factor": self.conversion_factor.unit} + + # @property + # def input_units(self): + # # return {"x": self.model_spec_units }SS + # return {"x": u.ph * u.keV**-1 * u.s**-1 * u.cm**-2 } + + # @property + # def return_units(self): + # # return {"y": self.data_spec_units} + # return {"y": u.ct* u.keV**-1 * u.s**-1} - def evaluate(self, model_y): - # Requires input must have a specific dimensionality - return model_y @ self.matrix + # def _parameter_units_for_data_units(self, inputs_unit, outputs_unit): + # return {"conversion_factor": self.conversion_factor.unit} \ No newline at end of file diff --git a/sunkit_spex/models/physical/albedo.py b/sunkit_spex/models/physical/albedo.py index 8fa6e758..345d4950 100644 --- a/sunkit_spex/models/physical/albedo.py +++ b/sunkit_spex/models/physical/albedo.py @@ -88,14 +88,22 @@ class Albedo(FittableModel): _input_units_allow_dimensionless = True - def __init__(self, *args, **kwargs): - self.energy_edges = kwargs.pop("energy_edges") + def __init__(self,energy_edges=None,spectral_model=True, *args, **kwargs): + + self.spectral_model = spectral_model + + # if self.spectral_model: + self.energy_edges = energy_edges + # else: + # self.energy_edges = kwargs.pop("energy_edges") super().__init__(*args, **kwargs) def evaluate(self, spectrum, theta, anisotropy): if not isinstance(theta, Quantity): theta = theta * u.deg + + # print('albedo ee = ',self.energy_edges) albedo_matrix = get_albedo_matrix(self.energy_edges, theta, anisotropy) diff --git a/sunkit_spex/models/physical/nonthermal.py b/sunkit_spex/models/physical/nonthermal.py index 1e95e29a..71e829c4 100644 --- a/sunkit_spex/models/physical/nonthermal.py +++ b/sunkit_spex/models/physical/nonthermal.py @@ -70,17 +70,19 @@ class ThickTarget(FittableModel): of ph s^-1 keV^-1. """ + name = 'ThickTarget' + n_inputs = 1 n_outputs = 1 - p = Parameter(name="p", default=2, description="Slope below break", fixed=False) + p = Parameter(name="p", default=4,min=3, description="Slope below break", fixed=False) - break_energy = Parameter(name="break_energy", default=100, unit=u.keV, description="Break Energy", fixed=False) + break_energy = Parameter(name="break_energy", default=1500, unit=u.keV, description="Break Energy", fixed=True) q = Parameter(name="q", default=5, min=0.01, description="Slope above break", fixed=True) low_e_cutoff = Parameter( - name="low_e_cutoff", default=7, unit=u.keV, description="Low energy electron cut off", fixed=False + name="low_e_cutoff", default=20, unit=u.keV, description="Low energy electron cut off", fixed=False ) high_e_cutoff = Parameter( @@ -140,6 +142,7 @@ def evaluate(self, energy_edges, p, break_energy, q, low_e_cutoff, high_e_cutoff flux = thick_fn( energy_centers, p, break_energy, q, low_e_cutoff, high_e_cutoff, total_eflux, self.integrator ) + # print('nt_flux = ',flux) return flux diff --git a/sunkit_spex/models/physical/thermal.py b/sunkit_spex/models/physical/thermal.py index 73d09897..c2f33ae2 100644 --- a/sunkit_spex/models/physical/thermal.py +++ b/sunkit_spex/models/physical/thermal.py @@ -263,7 +263,12 @@ def evaluate( fe, ) - return line_flux + cont_flux + flux = line_flux + cont_flux + + if hasattr(energy_edges,"unit"): + return flux + else: + return flux.value @property def input_units(self): @@ -1222,7 +1227,7 @@ def _sanitize_inputs(energy_edges, temperature, emission_measure): # came with them attached. # If they were not already Quantities, the parameters get the default units. energy_edges <<= u.keV - temperature <<= u.K + temperature <<= u.MK emission_measure <<= u.cm**-3 energy_edges_keV = energy_edges.to(u.keV) @@ -1239,6 +1244,7 @@ def _sanitize_inputs(energy_edges, temperature, emission_measure): def _error_if_input_outside_valid_range(input_values, grid_range, param_name, param_unit): + # print(input_values) if input_values.min() < grid_range[0] or input_values.max() > grid_range[1]: if param_name == "temperature": message_unit = "MK" @@ -1251,6 +1257,8 @@ def _error_if_input_outside_valid_range(input_values, grid_range, param_name, pa def _warn_if_input_outside_valid_range(input_values, grid_range, param_name, param_unit): + print(input_values) + print(grid_range) if input_values.min() < grid_range[0] or input_values.max() > grid_range[1]: message = ( f"Some input {param_name} values outside valid range of " diff --git a/sunkit_spex/spectrum/spectrum.py b/sunkit_spex/spectrum/spectrum.py index b1ca2d86..a7f45249 100644 --- a/sunkit_spex/spectrum/spectrum.py +++ b/sunkit_spex/spectrum/spectrum.py @@ -1,3 +1,6 @@ +import copy +from copy import deepcopy + import numpy as np from gwcs import WCS as GWCS from gwcs import coordinate_frames as cf @@ -5,53 +8,198 @@ import astropy.units as u from astropy.coordinates import SpectralCoord +from astropy.modeling.mappings import Identity, Mapping from astropy.modeling.tabular import Tabular1D from astropy.utils import lazyproperty +from astropy.wcs.wcsapi import sanitize_slices __all__ = ["SpectralAxis", "Spectrum", "gwcs_from_array"] -__doctest_requires__ = {"Spectrum": ["ndcube>=2.3"]} __doctest_requires__ = {"Spectrum": ["ndcube>=2.3"]} -def gwcs_from_array(array): +class SpectralGWCS(GWCS): + """ + This is a placeholder lookup-table GWCS created when a :class:`~specutils.Spectrum` is + instantiated with a ``spectral_axis`` and no WCS. + """ + + def __init__(self, *args, **kwargs): + self.original_unit = kwargs.pop("original_unit", "") + super().__init__(*args, **kwargs) + + def copy(self): + """ + Return a shallow copy of the object. + + Convenience method so user doesn't have to import the + :mod:`copy` stdlib module. + + .. warning:: + Use `deepcopy` instead of `copy` unless you know why you need a + shallow copy. + """ + return copy.copy(self) + + def deepcopy(self): + """ + Return a deep copy of the object. + + Convenience method so user doesn't have to import the + :mod:`copy` stdlib module. + """ + return copy.deepcopy(self) + + +def gwcs_from_array(array, flux_shape, spectral_axis_index=None): """ Create a new WCS from provided tabular data. This defaults to being - a GWCS object. + a GWCS object with a lookup table for the spectral axis and filler + pixel to pixel identity conversions for spatial axes, if they exist. """ orig_array = u.Quantity(array) - - coord_frame = cf.CoordinateFrame(naxes=1, axes_type=("SPECTRAL",), axes_order=(0,)) - spec_frame = cf.SpectralFrame(unit=array.unit, axes_order=(0,)) + naxes = len(flux_shape) + + if naxes > 1: + if spectral_axis_index is None: + raise ValueError("spectral_axis_index must be set for multidimensional flux arrays") + # Axis order is reversed for WCS from numpy array + spectral_axis_index = naxes - spectral_axis_index - 1 + elif naxes == 1: + spectral_axis_index = 0 + + axes_order = list(np.arange(naxes)) + axes_type = [ + "SPATIAL", + ] * naxes + axes_type[spectral_axis_index] = "SPECTRAL" + + detector_frame = cf.CoordinateFrame( + naxes=naxes, + name="detector", + unit=[ + u.pix, + ] + * naxes, + axes_order=axes_order, + axes_type=axes_type, + ) + + if array.unit in ("", "pix", "pixel"): + # Spectrum was initialized without a wcs or spectral axis + spectral_frame = cf.CoordinateFrame( + naxes=1, + unit=[ + array.unit, + ], + axes_type=[ + "Spectral", + ], + axes_order=(spectral_axis_index,), + ) + else: + phys_types = None + # Note that some units have multiple physical types, so we can't just set the + # axis name to the physical type string. + if array.unit.physical_type == "length": + axes_names = [ + "wavelength", + ] + elif array.unit.physical_type == "frequency": + axes_names = [ + "frequency", + ] + elif array.unit.physical_type == "velocity": + axes_names = [ + "velocity", + ] + phys_types = [ + "spect.dopplerVeloc.optical", + ] + elif array.unit.physical_type == "wavenumber": + axes_names = [ + "wavenumber", + ] + elif array.unit.physical_type == "energy": + axes_names = [ + "energy", + ] + else: + raise ValueError("Spectral axis units must be one of length,frequency, velocity, energy, or wavenumber") + + spectral_frame = cf.SpectralFrame( + unit=array.unit, axes_order=(spectral_axis_index,), axes_names=axes_names, axis_physical_types=phys_types + ) + + if naxes > 1: + axes_order.remove(spectral_axis_index) + spatial_frame = cf.CoordinateFrame( + naxes=naxes - 1, + unit=[ + "", + ] + * (naxes - 1), + axes_type=[ + "Spatial", + ] + * (naxes - 1), + axes_order=axes_order, + ) + output_frame = cf.CompositeFrame(frames=[spatial_frame, spectral_frame]) + else: + output_frame = spectral_frame # In order for the world_to_pixel transformation to automatically convert - # input units, the equivalencies in the lookup table have to be extended + # input units, the equivalencies in the look up table have to be extended # with spectral unit information. - SpectralTabular1D = type("SpectralTabular1D", (Tabular1D,), {"input_units_equivalencies": {"x0": u.spectral()}}) + SpectralTabular1D = type( + "SpectralTabular1D", (Tabular1D,), {"input_units_equivalencies": {"x0": u.spectral()}, "bounds_error": True} + ) + + # We pass through the pixel values of spatial axes with Identity and use a lookup + # table for the spectral axis values. We use Mapping to pipe the values to the correct + # model depending on which axis is the spectral axis + if naxes == 1: + forward_transform = SpectralTabular1D(np.arange(len(array)) * u.pix, lookup_table=array) + else: + axes_order.append(spectral_axis_index) + # WCS axis order is reverse of numpy array order + mapped_axes = axes_order + out_mapping = np.ones(len(mapped_axes)).astype(int) + for i in range(len(mapped_axes)): + out_mapping[mapped_axes[i]] = i + forward_transform = ( + Mapping(mapped_axes) + | Identity(naxes - 1) & SpectralTabular1D(np.arange(len(array)) * u.pix, lookup_table=array) + | Mapping(out_mapping) + ) - forward_transform = SpectralTabular1D(np.arange(len(array)), lookup_table=array) # If our spectral axis is in descending order, we have to flip the lookup # table to be ascending in order for world_to_pixel to work. if len(array) == 0 or array[-1] > array[0]: - forward_transform.inverse = SpectralTabular1D(array, lookup_table=np.arange(len(array))) + forward_transform.inverse = SpectralTabular1D(array, lookup_table=np.arange(len(array)) * u.pix) else: - forward_transform.inverse = SpectralTabular1D(array[::-1], lookup_table=np.arange(len(array))[::-1]) - - class SpectralGWCS(GWCS): - def pixel_to_world(self, *args, **kwargs): - if orig_array.unit == "": - return u.Quantity(super().pixel_to_world_values(*args, **kwargs)) - return super().pixel_to_world(*args, **kwargs).to(orig_array.unit, equivalencies=u.spectral()) - - return SpectralGWCS(forward_transform=forward_transform, input_frame=coord_frame, output_frame=spec_frame) + raise ValueError("Unsupported ") + # forward_transform.inverse = SpectralTabular1D( + # array[::-1], lookup_table=np.arange(len(array))[::-1]) + + tabular_gwcs = SpectralGWCS( + original_unit=orig_array.unit, + forward_transform=forward_transform, + input_frame=detector_frame, + output_frame=output_frame, + ) + tabular_gwcs.bounding_box = None # Store the intended unit from the origin input array # tabular_gwcs._input_unit = orig_array.unit + return tabular_gwcs + class SpectralAxis(SpectralCoord): - """ + r""" Coordinate object representing spectral values corresponding to a specific spectrum. Overloads SpectralCoord with additional information (currently only bin edges). @@ -75,6 +223,9 @@ def __new__(cls, value, *args, bin_specification="centers", **kwargs): ): raise ValueError("u.pix spectral axes should always be ascending") + if bin_specification == "edges" and value.size < 2: + raise ValueError('If bin_specification="centers" have at least two bin edges.') + # Convert to bin centers if bin edges were given, since SpectralCoord # only accepts centers if bin_specification == "edges": @@ -88,43 +239,35 @@ def __new__(cls, value, *args, bin_specification="centers", **kwargs): return obj - @staticmethod - def _edges_from_centers(centers, unit): - """ - Calculates interior bin edges based on the average of each pair of - centers, with the two outer edges based on extrapolated centers added - to the beginning and end of the spectral axis. - """ - a = np.insert(centers, 0, 2 * centers[0] - centers[1]) - b = np.append(centers, 2 * centers[-1] - centers[-2]) - edges = (a + b) / 2 - return edges * unit - @staticmethod def _centers_from_edges(edges): - """ + r""" Calculates the bin centers as the average of each pair of edges """ return (edges[1:] + edges[:-1]) / 2 @lazyproperty def bin_edges(self): - """ + r""" Calculates bin edges if the spectral axis was created with centers specified. """ - if hasattr(self, "_bin_edges"): + if hasattr(self, "_bin_edges") and self._bin_edges is not None: return self._bin_edges - return self._edges_from_centers(self.value, self.unit) + return None + + def __array_finalize__(self, obj): + super().__array_finalize__(obj) + self._bin_edges = getattr(obj, "_bin_edges", None) class Spectrum(NDCube): r""" - Spectrum container for data with one spectral axis. + Spectrum container for data which share a common spectral axis. Note that "1D" in this case refers to the fact that there is only one - spectral axis. `Spectrum` can contain "vector 1D spectra" by having the - ``flux`` have a shape with dimension greater than 1. + spectral axis. `Spectrum` can contain ND data where + ``data`` have a shape with dimension greater than 1. Notes ----- @@ -134,15 +277,15 @@ class Spectrum(NDCube): ---------- data : `~astropy.units.Quantity` The data for this spectrum. This can be a simple `~astropy.units.Quantity`, - or an existing `~Spectrum1D` or `~ndcube.NDCube` object. + or an existing `~Spectrum` or `~ndcube.NDCube` object. uncertainty : `~astropy.nddata.NDUncertainty` Contains uncertainty information along with propagation rules for spectrum arithmetic. Can take a unit, but if none is given, will use - the unit defined in the flux. + the unit defined in the data. spectral_axis : `~astropy.units.Quantity` or `~specutils.SpectralAxis` - Dispersion information with the same shape as the dimension specified by spectral_dimension - of shape plus one if specifying bin edges. - spectral_dimension : `int` default 0 + Dispersion information with the same shape as the dimension specified by spectral_axis_index + or shape plus one if specifying bin edges. + spectral_axis_index : `int` default 0 The dimension of the data which represents the spectral information default to first dimension index 0. mask : `~numpy.ndarray`-like Array where values in the flux to be masked are those that @@ -157,7 +300,7 @@ class Spectrum(NDCube): >>> import numpy as np >>> import astropy.units as u >>> from sunkit_spex.spectrum import Spectrum - >>> spec = Spectrum(np.arange(1, 11)*u.watt, spectral_axis=np.arange(1, 12)*u.keV) + >>> spec = Spectrum(np.arange(1, 11)*u.watt,spectral_axis=np.arange(1, 12)*u.keV) >>> spec 0: + raise ValueError( + "Initializer contains unknown arguments(s): {}.".format(", ".join(map(str, unknown_kwargs))) + ) + + # Handle initializing from NDCube objects + if isinstance(data, NDCube): + if data.unit is None: + raise ValueError("Input NDCube missing unit parameter") + + if spectral_axis is None: + raise ValueError("Spectral axis must be specified") + + # Change the data array from bare ndarray to a Quantity + q_data = data.data << u.Unit(data.unit) + + self.__init__( + q_data, wcs=data.wcs, mask=data.mask, uncertainty=data.uncertainty, spectral_axis=spectral_axis + ) + return + + self._spectral_axis_index = spectral_axis_index + # If here data should be an array or quantity + if spectral_axis_index is None and data is not None: + if data.ndim == 1: + self._spectral_axis_index = 0 + elif data is None: + self._spectral_axis_index = 0 # Ensure that the unit information codified in the quantity object is # the One True Unit. kwargs.setdefault("unit", data.unit if isinstance(data, u.Quantity) else kwargs.get("unit")) - # If flux and spectral axis are both specified, check that their lengths + # If a WCS is provided, determine which axis is the spectral axis + if wcs is not None: + if spectral_axis is None: + raise ValueError("Spectral axis must be specified") + + naxis = None + if hasattr(wcs, "naxis"): + naxis = wcs.naxis + # GWCS doesn't have naxis + elif hasattr(wcs, "world_n_dim"): + naxis = wcs.world_n_dim + + if naxis is not None and naxis > 1: + temp_axes = [] + phys_axes = wcs.world_axis_physical_types + if self._spectral_axis_index is None: + for i in range(len(phys_axes)): + if phys_axes[i] is None: + continue + if phys_axes[i][0:2] == "em" or phys_axes[i][0:5] == "spect" or phys_axes[i][7:12] == "Spect": + temp_axes.append(i) + if len(temp_axes) != 1: + raise ValueError( + f"Input WCS must have exactly one axis with spectral units, found {len(temp_axes)}" + ) + # Due to FITS conventions, the WCS axes are listed in opposite + # order compared to the data array. + self._spectral_axis_index = len(data.shape) - temp_axes[0] - 1 + + else: + if data is not None and data.ndim == 1: + self._spectral_axis_index = 0 + else: + if self.spectral_axis_index is None: + raise ValueError("WCS is 1D but flux is multi-dimensional. Please specify spectral_axis_index.") + + # If data and spectral axis are both specified, check that their lengths # match or are off by one (implying the spectral axis stores bin edges) + bin_specification = "centers" # default value if data is not None and spectral_axis is not None: - if spectral_axis.shape[0] == data.shape[spectral_dimension]: + if spectral_axis.shape[0] == data.shape[self.spectral_axis_index]: bin_specification = "centers" - elif spectral_axis.shape[0] == data.shape[spectral_dimension] + 1: + elif spectral_axis.shape[0] == data.shape[self.spectral_axis_index] + 1: bin_specification = "edges" else: raise ValueError( - f"Spectral axis length ({spectral_axis.shape[0]}) must be the same size or one " - "greater (if specifying bin edges) than that of the spextral" - f"axis ({data.shape[spectral_dimension]})" + f"Spectral axis length ({spectral_axis.shape[0]}) must be the " + "same size or one greater (if specifying bin edges) than that " + f"of the corresponding data axis ({data.shape[self.spectral_axis_index]})" ) # Attempt to parse the spectral axis. If none is given, try instead to # parse a given wcs. This is put into a GWCS object to # then be used behind-the-scenes for all operations. - if spectral_axis is not None: - # Ensure that the spectral axis is an astropy Quantity - if not isinstance(spectral_axis, u.Quantity): - raise ValueError("Spectral axis must be a `Quantity` or `SpectralAxis` object.") - - # If a spectral axis is provided as an astropy Quantity, convert it - # to a SpectralAxis object. - if not isinstance(spectral_axis, SpectralAxis): - if spectral_axis.shape[0] == data.shape[spectral_dimension] + 1: - bin_specification = "edges" - else: - bin_specification = "centers" - self._spectral_axis = SpectralAxis(spectral_axis, bin_specification=bin_specification) - - wcs = gwcs_from_array(self._spectral_axis) - - super().__init__( - data=data.value if isinstance(data, u.Quantity) else data, - wcs=wcs, - mask=mask, - meta=meta, - uncertainty=uncertainty, - **kwargs, - ) + + # Ensure that the spectral axis is an astropy Quantity or SpectralAxis + if not isinstance(spectral_axis, (u.Quantity, SpectralAxis)): + raise ValueError("Spectral axis must be a `Quantity` or `SpectralAxis` object.") + + # If spectral axis is provided as an astropy Quantity, convert it + # to a specutils SpectralAxis object. + if not isinstance(spectral_axis, SpectralAxis): + self._spectral_axis = SpectralAxis(spectral_axis, bin_specification=bin_specification) + # If a SpectralAxis object is provided, we assume it doesn't need + # information from other keywords added + else: + self._spectral_axis = spectral_axis + + # Check the spectral_axis matches the wcs + if wcs is not None: + wsc_coords = None + if hasattr(wcs, "spectral") and getattr(wcs, "is_spectral", False): + wcs_coords = wcs.spectral.pixel_to_world(np.arange(data.shape[self.spectral_axis_index])).to("keV") + elif wcs.pixel_n_dim == 1: + wcs_coords = wcs.pixel_to_world(np.arange(data.shape[self.spectral_axis_index])) + # else: + # array_index = wcs.pixel_n_dim - self._spectral_axis_index - 1 + # pixels = [0] * wcs.pixel_n_dim + # pixels[array_index] = np.arange(data.shape[self.spectral_axis_index]) + # wcs_coords = wcs.pixel_to_world(*pixels)[array_index] + if wsc_coords is not None: + if not u.allclose(self._spectral_axis, wcs_coords): + raise ValueError( + f"Spectral axis {self._spectral_axis} and wcs spectral axis {wcs_coords} must match." + ) + + if wcs is None: + wcs = gwcs_from_array(self._spectral_axis, data.shape, spectral_axis_index=self.spectral_axis_index) + + super().__init__(data=data.value if isinstance(data, u.Quantity) else data, wcs=wcs, **kwargs) + + # make sure that spectral axis is strictly increasing or strictly decreasing + is_strictly_increasing = np.all(self._spectral_axis[1:] > self._spectral_axis[:-1]) + if len(self._spectral_axis) > 1 and not (is_strictly_increasing): + raise ValueError("Spectral axis must be strictly increasing decreasing.") + + if hasattr(self, "uncertainty") and self.uncertainty is not None: + if not data.shape == self.uncertainty.array.shape: + raise ValueError( + f"Data axis ({data.shape}) and uncertainty ({self.uncertainty.array.shape}) shapes must be the " + "same." + ) + + def __getitem__(self, item): + sliced_cube = super().__getitem__(item) + item = tuple(sanitize_slices(item, len(self.shape))) + sliced_spec_axis = self.spectral_axis[item[self.spectral_axis_index]] + return Spectrum(sliced_cube, spectral_axis=sliced_spec_axis) + + def _slice(self, item): + kwargs = super()._slice(item) + item = tuple(sanitize_slices(item, len(self.shape))) + + kwargs["spectral_axis_index"] = self.spectral_axis_index + kwargs["spectral_axis"] = self.spectral_axis[item[self.spectral_axis_index]] + return kwargs + + def _new_instance(self, **kwargs): + keys = ("unit", "wcs", "mask", "meta", "uncertainty", "psf", "spectral_axis") + full_kwargs = {k: deepcopy(getattr(self, k, None)) for k in keys} + # We Explicitly DO NOT deepcopy any data + full_kwargs["data"] = self.data + full_kwargs.update(kwargs) + new_spectrum = type(self)(**full_kwargs) + if self.extra_coords is not None: + new_spectrum._extra_coords = deepcopy(self.extra_coords) + if self.global_coords is not None: + new_spectrum._global_coords = deepcopy(self.global_coords) + return new_spectrum + + @property + def spectral_axis(self): + return self._spectral_axis + + @property + def spectral_axis_index(self): + return self._spectral_axis_index \ No newline at end of file