From 22269341e597a235931f5bdebe700fae9a76879a Mon Sep 17 00:00:00 2001 From: Florian Vahl <7vahl@informatik.uni-hamburg.de> Date: Thu, 30 Jan 2025 11:29:33 +0100 Subject: [PATCH] Add destilled model to plot --- ddlitlab2024/ml/inference/plot.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/ddlitlab2024/ml/inference/plot.py b/ddlitlab2024/ml/inference/plot.py index 99411ce..5e9bba5 100644 --- a/ddlitlab2024/ml/inference/plot.py +++ b/ddlitlab2024/ml/inference/plot.py @@ -36,7 +36,8 @@ joint_state_context_length = 100 num_normalization_samples = 50 num_joints = 20 - checkpoint = "/homes/17vahl/ddlitlab2024/ddlitlab2024/ml/training/trajectory_transformer_model.pth" + checkpoint = "/homes/17vahl/ddlitlab2024/ddlitlab2024/ml/training/destilled_trajectory_transformer_model.pth" + distilled = True logger.info("Load model") model = End2EndDiffusionTransformer( @@ -110,15 +111,20 @@ noisy_trajectory = torch.randn_like(joint_targets).to(device) trajectory = noisy_trajectory - # Perform the denoising process - scheduler.set_timesteps(inference_denosing_timesteps) - for t in scheduler.timesteps: + if distilled: + # Directly predict the trajectory based on the noise with torch.no_grad(): - # Predict the noise residual - noise_pred = model(batch, trajectory, torch.tensor([t], device=device)) - - # Update the trajectory based on the predicted noise and the current step of the denoising process - trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample + trajectory = model(batch, noisy_trajectory, torch.tensor([0], device=device)) + else: + # Perform the denoising process + scheduler.set_timesteps(inference_denosing_timesteps) + for t in scheduler.timesteps: + with torch.no_grad(): + # Predict the noise residual + noise_pred = model(batch, trajectory, torch.tensor([t], device=device)) + + # Update the trajectory based on the predicted noise and the current step of the denoising process + trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample # Undo the normalization print(normalizer.mean)