Skip to content

Commit

Permalink
[update] add support for new models
Browse files Browse the repository at this point in the history
  • Loading branch information
Immortalise committed Dec 24, 2023
1 parent b767e4e commit 7160c4b
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 5 deletions.
7 changes: 6 additions & 1 deletion promptbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
VicunaModel: ['vicuna-7b', 'vicuna-13b', 'vicuna-13b-v1.3'],
UL2Model: ['google/flan-ul2'],
GeminiModel: ['gemini-pro'],
MistralModel: ['mistralai/Mistral-7B-v0.1', 'mistralai/Mistral-7B-Instruct-v0.1'],
MixtralModel: ['mistralai/Mixtral-8x7B-v0.1'],
YiModel: ['01-ai/Yi-6B', '01-ai/Yi-34B', '01-ai/Yi-6B-Chat', '01-ai/Yi-34B-Chat'],
BaichuanModel: ['baichuan-inc/Baichuan2-7B-Base', 'baichuan-inc/Baichuan2-13B-Base',
'baichuan-inc/Baichuan2-7B-Chat', 'baichuan-inc/Baichuan2-13B-Chat'],
}

SUPPORTED_MODELS = [model for model_class in MODEL_LIST.keys() for model in MODEL_LIST[model_class]]
Expand Down Expand Up @@ -151,4 +156,4 @@ def _other_concat_prompts(self, prompt_list):

def __call__(self, input_text, **kwargs):
"""Predicts the output based on the given input text using the loaded model."""
return self.model.predict(input_text, **kwargs)
return self.model.predict(input_text, **kwargs)
115 changes: 111 additions & 4 deletions promptbench/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def predict(self, input_text, **kwargs):
outputs = self.model.generate(input_ids,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature,
do_sample=True,
**kwargs)

out = self.tokenizer.decode(outputs[0])
Expand All @@ -49,6 +50,112 @@ def __call__(self, input_text, **kwargs):
return self.predict(input_text, **kwargs)


class BaichuanModel(LMMBaseModel):
"""
Language model class for the Baichuan model.
Inherits from LMMBaseModel and sets up the Baichuan language model for use.
Parameters:
-----------
model : str
The name of the Baichuan model.
max_new_tokens : int
The maximum number of new tokens to be generated.
temperature : float, optional
The temperature for text generation (default is 0).
device: str
The device to use for inference (default is 'auto').
Methods:
--------
predict(input_text, **kwargs)
Generates a prediction based on the input text.
"""
def __init__(self, model_name, max_new_tokens, temperature, device, dtype):
super(BaichuanModel, self).__init__(model_name, max_new_tokens, temperature, device)
from transformers import AutoTokenizer, AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, torch_dtype=dtype, device_map=device, use_fast=False, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=dtype, device_map=device, trust_remote_code=True)


class YiModel(LMMBaseModel):
"""
Language model class for the Yi model.
Inherits from LMMBaseModel and sets up the Yi language model for use.
Parameters:
-----------
model : str
The name of the Yi model.
max_new_tokens : int
The maximum number of new tokens to be generated.
temperature : float
The temperature for text generation (default is 0).
device: str
The device to use for inference (default is 'auto').
"""
def __init__(self, model_name, max_new_tokens, temperature, device, dtype):
super(YiModel, self).__init__(model_name, max_new_tokens, temperature, device)
from transformers import AutoTokenizer, AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, torch_dtype=dtype, device_map=device)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=dtype, device_map=device)


class MixtralModel(LMMBaseModel):
"""
Language model class for the Mixtral model.
Inherits from LMMBaseModel and sets up the Mixtral language model for use.
Parameters:
-----------
model : str
The name of the Mixtral model.
max_new_tokens : int
The maximum number of new tokens to be generated.
temperature : float
The temperature for text generation (default is 0).
device: str
The device to use for inference (default is 'auto').
dtype: str
The dtype to use for inference (default is 'auto').
"""
def __init__(self, model_name, max_new_tokens, temperature, device, dtype):
super(MixtralModel, self).__init__(model_name, max_new_tokens, temperature, device)
from transformers import AutoTokenizer, AutoModelForCausalLM
print(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name[:-len("-v0.1")], torch_dtype=dtype, device_map=device)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=dtype, device_map=device)


class MistralModel(LMMBaseModel):
"""
Language model class for the Mistral model.
Inherits from LMMBaseModel and sets up the Mistral language model for use.
Parameters:
-----------
model : str
The name of the Mistral model.
max_new_tokens : int
The maximum number of new tokens to be generated.
temperature : float
The temperature for text generation (default is 0).
device: str
The device to use for inference (default is 'auto').
dtype: str
The dtype to use for inference (default is 'auto').
"""
def __init__(self, model_name, max_new_tokens, temperature, device, dtype):
super(MistralModel, self).__init__(model_name, max_new_tokens, temperature, device)
from transformers import AutoTokenizer, AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, torch_dtype=dtype, device_map=device)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=dtype, device_map=device)


class PhiModel(LMMBaseModel):
"""
Language model class for the Phi model.
Expand Down Expand Up @@ -337,8 +444,8 @@ class PaLMModel(LMMBaseModel):
The maximum number of new tokens to be generated.
temperature : float, optional
The temperature for text generation (default is 0).
model_dir : str, optional
The directory containing the model files (default is None).
api_key : str, optional
The PaLM API key (default is None).
"""
def __init__(self, model, max_new_tokens, temperature=0, api_key=None):
super(PaLMModel, self).__init__(model, max_new_tokens, temperature)
Expand Down Expand Up @@ -384,8 +491,8 @@ class GeminiModel(LMMBaseModel):
The maximum number of new tokens to be generated.
temperature : float, optional
The temperature for text generation (default is 0).
model_dir : str, optional
The directory containing the model files (default is None).
gemini_key : str, optional
The Gemini API key (default is None).
"""
def __init__(self, model, max_new_tokens, temperature=0, gemini_key=None):
super(GeminiModel, self).__init__(model, max_new_tokens, temperature)
Expand Down

0 comments on commit 7160c4b

Please sign in to comment.