Skip to content

Commit

Permalink
feat: support Claude v3, Llama3 and Command R models on Amazon Bedrock (
Browse files Browse the repository at this point in the history
#809)

* feat: support Claude v3 and Cohere Command R models on Amazon Bedrock

* revert chat pattern change

* rename llama adapter

* fix tests after llama adapter rename
  • Loading branch information
tstadel authored Jun 14, 2024
1 parent 590e2b0 commit 5e66f1d
Show file tree
Hide file tree
Showing 3 changed files with 413 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ class AnthropicClaudeAdapter(BedrockModelAdapter):
Adapter for the Anthropic Claude models.
"""

def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> None:
self.use_messages_api = model_kwargs.get("use_messages_api", True)
super().__init__(model_kwargs, max_length)

def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
"""
Prepares the body for the Claude model
Expand All @@ -108,16 +112,30 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
- `prompt`: The prompt to be sent to the model.
- specified inference parameters.
"""
default_params = {
"max_tokens_to_sample": self.max_length,
"stop_sequences": ["\n\nHuman:"],
"temperature": None,
"top_p": None,
"top_k": None,
}
params = self._get_params(inference_kwargs, default_params)

body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params}
if self.use_messages_api:
default_params: Dict[str, Any] = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": self.max_length,
"system": None,
"stop_sequences": None,
"temperature": None,
"top_p": None,
"top_k": None,
}
params = self._get_params(inference_kwargs, default_params)

body = {"messages": [{"role": "user", "content": prompt}], **params}
else:
default_params = {
"max_tokens_to_sample": self.max_length,
"stop_sequences": ["\n\nHuman:"],
"temperature": None,
"top_p": None,
"top_k": None,
}
params = self._get_params(inference_kwargs, default_params)

body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params}
return body

def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
Expand All @@ -127,6 +145,9 @@ def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> L
:param response_body: The response body from the Amazon Bedrock request.
:returns: A list of string responses.
"""
if self.use_messages_api:
return [content["text"] for content in response_body["content"]]

return [response_body["completion"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Expand All @@ -136,6 +157,9 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
:param chunk: The streaming chunk.
:returns: A string token.
"""
if self.use_messages_api:
return chunk.get("delta", {}).get("text", "")

return chunk.get("completion", "")


Expand Down Expand Up @@ -240,6 +264,66 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
return chunk.get("text", "")


class CohereCommandRAdapter(BedrockModelAdapter):
"""
Adapter for the Cohere Command R models.
"""

def prepare_body(self, prompt: str, **inference_kwargs: Any) -> Dict[str, Any]:
"""
Prepares the body for the Command model
:param prompt: The prompt to be sent to the model.
:param inference_kwargs: Additional keyword arguments passed to the handler.
:returns: A dictionary with the following keys:
- `prompt`: The prompt to be sent to the model.
- specified inference parameters.
"""
default_params = {
"chat_history": None,
"documents": None,
"search_query_only": None,
"preamble": None,
"max_tokens": self.max_length,
"temperature": None,
"p": None,
"k": None,
"prompt_truncation": None,
"frequency_penalty": None,
"presence_penalty": None,
"seed": None,
"return_prompt": None,
"tools": None,
"tool_results": None,
"stop_sequences": None,
"raw_prompting": None,
}
params = self._get_params(inference_kwargs, default_params)

body = {"message": prompt, **params}
return body

def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
"""
Extracts the responses from the Cohere Command model response.
:param response_body: The response body from the Amazon Bedrock request.
:returns: A list of string responses.
"""
responses = [response_body["text"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""
Extracts the token from a streaming chunk.
:param chunk: The streaming chunk.
:returns: A string token.
"""
token: str = chunk.get("text", "")
return token


class AI21LabsJurassic2Adapter(BedrockModelAdapter):
"""
Model adapter for AI21 Labs' Jurassic 2 models.
Expand Down Expand Up @@ -324,7 +408,7 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
return chunk.get("outputText", "")


class MetaLlama2ChatAdapter(BedrockModelAdapter):
class MetaLlamaAdapter(BedrockModelAdapter):
"""
Adapter for Meta's Llama2 models.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
AnthropicClaudeAdapter,
BedrockModelAdapter,
CohereCommandAdapter,
MetaLlama2ChatAdapter,
CohereCommandRAdapter,
MetaLlamaAdapter,
MistralAdapter,
)
from .handlers import (
Expand Down Expand Up @@ -56,9 +57,10 @@ class AmazonBedrockGenerator:
SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = {
r"amazon.titan-text.*": AmazonTitanAdapter,
r"ai21.j2.*": AI21LabsJurassic2Adapter,
r"cohere.command.*": CohereCommandAdapter,
r"cohere.command-[^r].*": CohereCommandAdapter,
r"cohere.command-r.*": CohereCommandRAdapter,
r"anthropic.claude.*": AnthropicClaudeAdapter,
r"meta.llama2.*": MetaLlama2ChatAdapter,
r"meta.llama.*": MetaLlamaAdapter,
r"mistral.*": MistralAdapter,
}

Expand Down
Loading

0 comments on commit 5e66f1d

Please sign in to comment.