Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8-bit inference (#512) #513

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions examples/ppo_sentiments_8bit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Generates positive movie reviews by tuning a pretrained model on IMDB dataset
# with a sentiment reward function
import json
import os
import sys
from typing import List

import torch
from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.default_configs import TRLConfig, default_ppo_config


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


def main(hparams={}):
# Merge sweep config with default config if given
config = TRLConfig.update(default_ppo_config().to_dict(), hparams)

if torch.cuda.is_available():
device = int(os.environ.get("LOCAL_RANK", 0))
else:
device = -1

sentiment_fn = pipeline(
"sentiment-analysis",
"lvwerra/distilbert-imdb",
top_k=2,
truncation=True,
batch_size=256,
device=device,
)

# Set the model loading in 8 bits
config.model.from_pretrained_kwargs = {
"load_in_8bit": True,
"torch_dtype": torch.bfloat16,
"device_map": "auto",
}

def reward_fn(samples: List[str], **kwargs) -> List[float]:
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
return sentiments

# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 256,
config=config,
)


if __name__ == "__main__":
hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1])
main(hparams)
17 changes: 11 additions & 6 deletions tests/test_peft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import gc
import importlib
import os
import sys
import tempfile
Expand Down Expand Up @@ -400,7 +401,7 @@ def test_lora_modules_to_save(self):

peft_config = {
"peft_type": PeftType.LORA,
"task_type": CAUSAL,
"task_type": TaskType.CAUSAL_LM,
"r": 8,
"lora_alpha": 32,
"lora_dropout": 0.0,
Expand Down Expand Up @@ -436,11 +437,10 @@ def test_lora_modules_to_save(self):
loaded_model_logits = loaded_model(**self.inputs, return_dict=True).logits
self.assertTrue(torch.equal(trained_model_logits, loaded_model_logits))

# @unittest.skipUnless(
# importlib.util.find_spec("bitsandbytes") and torch.cuda.is_available(),
# "bitsandbytes and GPU needed to execute test_8bits",
# )
@unittest.skip("`8-bit` model loading support is not yet fully implemented")
@unittest.skipUnless(
importlib.util.find_spec("bitsandbytes") and torch.cuda.is_available(),
"bitsandbytes and GPU needed to execute test_8bits",
)
def test_8bits(self):
"""Test the behaviour of from_pretrained with 8 bits models"""
from bitsandbytes.nn import Linear8bitLt
Expand Down Expand Up @@ -487,3 +487,8 @@ def test_8bits(self):
self.assertEqual(new_nb_trainable_params, initial_nb_trainable_params)

self.assertIsInstance(model_8bit.base_model.model.gpt_neox.layers[0].mlp.dense_h_to_4h, Linear8bitLt)

# Check that forward and generation work
self._create_inputs(model_id, CAUSAL)
model_8bit(**self.inputs)
model_8bit.generate(**self.inputs)
5 changes: 5 additions & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,17 @@ class ModelConfig:

(parameter-efficient fine-tuning was previously done in trlx with OpenDelta, but it is no longer supported)
:type peft_config: Union[peft.PeftConfig, Dict[str, Any]]

:param from_pretrained_kwargs: Any additional argument for PreTrainedModelWrapper.from_pretrained.
Can be used for example to load a model in 8 or 16 bits.
:type peft_config: Union[peft.PeftConfig, Dict[str, Any]]
"""

model_path: str
model_arch_type: str = "causal"
num_layers_unfrozen: int = -1
peft_config: Any = None
from_pretrained_kwargs: Dict[str, Any] = field(default_factory=dict)

@classmethod
def from_dict(cls, config: Dict[str, Any]):
Expand Down
15 changes: 2 additions & 13 deletions trlx/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,8 @@ def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, pe
super().__init__()
self.base_model = base_model
# cache `forward` args for general use (avoids incompatible args across architectures)
self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args
self.forward_kwargs = inspect.getfullargspec(self.base_model.__class__.forward).args
self.is_loaded_in_8bit = getattr(base_model, "is_loaded_in_8bit", False)
if self.is_loaded_in_8bit:
# TODO(glerzing): Fully test and support loading in 8-bit
raise NotImplementedError(
"`is_loaded_in_8bit` is an experimental feature not yet fully supported. Please do not use it."
)
self.peft_config = peft_config
self.peft_type = peft_config.peft_type if peft_config else None

Expand Down Expand Up @@ -176,12 +171,6 @@ def from_pretrained( # noqa: max-complexity
else:
is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False)

if is_loaded_in_8bit:
# TODO(glerzing): Fully test and support loading in 8-bit
raise NotImplementedError(
"`is_loaded_in_8bit` is an experimental feature not yet fully supported. Please do not use it."
)

if peft_config is not None:
if not is_peft_available():
raise ModuleNotFoundError("To use the argument peft_config, please install `peft`")
Expand Down Expand Up @@ -357,7 +346,7 @@ def post_init(self, *args, **kwargs):
# Don't use the interface of the peft model,
# use the interface of the underlying transformer model instead.
# (peft adds 2 "base_model" layers)
self.forward_kwargs = inspect.getfullargspec(self.base_model.base_model.base_model.forward).args
self.forward_kwargs = inspect.getfullargspec(self.base_model.base_model.base_model.__class__.forward).args

def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]:
"""Filter out arguments not supported by the specific instance of
Expand Down
14 changes: 10 additions & 4 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ class CausalLMOutputWithValue(ModelOutput):
value: Optional[torch.FloatTensor] = None


def make_value_branch(base_model, num_value_layers_unfrozen):
value_head = make_head(hf_get_hidden_size(base_model.config), 1)
def make_value_branch(base_model, num_value_layers_unfrozen, dtype=torch.float32):
value_head = make_head(hf_get_hidden_size(base_model.config), 1, dtype)
if num_value_layers_unfrozen == 0:
return value_head
config = base_model.config
Expand All @@ -280,7 +280,10 @@ def __init__(
):
super().__init__(base_model, peft_config=peft_config)
self.num_value_layers_unfrozen = num_value_layers_unfrozen
self.v_head = make_value_branch(base_model, num_value_layers_unfrozen)
parameter = next(hf_get_lm_head(self.base_model).parameters())
dtype = parameter.dtype
device = parameter.device
self.v_head = make_value_branch(base_model, num_value_layers_unfrozen, dtype).to(device)

def forward(
self,
Expand Down Expand Up @@ -1209,9 +1212,12 @@ def __init__(
):
super().__init__(base_model, peft_config=peft_config)
# TODO: Support Seq2Seq value branching
parameter = next(hf_get_lm_head(self.base_model).parameters())
dtype = parameter.dtype
device = parameter.device
if num_value_layers_unfrozen > 0:
raise NotImplementedError("Value branches unsupported for Seq2Seq architecture")
self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1)
self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1, dtype).to(device)

def forward(
self,
Expand Down
1 change: 1 addition & 0 deletions trlx/trainer/accelerate_ilql_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def get_arch(self, config):
two_qs=config.method.two_qs,
alpha=config.method.alpha,
peft_config=self.config.model.peft_config,
**config.model.from_pretrained_kwargs,
)

def post_backward_callback(self):
Expand Down
1 change: 1 addition & 0 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def get_arch(self, config: TRLConfig):
num_layers_unfrozen=config.model.num_layers_unfrozen,
num_value_layers_unfrozen=config.method.num_value_layers_unfrozen,
peft_config=self.config.model.peft_config,
**config.model.from_pretrained_kwargs,
)

def loss(self, batch: PPORLBatch):
Expand Down
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_arch(self, config):
if issubclass(type(config.model.model_path), PretrainedConfig):
from_fn = AutoModelForCausalLM.from_config

model = from_fn(config.model.model_path)
model = from_fn(config.model.model_path, **config.model.from_pretrained_kwargs)

if config.model.peft_config is not None:
# Initialize the peft adapter
Expand Down
2 changes: 1 addition & 1 deletion trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import transformers


def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequential:
def make_head(n_embd: int, out: int, dtype: torch.dtype = torch.float32) -> nn.Sequential:
"""Returns a generic sequential MLP head."""
return nn.Sequential(
nn.Linear(n_embd, n_embd * 2, dtype=dtype),
Expand Down