Skip to content

Commit

Permalink
NeuralLinearBandit deadlock avoidance
Browse files Browse the repository at this point in the history
Summary:
When running in a distributed environment, the fix in D65428925 for zero-weight batches can lead to deadlock. This occurs because some workers are performing the LinearRegression update, and some are skipping it, but internally there is a `torch.distributed.allreduce`.

I had previously assumed that my remote jobs were hanging due to an unrelated bug, but this seems to be the root cause. Distributed jobs complete successfully with this modification.

Reviewed By: alexnikulkov

Differential Revision: D65556041

fbshipit-source-id: 1a4ea7eb5211622d452ea5b843a2307e73fe2523
  • Loading branch information
Alex Bird authored and facebook-github-bot committed Nov 6, 2024
1 parent f01b97e commit 18e67f2
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions pearl/policy_learners/contextual_bandits/neural_linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,15 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]:
loss.backward()
self._optimizer.step()

# Optimize linear regression
self.model._linear_regression_layer.learn_batch(
model_ret["nn_output"].detach(),
expected_values,
batch_weight,
)
self._maybe_apply_discounting()
# Optimize linear regression
# n.b. this is also done for 0-weight batches to ensure parity across workers for the
# the internal torch.distributed.allreduce; it can otherwise lead to deadlocks.
self.model._linear_regression_layer.learn_batch(
model_ret["nn_output"].detach(),
expected_values,
batch_weight,
)
self._maybe_apply_discounting()

predicted_values = predicted_values.detach() # detach for logging
return {
Expand Down

0 comments on commit 18e67f2

Please sign in to comment.