From 624d9e00abe230aaffd0537f12c3510f29281af5 Mon Sep 17 00:00:00 2001 From: Hiroki Yoneda Date: Thu, 3 Oct 2024 17:36:15 +0200 Subject: [PATCH] Changed ExtendedSourceResponse as a child class of Histogram --- cosipy/response/ExtendedSourceResponse.py | 112 +++++----------------- 1 file changed, 24 insertions(+), 88 deletions(-) diff --git a/cosipy/response/ExtendedSourceResponse.py b/cosipy/response/ExtendedSourceResponse.py index 7b8cf098..425c4d62 100644 --- a/cosipy/response/ExtendedSourceResponse.py +++ b/cosipy/response/ExtendedSourceResponse.py @@ -1,82 +1,36 @@ from histpy import Histogram, Axes, Axis -import h5py as h5 import numpy as np -import sys -import astropy.units as u from .functions import get_integrated_extended_model -class ExtendedSourceResponse(object): +class ExtendedSourceResponse(Histogram): """ A class to represent and manipulate extended source response data. This class provides methods to load data from HDF5 files, access contents, units, and axes information, and calculate expectations based on sky models. - Attributes - ---------- - _contents : astropy.units.Quantity - The contents of the extended source response as a Quantity array - (numpy array with astropy units). - _unit : astropy.units.Unit - The unit of the contents. - _axes : Axes - The axes object representing the dimensions of the data. - Methods ------- - open(filename, name='hist') - Load data from an HDF5 file. get_expectation(allsky_image_model) Calculate expectation based on an all-sky image model. get_expectation_from_astromodel(source) Calculate expectation from an astronomical model source. + + Notes + ----- + Currently, the axes of the response must be ['NuLambda', 'Ei', 'Em', 'Phi', 'PsiChi']. """ - def __init__(self, contents = None, unit = None, axes = None): + def __init__(self, *args, **kwargs): """ Initialize an ExtendedSourceResponse object. """ - self._contents = contents - self._unit = unit - self._axes = axes - - @property - def contents(self): - """ - Get the contents of the extended source response. - - Returns - ------- - astropy.units.Quantity - The contents of the extended source response as a Quantity array - (numpy array with astropy units). - """ - return self._contents - - @property - def unit(self): - """ - Get the unit of the contents. - - Returns - ------- - astropy.units.Unit - The unit of the contents. - """ - return self._unit - - @property - def axes(self): - """ - Get the axes object. - - Returns - ------- - Axes - The axes object representing the dimensions of the data. - """ - return self._axes + super().__init__(*args, **kwargs) + + if not np.all(self.axes.labels == ['NuLambda', 'Ei', 'Em', 'Phi', 'PsiChi']): + # 'NuLambda' should be 'lb' if it is in the gal. coordinates? + raise ValueError(f"The input axes {self.axes.labels} is not supported by ExtendedSourceResponse class.") @classmethod def open(cls, filename, name='hist'): @@ -100,38 +54,20 @@ def open(cls, filename, name='hist'): ValueError If the shape of the contents does not match the axes. """ - new = cls() + hist = super().open(filename, name) + + axes = hist.axes + contents = hist[:] + sumw2 = hist.sumw2 + unit = hist.unit + track_overflow = False + + new = cls(axes, contents = contents, + sumw2 = sumw2, + unit = unit, + track_overflow = track_overflow) - with h5.File(filename, 'r') as f: - hist_group = f[name] - - # load axes - axes_group = hist_group['axes'] - - axes = [] - for axis in axes_group.values(): - if '__class__' in axis.attrs: - class_module, class_name = axis.attrs['__class__'] - axis_cls = getattr(sys.modules[class_module], class_name) - axes += [axis_cls._open(axis)] - - new._axes = Axes(axes) - - # load unit - if 'unit' in hist_group.attrs: - new._unit = u.Unit(hist_group.attrs['unit']) - - # load contents - contents = np.zeros(new.axes.nbins) - if np.all(new.axes.nbins == hist_group['contents'].shape): - contents = hist_group['contents'][:] - elif np.all(new.axes.nbins + 2 == hist_group['contents'].shape): - contents = hist_group['contents'][tuple(slice(1, -1) for _ in range(len(new.axes)))] - else: - raise ValueError - - new._contents = contents * new.unit - + del hist return new def get_expectation(self, allsky_image_model):