Skip to content

Commit

Permalink
feat: add dpo
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Feb 21, 2024
1 parent 0cf2e29 commit 4148d5e
Show file tree
Hide file tree
Showing 5 changed files with 608 additions and 12 deletions.
7 changes: 7 additions & 0 deletions stoix/configs/default_ff_dpo_continuous.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- logger: ff_ppo
- arch: anakin
- system: ff_dpo
- network: mlp_continuous
- env: brax/ant
- _self_
23 changes: 23 additions & 0 deletions stoix/configs/system/ff_dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# --- Defaults FF-PPO ---

total_timesteps: 1e8 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates
seed: 42

# --- RL hyperparameters ---
actor_lr: 3e-4 # Learning rate for actor network
critic_lr: 3e-4 # Learning rate for critic network
update_batch_size: 1 # Number of vectorised gradient updates per device.
rollout_length: 16 # Number of environment steps per vectorised environment.
ppo_epochs: 4 # Number of ppo epochs per training data batch.
num_minibatches: 16 # Number of minibatches per ppo epoch.
gamma: 0.99 # Discounting factor.
gae_lambda: 0.95 # Lambda value for GAE computation.
clip_eps: 0.2 # Clipping value for PPO updates and value function.
ent_coef: 0.001 # Entropy regularisation term for loss function.
vf_coef: 1.0 # Critic weight in
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
alpha : 2.0
beta : 0.6
Loading

0 comments on commit 4148d5e

Please sign in to comment.