diff --git a/blendsql/models/_model.py b/blendsql/models/_model.py index 51570999..14f2b134 100644 --- a/blendsql/models/_model.py +++ b/blendsql/models/_model.py @@ -119,6 +119,7 @@ def predict(self, program: Type[Program], **kwargs) -> dict: return self.cache.get(key) # Modify fields used for tracking Model usage response, prompt = program(model=self, **kwargs) + self.prompts.insert(-1, self.format_prompt(response, **kwargs)) self.num_calls += 1 if self.tokenizer is not None: self.prompt_tokens += len(self.tokenizer.encode(prompt)) @@ -155,8 +156,8 @@ def _create_key(self, program: Program, **kwargs) -> str: return hasher.hexdigest() @staticmethod - def format_prompt(res, **kwargs) -> dict: - d = {"answer": res} + def format_prompt(response: str, **kwargs) -> dict: + d = {"answer": response} if IngredientKwarg.QUESTION in kwargs: d[IngredientKwarg.QUESTION] = kwargs.get(IngredientKwarg.QUESTION) if IngredientKwarg.CONTEXT in kwargs: