From 18e67f2015ab66da485b381ed6138dfa4eba7192 Mon Sep 17 00:00:00 2001 From: Alex Bird Date: Wed, 6 Nov 2024 15:47:22 -0800 Subject: [PATCH] NeuralLinearBandit deadlock avoidance Summary: When running in a distributed environment, the fix in D65428925 for zero-weight batches can lead to deadlock. This occurs because some workers are performing the LinearRegression update, and some are skipping it, but internally there is a `torch.distributed.allreduce`. I had previously assumed that my remote jobs were hanging due to an unrelated bug, but this seems to be the root cause. Distributed jobs complete successfully with this modification. Reviewed By: alexnikulkov Differential Revision: D65556041 fbshipit-source-id: 1a4ea7eb5211622d452ea5b843a2307e73fe2523 --- .../contextual_bandits/neural_linear_bandit.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py index 3bab843..e15c713 100644 --- a/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/neural_linear_bandit.py @@ -199,13 +199,15 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: loss.backward() self._optimizer.step() - # Optimize linear regression - self.model._linear_regression_layer.learn_batch( - model_ret["nn_output"].detach(), - expected_values, - batch_weight, - ) - self._maybe_apply_discounting() + # Optimize linear regression + # n.b. this is also done for 0-weight batches to ensure parity across workers for the + # the internal torch.distributed.allreduce; it can otherwise lead to deadlocks. + self.model._linear_regression_layer.learn_batch( + model_ret["nn_output"].detach(), + expected_values, + batch_weight, + ) + self._maybe_apply_discounting() predicted_values = predicted_values.detach() # detach for logging return {