Skip to content

Commit

Permalink
Epinet in Pearl
Browse files Browse the repository at this point in the history
Summary: Implement epinet in the Pearl open source library.

Reviewed By: xuruiyang

Differential Revision: D59681898

fbshipit-source-id: ac437771c56fcf67c145bacf6eed25de5999da7b
  • Loading branch information
Hong Jun Jeon authored and facebook-github-bot committed Jul 16, 2024
1 parent 1827a08 commit 13e5ffc
Show file tree
Hide file tree
Showing 6 changed files with 366 additions and 18 deletions.
1 change: 1 addition & 0 deletions pearl/neural_networks/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
"ValueNetwork",
"VanillaCNN",
"VanillaValueNetwork",
"Epinet",
]
128 changes: 116 additions & 12 deletions pearl/neural_networks/common/epistemic_neural_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
"""

from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Any, Dict, List, Optional

import torch
import torch.nn as nn
from pearl.neural_networks.common.utils import mlp_block
from pearl.neural_networks.common.utils import init_weights, mlp_block
from torch import Tensor


Expand All @@ -37,7 +37,7 @@ def __init__(
super(EpistemicNeuralNetwork, self).__init__()

@abstractmethod
def forward(self, x: Tensor, z: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, z: Tensor) -> Tensor:
"""
Input:
x: Feature vector of state action pairs
Expand Down Expand Up @@ -116,22 +116,17 @@ def __init__(

self._resample_epistemic_index()

def forward(
self, x: Tensor, z: Optional[Tensor] = None, persistent: bool = False
) -> Tensor:
def forward(self, x: Tensor, z: Tensor, persistent: bool = False) -> Tensor:
"""
Input:
x: Feature vector of state action pairs
z: Single integer tensor. Ensemble epistemic index
Output:
posterior samples corresponding to z
"""
if z is not None:
assert z.flatten().shape[0] == 1
ensemble_index = int(z.item())
assert ensemble_index >= 0 and ensemble_index < self.ensemble_size
else:
ensemble_index = self.z
assert z.flatten().shape[0] == 1
ensemble_index = int(z.item())
assert ensemble_index >= 0 and ensemble_index < self.ensemble_size

return self.models[ensemble_index](x)

Expand All @@ -140,3 +135,112 @@ def forward(

def _resample_epistemic_index(self) -> None:
self.z = torch.randint(0, self.ensemble_size, (1,))


class Priornet(nn.Module):
"""
Prior network for epinet. This network contains an ensemble of
randomly initialized models which are held fixed during training.
"""

def __init__(
self, input_dim: int, hidden_dims: List[int], output_dim: int, index_dim: int
) -> None:
super(Priornet, self).__init__()
self.input_dim = input_dim
self.hidden_dims = hidden_dims
self.output_dim = output_dim
self.index_dim = index_dim
models = []
for _ in range(self.index_dim):
model = mlp_block(self.input_dim, self.hidden_dims, self.output_dim)
# Xavier uniform initalization
model.apply(init_weights)
models.append(model)
self.models: nn.ModuleList = nn.ModuleList(models)
self.params: Dict[str, Any]
self.buffers: Dict[str, Any]
self.params, self.buffers = torch.func.stack_module_state(self.models)

def forward(self, x: Tensor, z: Tensor) -> Tensor:
"""
Perform forward pass on the priornet ensemble and weight by epistemic index
x and z are assumed to already be formatted.
Input:
x: tensor consisting of concatentated input and epistemic index
z: tensor consisting of epistemic index
Output:
ensemble output of x weighted by epistemic index vector z.
"""
outputs = []
for model in self.models:
outputs.append(model(x))
outputs = torch.stack(outputs, dim=0)
return torch.einsum("ijk,ji->jk", outputs, z)


class Epinet(EpistemicNeuralNetwork):
def __init__(
self,
index_dim: int,
input_dim: int,
output_dim: int,
num_indices: int,
epi_hiddens: List[int],
prior_hiddens: List[int],
prior_scale: float,
) -> None:
super(Epinet, self).__init__(input_dim, None, output_dim)
self.index_dim = index_dim
self.input_dim = input_dim
self.output_dim = output_dim
self.num_indices = num_indices
self.epi_hiddens = epi_hiddens
self.prior_hiddens = prior_hiddens
self.prior_scale = prior_scale

epinet_input_dim = self.input_dim + self.index_dim
# Trainable Epinet
self.epinet: nn.Module = mlp_block(
epinet_input_dim, self.epi_hiddens, self.index_dim * self.output_dim
)
self.epinet.apply(init_weights)
# Priornet
self.priornet = Priornet(
self.input_dim, self.prior_hiddens, self.output_dim, self.index_dim
)

def format_xz(self, x: Tensor, z: Tensor) -> Tensor:
"""
Take cartesian product of x and z and concatenate for forward pass.
Input:
x: Feature vectors containing item and user embeddings and interactions
z: Epinet epistemic indices
Output:
xz: Concatenated cartesian product of x and z
"""
batch_size, d = x.shape
num_indices, _ = z.shape
x_expanded = x.unsqueeze(1).expand(batch_size, num_indices, d)
z_expanded = z.unsqueeze(0).expand(batch_size, num_indices, self.index_dim)
xz = torch.cat([x_expanded, z_expanded], dim=-1)
return xz.view(batch_size * num_indices, d + self.index_dim)

def forward(self, x: Tensor, z: Tensor, persistent: bool = False) -> Tensor:
"""
Input:
x: Feature vector containing item and user embeddings and interactions
z: Matrix containing . Epinet epistemic indices
Output:
posterior samples corresponding to z
"""
xz = self.format_xz(x, z)
x_cartesian, z_cartesian = xz[:, : -self.index_dim], xz[:, -self.index_dim :]
batch_size, _ = xz.shape
epinet_out = self.epinet(xz.detach()).view(
batch_size, self.output_dim, self.index_dim
)
epinet_out = torch.einsum("ijk,ik->ij", epinet_out, z_cartesian)
with torch.no_grad():
priornet_out = self.prior_scale * self.priornet(x_cartesian, z_cartesian)
return epinet_out + priornet_out
Original file line number Diff line number Diff line change
Expand Up @@ -527,17 +527,15 @@ def resample_epistemic_index(self) -> None:
r"""Resamples the epistemic index of the underlying model."""
self._model._resample_epistemic_index()

def forward(
self, x: Tensor, z: Optional[Tensor] = None, persistent: bool = False
) -> Tensor:
def forward(self, x: Tensor, z: Tensor, persistent: bool = False) -> Tensor:
return self._model(x, z=z, persistent=persistent)

def get_q_values(
self,
state_batch: Tensor,
action_batch: Tensor,
z: Tensor,
curr_available_actions_batch: Optional[Tensor] = None,
z: Optional[Tensor] = None,
persistent: bool = False,
) -> Tensor:
x = torch.cat([state_batch, action_batch], dim=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def act(

with torch.no_grad():
q_values = self.q_ensemble_network.get_q_values(
state_batch=states_repeated, action_batch=actions, persistent=True
state_batch=states_repeated,
action_batch=actions,
z=self.q_ensemble_network._model.z,
persistent=True,
)
# this does a forward pass since all available
# actions are already stacked together
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
)
from pearl.api.action import Action

from pearl.api.action_space import ActionSpace
from pearl.api.state import SubjectiveState
from pearl.neural_networks.common.utils import update_target_network
from pearl.neural_networks.sequential_decision_making.q_value_networks import (
EnsembleQValueNetwork,
Expand All @@ -32,6 +34,7 @@
TransitionBatch,
TransitionWithBootstrapMaskBatch,
)
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace
from torch import optim, Tensor


Expand Down Expand Up @@ -134,9 +137,50 @@ def reset(self, action_space: ActionSpace) -> None:
# Reset the `DeepExploration` module, which will resample the epistemic index.
self._exploration_module.reset()

def act(
self,
subjective_state: SubjectiveState,
available_action_space: ActionSpace,
exploit: bool = False,
) -> Action:
# Fix the available action space.
assert isinstance(available_action_space, DiscreteActionSpace)
with torch.no_grad():
states_repeated = torch.repeat_interleave(
subjective_state.unsqueeze(0),
available_action_space.n,
dim=0,
)
# (action_space_size x state_dim)

actions = self._action_representation_module(
available_action_space.actions_batch.to(states_repeated)
)
# (action_space_size, action_dim)

q_values = self._Q.get_q_values(
states_repeated, actions, z=self._Q._model.z
)
# this does a forward pass since all avaialble
# actions are already stacked together

exploit_action_index = torch.argmax(q_values)
exploit_action = available_action_space.actions[exploit_action_index]

if exploit:
return exploit_action

assert self._exploration_module is not None
return self._exploration_module.act(
subjective_state=subjective_state,
action_space=available_action_space,
exploit_action=exploit_action,
values=q_values,
)

@torch.no_grad()
def _get_next_state_values(
self, batch: TransitionBatch, batch_size: int, z: Optional[Tensor] = None
self, batch: TransitionBatch, batch_size: int, z: Tensor
) -> torch.Tensor:
(
next_state,
Expand Down
Loading

0 comments on commit 13e5ffc

Please sign in to comment.