diff --git a/llm_core/llm_core/utils/llm_utils.py b/llm_core/llm_core/utils/llm_utils.py index 6cacd4c0..4637b855 100644 --- a/llm_core/llm_core/utils/llm_utils.py +++ b/llm_core/llm_core/utils/llm_utils.py @@ -1,6 +1,7 @@ from typing import Type, TypeVar, List from pydantic import BaseModel import tiktoken +from langchain.chat_models import ChatOpenAI from langchain.base_language import BaseLanguageModel from langchain.prompts import ( ChatPromptTemplate, @@ -65,6 +66,18 @@ def check_prompt_length_and_omit_features_if_necessary(prompt: ChatPromptTemplat return prompt_input, False +def supports_function_calling(model: BaseLanguageModel): + """Returns True if the model supports function calling, False otherwise + + Args: + model (BaseLanguageModel): The model to check + + Returns: + boolean: True if the model supports function calling, False otherwise + """ + return isinstance(model, ChatOpenAI) + + def get_chat_prompt_with_formatting_instructions( model: BaseLanguageModel, system_message: str, @@ -84,9 +97,14 @@ def get_chat_prompt_with_formatting_instructions( Returns: ChatPromptTemplate: ChatPromptTemplate with formatting instructions (if necessary) """ + if supports_function_calling(model): + system_message_prompt = SystemMessagePromptTemplate.from_template(system_message) + human_message_prompt = HumanMessagePromptTemplate.from_template(human_message) + return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) + output_parser = PydanticOutputParser(pydantic_object=pydantic_object) - system_message_prompt = SystemMessagePromptTemplate.from_template(system_message + "\n\n{format_instructions}") + system_message_prompt = SystemMessagePromptTemplate.from_template(system_message + "\n{format_instructions}") system_message_prompt.prompt.partial_variables = {"format_instructions": output_parser.get_format_instructions()} system_message_prompt.prompt.input_variables.remove("format_instructions") - human_message_prompt = HumanMessagePromptTemplate.from_template(human_message) + human_message_prompt = HumanMessagePromptTemplate.from_template(human_message + "\n\nJSON response following the provided schema:") return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) \ No newline at end of file