Skip to content

Commit

Permalink
PR feedback David
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Feb 7, 2024
1 parent 199bdc6 commit 11b9b2f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List

from botocore.eventstream import EventStream
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
from transformers import AutoTokenizer, PreTrainedTokenizer

Expand All @@ -25,9 +26,9 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[

def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""Extracts the responses from the Amazon Bedrock response."""
return self._extract_messages_from_response(response_body)
return self._extract_messages_from_response(self.response_body_message_key(), response_body)

def get_stream_responses(self, stream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]:
def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]:
tokens: List[str] = []
for event in stream:
chunk = event.get("chunk")
Expand All @@ -43,7 +44,8 @@ def get_stream_responses(self, stream, stream_handler: Callable[[StreamingChunk]
responses = ["".join(tokens).lstrip()]
return responses

def _update_params(self, target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None:
@staticmethod
def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None:
"""
Updates target_dict with values from updates_dict. Merges lists instead of overriding them.
Expand All @@ -62,6 +64,10 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str
"""
Merges params from inference_kwargs with the default params and self.generation_kwargs.
Uses a helper function to merge lists or override values as necessary.
:param inference_kwargs: The inference kwargs to merge.
:param default_params: The default params to start with.
:return: The merged params.
"""
# Start with a copy of default_params
kwargs = default_params.copy()
Expand Down Expand Up @@ -95,9 +101,13 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]:
:return: A dictionary containing the resized prompt and additional information.
"""

def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]:
metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]

@abstractmethod
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""Extracts the responses from the Amazon Bedrock response."""
def response_body_message_key(self) -> str:
"""Returns the key for the message in the response body."""

@abstractmethod
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
Expand Down Expand Up @@ -171,9 +181,8 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
def check_prompt(self, prompt: str) -> Dict[str, Any]:
return self.prompt_handler(prompt)

def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
metadata = {k: v for (k, v) in response_body.items() if k != "completion"}
return [ChatMessage.from_assistant(response_body["completion"], meta=metadata)]
def response_body_message_key(self) -> str:
return "completion"

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
return chunk.get("completion", "")
Expand Down Expand Up @@ -250,9 +259,8 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
def check_prompt(self, prompt: str) -> Dict[str, Any]:
return self.prompt_handler(prompt)

def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
metadata = {k: v for (k, v) in response_body.items() if k != "generation"}
return [ChatMessage.from_assistant(response_body["generation"], meta=metadata)]
def response_body_message_key(self) -> str:
return "generation"

def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
return chunk.get("generation", "")
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(
Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded
automatically from the environment or the AWS configuration file and do not need to be provided explicitly via
the constructor.
the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the
constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`,
and `aws_region_name`.
:param model: The model to use for generation. The model must be available in Amazon Bedrock. The model has to
be specified in the format outlined in the Amazon Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html).
Expand Down

0 comments on commit 11b9b2f

Please sign in to comment.