Skip to content

Commit

Permalink
Fix Priornet Forward in Pearl
Browse files Browse the repository at this point in the history
Summary: Model save tests fail with the current vmap implementation of forward in PriorNet. Revert to for loop.

Reviewed By: rodrigodesalvobraz

Differential Revision: D62661270

fbshipit-source-id: 93f904f81189c35f04be55f661af12f0529e3c96
  • Loading branch information
Hong Jun Jeon authored and facebook-github-bot committed Sep 14, 2024
1 parent 479c959 commit 69577db
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions pearl/neural_networks/common/epistemic_neural_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,14 @@ def forward(self, x: Tensor, z: Tensor) -> Tensor:
Output:
ensemble output of x weighted by epistemic index vector z.
"""
outputs = torch.vmap(self.call_single_model, (0, 0, None))(
self.params, self.buffers, x
)
# vmap is not compatible with torchscript
# outputs = torch.vmap(self.call_single_model, (0, 0, None))(
# self.params, self.buffers, x
# )
outputs = []
for model in self.models:
outputs.append(model(x))
outputs = torch.stack(outputs, dim=0)
return torch.einsum("ijk,ji->jk", outputs, z)


Expand Down

0 comments on commit 69577db

Please sign in to comment.