From 7283e51e4789b0ddc19050239428ae1c88d2ed60 Mon Sep 17 00:00:00 2001 From: Sina Date: Fri, 1 Nov 2024 08:11:42 +0000 Subject: [PATCH] Refactor code --- README.md | 3 ++- chainlite/__init__.py | 3 +-- chainlite/chat_lite_llm.py | 7 +++++++ chainlite/llm_config.py | 40 ++---------------------------------- chainlite/llm_generate.py | 13 ------------ chainlite/load_prompt.py | 4 ++++ chainlite/threadsafe_dict.py | 38 ++++++++++++++++++++++++++++++++++ chainlite/utils.py | 14 +++++++++++-- 8 files changed, 66 insertions(+), 56 deletions(-) create mode 100644 chainlite/threadsafe_dict.py diff --git a/README.md b/README.md index bd918a4..438bd55 100644 --- a/README.md +++ b/README.md @@ -50,9 +50,10 @@ llm_generation_chain( template_blocks: list[tuple[str]]=None, keep_indentation: bool = False, progress_bar_desc: Optional[str] = None, + additional_postprocessing_runnable: Runnable = None, tools: Optional[List[Callable]] = None, force_tool_calling: bool = False, - additional_postprocessing_runnable: Runnable = None, + return_top_logprobs: int = 0, bind_prompt_values: Dict = {}, ) # returns a LangChain chain the accepts inputs and returns a string as output load_config_from_file(config_file: str) diff --git a/chainlite/__init__.py b/chainlite/__init__.py index 9347beb..f84ac96 100644 --- a/chainlite/__init__.py +++ b/chainlite/__init__.py @@ -8,12 +8,11 @@ ) from .llm_generate import ( llm_generation_chain, - pprint_chain, write_prompt_logs_to_file, ToolOutput, ) from .load_prompt import register_prompt_constants -from .utils import get_logger +from .utils import get_logger, pprint_chain __all__ = [ diff --git a/chainlite/chat_lite_llm.py b/chainlite/chat_lite_llm.py index bd52a51..24ed614 100644 --- a/chainlite/chat_lite_llm.py +++ b/chainlite/chat_lite_llm.py @@ -305,6 +305,13 @@ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: generations.append(gen) token_usage = response.get("usage", {}) llm_output = {"token_usage": token_usage, "model": self.model} + if ( + "logprobs" in response["choices"][0] + and "content" in response["choices"][0]["logprobs"] + ): + llm_output["logprobs"] = ( + response["choices"][0].get("logprobs").get("content") + ) return ChatResult(generations=generations, llm_output=llm_output) def _create_message_dicts( diff --git a/chainlite/llm_config.py b/chainlite/llm_config.py index b1fbd38..b234842 100644 --- a/chainlite/llm_config.py +++ b/chainlite/llm_config.py @@ -10,6 +10,8 @@ from langchain_core.caches import RETURN_VAL_TYPE from langchain_core.load.dump import dumps +from chainlite.threadsafe_dict import ThreadSafeDict + from .load_prompt import initialize_jinja_environment # TODO move cache setting to the config file @@ -65,44 +67,6 @@ def _configure_pipeline_for_update( ) -class ThreadSafeDict: - def __init__(self): - self._dict = {} - self._lock = threading.Lock() - - def __setitem__(self, key, value): - with self._lock: - self._dict[key] = value - - def __getitem__(self, key): - with self._lock: - return self._dict[key] - - def __delitem__(self, key): - with self._lock: - del self._dict[key] - - def get(self, key, default=None): - with self._lock: - return self._dict.get(key, default) - - def __contains__(self, key): - with self._lock: - return key in self._dict - - def items(self): - with self._lock: - return list(self._dict.items()) - - def keys(self): - with self._lock: - return list(self._dict.keys()) - - def values(self): - with self._lock: - return list(self._dict.values()) - - class GlobalVars: prompt_logs = ThreadSafeDict() all_llm_endpoints = None diff --git a/chainlite/llm_generate.py b/chainlite/llm_generate.py index 4e14ded..0541576 100644 --- a/chainlite/llm_generate.py +++ b/chainlite/llm_generate.py @@ -28,10 +28,6 @@ from .load_prompt import load_fewshot_prompt_template from .utils import get_logger -logging.getLogger("LiteLLM").setLevel(logging.WARNING) -logging.getLogger("LiteLLM Router").setLevel(logging.WARNING) -logging.getLogger("LiteLLM Proxy").setLevel(logging.WARNING) -logging.getLogger("httpx").setLevel(logging.WARNING) logger = get_logger(__name__) # This regex pattern aims to capture up through the last '.', '!', '?', @@ -41,15 +37,6 @@ partial_sentence_regex = re.compile(r'([\s\S]*?[.!?]"?)(?=(?:[^.!?]*$))') -@chain -def pprint_chain(_dict: Any) -> Any: - """ - Print intermediate results for debugging - """ - pprint(_dict) - return _dict - - def is_same_prompt(template_name_1: str, template_name_2: str) -> bool: return os.path.basename(template_name_1) == os.path.basename(template_name_2) diff --git a/chainlite/load_prompt.py b/chainlite/load_prompt.py index a46e82f..c374c51 100644 --- a/chainlite/load_prompt.py +++ b/chainlite/load_prompt.py @@ -1,3 +1,7 @@ +""" +Functionality to work with .prompt files in Jinja2 format. +""" + import re from datetime import datetime from functools import lru_cache diff --git a/chainlite/threadsafe_dict.py b/chainlite/threadsafe_dict.py new file mode 100644 index 0000000..8de80b8 --- /dev/null +++ b/chainlite/threadsafe_dict.py @@ -0,0 +1,38 @@ +import threading + +class ThreadSafeDict: + def __init__(self): + self._dict = {} + self._lock = threading.Lock() + + def __setitem__(self, key, value): + with self._lock: + self._dict[key] = value + + def __getitem__(self, key): + with self._lock: + return self._dict[key] + + def __delitem__(self, key): + with self._lock: + del self._dict[key] + + def get(self, key, default=None): + with self._lock: + return self._dict.get(key, default) + + def __contains__(self, key): + with self._lock: + return key in self._dict + + def items(self): + with self._lock: + return list(self._dict.items()) + + def keys(self): + with self._lock: + return list(self._dict.keys()) + + def values(self): + with self._lock: + return list(self._dict.values()) \ No newline at end of file diff --git a/chainlite/utils.py b/chainlite/utils.py index ad7b8c5..d1499af 100644 --- a/chainlite/utils.py +++ b/chainlite/utils.py @@ -1,8 +1,9 @@ import asyncio import logging -from typing import Optional - +from typing import Any, Optional +import rich from tqdm import tqdm +from langchain_core.runnables import chain logging.getLogger("LiteLLM").setLevel(logging.WARNING) logging.getLogger("LiteLLM Router").setLevel(logging.WARNING) @@ -61,3 +62,12 @@ async def async_function_with_semaphore(f, i, original_index) -> tuple: ret[original_index] = result return ret + + +@chain +def pprint_chain(_dict: Any) -> Any: + """ + Print intermediate results for debugging + """ + rich.print(_dict) + return _dict