Skip to content

Commit

Permalink
Add option to show progress bar for llm_generation_chains
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jse committed May 31, 2024
1 parent f449754 commit 175f71a
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

*.jsonl
API_KEYS
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## Installation

To install ChainLite, use the following steps:
ChainLite has been tested with Python 3.10. To install, do the following:


1. Install ChainLite via pip:
Expand Down
47 changes: 36 additions & 11 deletions chainlite/llm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@
import json
import logging
import os
from pprint import pprint
import random
import re
from typing import AsyncIterator, Optional, Any
from pprint import pprint
from typing import Any, AsyncIterator, Optional
from uuid import UUID

from langchain_community.chat_models import ChatLiteLLM
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.outputs import LLMResult
from langchain_core.runnables import Runnable, chain

from tqdm.auto import tqdm

from chainlite.llm_config import GlobalVars

from .load_prompt import load_fewshot_prompt_template
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.messages import BaseMessage
from langchain_core.runnables import chain, Runnable

from .utils import get_logger


logging.getLogger("LiteLLM").setLevel(logging.WARNING)
logging.getLogger("LiteLLM Router").setLevel(logging.WARNING)
logging.getLogger("LiteLLM Proxy").setLevel(logging.WARNING)
Expand Down Expand Up @@ -123,6 +123,25 @@ async def on_llm_end(
GlobalVars.prompt_logs[run_id]["output"] = llm_output


class ProgbarCallback(AsyncCallbackHandler):
def __init__(self, desc: str, total: int = None):
super().__init__()
self.count = 0
self.progress_bar = tqdm(total=total, desc=desc) # define a progress bar

# Override on_llm_end method. This is called after every response from LLM
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> Any:
self.count += 1
self.progress_bar.update(1)


prompt_log_handler = PromptLogHandler()


Expand Down Expand Up @@ -218,6 +237,7 @@ def llm_generation_chain(
output_json: bool = False,
keep_indentation: bool = False,
postprocess: bool = False,
progress_bar_desc: Optional[str] = None,
bind_prompt_values: dict = {},
) -> Runnable:
"""
Expand All @@ -233,7 +253,8 @@ def llm_generation_chain(
output_json (bool, optional): If True, asks the LLM API to output a JSON. This depends on the underlying model to support.
For example, GPT-4, GPT-4o and newer GPT-3.5-Turbo models support it, but require the word "json" to be present in the input. Defaults to False.
keep_indentation (bool, optional): If True, will keep indentations at the beginning of each line in the template_file. Defaults to False.
postprocess (bool, optional): If true, postprocessing deletes incomplete sentences from the end of the generation. Defaults to False.
postprocess (bool, optional): If True, postprocessing deletes incomplete sentences from the end of the generation. Defaults to False.
progress_bar_name (str, optional): If provided, will display a `tqdm` progress bar using this name
bind_prompt_values (dict, optional): A dictionary containing {Variable: str : Value}. Binds values to the prompt. Additional variables can be provided when the chain is called. Defaults to {}.
Returns:
Expand Down Expand Up @@ -290,6 +311,10 @@ def llm_generation_chain(
if output_json:
model_kwargs["response_format"] = {"type": "json_object"}

callbacks = [prompt_log_handler]
if progress_bar_desc:
cb = ProgbarCallback(progress_bar_desc)
callbacks.append(cb)
llm = ChatLiteLLM(
model_kwargs=model_kwargs,
api_base=llm_resource["api_base"] if "api_base" in llm_resource else None,
Expand All @@ -302,7 +327,7 @@ def llm_generation_chain(
"distillation_instruction": distillation_instruction,
"template_name": os.path.basename(template_file),
}, # for logging to file
callbacks=[prompt_log_handler],
callbacks=callbacks,
)

# for variable, value in bind_prompt_values.keys():
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="chainlite",
version="0.1.11",
version="0.1.12",
author="Sina Semnani",
author_email="[email protected]",
description="A Python package that uses LangChain and LiteLLM to call large language model APIs easily",
Expand All @@ -24,7 +24,7 @@
"redis[hiredis]",
],
extras_require={
"dev": ["invoke", "pytest", "pytest-asyncio", "setuptools", "wheel", "twine"],
"dev": ["invoke", "pytest", "pytest-asyncio", "setuptools", "wheel", "twine", "isort"],
},
classifiers=[
"Programming Language :: Python :: 3",
Expand Down
22 changes: 19 additions & 3 deletions tests/test_llm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from chainlite import llm_generation_chain, load_config_from_file
from chainlite.llm_config import GlobalVars
from chainlite.llm_generate import write_prompt_logs_to_file
from chainlite.llm_generate import ProgbarCallback, write_prompt_logs_to_file
from chainlite.utils import get_logger

logger = get_logger(__name__)


# load_config_from_file("./llm_config.yaml")


@pytest.mark.asyncio(scope="session")
async def test_llm_generate():
# Check that the config file has been loaded properly
Expand All @@ -20,7 +21,7 @@ async def test_llm_generate():
assert GlobalVars.local_engine_set

response = await llm_generation_chain(
template_file="test.prompt", # prompt path relative to one of the paths specified in `prompt_dirs`
template_file="test.prompt", # prompt path relative to one of the paths specified in `prompt_dirs`
engine="gpt-4o",
max_tokens=100,
).ainvoke({})
Expand All @@ -37,7 +38,22 @@ async def test_readme_example():
template_file="tests/joke.prompt",
engine="gpt-35-turbo",
max_tokens=100,
temperature=0.1,
progress_bar_desc="test1",
).ainvoke({"topic": "Life as a PhD student"})
logger.info(response)

write_prompt_logs_to_file("llm_input_outputs.jsonl")
write_prompt_logs_to_file("tests/llm_input_outputs.jsonl")


@pytest.mark.asyncio(scope="session")
async def test_batching():
response = await llm_generation_chain(
template_file="tests/joke.prompt",
engine="gpt-35-turbo",
max_tokens=100,
temperature=0.1,
progress_bar_desc="test2",
).abatch([{"topic": "Life as a PhD student"}] * 10)
assert len(response) == 10
logger.info(response)

0 comments on commit 175f71a

Please sign in to comment.