Skip to content

Commit

Permalink
Update Pearl priornet with vmap
Browse files Browse the repository at this point in the history
Summary: Update the priornet implementation in pearl to use vmap for

Reviewed By: BillMatrix

Differential Revision: D60598963

fbshipit-source-id: 31f0fcd6c78993c866fd99999f87a62dad43f3cd
  • Loading branch information
Hong Jun Jeon authored and facebook-github-bot committed Aug 1, 2024
1 parent d1b22eb commit 55c5f54
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions pearl/neural_networks/common/epistemic_neural_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,31 @@ def __init__(
# Xavier uniform initalization
model.apply(init_weights)
models.append(model)
self.base_model: nn.Module = mlp_block(
self.input_dim, self.hidden_dims, self.output_dim
)
self.base_model.apply(init_weights)
self.base_model = self.base_model.to("meta")
self.models: nn.ModuleList = nn.ModuleList(models)

self.params: Dict[str, Any]
self.buffers: Dict[str, Any]
self.generate_params_buffers()

def generate_params_buffers(self) -> None:
"""
Generate parameters and buffers for the priornet.
"""
self.params, self.buffers = torch.func.stack_module_state(self.models)

def call_single_model(
self, params: Dict[str, Any], buffers: Dict[str, Any], data: Tensor
) -> Tensor:
"""
Method for parallelizing priornet forward passes with torch.vmap.
"""
return torch.func.functional_call(self.base_model, (params, buffers), (data,))

def forward(self, x: Tensor, z: Tensor) -> Tensor:
"""
Perform forward pass on the priornet ensemble and weight by epistemic index
Expand All @@ -172,10 +192,9 @@ def forward(self, x: Tensor, z: Tensor) -> Tensor:
Output:
ensemble output of x weighted by epistemic index vector z.
"""
outputs = []
for model in self.models:
outputs.append(model(x))
outputs = torch.stack(outputs, dim=0)
outputs = torch.vmap(self.call_single_model, (0, 0, None))(
self.params, self.buffers, x
)
return torch.einsum("ijk,ji->jk", outputs, z)


Expand Down

0 comments on commit 55c5f54

Please sign in to comment.