Skip to content

Commit

Permalink
ai21: docstrings (#23142)
Browse files Browse the repository at this point in the history
Added missed docstrings. Format docstrings to the consistent format
(used in the API Reference)
  • Loading branch information
leo-gan authored Jun 19, 2024
1 parent 0c2ebe5 commit a70b7a6
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 5 deletions.
2 changes: 2 additions & 0 deletions libs/partners/ai21/langchain_ai21/ai21_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class AI21Base(BaseModel):
"""Base class for AI21 models."""

class Config:
arbitrary_types_allowed = True

Expand Down
8 changes: 6 additions & 2 deletions libs/partners/ai21/langchain_ai21/chat/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@


class ChatAdapter(ABC):
"""
Provides a common interface for the different Chat models available in AI21.
"""Common interface for the different Chat models available in AI21.
It converts LangChain messages to AI21 messages.
Calls the appropriate AI21 model API with the converted messages.
"""
Expand Down Expand Up @@ -77,6 +77,8 @@ def _get_system_message_from_message(self, message: BaseMessage) -> str:


class J2ChatAdapter(ChatAdapter):
"""Adapter for J2Chat models."""

def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]:
system_message = ""
converted_messages = [] # type: ignore
Expand Down Expand Up @@ -107,6 +109,8 @@ def call(self, client: Any, **params: Any) -> List[BaseMessage]:


class JambaChatCompletionsAdapter(ChatAdapter):
"""Adapter for Jamba Chat Completions."""

def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]:
return {
"messages": [
Expand Down
8 changes: 8 additions & 0 deletions libs/partners/ai21/langchain_ai21/chat/chat_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@


def create_chat_adapter(model: str) -> ChatAdapter:
"""Create a chat adapter based on the model.
Args:
model: The model to create the chat adapter for.
Returns:
The chat adapter.
"""
if "j2" in model:
return J2ChatAdapter()

Expand Down
4 changes: 4 additions & 0 deletions libs/partners/ai21/langchain_ai21/contextual_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@


class ContextualAnswerInput(TypedDict):
"""Input for the ContextualAnswers runnable."""

context: ContextType
question: str


class AI21ContextualAnswers(RunnableSerializable[ContextualAnswerInput, str], AI21Base):
"""Runnable for the AI21 Contextual Answers API."""

class Config:
"""Configuration for this pydantic object."""

Expand Down
3 changes: 2 additions & 1 deletion libs/partners/ai21/langchain_ai21/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[Lis


class AI21Embeddings(Embeddings, AI21Base):
"""AI21 Embeddings embedding model.
"""AI21 embedding model.
To use, you should have the 'AI21_API_KEY' environment variable set
or pass as a named parameter to the constructor.
Expand Down
2 changes: 1 addition & 1 deletion libs/partners/ai21/langchain_ai21/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class AI21LLM(BaseLLM, AI21Base):
"""AI21LLM large language models.
"""AI21 large language models.
Example:
.. code-block:: python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class AI21SemanticTextSplitter(TextSplitter):
"""Splitting text into coherent and readable units,
based on distinct topics and lines
based on distinct topics and lines.
"""

def __init__(
Expand Down

0 comments on commit a70b7a6

Please sign in to comment.