From b7c71e2e07037c2b340fd1f6dcefb34dff1e2f92 Mon Sep 17 00:00:00 2001 From: Kate Silverstein Date: Fri, 1 Mar 2024 16:49:18 -0500 Subject: [PATCH] community[minor]: llamafile embeddings support (#17976) * **Description:** adds `LlamafileEmbeddings` class implementation for generating embeddings using [llamafile](https://github.com/Mozilla-Ocho/llamafile)-based models. Includes related unit tests and notebook showing example usage. * **Issue:** N/A * **Dependencies:** N/A --- .../text_embedding/llamafile.ipynb | 157 ++++++++++++++++++ .../embeddings/__init__.py | 2 + .../embeddings/llamafile.py | 119 +++++++++++++ .../langchain_community/llms/__init__.py | 6 + .../unit_tests/embeddings/test_imports.py | 1 + .../unit_tests/embeddings/test_llamafile.py | 67 ++++++++ 6 files changed, 352 insertions(+) create mode 100644 docs/docs/integrations/text_embedding/llamafile.ipynb create mode 100644 libs/community/langchain_community/embeddings/llamafile.py create mode 100644 libs/community/tests/unit_tests/embeddings/test_llamafile.py diff --git a/docs/docs/integrations/text_embedding/llamafile.ipynb b/docs/docs/integrations/text_embedding/llamafile.ipynb new file mode 100644 index 0000000000000..190520fd7cd28 --- /dev/null +++ b/docs/docs/integrations/text_embedding/llamafile.ipynb @@ -0,0 +1,157 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "278b6c63", + "metadata": {}, + "source": [ + "# llamafile\n", + "\n", + "Let's load the [llamafile](https://github.com/Mozilla-Ocho/llamafile) Embeddings class.\n", + "\n", + "## Setup\n", + "\n", + "First, the are 3 setup steps:\n", + "\n", + "1. Download a llamafile. In this notebook, we use `TinyLlama-1.1B-Chat-v1.0.Q5_K_M` but there are many others available on [HuggingFace](https://huggingface.co/models?other=llamafile).\n", + "2. Make the llamafile executable.\n", + "3. Start the llamafile in server mode.\n", + "\n", + "You can run the following bash script to do all this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43ef6dfa-9cc4-4552-8a53-5df523afae7c", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "# llamafile setup\n", + "\n", + "# Step 1: Download a llamafile. The download may take several minutes.\n", + "wget -nv -nc https://huggingface.co/jartine/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n", + "\n", + "# Step 2: Make the llamafile executable. Note: if you're on Windows, just append '.exe' to the filename.\n", + "chmod +x TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile\n", + "\n", + "# Step 3: Start llamafile server in background. All the server logs will be written to 'tinyllama.log'.\n", + "# Alternatively, you can just open a separate terminal outside this notebook and run: \n", + "# ./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser --embedding\n", + "./TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile --server --nobrowser --embedding > tinyllama.log 2>&1 &\n", + "pid=$!\n", + "echo \"${pid}\" > .llamafile_pid # write the process pid to a file so we can terminate the server later" + ] + }, + { + "cell_type": "markdown", + "id": "3188b22f-879f-47b3-9a27-24412f6fad5f", + "metadata": {}, + "source": [ + "## Embedding texts using LlamafileEmbeddings\n", + "\n", + "Now, we can use the `LlamafileEmbeddings` class to interact with the llamafile server that's currently serving our TinyLlama model at http://localhost:8080." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0be1af71", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.embeddings import LlamafileEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c66e5da", + "metadata": {}, + "outputs": [], + "source": [ + "embedder = LlamafileEmbeddings()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01370375", + "metadata": {}, + "outputs": [], + "source": [ + "text = \"This is a test document.\"" + ] + }, + { + "cell_type": "markdown", + "id": "a42e4035", + "metadata": {}, + "source": [ + "To generate embeddings, you can either query an invidivual text, or you can query a list of texts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91bc875d-829b-4c3d-8e6f-fc2dda30a3bd", + "metadata": {}, + "outputs": [], + "source": [ + "query_result = embedder.embed_query(text)\n", + "query_result[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4b0d49e-0c73-44b6-aed5-5b426564e085", + "metadata": {}, + "outputs": [], + "source": [ + "doc_result = embedder.embed_documents([text])\n", + "doc_result[0][:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ccc78fc-03ae-411d-ae73-74a4ee91c725", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "# cleanup: kill the llamafile server process\n", + "kill $(cat .llamafile_pid)\n", + "rm .llamafile_pid" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + }, + "vscode": { + "interpreter": { + "hash": "e971737741ff4ec9aff7dc6155a1060a59a8a6d52c757dbbe66bf8ee389494b1" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index d228995afff73..a5076944780f4 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -57,6 +57,7 @@ from langchain_community.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings from langchain_community.embeddings.laser import LaserEmbeddings from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings +from langchain_community.embeddings.llamafile import LlamafileEmbeddings from langchain_community.embeddings.llm_rails import LLMRailsEmbeddings from langchain_community.embeddings.localai import LocalAIEmbeddings from langchain_community.embeddings.minimax import MiniMaxEmbeddings @@ -112,6 +113,7 @@ "JinaEmbeddings", "LaserEmbeddings", "LlamaCppEmbeddings", + "LlamafileEmbeddings", "LLMRailsEmbeddings", "HuggingFaceHubEmbeddings", "MlflowEmbeddings", diff --git a/libs/community/langchain_community/embeddings/llamafile.py b/libs/community/langchain_community/embeddings/llamafile.py new file mode 100644 index 0000000000000..310f61a03e1d0 --- /dev/null +++ b/libs/community/langchain_community/embeddings/llamafile.py @@ -0,0 +1,119 @@ +import logging +from typing import List, Optional + +import requests +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel + +logger = logging.getLogger(__name__) + + +class LlamafileEmbeddings(BaseModel, Embeddings): + """Llamafile lets you distribute and run large language models with a + single file. + + To get started, see: https://github.com/Mozilla-Ocho/llamafile + + To use this class, you will need to first: + + 1. Download a llamafile. + 2. Make the downloaded file executable: `chmod +x path/to/model.llamafile` + 3. Start the llamafile in server mode with embeddings enabled: + + `./path/to/model.llamafile --server --nobrowser --embedding` + + Example: + .. code-block:: python + + from langchain_community.embeddings import LlamafileEmbeddings + embedder = LlamafileEmbeddings() + doc_embeddings = embedder.embed_documents( + [ + "Alpha is the first letter of the Greek alphabet", + "Beta is the second letter of the Greek alphabet", + ] + ) + query_embedding = embedder.embed_query( + "What is the second letter of the Greek alphabet" + ) + + """ + + base_url: str = "http://localhost:8080" + """Base url where the llamafile server is listening.""" + + request_timeout: Optional[int] = None + """Timeout for server requests""" + + def _embed(self, text: str) -> List[float]: + try: + response = requests.post( + url=f"{self.base_url}/embedding", + headers={ + "Content-Type": "application/json", + }, + json={ + "content": text, + }, + timeout=self.request_timeout, + ) + except requests.exceptions.ConnectionError: + raise requests.exceptions.ConnectionError( + f"Could not connect to Llamafile server. Please make sure " + f"that a server is running at {self.base_url}." + ) + + # Raise exception if we got a bad (non-200) response status code + response.raise_for_status() + + contents = response.json() + if "embedding" not in contents: + raise KeyError( + "Unexpected output from /embedding endpoint, output dict " + "missing 'embedding' key." + ) + + embedding = contents["embedding"] + + # Sanity check the embedding vector: + # Prior to llamafile v0.6.2, if the server was not started with the + # `--embedding` option, the embedding endpoint would always return a + # 0-vector. See issue: + # https://github.com/Mozilla-Ocho/llamafile/issues/243 + # So here we raise an exception if the vector sums to exactly 0. + if sum(embedding) == 0.0: + raise ValueError( + "Embedding sums to 0, did you start the llamafile server with " + "the `--embedding` option enabled?" + ) + + return embedding + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed documents using a llamafile server running at `self.base_url`. + llamafile server should be started in a separate process before invoking + this method. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + doc_embeddings = [] + for text in texts: + doc_embeddings.append(self._embed(text)) + return doc_embeddings + + def embed_query(self, text: str) -> List[float]: + """Embed a query using a llamafile server running at `self.base_url`. + llamafile server should be started in a separate process before invoking + this method. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self._embed(text) diff --git a/libs/community/langchain_community/llms/__init__.py b/libs/community/langchain_community/llms/__init__.py index b3b508c69a112..4d29156af0a17 100644 --- a/libs/community/langchain_community/llms/__init__.py +++ b/libs/community/langchain_community/llms/__init__.py @@ -295,6 +295,12 @@ def _import_llamacpp() -> Type[BaseLLM]: return LlamaCpp +def _import_llamafile() -> Type[BaseLLM]: + from langchain_community.llms.llamafile import Llamafile + + return Llamafile + + def _import_manifest() -> Type[BaseLLM]: from langchain_community.llms.manifest import ManifestWrapper diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index 27d4bc301b605..5ca203d9344d8 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -17,6 +17,7 @@ "JinaEmbeddings", "LaserEmbeddings", "LlamaCppEmbeddings", + "LlamafileEmbeddings", "LLMRailsEmbeddings", "HuggingFaceHubEmbeddings", "MlflowAIGatewayEmbeddings", diff --git a/libs/community/tests/unit_tests/embeddings/test_llamafile.py b/libs/community/tests/unit_tests/embeddings/test_llamafile.py new file mode 100644 index 0000000000000..e69c9d9c46bea --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_llamafile.py @@ -0,0 +1,67 @@ +import json + +import numpy as np +import requests +from pytest import MonkeyPatch + +from langchain_community.embeddings import LlamafileEmbeddings + + +def mock_response() -> requests.Response: + contents = json.dumps({"embedding": np.random.randn(512).tolist()}) + response = requests.Response() + response.status_code = 200 + response._content = str.encode(contents) + return response + + +def test_embed_documents(monkeypatch: MonkeyPatch) -> None: + """ + Test basic functionality of the `embed_documents` method + """ + embedder = LlamafileEmbeddings( + base_url="http://llamafile-host:8080", + ) + + def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def] + assert url == "http://llamafile-host:8080/embedding" + assert headers == { + "Content-Type": "application/json", + } + # 'unknown' kwarg should be ignored + assert json == {"content": "Test text"} + # assert stream is False + assert timeout is None + return mock_response() + + monkeypatch.setattr(requests, "post", mock_post) + out = embedder.embed_documents(["Test text", "Test text"]) + assert isinstance(out, list) + assert len(out) == 2 + for vec in out: + assert len(vec) == 512 + + +def test_embed_query(monkeypatch: MonkeyPatch) -> None: + """ + Test basic functionality of the `embed_query` method + """ + embedder = LlamafileEmbeddings( + base_url="http://llamafile-host:8080", + ) + + def mock_post(url, headers, json, timeout): # type: ignore[no-untyped-def] + assert url == "http://llamafile-host:8080/embedding" + assert headers == { + "Content-Type": "application/json", + } + # 'unknown' kwarg should be ignored + assert json == {"content": "Test text"} + # assert stream is False + assert timeout is None + return mock_response() + + monkeypatch.setattr(requests, "post", mock_post) + out = embedder.embed_query("Test text") + assert isinstance(out, list) + assert len(out) == 512