Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AnthropicVertexChatGenerator component #1192

Merged
merged 8 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Create adapter class and add VertexAPI
  • Loading branch information
Amnah199 committed Nov 14, 2024
commit 7d54e95d94df73ee64a53acd5887517730ea9d49
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, ClassVar, Dict, Type

from haystack.core.errors import DeserializationError
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret

with LazyImport('Run pip install -U google-cloud-aiplatform "anthropic[vertex]".'):
from anthropic import AnthropicVertex


with LazyImport("Run pip install anthropic."):
from anthropic import Anthropic


@dataclass(frozen=True)
class BaseAdapter(ABC):
"""
Base class for model APIs supported by AnthropicGenerator.
"""

TYPE: ClassVar[str]

def to_dict(self) -> Dict[str, Any]:
"""
Converts the object to a dictionary representation for serialization.
"""
_fields = {}
for _field in fields(self):
if _field.type is Secret:
_fields[_field.name] = getattr(self, _field.name).to_dict()
else:
_fields[_field.name] = getattr(self, _field.name)

return {"type": self.TYPE, "init_parameters": _fields}

@staticmethod
def from_dict(data: Dict[str, Any]) -> "BaseAdapter":
"""
Converts a dictionary representation to an adapter object.
"""

@abstractmethod
def client(self):
"""
Resolves all the secrets and evironment variables and returns the corresponding adapter object.
All subclasses must implement this method.
"""

@abstractmethod
def set_model(self, model):
"""
Sets the model name in the format required by the API.
"""


@dataclass(frozen=True)
class AnthropicAdapter(BaseAdapter):
"""
Model adapter for the Anthropic API. It will load the api_key from the environment variable `ANTHROPIC_API_KEY`.
"""

api_key: Secret
TYPE = "anthropic"

def client(self) -> Anthropic:
return Anthropic(api_key=self.api_key.resolve_value())

def set_model(self, model) -> str:
return model # default model name format is correct for Anthropic API


@dataclass(frozen=True)
class AnthropicVertexAdapter(BaseAdapter):
"""
Model adapter for the Anthropic Vertex API. It authenticate using GCP authentication and select
`REGION` and `PROJECT_ID` from the environment variable.
"""

TYPE = "anthropic_vertex"
region: str
project_id: str

def client(self) -> AnthropicVertex:
return AnthropicVertex(region=self.region, project_id=self.project_id)

def set_model(self, model) -> str:
"""
Converts the model name to the format required by the Anthropic Vertex API.
AnthropicVertex requires model name in the format `claude-3-sonnet@20240229`
instead of `claude-3-sonnet-20240229`.
"""
return model[::-1].replace("-", "@", 1)[::-1]

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable

from anthropic import Anthropic, Stream
from anthropic import Stream
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
Expand All @@ -14,6 +15,8 @@
MessageStreamEvent,
)

from .adapter import AnthropicAdapter, AnthropicVertexAdapter, BaseAdapter

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -52,7 +55,9 @@ class AnthropicGenerator:

def __init__(
self,
api_key: Secret = Secret.from_env_var("ANTHROPIC_API_KEY"), # noqa: B008
api_key: Optional[Secret] = Secret.from_env_var("ANTHROPIC_API_KEY"), # noqa: B008
region: Optional[str] = None,
project_id: Optional[str] = None,
model: str = "claude-3-sonnet-20240229",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
system_prompt: Optional[str] = None,
Expand All @@ -68,11 +73,25 @@ def __init__(
:param generation_kwargs: Additional keyword arguments for generation.
"""
self.api_key = api_key
self.region = region or os.environ.get("REGION") or None
self.project_id = project_id or os.environ.get("PROJECT_ID")
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback
self.system_prompt = system_prompt
self.client = Anthropic(api_key=self.api_key.resolve_value())

def get_model_adapter(self) -> "BaseAdapter":
"""
Factory method to select model adapter based on provided secrets.
"""
if self.region and self.project_id:
return AnthropicVertexAdapter(region=self.region, project_id=self.project_id)
elif self.api_key:
return AnthropicAdapter(api_key=self.api_key)
else:
msg = "Failed to select a model:'ANTHROPIC_API_KEY' env variable must be provided for Anthropic,"
"or both 'REGION' and 'PROJECT_ID' env variables must be provided for AnthropicVertex."
raise ValueError(msg)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -96,7 +115,9 @@ def to_dict(self) -> Dict[str, Any]:
streaming_callback=callback_name,
system_prompt=self.system_prompt,
generation_kwargs=self.generation_kwargs,
api_key=self.api_key.to_dict(),
api_key=self.api_key.to_dict() if self.api_key else None,
region=self.region,
project_id=self.project_id,
)

@classmethod
Expand Down Expand Up @@ -126,6 +147,12 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
- `replies`: A list of generated replies.
- `meta`: A list of metadata dictionaries for each reply.
"""

# Select the model adapter based on the provided env variables
model_adapter_cls = self.get_model_adapter()
client = model_adapter_cls.client()
model_name = model_adapter_cls.set_model(self.model) # set the model name in the format required by the API

# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
filtered_generation_kwargs = {k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS}
Expand All @@ -136,10 +163,10 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
f"Allowed parameters are {self.ALLOWED_PARAMS}."
)

response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create(
response: Union[Message, Stream[MessageStreamEvent]] = client.messages.create(
max_tokens=filtered_generation_kwargs.pop("max_tokens", 512),
system=self.system_prompt if self.system_prompt else filtered_generation_kwargs.pop("system", ""),
model=self.model,
model=model_name,
messages=[MessageParam(content=prompt, role="user")],
stream=self.streaming_callback is not None,
**filtered_generation_kwargs,
Expand Down
Loading