Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in the temperature bin lengths #32

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions sunkit_dem/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@
cls._registry[cls] = cls.defines_model_for

@u.quantity_input
def __init__(self, data, kernel, temperature_bin_edges: u.K, **kwargs):
def __init__(self, data, kernel, temperature_bin_edges: u.K, kernel_temperatures=None, **kwargs):

Check warning on line 61 in sunkit_dem/base_model.py

View check run for this annotation

Codecov / codecov/patch

sunkit_dem/base_model.py#L61

Added line #L61 was not covered by tests
self.temperature_bin_edges = temperature_bin_edges
self.data = data
self.kernel_temperatures = kernel_temperatures
if self.kernel_temperatures is None:

Check warning on line 65 in sunkit_dem/base_model.py

View check run for this annotation

Codecov / codecov/patch

sunkit_dem/base_model.py#L64-L65

Added lines #L64 - L65 were not covered by tests
self.kernel_temperatures = self.temperature_bin_centers
self.kernel = kernel

@property
Expand All @@ -72,9 +75,7 @@
@property
@u.quantity_input
def temperature_bin_centers(self) -> u.K:
log_temperature = np.log10(self.temperature_bin_edges.value)
log_temperature_centers = (log_temperature[1:] + log_temperature[:-1])/2.
return u.Quantity(10.**log_temperature_centers, self.temperature_bin_edges.unit)
return (self.temperature_bin_edges[1:] + self.temperature_bin_edges[:-1])/2.

@property
def data(self) -> ndcube.NDCollection:
Expand All @@ -89,9 +90,23 @@
if not isinstance(data, ndcube.NDCollection):
raise ValueError('Input data must be an NDCollection')
if not all([hasattr(data[k], 'unit') for k in data]):
raise u.UnitsError('Each NDCube in NDCubeSequence must have units')
raise u.UnitsError('Each NDCube in NDCollection must have units')
self._data = data

@property
def combined_mask(self):
"""
Combined mask of all members of ``data``. Will be True if any member is masked.
This is propagated to the final DEM result
"""
combined_mask = []
for k in self._keys:

Check warning on line 103 in sunkit_dem/base_model.py

View check run for this annotation

Codecov / codecov/patch

sunkit_dem/base_model.py#L100-L103

Added lines #L100 - L103 were not covered by tests
if self.data[k].mask is not None:
combined_mask.append(self.data[k].mask)
else:

Check warning on line 106 in sunkit_dem/base_model.py

View check run for this annotation

Codecov / codecov/patch

sunkit_dem/base_model.py#L105-L106

Added lines #L105 - L106 were not covered by tests
combined_mask.append(np.full(self.data[k].shape, False))
return np.any(combined_mask, axis=0)

@property
def kernel(self):
return self._kernel
Expand All @@ -100,20 +115,20 @@
def kernel(self, kernel):
if len(kernel) != len(self.data):
raise ValueError('Number of kernels must be equal to length of wavelength dimension.')
if not all([v.shape == self.temperature_bin_centers.shape for _,v in kernel.items()]):
if not all([v.shape == self.kernel_temperatures.shape for _, v in kernel.items()]):

Check warning on line 118 in sunkit_dem/base_model.py

View check run for this annotation

Codecov / codecov/patch

sunkit_dem/base_model.py#L118

Added line #L118 was not covered by tests
raise ValueError('Temperature bin centers and kernels must have the same shape.')
self._kernel = kernel

@property
def data_matrix(self):
return np.stack([self.data[k].data*self.data[k].unit for k in self._keys])
return np.stack([self.data[k].data for k in self._keys])

@property
def kernel_matrix(self):
return np.stack([self.kernel[k] for k in self._keys])
return np.stack([self.kernel[k].value for k in self._keys])

def fit(self, *args, **kwargs):
"""
r"""
Apply inversion procedure to data.

Returns
Expand All @@ -126,9 +141,13 @@
dem_dict = self._model(*args, **kwargs)
wcs = self._make_dem_wcs()
meta = self._make_dem_meta()
dem = ndcube.NDCube(dem_dict.pop('dem'),
dem_data = dem_dict.pop('dem')
mask = np.full(dem_data.shape, False)

Check warning on line 145 in sunkit_dem/base_model.py

View check run for this annotation

Codecov / codecov/patch

sunkit_dem/base_model.py#L144-L145

Added lines #L144 - L145 were not covered by tests
mask[:,...] = self.combined_mask
dem = ndcube.NDCube(dem_data,
wcs,
meta=meta,
mask=mask,

Check warning on line 150 in sunkit_dem/base_model.py

View check run for this annotation

Codecov / codecov/patch

sunkit_dem/base_model.py#L150

Added line #L150 was not covered by tests
uncertainty=StdDevUncertainty(dem_dict.pop('uncertainty')))
cubes = [('dem', dem),]
for k in dem_dict:
Expand Down