Skip to content

Commit

Permalink
fix: move hydra heads resize_token_embeddings
Browse files Browse the repository at this point in the history
move hydra heads and ref_model 's resize_token_embeddings function calls to  AcceleratePPOTrainer
  • Loading branch information
congchan committed Jun 14, 2023
1 parent 23cffd4 commit 134bbf9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
5 changes: 0 additions & 5 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ def __init__(self, config, **kwargs): # noqa: C901
self.tokenizer.add_tokens(self.additional_tokens)
# resize the model by-default
self.model.base_model.resize_token_embeddings(len(self.tokenizer))
if hasattr(self.model, "frozen_head"):
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer))
else:
# resize a reference model when hydra heads are not used
self.ref_model.resize_token_embeddings(len(self.tokenizer))

self.tokenizer.padding_side = config.tokenizer.padding_side
self.tokenizer.truncation_side = config.tokenizer.truncation_side
Expand Down
4 changes: 4 additions & 0 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def __init__(self, config: TRLConfig, **kwargs):
# Setup a reference model when hydra heads are not used
if not hasattr(self.model, "frozen_head"):
self.ref_model = self.get_arch(self.config)
self.ref_model.resize_token_embeddings(len(self.tokenizer))
self.ref_model.to(self.accelerator.device)
self.ref_model.eval()
else:
# resize hydra heads
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer))

# Setup the KL controller
# This helps prevent large divergences in the controller (policy)
Expand Down

0 comments on commit 134bbf9

Please sign in to comment.