Skip to content

Commit

Permalink
Add support for function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jse committed Oct 27, 2024
1 parent 849b125 commit d384951
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 51 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ llm_generation_chain(
output_json: bool = False,
template_blocks: list[tuple[str]]=None,
keep_indentation: bool = False,
postprocess: bool = False,
progress_bar_desc: Optional[str] = None,
tools: Optional[List[Callable]] = None,
force_tool_calling: bool = False,
additional_postprocessing_runnable: Runnable = None,
bind_prompt_values: Dict = {},
) # returns a LangChain chain the accepts inputs and returns a string as output
Expand Down
3 changes: 3 additions & 0 deletions chainlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from langchain_core.runnables import chain


from .llm_config import (
get_total_cost,
load_config_from_file,
Expand All @@ -9,6 +10,7 @@
llm_generation_chain,
pprint_chain,
write_prompt_logs_to_file,
ToolOutput,
)
from .load_prompt import register_prompt_constants
from .utils import get_logger
Expand All @@ -22,6 +24,7 @@
"write_prompt_logs_to_file",
"get_total_cost",
"chain",
"ToolOutput",
"get_all_configured_engines",
"register_prompt_constants",
]
11 changes: 10 additions & 1 deletion chainlite/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class CustomAsyncRedisCache(AsyncRedisCache):
"""This class fixes langchain 0.2.*'s cache issue with LiteLLM
The core of the problem is that LiteLLM's Usage class should inherit from LangChain's Serializable class, but doesn't.
The core of the problem is that LiteLLM's `Usage` and `ChatCompletionMessageToolCall` classes should inherit from LangChain's Serializable class, but don't.
This class is the minimal fix to make it work.
"""

Expand All @@ -28,6 +28,15 @@ def _configure_pipeline_for_update(
key: str, pipe: Any, return_val: RETURN_VAL_TYPE, ttl: Optional[int] = None
) -> None:
for r in return_val:
if (
hasattr(r.message, "additional_kwargs")
and "tool_calls" in r.message.additional_kwargs
):
r.message.additional_kwargs["tool_calls"] = [
tool_call.dict()
for tool_call in r.message.additional_kwargs["tool_calls"]
]

if (
hasattr(r.message, "response_metadata")
and "token_usage" in r.message.response_metadata
Expand Down
121 changes: 73 additions & 48 deletions chainlite/llm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
import re
from datetime import datetime
from rich import print as pprint
from typing import Any, AsyncIterator, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
from uuid import UUID

from langchain_community.chat_models import ChatLiteLLM
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser

from langchain_core.outputs import LLMResult
from langchain_core.runnables import Runnable, chain
from tqdm.auto import tqdm
Expand Down Expand Up @@ -137,6 +138,22 @@ async def on_chain_end(
run_id = str(run_id)
if run_id in GlobalVars.prompt_logs:
# this is the final response in the entire chain
if (
isinstance(response, tuple)
and len(response) == 2
and isinstance(response[1], ToolOutput)
):
response = list(response)
response[1] = str(response[1])
elif isinstance(response, ToolOutput):
response = str(response)
if isinstance(response, tuple) and len(response) == 2:
response = list(response)
# if exactly one is not None/empty, then we want to log that one
if response[0] and not response[1]:
response = response[0]
elif not response[0] and response[1]:
response = response[1]
GlobalVars.prompt_logs[run_id]["output"] = str(
response
) # convert to str because output might be a Pydantic object (if `pydantic_class` is provided in `llm_generation_chain()`)
Expand Down Expand Up @@ -171,45 +188,6 @@ async def on_llm_end(
self.progress_bar.update(1)


async def strip(input_: AsyncIterator[str]) -> AsyncIterator[str]:
"""
Strips whitespace from a string, but supports streaming in a LangChain chain
"""
prev_chunk = (await input_.__anext__()).lstrip()
while True:
try:
current_chunk = await input_.__anext__()
except StopAsyncIteration as e:
yield prev_chunk.rstrip()
break
yield prev_chunk
prev_chunk = current_chunk


def extract_until_last_full_sentence(text):
match = partial_sentence_regex.search(text)
if match:
# Return the matched group, which should include punctuation and an optional quotation.
return match.group(1)
else:
return ""


async def postprocess_generations(input_: AsyncIterator[str]) -> AsyncIterator[str]:
buffer = ""
yielded = False
async for chunk in input_:
buffer += chunk
until_last_full_sentence = extract_until_last_full_sentence(buffer)
if len(until_last_full_sentence) > 0:
yield buffer[: len(until_last_full_sentence)]
yielded = True
buffer = buffer[len(until_last_full_sentence) :]
if not yielded:
# yield the entire input so that the output is not None
yield buffer


@chain
def string_to_pydantic_object(llm_output: str, pydantic_class: BaseModel):
try:
Expand Down Expand Up @@ -288,6 +266,38 @@ def _ensure_strict_json_schema(
return json_schema


class ToolOutput(BaseModel):
function: Callable
kwargs: dict

def __repr__(self):
return (
f"{self.function.__name__}("
+ ", ".join([f"{k}={repr(v)}" for k, v in self.kwargs.items()])
+ ")"
)


@chain
async def return_response_and_tool(
llm_output, tools: list[Callable], force_tool_calling: bool
):
response = await StrOutputParser().ainvoke(input=llm_output)
tool_output_in_json_format = llm_output.tool_calls

tool_outputs = []
for t in tool_output_in_json_format:
tool_name = t["name"]
matching_tool = next(
(tool for tool in tools if tool.__name__ == tool_name), None
)
if matching_tool:
tool_outputs.append(ToolOutput(function=matching_tool, kwargs=t["args"]))
if force_tool_calling:
return tool_outputs
return response, tool_outputs


def llm_generation_chain(
template_file: str,
engine: str,
Expand All @@ -299,9 +309,10 @@ def llm_generation_chain(
pydantic_class: BaseModel = None,
template_blocks: list[tuple[str]] = None,
keep_indentation: bool = False,
postprocess: bool = False,
progress_bar_desc: Optional[str] = None,
additional_postprocessing_runnable: Runnable = None,
tools: Optional[list[Callable]] = None,
force_tool_calling: bool = False,
bind_prompt_values: Dict = {},
force_skip_cache: bool = False,
) -> Runnable:
Expand All @@ -321,8 +332,9 @@ def llm_generation_chain(
and newer are supported
template_blocks: If provided, will use this instead of `template_file`. The format is [(role, string)] where role is one of "instruction", "input", "output"
keep_indentation (bool, optional): If True, will keep indentations at the beginning of each line in the template_file. Defaults to False.
postprocess (bool, optional): If True, postprocessing deletes incomplete sentences from the end of the generation. Defaults to False.
progress_bar_name (str, optional): If provided, will display a `tqdm` progress bar using this name
tools (List[Callable], optional): If provided, will be made available to the underlying LLM, to optionally output it for function calling. Defaults to None.
force_tool_calling (bool, optional): If True, will force the LLM to output the tools for function calling. Defaults to False.
additional_postprocessing_runnable (Runnable, optional): If provided, will be applied to the output of LLM generation, and the final output will be logged
bind_prompt_values (Dict, optional): A dictionary containing {Variable: str : Value}. Binds values to the prompt. Additional variables can be provided when the chain is called. Defaults to {}.
Expand All @@ -347,9 +359,13 @@ def llm_generation_chain(
raise IndexError(
f"Could not find any matching engines for {engine}. Please check that llm_config.yaml is configured correctly and that the API key is set in the terminal before running this script."
)
if pydantic_class and output_json:
if (
(pydantic_class and tools)
or (pydantic_class and output_json)
or (pydantic_class and output_json)
):
raise ValueError(
"At most one of `output_json` and `pydantic_class` can be used."
"At most one of `pydantic_class`, `output_json` and `tools` can be used."
)
llm_resource = random.choice(potential_llm_resources)

Expand Down Expand Up @@ -414,20 +430,29 @@ def llm_generation_chain(
temperature=temperature,
callbacks=callbacks,
)
if tools:
if force_tool_calling:
llm = llm.bind_tools(tools=tools, tool_choice="required")
else:
llm = llm.bind_tools(tools=tools)

# for variable, value in bind_prompt_values.keys():
if len(bind_prompt_values) > 0:
prompt = prompt.partial(**bind_prompt_values)
llm_generation_chain = prompt | llm | StrOutputParser()
if postprocess:
llm_generation_chain = llm_generation_chain | postprocess_generations

llm_generation_chain = prompt | llm
if tools:
llm_generation_chain = llm_generation_chain | return_response_and_tool.bind(
tools=tools, force_tool_calling=force_tool_calling
)
else:
llm_generation_chain = llm_generation_chain | strip
llm_generation_chain = llm_generation_chain | StrOutputParser()

if pydantic_class:
llm_generation_chain = llm_generation_chain | string_to_pydantic_object.bind(
pydantic_class=pydantic_class
)

if additional_postprocessing_runnable:
llm_generation_chain = llm_generation_chain | additional_postprocessing_runnable
return llm_generation_chain.with_config(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="chainlite",
version="0.3.0",
version="0.3.1",
author="Sina Semnani",
author_email="[email protected]",
description="A Python package that uses LangChain and LiteLLM to call large language model APIs easily",
Expand Down
1 change: 1 addition & 0 deletions tasks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def tests(c, log_level="info", parallel=False):

test_files = [
"./tests/test_llm_generate.py",
"./tests/test_function_calling.py",
]

pytest_command = (
Expand Down
86 changes: 86 additions & 0 deletions tests/test_function_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest

from chainlite import (
llm_generation_chain,
write_prompt_logs_to_file,
)
from chainlite.llm_generate import ToolOutput

test_engine = "gpt-4o-august"


# @tool
def get_current_weather(location: str):
"""Get the current weather in a given location"""
if "boston" in location.lower():
return "The weather is 12F"


# @tool
def add(a: int, b: int) -> int:
"""Adds a and b."""
return a + b


@pytest.mark.asyncio(scope="session")
async def test_function_calling():
test_tool_chain = llm_generation_chain(
"tool.prompt",
engine=test_engine,
max_tokens=100,
tools=[get_current_weather, add],
# force_skip_cache=True,
)
# No function calling done, just output text
text_output, tool_outputs = await test_tool_chain.ainvoke(
{"message": "What tools do you have available?"}
)
assert "functions.get_current_weather" in text_output
assert "functions.add" in text_output
assert tool_outputs == []

# Function calling needed
text_output, tool_outputs = await test_tool_chain.ainvoke(
{"message": "What is the weather like in Boston ?"}
)

assert text_output == ""
assert str(tool_outputs) == "[get_current_weather(location='Boston')]"

text_output, tool_outputs = await test_tool_chain.ainvoke(
{"message": "What 1021 + 9573?"}
)
assert text_output == ""
assert str(tool_outputs) == "[add(a=1021, b=9573)]"

write_prompt_logs_to_file("tests/llm_input_outputs.jsonl")


@pytest.mark.asyncio(scope="session")
async def test_forced_function_calling():
test_tool_chain = llm_generation_chain(
"tool.prompt",
engine=test_engine,
max_tokens=100,
tools=[get_current_weather, add],
# force_skip_cache=True,
force_tool_calling=True,
)

# Forcing function call when it is already needed
tool_outputs = await test_tool_chain.ainvoke(
{"message": "What is the weather like in New York City?"}
)

assert isinstance(tool_outputs, list)
assert str(tool_outputs) == "[get_current_weather(location='New York City')]"
print(tool_outputs)

# Forcing function call when it is not needed
tool_outputs = await test_tool_chain.ainvoke({"message": "What is your name?"})
print(tool_outputs)
assert isinstance(tool_outputs, list)
assert len(tool_outputs) > 0
assert isinstance(tool_outputs[0], ToolOutput)

write_prompt_logs_to_file("tests/llm_input_outputs.jsonl")
2 changes: 2 additions & 0 deletions tests/tool.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# input
{{ message }}

0 comments on commit d384951

Please sign in to comment.