diff --git a/pearl/neural_networks/common/epistemic_neural_networks.py b/pearl/neural_networks/common/epistemic_neural_networks.py index d9921aa1..b99ded4c 100644 --- a/pearl/neural_networks/common/epistemic_neural_networks.py +++ b/pearl/neural_networks/common/epistemic_neural_networks.py @@ -157,11 +157,31 @@ def __init__( # Xavier uniform initalization model.apply(init_weights) models.append(model) + self.base_model: nn.Module = mlp_block( + self.input_dim, self.hidden_dims, self.output_dim + ) + self.base_model.apply(init_weights) + self.base_model = self.base_model.to("meta") self.models: nn.ModuleList = nn.ModuleList(models) + self.params: Dict[str, Any] self.buffers: Dict[str, Any] + self.generate_params_buffers() + + def generate_params_buffers(self) -> None: + """ + Generate parameters and buffers for the priornet. + """ self.params, self.buffers = torch.func.stack_module_state(self.models) + def call_single_model( + self, params: Dict[str, Any], buffers: Dict[str, Any], data: Tensor + ) -> Tensor: + """ + Method for parallelizing priornet forward passes with torch.vmap. + """ + return torch.func.functional_call(self.base_model, (params, buffers), (data,)) + def forward(self, x: Tensor, z: Tensor) -> Tensor: """ Perform forward pass on the priornet ensemble and weight by epistemic index @@ -172,10 +192,9 @@ def forward(self, x: Tensor, z: Tensor) -> Tensor: Output: ensemble output of x weighted by epistemic index vector z. """ - outputs = [] - for model in self.models: - outputs.append(model(x)) - outputs = torch.stack(outputs, dim=0) + outputs = torch.vmap(self.call_single_model, (0, 0, None))( + self.params, self.buffers, x + ) return torch.einsum("ijk,ji->jk", outputs, z)