Skip to content

Commit

Permalink
let args be strings
Browse files Browse the repository at this point in the history
  • Loading branch information
mertyg committed Jun 11, 2024
1 parent 3e07bd8 commit 7f1cb71
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion textgrad/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class TextLoss(Module):
def __init__(self,
eval_system_prompt: Variable,
eval_system_prompt: Union[Variable, str],
engine: Union[EngineLM, str] = None):
"""
A vanilla loss function to evaluate a response.
Expand All @@ -29,6 +29,8 @@ def __init__(self,
>>> response_evaluator(response)
"""
super().__init__()
if isinstance(eval_system_prompt, str):
eval_system_prompt = Variable(eval_system_prompt, requires_grad=False, role_description="system prompt for the evaluation")
self.eval_system_prompt = eval_system_prompt
if ((engine is None) and (SingletonBackwardEngine().get_engine() is None)):
raise Exception("No engine provided. Either provide an engine as the argument to this call, or use `textgrad.set_backward_engine(engine)` to set the backward engine.")
Expand Down
4 changes: 3 additions & 1 deletion textgrad/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .config import SingletonBackwardEngine

class BlackboxLLM(Module):
def __init__(self, engine: Union[EngineLM, str] = None, system_prompt: Variable = None):
def __init__(self, engine: Union[EngineLM, str] = None, system_prompt: Union[Variable, str] = None):
"""
Initialize the LLM module.
Expand All @@ -24,6 +24,8 @@ def __init__(self, engine: Union[EngineLM, str] = None, system_prompt: Variable
if isinstance(engine, str):
engine = get_engine(engine)
self.engine = engine
if isinstance(system_prompt, str):
system_prompt = Variable(system_prompt, requires_grad=False, role_description="system prompt for the language model")
self.system_prompt = system_prompt
self.llm_call = LLMCall(self.engine, self.system_prompt)

Expand Down

0 comments on commit 7f1cb71

Please sign in to comment.