Skip to content

Commit

Permalink
feat(AWS Bedrock): Add Cohere Reranker (#1291)
Browse files Browse the repository at this point in the history
* Amazon Bedrock: Add Cohere Rerank model

* Run Lint

* Remove changes to CHANGELOG.md

* Remove var from serialization test

* # noqa: B008 fix test lint, yada yada

* adding BedrockRanker to pydoc

---------

Co-authored-by: David S. Batista <[email protected]>
Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2025
1 parent a5bdb76 commit 01c5385
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 4 deletions.
32 changes: 32 additions & 0 deletions integrations/amazon_bedrock/examples/bedrock_ranker_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os

from haystack import Document
from haystack.utils import Secret

from haystack_integrations.components.rankers.amazon_bedrock import BedrockRanker

# Set up AWS credentials
# You can also set these as environment variables
aws_profile_name = os.environ.get("AWS_PROFILE") or "default"
aws_region_name = os.environ.get("AWS_DEFAULT_REGION") or "eu-central-1"
# Initialize the BedrockRanker with AWS credentials
ranker = BedrockRanker(
model="cohere.rerank-v3-5:0",
top_k=2,
aws_profile_name=Secret.from_token(aws_profile_name),
aws_region_name=Secret.from_token(aws_region_name),
)

# Create some sample documents
docs = [
Document(content="Paris is the capital of France."),
Document(content="Berlin is the capital of Germany."),
Document(content="London is the capital of the United Kingdom."),
Document(content="Rome is the capital of Italy."),
]

# Define a query
query = "What is the capital of Germany?"

# Run the ranker
output = ranker.run(query=query, documents=docs)
9 changes: 5 additions & 4 deletions integrations/amazon_bedrock/pydoc/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../src]
modules: [
"haystack_integrations.common.amazon_bedrock.errors",
"haystack_integrations.components.embedders.amazon_bedrock.document_embedder",
"haystack_integrations.components.embedders.amazon_bedrock.text_embedder",
"haystack_integrations.components.generators.amazon_bedrock.generator",
"haystack_integrations.components.generators.amazon_bedrock.adapters",
"haystack_integrations.common.amazon_bedrock.errors",
"haystack_integrations.components.generators.amazon_bedrock.handlers",
"haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator",
"haystack_integrations.components.embedders.amazon_bedrock.text_embedder",
"haystack_integrations.components.embedders.amazon_bedrock.document_embedder",
"haystack_integrations.components.generators.amazon_bedrock.handlers",
"haystack_integrations.components.rankers.amazon_bedrock.ranker",
]
ignore_when_discovered: ["__init__"]
processors:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ranker import BedrockRanker

__all__ = ["BedrockRanker"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import json
import logging
from typing import Any, Dict, List, Optional

from botocore.exceptions import ClientError
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace

from haystack_integrations.common.amazon_bedrock.errors import (
AmazonBedrockConfigurationError,
AmazonBedrockInferenceError,
)
from haystack_integrations.common.amazon_bedrock.utils import get_aws_session

logger = logging.getLogger(__name__)

MAX_NUM_DOCS_FOR_BEDROCK_RANKER = 1000


@component
class BedrockRanker:
"""
Ranks Documents based on their similarity to the query using Amazon Bedrock's Cohere Rerank model.
Documents are indexed from most to least semantically relevant to the query.
Usage example:
```python
from haystack import Document
from haystack.utils import Secret
from haystack_integrations.components.rankers.amazon_bedrock import BedrockRanker
ranker = BedrockRanker(model="cohere.rerank-v3-5:0", top_k=2, aws_region_name=Secret.from_token("eu-central-1"))
docs = [Document(content="Paris"), Document(content="Berlin")]
query = "What is the capital of germany?"
output = ranker.run(query=query, documents=docs)
docs = output["documents"]
```
BedrockRanker uses AWS for authentication. You can use the AWS CLI to authenticate through your IAM.
For more information on setting up an IAM identity-based policy, see [Amazon Bedrock documentation]
(https://docs.aws.amazon.com/bedrock/latest/userguide/security_iam_id-based-policy-examples.html).
If the AWS environment is configured correctly, the AWS credentials are not required as they're loaded
automatically from the environment or the AWS configuration file.
If the AWS environment is not configured, set `aws_access_key_id`, `aws_secret_access_key`,
and `aws_region_name` as environment variables or pass them as
[Secret](https://docs.haystack.deepset.ai/v2.0/docs/secret-management) arguments. Make sure the region you set
supports Amazon Bedrock.
"""

def __init__(
self,
model: str = "cohere.rerank-v3-5:0",
top_k: int = 10,
aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008
aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008
["AWS_SECRET_ACCESS_KEY"], strict=False
),
aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
max_chunks_per_doc: Optional[int] = None,
meta_fields_to_embed: Optional[List[str]] = None,
meta_data_separator: str = "\n",
):
if not model:
msg = "'model' cannot be None or empty string"
raise ValueError(msg)
"""
Creates an instance of the 'BedrockRanker'.
:param model: Amazon Bedrock model name for Cohere Rerank. Default is "cohere.rerank-v3-5:0".
:param top_k: The maximum number of documents to return.
:param aws_access_key_id: AWS access key ID.
:param aws_secret_access_key: AWS secret access key.
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:param max_chunks_per_doc: If your document exceeds 512 tokens, this determines the maximum number of
chunks a document can be split into. If `None`, the default of 10 is used.
Note: This parameter is not currently used in the implementation but is included for future compatibility.
:param meta_fields_to_embed: List of meta fields that should be concatenated
with the document content for reranking.
:param meta_data_separator: Separator used to concatenate the meta fields
to the Document content.
"""
self.model_name = model
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.top_k = top_k
self.max_chunks_per_doc = max_chunks_per_doc
self.meta_fields_to_embed = meta_fields_to_embed or []
self.meta_data_separator = meta_data_separator

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None

try:
session = get_aws_session(
aws_access_key_id=resolve_secret(aws_access_key_id),
aws_secret_access_key=resolve_secret(aws_secret_access_key),
aws_session_token=resolve_secret(aws_session_token),
aws_region_name=resolve_secret(aws_region_name),
aws_profile_name=resolve_secret(aws_profile_name),
)
self._bedrock_client = session.client("bedrock-runtime")
except Exception as exception:
msg = (
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
)
raise AmazonBedrockConfigurationError(msg) from exception

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
model=self.model_name,
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None,
aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None,
aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
top_k=self.top_k,
max_chunks_per_doc=self.max_chunks_per_doc,
meta_fields_to_embed=self.meta_fields_to_embed,
meta_data_separator=self.meta_data_separator,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BedrockRanker":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
deserialize_secrets_inplace(
data["init_parameters"],
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)

def _prepare_bedrock_input_docs(self, documents: List[Document]) -> List[str]:
"""
Prepare the input by concatenating the document text with the metadata fields specified.
:param documents: The list of Document objects.
:return: A list of strings to be given as input to Bedrock model.
"""
concatenated_input_list = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta.get(key)
]
concatenated_input = self.meta_data_separator.join([*meta_values_to_embed, doc.content or ""])
concatenated_input_list.append(concatenated_input)

return concatenated_input_list

@component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
"""
Use the Amazon Bedrock Cohere Reranker to re-rank the list of documents based on the query.
:param query:
Query string.
:param documents:
List of Documents.
:param top_k:
The maximum number of Documents you want the Ranker to return.
:returns:
A dictionary with the following keys:
- `documents`: List of Documents most similar to the given query in descending order of similarity.
:raises ValueError: If `top_k` is not > 0.
"""
top_k = top_k or self.top_k
if top_k <= 0:
msg = f"top_k must be > 0, but got {top_k}"
raise ValueError(msg)

bedrock_input_docs = self._prepare_bedrock_input_docs(documents)
if len(bedrock_input_docs) > MAX_NUM_DOCS_FOR_BEDROCK_RANKER:
logger.warning(
f"The Amazon Bedrock reranking endpoint only supports {MAX_NUM_DOCS_FOR_BEDROCK_RANKER} documents.\
The number of documents has been truncated to {MAX_NUM_DOCS_FOR_BEDROCK_RANKER} \
from {len(bedrock_input_docs)}."
)
bedrock_input_docs = bedrock_input_docs[:MAX_NUM_DOCS_FOR_BEDROCK_RANKER]

# Prepare the request body for Amazon Bedrock
request_body = {"documents": bedrock_input_docs, "query": query, "top_n": top_k, "api_version": 2}

try:
# Make the API call to Amazon Bedrock
response = self._bedrock_client.invoke_model(modelId=self.model_name, body=json.dumps(request_body))

# Parse the response
response_body = json.loads(response["body"].read())
results = response_body["results"]

# Sort documents based on the reranking results
sorted_docs = []
for result in results:
idx = result["index"]
score = result["relevance_score"]
doc = documents[idx]
doc.score = score
sorted_docs.append(doc)

return {"documents": sorted_docs}
except ClientError as exception:
msg = f"Could not inference Amazon Bedrock model {self.model_name} due: {exception}"
raise AmazonBedrockInferenceError(msg) from exception
except KeyError as e:
msg = f"Unexpected response format from Amazon Bedrock: {e!s}"
raise AmazonBedrockInferenceError(msg) from e
except Exception as e:
msg = f"Error during Amazon Bedrock API call: {e!s}"
raise AmazonBedrockInferenceError(msg) from e
103 changes: 103 additions & 0 deletions integrations/amazon_bedrock/tests/test_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from unittest.mock import MagicMock, patch

import pytest
from haystack import Document
from haystack.utils import Secret

from haystack_integrations.common.amazon_bedrock.errors import (
AmazonBedrockInferenceError,
)
from haystack_integrations.components.rankers.amazon_bedrock import BedrockRanker


@pytest.fixture
def mock_aws_session():
with patch("haystack_integrations.components.rankers.amazon_bedrock.ranker.get_aws_session") as mock_session:
mock_client = MagicMock()
mock_session.return_value.client.return_value = mock_client
yield mock_client


def test_bedrock_ranker_initialization(mock_aws_session):
ranker = BedrockRanker(
model="cohere.rerank-v3-5:0",
top_k=2,
aws_access_key_id=Secret.from_token("test_access_key"),
aws_secret_access_key=Secret.from_token("test_secret_key"),
aws_region_name=Secret.from_token("us-east-1"),
)
assert ranker.model_name == "cohere.rerank-v3-5:0"
assert ranker.top_k == 2


def test_bedrock_ranker_run(mock_aws_session):
ranker = BedrockRanker(
model="cohere.rerank-v3-5:0",
top_k=2,
aws_access_key_id=Secret.from_token("test_access_key"),
aws_secret_access_key=Secret.from_token("test_secret_key"),
aws_region_name=Secret.from_token("us-east-1"),
)

mock_response = {
"body": MagicMock(
read=MagicMock(
return_value=b'{"results": [{"index": 0, "relevance_score": 0.9},'
b' {"index": 1, "relevance_score": 0.7}]}'
)
)
}
mock_aws_session.invoke_model.return_value = mock_response

docs = [Document(content="Test document 1"), Document(content="Test document 2")]
result = ranker.run(query="test query", documents=docs)

assert len(result["documents"]) == 2
assert result["documents"][0].score == 0.9
assert result["documents"][1].score == 0.7


@pytest.mark.integration
def test_bedrock_ranker_live_run():
ranker = BedrockRanker(
model="cohere.rerank-v3-5:0",
top_k=2,
aws_region_name=Secret.from_token("eu-central-1"),
)

docs = [Document(content="Test document 1"), Document(content="Test document 2")]
result = ranker.run(query="test query", documents=docs)
assert len(result["documents"]) == 2
assert isinstance(result["documents"][0].score, float)


def test_bedrock_ranker_run_inference_error(mock_aws_session):
ranker = BedrockRanker(
model="cohere.rerank-v3-5:0",
top_k=2,
aws_access_key_id=Secret.from_token("test_access_key"),
aws_secret_access_key=Secret.from_token("test_secret_key"),
aws_region_name=Secret.from_token("us-east-1"),
)

mock_aws_session.invoke_model.side_effect = Exception("Inference error")

docs = [Document(content="Test document 1"), Document(content="Test document 2")]
with pytest.raises(AmazonBedrockInferenceError):
ranker.run(query="test query", documents=docs)


def test_bedrock_ranker_serialization(mock_aws_session):
ranker = BedrockRanker(
model="cohere.rerank-v3-5:0",
top_k=2,
)

serialized = ranker.to_dict()
assert serialized["init_parameters"]["model"] == "cohere.rerank-v3-5:0"
assert serialized["init_parameters"]["top_k"] == 2

deserialized = BedrockRanker.from_dict(serialized)
assert isinstance(deserialized, BedrockRanker)
assert deserialized.model_name == "cohere.rerank-v3-5:0"
assert deserialized.top_k == 2

0 comments on commit 01c5385

Please sign in to comment.