diff --git a/scripts/configs/optim.py b/scripts/configs/optim.py index d9b9b113..31a8a43d 100644 --- a/scripts/configs/optim.py +++ b/scripts/configs/optim.py @@ -25,14 +25,21 @@ def get_params( "params": [ val for key, val in parametrization.parameters.items() - if key != "logZ" + if "logZ" not in key ], "lr": self.lr, } ] if "logZ" in parametrization.parameters: params.append( - {"params": [parametrization.parameters["logZ"]], "lr": self.lr_Z} + { + "params": [ + val + for key, val in parametrization.parameters.items() + if "logZ" in key + ], + "lr": self.lr_Z, + } ) else: self.lr_Z = None