Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-risch committed Dec 29, 2023
1 parent 6356845 commit 8b0581f
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 227 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class AmazonBedrockError(Exception):
"""

def __init__(
self, message: Optional[str] = None,
self,
message: Optional[str] = None,
):
super().__init__()
if message:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ class AmazonBedrockGenerator:
}

def __init__(
self,
model_name_or_path: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
max_length: Optional[int] = 100,
**kwargs,
self,
model_name_or_path: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
max_length: Optional[int] = 100,
**kwargs,
):
if model_name_or_path is None or len(model_name_or_path) == 0:
msg = "model_name_or_path cannot be None or empty string"
Expand All @@ -96,8 +96,10 @@ def __init__(
)
self.client = session.client("bedrock-runtime")
except Exception as exception:
msg = ("Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration")
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

model_input_kwargs = kwargs
Expand All @@ -115,23 +117,19 @@ def __init__(
max_length=self.max_length or 100,
)

model_apapter_cls = self.get_model_adapter(
model_name_or_path=model_name_or_path
)
model_apapter_cls = self.get_model_adapter(model_name_or_path=model_name_or_path)
if not model_apapter_cls:
msg = f"AmazonBedrockGenerator doesn't support the model {model_name_or_path}."
raise AmazonBedrockConfigurationError(msg)
self.model_adapter = model_apapter_cls(
model_kwargs=model_input_kwargs, max_length=self.max_length
)
self.model_adapter = model_apapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length)

def _ensure_token_limit(
self, prompt: Union[str, List[Dict[str, str]]]
) -> Union[str, List[Dict[str, str]]]:
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
# the prompt for this model will be of the type str
if isinstance(prompt, List):
msg = ("AmazonBedrockGenerator only supports a string as a prompt, "
"while currently, the prompt is of type List.")
msg = (
"AmazonBedrockGenerator only supports a string as a prompt, "
"while currently, the prompt is of type List."
)
raise ValueError(msg)

resize_info = self.prompt_handler(prompt)
Expand All @@ -156,32 +154,29 @@ def supports(cls, model_name_or_path, **kwargs):
try:
session = cls.get_aws_session(**kwargs)
bedrock = session.client("bedrock")
foundation_models_response = bedrock.list_foundation_models(
byOutputModality="TEXT"
)
available_model_ids = [
entry["modelId"]
for entry in foundation_models_response.get("modelSummaries", [])
]
foundation_models_response = bedrock.list_foundation_models(byOutputModality="TEXT")
available_model_ids = [entry["modelId"] for entry in foundation_models_response.get("modelSummaries", [])]
model_ids_supporting_streaming = [
entry["modelId"]
for entry in foundation_models_response.get("modelSummaries", [])
if entry.get("responseStreamingSupported", False)
]
except AWSConfigurationError as exception:
raise AmazonBedrockConfigurationError(
message=exception.message
) from exception
raise AmazonBedrockConfigurationError(message=exception.message) from exception
except Exception as exception:
msg = ("Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration")
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

model_available = model_name_or_path in available_model_ids
if not model_available:
msg = (f"The model {model_name_or_path} is not available in Amazon Bedrock. "
f"Make sure the model you want to use is available in the configured AWS region and "
f"you have access.")
msg = (
f"The model {model_name_or_path} is not available in Amazon Bedrock. "
f"Make sure the model you want to use is available in the configured AWS region and "
f"you have access."
)
raise AmazonBedrockConfigurationError(msg)

stream: bool = kwargs.get("stream", False)
Expand All @@ -195,13 +190,13 @@ def supports(cls, model_name_or_path, **kwargs):
def invoke(self, *args, **kwargs):
kwargs = kwargs.copy()
prompt: str = kwargs.pop("prompt", None)
stream: bool = kwargs.get(
"stream", self.model_adapter.model_kwargs.get("stream", False)
)
stream: bool = kwargs.get("stream", self.model_adapter.model_kwargs.get("stream", False))

if not prompt or not isinstance(prompt, (str, list)):
msg = (f"The model {self.model_name_or_path} requires a valid prompt, but currently, it has no prompt. "
f"Make sure to provide a prompt in the format that the model expects.")
msg = (
f"The model {self.model_name_or_path} requires a valid prompt, but currently, it has no prompt. "
f"Make sure to provide a prompt in the format that the model expects."
)
raise ValueError(msg)

body = self.model_adapter.prepare_body(prompt=prompt, **kwargs)
Expand All @@ -216,13 +211,9 @@ def invoke(self, *args, **kwargs):
response_stream = response["body"]
handler: TokenStreamingHandler = kwargs.get(
"stream_handler",
self.model_adapter.model_kwargs.get(
"stream_handler", DefaultTokenStreamingHandler()
),
)
responses = self.model_adapter.get_stream_responses(
stream=response_stream, stream_handler=handler
self.model_adapter.model_kwargs.get("stream_handler", DefaultTokenStreamingHandler()),
)
responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler)
else:
response = self.client.invoke_model(
body=json.dumps(body),
Expand All @@ -231,13 +222,13 @@ def invoke(self, *args, **kwargs):
contentType="application/json",
)
response_body = json.loads(response.get("body").read().decode("utf-8"))
responses = self.model_adapter.get_responses(
response_body=response_body
)
responses = self.model_adapter.get_responses(response_body=response_body)
except ClientError as exception:
msg = (f"Could not connect to Amazon Bedrock model {self.model_name_or_path}. "
f"Make sure your AWS environment is configured correctly, "
f"the model is available in the configured AWS region, and you have access.")
msg = (
f"Could not connect to Amazon Bedrock model {self.model_name_or_path}. "
f"Make sure your AWS environment is configured correctly, "
f"the model is available in the configured AWS region, and you have access."
)
raise AmazonBedrockInferenceError(msg) from exception

return responses
Expand All @@ -247,9 +238,7 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
pass

@classmethod
def get_model_adapter(
cls, model_name_or_path: str
) -> Optional[Type[BedrockModelAdapter]]:
def get_model_adapter(cls, model_name_or_path: str) -> Optional[Type[BedrockModelAdapter]]:
for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items():
if re.fullmatch(pattern, model_name_or_path):
return adapter
Expand All @@ -267,13 +256,13 @@ def aws_configured(cls, **kwargs) -> bool:

@classmethod
def get_aws_session(
cls,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
**kwargs,
cls,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
**kwargs,
):
"""
Creates an AWS Session with the given parameters.
Expand All @@ -298,8 +287,6 @@ def get_aws_session(
profile_name=aws_profile_name,
)
except BotoCoreError as e:
provided_aws_config = {
k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS
}
provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS}
msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}"
raise AWSConfigurationError(msg) from e
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[str]:
responses = [completion.lstrip() for completion in completions]
return responses

def get_stream_responses(
self, stream, stream_handler: TokenStreamingHandler
) -> List[str]:
def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]:
tokens: List[str] = []
for event in stream:
chunk = event.get("chunk")
Expand All @@ -37,9 +35,7 @@ def get_stream_responses(
responses = ["".join(tokens).lstrip()]
return responses

def _get_params(
self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]
) -> Dict[str, Any]:
def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]:
"""
Merges the default params with the inference kwargs and model kwargs.
Expand All @@ -54,9 +50,7 @@ def _get_params(
}

@abstractmethod
def _extract_completions_from_response(
self, response_body: Dict[str, Any]
) -> List[str]:
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
"""Extracts the responses from the Amazon Bedrock response."""

@abstractmethod
Expand All @@ -82,9 +76,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params}
return body

def _extract_completions_from_response(
self, response_body: Dict[str, Any]
) -> List[str]:
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
return [response_body["completion"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Expand Down Expand Up @@ -114,9 +106,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
body = {"prompt": prompt, **params}
return body

def _extract_completions_from_response(
self, response_body: Dict[str, Any]
) -> List[str]:
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
responses = [generation["text"] for generation in response_body["generations"]]
return responses

Expand Down Expand Up @@ -145,12 +135,8 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
body = {"prompt": prompt, **params}
return body

def _extract_completions_from_response(
self, response_body: Dict[str, Any]
) -> List[str]:
responses = [
completion["data"]["text"] for completion in response_body["completions"]
]
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
responses = [completion["data"]["text"] for completion in response_body["completions"]]
return responses

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Expand All @@ -175,9 +161,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
body = {"inputText": prompt, "textGenerationConfig": params}
return body

def _extract_completions_from_response(
self, response_body: Dict[str, Any]
) -> List[str]:
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
responses = [result["outputText"] for result in response_body["results"]]
return responses

Expand All @@ -201,9 +185,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
body = {"prompt": prompt, **params}
return body

def _extract_completions_from_response(
self, response_body: Dict[str, Any]
) -> List[str]:
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
return [response_body["generation"]]

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ class DefaultPromptHandler:
are within the model_max_length.
"""

def __init__(
self, model_name_or_path: str, model_max_length: int, max_length: int = 100
):
def __init__(self, model_name_or_path: str, model_max_length: int, max_length: int = 100):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.tokenizer.model_max_length = model_max_length
self.model_max_length = model_max_length
Expand Down Expand Up @@ -40,9 +38,7 @@ def __call__(self, prompt: str, **kwargs) -> Dict[str, Union[str, int]]:
resized_prompt = self.tokenizer.convert_tokens_to_string(
tokenized_prompt[: self.model_max_length - self.max_length]
)
new_prompt_length = len(
tokenized_prompt[: self.model_max_length - self.max_length]
)
new_prompt_length = len(tokenized_prompt[: self.model_max_length - self.max_length])

return {
"resized_prompt": resized_prompt,
Expand Down
Loading

0 comments on commit 8b0581f

Please sign in to comment.