diff --git a/ctlearn/data_loader.py b/ctlearn/data_loader.py index b11188f6..800da56e 100644 --- a/ctlearn/data_loader.py +++ b/ctlearn/data_loader.py @@ -39,9 +39,10 @@ def __init__( self.event_list, self.obs_list = [], [] # Labels self.prt_pos, self.enr_pos, self.drc_pos = None, None, None + self.drc_unit = None self.prt_labels = [] self.enr_labels = [] - self.alt_labels, self.az_labels = [], [] + self.az_labels, self.alt_labels, self.sep_labels = [], [], [] self.trgpatch_labels = [] self.energy_unit = None @@ -68,6 +69,7 @@ def __init__( self.energy_unit = desc["unit"] elif "direction" in desc["name"]: self.drc_pos = i + self.drc_unit = desc["unit"] elif "event_id" in desc["name"]: self.evt_pos = i elif "obs_id" in desc["name"]: @@ -120,7 +122,7 @@ def __data_generation(self, batch_indices): if self.enr_pos is not None: energy = np.empty((self.batch_size)) if self.drc_pos is not None: - direction = np.empty((self.batch_size, 2)) + direction = np.empty((self.batch_size, 3)) if self.trgpatch_pos is not None: trigger_patches_true_image_sum = np.empty( (self.batch_size, *self.trgpatch_shape) @@ -157,10 +159,11 @@ def __data_generation(self, batch_indices): if self.prt_pos is not None: self.prt_labels.append(np.float32(event[self.prt_pos])) if self.enr_pos is not None: - self.enr_labels.append(np.float32(event[self.enr_pos][0])) + self.enr_labels.append(np.float32(event[self.enr_pos])) if self.drc_pos is not None: - self.alt_labels.append(np.float32(event[self.drc_pos][0])) - self.az_labels.append(np.float32(event[self.drc_pos][1])) + self.az_labels.append(np.float32(event[self.drc_pos][0])) + self.alt_labels.append(np.float32(event[self.drc_pos][1])) + self.sep_labels.append(np.float32(event[self.drc_pos][2])) if self.trgpatch_pos is not None: self.trgpatch_labels.append(np.float32(event[self.trgpatch_pos])) # Save all parameters for the prediction phase diff --git a/ctlearn/default_models/head.py b/ctlearn/default_models/head.py index a3d22f2b..ab681fc6 100644 --- a/ctlearn/default_models/head.py +++ b/ctlearn/default_models/head.py @@ -43,7 +43,7 @@ def standard_head(inputs, tasks, params): logits["direction"] = fully_connect( inputs, standard_head_settings["direction"]["fc_head"], - expected_logits_dimension=2, + expected_logits_dimension=3, name="direction", ) losses["direction"] = tf.keras.losses.MeanAbsoluteError( diff --git a/ctlearn/output_handler.py b/ctlearn/output_handler.py index 990c5c55..a3ab0b38 100644 --- a/ctlearn/output_handler.py +++ b/ctlearn/output_handler.py @@ -3,6 +3,8 @@ import numpy as np import pandas as pd from astropy.table import Table +from astropy.coordinates import SkyCoord +import astropy.units as u def write_output(h5file, data, rest_data, reader, predictions, tasks): @@ -90,43 +92,75 @@ def write_output(h5file, data, rest_data, reader, predictions, tasks): ) reco["true_energy"] = true_energy if "energy" in tasks: - if data.energy_unit == "log(TeV)" or np.min(predictions) < 0.0: + if data.energy_unit == "log(TeV)": reco["reco_energy"] = np.power(10, predictions)[:, 0] else: reco["reco_energy"] = np.array(predictions)[:, 0] # Arrival direction regression if data.drc_pos: - alt = data.alt_labels[data.batch_size :] - az = data.az_labels[data.batch_size :] + true_az_offset = data.az_labels[data.batch_size :] + true_alt_offset = data.alt_labels[data.batch_size :] + true_sep = data.sep_labels[data.batch_size :] if rest_data: - alt = ( - np.concatenate( - ( - alt, - rest_data.alt_labels[rest_data.batch_size :], - ), - axis=0, - ) + true_az_offset = np.concatenate( + ( + true_az_offset, + rest_data.az_labels[rest_data.batch_size :], + ), + axis=0, ) - az = ( - np.concatenate( - ( - az, - rest_data.az_labels[rest_data.batch_size :], - ), - axis=0, - ) + true_alt_offset = np.concatenate( + ( + true_alt_offset, + rest_data.alt_labels[rest_data.batch_size :], + ), + axis=0, ) - reco["true_alt"] = alt if reader.fix_pointing_alt is None else alt + reader.fix_pointing_alt - reco["true_az"] = az if reader.fix_pointing_az is None else az + reader.fix_pointing_az - + true_sep = np.concatenate( + ( + true_sep, + rest_data.sep_labels[rest_data.batch_size :], + ), + axis=0, + ) + reco["true_sep"] = true_sep + if reader.fix_pointing is None: + reco["true_az"] = true_az_offset + reco["true_alt"] = true_alt_offset + else: + true_az, true_alt = [], [] + for az_off, alt_off in zip(true_az_offset, true_alt_offset): + true_direction = reader.fix_pointing.spherical_offsets_by( + u.Quantity(az_off, unit=data.drc_unit), + u.Quantity(alt_off, unit=data.drc_unit), + ) + true_az.append(true_direction.az.to_value(data.drc_unit)) + true_alt.append(true_direction.alt.to_value(data.drc_unit)) + reco["true_az"] = np.array(true_az) + reco["true_alt"] = np.array(true_alt) if "direction" in tasks: - reco["reco_alt"] = np.array(predictions[:, 0]) if reader.fix_pointing_alt is None else np.array(predictions[:, 0]) + reader.fix_pointing_alt - reco["reco_az"] = np.array(predictions[:, 1]) if reader.fix_pointing_az is None else np.array(predictions[:, 1]) + reader.fix_pointing_az - if reader.fix_pointing_alt is not None: - reco["pointing_alt"] = np.full(len(reco["reco_alt"]), reader.fix_pointing_alt) - if reader.fix_pointing_az is not None: - reco["pointing_az"] = np.full(len(reco["reco_az"]), reader.fix_pointing_az) + if reader.fix_pointing is None: + reco["reco_az"] = np.array(predictions[:, 0]) + reco["reco_alt"] = np.array(predictions[:, 1]) + reco["reco_sep"] = np.array(predictions[:, 2]) + else: + reco_az, reco_alt = [], [] + for az_off, alt_off in zip(predictions[:, 0], predictions[:, 1]): + reco_direction = reader.fix_pointing.spherical_offsets_by( + u.Quantity(az_off, unit=data.drc_unit), + u.Quantity(alt_off, unit=data.drc_unit), + ) + reco_az.append(reco_direction.az.to_value(data.drc_unit)) + reco_alt.append(reco_direction.alt.to_value(data.drc_unit)) + reco["reco_az"] = np.array(reco_az) + reco["reco_alt"] = np.array(reco_alt) + reco["reco_sep"] = np.array(predictions[:, 2]) + reco["pointing_az"] = np.full( + len(reco["reco_az"]), reader.fix_pointing.az.to_value(data.drc_unit) + ) + reco["pointing_alt"] = np.full( + len(reco["reco_alt"]), reader.fix_pointing.alt.to_value(data.drc_unit) + ) if data.trgpatch_pos: cherenkov_photons = data.trgpatch_labels[data.batch_size :] diff --git a/ctlearn/utils.py b/ctlearn/utils.py index 3c8e3628..ab4b1ef9 100644 --- a/ctlearn/utils.py +++ b/ctlearn/utils.py @@ -123,11 +123,11 @@ def setup_DL1DataReader(config, mode): and mode == "predict" ): config["Data"]["parameter_settings"] = {"parameter_list": dl1bparameter_names} - if "direction" in tasks or mode == "predict": + if "direction" in tasks: if mc_file: event_info.append("true_alt") event_info.append("true_az") - transformations.append({"name": "DeltaAltAz"}) + transformations.append({"name": "SkyOffsetSeparation"}) if "cherenkov_photons" in tasks: if "trigger_settings" in config["Data"]: config["Data"]["trigger_settings"]["reco_cherenkov_photons"] = True @@ -143,7 +143,7 @@ def setup_DL1DataReader(config, mode): if "energy" in tasks or mode == "predict": if mc_file: event_info.append("true_energy") - transformations.append({"name": "MCEnergy"}) + transformations.append({"name": "LogEnergy"}) stack_telescope_images = config["Input"].get("stack_telescope_images", False) if config["Data"]["mode"] == "stereo" and not stack_telescope_images: