From 75851264613aa316c5bb5bbe3bcae15d5947ecbd Mon Sep 17 00:00:00 2001 From: "Bruno M. Pacheco" Date: Fri, 13 Jan 2023 17:14:02 -0300 Subject: [PATCH] move requires_grad of env_params (fix #5) only enables grad recording after changing env_params's device --- imitation_nonconvex/il_exp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/imitation_nonconvex/il_exp.py b/imitation_nonconvex/il_exp.py index e7e2557..3d5f092 100755 --- a/imitation_nonconvex/il_exp.py +++ b/imitation_nonconvex/il_exp.py @@ -135,17 +135,16 @@ def __init__(self, **kwargs): if self.learn_dx: if self.env_name == 'pendulum': self.env_params = torch.tensor( - (15., 3., 0.5), requires_grad=True) + (15., 3., 0.5)) elif self.env_name == 'cartpole': self.env_params = torch.tensor( - (9.8, 3.0, 0.1, 1.0), requires_grad=True) + (9.8, 3.0, 0.1, 1.0)) elif self.env_name == 'pendulum-complex': # self.env_params = torch.tensor( - # (15., 3., 0.5), requires_grad=True) + # (15., 3., 0.5)) torch.manual_seed(self.seed) self.env_params = torch.tensor((5., 1., 1.)) + \ torch.tensor((3., 1., 1.))*(torch.rand(3)-0.5) - self.env_params.requires_grad_() # n_hidden = 256 # self.extra_dx = NNDynamics( @@ -155,6 +154,7 @@ def __init__(self, **kwargs): else: self.env_params = self.env.true_dx.params self.env_params = self.env_params.to(self.device) + self.env_params.requires_grad_() else: assert False