Skip to content

Commit

Permalink
Make marginalized_scope variable name consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Nov 8, 2023
1 parent c0f7d4d commit 179cc72
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions simple_einet/einet.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ def __init__(self, config: EinetConfig):
# Construct the architecture
self._build()

def forward(self, x: torch.Tensor, marginalization_mask: torch.Tensor = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, marginalized_scopes: torch.Tensor = None) -> torch.Tensor:
"""
Inference pass for the Einet model.
Args:
x (torch.Tensor): Input data of shape [N, C, D], where C is the number of input channels (useful for images) and D is the number of features/random variables (H*W for images).
marginalized_scope: torch.Tensor: (Default value = None)
marginalized_scopes: torch.Tensor: (Default value = None)
Returns:
Log-likelihood tensor of the input: p(X) or p(X | C) if number of classes > 1.
Expand All @@ -111,10 +111,15 @@ def forward(self, x: torch.Tensor, marginalization_mask: torch.Tensor = None) ->
x = x.view(x.shape[0], self.config.num_channels, -1)

assert x.dim() == 3
assert x.shape[1] == self.config.num_channels
assert (
x.shape[1] == self.config.num_channels
), f"Number of channels in input ({x.shape[1]}) does not match number of channels specified in config ({self.config.num_channels})."
assert (
x.shape[2] == self.config.num_features
), f"Number of features in input ({x.shape[0]}) does not match number of features specified in config ({self.config.num_features})."

# Apply leaf distributions (replace marginalization indicators with 0.0 first)
x = self.leaf(x, marginalization_mask)
x = self.leaf(x, marginalized_scopes)

# Pass through intermediate layers
x = self._forward_layers(x)
Expand Down

0 comments on commit 179cc72

Please sign in to comment.