From a7b5067b739d0b324835366cc5b505e20be30d0a Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 24 Jun 2023 06:24:53 +0200 Subject: [PATCH 1/6] 8-bit inference (#512) --- tests/test_peft.py | 17 +++++++++++------ trlx/models/modeling_base.py | 15 ++------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/tests/test_peft.py b/tests/test_peft.py index ffe6f8bcb..73c6e40f4 100644 --- a/tests/test_peft.py +++ b/tests/test_peft.py @@ -1,5 +1,6 @@ import copy import gc +import importlib import os import sys import tempfile @@ -400,7 +401,7 @@ def test_lora_modules_to_save(self): peft_config = { "peft_type": PeftType.LORA, - "task_type": CAUSAL, + "task_type": TaskType.CAUSAL_LM, "r": 8, "lora_alpha": 32, "lora_dropout": 0.0, @@ -436,11 +437,10 @@ def test_lora_modules_to_save(self): loaded_model_logits = loaded_model(**self.inputs, return_dict=True).logits self.assertTrue(torch.equal(trained_model_logits, loaded_model_logits)) - # @unittest.skipUnless( - # importlib.util.find_spec("bitsandbytes") and torch.cuda.is_available(), - # "bitsandbytes and GPU needed to execute test_8bits", - # ) - @unittest.skip("`8-bit` model loading support is not yet fully implemented") + @unittest.skipUnless( + importlib.util.find_spec("bitsandbytes") and torch.cuda.is_available(), + "bitsandbytes and GPU needed to execute test_8bits", + ) def test_8bits(self): """Test the behaviour of from_pretrained with 8 bits models""" from bitsandbytes.nn import Linear8bitLt @@ -487,3 +487,8 @@ def test_8bits(self): self.assertEqual(new_nb_trainable_params, initial_nb_trainable_params) self.assertIsInstance(model_8bit.base_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h, Linear8bitLt) + + # Check that forward and generation work + self._create_inputs(model_id, CAUSAL) + model_8bit(**self.inputs) + model_8bit.generate(**self.inputs) diff --git a/trlx/models/modeling_base.py b/trlx/models/modeling_base.py index 3a6f2e062..01e95d38b 100644 --- a/trlx/models/modeling_base.py +++ b/trlx/models/modeling_base.py @@ -70,13 +70,8 @@ def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, pe super().__init__() self.base_model = base_model # cache `forward` args for general use (avoids incompatible args across architectures) - self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args + self.forward_kwargs = inspect.getfullargspec(self.base_model.__class__.forward).args self.is_loaded_in_8bit = getattr(base_model, "is_loaded_in_8bit", False) - if self.is_loaded_in_8bit: - # TODO(glerzing): Fully test and support loading in 8-bit - raise NotImplementedError( - "`is_loaded_in_8bit` is an experimental feature not yet fully supported. Please do not use it." - ) self.peft_config = peft_config self.peft_type = peft_config.peft_type if peft_config else None @@ -176,12 +171,6 @@ def from_pretrained( # noqa: max-complexity else: is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False) - if is_loaded_in_8bit: - # TODO(glerzing): Fully test and support loading in 8-bit - raise NotImplementedError( - "`is_loaded_in_8bit` is an experimental feature not yet fully supported. Please do not use it." - ) - if peft_config is not None: if not is_peft_available(): raise ModuleNotFoundError("To use the argument peft_config, please install `peft`") @@ -357,7 +346,7 @@ def post_init(self, *args, **kwargs): # Don't use the interface of the peft model, # use the interface of the underlying transformer model instead. # (peft adds 2 "base_model" layers) - self.forward_kwargs = inspect.getfullargspec(self.base_model.base_model.base_model.forward).args + self.forward_kwargs = inspect.getfullargspec(self.base_model.base_model.base_model.__class__.forward).args def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: """Filter out arguments not supported by the specific instance of From a81428a0446004f1b19faee202c338e60a7d6118 Mon Sep 17 00:00:00 2001 From: Alain Date: Wed, 12 Jul 2023 14:19:37 +0200 Subject: [PATCH 2/6] Add an example + the possibility to configure the from_pretrained arguments --- examples/ppo_sentiments_8bits.py | 63 +++++++++++++++++++++++++ trlx/data/configs.py | 5 ++ trlx/models/modeling_ppo.py | 10 ++-- trlx/trainer/accelerate_ilql_trainer.py | 1 + trlx/trainer/accelerate_ppo_trainer.py | 1 + trlx/trainer/accelerate_sft_trainer.py | 2 +- 6 files changed, 77 insertions(+), 5 deletions(-) create mode 100644 examples/ppo_sentiments_8bits.py diff --git a/examples/ppo_sentiments_8bits.py b/examples/ppo_sentiments_8bits.py new file mode 100644 index 000000000..e8fffa566 --- /dev/null +++ b/examples/ppo_sentiments_8bits.py @@ -0,0 +1,63 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from transformers import pipeline + +import trlx +from trlx.data.default_configs import TRLConfig, default_ppo_config + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +def main(hparams={}): + # Merge sweep config with default config if given + config = TRLConfig.update(default_ppo_config().to_dict(), hparams) + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=device, + ) + + # Set the model loading in 8 bits + config.model.from_pretrained_kwargs = { + "load_in_8bit": True, + "device_map": "auto", + } + + def reward_fn(samples: List[str], **kwargs) -> List[float]: + sentiments = list(map(get_positive_score, sentiment_fn(samples))) + return sentiments + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train+test") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + + trlx.train( + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=["I don't know much about Hungarian underground"] * 256, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index d2cb621e2..3c3b4b95a 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -59,12 +59,17 @@ class ModelConfig: (parameter-efficient fine-tuning was previously done in trlx with OpenDelta, but it is no longer supported) :type peft_config: Union[peft.PeftConfig, Dict[str, Any]] + + :param from_pretrained_kwargs: Any additional argument for PreTrainedModelWrapper.from_pretrained. + Can be used for example to load a model in 8 or 16 bits. + :type peft_config: Union[peft.PeftConfig, Dict[str, Any]] """ model_path: str model_arch_type: str = "causal" num_layers_unfrozen: int = -1 peft_config: Any = None + from_pretrained_kwargs: Dict[str, Any] = field(default_factory=dict) @classmethod def from_dict(cls, config: Dict[str, Any]): diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index ec377c062..8a0b5faf5 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -252,8 +252,8 @@ class CausalLMOutputWithValue(ModelOutput): value: Optional[torch.FloatTensor] = None -def make_value_branch(base_model, num_value_layers_unfrozen): - value_head = make_head(hf_get_hidden_size(base_model.config), 1) +def make_value_branch(base_model, num_value_layers_unfrozen, dtype): + value_head = make_head(hf_get_hidden_size(base_model.config), 1, dtype) if num_value_layers_unfrozen == 0: return value_head config = base_model.config @@ -279,8 +279,10 @@ def __init__( num_value_layers_unfrozen=0, ): super().__init__(base_model, peft_config=peft_config) - self.num_value_layers_unfrozen = num_value_layers_unfrozen - self.v_head = make_value_branch(base_model, num_value_layers_unfrozen) + parameter = next(hf_get_lm_head(self.base_model).parameters()) + dtype = parameter.dtype + device = parameter.device + self.v_head = make_value_branch(base_model, num_value_layers_unfrozen, dtype).to(device) def forward( self, diff --git a/trlx/trainer/accelerate_ilql_trainer.py b/trlx/trainer/accelerate_ilql_trainer.py index 60001ee55..117cc7122 100644 --- a/trlx/trainer/accelerate_ilql_trainer.py +++ b/trlx/trainer/accelerate_ilql_trainer.py @@ -132,6 +132,7 @@ def get_arch(self, config): two_qs=config.method.two_qs, alpha=config.method.alpha, peft_config=self.config.model.peft_config, + **config.model.from_pretrained_kwargs, ) def post_backward_callback(self): diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a7fcbb447..68da61c3f 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -120,6 +120,7 @@ def get_arch(self, config: TRLConfig): num_layers_unfrozen=config.model.num_layers_unfrozen, num_value_layers_unfrozen=config.method.num_value_layers_unfrozen, peft_config=self.config.model.peft_config, + **config.model.from_pretrained_kwargs, ) def loss(self, batch: PPORLBatch): diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index 11c88a1c9..8b065c6b2 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -42,7 +42,7 @@ def get_arch(self, config): if issubclass(type(config.model.model_path), PretrainedConfig): from_fn = AutoModelForCausalLM.from_config - model = from_fn(config.model.model_path) + model = from_fn(config.model.model_path, **config.model.from_pretrained_kwargs) if config.model.peft_config is not None: # Initialize the peft adapter From 44673376d2b171e1c02d554b1eee778d4dbc000a Mon Sep 17 00:00:00 2001 From: Alain Date: Mon, 17 Jul 2023 23:53:57 +0200 Subject: [PATCH 3/6] Minor renaming --- examples/{ppo_sentiments_8bits.py => ppo_sentiments_8bit.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{ppo_sentiments_8bits.py => ppo_sentiments_8bit.py} (100%) diff --git a/examples/ppo_sentiments_8bits.py b/examples/ppo_sentiments_8bit.py similarity index 100% rename from examples/ppo_sentiments_8bits.py rename to examples/ppo_sentiments_8bit.py From 4a6896bc39a335ea3b128b8f8c2a8be7b42837f6 Mon Sep 17 00:00:00 2001 From: Alain Date: Sat, 22 Jul 2023 16:49:14 +0200 Subject: [PATCH 4/6] Setting torch_dtype to bfloat16 --- examples/ppo_sentiments_8bit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ppo_sentiments_8bit.py b/examples/ppo_sentiments_8bit.py index e8fffa566..9f5a1e4f3 100644 --- a/examples/ppo_sentiments_8bit.py +++ b/examples/ppo_sentiments_8bit.py @@ -39,6 +39,7 @@ def main(hparams={}): # Set the model loading in 8 bits config.model.from_pretrained_kwargs = { "load_in_8bit": True, + "torch_dtype": torch.bfloat16, "device_map": "auto", } From 9dae10c18b7a42bae20a17eac3c4e4f5aa99b268 Mon Sep 17 00:00:00 2001 From: Alain Date: Mon, 7 Aug 2023 23:17:20 +0200 Subject: [PATCH 5/6] Fixes Seq2seq PPO + dtype default value --- trlx/models/modeling_ppo.py | 7 +++++-- trlx/utils/modeling.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 8a0b5faf5..f47e761fd 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -252,7 +252,7 @@ class CausalLMOutputWithValue(ModelOutput): value: Optional[torch.FloatTensor] = None -def make_value_branch(base_model, num_value_layers_unfrozen, dtype): +def make_value_branch(base_model, num_value_layers_unfrozen, dtype=torch.float32): value_head = make_head(hf_get_hidden_size(base_model.config), 1, dtype) if num_value_layers_unfrozen == 0: return value_head @@ -1211,9 +1211,12 @@ def __init__( ): super().__init__(base_model, peft_config=peft_config) # TODO: Support Seq2Seq value branching + parameter = next(hf_get_lm_head(self.base_model).parameters()) + dtype = parameter.dtype + device = parameter.device if num_value_layers_unfrozen > 0: raise NotImplementedError("Value branches unsupported for Seq2Seq architecture") - self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1) + self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1, dtype).to(device) def forward( self, diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 6e737c080..84134642b 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -10,7 +10,7 @@ import transformers -def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequential: +def make_head(n_embd: int, out: int, dtype: torch.dtype = torch.float32) -> nn.Sequential: """Returns a generic sequential MLP head.""" return nn.Sequential( nn.Linear(n_embd, n_embd * 2, dtype=dtype), From 5cc4a17550ae3f9a9b12158aad572afe3384ac88 Mon Sep 17 00:00:00 2001 From: Alain Date: Tue, 8 Aug 2023 02:51:39 +0200 Subject: [PATCH 6/6] missing attribute --- trlx/models/modeling_ppo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index f47e761fd..5f866932d 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -279,6 +279,7 @@ def __init__( num_value_layers_unfrozen=0, ): super().__init__(base_model, peft_config=peft_config) + self.num_value_layers_unfrozen = num_value_layers_unfrozen parameter = next(hf_get_lm_head(self.base_model).parameters()) dtype = parameter.dtype device = parameter.device