Skip to content

Commit

Permalink
Update __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
icecream-and-tea authored Jun 29, 2024
1 parent 36e6305 commit ce1c035
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions promptbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ class LLMModel(object):
def model_list():
return SUPPORTED_MODELS

def __init__(self, model: str, max_new_tokens: int=20, temperature: float=0.0, device: str="cuda", dtype: str="auto", model_dir: str=None, system_prompt: str=None, api_key:str =None):
def __init__(self, model: str, max_new_tokens: int=20, temperature: float=0.0, device: str="cuda", dtype: str="auto", model_dir: str=None, system_prompt: str=None, api_key:str =None, **kwargs):
self.model_name = model
self.model = self._create_model(max_new_tokens, temperature, device, dtype, model_dir, system_prompt, api_key)
self.model = self._create_model(max_new_tokens, temperature, device, dtype, model_dir, system_prompt, api_key, **kwargs)

def _create_model(self, max_new_tokens, temperature, device, dtype, model_dir, system_prompt, api_key):
def _create_model(self, max_new_tokens, temperature, device, dtype, model_dir, system_prompt, api_key, **kwargs):
"""Creates and returns the appropriate model based on the model name."""

# Dictionary mapping of model names to their respective classes
Expand Down Expand Up @@ -310,4 +310,4 @@ def _other_concat_prompts(self, prompt_list):

def __call__(self, input_images, input_text, **kwargs):
"""Predicts the output based on the given input text using the loaded model."""
return self.model.predict(input_images, input_text, **kwargs)
return self.model.predict(input_images, input_text, **kwargs)

0 comments on commit ce1c035

Please sign in to comment.