Skip to content

Commit

Permalink
feat: Llama.cpp EF download models from HF
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Sep 10, 2024
1 parent 89fb528 commit ac4fc89
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 7 deletions.
56 changes: 50 additions & 6 deletions chromadbx/embeddings/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,68 @@
import os.path
from enum import Enum
from typing import Optional

from chromadb import EmbeddingFunction, Documents


class PoolingType(int, Enum):
NONE = 0
MEAN = 1
CLS = 2
LAST = 3


class LlamaCppEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(self, model_path: str) -> None:
def __init__(
self,
model_path: str,
*,
hf_file_name: Optional[str] = None,
pooling_type: Optional[PoolingType] = PoolingType.MEAN,
) -> None:
"""
Initialize the LlamaCppEmbeddingFunction.
:param model_path: This can be a local path to the model or the HuggingFace repository. You need to install huggingface_hub package.
:param hf_file_name: The name of the file in the HuggingFace repository.
This is only required if the model_path is a HuggingFace repository.
:param pooling_type: The pooling type to use. Default is `PoolingType.MEAN`.
"""
try:
from llama_embedder import LlamaEmbedder, PoolingType
from llama_embedder import LlamaEmbedder, PoolingType as PT
except ImportError:
raise ValueError(
"The `llama-embedder` python package is not installed. "
"Please install it with `pip install llama-embedder`"
)

if not os.path.exists(model_path):
if not os.path.exists(model_path) and hf_file_name is None:
raise ValueError(f"Model path {model_path} does not exist")
self._embedder = LlamaEmbedder(
model_path=model_path, pooling_type=PoolingType.MEAN
)
elif os.path.exists(model_path):
self._model_file = model_path
elif model_path and hf_file_name:
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ValueError(
"The `huggingface_hub` python package is not installed. "
"Please install it with `pip install huggingface_hub`"
)
self._model_file = hf_hub_download(
repo_id=model_path, filename=hf_file_name
)
if pooling_type is None:
pt = PT.NONE
elif pooling_type == PoolingType.MEAN:
pt = PT.MEAN
elif pooling_type == PoolingType.CLS:
pt = PT.CLS
elif pooling_type == PoolingType.LAST:
pt = PT.LAST
else:
raise ValueError(f"Invalid pooling type: {pooling_type}")

self._embedder = LlamaEmbedder(model_path=self._model_file, pooling_type=pt)

def __call__(self, input: Documents) -> Optional[Documents]:
return self._embedder.embed(input)
24 changes: 23 additions & 1 deletion experiments/llama-cpp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"source": [
"from chromadbx.embeddings.llamacpp import LlamaCppEmbeddingFunction\n",
"import chromadb\n",
"\n",
"ef = LlamaCppEmbeddingFunction(model_path=\"snowflake-arctic-embed-s/snowflake-arctic-embed-s-f16.GGUF\")\n",
"\n",
"client = chromadb.Client()\n",
Expand All @@ -41,13 +42,34 @@
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "# With HF model\n",
"id": "296662e89e5e2ce1"
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"source": [
"from chromadbx.embeddings.llamacpp import LlamaCppEmbeddingFunction\n",
"\n",
"ef = LlamaCppEmbeddingFunction(model_path=\"yixuan-chia/snowflake-arctic-embed-s-GGUF\",\n",
" hf_file_name=\"snowflake-arctic-embed-s-F32.gguf\")\n",
"\n",
"ef([\"lorem ipsum...\"])"
],
"id": "e5658111be0f28d8",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"id": "9e0f134be5b72d4c",
"outputs": [],
"execution_count": null
}
],
"metadata": {
Expand Down
11 changes: 11 additions & 0 deletions test/embeddings/test_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,14 @@ def test_embed(get_model: str) -> None:
assert len(embeddings[0]) == 384
assert len(embeddings[1]) == 384
assert embeddings[0] != embeddings[1]


def test_embed_from_hf_model() -> None:
ef = LlamaCppEmbeddingFunction(
model_path=DEFAULT_REPO, hf_file_name=DEFAULT_TEST_MODEL
)
embeddings = ef(["hello world", "goodbye world"])
assert len(embeddings) == 2
assert len(embeddings[0]) == 384
assert len(embeddings[1]) == 384
assert embeddings[0] != embeddings[1]

0 comments on commit ac4fc89

Please sign in to comment.