From 693e54454c58b7a47f6fcb9e0d0e588b3f30ff2b Mon Sep 17 00:00:00 2001 From: parkervg Date: Fri, 24 May 2024 19:11:56 -0400 Subject: [PATCH] forgot to update nl_to_blendsql.py with model change from ccf0ecf --- blendsql/nl_to_blendsql/nl_to_blendsql.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/blendsql/nl_to_blendsql/nl_to_blendsql.py b/blendsql/nl_to_blendsql/nl_to_blendsql.py index 31d79acf..6b83bc42 100644 --- a/blendsql/nl_to_blendsql/nl_to_blendsql.py +++ b/blendsql/nl_to_blendsql/nl_to_blendsql.py @@ -38,6 +38,7 @@ class ParserProgram(Program): def __call__( self, + model: Model, system_prompt: str, serialized_db: str, question: str, @@ -55,21 +56,22 @@ def __call__( + prompt + Fore.RESET ) - if isinstance(self.model, OllamaLLM): + if isinstance(model, OllamaLLM): # Handle call to ollama return return_ollama_response( - logits_generator=self.model.logits_generator, + logits_generator=model.logits_generator, prompt=prompt, stop=PARSER_STOP_TOKENS, temperature=0.0, ) - generator = outlines.generate.text(self.model.logits_generator) + generator = outlines.generate.text(model.logits_generator) return (generator(prompt, stop_at=PARSER_STOP_TOKENS), prompt) class CorrectionProgram(Program): def __call__( self, + model: Model, system_prompt: str, serialized_db: str, question: str, @@ -77,7 +79,7 @@ def __call__( candidates: List[str], **kwargs, ) -> Tuple[str, str]: - if isinstance(self.model, OllamaLLM): + if isinstance(model, OllamaLLM): raise ValueError("CorrectionProgram can't use OllamaLLM!") prompt = "" prompt += ( @@ -89,7 +91,7 @@ def __call__( prompt += f"BlendSQL:\n" prompt += partial_completion generator = outlines.generate.choice( - self.model.logits_generator, [re.escape(str(i)) for i in candidates] + model.logits_generator, [re.escape(str(i)) for i in candidates] ) return (generator(prompt), prompt)