Skip to content

Commit

Permalink
use structured output
Browse files Browse the repository at this point in the history
  • Loading branch information
= Enea_Gore committed Oct 7, 2024
1 parent 85ea8a2 commit 7918ffd
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions modules/llm_core/llm_core/utils/llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from typing import Optional, Type, TypeVar, List
from pydantic import BaseModel, ValidationError
import tiktoken

from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.base_language import BaseLanguageModel
from langchain.prompts import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.chains.openai_functions import create_structured_output_chain
from langchain.output_parsers import PydanticOutputParser
from langchain.schema import OutputParserException
from langchain_core.runnables import RunnableSequence

from athena import emit_meta, get_experiment_environment

Expand Down Expand Up @@ -144,19 +141,13 @@ async def predict_and_parse(
if experiment.run_id is not None:
tags.append(f"run-{experiment.run_id}")

if supports_function_calling(model):
chain = create_structured_output_chain(pydantic_object, llm=model, prompt=chat_prompt, tags=tags)

try:
return await chain.arun(**prompt_input)
except (OutputParserException, ValidationError):
# In the future, we should probably have some recovery mechanism here (i.e. fix the output with another prompt)
return None
structured_output_llm = model.with_structured_output(pydantic_object, method="json_mode")
chain = RunnableSequence(
chat_prompt,
structured_output_llm
)

output_parser = PydanticOutputParser(pydantic_object=pydantic_object)
chain = LLMChain(llm=model, prompt=chat_prompt, output_parser=output_parser, tags=tags)
try:
return await chain.arun(**prompt_input)
except (OutputParserException, ValidationError):
# In the future, we should probably have some recovery mechanism here (i.e. fix the output with another prompt)
return None
return await chain.ainvoke(prompt_input, config={"tags": tags})
except ValidationError as e:
raise ValueError(f"Could not parse output: {e}") from e

0 comments on commit 7918ffd

Please sign in to comment.