-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(llms): add support for HuggingFace models loaded locally (#61)
- Loading branch information
Showing
9 changed files
with
252 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# How-To: Use Local LLMs | ||
|
||
db-ally includes a ready-to-use implementation for local LLMs called [`LocalLLM`](../../reference/llms/local.md#dbally.llms.local.LocalLLM), which leverages the Hugging Face Transformers library to provide access to various LLMs available on Hugging Face. | ||
|
||
## Basic Usage | ||
|
||
Install the required dependencies for using local LLMs. | ||
|
||
```bash | ||
pip install dbally[local] | ||
``` | ||
|
||
Integrate db-ally with your Local LLM | ||
|
||
First, set up your environment to use a Hugging Face model. | ||
|
||
```python | ||
|
||
import os | ||
from dbally.llms.localllm import LocalLLM | ||
|
||
os.environ["HUGGINGFACE_API_KEY"] = "your-api-key" | ||
|
||
llm = LocalLLM(model_name="meta-llama/Meta-Llama-3-8B-Instruct") | ||
``` | ||
|
||
Use LLM in your collection | ||
|
||
```python | ||
|
||
my_collection = dbally.create_collection("my_collection", llm) | ||
response = await my_collection.ask("Which LLM should I use?") | ||
``` | ||
|
||
## Advanced Usage | ||
|
||
For advanced users, you can customize your LLM using [`LocalLLMOptions`](../../reference/llms/local.md#dbally.llms.clients.local.LocalLLMOptions). Here is a list of available parameters: | ||
|
||
- `repetition_penalty`: *float or null (optional)* - Penalizes repeated tokens to avoid repetitions. | ||
- `do_sample`: *bool or null (optional)* - Enables sampling instead of greedy decoding. | ||
- `best_of`: *int or null (optional)* - Generates multiple sequences and returns the one with the highest score. | ||
- `max_new_tokens`: *int (optional)* - The maximum number of new tokens to generate. | ||
- `top_k`: *int or null (optional)* - Limits the next token choices to the top-k probability tokens. | ||
- `top_p`: *float or null (optional)* - Limits the next token choices to tokens within the top-p probability mass. | ||
- `seed`: *int or null (optional)* - Sets the seed for random number generation to ensure reproducibility. | ||
- `stop_sequences`: *list of strings or null (optional)* - Specifies sequences where the generation should stop. | ||
- `temperature`: *float or null (optional)* - Adjusts the randomness of token selection. | ||
|
||
```python | ||
import dbally | ||
from dbally.llms.clients.localllm import LocalLLMOptions | ||
|
||
llm = LocalLLM("meta-llama/Meta-Llama-3-8B-Instruct", default_options=LocalLLMOptions(temperature=0.7)) | ||
my_collection = dbally.create_collection("my_collection", llm) | ||
``` | ||
|
||
You can also override any default parameter on the ask [`ask`](../../reference/collection.md#dbally.Collection.ask) call. | ||
|
||
```python | ||
response = await my_collection.ask( | ||
question="Which LLM should I use?", | ||
llm_options=LocalLLMOptions( | ||
temperature=0.65, | ||
), | ||
) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Local | ||
|
||
::: dbally.llms.local.LocalLLM | ||
|
||
::: dbally.llms.clients.local.LocalLLMClient | ||
|
||
::: dbally.llms.clients.local.LocalLLMOptions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from dataclasses import dataclass | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from dbally.audit.events import LLMEvent | ||
from dbally.llms.clients.base import LLMClient, LLMOptions | ||
from dbally.prompt.template import ChatFormat | ||
|
||
from ..._types import NOT_GIVEN, NotGiven | ||
|
||
|
||
@dataclass | ||
class LocalLLMOptions(LLMOptions): | ||
""" | ||
Dataclass that represents all available LLM call options for the local LLM client. | ||
Each of them is described in the [HuggingFace documentation] | ||
(https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). # pylint: disable=line-too-long | ||
""" | ||
|
||
repetition_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN | ||
do_sample: Union[Optional[bool], NotGiven] = NOT_GIVEN | ||
best_of: Union[Optional[int], NotGiven] = NOT_GIVEN | ||
max_new_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN | ||
top_k: Union[Optional[int], NotGiven] = NOT_GIVEN | ||
top_p: Union[Optional[float], NotGiven] = NOT_GIVEN | ||
seed: Union[Optional[int], NotGiven] = NOT_GIVEN | ||
stop_sequences: Union[Optional[List[str]], NotGiven] = NOT_GIVEN | ||
temperature: Union[Optional[float], NotGiven] = NOT_GIVEN | ||
|
||
|
||
class LocalLLMClient(LLMClient[LocalLLMOptions]): | ||
""" | ||
Client for the local LLM that supports Hugging Face models. | ||
""" | ||
|
||
_options_cls = LocalLLMOptions | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
*, | ||
hf_api_key: Optional[str] = None, | ||
) -> None: | ||
""" | ||
Constructs a new local LLMClient instance. | ||
Args: | ||
model_name: Name of the model to use. | ||
hf_api_key: The Hugging Face API key for authentication. | ||
""" | ||
|
||
super().__init__(model_name) | ||
|
||
self.model = AutoModelForCausalLM.from_pretrained( | ||
model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_api_key | ||
) | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_api_key) | ||
|
||
async def call( | ||
self, | ||
conversation: ChatFormat, | ||
options: LocalLLMOptions, | ||
event: LLMEvent, | ||
json_mode: bool = False, | ||
) -> str: | ||
""" | ||
Makes a call to the local LLM with the provided prompt and options. | ||
Args: | ||
conversation: List of dicts with "role" and "content" keys, representing the chat history so far. | ||
options: Additional settings used by the LLM. | ||
event: Container with the prompt, LLM response, and call metrics. | ||
json_mode: Force the response to be in JSON format. | ||
Returns: | ||
Response string from LLM. | ||
""" | ||
|
||
input_ids = self.tokenizer.apply_chat_template( | ||
conversation, add_generation_prompt=True, return_tensors="pt" | ||
).to(self.model.device) | ||
|
||
outputs = self.model.generate( | ||
input_ids, | ||
eos_token_id=self.tokenizer.eos_token_id, | ||
**options.dict(), | ||
) | ||
response = outputs[0][input_ids.shape[-1] :] | ||
event.completion_tokens = len(outputs[0][input_ids.shape[-1] :]) | ||
event.prompt_tokens = len(outputs[0][: input_ids.shape[-1]]) | ||
event.total_tokens = input_ids.shape[-1] | ||
decoded_response = self.tokenizer.decode(response, skip_special_tokens=True) | ||
return decoded_response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from functools import cached_property | ||
from typing import Optional | ||
|
||
from transformers import AutoTokenizer | ||
|
||
from dbally.llms.base import LLM | ||
from dbally.llms.clients.local import LocalLLMClient, LocalLLMOptions | ||
from dbally.prompt.template import PromptTemplate | ||
|
||
|
||
class LocalLLM(LLM[LocalLLMOptions]): | ||
""" | ||
Class for interaction with any LLM available in HuggingFace. | ||
""" | ||
|
||
_options_cls = LocalLLMOptions | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
default_options: Optional[LocalLLMOptions] = None, | ||
*, | ||
api_key: Optional[str] = None, | ||
) -> None: | ||
""" | ||
Constructs a new local LLM instance. | ||
Args: | ||
model_name: Name of the model to use. This should be a model from the CausalLM class. | ||
default_options: Default options for the LLM. | ||
api_key: The API key for Hugging Face authentication. | ||
""" | ||
|
||
super().__init__(model_name, default_options) | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key) | ||
self.api_key = api_key | ||
|
||
@cached_property | ||
def client(self) -> LocalLLMClient: | ||
""" | ||
Client for the LLM. | ||
Returns: | ||
The client used to interact with the LLM. | ||
""" | ||
return LocalLLMClient(model_name=self.model_name, hf_api_key=self.api_key) | ||
|
||
def count_tokens(self, prompt: PromptTemplate) -> int: | ||
""" | ||
Counts tokens in the messages. | ||
Args: | ||
prompt: Messages to count tokens for. | ||
Returns: | ||
Number of tokens in the messages. | ||
""" | ||
|
||
input_ids = self.tokenizer.apply_chat_template(prompt.chat) | ||
return len(input_ids) |