From 8dc524f321b9c122b46fe9d5b1e8205f9568c038 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 15 Aug 2024 13:59:47 +0100 Subject: [PATCH] fix bug in example 2 (#361) --- examples/2_evaluate_pretrained_policy.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index 5c1932ded..b2fe1dba1 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -18,6 +18,14 @@ output_directory = Path("outputs/eval/example_pusht_diffusion") output_directory.mkdir(parents=True, exist_ok=True) +# Download the diffusion policy for pusht environment +pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht")) +# OR uncomment the following to evaluate a policy from the local outputs/train folder. +# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") + +policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) +policy.eval() + # Check if GPU is available if torch.cuda.is_available(): device = torch.device("cuda") @@ -28,13 +36,6 @@ # Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed) policy.diffusion.num_inference_steps = 10 -# Download the diffusion policy for pusht environment -pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht")) -# OR uncomment the following to evaluate a policy from the local outputs/train folder. -# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") - -policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) -policy.eval() policy.to(device) # Initialize evaluation environment to render two observation types: