Skip to content

Commit

Permalink
community[minor]: llamafile embeddings support (#17976)
Browse files Browse the repository at this point in the history
* **Description:** adds `LlamafileEmbeddings` class implementation for
generating embeddings using
[llamafile](https://github.com/Mozilla-Ocho/llamafile)-based models.
Includes related unit tests and notebook showing example usage.
* **Issue:** N/A
* **Dependencies:** N/A
  • Loading branch information
k8si authored Mar 1, 2024
1 parent c3c987d commit b7c71e2
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 0 deletions.
157 changes: 157 additions & 0 deletions docs/docs/integrations/text_embedding/llamafile.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "278b6c63",
"metadata": {},
"source": [
"# llamafile\n",
"\n",
"Let's load the [llamafile](https://github.com/Mozilla-Ocho/llamafile) Embeddings class.\n",
"\n",
"## Setup\n",
"\n",
"First, the are 3 setup steps:\n",
"\n",
"1. Download a llamafile. In this notebook, we use `TinyLlama-1.1B-Chat-v1.0.Q5_K_M` but there are many others available on [HuggingFace](https://huggingface.co/models?other=llamafile).\n",
"2. Make the llamafile executable.\n",
"3. Start the llamafile in server mode.\n",
"\n",
"You can run the following bash script to do all this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43ef6dfa-9cc4-4552-8a53-5df523afae7c",
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"# llamafile setup\n",
"\n",
"# Step 1: Download a llamafile. The download may take several minutes.\n",
"wget -nv -nc https://huggingface.co/jartine/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n",
"\n",
"# Step 2: Make the llamafile executable. Note: if you're on Windows, just append '.exe' to the filename.\n",
"chmod +x TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n",
"\n",
"# Step 3: Start llamafile server in background. All the server logs will be written to 'tinyllama.log'.\n",
"# Alternatively, you can just open a separate terminal outside this notebook and run: \n",
"# ./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser --embedding\n",
"./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser --embedding > tinyllama.log 2>&1 &\n",
"pid=$!\n",
"echo \"${pid}\" > .llamafile_pid # write the process pid to a file so we can terminate the server later"
]
},
{
"cell_type": "markdown",
"id": "3188b22f-879f-47b3-9a27-24412f6fad5f",
"metadata": {},
"source": [
"## Embedding texts using LlamafileEmbeddings\n",
"\n",
"Now, we can use the `LlamafileEmbeddings` class to interact with the llamafile server that's currently serving our TinyLlama model at http://localhost:8080."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0be1af71",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.embeddings import LlamafileEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c66e5da",
"metadata": {},
"outputs": [],
"source": [
"embedder = LlamafileEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01370375",
"metadata": {},
"outputs": [],
"source": [
"text = \"This is a test document.\""
]
},
{
"cell_type": "markdown",
"id": "a42e4035",
"metadata": {},
"source": [
"To generate embeddings, you can either query an invidivual text, or you can query a list of texts."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91bc875d-829b-4c3d-8e6f-fc2dda30a3bd",
"metadata": {},
"outputs": [],
"source": [
"query_result = embedder.embed_query(text)\n",
"query_result[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4b0d49e-0c73-44b6-aed5-5b426564e085",
"metadata": {},
"outputs": [],
"source": [
"doc_result = embedder.embed_documents([text])\n",
"doc_result[0][:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ccc78fc-03ae-411d-ae73-74a4ee91c725",
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"# cleanup: kill the llamafile server process\n",
"kill $(cat .llamafile_pid)\n",
"rm .llamafile_pid"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
},
"vscode": {
"interpreter": {
"hash": "e971737741ff4ec9aff7dc6155a1060a59a8a6d52c757dbbe66bf8ee389494b1"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
from langchain_community.embeddings.laser import LaserEmbeddings
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
from langchain_community.embeddings.llamafile import LlamafileEmbeddings
from langchain_community.embeddings.llm_rails import LLMRailsEmbeddings
from langchain_community.embeddings.localai import LocalAIEmbeddings
from langchain_community.embeddings.minimax import MiniMaxEmbeddings
Expand Down Expand Up @@ -112,6 +113,7 @@
"JinaEmbeddings",
"LaserEmbeddings",
"LlamaCppEmbeddings",
"LlamafileEmbeddings",
"LLMRailsEmbeddings",
"HuggingFaceHubEmbeddings",
"MlflowEmbeddings",
Expand Down
119 changes: 119 additions & 0 deletions libs/community/langchain_community/embeddings/llamafile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import logging
from typing import List, Optional

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel

logger = logging.getLogger(__name__)


class LlamafileEmbeddings(BaseModel, Embeddings):
"""Llamafile lets you distribute and run large language models with a
single file.
To get started, see: https://github.com/Mozilla-Ocho/llamafile
To use this class, you will need to first:
1. Download a llamafile.
2. Make the downloaded file executable: `chmod +x path/to/model.llamafile`
3. Start the llamafile in server mode with embeddings enabled:
`./path/to/model.llamafile --server --nobrowser --embedding`
Example:
.. code-block:: python
from langchain_community.embeddings import LlamafileEmbeddings
embedder = LlamafileEmbeddings()
doc_embeddings = embedder.embed_documents(
[
"Alpha is the first letter of the Greek alphabet",
"Beta is the second letter of the Greek alphabet",
]
)
query_embedding = embedder.embed_query(
"What is the second letter of the Greek alphabet"
)
"""

base_url: str = "http://localhost:8080"
"""Base url where the llamafile server is listening."""

request_timeout: Optional[int] = None
"""Timeout for server requests"""

def _embed(self, text: str) -> List[float]:
try:
response = requests.post(
url=f"{self.base_url}/embedding",
headers={
"Content-Type": "application/json",
},
json={
"content": text,
},
timeout=self.request_timeout,
)
except requests.exceptions.ConnectionError:
raise requests.exceptions.ConnectionError(
f"Could not connect to Llamafile server. Please make sure "
f"that a server is running at {self.base_url}."
)

# Raise exception if we got a bad (non-200) response status code
response.raise_for_status()

contents = response.json()
if "embedding" not in contents:
raise KeyError(
"Unexpected output from /embedding endpoint, output dict "
"missing 'embedding' key."
)

embedding = contents["embedding"]

# Sanity check the embedding vector:
# Prior to llamafile v0.6.2, if the server was not started with the
# `--embedding` option, the embedding endpoint would always return a
# 0-vector. See issue:
# https://github.com/Mozilla-Ocho/llamafile/issues/243
# So here we raise an exception if the vector sums to exactly 0.
if sum(embedding) == 0.0:
raise ValueError(
"Embedding sums to 0, did you start the llamafile server with "
"the `--embedding` option enabled?"
)

return embedding

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed documents using a llamafile server running at `self.base_url`.
llamafile server should be started in a separate process before invoking
this method.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
doc_embeddings = []
for text in texts:
doc_embeddings.append(self._embed(text))
return doc_embeddings

def embed_query(self, text: str) -> List[float]:
"""Embed a query using a llamafile server running at `self.base_url`.
llamafile server should be started in a separate process before invoking
this method.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self._embed(text)
6 changes: 6 additions & 0 deletions libs/community/langchain_community/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ def _import_llamacpp() -> Type[BaseLLM]:
return LlamaCpp


def _import_llamafile() -> Type[BaseLLM]:
from langchain_community.llms.llamafile import Llamafile

return Llamafile


def _import_manifest() -> Type[BaseLLM]:
from langchain_community.llms.manifest import ManifestWrapper

Expand Down
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"JinaEmbeddings",
"LaserEmbeddings",
"LlamaCppEmbeddings",
"LlamafileEmbeddings",
"LLMRailsEmbeddings",
"HuggingFaceHubEmbeddings",
"MlflowAIGatewayEmbeddings",
Expand Down
67 changes: 67 additions & 0 deletions libs/community/tests/unit_tests/embeddings/test_llamafile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import json

import numpy as np
import requests
from pytest import MonkeyPatch

from langchain_community.embeddings import LlamafileEmbeddings


def mock_response() -> requests.Response:
contents = json.dumps({"embedding": np.random.randn(512).tolist()})
response = requests.Response()
response.status_code = 200
response._content = str.encode(contents)
return response


def test_embed_documents(monkeypatch: MonkeyPatch) -> None:
"""
Test basic functionality of the `embed_documents` method
"""
embedder = LlamafileEmbeddings(
base_url="http://llamafile-host:8080",
)

def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
assert url == "http://llamafile-host:8080/embedding"
assert headers == {
"Content-Type": "application/json",
}
# 'unknown' kwarg should be ignored
assert json == {"content": "Test text"}
# assert stream is False
assert timeout is None
return mock_response()

monkeypatch.setattr(requests, "post", mock_post)
out = embedder.embed_documents(["Test text", "Test text"])
assert isinstance(out, list)
assert len(out) == 2
for vec in out:
assert len(vec) == 512


def test_embed_query(monkeypatch: MonkeyPatch) -> None:
"""
Test basic functionality of the `embed_query` method
"""
embedder = LlamafileEmbeddings(
base_url="http://llamafile-host:8080",
)

def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def]
assert url == "http://llamafile-host:8080/embedding"
assert headers == {
"Content-Type": "application/json",
}
# 'unknown' kwarg should be ignored
assert json == {"content": "Test text"}
# assert stream is False
assert timeout is None
return mock_response()

monkeypatch.setattr(requests, "post", mock_post)
out = embedder.embed_query("Test text")
assert isinstance(out, list)
assert len(out) == 512

0 comments on commit b7c71e2

Please sign in to comment.