Skip to content

Commit

Permalink
set deterministic=True when sampling in diffusion evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
allenzren committed Sep 26, 2024
1 parent 4962bbc commit dd14c58
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion agent/eval/eval_diffusion_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run(self):
.float()
.to(self.device)
}
samples = self.model(cond=cond)
samples = self.model(cond=cond, deterministic=True)
output_venv = (
samples.trajectories.cpu().numpy()
) # n_env x horizon x act
Expand Down
2 changes: 1 addition & 1 deletion agent/eval/eval_diffusion_img_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def run(self):
key: torch.from_numpy(prev_obs_venv[key]).float().to(self.device)
for key in self.obs_dims
} # batch each type of obs and put into dict
samples = self.model(cond=cond)
samples = self.model(cond=cond, deterministic=True)
output_venv = (
samples.trajectories.cpu().numpy()
) # n_env x horizon x act
Expand Down

0 comments on commit dd14c58

Please sign in to comment.