Skip to content

Commit

Permalink
Refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jse committed Nov 1, 2024
1 parent 0518517 commit 7283e51
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 56 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ llm_generation_chain(
template_blocks: list[tuple[str]]=None,
keep_indentation: bool = False,
progress_bar_desc: Optional[str] = None,
additional_postprocessing_runnable: Runnable = None,
tools: Optional[List[Callable]] = None,
force_tool_calling: bool = False,
additional_postprocessing_runnable: Runnable = None,
return_top_logprobs: int = 0,
bind_prompt_values: Dict = {},
) # returns a LangChain chain the accepts inputs and returns a string as output
load_config_from_file(config_file: str)
Expand Down
3 changes: 1 addition & 2 deletions chainlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
)
from .llm_generate import (
llm_generation_chain,
pprint_chain,
write_prompt_logs_to_file,
ToolOutput,
)
from .load_prompt import register_prompt_constants
from .utils import get_logger
from .utils import get_logger, pprint_chain


__all__ = [
Expand Down
7 changes: 7 additions & 0 deletions chainlite/chat_lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model": self.model}
if (
"logprobs" in response["choices"][0]
and "content" in response["choices"][0]["logprobs"]
):
llm_output["logprobs"] = (
response["choices"][0].get("logprobs").get("content")
)
return ChatResult(generations=generations, llm_output=llm_output)

def _create_message_dicts(
Expand Down
40 changes: 2 additions & 38 deletions chainlite/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from langchain_core.caches import RETURN_VAL_TYPE
from langchain_core.load.dump import dumps

from chainlite.threadsafe_dict import ThreadSafeDict

from .load_prompt import initialize_jinja_environment

# TODO move cache setting to the config file
Expand Down Expand Up @@ -65,44 +67,6 @@ def _configure_pipeline_for_update(
)


class ThreadSafeDict:
def __init__(self):
self._dict = {}
self._lock = threading.Lock()

def __setitem__(self, key, value):
with self._lock:
self._dict[key] = value

def __getitem__(self, key):
with self._lock:
return self._dict[key]

def __delitem__(self, key):
with self._lock:
del self._dict[key]

def get(self, key, default=None):
with self._lock:
return self._dict.get(key, default)

def __contains__(self, key):
with self._lock:
return key in self._dict

def items(self):
with self._lock:
return list(self._dict.items())

def keys(self):
with self._lock:
return list(self._dict.keys())

def values(self):
with self._lock:
return list(self._dict.values())


class GlobalVars:
prompt_logs = ThreadSafeDict()
all_llm_endpoints = None
Expand Down
13 changes: 0 additions & 13 deletions chainlite/llm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
from .load_prompt import load_fewshot_prompt_template
from .utils import get_logger

logging.getLogger("LiteLLM").setLevel(logging.WARNING)
logging.getLogger("LiteLLM Router").setLevel(logging.WARNING)
logging.getLogger("LiteLLM Proxy").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = get_logger(__name__)

# This regex pattern aims to capture up through the last '.', '!', '?',
Expand All @@ -41,15 +37,6 @@
partial_sentence_regex = re.compile(r'([\s\S]*?[.!?]"?)(?=(?:[^.!?]*$))')


@chain
def pprint_chain(_dict: Any) -> Any:
"""
Print intermediate results for debugging
"""
pprint(_dict)
return _dict


def is_same_prompt(template_name_1: str, template_name_2: str) -> bool:
return os.path.basename(template_name_1) == os.path.basename(template_name_2)

Expand Down
4 changes: 4 additions & 0 deletions chainlite/load_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Functionality to work with .prompt files in Jinja2 format.
"""

import re
from datetime import datetime
from functools import lru_cache
Expand Down
38 changes: 38 additions & 0 deletions chainlite/threadsafe_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import threading

class ThreadSafeDict:
def __init__(self):
self._dict = {}
self._lock = threading.Lock()

def __setitem__(self, key, value):
with self._lock:
self._dict[key] = value

def __getitem__(self, key):
with self._lock:
return self._dict[key]

def __delitem__(self, key):
with self._lock:
del self._dict[key]

def get(self, key, default=None):
with self._lock:
return self._dict.get(key, default)

def __contains__(self, key):
with self._lock:
return key in self._dict

def items(self):
with self._lock:
return list(self._dict.items())

def keys(self):
with self._lock:
return list(self._dict.keys())

def values(self):
with self._lock:
return list(self._dict.values())
14 changes: 12 additions & 2 deletions chainlite/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import logging
from typing import Optional

from typing import Any, Optional
import rich
from tqdm import tqdm
from langchain_core.runnables import chain

logging.getLogger("LiteLLM").setLevel(logging.WARNING)
logging.getLogger("LiteLLM Router").setLevel(logging.WARNING)
Expand Down Expand Up @@ -61,3 +62,12 @@ async def async_function_with_semaphore(f, i, original_index) -> tuple:
ret[original_index] = result

return ret


@chain
def pprint_chain(_dict: Any) -> Any:
"""
Print intermediate results for debugging
"""
rich.print(_dict)
return _dict

0 comments on commit 7283e51

Please sign in to comment.