From 3b5bdbfee8d12caaa003daf96398b3f39f8e1476 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Sun, 25 Feb 2024 21:57:26 -0800 Subject: [PATCH] anthropic[minor]: package move (#17974) --- .../anthropic/langchain_anthropic/__init__.py | 5 +- .../langchain_anthropic/chat_models.py | 50 ++- .../anthropic/langchain_anthropic/llms.py | 352 ++++++++++++++++++ libs/partners/anthropic/poetry.lock | 74 ++-- libs/partners/anthropic/pyproject.toml | 5 +- .../integration_tests/test_chat_models.py | 80 +++- .../tests/integration_tests/test_llms.py | 74 ++++ .../anthropic/tests/unit_tests/_utils.py | 255 +++++++++++++ .../tests/unit_tests/test_chat_models.py | 46 ++- .../tests/unit_tests/test_imports.py | 2 +- 10 files changed, 886 insertions(+), 57 deletions(-) create mode 100644 libs/partners/anthropic/langchain_anthropic/llms.py create mode 100644 libs/partners/anthropic/tests/integration_tests/test_llms.py create mode 100644 libs/partners/anthropic/tests/unit_tests/_utils.py diff --git a/libs/partners/anthropic/langchain_anthropic/__init__.py b/libs/partners/anthropic/langchain_anthropic/__init__.py index c801b067e14c1..1ad0801abf14c 100644 --- a/libs/partners/anthropic/langchain_anthropic/__init__.py +++ b/libs/partners/anthropic/langchain_anthropic/__init__.py @@ -1,3 +1,4 @@ -from langchain_anthropic.chat_models import ChatAnthropicMessages +from langchain_anthropic.chat_models import ChatAnthropic, ChatAnthropicMessages +from langchain_anthropic.llms import Anthropic, AnthropicLLM -__all__ = ["ChatAnthropicMessages"] +__all__ = ["ChatAnthropicMessages", "ChatAnthropic", "Anthropic", "AnthropicLLM"] diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index a7382c3cd1b30..9da2a22ed9ea7 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -2,6 +2,7 @@ from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple import anthropic +from langchain_core._api.deprecation import deprecated from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -14,7 +15,11 @@ ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str +from langchain_core.utils import ( + build_extra_kwargs, + convert_to_secret_str, + get_pydantic_field_names, +) _message_type_lookups = {"human": "user", "ai": "assistant"} @@ -50,7 +55,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D return system, formatted_messages -class ChatAnthropicMessages(BaseChatModel): +class ChatAnthropic(BaseChatModel): """ChatAnthropicMessages chat model. Example: @@ -61,13 +66,18 @@ class ChatAnthropicMessages(BaseChatModel): model = ChatAnthropicMessages() """ - _client: anthropic.Client = Field(default_factory=anthropic.Client) - _async_client: anthropic.AsyncClient = Field(default_factory=anthropic.AsyncClient) + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + _client: anthropic.Client = Field(default=None) + _async_client: anthropic.AsyncClient = Field(default=None) model: str = Field(alias="model_name") """Model name to use.""" - max_tokens: int = Field(default=256) + max_tokens: int = Field(default=256, alias="max_tokens_to_sample") """Denotes the number of tokens to predict per generation.""" temperature: Optional[float] = None @@ -88,16 +98,20 @@ class ChatAnthropicMessages(BaseChatModel): model_kwargs: Dict[str, Any] = Field(default_factory=dict) - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True - @property def _llm_type(self) -> str: """Return type of chat model.""" return "chat-anthropic-messages" + @root_validator(pre=True) + def build_extra(cls, values: Dict) -> Dict: + extra = values.get("model_kwargs", {}) + all_required_field_names = get_pydantic_field_names(cls) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) + return values + @root_validator() def validate_environment(cls, values: Dict) -> Dict: anthropic_api_key = convert_to_secret_str( @@ -130,6 +144,7 @@ def _format_params( "top_p": self.top_p, "stop_sequences": stop, "system": system, + **self.model_kwargs, } rtn = {k: v for k, v in rtn.items() if v is not None} @@ -145,7 +160,10 @@ def _stream( params = self._format_params(messages=messages, stop=stop, **kwargs) with self._client.messages.stream(**params) as stream: for text in stream.text_stream: - yield ChatGenerationChunk(message=AIMessageChunk(content=text)) + chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) + if run_manager: + run_manager.on_llm_new_token(text, chunk=chunk) + yield chunk async def _astream( self, @@ -157,7 +175,10 @@ async def _astream( params = self._format_params(messages=messages, stop=stop, **kwargs) async with self._async_client.messages.stream(**params) as stream: async for text in stream.text_stream: - yield ChatGenerationChunk(message=AIMessageChunk(content=text)) + chunk = ChatGenerationChunk(message=AIMessageChunk(content=text)) + if run_manager: + await run_manager.on_llm_new_token(text, chunk=chunk) + yield chunk def _generate( self, @@ -190,3 +211,8 @@ async def _agenerate( ], llm_output=data, ) + + +@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatAnthropic") +class ChatAnthropicMessages(ChatAnthropic): + pass diff --git a/libs/partners/anthropic/langchain_anthropic/llms.py b/libs/partners/anthropic/langchain_anthropic/llms.py new file mode 100644 index 0000000000000..d5d04a7f3e91b --- /dev/null +++ b/libs/partners/anthropic/langchain_anthropic/llms.py @@ -0,0 +1,352 @@ +import re +import warnings +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, +) + +import anthropic +from langchain_core._api.deprecation import deprecated +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models import BaseLanguageModel +from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk +from langchain_core.prompt_values import PromptValue +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import ( + get_from_dict_or_env, + get_pydantic_field_names, +) +from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str + + +class _AnthropicCommon(BaseLanguageModel): + client: Any = None #: :meta private: + async_client: Any = None #: :meta private: + model: str = Field(default="claude-2", alias="model_name") + """Model name to use.""" + + max_tokens_to_sample: int = Field(default=256, alias="max_tokens") + """Denotes the number of tokens to predict per generation.""" + + temperature: Optional[float] = None + """A non-negative float that tunes the degree of randomness in generation.""" + + top_k: Optional[int] = None + """Number of most likely tokens to consider at each step.""" + + top_p: Optional[float] = None + """Total probability mass of tokens to consider at each step.""" + + streaming: bool = False + """Whether to stream the results.""" + + default_request_timeout: Optional[float] = None + """Timeout for requests to Anthropic Completion API. Default is 600 seconds.""" + + max_retries: int = 2 + """Number of retries allowed for requests sent to the Anthropic Completion API.""" + + anthropic_api_url: Optional[str] = None + + anthropic_api_key: Optional[SecretStr] = None + + HUMAN_PROMPT: Optional[str] = None + AI_PROMPT: Optional[str] = None + count_tokens: Optional[Callable[[str], int]] = None + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + @root_validator(pre=True) + def build_extra(cls, values: Dict) -> Dict: + extra = values.get("model_kwargs", {}) + all_required_field_names = get_pydantic_field_names(cls) + values["model_kwargs"] = build_extra_kwargs( + extra, values, all_required_field_names + ) + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["anthropic_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY") + ) + # Get custom api url from environment. + values["anthropic_api_url"] = get_from_dict_or_env( + values, + "anthropic_api_url", + "ANTHROPIC_API_URL", + default="https://api.anthropic.com", + ) + + values["client"] = anthropic.Anthropic( + base_url=values["anthropic_api_url"], + api_key=values["anthropic_api_key"].get_secret_value(), + timeout=values["default_request_timeout"], + max_retries=values["max_retries"], + ) + values["async_client"] = anthropic.AsyncAnthropic( + base_url=values["anthropic_api_url"], + api_key=values["anthropic_api_key"].get_secret_value(), + timeout=values["default_request_timeout"], + max_retries=values["max_retries"], + ) + values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT + values["AI_PROMPT"] = anthropic.AI_PROMPT + values["count_tokens"] = values["client"].count_tokens + return values + + @property + def _default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling Anthropic API.""" + d = { + "max_tokens_to_sample": self.max_tokens_to_sample, + "model": self.model, + } + if self.temperature is not None: + d["temperature"] = self.temperature + if self.top_k is not None: + d["top_k"] = self.top_k + if self.top_p is not None: + d["top_p"] = self.top_p + return {**d, **self.model_kwargs} + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{}, **self._default_params} + + def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]: + if not self.HUMAN_PROMPT or not self.AI_PROMPT: + raise NameError("Please ensure the anthropic package is loaded") + + if stop is None: + stop = [] + + # Never want model to invent new turns of Human / Assistant dialog. + stop.extend([self.HUMAN_PROMPT]) + + return stop + + +class AnthropicLLM(LLM, _AnthropicCommon): + """Anthropic large language models. + + To use, you should have the ``anthropic`` python package installed, and the + environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass + it as a named parameter to the constructor. + + Example: + .. code-block:: python + + import anthropic + from langchain_community.llms import Anthropic + + model = Anthropic(model="", anthropic_api_key="my-api-key") + + # Simplest invocation, automatically wrapped with HUMAN_PROMPT + # and AI_PROMPT. + response = model("What are the biggest risks facing humanity?") + + # Or if you want to use the chat mode, build a few-shot-prompt, or + # put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT: + raw_prompt = "What are the biggest risks facing humanity?" + prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}" + response = model(prompt) + """ + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + + @root_validator() + def raise_warning(cls, values: Dict) -> Dict: + """Raise warning that this class is deprecated.""" + warnings.warn( + "This Anthropic LLM is deprecated. " + "Please use `from langchain_community.chat_models import ChatAnthropic` " + "instead" + ) + return values + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "anthropic-llm" + + def _wrap_prompt(self, prompt: str) -> str: + if not self.HUMAN_PROMPT or not self.AI_PROMPT: + raise NameError("Please ensure the anthropic package is loaded") + + if prompt.startswith(self.HUMAN_PROMPT): + return prompt # Already wrapped. + + # Guard against common errors in specifying wrong number of newlines. + corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt) + if n_subs == 1: + return corrected_prompt + + # As a last resort, wrap the prompt ourselves to emulate instruct-style. + return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + r"""Call out to Anthropic's completion endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + prompt = "What are the biggest risks facing humanity?" + prompt = f"\n\nHuman: {prompt}\n\nAssistant:" + response = model(prompt) + + """ + if self.streaming: + completion = "" + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + return completion + + stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} + response = self.client.completions.create( + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + **params, + ) + return response.completion + + def convert_prompt(self, prompt: PromptValue) -> str: + return self._wrap_prompt(prompt.to_string()) + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to Anthropic's completion endpoint asynchronously.""" + if self.streaming: + completion = "" + async for chunk in self._astream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + return completion + + stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} + + response = await self.async_client.completions.create( + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + **params, + ) + return response.completion + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + r"""Call Anthropic completion_stream and return the resulting generator. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + Returns: + A generator representing the stream of tokens from Anthropic. + Example: + .. code-block:: python + + prompt = "Write a poem about a stream." + prompt = f"\n\nHuman: {prompt}\n\nAssistant:" + generator = anthropic.stream(prompt) + for token in generator: + yield token + """ + stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} + + for token in self.client.completions.create( + prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params + ): + chunk = GenerationChunk(text=token.completion) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + r"""Call Anthropic completion_stream and return the resulting generator. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + Returns: + A generator representing the stream of tokens from Anthropic. + Example: + .. code-block:: python + prompt = "Write a poem about a stream." + prompt = f"\n\nHuman: {prompt}\n\nAssistant:" + generator = anthropic.stream(prompt) + for token in generator: + yield token + """ + stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} + + async for token in await self.async_client.completions.create( + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + stream=True, + **params, + ): + chunk = GenerationChunk(text=token.completion) + yield chunk + if run_manager: + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + + def get_num_tokens(self, text: str) -> int: + """Calculate number of tokens.""" + if not self.count_tokens: + raise NameError("Please ensure the anthropic package is loaded") + return self.count_tokens(text) + + +@deprecated(since="0.1.0", removal="0.2.0", alternative="AnthropicLLM") +class Anthropic(AnthropicLLM): + pass diff --git a/libs/partners/anthropic/poetry.lock b/libs/partners/anthropic/poetry.lock index f06d19b59a0e3..7b886b3590329 100644 --- a/libs/partners/anthropic/poetry.lock +++ b/libs/partners/anthropic/poetry.lock @@ -40,13 +40,13 @@ vertex = ["google-auth (>=2,<3)"] [[package]] name = "anyio" -version = "4.2.0" +version = "4.3.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.8" files = [ - {file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"}, - {file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"}, + {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, + {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, ] [package.dependencies] @@ -301,13 +301,13 @@ files = [ [[package]] name = "httpcore" -version = "1.0.3" +version = "1.0.4" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-1.0.3-py3-none-any.whl", hash = "sha256:9a6a501c3099307d9fd76ac244e08503427679b1e81ceb1d922485e2f2462ad2"}, - {file = "httpcore-1.0.3.tar.gz", hash = "sha256:5c0f9546ad17dac4d0772b0808856eb616eb8b48ce94f49ed819fd6982a8a544"}, + {file = "httpcore-1.0.4-py3-none-any.whl", hash = "sha256:ac418c1db41bade2ad53ae2f3834a3a0f5ae76b56cf5aa497d2d033384fc7d73"}, + {file = "httpcore-1.0.4.tar.gz", hash = "sha256:cb2839ccfcba0d2d3c1131d3c3e26dfc327326fbe7a5dc0dbfe9f6c9151bb022"}, ] [package.dependencies] @@ -318,17 +318,17 @@ h11 = ">=0.13,<0.15" asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.24.0)"] +trio = ["trio (>=0.22.0,<0.25.0)"] [[package]] name = "httpx" -version = "0.26.0" +version = "0.27.0" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpx-0.26.0-py3-none-any.whl", hash = "sha256:8915f5a3627c4d47b73e8202457cb28f1266982d1159bd5779d86a80c0eab1cd"}, - {file = "httpx-0.26.0.tar.gz", hash = "sha256:451b55c30d5185ea6b23c2c793abf9bb237d2a7dfb901ced6ff69ad37ec1dfaf"}, + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, ] [package.dependencies] @@ -425,7 +425,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.23" +version = "0.1.25" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -435,7 +435,7 @@ develop = true [package.dependencies] anyio = ">=3,<5" jsonpatch = "^1.33" -langsmith = "^0.0.87" +langsmith = "^0.1.0" packaging = "^23.2" pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -451,13 +451,13 @@ url = "../../core" [[package]] name = "langsmith" -version = "0.0.87" +version = "0.1.5" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langsmith-0.0.87-py3-none-any.whl", hash = "sha256:8903d3811b9fc89eb18f5961c8e6935fbd2d0f119884fbf30dc70b8f8f4121fc"}, - {file = "langsmith-0.0.87.tar.gz", hash = "sha256:36c4cc47e5b54be57d038036a30fb19ce6e4c73048cd7a464b8f25b459694d34"}, + {file = "langsmith-0.1.5-py3-none-any.whl", hash = "sha256:a1811821a923d90e53bcbacdd0988c3c366aff8f4c120d8777e7af8ecda06268"}, + {file = "langsmith-0.1.5.tar.gz", hash = "sha256:aa7a2861aa3d9ae563a077c622953533800466c4e2e539b0d567b84d5fd5b157"}, ] [package.dependencies] @@ -830,28 +830,28 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "ruff" -version = "0.1.15" +version = "0.2.2" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.1.15-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:5fe8d54df166ecc24106db7dd6a68d44852d14eb0729ea4672bb4d96c320b7df"}, - {file = "ruff-0.1.15-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6f0bfbb53c4b4de117ac4d6ddfd33aa5fc31beeaa21d23c45c6dd249faf9126f"}, - {file = "ruff-0.1.15-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0d432aec35bfc0d800d4f70eba26e23a352386be3a6cf157083d18f6f5881c8"}, - {file = "ruff-0.1.15-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9405fa9ac0e97f35aaddf185a1be194a589424b8713e3b97b762336ec79ff807"}, - {file = "ruff-0.1.15-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c66ec24fe36841636e814b8f90f572a8c0cb0e54d8b5c2d0e300d28a0d7bffec"}, - {file = "ruff-0.1.15-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6f8ad828f01e8dd32cc58bc28375150171d198491fc901f6f98d2a39ba8e3ff5"}, - {file = "ruff-0.1.15-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86811954eec63e9ea162af0ffa9f8d09088bab51b7438e8b6488b9401863c25e"}, - {file = "ruff-0.1.15-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd4025ac5e87d9b80e1f300207eb2fd099ff8200fa2320d7dc066a3f4622dc6b"}, - {file = "ruff-0.1.15-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b17b93c02cdb6aeb696effecea1095ac93f3884a49a554a9afa76bb125c114c1"}, - {file = "ruff-0.1.15-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ddb87643be40f034e97e97f5bc2ef7ce39de20e34608f3f829db727a93fb82c5"}, - {file = "ruff-0.1.15-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:abf4822129ed3a5ce54383d5f0e964e7fef74a41e48eb1dfad404151efc130a2"}, - {file = "ruff-0.1.15-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6c629cf64bacfd136c07c78ac10a54578ec9d1bd2a9d395efbee0935868bf852"}, - {file = "ruff-0.1.15-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1bab866aafb53da39c2cadfb8e1c4550ac5340bb40300083eb8967ba25481447"}, - {file = "ruff-0.1.15-py3-none-win32.whl", hash = "sha256:2417e1cb6e2068389b07e6fa74c306b2810fe3ee3476d5b8a96616633f40d14f"}, - {file = "ruff-0.1.15-py3-none-win_amd64.whl", hash = "sha256:3837ac73d869efc4182d9036b1405ef4c73d9b1f88da2413875e34e0d6919587"}, - {file = "ruff-0.1.15-py3-none-win_arm64.whl", hash = "sha256:9a933dfb1c14ec7a33cceb1e49ec4a16b51ce3c20fd42663198746efc0427360"}, - {file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"}, + {file = "ruff-0.2.2-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0a9efb032855ffb3c21f6405751d5e147b0c6b631e3ca3f6b20f917572b97eb6"}, + {file = "ruff-0.2.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d450b7fbff85913f866a5384d8912710936e2b96da74541c82c1b458472ddb39"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecd46e3106850a5c26aee114e562c329f9a1fbe9e4821b008c4404f64ff9ce73"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e22676a5b875bd72acd3d11d5fa9075d3a5f53b877fe7b4793e4673499318ba"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1695700d1e25a99d28f7a1636d85bafcc5030bba9d0578c0781ba1790dbcf51c"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:b0c232af3d0bd8f521806223723456ffebf8e323bd1e4e82b0befb20ba18388e"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f63d96494eeec2fc70d909393bcd76c69f35334cdbd9e20d089fb3f0640216ca"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a61ea0ff048e06de273b2e45bd72629f470f5da8f71daf09fe481278b175001"}, + {file = "ruff-0.2.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e1439c8f407e4f356470e54cdecdca1bd5439a0673792dbe34a2b0a551a2fe3"}, + {file = "ruff-0.2.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:940de32dc8853eba0f67f7198b3e79bc6ba95c2edbfdfac2144c8235114d6726"}, + {file = "ruff-0.2.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0c126da55c38dd917621552ab430213bdb3273bb10ddb67bc4b761989210eb6e"}, + {file = "ruff-0.2.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3b65494f7e4bed2e74110dac1f0d17dc8e1f42faaa784e7c58a98e335ec83d7e"}, + {file = "ruff-0.2.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1ec49be4fe6ddac0503833f3ed8930528e26d1e60ad35c2446da372d16651ce9"}, + {file = "ruff-0.2.2-py3-none-win32.whl", hash = "sha256:d920499b576f6c68295bc04e7b17b6544d9d05f196bb3aac4358792ef6f34325"}, + {file = "ruff-0.2.2-py3-none-win_amd64.whl", hash = "sha256:cc9a91ae137d687f43a44c900e5d95e9617cb37d4c989e462980ba27039d239d"}, + {file = "ruff-0.2.2-py3-none-win_arm64.whl", hash = "sha256:c9d15fc41e6054bfc7200478720570078f0b41c9ae4f010bcc16bd6f4d1aacdd"}, + {file = "ruff-0.2.2.tar.gz", hash = "sha256:e62ed7f36b3068a30ba39193a14274cd706bc486fad521276458022f7bccb31d"}, ] [[package]] @@ -1075,13 +1075,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.0" +version = "2.2.1" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.0-py3-none-any.whl", hash = "sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224"}, - {file = "urllib3-2.2.0.tar.gz", hash = "sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20"}, + {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, + {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, ] [package.extras] @@ -1134,4 +1134,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "c4d03a1586b121b905ea4c0f86d04427cbb3e155e60d67c4b3351186de0d540a" +content-hash = "e88b90ae60758bdab0fe844948c5b9a45bb7f3a96e9ab31ee6d56a7ebd24bfde" diff --git a/libs/partners/anthropic/pyproject.toml b/libs/partners/anthropic/pyproject.toml index 2cdfaba1989d1..2a79c2a4bd6b4 100644 --- a/libs/partners/anthropic/pyproject.toml +++ b/libs/partners/anthropic/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-anthropic" -version = "0.0.2" +version = "0.1.0" description = "An integration package connecting AnthropicMessages and LangChain" authors = [] readme = "README.md" @@ -37,7 +37,8 @@ codespell = "^2.2.0" optional = true [tool.poetry.group.lint.dependencies] -ruff = "^0.1.5" +ruff = ">=0.2.2,<1" +mypy = "^0.991" [tool.poetry.group.typing.dependencies] mypy = "^0.991" diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index 2a057a587ef6e..a27183941638e 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -1,7 +1,14 @@ -"""Test ChatAnthropicMessages chat model.""" +"""Test ChatAnthropic chat model.""" + +from typing import List + +from langchain_core.callbacks import CallbackManager +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompts import ChatPromptTemplate -from langchain_anthropic.chat_models import ChatAnthropicMessages +from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages +from tests.unit_tests._utils import FakeCallbackHandler def test_stream() -> None: @@ -84,3 +91,72 @@ def test_system_invoke() -> None: result = chain.invoke({}) assert isinstance(result.content, str) + + +def test_anthropic_call() -> None: + """Test valid call to anthropic.""" + chat = ChatAnthropic(model="test") + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_anthropic_generate() -> None: + """Test generate method of anthropic.""" + chat = ChatAnthropic(model="test") + chat_messages: List[List[BaseMessage]] = [ + [HumanMessage(content="How many toes do dogs have?")] + ] + messages_copy = [messages.copy() for messages in chat_messages] + result: LLMResult = chat.generate(chat_messages) + assert isinstance(result, LLMResult) + for response in result.generations[0]: + assert isinstance(response, ChatGeneration) + assert isinstance(response.text, str) + assert response.text == response.message.content + assert chat_messages == messages_copy + + +def test_anthropic_streaming() -> None: + """Test streaming tokens from anthropic.""" + chat = ChatAnthropic(model="test") + message = HumanMessage(content="Hello") + response = chat.stream([message]) + for token in response: + assert isinstance(token, AIMessageChunk) + assert isinstance(token.content, str) + + +def test_anthropic_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatAnthropic( + model="test", + callback_manager=callback_manager, + verbose=True, + ) + message = HumanMessage(content="Write me a sentence with 10 words.") + for token in chat.stream([message]): + assert isinstance(token, AIMessageChunk) + assert isinstance(token.content, str) + assert callback_handler.llm_streams > 1 + + +async def test_anthropic_async_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatAnthropic( + model="test", + callback_manager=callback_manager, + verbose=True, + ) + chat_messages: List[BaseMessage] = [ + HumanMessage(content="How many toes do dogs have?") + ] + async for token in chat.astream(chat_messages): + assert isinstance(token, AIMessageChunk) + assert isinstance(token.content, str) + assert callback_handler.llm_streams > 1 diff --git a/libs/partners/anthropic/tests/integration_tests/test_llms.py b/libs/partners/anthropic/tests/integration_tests/test_llms.py new file mode 100644 index 0000000000000..b0c5e4f782cbd --- /dev/null +++ b/libs/partners/anthropic/tests/integration_tests/test_llms.py @@ -0,0 +1,74 @@ +"""Test Anthropic API wrapper.""" + +from typing import Generator + +import pytest +from langchain_core.callbacks import CallbackManager +from langchain_core.outputs import LLMResult + +from langchain_anthropic import Anthropic +from tests.unit_tests._utils import FakeCallbackHandler + + +@pytest.mark.requires("anthropic") +def test_anthropic_model_name_param() -> None: + llm = Anthropic(model_name="foo") + assert llm.model == "foo" + + +@pytest.mark.requires("anthropic") +def test_anthropic_model_param() -> None: + llm = Anthropic(model="foo") + assert llm.model == "foo" + + +def test_anthropic_call() -> None: + """Test valid call to anthropic.""" + llm = Anthropic(model="claude-instant-1") + output = llm("Say foo:") + assert isinstance(output, str) + + +def test_anthropic_streaming() -> None: + """Test streaming tokens from anthropic.""" + llm = Anthropic(model="claude-instant-1") + generator = llm.stream("I'm Pickle Rick") + + assert isinstance(generator, Generator) + + for token in generator: + assert isinstance(token, str) + + +def test_anthropic_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + llm = Anthropic( + streaming=True, + callback_manager=callback_manager, + verbose=True, + ) + llm("Write me a sentence with 100 words.") + assert callback_handler.llm_streams > 1 + + +async def test_anthropic_async_generate() -> None: + """Test async generate.""" + llm = Anthropic() + output = await llm.agenerate(["How many toes do dogs have?"]) + assert isinstance(output, LLMResult) + + +async def test_anthropic_async_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + llm = Anthropic( + streaming=True, + callback_manager=callback_manager, + verbose=True, + ) + result = await llm.agenerate(["How many toes do dogs have?"]) + assert callback_handler.llm_streams > 1 + assert isinstance(result, LLMResult) diff --git a/libs/partners/anthropic/tests/unit_tests/_utils.py b/libs/partners/anthropic/tests/unit_tests/_utils.py new file mode 100644 index 0000000000000..3c053d2cc42bd --- /dev/null +++ b/libs/partners/anthropic/tests/unit_tests/_utils.py @@ -0,0 +1,255 @@ +"""A fake callback handler for testing purposes.""" + +from typing import Any, Union + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.pydantic_v1 import BaseModel + + +class BaseFakeCallbackHandler(BaseModel): + """Base fake callback handler for testing.""" + + starts: int = 0 + ends: int = 0 + errors: int = 0 + text: int = 0 + ignore_llm_: bool = False + ignore_chain_: bool = False + ignore_agent_: bool = False + ignore_retriever_: bool = False + ignore_chat_model_: bool = False + + # to allow for similar callback handlers that are not technicall equal + fake_id: Union[str, None] = None + + # add finer-grained counters for easier debugging of failing tests + chain_starts: int = 0 + chain_ends: int = 0 + llm_starts: int = 0 + llm_ends: int = 0 + llm_streams: int = 0 + tool_starts: int = 0 + tool_ends: int = 0 + agent_actions: int = 0 + agent_ends: int = 0 + chat_model_starts: int = 0 + retriever_starts: int = 0 + retriever_ends: int = 0 + retriever_errors: int = 0 + retries: int = 0 + + +class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): + """Base fake callback handler mixin for testing.""" + + def on_llm_start_common(self) -> None: + self.llm_starts += 1 + self.starts += 1 + + def on_llm_end_common(self) -> None: + self.llm_ends += 1 + self.ends += 1 + + def on_llm_error_common(self) -> None: + self.errors += 1 + + def on_llm_new_token_common(self) -> None: + self.llm_streams += 1 + + def on_retry_common(self) -> None: + self.retries += 1 + + def on_chain_start_common(self) -> None: + self.chain_starts += 1 + self.starts += 1 + + def on_chain_end_common(self) -> None: + self.chain_ends += 1 + self.ends += 1 + + def on_chain_error_common(self) -> None: + self.errors += 1 + + def on_tool_start_common(self) -> None: + self.tool_starts += 1 + self.starts += 1 + + def on_tool_end_common(self) -> None: + self.tool_ends += 1 + self.ends += 1 + + def on_tool_error_common(self) -> None: + self.errors += 1 + + def on_agent_action_common(self) -> None: + self.agent_actions += 1 + self.starts += 1 + + def on_agent_finish_common(self) -> None: + self.agent_ends += 1 + self.ends += 1 + + def on_chat_model_start_common(self) -> None: + self.chat_model_starts += 1 + self.starts += 1 + + def on_text_common(self) -> None: + self.text += 1 + + def on_retriever_start_common(self) -> None: + self.starts += 1 + self.retriever_starts += 1 + + def on_retriever_end_common(self) -> None: + self.ends += 1 + self.retriever_ends += 1 + + def on_retriever_error_common(self) -> None: + self.errors += 1 + self.retriever_errors += 1 + + +class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): + """Fake callback handler for testing.""" + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return self.ignore_llm_ + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return self.ignore_chain_ + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return self.ignore_agent_ + + @property + def ignore_retriever(self) -> bool: + """Whether to ignore retriever callbacks.""" + return self.ignore_retriever_ + + def on_llm_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_start_common() + + def on_llm_new_token( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_new_token_common() + + def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_end_common() + + def on_llm_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_error_common() + + def on_retry( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retry_common() + + def on_chain_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_start_common() + + def on_chain_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_end_common() + + def on_chain_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_error_common() + + def on_tool_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_start_common() + + def on_tool_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_end_common() + + def on_tool_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_error_common() + + def on_agent_action( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_agent_action_common() + + def on_agent_finish( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_agent_finish_common() + + def on_text( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_text_common() + + def on_retriever_start( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retriever_start_common() + + def on_retriever_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retriever_end_common() + + def on_retriever_error( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retriever_error_common() + + def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": + return self diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index ebcf0386aeee3..35232dc90c49d 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -1,10 +1,54 @@ """Test chat model integration.""" +import os -from langchain_anthropic.chat_models import ChatAnthropicMessages +import pytest + +from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages + +os.environ["ANTHROPIC_API_KEY"] = "foo" def test_initialization() -> None: """Test chat model initialization.""" ChatAnthropicMessages(model_name="claude-instant-1.2", anthropic_api_key="xyz") ChatAnthropicMessages(model="claude-instant-1.2", anthropic_api_key="xyz") + + +@pytest.mark.requires("anthropic") +def test_anthropic_model_name_param() -> None: + llm = ChatAnthropic(model_name="foo") + assert llm.model == "foo" + + +@pytest.mark.requires("anthropic") +def test_anthropic_model_param() -> None: + llm = ChatAnthropic(model="foo") + assert llm.model == "foo" + + +@pytest.mark.requires("anthropic") +def test_anthropic_model_kwargs() -> None: + llm = ChatAnthropic(model_name="foo", model_kwargs={"foo": "bar"}) + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.requires("anthropic") +def test_anthropic_invalid_model_kwargs() -> None: + with pytest.raises(ValueError): + ChatAnthropic(model="foo", model_kwargs={"max_tokens_to_sample": 5}) + + +@pytest.mark.requires("anthropic") +def test_anthropic_incorrect_field() -> None: + with pytest.warns(match="not default parameter"): + llm = ChatAnthropic(model="foo", foo="bar") + assert llm.model_kwargs == {"foo": "bar"} + + +@pytest.mark.requires("anthropic") +def test_anthropic_initialization() -> None: + """Test anthropic initialization.""" + # Verify that chat anthropic can be initialized using a secret key provided + # as a parameter rather than an environment variable. + ChatAnthropic(model="test", anthropic_api_key="test") diff --git a/libs/partners/anthropic/tests/unit_tests/test_imports.py b/libs/partners/anthropic/tests/unit_tests/test_imports.py index c37e55b9cd34d..e714099037af7 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_imports.py +++ b/libs/partners/anthropic/tests/unit_tests/test_imports.py @@ -1,6 +1,6 @@ from langchain_anthropic import __all__ -EXPECTED_ALL = ["ChatAnthropicMessages"] +EXPECTED_ALL = ["ChatAnthropicMessages", "ChatAnthropic", "Anthropic", "AnthropicLLM"] def test_all_imports() -> None: