diff --git a/DeepCrazyhouse/src/training/trainer_agent_pytorch.py b/DeepCrazyhouse/src/training/trainer_agent_pytorch.py index 26dff358..eeb177fc 100644 --- a/DeepCrazyhouse/src/training/trainer_agent_pytorch.py +++ b/DeepCrazyhouse/src/training/trainer_agent_pytorch.py @@ -186,8 +186,9 @@ def train(self, cur_it=None): # log the metric values to tensorboard self._log_metrics(train_metric_values, global_step=self.k_steps, prefix="train_") self._log_metrics(val_metric_values, global_step=self.k_steps, prefix="val_") - for dataset_name, metric_values in additional_metric_values.items(): - self._log_metrics(metric_values, global_step=self.k_steps, prefix=f"{dataset_name}_") + if self.additional_loaders is not None: + for dataset_name, metric_values in additional_metric_values.items(): + self._log_metrics(metric_values, global_step=self.k_steps, prefix=f"{dataset_name}_") if self.tc.log_metrics_to_tensorboard and self.tc.export_grad_histograms: grads = [] @@ -332,21 +333,22 @@ def evaluate(self, train_loader): # do additional evaluations based on self.additional_loaders additional_metric_values = dict() - for dataset_name, dataloader in self.additional_loaders.items(): - print(f"starting {dataset_name} eval") - metric_values = evaluate_metrics( - self.to.metrics, - dataloader, - self._model, - nb_batches=None, - ctx=self._ctx, - phase_weights={k: 1.0 for k, v in self.to.phase_weights.items()}, # use no weighting - sparse_policy_label=self.tc.sparse_policy_label, - apply_select_policy_from_plane=self.tc.select_policy_from_plane and not self.tc.is_policy_from_plane_data, - use_wdl=self.tc.use_wdl, - use_plys_to_end=self.tc.use_plys_to_end, - ) - additional_metric_values[dataset_name] = metric_values + if self.additional_loaders is not None: + for dataset_name, dataloader in self.additional_loaders.items(): + print(f"starting {dataset_name} eval") + metric_values = evaluate_metrics( + self.to.metrics, + dataloader, + self._model, + nb_batches=None, + ctx=self._ctx, + phase_weights={k: 1.0 for k, v in self.to.phase_weights.items()}, # use no weighting + sparse_policy_label=self.tc.sparse_policy_label, + apply_select_policy_from_plane=self.tc.select_policy_from_plane and not self.tc.is_policy_from_plane_data, + use_wdl=self.tc.use_wdl, + use_plys_to_end=self.tc.use_plys_to_end, + ) + additional_metric_values[dataset_name] = metric_values self._model.train() # return back to training mode return train_metric_values, val_metric_values, additional_metric_values