From 7918ffd0a004797802abdcd9397f49eca89a093f Mon Sep 17 00:00:00 2001 From: = Enea_Gore Date: Mon, 7 Oct 2024 17:55:24 +0200 Subject: [PATCH] use structured output --- modules/llm_core/llm_core/utils/llm_utils.py | 27 +++++++------------- 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/modules/llm_core/llm_core/utils/llm_utils.py b/modules/llm_core/llm_core/utils/llm_utils.py index e7852bda0..f779cb55b 100644 --- a/modules/llm_core/llm_core/utils/llm_utils.py +++ b/modules/llm_core/llm_core/utils/llm_utils.py @@ -1,8 +1,6 @@ from typing import Optional, Type, TypeVar, List from pydantic import BaseModel, ValidationError import tiktoken - -from langchain.chains import LLMChain from langchain.chat_models import ChatOpenAI from langchain.base_language import BaseLanguageModel from langchain.prompts import ( @@ -10,9 +8,8 @@ SystemMessagePromptTemplate, HumanMessagePromptTemplate, ) -from langchain.chains.openai_functions import create_structured_output_chain from langchain.output_parsers import PydanticOutputParser -from langchain.schema import OutputParserException +from langchain_core.runnables import RunnableSequence from athena import emit_meta, get_experiment_environment @@ -144,19 +141,13 @@ async def predict_and_parse( if experiment.run_id is not None: tags.append(f"run-{experiment.run_id}") - if supports_function_calling(model): - chain = create_structured_output_chain(pydantic_object, llm=model, prompt=chat_prompt, tags=tags) - - try: - return await chain.arun(**prompt_input) - except (OutputParserException, ValidationError): - # In the future, we should probably have some recovery mechanism here (i.e. fix the output with another prompt) - return None + structured_output_llm = model.with_structured_output(pydantic_object, method="json_mode") + chain = RunnableSequence( + chat_prompt, + structured_output_llm + ) - output_parser = PydanticOutputParser(pydantic_object=pydantic_object) - chain = LLMChain(llm=model, prompt=chat_prompt, output_parser=output_parser, tags=tags) try: - return await chain.arun(**prompt_input) - except (OutputParserException, ValidationError): - # In the future, we should probably have some recovery mechanism here (i.e. fix the output with another prompt) - return None + return await chain.ainvoke(prompt_input, config={"tags": tags}) + except ValidationError as e: + raise ValueError(f"Could not parse output: {e}") from e