From 175f71ab60f4b01cd4fcd9ff3ea09d5158965963 Mon Sep 17 00:00:00 2001 From: Sina Date: Fri, 31 May 2024 02:36:20 +0000 Subject: [PATCH] Add option to show progress bar for `llm_generation_chain`s --- .gitignore | 1 + README.md | 2 +- chainlite/llm_generate.py | 47 +++++++++++++++++++++++++++++--------- setup.py | 4 ++-- tests/test_llm_generate.py | 22 +++++++++++++++--- 5 files changed, 59 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 68b9c94..2bc5089 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +*.jsonl API_KEYS \ No newline at end of file diff --git a/README.md b/README.md index d5225db..bb40a86 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ ## Installation -To install ChainLite, use the following steps: +ChainLite has been tested with Python 3.10. To install, do the following: 1. Install ChainLite via pip: diff --git a/chainlite/llm_generate.py b/chainlite/llm_generate.py index aa87a93..0d09b37 100644 --- a/chainlite/llm_generate.py +++ b/chainlite/llm_generate.py @@ -5,26 +5,26 @@ import json import logging import os -from pprint import pprint import random import re -from typing import AsyncIterator, Optional, Any +from pprint import pprint +from typing import Any, AsyncIterator, Optional from uuid import UUID +from langchain_community.chat_models import ChatLiteLLM +from langchain_core.callbacks import AsyncCallbackHandler +from langchain_core.messages import BaseMessage from langchain_core.output_parsers import StrOutputParser +from langchain_core.outputs import LLMResult +from langchain_core.runnables import Runnable, chain + +from tqdm.auto import tqdm from chainlite.llm_config import GlobalVars from .load_prompt import load_fewshot_prompt_template -from langchain_community.chat_models import ChatLiteLLM -from langchain_core.callbacks import AsyncCallbackHandler -from langchain_core.outputs import LLMResult -from langchain_core.messages import BaseMessage -from langchain_core.runnables import chain, Runnable - from .utils import get_logger - logging.getLogger("LiteLLM").setLevel(logging.WARNING) logging.getLogger("LiteLLM Router").setLevel(logging.WARNING) logging.getLogger("LiteLLM Proxy").setLevel(logging.WARNING) @@ -123,6 +123,25 @@ async def on_llm_end( GlobalVars.prompt_logs[run_id]["output"] = llm_output +class ProgbarCallback(AsyncCallbackHandler): + def __init__(self, desc: str, total: int = None): + super().__init__() + self.count = 0 + self.progress_bar = tqdm(total=total, desc=desc) # define a progress bar + + # Override on_llm_end method. This is called after every response from LLM + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> Any: + self.count += 1 + self.progress_bar.update(1) + + prompt_log_handler = PromptLogHandler() @@ -218,6 +237,7 @@ def llm_generation_chain( output_json: bool = False, keep_indentation: bool = False, postprocess: bool = False, + progress_bar_desc: Optional[str] = None, bind_prompt_values: dict = {}, ) -> Runnable: """ @@ -233,7 +253,8 @@ def llm_generation_chain( output_json (bool, optional): If True, asks the LLM API to output a JSON. This depends on the underlying model to support. For example, GPT-4, GPT-4o and newer GPT-3.5-Turbo models support it, but require the word "json" to be present in the input. Defaults to False. keep_indentation (bool, optional): If True, will keep indentations at the beginning of each line in the template_file. Defaults to False. - postprocess (bool, optional): If true, postprocessing deletes incomplete sentences from the end of the generation. Defaults to False. + postprocess (bool, optional): If True, postprocessing deletes incomplete sentences from the end of the generation. Defaults to False. + progress_bar_name (str, optional): If provided, will display a `tqdm` progress bar using this name bind_prompt_values (dict, optional): A dictionary containing {Variable: str : Value}. Binds values to the prompt. Additional variables can be provided when the chain is called. Defaults to {}. Returns: @@ -290,6 +311,10 @@ def llm_generation_chain( if output_json: model_kwargs["response_format"] = {"type": "json_object"} + callbacks = [prompt_log_handler] + if progress_bar_desc: + cb = ProgbarCallback(progress_bar_desc) + callbacks.append(cb) llm = ChatLiteLLM( model_kwargs=model_kwargs, api_base=llm_resource["api_base"] if "api_base" in llm_resource else None, @@ -302,7 +327,7 @@ def llm_generation_chain( "distillation_instruction": distillation_instruction, "template_name": os.path.basename(template_file), }, # for logging to file - callbacks=[prompt_log_handler], + callbacks=callbacks, ) # for variable, value in bind_prompt_values.keys(): diff --git a/setup.py b/setup.py index 1f51129..781c0cc 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="chainlite", - version="0.1.11", + version="0.1.12", author="Sina Semnani", author_email="sinaj@cs.stanford.edu", description="A Python package that uses LangChain and LiteLLM to call large language model APIs easily", @@ -24,7 +24,7 @@ "redis[hiredis]", ], extras_require={ - "dev": ["invoke", "pytest", "pytest-asyncio", "setuptools", "wheel", "twine"], + "dev": ["invoke", "pytest", "pytest-asyncio", "setuptools", "wheel", "twine", "isort"], }, classifiers=[ "Programming Language :: Python :: 3", diff --git a/tests/test_llm_generate.py b/tests/test_llm_generate.py index 4f96fc0..dae808d 100644 --- a/tests/test_llm_generate.py +++ b/tests/test_llm_generate.py @@ -2,7 +2,7 @@ from chainlite import llm_generation_chain, load_config_from_file from chainlite.llm_config import GlobalVars -from chainlite.llm_generate import write_prompt_logs_to_file +from chainlite.llm_generate import ProgbarCallback, write_prompt_logs_to_file from chainlite.utils import get_logger logger = get_logger(__name__) @@ -10,6 +10,7 @@ # load_config_from_file("./llm_config.yaml") + @pytest.mark.asyncio(scope="session") async def test_llm_generate(): # Check that the config file has been loaded properly @@ -20,7 +21,7 @@ async def test_llm_generate(): assert GlobalVars.local_engine_set response = await llm_generation_chain( - template_file="test.prompt", # prompt path relative to one of the paths specified in `prompt_dirs` + template_file="test.prompt", # prompt path relative to one of the paths specified in `prompt_dirs` engine="gpt-4o", max_tokens=100, ).ainvoke({}) @@ -37,7 +38,22 @@ async def test_readme_example(): template_file="tests/joke.prompt", engine="gpt-35-turbo", max_tokens=100, + temperature=0.1, + progress_bar_desc="test1", ).ainvoke({"topic": "Life as a PhD student"}) logger.info(response) - write_prompt_logs_to_file("llm_input_outputs.jsonl") \ No newline at end of file + write_prompt_logs_to_file("tests/llm_input_outputs.jsonl") + + +@pytest.mark.asyncio(scope="session") +async def test_batching(): + response = await llm_generation_chain( + template_file="tests/joke.prompt", + engine="gpt-35-turbo", + max_tokens=100, + temperature=0.1, + progress_bar_desc="test2", + ).abatch([{"topic": "Life as a PhD student"}] * 10) + assert len(response) == 10 + logger.info(response)