Skip to content

Commit

Permalink
properly treat direction task
Browse files Browse the repository at this point in the history
  • Loading branch information
TjarkMiener committed Jul 3, 2024
1 parent d4edbdc commit 8cab507
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
11 changes: 7 additions & 4 deletions ctlearn/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ def __init__(
self.event_list, self.obs_list = [], []
# Labels
self.prt_pos, self.enr_pos, self.drc_pos = None, None, None
self.drc_unit = None
self.prt_labels = []
self.enr_labels = []
self.alt_labels, self.az_labels = [], []
self.az_labels, self.alt_labels, self.sep_labels = [], [], []
self.trgpatch_labels = []
self.energy_unit = None

Expand All @@ -68,6 +69,7 @@ def __init__(
self.energy_unit = desc["unit"]
elif "direction" in desc["name"]:
self.drc_pos = i
self.drc_unit = desc["unit"]
elif "event_id" in desc["name"]:
self.evt_pos = i
elif "obs_id" in desc["name"]:
Expand Down Expand Up @@ -120,7 +122,7 @@ def __data_generation(self, batch_indices):
if self.enr_pos is not None:
energy = np.empty((self.batch_size))
if self.drc_pos is not None:
direction = np.empty((self.batch_size, 2))
direction = np.empty((self.batch_size, 3))
if self.trgpatch_pos is not None:
trigger_patches_true_image_sum = np.empty(
(self.batch_size, *self.trgpatch_shape)
Expand Down Expand Up @@ -159,8 +161,9 @@ def __data_generation(self, batch_indices):
if self.enr_pos is not None:
self.enr_labels.append(np.float32(event[self.enr_pos][0]))
if self.drc_pos is not None:
self.alt_labels.append(np.float32(event[self.drc_pos][0]))
self.az_labels.append(np.float32(event[self.drc_pos][1]))
self.az_labels.append(np.float32(event[self.drc_pos][0]))
self.alt_labels.append(np.float32(event[self.drc_pos][1]))
self.sep_labels.append(np.float32(event[self.drc_pos][2]))
if self.trgpatch_pos is not None:
self.trgpatch_labels.append(np.float32(event[self.trgpatch_pos]))
# Save all parameters for the prediction phase
Expand Down
2 changes: 1 addition & 1 deletion ctlearn/default_models/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def standard_head(inputs, tasks, params):
logits["direction"] = fully_connect(
inputs,
standard_head_settings["direction"]["fc_head"],
expected_logits_dimension=2,
expected_logits_dimension=3,
name="direction",
)
losses["direction"] = tf.keras.losses.MeanAbsoluteError(
Expand Down
4 changes: 2 additions & 2 deletions ctlearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def setup_DL1DataReader(config, mode):
if mc_file:
event_info.append("true_alt")
event_info.append("true_az")
transformations.append({"name": "DeltaAltAz"})
transformations.append({"name": "SkyOffsetSeparation"})
if "cherenkov_photons" in tasks:
if "trigger_settings" in config["Data"]:
config["Data"]["trigger_settings"]["reco_cherenkov_photons"] = True
Expand All @@ -143,7 +143,7 @@ def setup_DL1DataReader(config, mode):
if "energy" in tasks or mode == "predict":
if mc_file:
event_info.append("true_energy")
transformations.append({"name": "MCEnergy"})
transformations.append({"name": "LogEnergy"})

stack_telescope_images = config["Input"].get("stack_telescope_images", False)
if config["Data"]["mode"] == "stereo" and not stack_telescope_images:
Expand Down

0 comments on commit 8cab507

Please sign in to comment.