From a043d762cec53b631ca64e5fb024d80f4846a5da Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Sat, 2 Nov 2024 17:10:28 +0300 Subject: [PATCH 1/2] add encode_query kwargs --- .../embeddings/huggingface.py | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py index 180a9ed3b5e79..3eee5252c6de1 100644 --- a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py @@ -36,9 +36,14 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): `prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`. See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer""" encode_kwargs: Dict[str, Any] = Field(default_factory=dict) - """Keyword arguments to pass when calling the `encode` method of the Sentence - Transformer model, such as `prompt_name`, `prompt`, `batch_size`, `precision`, - `normalize_embeddings`, and more. + """Keyword arguments to pass when calling the `encode` method for the documents of + the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`, + `precision`, `normalize_embeddings`, and more. + See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode""" + query_encode_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass when calling the `encode` method for the query of + the Sentence Transformer model, such as `prompt_name`, `prompt`, `batch_size`, + `precision`, `normalize_embeddings`, and more. See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode""" multi_process: bool = False """Run encode() on multiple GPUs.""" @@ -65,11 +70,17 @@ def __init__(self, **kwargs: Any): protected_namespaces=(), ) - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Compute doc embeddings using a HuggingFace transformer model. + def embed( + self, texts: list[str], encode_kwargs: Dict[str, Any] + ) -> List[List[float]]: + """ + Embed a text using the HuggingFace transformer model. Args: texts: The list of texts to embed. + encode_kwargs: Keyword arguments to pass when calling the + `encode` method for the documents of the SentenceTransformer + encode method. Returns: List of embeddings, one for each text. @@ -85,7 +96,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: embeddings = self._client.encode( texts, show_progress_bar=self.show_progress, - **self.encode_kwargs, # type: ignore + **encode_kwargs, # type: ignore ) if isinstance(embeddings, list): @@ -96,6 +107,17 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: return embeddings.tolist() + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a HuggingFace transformer model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + return self.embed(texts, self.encode_kwargs) + def embed_query(self, text: str) -> List[float]: """Compute query embeddings using a HuggingFace transformer model. @@ -105,4 +127,9 @@ def embed_query(self, text: str) -> List[float]: Returns: Embeddings for the text. """ - return self.embed_documents([text])[0] + embed_kwargs = ( + self.query_encode_kwargs + if len(self.query_encode_kwargs) > 0 + else self.encode_kwargs + ) + return self.embed([text], embed_kwargs)[0] From ebf31a4d9e75647d93cfb9ec7bb5d6d2edd611e7 Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Tue, 5 Nov 2024 18:06:42 +0300 Subject: [PATCH 2/2] make embed private --- .../langchain_huggingface/embeddings/huggingface.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py index 3eee5252c6de1..2bbc551f4e0b1 100644 --- a/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py +++ b/libs/partners/huggingface/langchain_huggingface/embeddings/huggingface.py @@ -70,7 +70,7 @@ def __init__(self, **kwargs: Any): protected_namespaces=(), ) - def embed( + def _embed( self, texts: list[str], encode_kwargs: Dict[str, Any] ) -> List[List[float]]: """ @@ -116,7 +116,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: Returns: List of embeddings, one for each text. """ - return self.embed(texts, self.encode_kwargs) + return self._embed(texts, self.encode_kwargs) def embed_query(self, text: str) -> List[float]: """Compute query embeddings using a HuggingFace transformer model. @@ -132,4 +132,4 @@ def embed_query(self, text: str) -> List[float]: if len(self.query_encode_kwargs) > 0 else self.encode_kwargs ) - return self.embed([text], embed_kwargs)[0] + return self._embed([text], embed_kwargs)[0]