Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Bedrock class to support Claude-3 model and improve error handling #720

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 52 additions & 20 deletions dsp/modules/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from __future__ import annotations

import json
from typing import Any, Optional
from dataclasses import dataclass
from typing import Any

from dsp.modules.aws_lm import AWSLM


@dataclass
class ChatMessage:
role: str
content: str


class Bedrock(AWSLM):
def __init__(
self,
region_name: str,
model: str,
profile_name: Optional[str] = None,
input_output_ratio: int = 3,
max_new_tokens: int = 1500,
self,
region_name: str,
model: str,
profile_name: str | None = None,
input_output_ratio: int = 3,
max_new_tokens: int = 1500,
) -> None:
"""Use an AWS Bedrock language model.
NOTE: You must first configure your AWS credentials with the AWS CLI before using this model!
Expand All @@ -37,26 +44,36 @@ def __init__(
)
self._validate_model(model)
self.provider = "claude" if "claude" in model.lower() else "bedrock"
self.use_messages = "claude-3" in model.lower()

def _validate_model(self, model: str) -> None:
if "claude" not in model.lower():
raise NotImplementedError("Only claude models are supported as of now")

def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]:
def _create_body(self, prompt: str, system_prompt: str | None = None, **kwargs) -> dict[str, Any]:
base_args: dict[str, Any] = {
"max_tokens_to_sample": self._max_new_tokens,
"anthropic_version": "bedrock-2023-05-31",
}
for k, v in kwargs.items():
base_args[k] = v

query_args: dict[str, Any] = self._sanitize_kwargs(base_args)
query_args["prompt"] = prompt
# AWS Bedrock forbids these keys
if "max_tokens" in query_args:
max_tokens: int = query_args["max_tokens"]
input_tokens: int = self._estimate_tokens(prompt)
max_tokens_to_sample: int = max_tokens - input_tokens
del query_args["max_tokens"]
query_args["max_tokens_to_sample"] = max_tokens_to_sample

if self.use_messages:
messages = [ChatMessage(role="user", content=prompt)]
if system_prompt:
messages.insert(0, ChatMessage(role="system", content=system_prompt))
else:
messages.insert(0, ChatMessage(role="system", content="You are a helpful AI assistant."))
serialized_messages = [vars(m) for m in messages if m.role != "system"]
system_message = next(m["content"] for m in [vars(m) for m in messages if m.role == "system"])
query_args["messages"] = serialized_messages
query_args["system"] = system_message
query_args["max_tokens"] = self._max_new_tokens
else:
query_args["prompt"] = self._format_prompt(prompt)
query_args["max_tokens_to_sample"] = self._max_new_tokens

return query_args

def _call_model(self, body: str) -> str:
Expand All @@ -67,13 +84,28 @@ def _call_model(self, body: str) -> str:
contentType="application/json",
)
response_body = json.loads(response["body"].read())
completion = response_body["completion"]

if self.use_messages: # Claude-3 model
try:
completion = response_body['content'][0]['text']
except (KeyError, IndexError):
raise ValueError("Unexpected response format from the Claude-3 model.")
else: # Other models
expected_keys = ["completion", "text"]
found_key = next((key for key in expected_keys if key in response_body), None)

if found_key:
completion = response_body[found_key]
else:
raise ValueError(
f"Unexpected response format from the model. Expected one of {', '.join(expected_keys)} keys.")

return completion

def _extract_input_parameters(
self, body: dict[Any, Any],
self, body: dict[Any, Any],
) -> dict[str, str | float | int]:
return body

def _format_prompt(self, raw_prompt: str) -> str:
return "\n\nHuman: " + raw_prompt + "\n\nAssistant:"
return "\n\nHuman: " + raw_prompt + "\n\nAssistant:"
Loading