Skip to content

Commit

Permalink
feat: support for bedrock mistral models and claude messages api (#7543)
Browse files Browse the repository at this point in the history
* feat: support for bedrock mistral models and claude messages api

* feat: support for bedrock mistral models and claude messages api

* fix mypy
  • Loading branch information
tstadel authored and vblagoje committed Apr 23, 2024
1 parent 30e683c commit ec4862a
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
- id: ruff

- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: b8ecc9b3acf31690c3cb2fc5bb03a3fbbbc2d7a3
hooks:
- id: codespell
additional_dependencies:
Expand Down
72 changes: 62 additions & 10 deletions haystack/nodes/prompt/invocation_layer/amazon_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,47 @@ class AnthropicClaudeAdapter(BedrockModelAdapter):
Model adapter for the Anthropic's Claude model.
"""

def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
default_params = {
"max_tokens_to_sample": self.max_length,
"stop_sequences": ["\n\nHuman:"],
"temperature": None,
"top_p": None,
"top_k": None,
}
params = self._get_params(inference_kwargs, default_params)
def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> None:
self.use_messages_api = model_kwargs.get("use_messages_api", True)
super().__init__(model_kwargs, max_length)

body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params}
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
if self.use_messages_api:
default_params: Dict[str, Any] = {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": self.max_length,
"system": None,
"stop_sequences": None,
"temperature": None,
"top_p": None,
"top_k": None,
}
params = self._get_params(inference_kwargs, default_params)

body = {"messages": [{"role": "user", "content": prompt}], **params}
else:
default_params = {
"max_tokens_to_sample": self.max_length,
"stop_sequences": ["\n\nHuman:"],
"temperature": None,
"top_p": None,
"top_k": None,
}
params = self._get_params(inference_kwargs, default_params)

body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params}
return body

def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
if self.use_messages_api:
return [content["text"] for content in response_body["content"]]

return [response_body["completion"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
if self.use_messages_api:
return chunk.get("delta", {}).get("text", "")

return chunk.get("completion", "")


Expand Down Expand Up @@ -197,6 +221,33 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
return chunk.get("generation", "")


class MistralAIAdapter(BedrockModelAdapter):
"""
Model adapter for the Mistral's AI models.
"""

def prepare_body(self, prompt: str, **inference_kwargs: Any) -> Dict[str, Any]:
default_params = {
"max_tokens": self.max_length,
"stop": None,
"temperature": None,
"top_p": None,
"top_k": None,
}
params = self._get_params(inference_kwargs, default_params)

body = {"prompt": prompt, **params}
return body

def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
return [output["text"] for output in response_body["outputs"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
outputs: List[Dict[str, str]] = chunk.get("outputs", [])
output = next(iter(outputs), {})
return output.get("text", "")


class AmazonBedrockInvocationLayer(AWSBaseInvocationLayer):
"""
Invocation layer for Amazon Bedrock models.
Expand All @@ -208,6 +259,7 @@ class AmazonBedrockInvocationLayer(AWSBaseInvocationLayer):
r"cohere.command.*": CohereCommandAdapter,
r"anthropic.claude.*": AnthropicClaudeAdapter,
r"meta.llama2.*": MetaLlama2ChatAdapter,
r"mistral.mi[sx]tral.*": MistralAIAdapter, # codespell:ignore tral
}

def __init__(
Expand Down
Loading

0 comments on commit ec4862a

Please sign in to comment.