Skip to content

Commit

Permalink
forgot to update nl_to_blendsql.py with model change from ccf0ecf
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed May 24, 2024
1 parent 0586538 commit 693e544
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions blendsql/nl_to_blendsql/nl_to_blendsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
class ParserProgram(Program):
def __call__(
self,
model: Model,
system_prompt: str,
serialized_db: str,
question: str,
Expand All @@ -55,29 +56,30 @@ def __call__(
+ prompt
+ Fore.RESET
)
if isinstance(self.model, OllamaLLM):
if isinstance(model, OllamaLLM):
# Handle call to ollama
return return_ollama_response(
logits_generator=self.model.logits_generator,
logits_generator=model.logits_generator,
prompt=prompt,
stop=PARSER_STOP_TOKENS,
temperature=0.0,
)
generator = outlines.generate.text(self.model.logits_generator)
generator = outlines.generate.text(model.logits_generator)
return (generator(prompt, stop_at=PARSER_STOP_TOKENS), prompt)


class CorrectionProgram(Program):
def __call__(
self,
model: Model,
system_prompt: str,
serialized_db: str,
question: str,
partial_completion: str,
candidates: List[str],
**kwargs,
) -> Tuple[str, str]:
if isinstance(self.model, OllamaLLM):
if isinstance(model, OllamaLLM):
raise ValueError("CorrectionProgram can't use OllamaLLM!")
prompt = ""
prompt += (
Expand All @@ -89,7 +91,7 @@ def __call__(
prompt += f"BlendSQL:\n"
prompt += partial_completion
generator = outlines.generate.choice(
self.model.logits_generator, [re.escape(str(i)) for i in candidates]
model.logits_generator, [re.escape(str(i)) for i in candidates]
)
return (generator(prompt), prompt)

Expand Down

0 comments on commit 693e544

Please sign in to comment.