Skip to content

Commit

Permalink
transform back to coordinates when offsets are predicted
Browse files Browse the repository at this point in the history
using astropy function 'spherical_offsets_by'
  • Loading branch information
TjarkMiener committed Jul 4, 2024
1 parent 8cab507 commit c4d2472
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 30 deletions.
2 changes: 1 addition & 1 deletion ctlearn/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __data_generation(self, batch_indices):
if self.prt_pos is not None:
self.prt_labels.append(np.float32(event[self.prt_pos]))
if self.enr_pos is not None:
self.enr_labels.append(np.float32(event[self.enr_pos][0]))
self.enr_labels.append(np.float32(event[self.enr_pos]))
if self.drc_pos is not None:
self.az_labels.append(np.float32(event[self.drc_pos][0]))
self.alt_labels.append(np.float32(event[self.drc_pos][1]))
Expand Down
90 changes: 62 additions & 28 deletions ctlearn/output_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import pandas as pd
from astropy.table import Table
from astropy.coordinates import SkyCoord
import astropy.units as u


def write_output(h5file, data, rest_data, reader, predictions, tasks):
Expand Down Expand Up @@ -90,43 +92,75 @@ def write_output(h5file, data, rest_data, reader, predictions, tasks):
)
reco["true_energy"] = true_energy
if "energy" in tasks:
if data.energy_unit == "log(TeV)" or np.min(predictions) < 0.0:
if data.energy_unit == "log(TeV)":
reco["reco_energy"] = np.power(10, predictions)[:, 0]
else:
reco["reco_energy"] = np.array(predictions)[:, 0]
# Arrival direction regression
if data.drc_pos:
alt = data.alt_labels[data.batch_size :]
az = data.az_labels[data.batch_size :]
true_az_offset = data.az_labels[data.batch_size :]
true_alt_offset = data.alt_labels[data.batch_size :]
true_sep = data.sep_labels[data.batch_size :]
if rest_data:
alt = (
np.concatenate(
(
alt,
rest_data.alt_labels[rest_data.batch_size :],
),
axis=0,
)
true_az_offset = np.concatenate(
(
true_az_offset,
rest_data.az_labels[rest_data.batch_size :],
),
axis=0,
)
az = (
np.concatenate(
(
az,
rest_data.az_labels[rest_data.batch_size :],
),
axis=0,
)
true_alt_offset = np.concatenate(
(
true_alt_offset,
rest_data.alt_labels[rest_data.batch_size :],
),
axis=0,
)
reco["true_alt"] = alt if reader.fix_pointing_alt is None else alt + reader.fix_pointing_alt
reco["true_az"] = az if reader.fix_pointing_az is None else az + reader.fix_pointing_az

true_sep = np.concatenate(
(
true_sep,
rest_data.sep_labels[rest_data.batch_size :],
),
axis=0,
)
reco["true_sep"] = true_sep
if reader.fix_pointing is None:
reco["true_az"] = true_az_offset
reco["true_alt"] = true_alt_offset
else:
true_az, true_alt = [], []
for az_off, alt_off in zip(true_az_offset, true_alt_offset):
true_direction = reader.fix_pointing.spherical_offsets_by(
u.Quantity(az_off, unit=data.drc_unit),
u.Quantity(alt_off, unit=data.drc_unit),
)
true_az.append(true_direction.az.to_value(data.drc_unit))
true_alt.append(true_direction.alt.to_value(data.drc_unit))
reco["true_az"] = np.array(true_az)
reco["true_alt"] = np.array(true_alt)
if "direction" in tasks:
reco["reco_alt"] = np.array(predictions[:, 0]) if reader.fix_pointing_alt is None else np.array(predictions[:, 0]) + reader.fix_pointing_alt
reco["reco_az"] = np.array(predictions[:, 1]) if reader.fix_pointing_az is None else np.array(predictions[:, 1]) + reader.fix_pointing_az
if reader.fix_pointing_alt is not None:
reco["pointing_alt"] = np.full(len(reco["reco_alt"]), reader.fix_pointing_alt)
if reader.fix_pointing_az is not None:
reco["pointing_az"] = np.full(len(reco["reco_az"]), reader.fix_pointing_az)
if reader.fix_pointing is None:
reco["reco_az"] = np.array(predictions[:, 0])
reco["reco_alt"] = np.array(predictions[:, 1])
reco["reco_sep"] = np.array(predictions[:, 2])
else:
reco_az, reco_alt = [], []
for az_off, alt_off in zip(predictions[:, 0], predictions[:, 1]):
reco_direction = reader.fix_pointing.spherical_offsets_by(
u.Quantity(az_off, unit=data.drc_unit),
u.Quantity(alt_off, unit=data.drc_unit),
)
reco_az.append(reco_direction.az.to_value(data.drc_unit))
reco_alt.append(reco_direction.alt.to_value(data.drc_unit))
reco["reco_az"] = np.array(reco_az)
reco["reco_alt"] = np.array(reco_alt)
reco["reco_sep"] = np.array(predictions[:, 2])
reco["pointing_az"] = np.full(
len(reco["reco_az"]), reader.fix_pointing.az.to_value(data.drc_unit)
)
reco["pointing_alt"] = np.full(
len(reco["reco_alt"]), reader.fix_pointing.alt.to_value(data.drc_unit)
)

if data.trgpatch_pos:
cherenkov_photons = data.trgpatch_labels[data.batch_size :]
Expand Down
2 changes: 1 addition & 1 deletion ctlearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def setup_DL1DataReader(config, mode):
and mode == "predict"
):
config["Data"]["parameter_settings"] = {"parameter_list": dl1bparameter_names}
if "direction" in tasks or mode == "predict":
if "direction" in tasks:
if mc_file:
event_info.append("true_alt")
event_info.append("true_az")
Expand Down

0 comments on commit c4d2472

Please sign in to comment.