From 41b6a86bbe030291cf8ee284ed0cd70dd493152b Mon Sep 17 00:00:00 2001 From: Mohammad Mohtashim <45242107+keenborder786@users.noreply.github.com> Date: Mon, 23 Dec 2024 19:50:22 +0500 Subject: [PATCH] Community: LlamaCppEmbeddings `embed_documents` and `embed_query` (#28827) - **Description:** `embed_documents` and `embed_query` was throwing off the error as stated in the issue. The issue was that `Llama` client is returning the embeddings in a nested list which is not being accounted for in the current implementation and therefore the stated error is being raised. - **Issue:** #28813 --------- Co-authored-by: Chester Curme --- .../embeddings/llamacpp.py | 50 ++++++++++++------- .../unit_tests/embeddings/test_llamacpp.py | 40 +++++++++++++++ 2 files changed, 72 insertions(+), 18 deletions(-) create mode 100644 libs/community/tests/unit_tests/embeddings/test_llamacpp.py diff --git a/libs/community/langchain_community/embeddings/llamacpp.py b/libs/community/langchain_community/embeddings/llamacpp.py index 6487312fd31d0..4adfeb0e52774 100644 --- a/libs/community/langchain_community/embeddings/llamacpp.py +++ b/libs/community/langchain_community/embeddings/llamacpp.py @@ -20,7 +20,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings): """ client: Any = None #: :meta private: - model_path: str + model_path: str = Field(default="") n_ctx: int = Field(512, alias="n_ctx") """Token context window.""" @@ -88,21 +88,22 @@ def validate_environment(self) -> Self: if self.n_gpu_layers is not None: model_params["n_gpu_layers"] = self.n_gpu_layers - try: - from llama_cpp import Llama - - self.client = Llama(model_path, embedding=True, **model_params) - except ImportError: - raise ImportError( - "Could not import llama-cpp-python library. " - "Please install the llama-cpp-python library to " - "use this embedding model: pip install llama-cpp-python" - ) - except Exception as e: - raise ValueError( - f"Could not load Llama model from path: {model_path}. " - f"Received error {e}" - ) + if not self.client: + try: + from llama_cpp import Llama + + self.client = Llama(model_path, embedding=True, **model_params) + except ImportError: + raise ImportError( + "Could not import llama-cpp-python library. " + "Please install the llama-cpp-python library to " + "use this embedding model: pip install llama-cpp-python" + ) + except Exception as e: + raise ValueError( + f"Could not load Llama model from path: {model_path}. " + f"Received error {e}" + ) return self @@ -116,7 +117,17 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: List of embeddings, one for each text. """ embeddings = self.client.create_embedding(texts) - return [list(map(float, e["embedding"])) for e in embeddings["data"]] + final_embeddings = [] + for e in embeddings["data"]: + try: + if isinstance(e["embedding"][0], list): + for data in e["embedding"]: + final_embeddings.append(list(map(float, data))) + else: + final_embeddings.append(list(map(float, e["embedding"]))) + except (IndexError, TypeError): + final_embeddings.append(list(map(float, e["embedding"]))) + return final_embeddings def embed_query(self, text: str) -> List[float]: """Embed a query using the Llama model. @@ -128,4 +139,7 @@ def embed_query(self, text: str) -> List[float]: Embeddings for the text. """ embedding = self.client.embed(text) - return list(map(float, embedding)) + if not isinstance(embedding, list): + return list(map(float, embedding)) + else: + return list(map(float, embedding[0])) diff --git a/libs/community/tests/unit_tests/embeddings/test_llamacpp.py b/libs/community/tests/unit_tests/embeddings/test_llamacpp.py new file mode 100644 index 0000000000000..ca2bd758216cf --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_llamacpp.py @@ -0,0 +1,40 @@ +from typing import Generator +from unittest.mock import MagicMock, patch + +import pytest + +from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings + + +@pytest.fixture +def mock_llama_client() -> Generator[MagicMock, None, None]: + with patch( + "langchain_community.embeddings.llamacpp.LlamaCppEmbeddings" + ) as MockLlama: + mock_client = MagicMock() + MockLlama.return_value = mock_client + yield mock_client + + +def test_initialization(mock_llama_client: MagicMock) -> None: + embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg] + assert embeddings.client is not None + + +def test_embed_documents(mock_llama_client: MagicMock) -> None: + mock_llama_client.create_embedding.return_value = { + "data": [{"embedding": [[0.1, 0.2, 0.3]]}, {"embedding": [[0.4, 0.5, 0.6]]}] + } + embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg] + texts = ["Hello world", "Test document"] + result = embeddings.embed_documents(texts) + expected = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + assert result == expected + + +def test_embed_query(mock_llama_client: MagicMock) -> None: + mock_llama_client.embed.return_value = [[0.1, 0.2, 0.3]] + embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg] + result = embeddings.embed_query("Sample query") + expected = [0.1, 0.2, 0.3] + assert result == expected