From 134bbf9cbd782edab1bf8ab86707077125ccde56 Mon Sep 17 00:00:00 2001 From: Cong Date: Wed, 14 Jun 2023 11:32:43 +0800 Subject: [PATCH] fix: move hydra heads resize_token_embeddings move hydra heads and ref_model 's resize_token_embeddings function calls to AcceleratePPOTrainer --- trlx/trainer/accelerate_base_trainer.py | 5 ----- trlx/trainer/accelerate_ppo_trainer.py | 4 ++++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index be7718768..d4d0023bf 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -73,11 +73,6 @@ def __init__(self, config, **kwargs): # noqa: C901 self.tokenizer.add_tokens(self.additional_tokens) # resize the model by-default self.model.base_model.resize_token_embeddings(len(self.tokenizer)) - if hasattr(self.model, "frozen_head"): - self.model.frozen_head.resize_token_embeddings(len(self.tokenizer)) - else: - # resize a reference model when hydra heads are not used - self.ref_model.resize_token_embeddings(len(self.tokenizer)) self.tokenizer.padding_side = config.tokenizer.padding_side self.tokenizer.truncation_side = config.tokenizer.truncation_side diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 938192067..0191b0869 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -70,8 +70,12 @@ def __init__(self, config: TRLConfig, **kwargs): # Setup a reference model when hydra heads are not used if not hasattr(self.model, "frozen_head"): self.ref_model = self.get_arch(self.config) + self.ref_model.resize_token_embeddings(len(self.tokenizer)) self.ref_model.to(self.accelerator.device) self.ref_model.eval() + else: + # resize hydra heads + self.model.frozen_head.resize_token_embeddings(len(self.tokenizer)) # Setup the KL controller # This helps prevent large divergences in the controller (policy)