Skip to content

Commit

Permalink
Check for additional_loaders not None
Browse files Browse the repository at this point in the history
  • Loading branch information
QueensGambit committed Aug 6, 2024
1 parent 1b05aff commit bc0ad0e
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions DeepCrazyhouse/src/training/trainer_agent_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,9 @@ def train(self, cur_it=None):
# log the metric values to tensorboard
self._log_metrics(train_metric_values, global_step=self.k_steps, prefix="train_")
self._log_metrics(val_metric_values, global_step=self.k_steps, prefix="val_")
for dataset_name, metric_values in additional_metric_values.items():
self._log_metrics(metric_values, global_step=self.k_steps, prefix=f"{dataset_name}_")
if self.additional_loaders is not None:
for dataset_name, metric_values in additional_metric_values.items():
self._log_metrics(metric_values, global_step=self.k_steps, prefix=f"{dataset_name}_")

if self.tc.log_metrics_to_tensorboard and self.tc.export_grad_histograms:
grads = []
Expand Down Expand Up @@ -332,21 +333,22 @@ def evaluate(self, train_loader):

# do additional evaluations based on self.additional_loaders
additional_metric_values = dict()
for dataset_name, dataloader in self.additional_loaders.items():
print(f"starting {dataset_name} eval")
metric_values = evaluate_metrics(
self.to.metrics,
dataloader,
self._model,
nb_batches=None,
ctx=self._ctx,
phase_weights={k: 1.0 for k, v in self.to.phase_weights.items()}, # use no weighting
sparse_policy_label=self.tc.sparse_policy_label,
apply_select_policy_from_plane=self.tc.select_policy_from_plane and not self.tc.is_policy_from_plane_data,
use_wdl=self.tc.use_wdl,
use_plys_to_end=self.tc.use_plys_to_end,
)
additional_metric_values[dataset_name] = metric_values
if self.additional_loaders is not None:
for dataset_name, dataloader in self.additional_loaders.items():
print(f"starting {dataset_name} eval")
metric_values = evaluate_metrics(
self.to.metrics,
dataloader,
self._model,
nb_batches=None,
ctx=self._ctx,
phase_weights={k: 1.0 for k, v in self.to.phase_weights.items()}, # use no weighting
sparse_policy_label=self.tc.sparse_policy_label,
apply_select_policy_from_plane=self.tc.select_policy_from_plane and not self.tc.is_policy_from_plane_data,
use_wdl=self.tc.use_wdl,
use_plys_to_end=self.tc.use_plys_to_end,
)
additional_metric_values[dataset_name] = metric_values

self._model.train() # return back to training mode
return train_metric_values, val_metric_values, additional_metric_values
Expand Down

0 comments on commit bc0ad0e

Please sign in to comment.