Skip to content

Commit

Permalink
ai21[patch]: AI21 Labs Batch Support in Embeddings (#18633)
Browse files Browse the repository at this point in the history
Description: Added support for batching when using AI21 Embeddings model
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Asaf Gardin <[email protected]>
Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
3 people authored Mar 14, 2024
1 parent 321db89 commit 4d7f6fa
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 23 deletions.
60 changes: 49 additions & 11 deletions libs/partners/ai21/langchain_ai21/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import Any, List
from itertools import islice
from typing import Any, Iterator, List, Optional

from ai21.models import EmbedType
from langchain_core.embeddings import Embeddings

from langchain_ai21.ai21_base import AI21Base

_DEFAULT_BATCH_SIZE = 128


def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[List[str]]:
texts_itr = iter(texts)
return iter(lambda: list(islice(texts_itr, batch_size)), [])


class AI21Embeddings(Embeddings, AI21Base):
"""AI21 Embeddings embedding model.
Expand All @@ -20,22 +28,52 @@ class AI21Embeddings(Embeddings, AI21Base):
query_result = embeddings.embed_query("Hello embeddings world!")
"""

def embed_documents(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
batch_size: int = _DEFAULT_BATCH_SIZE
"""Maximum number of texts to embed in each batch"""

def embed_documents(
self,
texts: List[str],
*,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> List[List[float]]:
"""Embed search docs."""
response = self.client.embed.create(
return self._send_embeddings(
texts=texts,
type=EmbedType.SEGMENT,
batch_size=batch_size or self.batch_size,
embed_type=EmbedType.SEGMENT,
**kwargs,
)

return [result.embedding for result in response.results]

def embed_query(self, text: str, **kwargs: Any) -> List[float]:
def embed_query(
self,
text: str,
*,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> List[float]:
"""Embed query text."""
response = self.client.embed.create(
return self._send_embeddings(
texts=[text],
type=EmbedType.QUERY,
batch_size=batch_size or self.batch_size,
embed_type=EmbedType.QUERY,
**kwargs,
)
)[0]

def _send_embeddings(
self, texts: List[str], *, batch_size: int, embed_type: EmbedType, **kwargs: Any
) -> List[List[float]]:
chunks = _split_texts_into_batches(texts, batch_size)
responses = [
self.client.embed.create(
texts=chunk,
type=embed_type,
**kwargs,
)
for chunk in chunks
]

return [result.embedding for result in response.results][0]
return [
result.embedding for response in responses for result in response.results
]
2 changes: 1 addition & 1 deletion libs/partners/ai21/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-ai21"
version = "0.1.0"
version = "0.1.1"
description = "An integration package connecting AI21 and LangChain"
authors = []
readme = "README.md"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
"""Test ChatAI21 chat model."""

from langchain_core.messages import HumanMessage
from langchain_core.outputs import ChatGeneration

from langchain_ai21.chat_models import ChatAI21

_MODEL_NAME = "j2-ultra"


def test_invoke() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
llm = ChatAI21(model=_MODEL_NAME)

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)


def test_generation() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
llm = ChatAI21(model=_MODEL_NAME)
message = HumanMessage(content="Hello")

result = llm.generate([[message], [message]], config=dict(tags=["foo"]))
Expand All @@ -30,7 +33,7 @@ def test_generation() -> None:

async def test_ageneration() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
llm = ChatAI21(model=_MODEL_NAME)
message = HumanMessage(content="Hello")

result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
Expand Down
18 changes: 18 additions & 0 deletions libs/partners/ai21/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test AI21 embeddings."""

from langchain_ai21.embeddings import AI21Embeddings


Expand All @@ -17,3 +18,20 @@ def test_langchain_ai21_embedding_query() -> None:
embedding = AI21Embeddings()
output = embedding.embed_query(document)
assert len(output) > 0


def test_langchain_ai21_embedding_documents__with_explicit_chunk_size() -> None:
"""Test AI21 embeddings with chunk size passed as an argument."""
documents = ["foo", "bar"]
embedding = AI21Embeddings()
output = embedding.embed_documents(documents, batch_size=1)
assert len(output) == 2
assert len(output[0]) > 0


def test_langchain_ai21_embedding_query__with_explicit_chunk_size() -> None:
"""Test AI21 embeddings with chunk size passed as an argument."""
documents = "foo bar"
embedding = AI21Embeddings()
output = embedding.embed_query(documents, batch_size=1)
assert len(output) > 0
17 changes: 9 additions & 8 deletions libs/partners/ai21/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""Test AI21LLM llm."""


from langchain_ai21.llms import AI21LLM

_MODEL_NAME = "j2-mid"


def _generate_llm() -> AI21LLM:
"""
Testing AI21LLm using non default parameters with the following parameters
"""
return AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
max_tokens=2, # Use less tokens for a faster response
temperature=0, # for a consistent response
epoch=1,
Expand All @@ -19,7 +20,7 @@ def _generate_llm() -> AI21LLM:
def test_stream() -> None:
"""Test streaming tokens from AI21."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

for token in llm.stream("I'm Pickle Rick"):
Expand All @@ -29,7 +30,7 @@ def test_stream() -> None:
async def test_abatch() -> None:
"""Test streaming tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
Expand All @@ -40,7 +41,7 @@ async def test_abatch() -> None:
async def test_abatch_tags() -> None:
"""Test batch tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

result = await llm.abatch(
Expand All @@ -53,7 +54,7 @@ async def test_abatch_tags() -> None:
def test_batch() -> None:
"""Test batch tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
Expand All @@ -64,7 +65,7 @@ def test_batch() -> None:
async def test_ainvoke() -> None:
"""Test invoke tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
Expand All @@ -74,7 +75,7 @@ async def test_ainvoke() -> None:
def test_invoke() -> None:
"""Test invoke tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
Expand Down
35 changes: 35 additions & 0 deletions libs/partners/ai21/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Test embedding model integration."""

from typing import List
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -65,3 +67,36 @@ def test_embed_documents(mock_client_with_embeddings: Mock) -> None:
texts=texts,
type=EmbedType.SEGMENT,
)


@pytest.mark.parametrize(
ids=[
"empty_texts",
"chunk_size_greater_than_texts_length",
"chunk_size_equal_to_texts_length",
"chunk_size_less_than_texts_length",
"chunk_size_one_with_multiple_texts",
"chunk_size_greater_than_texts_length",
],
argnames=["texts", "chunk_size", "expected_internal_embeddings_calls"],
argvalues=[
([], 3, 0),
(["text1", "text2", "text3"], 5, 1),
(["text1", "text2", "text3"], 3, 1),
(["text1", "text2", "text3", "text4", "text5"], 2, 3),
(["text1", "text2", "text3"], 1, 3),
(["text1", "text2", "text3"], 10, 1),
],
)
def test_get_len_safe_embeddings(
mock_client_with_embeddings: Mock,
texts: List[str],
chunk_size: int,
expected_internal_embeddings_calls: int,
) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
llm.embed_documents(texts=texts, batch_size=chunk_size)
assert (
mock_client_with_embeddings.embed.create.call_count
== expected_internal_embeddings_calls
)
29 changes: 29 additions & 0 deletions libs/partners/ai21/tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List

import pytest

from langchain_ai21.embeddings import _split_texts_into_batches


@pytest.mark.parametrize(
ids=[
"when_chunk_size_is_2__should_return_3_chunks",
"when_texts_is_empty__should_return_empty_list",
"when_chunk_size_is_1__should_return_10_chunks",
],
argnames=["input_texts", "chunk_size", "expected_output"],
argvalues=[
(["a", "b", "c", "d", "e"], 2, [["a", "b"], ["c", "d"], ["e"]]),
([], 3, []),
(
["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
1,
[["1"], ["2"], ["3"], ["4"], ["5"], ["6"], ["7"], ["8"], ["9"], ["10"]],
),
],
)
def test_chunked_text_generator(
input_texts: List[str], chunk_size: int, expected_output: List[List[str]]
) -> None:
result = list(_split_texts_into_batches(input_texts, chunk_size))
assert result == expected_output

0 comments on commit 4d7f6fa

Please sign in to comment.