diff --git a/simple_einet/einet.py b/simple_einet/einet.py index 030f677..0c28013 100644 --- a/simple_einet/einet.py +++ b/simple_einet/einet.py @@ -275,7 +275,7 @@ def _build_input_distribution(self, num_features_out: int): num_channels=self.config.num_channels, num_leaves=self.config.num_leaves, num_repetitions=self.config.num_repetitions, - **self.config.leaf_kwargs, + **self.config.leaf_kwargs if self.config.leaf_kwargs is not None else {}, ) return FactorizedLeaf(