Skip to content

Commit

Permalink
feat(llms): add support for HuggingFace models loaded locally (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
akotyla authored Jul 3, 2024
1 parent 6510bd8 commit 953d8a1
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 13 deletions.
12 changes: 8 additions & 4 deletions benchmark/dbally_benchmark/e2e_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM
from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE


Expand Down Expand Up @@ -82,10 +83,13 @@ async def evaluate(cfg: DictConfig) -> Any:

engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}")

llm = LiteLLM(
model_name="gpt-4",
api_key=benchmark_cfg.openai_api_key,
)
if cfg.model_name.startswith("local/"):
llm = LocalLLM(api_key=benchmark_cfg.hf_api_key, model_name=cfg.model_name.split("/", 1)[1])
else:
llm = LiteLLM(
model_name=cfg.model_name,
api_key=benchmark_cfg.openai_api_key,
)

db = dbally.create_collection(cfg.db_name, llm)

Expand Down
11 changes: 5 additions & 6 deletions benchmark/dbally_benchmark/iql_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM
from dbally.views.structured import BaseStructuredView


Expand Down Expand Up @@ -96,13 +97,11 @@ async def evaluate(cfg: DictConfig) -> Any:
engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}")
view = VIEW_REGISTRY[ViewName(view_name)](engine)

if "gpt" in cfg.model_name:
llm = LiteLLM(
model_name=cfg.model_name,
api_key=benchmark_cfg.openai_api_key,
)
if cfg.model_name.startswith("local/"):
llm = LocalLLM(model_name=cfg.model_name.split("/", 1)[1], api_key=benchmark_cfg.hf_api_key)
else:
raise ValueError("Only OpenAI's GPT models are supported for now.")
llm = LiteLLM(api_key=benchmark_cfg.openai_api_key, model_name=cfg.model_name)

iql_generator = IQLGenerator(llm=llm)

run = None
Expand Down
7 changes: 5 additions & 2 deletions benchmark/dbally_benchmark/text2sql_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from dbally.audit.event_tracker import EventTracker
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM


def _load_db_schema(db_name: str, encoding: Optional[str] = None) -> str:
Expand Down Expand Up @@ -84,10 +85,12 @@ async def evaluate(cfg: DictConfig) -> Any:

engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}")

if "gpt" in cfg.model_name:
if cfg.model_name.startswith("local/"):
llm = LocalLLM(model_name=cfg.model_name.split("/", 1)[1], api_key=benchmark_cfg.hf_api_key)
else:
llm = LiteLLM(
model_name=cfg.model_name,
api_key=benchmark_cfg.openai_api_key,
model_name=cfg.model_name,
)

run = None
Expand Down
66 changes: 66 additions & 0 deletions docs/how-to/llms/local.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# How-To: Use Local LLMs

db-ally includes a ready-to-use implementation for local LLMs called [`LocalLLM`](../../reference/llms/local.md#dbally.llms.local.LocalLLM), which leverages the Hugging Face Transformers library to provide access to various LLMs available on Hugging Face.

## Basic Usage

Install the required dependencies for using local LLMs.

```bash
pip install dbally[local]
```

Integrate db-ally with your Local LLM

First, set up your environment to use a Hugging Face model.

```python

import os
from dbally.llms.localllm import LocalLLM

os.environ["HUGGINGFACE_API_KEY"] = "your-api-key"

llm = LocalLLM(model_name="meta-llama/Meta-Llama-3-8B-Instruct")
```

Use LLM in your collection

```python

my_collection = dbally.create_collection("my_collection", llm)
response = await my_collection.ask("Which LLM should I use?")
```

## Advanced Usage

For advanced users, you can customize your LLM using [`LocalLLMOptions`](../../reference/llms/local.md#dbally.llms.clients.local.LocalLLMOptions). Here is a list of available parameters:

- `repetition_penalty`: *float or null (optional)* - Penalizes repeated tokens to avoid repetitions.
- `do_sample`: *bool or null (optional)* - Enables sampling instead of greedy decoding.
- `best_of`: *int or null (optional)* - Generates multiple sequences and returns the one with the highest score.
- `max_new_tokens`: *int (optional)* - The maximum number of new tokens to generate.
- `top_k`: *int or null (optional)* - Limits the next token choices to the top-k probability tokens.
- `top_p`: *float or null (optional)* - Limits the next token choices to tokens within the top-p probability mass.
- `seed`: *int or null (optional)* - Sets the seed for random number generation to ensure reproducibility.
- `stop_sequences`: *list of strings or null (optional)* - Specifies sequences where the generation should stop.
- `temperature`: *float or null (optional)* - Adjusts the randomness of token selection.

```python
import dbally
from dbally.llms.clients.localllm import LocalLLMOptions

llm = LocalLLM("meta-llama/Meta-Llama-3-8B-Instruct", default_options=LocalLLMOptions(temperature=0.7))
my_collection = dbally.create_collection("my_collection", llm)
```

You can also override any default parameter on the ask [`ask`](../../reference/collection.md#dbally.Collection.ask) call.

```python
response = await my_collection.ask(
question="Which LLM should I use?",
llm_options=LocalLLMOptions(
temperature=0.65,
),
)
```
7 changes: 7 additions & 0 deletions docs/reference/llms/local.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Local

::: dbally.llms.local.LocalLLM

::: dbally.llms.clients.local.LocalLLMClient

::: dbally.llms.clients.local.LocalLLMOptions
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ nav:
- how-to/views/few-shots.md
- Using LLMs:
- how-to/llms/litellm.md
- how-to/llms/local.md
- how-to/llms/custom.md
- Using similarity indexes:
- how-to/use_custom_similarity_fetcher.md
Expand Down Expand Up @@ -60,6 +61,7 @@ nav:
- LLMs:
- reference/llms/index.md
- reference/llms/litellm.md
- reference/llms/local.md
- reference/prompt.md
- Similarity:
- reference/similarity/index.md
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ elasticsearch =
gradio =
gradio~=4.31.5
gradio_client~=0.16.4

local =
accelerate~=0.31.0
torch~=2.2.1
transformers~=4.41.2

[options.packages.find]
where = src
Expand Down
95 changes: 95 additions & 0 deletions src/dbally/llms/clients/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from dataclasses import dataclass
from typing import List, Optional, Union

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from dbally.audit.events import LLMEvent
from dbally.llms.clients.base import LLMClient, LLMOptions
from dbally.prompt.template import ChatFormat

from ..._types import NOT_GIVEN, NotGiven


@dataclass
class LocalLLMOptions(LLMOptions):
"""
Dataclass that represents all available LLM call options for the local LLM client.
Each of them is described in the [HuggingFace documentation]
(https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). # pylint: disable=line-too-long
"""

repetition_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN
do_sample: Union[Optional[bool], NotGiven] = NOT_GIVEN
best_of: Union[Optional[int], NotGiven] = NOT_GIVEN
max_new_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN
top_k: Union[Optional[int], NotGiven] = NOT_GIVEN
top_p: Union[Optional[float], NotGiven] = NOT_GIVEN
seed: Union[Optional[int], NotGiven] = NOT_GIVEN
stop_sequences: Union[Optional[List[str]], NotGiven] = NOT_GIVEN
temperature: Union[Optional[float], NotGiven] = NOT_GIVEN


class LocalLLMClient(LLMClient[LocalLLMOptions]):
"""
Client for the local LLM that supports Hugging Face models.
"""

_options_cls = LocalLLMOptions

def __init__(
self,
model_name: str,
*,
hf_api_key: Optional[str] = None,
) -> None:
"""
Constructs a new local LLMClient instance.
Args:
model_name: Name of the model to use.
hf_api_key: The Hugging Face API key for authentication.
"""

super().__init__(model_name)

self.model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_api_key
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_api_key)

async def call(
self,
conversation: ChatFormat,
options: LocalLLMOptions,
event: LLMEvent,
json_mode: bool = False,
) -> str:
"""
Makes a call to the local LLM with the provided prompt and options.
Args:
conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
options: Additional settings used by the LLM.
event: Container with the prompt, LLM response, and call metrics.
json_mode: Force the response to be in JSON format.
Returns:
Response string from LLM.
"""

input_ids = self.tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, return_tensors="pt"
).to(self.model.device)

outputs = self.model.generate(
input_ids,
eos_token_id=self.tokenizer.eos_token_id,
**options.dict(),
)
response = outputs[0][input_ids.shape[-1] :]
event.completion_tokens = len(outputs[0][input_ids.shape[-1] :])
event.prompt_tokens = len(outputs[0][: input_ids.shape[-1]])
event.total_tokens = input_ids.shape[-1]
decoded_response = self.tokenizer.decode(response, skip_special_tokens=True)
return decoded_response
60 changes: 60 additions & 0 deletions src/dbally/llms/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from functools import cached_property
from typing import Optional

from transformers import AutoTokenizer

from dbally.llms.base import LLM
from dbally.llms.clients.local import LocalLLMClient, LocalLLMOptions
from dbally.prompt.template import PromptTemplate


class LocalLLM(LLM[LocalLLMOptions]):
"""
Class for interaction with any LLM available in HuggingFace.
"""

_options_cls = LocalLLMOptions

def __init__(
self,
model_name: str,
default_options: Optional[LocalLLMOptions] = None,
*,
api_key: Optional[str] = None,
) -> None:
"""
Constructs a new local LLM instance.
Args:
model_name: Name of the model to use. This should be a model from the CausalLM class.
default_options: Default options for the LLM.
api_key: The API key for Hugging Face authentication.
"""

super().__init__(model_name, default_options)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key)
self.api_key = api_key

@cached_property
def client(self) -> LocalLLMClient:
"""
Client for the LLM.
Returns:
The client used to interact with the LLM.
"""
return LocalLLMClient(model_name=self.model_name, hf_api_key=self.api_key)

def count_tokens(self, prompt: PromptTemplate) -> int:
"""
Counts tokens in the messages.
Args:
prompt: Messages to count tokens for.
Returns:
Number of tokens in the messages.
"""

input_ids = self.tokenizer.apply_chat_template(prompt.chat)
return len(input_ids)

0 comments on commit 953d8a1

Please sign in to comment.