Skip to content

Commit

Permalink
Fix destillation
Browse files Browse the repository at this point in the history
  • Loading branch information
Flova committed Jan 28, 2025
1 parent cb936c2 commit a116338
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions ddlitlab2024/ml/training/destill.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
num_normalization_samples = 1000
inference_denosing_timesteps = 30
num_joints = 20
checkpoint: str = "trajectory_transformer_model.pth"
checkpoint: str = "trajectory_transformer_model_500_epoch_xmas.pth"

# Load the dataset (primary for example conditioning)
logger.info("Create dataset objects")
Expand All @@ -66,8 +66,7 @@
worker_init_fn=worker_init_fn,
)

# Initialize the Transformer model and optimizer, and move model to device
teacher_model = End2EndDiffusionTransformer( # TODO enforce all params to be consistent with the dataset
model_config = dict( # TODO enforce all params to be consistent with the dataset
num_joints=num_joints,
hidden_dim=hidden_dim,
use_action_history=True,
Expand All @@ -87,19 +86,26 @@
max_image_context_length=image_context_length,
num_decoder_layers=4,
trajectory_prediction_length=trajectory_prediction_length,
).to(device)
)

# Initialize the Transformer model and optimizer, and move model to device
teacher_model = End2EndDiffusionTransformer(**model_config).to(device)

# Utilize an Exponential Moving Average (EMA) for the model to smooth out the training process
teacher_ema = EMA(teacher_model, beta=0.999)

# Load the model if a checkpoint is provided
if checkpoint is not None:
logger.info(f"Loading model from {checkpoint}")
teacher_ema.load_state_dict(torch.load(checkpoint, weights_only=True))
logger.info(f"Loading model from {checkpoint}")
teacher_ema.load_state_dict(torch.load(checkpoint, weights_only=True))

# Clone the model
student_model = teacher_model.clone()
student_model = End2EndDiffusionTransformer(**model_config).to(device)

# Load the same checkpoint into the student model
# I load it from disk do avoid any potential issues when copying the model
student_ema = EMA(student_model, beta=0.999)
logger.info(f"Loading model from {checkpoint}")
student_ema.load_state_dict(torch.load(checkpoint, weights_only=True))

# Create optimizer and learning rate scheduler
optimizer = torch.optim.AdamW(student_model.parameters(), lr=lr)
Expand Down Expand Up @@ -139,14 +145,15 @@
with torch.no_grad():
# Predict the noise residual
noise_pred = teacher_model.forward_with_context(
embedded_input, trajectory, torch.tensor([t], device=device)
embedded_input, trajectory, torch.full((joint_targets.size(0),), t, device=device)
)

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample

# Predict the denoised trajectory directly using the student model (null the timestep, as we are doing a single step prediction)
student_trajectory_prediction = student_model(batch, noisy_trajectory, torch.zeros(joint_targets.size(0), device=device))
student_trajectory_prediction = student_model.forward_with_context(
embedded_input, noisy_trajectory, torch.zeros(joint_targets.size(0), device=device))

# Compute the loss
loss = F.mse_loss(student_trajectory_prediction, trajectory)
Expand Down

0 comments on commit a116338

Please sign in to comment.