diff --git a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py index 3bab843..e15c713 100644 --- a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py @@ -199,13 +199,15 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: loss.backward() self._optimizer.step() - # Optimize linear regression - self.model._linear_regression_layer.learn_batch( - model_ret["nn_output"].detach(), - expected_values, - batch_weight, - ) - self._maybe_apply_discounting() + # Optimize linear regression + # n.b. this is also done for 0-weight batches to ensure parity across workers for the + # the internal torch.distributed.allreduce; it can otherwise lead to deadlocks. + self.model._linear_regression_layer.learn_batch( + model_ret["nn_output"].detach(), + expected_values, + batch_weight, + ) + self._maybe_apply_discounting() predicted_values = predicted_values.detach() # detach for logging return {