Skip to content

Commit

Permalink
Add destilled model to plot
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Jan 30, 2025
1 parent a116338 commit 2226934
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions ddlitlab2024/ml/inference/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
joint_state_context_length = 100
num_normalization_samples = 50
num_joints = 20
checkpoint = "/homes/17vahl/ddlitlab2024/ddlitlab2024/ml/training/trajectory_transformer_model.pth"
checkpoint = "/homes/17vahl/ddlitlab2024/ddlitlab2024/ml/training/destilled_trajectory_transformer_model.pth"
distilled = True

logger.info("Load model")
model = End2EndDiffusionTransformer(
Expand Down Expand Up @@ -110,15 +111,20 @@
noisy_trajectory = torch.randn_like(joint_targets).to(device)
trajectory = noisy_trajectory

# Perform the denoising process
scheduler.set_timesteps(inference_denosing_timesteps)
for t in scheduler.timesteps:
if distilled:
# Directly predict the trajectory based on the noise
with torch.no_grad():
# Predict the noise residual
noise_pred = model(batch, trajectory, torch.tensor([t], device=device))

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample
trajectory = model(batch, noisy_trajectory, torch.tensor([0], device=device))
else:
# Perform the denoising process
scheduler.set_timesteps(inference_denosing_timesteps)
for t in scheduler.timesteps:
with torch.no_grad():
# Predict the noise residual
noise_pred = model(batch, trajectory, torch.tensor([t], device=device))

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample

# Undo the normalization
print(normalizer.mean)
Expand Down

0 comments on commit 2226934

Please sign in to comment.