diff --git a/blendsql/nl_to_blendsql/nl_to_blendsql.py b/blendsql/nl_to_blendsql/nl_to_blendsql.py index 31d79acf..6b83bc42 100644 --- a/blendsql/nl_to_blendsql/nl_to_blendsql.py +++ b/blendsql/nl_to_blendsql/nl_to_blendsql.py @@ -38,6 +38,7 @@ class ParserProgram(Program): def __call__( self, + model: Model, system_prompt: str, serialized_db: str, question: str, @@ -55,21 +56,22 @@ def __call__( + prompt + Fore.RESET ) - if isinstance(self.model, OllamaLLM): + if isinstance(model, OllamaLLM): # Handle call to ollama return return_ollama_response( - logits_generator=self.model.logits_generator, + logits_generator=model.logits_generator, prompt=prompt, stop=PARSER_STOP_TOKENS, temperature=0.0, ) - generator = outlines.generate.text(self.model.logits_generator) + generator = outlines.generate.text(model.logits_generator) return (generator(prompt, stop_at=PARSER_STOP_TOKENS), prompt) class CorrectionProgram(Program): def __call__( self, + model: Model, system_prompt: str, serialized_db: str, question: str, @@ -77,7 +79,7 @@ def __call__( candidates: List[str], **kwargs, ) -> Tuple[str, str]: - if isinstance(self.model, OllamaLLM): + if isinstance(model, OllamaLLM): raise ValueError("CorrectionProgram can't use OllamaLLM!") prompt = "" prompt += ( @@ -89,7 +91,7 @@ def __call__( prompt += f"BlendSQL:\n" prompt += partial_completion generator = outlines.generate.choice( - self.model.logits_generator, [re.escape(str(i)) for i in candidates] + model.logits_generator, [re.escape(str(i)) for i in candidates] ) return (generator(prompt), prompt)