Skip to content

Commit

Permalink
fix bug in setting basemodel
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed Oct 2, 2023
1 parent e8f0c9c commit 79ecb5b
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/chemlift/finetune/peftmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@


class ChemLIFTClassifierFactory:
def __init__(self, model_name: str, **kwargs):
def __init__(self, property_name: str, model_name: str, **kwargs):
self.model_name = model_name
self.kwargs = kwargs
self.property_name = property_name

def create_model(self):
if "openai" in self.model_name:
tuner = Tuner(**self.kwargs)
return GPTClassifier(self.model_name, tuner=tuner, **self.kwargs)
if "openai/" in self.model_name:
model = self.model_name.split("/")[-1]
tuner = Tuner(base_model=model, **self.kwargs)
return GPTClassifier(self.property_name, tuner=tuner, **self.kwargs)
else:
return PEFTClassifier(self.model_name, **self.kwargs)
return PEFTClassifier(self.property_name, base_model=self.model_name, **self.kwargs)

def __call__(self):
return self.create_model()
Expand Down

0 comments on commit 79ecb5b

Please sign in to comment.