diff --git a/simple_einet/einet.py b/simple_einet/einet.py index 5b5c471..a4f502d 100644 --- a/simple_einet/einet.py +++ b/simple_einet/einet.py @@ -91,13 +91,13 @@ def __init__(self, config: EinetConfig): # Construct the architecture self._build() - def forward(self, x: torch.Tensor, marginalization_mask: torch.Tensor = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, marginalized_scopes: torch.Tensor = None) -> torch.Tensor: """ Inference pass for the Einet model. Args: x (torch.Tensor): Input data of shape [N, C, D], where C is the number of input channels (useful for images) and D is the number of features/random variables (H*W for images). - marginalized_scope: torch.Tensor: (Default value = None) + marginalized_scopes: torch.Tensor: (Default value = None) Returns: Log-likelihood tensor of the input: p(X) or p(X | C) if number of classes > 1. @@ -111,10 +111,15 @@ def forward(self, x: torch.Tensor, marginalization_mask: torch.Tensor = None) -> x = x.view(x.shape[0], self.config.num_channels, -1) assert x.dim() == 3 - assert x.shape[1] == self.config.num_channels + assert ( + x.shape[1] == self.config.num_channels + ), f"Number of channels in input ({x.shape[1]}) does not match number of channels specified in config ({self.config.num_channels})." + assert ( + x.shape[2] == self.config.num_features + ), f"Number of features in input ({x.shape[0]}) does not match number of features specified in config ({self.config.num_features})." # Apply leaf distributions (replace marginalization indicators with 0.0 first) - x = self.leaf(x, marginalization_mask) + x = self.leaf(x, marginalized_scopes) # Pass through intermediate layers x = self._forward_layers(x)