diff --git a/swirl_dynamics/projects/probabilistic_diffusion/colabs/demo.ipynb b/swirl_dynamics/projects/probabilistic_diffusion/colabs/demo.ipynb index 88ca0b7..fc793f3 100644 --- a/swirl_dynamics/projects/probabilistic_diffusion/colabs/demo.ipynb +++ b/swirl_dynamics/projects/probabilistic_diffusion/colabs/demo.ipynb @@ -973,7 +973,7 @@ "num_samples_per_cond = 5\n", "\n", "generate = jax.jit(\n", - " functools.partial(sampler.generate, num_samples_per_cond)\n", + " functools.partial(cond_sampler.generate, num_samples_per_cond)\n", ")" ] },