Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add AWS bedrock embeddings to embedding encoder #6406

Merged
merged 29 commits into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2c5e5a7
feat: add bedrock embeddings to embedding encoder
jlonge4 Nov 23, 2023
bc370b8
reno & black
jlonge4 Nov 24, 2023
b28fbc3
feat: refactoring for bedrock embedding encoder
jlonge4 Nov 24, 2023
7eb3855
feat: bedrock embedding encoder
jlonge4 Nov 24, 2023
928920c
feat: bedrock refactoring
jlonge4 Nov 24, 2023
cfa06d0
feat: bedrock refactoring
jlonge4 Nov 24, 2023
98f997a
feat: bedrock refactoring
jlonge4 Nov 24, 2023
6d11af3
feat: bedrock refactoring
jlonge4 Nov 24, 2023
5739b51
feat: bedrock refactoring
jlonge4 Nov 26, 2023
b80d039
feat: bedrock refactoring
jlonge4 Nov 26, 2023
78a506d
feat: bedrock refactoring, add cohere
jlonge4 Nov 26, 2023
998cfe3
feat: bedrock refactoring, add cohere
jlonge4 Nov 26, 2023
a87395c
pylint: disable too-many-return-statements in method
anakin87 Dec 5, 2023
5b6e2c3
feat: bedrock refactoring
jlonge4 Dec 7, 2023
fb55724
Merge branch 'feature-sbx' of https://github.com/jlonge4/haystack int…
jlonge4 Dec 7, 2023
d50a03e
feat: bedrock refactoring
jlonge4 Dec 7, 2023
8760b23
feat: bedrock refactoring
jlonge4 Dec 7, 2023
1a1bbce
feat: bedrock refactoring
jlonge4 Dec 7, 2023
92b410d
feat: bedrock refactoring
jlonge4 Dec 7, 2023
b5cf460
Merge branch 'v1.x' into feature-sbx
anakin87 Dec 7, 2023
b813643
feat: bedrock refactoring
jlonge4 Dec 7, 2023
41eb0b1
Merge branch 'feature-sbx' of https://github.com/jlonge4/haystack int…
jlonge4 Dec 7, 2023
4bb7ffa
Merge branch 'v1.x' into feature-sbx
anakin87 Dec 15, 2023
7125d9b
fix mypy and pylint errors
anakin87 Dec 15, 2023
8249685
manually run precommit
anakin87 Dec 15, 2023
badc4f9
refactor init
tstadel Dec 17, 2023
1315ff7
fix cohere truncate and refactor embed
tstadel Dec 17, 2023
6be5163
fix mypy
tstadel Dec 17, 2023
9fdcb93
proper exception handing
tstadel Dec 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions haystack/nodes/retriever/_embedding_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from haystack.modeling.infer import Inferencer
from haystack.nodes.retriever._losses import _TRAINING_LOSSES

with LazyImport(message="Run 'pip install boto3'") as boto3_import:
import boto3

COHERE_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
COHERE_BACKOFF = int(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
Expand Down Expand Up @@ -434,11 +436,79 @@ def save(self, save_dir: Union[Path, str]):
raise NotImplementedError(f"Saving is not implemented for {self.__class__}")


class _BedrockEmbeddingEncoder(_BaseEmbeddingEncoder):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added cohere model support in addition to titan, also initialized client using session. Added profile_name.

def __init__(self, retriever: "EmbeddingRetriever"):
boto3_import.check()

# See https://docs.aws.amazon.com/bedrock/latest/userguide/embeddings.html for more details
# The maximum input text is 8K tokens and the maximum output vector length is 1536
# Bedrock embeddings do not support batch operations
self.model: str = "amazon.titan-embed-text-v1"
self.aws_config = retriever.aws_config
self.client = self.initialize_boto3_client()

def initialize_boto3_client(self):
if self.aws_config:
access_key_id = self.aws_config.get("aws_access_key_id")
secret_access_key = self.aws_config.get("aws_secret_access_key")
region = self.aws_config.get("region")
try:
return boto3.client(
jlonge4 marked this conversation as resolved.
Show resolved Hide resolved
"bedrock-runtime",
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
region_name=region,
jlonge4 marked this conversation as resolved.
Show resolved Hide resolved
)
except Exception as e:
raise ValueError(f"AWS client error {e}")
else:
raise ValueError("Please pass boto3.client(bedrock-runtime) credentials configuration")

def embed(self, text: str) -> np.ndarray:
input_body = {}
input_body["inputText"] = text
jlonge4 marked this conversation as resolved.
Show resolved Hide resolved
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body, modelId=self.model, accept="application/json", contentType="application/json"
)

response_body = json.loads(response.get("body").read())
return np.array(response_body.get("embedding"))

def embed_queries(self, queries: List[str]) -> np.ndarray:
all_embeddings = []
for q in queries:
generated_embeddings = self.embed(q)
all_embeddings.append(generated_embeddings)
return np.concatenate(all_embeddings)

def embed_documents(self, docs: List[Document]) -> np.ndarray:
return self.embed_queries([d.content for d in docs])

def train(
self,
training_data: List[Dict[str, Any]],
learning_rate: float = 2e-5,
n_epochs: int = 1,
num_warmup_steps: Optional[int] = None,
batch_size: int = 16,
train_loss: Literal["mnrl", "margin_mse"] = "mnrl",
num_workers: int = 0,
use_amp: bool = False,
**kwargs,
):
raise NotImplementedError(f"Training is not implemented for {self.__class__}")

def save(self, save_dir: Union[Path, str]):
raise NotImplementedError(f"Saving is not implemented for {self.__class__}")


_EMBEDDING_ENCODERS: Dict[str, Callable] = {
"farm": _DefaultEmbeddingEncoder,
"transformers": _DefaultEmbeddingEncoder,
"sentence_transformers": _SentenceTransformersEmbeddingEncoder,
"retribert": _RetribertEmbeddingEncoder,
"openai": _OpenAIEmbeddingEncoder,
"cohere": _CohereEmbeddingEncoder,
"bedrock": _BedrockEmbeddingEncoder,
}
5 changes: 5 additions & 0 deletions haystack/nodes/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,7 @@ def __init__(
azure_deployment_name: Optional[str] = None,
api_base: str = "https://api.openai.com/v1",
openai_organization: Optional[str] = None,
aws_config: Optional[Dict[str, Any]] = None,
):
"""
:param document_store: An instance of DocumentStore from which to retrieve documents.
Expand Down Expand Up @@ -1532,6 +1533,7 @@ def __init__(
will not be used.
:param api_base: The OpenAI API base URL, defaults to `"https://api.openai.com/v1"`.
:param openai_organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI
:param aws_config: The aws_config contains {aws_access_key, aws_secret_key, aws_region } to use with the boto3 client for an AWS Bedrock retriever. Defaults to 'None'.
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
"""
torch_and_transformers_import.check()
Expand Down Expand Up @@ -1565,6 +1567,7 @@ def __init__(
self.azure_base_url = azure_base_url
self.azure_deployment_name = azure_deployment_name
self.openai_organization = openai_organization
self.aws_config = aws_config
self.model_format = (
self._infer_model_format(model_name_or_path=embedding_model, use_auth_token=use_auth_token)
if model_format is None
Expand Down Expand Up @@ -1892,6 +1895,8 @@ def _infer_model_format(model_name_or_path: str, use_auth_token: Optional[Union[
return "openai"
if model_name_or_path in COHERE_EMBEDDING_MODELS:
return "cohere"
if model_name_or_path == "bedrock":
return "bedrock"
# Check if model name is a local directory with sentence transformers config file in it
if Path(model_name_or_path).exists():
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Adding Bedrock Embeddings Encoder to use as a retriever.