Skip to content

Commit

Permalink
ai21: apply rate limiter in integration tests (langchain-ai#24717)
Browse files Browse the repository at this point in the history
Apply rate limiter in integration tests
  • Loading branch information
eyurtsev authored and olgamurraft committed Aug 16, 2024
1 parent 7bda058 commit d284c53
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import pytest
from langchain_core.messages import AIMessageChunk, HumanMessage
from langchain_core.outputs import ChatGeneration
from langchain_core.rate_limiters import InMemoryRateLimiter

from langchain_ai21.chat_models import ChatAI21
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME

rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)


@pytest.mark.parametrize(
ids=[
Expand All @@ -21,7 +24,7 @@
)
def test_invoke(model: str) -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=model) # type: ignore[call-arg]
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
Expand All @@ -48,7 +51,7 @@ def test_generation(model: str, num_results: int) -> None:
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results"

# Create the model instance using the appropriate key for the result count
llm = ChatAI21(model=model, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]

message = HumanMessage(content="Hello, this is a test. Can you help me please?")

Expand All @@ -75,7 +78,7 @@ def test_generation(model: str, num_results: int) -> None:
)
async def test_ageneration(model: str) -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=model) # type: ignore[call-arg]
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
message = HumanMessage(content="Hello")

result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
Expand Down
8 changes: 5 additions & 3 deletions libs/partners/ai21/tests/integration_tests/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
ChatModelIntegrationTests, # type: ignore[import-not-found]
)
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests

from langchain_ai21 import ChatAI21

rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)


class BaseTestAI21(ChatModelIntegrationTests):
def teardown(self) -> None:
Expand All @@ -31,6 +32,7 @@ class TestAI21J2(BaseTestAI21):
def chat_model_params(self) -> dict:
return {
"model": "j2-ultra",
"rate_limiter": rate_limiter,
}

@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")
Expand Down

0 comments on commit d284c53

Please sign in to comment.