diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index 8481f0f54..5c1932ded 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -18,7 +18,15 @@ output_directory = Path("outputs/eval/example_pusht_diffusion") output_directory.mkdir(parents=True, exist_ok=True) -device = torch.device("cuda") +# Check if GPU is available +if torch.cuda.is_available(): + device = torch.device("cuda") + print("GPU is available. Device set to:", device) +else: + device = torch.device("cpu") + print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.") + # Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed) + policy.diffusion.num_inference_steps = 10 # Download the diffusion policy for pusht environment pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))