Skip to content

Commit

Permalink
anthropic[minor]: package move (#17974)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored Feb 26, 2024
1 parent a2d5fa7 commit 3b5bdbf
Show file tree
Hide file tree
Showing 10 changed files with 886 additions and 57 deletions.
5 changes: 3 additions & 2 deletions libs/partners/anthropic/langchain_anthropic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from langchain_anthropic.chat_models import ChatAnthropicMessages
from langchain_anthropic.chat_models import ChatAnthropic, ChatAnthropicMessages
from langchain_anthropic.llms import Anthropic, AnthropicLLM

__all__ = ["ChatAnthropicMessages"]
__all__ = ["ChatAnthropicMessages", "ChatAnthropic", "Anthropic", "AnthropicLLM"]
50 changes: 38 additions & 12 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple

import anthropic
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand All @@ -14,7 +15,11 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from langchain_core.utils import (
build_extra_kwargs,
convert_to_secret_str,
get_pydantic_field_names,
)

_message_type_lookups = {"human": "user", "ai": "assistant"}

Expand Down Expand Up @@ -50,7 +55,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
return system, formatted_messages


class ChatAnthropicMessages(BaseChatModel):
class ChatAnthropic(BaseChatModel):
"""ChatAnthropicMessages chat model.
Example:
Expand All @@ -61,13 +66,18 @@ class ChatAnthropicMessages(BaseChatModel):
model = ChatAnthropicMessages()
"""

_client: anthropic.Client = Field(default_factory=anthropic.Client)
_async_client: anthropic.AsyncClient = Field(default_factory=anthropic.AsyncClient)
class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

_client: anthropic.Client = Field(default=None)
_async_client: anthropic.AsyncClient = Field(default=None)

model: str = Field(alias="model_name")
"""Model name to use."""

max_tokens: int = Field(default=256)
max_tokens: int = Field(default=256, alias="max_tokens_to_sample")
"""Denotes the number of tokens to predict per generation."""

temperature: Optional[float] = None
Expand All @@ -88,16 +98,20 @@ class ChatAnthropicMessages(BaseChatModel):

model_kwargs: Dict[str, Any] = Field(default_factory=dict)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-anthropic-messages"

@root_validator(pre=True)
def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
anthropic_api_key = convert_to_secret_str(
Expand Down Expand Up @@ -130,6 +144,7 @@ def _format_params(
"top_p": self.top_p,
"stop_sequences": stop,
"system": system,
**self.model_kwargs,
}
rtn = {k: v for k, v in rtn.items() if v is not None}

Expand All @@ -145,7 +160,10 @@ def _stream(
params = self._format_params(messages=messages, stop=stop, **kwargs)
with self._client.messages.stream(**params) as stream:
for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk

async def _astream(
self,
Expand All @@ -157,7 +175,10 @@ async def _astream(
params = self._format_params(messages=messages, stop=stop, **kwargs)
async with self._async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
await run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk

def _generate(
self,
Expand Down Expand Up @@ -190,3 +211,8 @@ async def _agenerate(
],
llm_output=data,
)


@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatAnthropic")
class ChatAnthropicMessages(ChatAnthropic):
pass
Loading

0 comments on commit 3b5bdbf

Please sign in to comment.