-
Notifications
You must be signed in to change notification settings - Fork 128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AnthropicVertexChatGenerator component #1192
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
3ec5c42
Created a model adapter
Amnah199 7d54e95
Create adapter class and add VertexAPI
Amnah199 384c154
Add chat generator for Anthropic Vertex
Amnah199 70cddc6
Add tests
Amnah199 e3322f0
Small fix
Amnah199 6590007
Improve doc_strings
Amnah199 bf65c61
Make project_id and region mandatory params
Amnah199 bbeda3d
Small fix
Amnah199 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
135 changes: 135 additions & 0 deletions
135
...c/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import os | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
from haystack import component, default_from_dict, default_to_dict, logging | ||
from haystack.dataclasses import StreamingChunk | ||
from haystack.utils import deserialize_callable, serialize_callable | ||
|
||
from anthropic import AnthropicVertex | ||
|
||
from .chat_generator import AnthropicChatGenerator | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@component | ||
class AnthropicVertexChatGenerator(AnthropicChatGenerator): | ||
""" | ||
|
||
Enables text generation using state-of-the-art Claude 3 LLMs via the Anthropic Vertex AI API. | ||
It supports models such as `Claude 3.5 Sonnet`, `Claude 3 Opus`, `Claude 3 Sonnet`, and `Claude 3 Haiku`, | ||
accessible through the Vertex AI API endpoint. | ||
|
||
To use AnthropicVertexChatGenerator, you must have a GCP project with Vertex AI enabled. | ||
Additionally, ensure that the desired Anthropic model is activated in the Vertex AI Model Garden. | ||
Before making requests, you may need to authenticate with GCP using `gcloud auth login`. | ||
For more details, refer to the [guide] (https://docs.anthropic.com/en/api/claude-on-vertex-ai). | ||
|
||
Any valid text generation parameters for the Anthropic messaging API can be passed to | ||
the AnthropicVertex API. Users can provide these parameters directly to the component via | ||
the `generation_kwargs` parameter in `__init__` or the `run` method. | ||
|
||
For more details on the parameters supported by the Anthropic API, refer to the | ||
Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages). | ||
|
||
```python | ||
from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator | ||
from haystack.dataclasses import ChatMessage | ||
|
||
messages = [ChatMessage.from_user("What's Natural Language Processing?")] | ||
client = AnthropicVertexChatGenerator( | ||
model="claude-3-sonnet@20240229", | ||
project_id="your-project-id", region="your-region" | ||
) | ||
response = client.run(messages) | ||
print(response) | ||
|
||
>> {'replies': [ChatMessage(content='Natural Language Processing (NLP) is a field of artificial intelligence that | ||
>> focuses on enabling computers to understand, interpret, and generate human language. It involves developing | ||
>> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and | ||
>> communicate in natural languages like English, Spanish, or Chinese.', role=<ChatRole.ASSISTANT: 'assistant'>, | ||
>> name=None, meta={'model': 'claude-3-sonnet@20240229', 'index': 0, 'finish_reason': 'end_turn', | ||
>> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]} | ||
``` | ||
|
||
For more details on supported models and their capabilities, refer to the Anthropic | ||
[documentation](https://docs.anthropic.com/claude/docs/intro-to-claude). | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
region: str, | ||
project_id: str, | ||
model: str = "claude-3-5-sonnet@20240620", | ||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, | ||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||
ignore_tools_thinking_messages: bool = True, | ||
): | ||
""" | ||
Creates an instance of AnthropicVertexChatGenerator. | ||
|
||
:param region: The region where the Anthropic model is deployed. Defaults to "us-central1". | ||
:param project_id: The GCP project ID where the Anthropic model is deployed. | ||
:param model: The name of the model to use. | ||
:param streaming_callback: A callback function that is called when a new token is received from the stream. | ||
The callback function accepts StreamingChunk as an argument. | ||
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to | ||
the AnthropicVertex endpoint. See Anthropic [documentation](https://docs.anthropic.com/claude/reference/messages_post) | ||
for more details. | ||
|
||
Supported generation_kwargs parameters are: | ||
- `system`: The system message to be passed to the model. | ||
- `max_tokens`: The maximum number of tokens to generate. | ||
- `metadata`: A dictionary of metadata to be passed to the model. | ||
- `stop_sequences`: A list of strings that the model should stop generating at. | ||
- `temperature`: The temperature to use for sampling. | ||
- `top_p`: The top_p value to use for nucleus sampling. | ||
- `top_k`: The top_k value to use for top-k sampling. | ||
- `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features). | ||
:param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a | ||
"chain of thought" messages before returning the actual function names and parameters in a message. If | ||
`ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool | ||
use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) | ||
for more details. | ||
""" | ||
self.region = region or os.environ.get("REGION") | ||
self.project_id = project_id or os.environ.get("PROJECT_ID") | ||
self.model = model | ||
self.generation_kwargs = generation_kwargs or {} | ||
self.streaming_callback = streaming_callback | ||
self.client = AnthropicVertex(region=self.region, project_id=self.project_id) | ||
self.ignore_tools_thinking_messages = ignore_tools_thinking_messages | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
|
||
:returns: | ||
The serialized component as a dictionary. | ||
""" | ||
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None | ||
return default_to_dict( | ||
self, | ||
region=self.region, | ||
project_id=self.project_id, | ||
model=self.model, | ||
streaming_callback=callback_name, | ||
generation_kwargs=self.generation_kwargs, | ||
ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "AnthropicVertexChatGenerator": | ||
""" | ||
Deserialize this component from a dictionary. | ||
|
||
:param data: The dictionary representation of this component. | ||
:returns: | ||
The deserialized component instance. | ||
""" | ||
init_params = data.get("init_parameters", {}) | ||
serialized_callback_handler = init_params.get("streaming_callback") | ||
if serialized_callback_handler: | ||
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) | ||
return default_from_dict(cls, data) |
197 changes: 197 additions & 0 deletions
197
integrations/anthropic/tests/test_vertex_chat_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import os | ||
|
||
import anthropic | ||
import pytest | ||
from haystack.components.generators.utils import print_streaming_chunk | ||
from haystack.dataclasses import ChatMessage, ChatRole | ||
|
||
from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator | ||
|
||
|
||
@pytest.fixture | ||
def chat_messages(): | ||
return [ | ||
ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), | ||
ChatMessage.from_user("What's the capital of France?"), | ||
] | ||
|
||
|
||
class TestAnthropicVertexChatGenerator: | ||
def test_init_default(self): | ||
component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") | ||
assert component.region == "us-central1" | ||
assert component.project_id == "test-project-id" | ||
assert component.model == "claude-3-5-sonnet@20240620" | ||
assert component.streaming_callback is None | ||
assert not component.generation_kwargs | ||
assert component.ignore_tools_thinking_messages | ||
|
||
def test_init_with_parameters(self): | ||
component = AnthropicVertexChatGenerator( | ||
region="us-central1", | ||
project_id="test-project-id", | ||
model="claude-3-5-sonnet@20240620", | ||
streaming_callback=print_streaming_chunk, | ||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | ||
ignore_tools_thinking_messages=False, | ||
) | ||
assert component.region == "us-central1" | ||
assert component.project_id == "test-project-id" | ||
assert component.model == "claude-3-5-sonnet@20240620" | ||
assert component.streaming_callback is print_streaming_chunk | ||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} | ||
assert component.ignore_tools_thinking_messages is False | ||
|
||
def test_to_dict_default(self): | ||
component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") | ||
data = component.to_dict() | ||
assert data == { | ||
"type": ( | ||
"haystack_integrations.components.generators." | ||
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" | ||
), | ||
"init_parameters": { | ||
"region": "us-central1", | ||
"project_id": "test-project-id", | ||
"model": "claude-3-5-sonnet@20240620", | ||
"streaming_callback": None, | ||
"generation_kwargs": {}, | ||
"ignore_tools_thinking_messages": True, | ||
}, | ||
} | ||
|
||
def test_to_dict_with_parameters(self): | ||
component = AnthropicVertexChatGenerator( | ||
region="us-central1", | ||
project_id="test-project-id", | ||
streaming_callback=print_streaming_chunk, | ||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | ||
) | ||
data = component.to_dict() | ||
assert data == { | ||
"type": ( | ||
"haystack_integrations.components.generators." | ||
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" | ||
), | ||
"init_parameters": { | ||
"region": "us-central1", | ||
"project_id": "test-project-id", | ||
"model": "claude-3-5-sonnet@20240620", | ||
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", | ||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | ||
"ignore_tools_thinking_messages": True, | ||
}, | ||
} | ||
|
||
def test_to_dict_with_lambda_streaming_callback(self): | ||
component = AnthropicVertexChatGenerator( | ||
region="us-central1", | ||
project_id="test-project-id", | ||
model="claude-3-5-sonnet@20240620", | ||
streaming_callback=lambda x: x, | ||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | ||
) | ||
data = component.to_dict() | ||
assert data == { | ||
"type": ( | ||
"haystack_integrations.components.generators." | ||
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" | ||
), | ||
"init_parameters": { | ||
"region": "us-central1", | ||
"project_id": "test-project-id", | ||
"model": "claude-3-5-sonnet@20240620", | ||
"streaming_callback": "tests.test_vertex_chat_generator.<lambda>", | ||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | ||
"ignore_tools_thinking_messages": True, | ||
}, | ||
} | ||
|
||
def test_from_dict(self): | ||
data = { | ||
"type": ( | ||
"haystack_integrations.components.generators." | ||
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" | ||
), | ||
"init_parameters": { | ||
"region": "us-central1", | ||
"project_id": "test-project-id", | ||
"model": "claude-3-5-sonnet@20240620", | ||
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", | ||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | ||
"ignore_tools_thinking_messages": True, | ||
}, | ||
} | ||
component = AnthropicVertexChatGenerator.from_dict(data) | ||
assert component.model == "claude-3-5-sonnet@20240620" | ||
assert component.region == "us-central1" | ||
assert component.project_id == "test-project-id" | ||
assert component.streaming_callback is print_streaming_chunk | ||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} | ||
|
||
def test_run(self, chat_messages, mock_chat_completion): | ||
component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") | ||
response = component.run(chat_messages) | ||
|
||
# check that the component returns the correct ChatMessage response | ||
assert isinstance(response, dict) | ||
assert "replies" in response | ||
assert isinstance(response["replies"], list) | ||
assert len(response["replies"]) == 1 | ||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]] | ||
|
||
def test_run_with_params(self, chat_messages, mock_chat_completion): | ||
component = AnthropicVertexChatGenerator( | ||
region="us-central1", project_id="test-project-id", generation_kwargs={"max_tokens": 10, "temperature": 0.5} | ||
) | ||
response = component.run(chat_messages) | ||
|
||
# check that the component calls the Anthropic API with the correct parameters | ||
_, kwargs = mock_chat_completion.call_args | ||
assert kwargs["max_tokens"] == 10 | ||
assert kwargs["temperature"] == 0.5 | ||
|
||
# check that the component returns the correct response | ||
assert isinstance(response, dict) | ||
assert "replies" in response | ||
assert isinstance(response["replies"], list) | ||
assert len(response["replies"]) == 1 | ||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]] | ||
|
||
@pytest.mark.skipif( | ||
not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), | ||
reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", | ||
) | ||
@pytest.mark.integration | ||
def test_live_run_wrong_model(self, chat_messages): | ||
component = AnthropicVertexChatGenerator( | ||
model="something-obviously-wrong", region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID") | ||
) | ||
with pytest.raises(anthropic.NotFoundError): | ||
component.run(chat_messages) | ||
|
||
@pytest.mark.skipif( | ||
not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), | ||
reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", | ||
) | ||
@pytest.mark.integration | ||
def test_default_inference_params(self, chat_messages): | ||
client = AnthropicVertexChatGenerator( | ||
region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID"), model="claude-3-sonnet@20240229" | ||
) | ||
response = client.run(chat_messages) | ||
|
||
assert "replies" in response, "Response does not contain 'replies' key" | ||
replies = response["replies"] | ||
assert isinstance(replies, list), "Replies is not a list" | ||
assert len(replies) > 0, "No replies received" | ||
|
||
first_reply = replies[0] | ||
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" | ||
assert first_reply.content, "First reply has no content" | ||
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" | ||
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" | ||
assert first_reply.meta, "First reply has no metadata" | ||
|
||
# Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint, | ||
# remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator. |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be
Secret
s I think. If you do it like this you can leak them when serializing the component.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case
project_id
andregion
do not grant access to GCP resources on their own, as proper authentication and permissions are required. So I thought these values are not sensitive data; but more of configuration parameters. Wdyt?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, we treated them as secrets on some other components though. Good for me to change this.
In any case I would make them mandatory arguments. If the user wants to pass them as env vars that's up to them.