Skip to content

Commit

Permalink
Hook prompt length check
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Feb 7, 2024
1 parent 13279db commit 2a62b8c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,29 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str

return kwargs

def _ensure_token_limit(self, prompt: str) -> str:
resize_info = self.check_prompt(prompt)
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
logger.warning(
"The prompt was truncated from %s tokens to %s tokens so that the prompt length and "
"the answer length (%s tokens) fit within the model's max token limit (%s tokens). "
"Shorten the prompt or it will be cut off.",
resize_info["prompt_length"],
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
resize_info["max_length"],
resize_info["model_max_length"],
)
return str(resize_info["resized_prompt"])

@abstractmethod
def check_prompt(self, prompt: str) -> Dict[str, Any]:
"""
Checks the prompt length and resizes it if necessary.
:param prompt: The prompt to check.
:return: A dictionary containing the resized prompt and additional information.
"""

@abstractmethod
def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
"""Extracts the responses from the Amazon Bedrock response."""
Expand All @@ -89,14 +112,14 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter):
ANTHROPIC_USER_TOKEN = "\n\nHuman:"
ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:"

def __init__(self, generation_kwargs: Dict[str, Any]) -> None:
def __init__(self, generation_kwargs: Dict[str, Any]):
super().__init__(generation_kwargs)

# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
# Anthropic Claude has a limit of at least 100000 tokens
# https://docs.anthropic.com/claude/reference/input-and-output-sizes
model_max_length = self.generation_kwargs.get("model_max_length", 100000)
model_max_length = self.generation_kwargs.pop("model_max_length", 100000)

# Truncate prompt if prompt tokens > model_max_length-max_length
# (max_length is the length of the generated text)
Expand Down Expand Up @@ -142,7 +165,11 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
invalid_role = f"Invalid role {message.role} for message {message.content}"
raise ValueError(invalid_role)

return "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " "
prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " "
return self._ensure_token_limit(prepared_prompt)

def check_prompt(self, prompt: str) -> Dict[str, Any]:
return self.prompt_handler(prompt)

def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
metadata = {k: v for (k, v) in response_body.items() if k != "completion"}
Expand Down Expand Up @@ -187,9 +214,13 @@ def __init__(self, generation_kwargs: Dict[str, Any]) -> None:
# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
# Llama 2 has context window size of 4096 tokens
model_max_length = self.generation_kwargs.get("model_max_length", 4096)
# Truncate prompt if prompt tokens > model_max_length-max_length
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("gpt2")
# with some exceptions when the context window has been extended
model_max_length = self.generation_kwargs.pop("model_max_length", 4096)

# Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2
# a) we should get good estimates for the prompt length (empirically close to llama 2)
# b) we can use apply_chat_template with the template above to delineate ChatMessages
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.unk_token = "<unk>"
Expand All @@ -214,7 +245,10 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
conversation=messages, tokenize=False, chat_template=self.chat_template
)
return prepared_prompt
return self._ensure_token_limit(prepared_prompt)

def check_prompt(self, prompt: str) -> Dict[str, Any]:
return self.prompt_handler(prompt)

def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
metadata = {k: v for (k, v) in response_body.items() if k != "generation"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,7 @@ def invoke(self, *args, **kwargs):
response_body = json.loads(response.get("body").read().decode("utf-8"))
responses = self.model_adapter.get_responses(response_body=response_body)
except ClientError as exception:
msg = (
f"Could not connect to Amazon Bedrock model {self.model}. "
f"Make sure your AWS environment is configured correctly, "
f"the model is available in the configured AWS region, and you have access."
)
msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}"
raise AmazonBedrockInferenceError(msg) from exception

return responses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ class DefaultPromptHandler:
"""

def __init__(self, tokenizer: Union[str, PreTrainedTokenizerBase], model_max_length: int, max_length: int = 100):
"""
:param tokenizer: The tokenizer to be used to tokenize the prompt.
:param model_max_length: The maximum length of the prompt and answer tokens combined.
:param max_length: The maximum length of the answer tokens.
"""
if isinstance(tokenizer, str):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
elif isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
Expand Down

0 comments on commit 2a62b8c

Please sign in to comment.