-
Notifications
You must be signed in to change notification settings - Fork 15.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: llamafile embeddings support (#17976)
* **Description:** adds `LlamafileEmbeddings` class implementation for generating embeddings using [llamafile](https://github.com/Mozilla-Ocho/llamafile)-based models. Includes related unit tests and notebook showing example usage. * **Issue:** N/A * **Dependencies:** N/A
- Loading branch information
Showing
6 changed files
with
352 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "278b6c63", | ||
"metadata": {}, | ||
"source": [ | ||
"# llamafile\n", | ||
"\n", | ||
"Let's load the [llamafile](https://github.com/Mozilla-Ocho/llamafile) Embeddings class.\n", | ||
"\n", | ||
"## Setup\n", | ||
"\n", | ||
"First, the are 3 setup steps:\n", | ||
"\n", | ||
"1. Download a llamafile. In this notebook, we use `TinyLlama-1.1B-Chat-v1.0.Q5_K_M` but there are many others available on [HuggingFace](https://huggingface.co/models?other=llamafile).\n", | ||
"2. Make the llamafile executable.\n", | ||
"3. Start the llamafile in server mode.\n", | ||
"\n", | ||
"You can run the following bash script to do all this:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "43ef6dfa-9cc4-4552-8a53-5df523afae7c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%bash\n", | ||
"# llamafile setup\n", | ||
"\n", | ||
"# Step 1: Download a llamafile. The download may take several minutes.\n", | ||
"wget -nv -nc https://huggingface.co/jartine/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n", | ||
"\n", | ||
"# Step 2: Make the llamafile executable. Note: if you're on Windows, just append '.exe' to the filename.\n", | ||
"chmod +x TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n", | ||
"\n", | ||
"# Step 3: Start llamafile server in background. All the server logs will be written to 'tinyllama.log'.\n", | ||
"# Alternatively, you can just open a separate terminal outside this notebook and run: \n", | ||
"# ./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser --embedding\n", | ||
"./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser --embedding > tinyllama.log 2>&1 &\n", | ||
"pid=$!\n", | ||
"echo \"${pid}\" > .llamafile_pid # write the process pid to a file so we can terminate the server later" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3188b22f-879f-47b3-9a27-24412f6fad5f", | ||
"metadata": {}, | ||
"source": [ | ||
"## Embedding texts using LlamafileEmbeddings\n", | ||
"\n", | ||
"Now, we can use the `LlamafileEmbeddings` class to interact with the llamafile server that's currently serving our TinyLlama model at http://localhost:8080." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0be1af71", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.embeddings import LlamafileEmbeddings" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2c66e5da", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"embedder = LlamafileEmbeddings()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "01370375", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"text = \"This is a test document.\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a42e4035", | ||
"metadata": {}, | ||
"source": [ | ||
"To generate embeddings, you can either query an invidivual text, or you can query a list of texts." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "91bc875d-829b-4c3d-8e6f-fc2dda30a3bd", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"query_result = embedder.embed_query(text)\n", | ||
"query_result[:5]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a4b0d49e-0c73-44b6-aed5-5b426564e085", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"doc_result = embedder.embed_documents([text])\n", | ||
"doc_result[0][:5]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1ccc78fc-03ae-411d-ae73-74a4ee91c725", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%bash\n", | ||
"# cleanup: kill the llamafile server process\n", | ||
"kill $(cat .llamafile_pid)\n", | ||
"rm .llamafile_pid" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.7" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "e971737741ff4ec9aff7dc6155a1060a59a8a6d52c757dbbe66bf8ee389494b1" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
119 changes: 119 additions & 0 deletions
119
libs/community/langchain_community/embeddings/llamafile.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import logging | ||
from typing import List, Optional | ||
|
||
import requests | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.pydantic_v1 import BaseModel | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LlamafileEmbeddings(BaseModel, Embeddings): | ||
"""Llamafile lets you distribute and run large language models with a | ||
single file. | ||
To get started, see: https://github.com/Mozilla-Ocho/llamafile | ||
To use this class, you will need to first: | ||
1. Download a llamafile. | ||
2. Make the downloaded file executable: `chmod +x path/to/model.llamafile` | ||
3. Start the llamafile in server mode with embeddings enabled: | ||
`./path/to/model.llamafile --server --nobrowser --embedding` | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.embeddings import LlamafileEmbeddings | ||
embedder = LlamafileEmbeddings() | ||
doc_embeddings = embedder.embed_documents( | ||
[ | ||
"Alpha is the first letter of the Greek alphabet", | ||
"Beta is the second letter of the Greek alphabet", | ||
] | ||
) | ||
query_embedding = embedder.embed_query( | ||
"What is the second letter of the Greek alphabet" | ||
) | ||
""" | ||
|
||
base_url: str = "http://localhost:8080" | ||
"""Base url where the llamafile server is listening.""" | ||
|
||
request_timeout: Optional[int] = None | ||
"""Timeout for server requests""" | ||
|
||
def _embed(self, text: str) -> List[float]: | ||
try: | ||
response = requests.post( | ||
url=f"{self.base_url}/embedding", | ||
headers={ | ||
"Content-Type": "application/json", | ||
}, | ||
json={ | ||
"content": text, | ||
}, | ||
timeout=self.request_timeout, | ||
) | ||
except requests.exceptions.ConnectionError: | ||
raise requests.exceptions.ConnectionError( | ||
f"Could not connect to Llamafile server. Please make sure " | ||
f"that a server is running at {self.base_url}." | ||
) | ||
|
||
# Raise exception if we got a bad (non-200) response status code | ||
response.raise_for_status() | ||
|
||
contents = response.json() | ||
if "embedding" not in contents: | ||
raise KeyError( | ||
"Unexpected output from /embedding endpoint, output dict " | ||
"missing 'embedding' key." | ||
) | ||
|
||
embedding = contents["embedding"] | ||
|
||
# Sanity check the embedding vector: | ||
# Prior to llamafile v0.6.2, if the server was not started with the | ||
# `--embedding` option, the embedding endpoint would always return a | ||
# 0-vector. See issue: | ||
# https://github.com/Mozilla-Ocho/llamafile/issues/243 | ||
# So here we raise an exception if the vector sums to exactly 0. | ||
if sum(embedding) == 0.0: | ||
raise ValueError( | ||
"Embedding sums to 0, did you start the llamafile server with " | ||
"the `--embedding` option enabled?" | ||
) | ||
|
||
return embedding | ||
|
||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
"""Embed documents using a llamafile server running at `self.base_url`. | ||
llamafile server should be started in a separate process before invoking | ||
this method. | ||
Args: | ||
texts: The list of texts to embed. | ||
Returns: | ||
List of embeddings, one for each text. | ||
""" | ||
doc_embeddings = [] | ||
for text in texts: | ||
doc_embeddings.append(self._embed(text)) | ||
return doc_embeddings | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
"""Embed a query using a llamafile server running at `self.base_url`. | ||
llamafile server should be started in a separate process before invoking | ||
this method. | ||
Args: | ||
text: The text to embed. | ||
Returns: | ||
Embeddings for the text. | ||
""" | ||
return self._embed(text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
67 changes: 67 additions & 0 deletions
67
libs/community/tests/unit_tests/embeddings/test_llamafile.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import json | ||
|
||
import numpy as np | ||
import requests | ||
from pytest import MonkeyPatch | ||
|
||
from langchain_community.embeddings import LlamafileEmbeddings | ||
|
||
|
||
def mock_response() -> requests.Response: | ||
contents = json.dumps({"embedding": np.random.randn(512).tolist()}) | ||
response = requests.Response() | ||
response.status_code = 200 | ||
response._content = str.encode(contents) | ||
return response | ||
|
||
|
||
def test_embed_documents(monkeypatch: MonkeyPatch) -> None: | ||
""" | ||
Test basic functionality of the `embed_documents` method | ||
""" | ||
embedder = LlamafileEmbeddings( | ||
base_url="http://llamafile-host:8080", | ||
) | ||
|
||
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def] | ||
assert url == "http://llamafile-host:8080/embedding" | ||
assert headers == { | ||
"Content-Type": "application/json", | ||
} | ||
# 'unknown' kwarg should be ignored | ||
assert json == {"content": "Test text"} | ||
# assert stream is False | ||
assert timeout is None | ||
return mock_response() | ||
|
||
monkeypatch.setattr(requests, "post", mock_post) | ||
out = embedder.embed_documents(["Test text", "Test text"]) | ||
assert isinstance(out, list) | ||
assert len(out) == 2 | ||
for vec in out: | ||
assert len(vec) == 512 | ||
|
||
|
||
def test_embed_query(monkeypatch: MonkeyPatch) -> None: | ||
""" | ||
Test basic functionality of the `embed_query` method | ||
""" | ||
embedder = LlamafileEmbeddings( | ||
base_url="http://llamafile-host:8080", | ||
) | ||
|
||
def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def] | ||
assert url == "http://llamafile-host:8080/embedding" | ||
assert headers == { | ||
"Content-Type": "application/json", | ||
} | ||
# 'unknown' kwarg should be ignored | ||
assert json == {"content": "Test text"} | ||
# assert stream is False | ||
assert timeout is None | ||
return mock_response() | ||
|
||
monkeypatch.setattr(requests, "post", mock_post) | ||
out = embedder.embed_query("Test text") | ||
assert isinstance(out, list) | ||
assert len(out) == 512 |