From 4d7f6fa968cd8361c10794e5f086885a66a528a3 Mon Sep 17 00:00:00 2001 From: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Date: Fri, 15 Mar 2024 01:10:23 +0200 Subject: [PATCH] ai21[patch]: AI21 Labs Batch Support in Embeddings (#18633) Description: Added support for batching when using AI21 Embeddings model Twitter handle: https://github.com/AI21Labs --------- Co-authored-by: Asaf Gardin Co-authored-by: Erick Friis --- .../ai21/langchain_ai21/embeddings.py | 60 +++++++++++++++---- libs/partners/ai21/pyproject.toml | 2 +- .../integration_tests/test_chat_models.py | 9 ++- .../integration_tests/test_embeddings.py | 18 ++++++ .../ai21/tests/integration_tests/test_llms.py | 17 +++--- .../ai21/tests/unit_tests/test_embeddings.py | 35 +++++++++++ .../ai21/tests/unit_tests/test_utils.py | 29 +++++++++ 7 files changed, 147 insertions(+), 23 deletions(-) create mode 100644 libs/partners/ai21/tests/unit_tests/test_utils.py diff --git a/libs/partners/ai21/langchain_ai21/embeddings.py b/libs/partners/ai21/langchain_ai21/embeddings.py index 59fad67b5c876..97b7d68242149 100644 --- a/libs/partners/ai21/langchain_ai21/embeddings.py +++ b/libs/partners/ai21/langchain_ai21/embeddings.py @@ -1,10 +1,18 @@ -from typing import Any, List +from itertools import islice +from typing import Any, Iterator, List, Optional from ai21.models import EmbedType from langchain_core.embeddings import Embeddings from langchain_ai21.ai21_base import AI21Base +_DEFAULT_BATCH_SIZE = 128 + + +def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[List[str]]: + texts_itr = iter(texts) + return iter(lambda: list(islice(texts_itr, batch_size)), []) + class AI21Embeddings(Embeddings, AI21Base): """AI21 Embeddings embedding model. @@ -20,22 +28,52 @@ class AI21Embeddings(Embeddings, AI21Base): query_result = embeddings.embed_query("Hello embeddings world!") """ - def embed_documents(self, texts: List[str], **kwargs: Any) -> List[List[float]]: + batch_size: int = _DEFAULT_BATCH_SIZE + """Maximum number of texts to embed in each batch""" + + def embed_documents( + self, + texts: List[str], + *, + batch_size: Optional[int] = None, + **kwargs: Any, + ) -> List[List[float]]: """Embed search docs.""" - response = self.client.embed.create( + return self._send_embeddings( texts=texts, - type=EmbedType.SEGMENT, + batch_size=batch_size or self.batch_size, + embed_type=EmbedType.SEGMENT, **kwargs, ) - return [result.embedding for result in response.results] - - def embed_query(self, text: str, **kwargs: Any) -> List[float]: + def embed_query( + self, + text: str, + *, + batch_size: Optional[int] = None, + **kwargs: Any, + ) -> List[float]: """Embed query text.""" - response = self.client.embed.create( + return self._send_embeddings( texts=[text], - type=EmbedType.QUERY, + batch_size=batch_size or self.batch_size, + embed_type=EmbedType.QUERY, **kwargs, - ) + )[0] + + def _send_embeddings( + self, texts: List[str], *, batch_size: int, embed_type: EmbedType, **kwargs: Any + ) -> List[List[float]]: + chunks = _split_texts_into_batches(texts, batch_size) + responses = [ + self.client.embed.create( + texts=chunk, + type=embed_type, + **kwargs, + ) + for chunk in chunks + ] - return [result.embedding for result in response.results][0] + return [ + result.embedding for response in responses for result in response.results + ] diff --git a/libs/partners/ai21/pyproject.toml b/libs/partners/ai21/pyproject.toml index 6ea612d1a2c30..c448bc532a847 100644 --- a/libs/partners/ai21/pyproject.toml +++ b/libs/partners/ai21/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-ai21" -version = "0.1.0" +version = "0.1.1" description = "An integration package connecting AI21 and LangChain" authors = [] readme = "README.md" diff --git a/libs/partners/ai21/tests/integration_tests/test_chat_models.py b/libs/partners/ai21/tests/integration_tests/test_chat_models.py index 37efd41e520c4..3cfdb9fc9468a 100644 --- a/libs/partners/ai21/tests/integration_tests/test_chat_models.py +++ b/libs/partners/ai21/tests/integration_tests/test_chat_models.py @@ -1,13 +1,16 @@ """Test ChatAI21 chat model.""" + from langchain_core.messages import HumanMessage from langchain_core.outputs import ChatGeneration from langchain_ai21.chat_models import ChatAI21 +_MODEL_NAME = "j2-ultra" + def test_invoke() -> None: """Test invoke tokens from AI21.""" - llm = ChatAI21(model="j2-ultra") + llm = ChatAI21(model=_MODEL_NAME) result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) @@ -15,7 +18,7 @@ def test_invoke() -> None: def test_generation() -> None: """Test invoke tokens from AI21.""" - llm = ChatAI21(model="j2-ultra") + llm = ChatAI21(model=_MODEL_NAME) message = HumanMessage(content="Hello") result = llm.generate([[message], [message]], config=dict(tags=["foo"])) @@ -30,7 +33,7 @@ def test_generation() -> None: async def test_ageneration() -> None: """Test invoke tokens from AI21.""" - llm = ChatAI21(model="j2-ultra") + llm = ChatAI21(model=_MODEL_NAME) message = HumanMessage(content="Hello") result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"])) diff --git a/libs/partners/ai21/tests/integration_tests/test_embeddings.py b/libs/partners/ai21/tests/integration_tests/test_embeddings.py index 8434234e56a4b..4dac3158af34a 100644 --- a/libs/partners/ai21/tests/integration_tests/test_embeddings.py +++ b/libs/partners/ai21/tests/integration_tests/test_embeddings.py @@ -1,4 +1,5 @@ """Test AI21 embeddings.""" + from langchain_ai21.embeddings import AI21Embeddings @@ -17,3 +18,20 @@ def test_langchain_ai21_embedding_query() -> None: embedding = AI21Embeddings() output = embedding.embed_query(document) assert len(output) > 0 + + +def test_langchain_ai21_embedding_documents__with_explicit_chunk_size() -> None: + """Test AI21 embeddings with chunk size passed as an argument.""" + documents = ["foo", "bar"] + embedding = AI21Embeddings() + output = embedding.embed_documents(documents, batch_size=1) + assert len(output) == 2 + assert len(output[0]) > 0 + + +def test_langchain_ai21_embedding_query__with_explicit_chunk_size() -> None: + """Test AI21 embeddings with chunk size passed as an argument.""" + documents = "foo bar" + embedding = AI21Embeddings() + output = embedding.embed_query(documents, batch_size=1) + assert len(output) > 0 diff --git a/libs/partners/ai21/tests/integration_tests/test_llms.py b/libs/partners/ai21/tests/integration_tests/test_llms.py index fe4a812552723..2219f592921e9 100644 --- a/libs/partners/ai21/tests/integration_tests/test_llms.py +++ b/libs/partners/ai21/tests/integration_tests/test_llms.py @@ -1,15 +1,16 @@ """Test AI21LLM llm.""" - from langchain_ai21.llms import AI21LLM +_MODEL_NAME = "j2-mid" + def _generate_llm() -> AI21LLM: """ Testing AI21LLm using non default parameters with the following parameters """ return AI21LLM( - model="j2-ultra", + model=_MODEL_NAME, max_tokens=2, # Use less tokens for a faster response temperature=0, # for a consistent response epoch=1, @@ -19,7 +20,7 @@ def _generate_llm() -> AI21LLM: def test_stream() -> None: """Test streaming tokens from AI21.""" llm = AI21LLM( - model="j2-ultra", + model=_MODEL_NAME, ) for token in llm.stream("I'm Pickle Rick"): @@ -29,7 +30,7 @@ def test_stream() -> None: async def test_abatch() -> None: """Test streaming tokens from AI21LLM.""" llm = AI21LLM( - model="j2-ultra", + model=_MODEL_NAME, ) result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) @@ -40,7 +41,7 @@ async def test_abatch() -> None: async def test_abatch_tags() -> None: """Test batch tokens from AI21LLM.""" llm = AI21LLM( - model="j2-ultra", + model=_MODEL_NAME, ) result = await llm.abatch( @@ -53,7 +54,7 @@ async def test_abatch_tags() -> None: def test_batch() -> None: """Test batch tokens from AI21LLM.""" llm = AI21LLM( - model="j2-ultra", + model=_MODEL_NAME, ) result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) @@ -64,7 +65,7 @@ def test_batch() -> None: async def test_ainvoke() -> None: """Test invoke tokens from AI21LLM.""" llm = AI21LLM( - model="j2-ultra", + model=_MODEL_NAME, ) result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) @@ -74,7 +75,7 @@ async def test_ainvoke() -> None: def test_invoke() -> None: """Test invoke tokens from AI21LLM.""" llm = AI21LLM( - model="j2-ultra", + model=_MODEL_NAME, ) result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) diff --git a/libs/partners/ai21/tests/unit_tests/test_embeddings.py b/libs/partners/ai21/tests/unit_tests/test_embeddings.py index a366b32dd331f..84677734085ed 100644 --- a/libs/partners/ai21/tests/unit_tests/test_embeddings.py +++ b/libs/partners/ai21/tests/unit_tests/test_embeddings.py @@ -1,4 +1,6 @@ """Test embedding model integration.""" + +from typing import List from unittest.mock import Mock import pytest @@ -65,3 +67,36 @@ def test_embed_documents(mock_client_with_embeddings: Mock) -> None: texts=texts, type=EmbedType.SEGMENT, ) + + +@pytest.mark.parametrize( + ids=[ + "empty_texts", + "chunk_size_greater_than_texts_length", + "chunk_size_equal_to_texts_length", + "chunk_size_less_than_texts_length", + "chunk_size_one_with_multiple_texts", + "chunk_size_greater_than_texts_length", + ], + argnames=["texts", "chunk_size", "expected_internal_embeddings_calls"], + argvalues=[ + ([], 3, 0), + (["text1", "text2", "text3"], 5, 1), + (["text1", "text2", "text3"], 3, 1), + (["text1", "text2", "text3", "text4", "text5"], 2, 3), + (["text1", "text2", "text3"], 1, 3), + (["text1", "text2", "text3"], 10, 1), + ], +) +def test_get_len_safe_embeddings( + mock_client_with_embeddings: Mock, + texts: List[str], + chunk_size: int, + expected_internal_embeddings_calls: int, +) -> None: + llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY) + llm.embed_documents(texts=texts, batch_size=chunk_size) + assert ( + mock_client_with_embeddings.embed.create.call_count + == expected_internal_embeddings_calls + ) diff --git a/libs/partners/ai21/tests/unit_tests/test_utils.py b/libs/partners/ai21/tests/unit_tests/test_utils.py new file mode 100644 index 0000000000000..5a1e676cef80c --- /dev/null +++ b/libs/partners/ai21/tests/unit_tests/test_utils.py @@ -0,0 +1,29 @@ +from typing import List + +import pytest + +from langchain_ai21.embeddings import _split_texts_into_batches + + +@pytest.mark.parametrize( + ids=[ + "when_chunk_size_is_2__should_return_3_chunks", + "when_texts_is_empty__should_return_empty_list", + "when_chunk_size_is_1__should_return_10_chunks", + ], + argnames=["input_texts", "chunk_size", "expected_output"], + argvalues=[ + (["a", "b", "c", "d", "e"], 2, [["a", "b"], ["c", "d"], ["e"]]), + ([], 3, []), + ( + ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + 1, + [["1"], ["2"], ["3"], ["4"], ["5"], ["6"], ["7"], ["8"], ["9"], ["10"]], + ), + ], +) +def test_chunked_text_generator( + input_texts: List[str], chunk_size: int, expected_output: List[List[str]] +) -> None: + result = list(_split_texts_into_batches(input_texts, chunk_size)) + assert result == expected_output