Skip to content

Commit

Permalink
ai21[patch]: AI21 Labs bump SDK version (#19114)
Browse files Browse the repository at this point in the history
Description: Added support AI21 SDK version 2.1.2
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Asaf Gardin <[email protected]>
Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
3 people authored Mar 19, 2024
1 parent edf9d1c commit 21c4547
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 70 deletions.
4 changes: 2 additions & 2 deletions libs/partners/ai21/langchain_ai21/ai21_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict, Optional
from typing import Any, Dict, Optional

from ai21 import AI21Client
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
Expand All @@ -12,7 +12,7 @@ class AI21Base(BaseModel):
class Config:
arbitrary_types_allowed = True

client: AI21Client = Field(default=None)
client: Any = Field(default=None, exclude=True) #: :meta private:
api_key: Optional[SecretStr] = None
api_host: Optional[str] = None
timeout_sec: Optional[float] = None
Expand Down
74 changes: 52 additions & 22 deletions libs/partners/ai21/langchain_ai21/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
from functools import partial
from typing import Any, List, Optional, Tuple, cast
from typing import Any, List, Mapping, Optional, Tuple, cast

from ai21.models import ChatMessage, Penalty, RoleType
from ai21.models import ChatMessage, RoleType
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand Down Expand Up @@ -109,13 +109,13 @@ class ChatAI21(BaseChatModel, AI21Base):
top_k_return: int = 0
"""The number of top-scoring tokens to consider for each generation step."""

frequency_penalty: Optional[Penalty] = None
frequency_penalty: Optional[Any] = None
"""A penalty applied to tokens that are frequently generated."""

presence_penalty: Optional[Penalty] = None
presence_penalty: Optional[Any] = None
""" A penalty applied to tokens that are already present in the prompt."""

count_penalty: Optional[Penalty] = None
count_penalty: Optional[Any] = None
"""A penalty applied to tokens based on their frequency
in the generated responses."""

Expand All @@ -129,31 +129,61 @@ def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-ai21"

@property
def _default_params(self) -> Mapping[str, Any]:
base_params = {
"model": self.model,
"num_results": self.num_results,
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k_return": self.top_k_return,
}

if self.count_penalty is not None:
base_params["count_penalty"] = self.count_penalty.to_dict()

if self.frequency_penalty is not None:
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()

if self.presence_penalty is not None:
base_params["presence_penalty"] = self.presence_penalty.to_dict()

return base_params

def _build_params_for_request(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Mapping[str, Any]:
params = {}
system, ai21_messages = _convert_messages_to_ai21_messages(messages)

if stop is not None:
if "stop" in kwargs:
raise ValueError("stop is defined in both stop and kwargs")
params["stop_sequences"] = stop

return {
"system": system or "",
"messages": ai21_messages,
**self._default_params,
**params,
**kwargs,
}

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
params = self._build_params_for_request(messages=messages, stop=stop, **kwargs)

response = self.client.chat.create(
model=self.model,
messages=ai21_messages,
system=system or "",
num_results=self.num_results,
temperature=self.temperature,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
top_p=self.top_p,
top_k_return=self.top_k_return,
stop_sequences=stop,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
count_penalty=self.count_penalty,
**kwargs,
)
response = self.client.chat.create(**params)

outputs = response.outputs
message = AIMessage(content=outputs[0].text)
Expand Down
78 changes: 55 additions & 23 deletions libs/partners/ai21/langchain_ai21/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from typing import (
Any,
List,
Mapping,
Optional,
)

from ai21.models import CompletionsResponse, Penalty
from ai21.models import CompletionsResponse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand Down Expand Up @@ -47,16 +48,16 @@ class AI21LLM(BaseLLM, AI21Base):
top_p: float = 1
"""A value controlling the diversity of the model's responses."""

top_k_returns: int = 0
top_k_return: int = 0
"""The number of top-scoring tokens to consider for each generation step."""

frequency_penalty: Optional[Penalty] = None
frequency_penalty: Optional[Any] = None
"""A penalty applied to tokens that are frequently generated."""

presence_penalty: Optional[Penalty] = None
presence_penalty: Optional[Any] = None
""" A penalty applied to tokens that are already present in the prompt."""

count_penalty: Optional[Penalty] = None
count_penalty: Optional[Any] = None
"""A penalty applied to tokens based on their frequency
in the generated responses."""

Expand All @@ -73,6 +74,51 @@ def _llm_type(self) -> str:
"""Return type of LLM."""
return "ai21-llm"

@property
def _default_params(self) -> Mapping[str, Any]:
base_params = {
"model": self.model,
"num_results": self.num_results,
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k_return": self.top_k_return,
}

if self.count_penalty is not None:
base_params["count_penalty"] = self.count_penalty.to_dict()

if self.custom_model is not None:
base_params["custom_model"] = self.custom_model

if self.epoch is not None:
base_params["epoch"] = self.epoch

if self.frequency_penalty is not None:
base_params["frequency_penalty"] = self.frequency_penalty.to_dict()

if self.presence_penalty is not None:
base_params["presence_penalty"] = self.presence_penalty.to_dict()

return base_params

def _build_params_for_request(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Mapping[str, Any]:
params = {}

if stop is not None:
if "stop" in kwargs:
raise ValueError("stop is defined in both stop and kwargs")
params["stop_sequences"] = stop

return {
**self._default_params,
**params,
**kwargs,
}

def _generate(
self,
prompts: List[str],
Expand All @@ -83,10 +129,10 @@ def _generate(
generations: List[List[Generation]] = []
token_count = 0

params = self._build_params_for_request(stop=stop, **kwargs)

for prompt in prompts:
response = self._invoke_completion(
prompt=prompt, model=self.model, stop_sequences=stop, **kwargs
)
response = self._invoke_completion(prompt=prompt, **params)
generation = self._response_to_generation(response)
generations.append(generation)
token_count += self.client.count_tokens(prompt)
Expand All @@ -109,25 +155,11 @@ async def _agenerate(
def _invoke_completion(
self,
prompt: str,
model: str,
stop_sequences: Optional[List[str]] = None,
**kwargs: Any,
) -> CompletionsResponse:
return self.client.completion.create(
prompt=prompt,
model=model,
max_tokens=self.max_tokens,
num_results=self.num_results,
min_tokens=self.min_tokens,
temperature=self.temperature,
top_p=self.top_p,
top_k_return=self.top_k_returns,
custom_model=self.custom_model,
stop_sequences=stop_sequences,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
count_penalty=self.count_penalty,
epoch=self.epoch,
**kwargs,
)

def _response_to_generation(
Expand Down
14 changes: 7 additions & 7 deletions libs/partners/ai21/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions libs/partners/ai21/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[tool.poetry]
name = "langchain-ai21"
version = "0.1.1"
version = "0.1.2"
description = "An integration package connecting AI21 and LangChain"
authors = []
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.22"
ai21 = "2.0.5"
ai21 = "^2.1.2"

[tool.poetry.group.test]
optional = true
Expand Down
24 changes: 23 additions & 1 deletion libs/partners/ai21/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from contextlib import contextmanager
from typing import Generator
from unittest.mock import Mock
Expand Down Expand Up @@ -31,11 +32,30 @@
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
"count_penalty": Penalty(
scale=0.2, apply_to_punctuation=True, apply_to_emojis=True
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
),
}


BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
"count_penalty": Penalty(
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
).to_dict(),
}


@pytest.fixture
def mocked_completion_response(mocker: MockerFixture) -> Mock:
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
Expand Down Expand Up @@ -86,10 +106,12 @@ def temporarily_unset_api_key() -> Generator:
"""
api_key = AI21EnvConfig.api_key
AI21EnvConfig.api_key = None
os.environ.pop("AI21_API_KEY", None)
yield

if api_key is not None:
AI21EnvConfig.api_key = api_key
os.environ["AI21_API_KEY"] = api_key


@pytest.fixture
Expand Down
Loading

0 comments on commit 21c4547

Please sign in to comment.