diff --git a/integrations/astra/README.md b/integrations/astra/README.md index d14544df4..f8b6f7c31 100644 --- a/integrations/astra/README.md +++ b/integrations/astra/README.md @@ -3,6 +3,13 @@ # Astra Store ## Installation + +```bash +pip install astra-haystack + +``` + +### Local Development install astra-haystack package locally to run integration tests: Open in gitpod: @@ -46,8 +53,8 @@ This package includes Astra Document Store and Astra Embedding Retriever classes Import the Document Store: ``` -from astra_store.document_store import AstraDocumentStore -from haystack.preview.document_stores import DuplicatePolicy +from haystack_integrations.document_stores.astra import AstraDocumentStore +from haystack.document_stores.types.policy import DuplicatePolicy ``` Load in environment variables: @@ -76,7 +83,7 @@ Then you can use the document store functions like count_document below: Create the Document Store object like above, then import and create the Pipeline: ``` -from haystack.preview import Pipeline +from haystack import Pipeline pipeline = Pipeline() ``` diff --git a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py index 7236b5749..2b9ac7d28 100644 --- a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py @@ -13,15 +13,28 @@ class AstraEmbeddingRetriever: """ A component for retrieving documents from an AstraDocumentStore. + + Usage example: + ```python + from haystack_integrations.document_stores.astra import AstraDocumentStore + from haystack_integrations.components.retrievers.astra import AstraEmbeddingRetriever + + document_store = AstraDocumentStore( + api_endpoint=api_endpoint, + token=token, + collection_name=collection_name, + duplicates_policy=DuplicatePolicy.SKIP, + embedding_dim=384, + ) + + retriever = AstraEmbeddingRetriever(document_store=document_store) + ``` """ def __init__(self, document_store: AstraDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): """ - Create an AstraEmbeddingRetriever component. Usually you pass some basic configuration - parameters to the constructor. - - :param filters: A dictionary with filters to narrow down the search space (default is None). - :param top_k: The maximum number of documents to retrieve (default is 10). + :param filters: a dictionary with filters to narrow down the search space. + :param top_k: the maximum number of documents to retrieve. """ self.filters = filters self.top_k = top_k @@ -33,13 +46,13 @@ def __init__(self, document_store: AstraDocumentStore, filters: Optional[Dict[st @component.output_types(documents=List[Document]) def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): - """Run the retriever on the given list of queries. + """Retrieve documents from the AstraDocumentStore. - Args: - query_embedding (List[str]): An input list of queries - filters (Optional[Dict[str, Any]], optional): A dictionary with filters to narrow down the search space. - Defaults to None. - top_k (Optional[int], optional): The maximum number of documents to retrieve. Defaults to None. + :param query_embedding: floats representing the query embedding + :param filters: filters to narrow down the search space. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - documents: A list of documents retrieved from the AstraDocumentStore. """ if not top_k: @@ -51,6 +64,12 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = return {"documents": self.document_store.search(query_embedding, top_k, filters=filters)} def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, filters=self.filters, @@ -60,6 +79,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AstraEmbeddingRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ document_store = AstraDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store return default_from_dict(cls, data) diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index 6e7b1a33f..c1eb1f6a7 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -43,6 +43,19 @@ def __init__( similarity_function: str, namespace: Optional[str] = None, ): + """ + The connection to Astra DB is established and managed through the JSON API. + The required credentials (api endpoint and application token) can be generated + through the UI by clicking and the connect tab, and then selecting JSON API and + Generate Configuration. + + :param api_endpoint: the Astra DB API endpoint. + :param token: the Astra DB application token. + :param collection_name: the current collection in the keyspace in the current Astra DB. + :param embedding_dimension: dimension of embedding vector. + :param similarity_function: the similarity function to use for the index. + :param namespace: the namespace to use for the collection. + """ self.api_endpoint = api_endpoint self.token = token self.collection_name = collection_name @@ -119,23 +132,17 @@ def query( include_values: Optional[bool] = None, ) -> QueryResponse: """ - The Query operation searches a namespace, using a query vector. - It retrieves the ids of the most similar items in a namespace, along with their similarity scores. - - Args: - vector (List[float]): The query vector. This should be the same length as the dimension of the index - being queried. Each `query()` request can contain only one of the parameters - `queries`, `id` or `vector`... [optional] - top_k (int): The number of results to return for each query. Must be an integer greater than 1. - query_filter (Dict[str, Union[str, float, int, bool, List, dict]): - The filter to apply. You can use vector metadata to limit your search. [optional] - include_metadata (bool): Indicates whether metadata is included in the response as well as the ids. - If omitted the server will use the default value of False [optional] - include_values (bool): Indicates whether values/vector is included in the response as well as the ids. - If omitted the server will use the default value of False [optional] - - Returns: object which contains the list of the closest vectors as ScoredVector objects, - and namespace name. + Search the Astra index using a query vector. + + :param vector: the query vector. This should be the same length as the dimension of the index being queried. + Each `query()` request can contain only one of the parameters `queries`, `id` or `vector`. + :param query_filter: the filter to apply. You can use vector metadata to limit your search. + :param top_k: the number of results to return for each query. Must be an integer greater than 1. + :param include_metadata: indicates whether metadata is included in the response as well as the ids. + If omitted the server will use the default value of `False`. + :param include_values: indicates whether values/vector is included in the response as well as the ids. + If omitted the server will use the default value of `False`. + :returns: object which contains the list of the closest vectors as ScoredVector objects, and namespace name. """ # get vector data and scores if vector is None: @@ -183,6 +190,12 @@ def _query(self, vector, top_k, filters=None): return result def find_documents(self, find_query): + """ + Find documents in the Astra index. + + :param find_query: a dictionary with the query options + :returns: the documents found in the index + """ response_dict = self._astra_db_collection.find( filter=find_query.get("filter"), sort=find_query.get("sort"), @@ -195,6 +208,13 @@ def find_documents(self, find_query): logger.warning(f"No documents found: {response_dict}") def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse: + """ + Get documents from the Astra index by their ids. + + :param ids: a list of document ids + :param batch_size: the batch size to use when querying the index + :returns: the documents found in the index + """ document_batch = [] def batch_generator(chunks, batch_size): @@ -213,6 +233,12 @@ def batch_generator(chunks, batch_size): return formatted_docs def insert(self, documents: List[Dict]): + """ + Insert documents into the Astra index. + + :param documents: a list of documents to insert + :returns: the IDs of the inserted documents + """ response_dict = self._astra_db_collection.insert_many(documents=documents) inserted_ids = ( @@ -226,6 +252,13 @@ def insert(self, documents: List[Dict]): return inserted_ids def update_document(self, document: Dict, id_key: str): + """ + Update a document in the Astra index. + + :param document: the document to update + :param id_key: the key to use as the document id + :returns: whether the document was updated successfully + """ document_id = document.pop(id_key) response_dict = self._astra_db_collection.find_one_and_update( @@ -251,6 +284,13 @@ def delete( delete_all: Optional[bool] = None, filters: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, ) -> int: + """Delete documents from the Astra index. + + :param ids: the ids of the documents to delete + :param delete_all: if `True`, delete all documents from the index + :param filters: additional filters to apply when deleting documents + :returns: the number of documents deleted + """ if delete_all: query = {"deleteMany": {}} # type: dict if ids is not None: @@ -276,7 +316,8 @@ def delete( def count_documents(self) -> int: """ - Returns how many documents are present in the document store. + Count the number of documents in the Astra index. + :returns: the number of documents in the index """ documents_count = self._astra_db_collection.count_documents() diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 1a4ec9d17..2f8f0928d 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -32,6 +32,19 @@ def _batches(input_list, batch_size): class AstraDocumentStore: """ An AstraDocumentStore document store for Haystack. + + Example Usage: + ```python + from haystack_integrations.document_stores.astra import AstraDocumentStore + + document_store = AstraDocumentStore( + api_endpoint=api_endpoint, + token=token, + collection_name=collection_name, + duplicates_policy=DuplicatePolicy.SKIP, + embedding_dim=384, + ) + ``` """ def __init__( @@ -45,22 +58,24 @@ def __init__( ): """ The connection to Astra DB is established and managed through the JSON API. - The required credentials (api endpoint andapplication token) can be generated + The required credentials (api endpoint and application token) can be generated through the UI by clicking and the connect tab, and then selecting JSON API and Generate Configuration. - :param api_endpoint: The Astra DB API endpoint. - :param token: The Astra DB application token. - :param collection_name: The current collection in the keyspace in the current Astra DB. - :param embedding_dimension: Dimension of embedding vector. - :param duplicates_policy: Handle duplicate documents based on DuplicatePolicy parameter options. - Parameter options : (SKIP, OVERWRITE, FAIL, NONE) - - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, + :param api_endpoint: the Astra DB API endpoint. + :param token: the Astra DB application token. + :param collection_name: the current collection in the keyspace in the current Astra DB. + :param embedding_dimension: dimension of embedding vector. + :param duplicates_policy: handle duplicate documents based on DuplicatePolicy parameter options. + Parameter options : (`SKIP`, `OVERWRITE`, `FAIL`, `NONE`) + - `DuplicatePolicy.NONE`: Default policy, If a Document with the same ID already exists, it is skipped and not written. - - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, it is skipped and not written. - - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. - - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. - :param similarity: The similarity function used to compare document vectors. + - `DuplicatePolicy.SKIP`: if a Document with the same ID already exists, it is skipped and not written. + - `DuplicatePolicy.OVERWRITE`: if a Document with the same ID already exists, it is overwritten. + - `DuplicatePolicy.FAIL`: if a Document with the same ID already exists, an error is raised. + :param similarity: the similarity function used to compare document vectors. + + :raises ValueError: if the API endpoint or token is not set. """ resolved_api_endpoint = api_endpoint.resolve_value() if resolved_api_endpoint is None: @@ -95,10 +110,24 @@ def __init__( @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AstraDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_endpoint", "token"]) return default_from_dict(cls, data) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, api_endpoint=self.api_endpoint.to_dict(), @@ -118,15 +147,18 @@ def write_documents( Indexes documents for later queries. :param documents: a list of Haystack Document objects. - :param policy: Handle duplicate documents based on DuplicatePolicy parameter options. - Parameter options : (SKIP, OVERWRITE, FAIL, NONE) - - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, - it is skipped and not written. - - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, - it is skipped and not written. - - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. - - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. - :return: int + :param policy: handle duplicate documents based on DuplicatePolicy parameter options. + Parameter options : (`SKIP`, `OVERWRITE`, `FAIL`, `NONE`) + - `DuplicatePolicy.NONE`: Default policy, If a Document with the same ID already exists, + it is skipped and not written. + - `DuplicatePolicy.SKIP`: If a Document with the same ID already exists, + it is skipped and not written. + - `DuplicatePolicy.OVERWRITE`: If a Document with the same ID already exists, it is overwritten. + - `DuplicatePolicy.FAIL`: If a Document with the same ID already exists, an error is raised. + :returns: number of documents written. + :raises ValueError: if the documents are not of type Document or dict. + :raises DuplicateDocumentError: if a document with the same ID already exists and policy is set to FAIL. + :raises Exception: if the document ID is not a string or if `id` and `_id` are both present in the document. """ if policy is None or policy == DuplicatePolicy.NONE: if self.duplicates_policy is not None and self.duplicates_policy != DuplicatePolicy.NONE: @@ -226,21 +258,19 @@ def _convert_input_document(document: Union[dict, Document]): def count_documents(self) -> int: """ - Returns how many documents are present in the document store. + Counts the number of documents in the document store. + + :returns: the number of documents in the document store. """ return self.index.count_documents() def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: - """Returns at most 1000 documents that match the filter - - Args: - filters (Optional[Dict[str, Any]], optional): Filters to apply. Defaults to None. - - Raises: - AstraDocumentStoreFilterError: If the filter is invalid or not supported by this class. + """ + Returns at most 1000 documents that match the filter. - Returns: - List[Document]: A list of matching documents. + :param filters: filters to apply. + :returns: matching documents. + :raises AstraDocumentStoreFilterError: if the filter is invalid or not supported by this class. """ if not isinstance(filters, dict) and filters is not None: msg = "Filters must be a dictionary or None" @@ -299,7 +329,10 @@ def _get_result_to_documents(results) -> List[Document]: def get_documents_by_id(self, ids: List[str]) -> List[Document]: """ - Returns documents with given ids. + Gets documents by their IDs. + + :param ids: the IDs of the documents to retrieve. + :returns: the matching documents. """ results = self.index.get_documents(ids=ids) ret = self._get_result_to_documents(results) @@ -307,9 +340,11 @@ def get_documents_by_id(self, ids: List[str]) -> List[Document]: def get_document_by_id(self, document_id: str) -> Document: """ - :param document_id: id of the document to retrieve - Returns documents with given ids. - Raises MissingDocumentError when document_id does not exist in document store + Gets a document by its ID. + + :param document_id: the ID to filter by + :returns: the found document + :raises MissingDocumentError: if the document is not found """ document = self.index.get_documents(ids=[document_id]) ret = self._get_result_to_documents(document) @@ -321,15 +356,13 @@ def get_document_by_id(self, document_id: str) -> Document: def search( self, query_embedding: List[float], top_k: int, filters: Optional[Dict[str, Any]] = None ) -> List[Document]: - """Perform a search for a list of queries. - - Args: - query_embedding (List[float]): A list of query embeddings. - top_k (int): The number of results to return. - filters (Optional[Dict[str, Any]], optional): Filters to apply during search. Defaults to None. + """ + Perform a search for a list of queries. - Returns: - List[Document]: A list of matching documents. + :param query_embedding: a list of query embeddings. + :param top_k: the number of results to return. + :param filters: filters to apply during search. + :returns: matching documents. """ converted_filters = _convert_filters(filters) @@ -348,13 +381,12 @@ def search( def delete_documents(self, document_ids: Optional[List[str]] = None, delete_all: Optional[bool] = None) -> None: """ - Deletes all documents with a matching document_ids from the document store. - Fails with `MissingDocumentError` if no document with this id is present in the store. + Deletes documents from the document store. - :param document_ids: the document_ids to delete. - :param delete_all: delete all documents. + :param document_ids: IDs of the documents to delete. + :param delete_all: if `True`, delete all documents. + :raises MissingDocumentError: if no document was deleted but document IDs were provided. """ - deletion_counter = 0 if self.index.count_documents() > 0: if document_ids is not None: diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py b/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py index 186a8fef2..493f62917 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py @@ -6,12 +6,18 @@ class AstraDocumentStoreError(DocumentStoreError): + """Parent class for all AstraDocumentStore errors.""" + pass class AstraDocumentStoreFilterError(FilterError): + """Raised when an invalid filter is passed to AstraDocumentStore.""" + pass class AstraDocumentStoreConfigError(AstraDocumentStoreError): + """Raised when an invalid configuration is passed to AstraDocumentStore.""" + pass diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py index 6b628486b..44cac25e6 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py @@ -19,7 +19,7 @@ def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: def _convert_filters(filters: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: """ - Convert haystack filters to astra filterstring capturing all boolean operators + Convert haystack filters to astra filter string capturing all boolean operators """ if not filters: return None diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py index 2712f8c17..e19d4acbe 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py @@ -10,16 +10,42 @@ @component class ChromaQueryTextRetriever: """ - A component for retrieving documents from an ChromaDocumentStore using the `query` API. + A component for retrieving documents from a [Chroma database](https://docs.trychroma.com/) using the `query` API. + + Example usage: + ```python + from haystack import Pipeline + from haystack.components.converters import TextFileToDocument + from haystack.components.writers import DocumentWriter + + from haystack_integrations.document_stores.chroma import ChromaDocumentStore + from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever + + file_paths = ... + + # Chroma is used in-memory so we use the same instances in the two pipelines below + document_store = ChromaDocumentStore() + + indexing = Pipeline() + indexing.add_component("converter", TextFileToDocument()) + indexing.add_component("writer", DocumentWriter(document_store)) + indexing.connect("converter", "writer") + indexing.run({"converter": {"sources": file_paths}}) + + querying = Pipeline() + querying.add_component("retriever", ChromaQueryTextRetriever(document_store)) + results = querying.run({"retriever": {"query": "Variable declarations", "top_k": 3}}) + + for d in results["retriever"]["documents"]: + print(d.meta, d.score) + ``` """ def __init__(self, document_store: ChromaDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): """ - Create a ChromaQueryTextRetriever component. - - :param document_store: An instance of ChromaDocumentStore. - :param filters: A dictionary with filters to narrow down the search space (default is None). - :param top_k: The maximum number of documents to retrieve (default is 10). + :param document_store: an instance of `ChromaDocumentStore`. + :param filters: filters to narrow down the search space. + :param top_k: the maximum number of documents to retrieve. """ self.filters = filters self.top_k = top_k @@ -36,7 +62,10 @@ def run( Run the retriever on the given input data. :param query: The input data for the retriever. In this case, a plain-text query. - :return: The retrieved documents. + :param top_k: The maximum number of documents to retrieve. + If not specified, the default value from the constructor is used. + :return: A dictionary with the following keys: + - "documents": List of documents returned by the search engine. :raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance. """ @@ -46,7 +75,10 @@ def run( def to_dict(self) -> Dict[str, Any]: """ - Override the default serializer in order to manage the Chroma client string representation + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ d = default_to_dict(self, filters=self.filters, top_k=self.top_k) d["init_parameters"]["document_store"] = self.document_store.to_dict() @@ -55,6 +87,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ document_store = ChromaDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store return default_from_dict(cls, data) @@ -62,6 +102,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever": @component class ChromaEmbeddingRetriever(ChromaQueryTextRetriever): + """ + A component for retrieving documents from a [Chroma database](https://docs.trychroma.com/) using embeddings. + """ + @component.output_types(documents=List[Document]) def run( self, @@ -72,10 +116,9 @@ def run( """ Run the retriever on the given input data. - :param queries: The input data for the retriever. In this case, a list of queries. - :return: The retrieved documents. - - :raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance. + :param query_embedding: the query embeddings. + :return: a dictionary with the following keys: + - "documents": List of documents returned by the search engine. """ top_k = top_k or self.top_k diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 3249ec4f8..0da89c87e 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -19,6 +19,8 @@ class ChromaDocumentStore: """ + A document store using [Chroma](https://docs.trychroma.com/) as the backend. + We use the `collection.get` API to implement the document store protocol, the `collection.search` API will be used in the retriever instead. """ @@ -38,6 +40,11 @@ def __init__( Note: for the component to be part of a serializable pipeline, the __init__ parameters must be serializable, reason why we use a registry to configure the embedding function passing a string. + + :param collection_name: the name of the collection to use in the database. + :param embedding_function: the name of the embedding function to use to embed the query + :param persist_path: where to store the database. If None, the database will be `in-memory`. + :param embedding_function_params: additional parameters to pass to the embedding function. """ # Store the params for marshalling self._collection_name = collection_name @@ -56,6 +63,8 @@ def __init__( def count_documents(self) -> int: """ Returns how many documents are present in the document store. + + :returns: how many documents are present in the document store. """ return self._collection.count() @@ -128,7 +137,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc ``` :param filters: the filters to apply to the document list. - :return: a list of Documents that match the given filters. + :returns: a list of Documents that match the given filters. """ if filters: ids, where, where_document = self._normalize_filters(filters) @@ -152,7 +161,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D :param documents: a list of documents. :param policy: not supported at the moment :raises DuplicateDocumentError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL` - :return: None + :returns: None """ for doc in documents: if not isinstance(doc, Document): @@ -177,15 +186,17 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D def delete_documents(self, document_ids: List[str]) -> None: """ Deletes all documents with a matching document_ids from the document store. - Fails with `MissingDocumentError` if no document with this id is present in the store. :param document_ids: the object_ids to delete """ self._collection.delete(ids=document_ids) def search(self, queries: List[str], top_k: int) -> List[List[Document]]: - """ - Perform vector search on the stored documents + """Search the documents in the store using the provided text queries. + + :param queries: the list of queries to search for. + :param top_k: top_k documents to return for each query. + :return: matching documents for each query. """ results = self._collection.query( query_texts=queries, n_results=top_k, include=["embeddings", "documents", "metadatas", "distances"] @@ -196,10 +207,14 @@ def search_embeddings( self, query_embeddings: List[List[float]], top_k: int, filters: Optional[Dict[str, Any]] = None ) -> List[List[Document]]: """ - Perform vector search on the stored document, pass the embeddings of the queries - instead of their text. + Perform vector search on the stored document, pass the embeddings of the queries instead of their text. + + :param query_embeddings: a list of embeddings to use as queries. + :param top_k: the maximum number of documents to retrieve. + :param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format. + + :returns: a list of lists of documents that match the given filters. - Accepts filters in haystack format. """ if filters is None: results = self._collection.query( @@ -221,9 +236,23 @@ def search_embeddings( @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ChromaDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ return ChromaDocumentStore(**data) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return { "collection_name": self._collection_name, "embedding_function": self._embedding_function, diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/errors.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/errors.py index aeb0230cd..7596a265f 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/errors.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/errors.py @@ -6,12 +6,18 @@ class ChromaDocumentStoreError(DocumentStoreError): + """Parent class for all ChromaDocumentStore exceptions.""" + pass class ChromaDocumentStoreFilterError(FilterError, ValueError): + """Raised when a filter is not valid for a ChromaDocumentStore.""" + pass class ChromaDocumentStoreConfigError(ChromaDocumentStoreError): + """Raised when a configuration is not valid for a ChromaDocumentStore.""" + pass diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py index b1a354af4..08d6db618 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/utils.py @@ -32,6 +32,13 @@ def get_embedding_function(function_name: str, **kwargs) -> EmbeddingFunction: + """Load an embedding function by name. + + :param function_name: the name of the embedding function. + :param kwargs: additional arguments to pass to the embedding function. + :returns: the loaded embedding function. + :raises ChromaDocumentStoreConfigError: if the function name is invalid. + """ try: return FUNCTION_REGISTRY[function_name](**kwargs) except KeyError: diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index b09258128..c7a249f6c 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -15,9 +15,10 @@ class CohereDocumentEmbedder: """ A component for computing Document embeddings using Cohere models. + The embedding of each Document is stored in the `embedding` field of the Document. - Usage Example: + Usage example: ```python from haystack import Document from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder @@ -49,32 +50,30 @@ def __init__( embedding_separator: str = "\n", ): """ - Create a CohereDocumentEmbedder component. - - :param api_key: The Cohere API key. - :param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: + :param api_key: the Cohere API key. + :param model: the name of the model to use. Supported Models are: `"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`, `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, `"embed-multilingual-v2.0"`. This list of all supported models can be found in the [model documentation](https://docs.cohere.com/docs/models#representation). - :param input_type: Specifies the type of input you're giving to the model. Supported values are - "search_document", "search_query", "classification" and "clustering". Defaults to "search_document". Not - required for older versions of the embedding models (meaning anything lower than v3), but is required for more - recent versions (meaning anything bigger than v2). - :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. - :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to - `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both + :param input_type: specifies the type of input you're giving to the model. Supported values are + "search_document", "search_query", "classification" and "clustering". Not + required for older versions of the embedding models (meaning anything lower than v3), but is required for more + recent versions (meaning anything bigger than v2). + :param api_base_url: the Cohere API Base url. + :param truncate: truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"). + Passing "START" will discard the start of the input. "END" will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. - If NONE is selected, when the input exceeds the maximum input token length an error will be returned. - :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use + If "NONE" is selected, when the input exceeds the maximum input token length an error will be returned. + :param use_async_client: flag to select the AsyncClient. It is recommended to use AsyncClient for applications with many concurrent calls. - :param max_retries: maximal number of retries for requests, defaults to `3`. - :param timeout: request timeout in seconds, defaults to `120`. - :param batch_size: Number of Documents to encode at once. - :param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments + :param max_retries: maximal number of retries for requests. + :param timeout: request timeout in seconds. + :param batch_size: number of Documents to encode at once. + :param progress_bar: whether to show a progress bar or not. Can be helpful to disable in production deployments to keep the logs clean. - :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text. - :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + :param meta_fields_to_embed: list of meta fields that should be embedded along with the Document text. + :param embedding_separator: separator used to concatenate the meta fields to the Document text. """ self.api_key = api_key @@ -92,7 +91,10 @@ def __init__( def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary omitting the api_key field. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ return default_to_dict( self, @@ -113,9 +115,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "CohereDocumentEmbedder": """ - Deserialize this component from a dictionary. - :param data: The dictionary representation of this component. - :return: The deserialized component instance. + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) @@ -137,13 +142,14 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: @component.output_types(documents=List[Document], meta=Dict[str, Any]) def run(self, documents: List[Document]): - """ - Embed a list of Documents. - The embedding of each Document is stored in the `embedding` field of the Document. + """Embed a list of `Documents`. - :param documents: A list of Documents to embed. + :param documents: documents to embed. + :returns: A dictionary with the following keys: + - `documents`: documents with the `embedding` field set. + - `meta`: metadata about the embedding process. + :raises TypeError: if the input is not a list of `Documents`. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "CohereDocumentEmbedder expects a list of Documents as input." diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py index 448a49dec..2aa779771 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py @@ -16,9 +16,9 @@ class CohereTextEmbedder: """ A component for embedding strings using Cohere models. - Usage Example: + Usage example: ```python - from cohere_haystack.embedders.text_embedder import CohereTextEmbedder + from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder text_to_embed = "I love pizza!" @@ -43,27 +43,25 @@ def __init__( timeout: int = 120, ): """ - Create a CohereTextEmbedder component. - - :param api_key: The Cohere API key. - :param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: + :param api_key: the Cohere API key. + :param model: the name of the model to use. Supported Models are: `"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`, `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, `"embed-multilingual-v2.0"`. This list of all supported models can be found in the [model documentation](https://docs.cohere.com/docs/models#representation). - :param input_type: Specifies the type of input you're giving to the model. Supported values are - "search_document", "search_query", "classification" and "clustering". Defaults to "search_document". Not - required for older versions of the embedding models (meaning anything lower than v3), but is required for more - recent versions (meaning anything bigger than v2). - :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. - :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to - `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both + :param input_type: specifies the type of input you're giving to the model. Supported values are + "search_document", "search_query", "classification" and "clustering". Not + required for older versions of the embedding models (meaning anything lower than v3), but is required for more + recent versions (meaning anything bigger than v2). + :param api_base_url: the Cohere API Base url. + :param truncate: truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to + `"END"`. Passing "START" will discard the start of the input. "END" will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. - If NONE is selected, when the input exceeds the maximum input token length an error will be returned. - :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use + If "NONE" is selected, when the input exceeds the maximum input token length an error will be returned. + :param use_async_client: flag to select the AsyncClient. It is recommended to use AsyncClient for applications with many concurrent calls. - :param max_retries: Maximum number of retries for requests, defaults to `3`. - :param timeout: Request timeout in seconds, defaults to `120`. + :param max_retries: maximum number of retries for requests. + :param timeout: request timeout in seconds. """ self.api_key = api_key @@ -77,7 +75,10 @@ def __init__( def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary omitting the api_key field. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ return default_to_dict( self, @@ -94,9 +95,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "CohereTextEmbedder": """ - Deserialize this component from a dictionary. - :param data: The dictionary representation of this component. - :return: The deserialized component instance. + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) @@ -104,7 +108,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereTextEmbedder": @component.output_types(embedding=List[float], meta=Dict[str, Any]) def run(self, text: str): - """Embed a string.""" + """Embed text. + + :param text: the text to embed. + :returns: A dictionary with the following keys: + - "embedding": the embedding of the text. + - "meta": metadata about the request. + :raises TypeError: If the input is not a string. + """ if not isinstance(text, str): msg = ( "CohereTextEmbedder expects a string as input." diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py index 9c16ecee7..21a65e3da 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py @@ -9,6 +9,19 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate): + """Embeds a list of texts asynchronously using the Cohere API. + + :param cohere_async_client: the Cohere `AsyncClient` + :param texts: the texts to embed + :param model_name: the name of the model to use + :param input_type: one of "classification", "clustering", "search_document", "search_query". + The type of input text provided to embed. + :param truncate: one of "NONE", "START", "END". How the API handles text longer than the maximum token length. + + :returns: A tuple of the embeddings and metadata. + + :raises ValueError: If an error occurs while querying the Cohere API. + """ all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} try: @@ -30,9 +43,22 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], def get_response( cohere_client: Client, texts: List[str], model_name, input_type, truncate, batch_size=32, progress_bar=False ) -> Tuple[List[List[float]], Dict[str, Any]]: + """Embeds a list of texts using the Cohere API. + + :param cohere_client: the Cohere `Client` + :param texts: the texts to embed + :param model_name: the name of the model to use + :param input_type: one of "classification", "clustering", "search_document", "search_query". + The type of input text provided to embed. + :param truncate: one of "NONE", "START", "END". How the API handles text longer than the maximum token length. + :param batch_size: the batch size to use + :param progress_bar: if `True`, show a progress bar + + :returns: A tuple of the embeddings and metadata. + + :raises ValueError: If an error occurs while querying the Cohere API. """ - We support batching with the sync client. - """ + all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index bbbdb7e47..dcee45d10 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -14,8 +14,10 @@ @component class CohereChatGenerator: - """Enables text generation using Cohere's chat endpoint. This component is designed to inference - Cohere's chat models. + """ + Enables text generation using Cohere's chat endpoint. + + This component is designed to inference Cohere's chat models. Users can pass any text generation parameters valid for the `cohere.Client,chat` method directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs` @@ -23,6 +25,16 @@ class CohereChatGenerator: Invocations are made using 'cohere' package. See [Cohere API](https://docs.cohere.com/reference/chat) for more details. + + Example usage: + ```python + from haystack_integrations.components.generators.cohere import CohereChatGenerator + + component = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) + response = component.run(chat_messages) + + assert response["replies"] + ``` """ def __init__( @@ -37,12 +49,12 @@ def __init__( """ Initialize the CohereChatGenerator instance. - :param api_key: The API key for the Cohere API. + :param api_key: the API key for the Cohere API. :param model: The name of the model to use. Available models are: [command, command-light, command-nightly, - command-nightly-light]. Defaults to "command". - :param streaming_callback: A callback function to be called with the streaming response. Defaults to None. - :param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai". - :param generation_kwargs: Additional model parameters. These will be used during generation. Refer to + command-nightly-light]. + :param streaming_callback: a callback function to be called with the streaming response. + :param api_base_url: the base URL of the Cohere API. + :param generation_kwargs: additional model parameters. These will be used during generation. Refer to https://docs.cohere.com/reference/chat for more details. Some of the parameters are: - 'chat_history': A list of previous messages between the user and the model, meant to give the model @@ -89,8 +101,10 @@ def _get_telemetry_data(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary. - :return: The serialized component as a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( @@ -105,9 +119,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator": """ - Deserialize this component from a dictionary. - :param data: The dictionary representation of this component. - :return: The deserialized component instance. + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) @@ -126,12 +143,13 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, """ Invoke the text generation inference based on the provided messages and generation parameters. - :param messages: A list of ChatMessage instances representing the input messages. - :param generation_kwargs: Additional keyword arguments for text generation. These parameters will - potentially override the parameters passed in the __init__ method. - For more details on the parameters supported by the Cohere API, refer to the - Cohere [documentation](https://docs.cohere.com/reference/chat). - :return: A list containing the generated responses as ChatMessage instances. + :param messages: list of `ChatMessage` instances representing the input messages. + :param generation_kwargs: additional keyword arguments for text generation. These parameters will + potentially override the parameters passed in the __init__ method. + For more details on the parameters supported by the Cohere API, refer to the + Cohere [documentation](https://docs.cohere.com/reference/chat). + :returns: A dictionary with the following keys: + - "replies": a list of `ChatMessage` instances representing the generated responses. """ # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 4811d3581..6a25727dd 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -25,7 +25,8 @@ class CohereGenerator: Example usage: ```python - from haystack.generators import CohereGenerator + from haystack_integrations.components.generators.cohere import CohereGenerator + generator = CohereGenerator(api_key="test-api-key") generator.run(prompt="What's the capital of France?") ``` @@ -42,12 +43,12 @@ def __init__( """ Instantiates a `CohereGenerator` component. - :param api_key: The API key for the Cohere API. - :param model: The name of the model to use. Available models are: [command, command-light, command-nightly, - command-nightly-light]. Defaults to "command". - :param streaming_callback: A callback function to be called with the streaming response. Defaults to None. - :param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai". - :param kwargs: Additional model parameters. These will be used during generation. Refer to + :param api_key: the API key for the Cohere API. + :param model: the name of the model to use. Available models are: [command, command-light, command-nightly, + command-nightly-light]. + :param streaming_callback: A callback function to be called with the streaming response. + :param api_base_url: the base URL of the Cohere API. + :param kwargs: additional model parameters. These will be used during generation. Refer to https://docs.cohere.com/reference/generate for more details. Some of the parameters are: - 'max_tokens': The maximum number of tokens to be generated. Defaults to 1024. @@ -87,7 +88,10 @@ def __init__( def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ return default_to_dict( self, @@ -101,7 +105,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator": """ - Deserialize this component from a dictionary. + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ init_params = data.get("init_parameters", {}) deserialize_secrets_inplace(init_params, ["api_key"]) @@ -115,7 +124,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator": def run(self, prompt: str): """ Queries the LLM with the prompts to produce replies. - :param prompt: The prompt to be sent to the generative model. + + :param prompt: the prompt to be sent to the generative model. + :returns: A dictionary with the following keys: + - "replies": the list of replies generated by the model. + - "meta": metadata about the request. """ response = self.client.generate( model=self.model, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py index 2ecc01bc3..6bcd94220 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py @@ -20,12 +20,14 @@ class JinaDocumentEmbedder: Usage example: ```python from haystack import Document - from jina_haystack import JinaDocumentEmbedder + from haystack_integrations.components.embedders.jina import JinaDocumentEmbedder - doc = Document(content="I love pizza!") + # Make sure that the environment variable JINA_API_KEY is set document_embedder = JinaDocumentEmbedder() + doc = Document(content="I love pizza!") + result = document_embedder.run([doc]) print(result['documents'][0].embedding) @@ -46,8 +48,10 @@ def __init__( ): """ Create a JinaDocumentEmbedder component. + :param api_key: The Jina API key. - :param model: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/` + :param model: The name of the Jina model to use. + Check the list of available models on [Jina documentation](https://jina.ai/embeddings/). :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. :param batch_size: Number of Documents to encode at once. @@ -83,8 +87,9 @@ def _get_telemetry_data(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]: """ - This method overrides the default serializer in order to avoid leaking the `api_key` value passed - to the constructor. + Serializes the component to a dictionary. + :returns: + Dictionary with serialized data. """ return default_to_dict( self, @@ -100,6 +105,13 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "JinaDocumentEmbedder": + """ + Deserializes the component from a dictionary. + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) @@ -151,10 +163,13 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List @component.output_types(documents=List[Document], meta=Dict[str, Any]) def run(self, documents: List[Document]): """ - Embed a list of Documents. - The embedding of each Document is stored in the `embedding` field of the Document. + Compute the embeddings for a list of Documents. :param documents: A list of Documents to embed. + :returns: A dictionary with following keys: + - `documents`: List of Documents, each with an `embedding` field containing the computed embedding. + - `meta`: A dictionary with metadata including the model name and usage statistics. + :raises TypeError: If the input is not a list of Documents. """ if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py index f99882f13..6398122a4 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py @@ -13,16 +13,18 @@ @component class JinaTextEmbedder: """ - A component for embedding strings using Jina models. + A component for embedding strings using Jina AI models. Usage example: ```python - from jina_haystack import JinaTextEmbedder + from haystack_integrations.components.embedders.jina import JinaTextEmbedder - text_to_embed = "I love pizza!" + # Make sure that the environment variable JINA_API_KEY is set text_embedder = JinaTextEmbedder() + text_to_embed = "I love pizza!" + print(text_embedder.run(text_to_embed)) # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], @@ -39,11 +41,12 @@ def __init__( suffix: str = "", ): """ - Create an JinaTextEmbedder component. + Create a JinaTextEmbedder component. :param api_key: The Jina API key. It can be explicitly provided or automatically read from the - environment variable JINA_API_KEY (recommended). - :param model: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/` + environment variable `JINA_API_KEY` (recommended). + :param model: The name of the Jina model to use. + Check the list of available models on [Jina documentation](https://jina.ai/embeddings/). :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. """ @@ -71,22 +74,37 @@ def _get_telemetry_data(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]: """ - This method overrides the default serializer in order to avoid leaking the `api_key` value passed - to the constructor. + Serializes the component to a dictionary. + :returns: + Dictionary with serialized data. """ - return default_to_dict( self, api_key=self.api_key.to_dict(), model=self.model_name, prefix=self.prefix, suffix=self.suffix ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "JinaTextEmbedder": + """ + Deserializes the component from a dictionary. + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) @component.output_types(embedding=List[float], meta=Dict[str, Any]) def run(self, text: str): - """Embed a string.""" + """ + Embed a string. + + :param text: The string to embed. + :returns: A dictionary with following keys: + - `embedding`: The embedding of the input string. + - `meta`: A dictionary with metadata including the model name and usage statistics. + :raises TypeError: If the input is not a string. + """ if not isinstance(text, str): msg = ( "JinaTextEmbedder expects a string as an input." diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py index 6e3273e1c..b5783c611 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py @@ -7,6 +7,23 @@ @component class OllamaDocumentEmbedder: + """ + Computes the embeddings of a list of Documents and stores the obtained vectors in the embedding field of each + Document. It uses embedding models compatible with the Ollama Library. + + Usage example: + ```python + from haystack import Document + from haystack_integrations.components.embedders.ollama import OllamaDocumentEmbedder + + doc = Document(content="What do llamas say once you have thanked them? No probllama!") + document_embedder = OllamaDocumentEmbedder() + + result = document_embedder.run([doc]) + print(result['documents'][0].embedding) + ``` + """ + def __init__( self, model: str = "nomic-embed-text", @@ -20,15 +37,16 @@ def __init__( embedding_separator: str = "\n", ): """ - :param model: The name of the model to use. The model should be available in the running Ollama instance. - Default is "nomic-embed-text". "https://ollama.com/library/nomic-embed-text" - :param url: The URL of the chat endpoint of a running Ollama instance. - Default is "http://localhost:11434/api/embeddings". - :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, - top_p, and others. See the available arguments in + :param model: + The name of the model to use. The model should be available in the running Ollama instance. + :param url: + The URL of the chat endpoint of a running Ollama instance. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, and others. + See the available arguments in [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :param timeout: The number of seconds before throwing a timeout error from the Ollama API. - Default is 120 seconds. + :param timeout: + The number of seconds before throwing a timeout error from the Ollama API. """ self.timeout = timeout self.generation_kwargs = generation_kwargs or {} @@ -44,15 +62,12 @@ def __init__( def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: """ Returns A dictionary of JSON arguments for a POST request to an Ollama service - :param text: Text that is to be converted to an embedding - :param generation_kwargs: - :return: A dictionary of arguments for a POST request to an Ollama service """ return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}} def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: """ - Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + Prepares the texts to embed by concatenating the Document text with the metadata fields to embed. """ texts_to_embed = [] for doc in documents: @@ -101,12 +116,17 @@ def _embed_batch( @component.output_types(documents=List[Document], meta=Dict[str, Any]) def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None): """ - Run an Ollama Model on a provided documents. - :param documents: Documents to be converted to an embedding. - :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + Runs an Ollama Model to compute embeddings of the provided documents. + + :param documents: + Documents to be converted to an embedding. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, etc. See the [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :return: Documents with embedding information attached and metadata in a dictionary + :returns: A dictionary with the following keys: + - `documents`: Documents with embedding information attached + - `meta`: The metadata collected during the embedding process """ if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py index e2ef136b4..5a28ba393 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py @@ -6,6 +6,20 @@ @component class OllamaTextEmbedder: + """ + Computes the embeddings of a list of Documents and stores the obtained vectors in the embedding field of + each Document. It uses embedding models compatible with the Ollama Library. + + Usage example: + ```python + from haystack_integrations.components.embedders.ollama import OllamaTextEmbedder + + embedder = OllamaTextEmbedder() + result = embedder.run(text="What do llamas say once you have thanked them? No probllama!") + print(result['embedding']) + ``` + """ + def __init__( self, model: str = "nomic-embed-text", @@ -14,15 +28,16 @@ def __init__( timeout: int = 120, ): """ - :param model: The name of the model to use. The model should be available in the running Ollama instance. - Default is "nomic-embed-text". "https://ollama.com/library/nomic-embed-text" - :param url: The URL of the chat endpoint of a running Ollama instance. - Default is "http://localhost:11434/api/embeddings". - :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + :param model: + The name of the model to use. The model should be available in the running Ollama instance. + :param url: + The URL of the chat endpoint of a running Ollama instance. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, and others. See the available arguments in [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :param timeout: The number of seconds before throwing a timeout error from the Ollama API. - Default is 120 seconds. + :param timeout: + The number of seconds before throwing a timeout error from the Ollama API. """ self.timeout = timeout self.generation_kwargs = generation_kwargs or {} @@ -32,21 +47,23 @@ def __init__( def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: """ Returns A dictionary of JSON arguments for a POST request to an Ollama service - :param text: Text that is to be converted to an embedding - :param generation_kwargs: - :return: A dictionary of arguments for a POST request to an Ollama service """ return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}} @component.output_types(embedding=List[float], meta=Dict[str, Any]) def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None): """ - Run an Ollama Model on a given chat history. - :param text: Text to be converted to an embedding. - :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + Runs an Ollama Model to compute embeddings of the provided text. + + :param text: + Text to be converted to an embedding. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, etc. See the [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :return: A dictionary with the key "embedding" and a list of floats as the value + :returns: A dictionary with the following keys: + - `embedding`: The computed embeddings + - `meta`: The metadata collected during the embedding process """ payload = self._create_json_payload(text, generation_kwargs) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 6a8c5493b..2abf3066b 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -9,8 +9,26 @@ @component class OllamaChatGenerator: """ - Chat Generator based on Ollama. Ollama is a library for easily running LLMs locally. - This component provides an interface to generate text using a LLM running in Ollama. + Supports models running on Ollama, such as llama2 and mixtral. Find the full list of supported models + [here](https://ollama.ai/library). + + Usage example: + ```python + from haystack_integrations.components.generators.ollama import OllamaChatGenerator + from haystack.dataclasses import ChatMessage + + generator = OllamaChatGenerator(model="zephyr", + url = "http://localhost:11434/api/chat", + generation_kwargs={ + "num_predict": 100, + "temperature": 0.9, + }) + + messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + + print(generator.run(messages=messages)) + ``` """ def __init__( @@ -22,16 +40,18 @@ def __init__( timeout: int = 120, ): """ - :param model: The name of the model to use. The model should be available in the running Ollama instance. - Default is "orca-mini". - :param url: The URL of the chat endpoint of a running Ollama instance. - Default is "http://localhost:11434/api/chat". - :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + :param model: + The name of the model to use. The model should be available in the running Ollama instance. + :param url: + The URL of the chat endpoint of a running Ollama instance. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, and others. See the available arguments in [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :param template: The full prompt template (overrides what is defined in the Ollama Modelfile). - :param timeout: The number of seconds before throwing a timeout error from the Ollama API. - Default is 120 seconds. + :param template: + The full prompt template (overrides what is defined in the Ollama Modelfile). + :param timeout: + The number of seconds before throwing a timeout error from the Ollama API. """ self.timeout = timeout @@ -46,9 +66,6 @@ def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: def _create_json_payload(self, messages: List[ChatMessage], generation_kwargs=None) -> Dict[str, Any]: """ Returns A dictionary of JSON arguments for a POST request to an Ollama service - :param messages: A history/list of chat messages - :param generation_kwargs: - :return: A dictionary of arguments for a POST request to an Ollama service """ generation_kwargs = generation_kwargs or {} return { @@ -62,8 +79,6 @@ def _create_json_payload(self, messages: List[ChatMessage], generation_kwargs=No def _build_message_from_ollama_response(self, ollama_response: Response) -> ChatMessage: """ Converts the non-streaming response from the Ollama API to a ChatMessage. - :param ollama_response: The completion returned by the Ollama API. - :return: The ChatMessage. """ json_content = ollama_response.json() message = ChatMessage.from_assistant(content=json_content["message"]["content"]) @@ -77,12 +92,16 @@ def run( generation_kwargs: Optional[Dict[str, Any]] = None, ): """ - Run an Ollama Model on a given chat history. - :param messages: A list of ChatMessage instances representing the input messages. - :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + Runs an Ollama Model on a given chat history. + + :param messages: + A list of ChatMessage instances representing the input messages. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, etc. See the [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :return: A dictionary of the replies containing their metadata + :returns: A dictionary with the following keys: + - `replies`: The responses from the model """ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index f3ab86282..bbd7b05ca 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -11,8 +11,21 @@ @component class OllamaGenerator: """ - Generator based on Ollama. Ollama is a library for easily running LLMs locally. - This component provides an interface to generate text using a LLM running in Ollama. + Provides an interface to generate text using an LLM running on Ollama. + + Usage example: + ```python + from haystack_integrations.components.generators.ollama import OllamaGenerator + + generator = OllamaGenerator(model="zephyr", + url = "http://localhost:11434/api/generate", + generation_kwargs={ + "num_predict": 100, + "temperature": 0.9, + }) + + print(generator.run("Who is the best American actor?")) + ``` """ def __init__( @@ -27,20 +40,25 @@ def __init__( streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ - :param model: The name of the model to use. The model should be available in the running Ollama instance. - Default is "orca-mini". - :param url: The URL of the generation endpoint of a running Ollama instance. - Default is "http://localhost:11434/api/generate". - :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + :param model: + The name of the model to use. The model should be available in the running Ollama instance. + :param url: + The URL of the generation endpoint of a running Ollama instance. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, and others. See the available arguments in [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :param system_prompt: Optional system message (overrides what is defined in the Ollama Modelfile). - :param template: The full prompt template (overrides what is defined in the Ollama Modelfile). - :param raw: If True, no formatting will be applied to the prompt. You may choose to use the raw parameter + :param system_prompt: + Optional system message (overrides what is defined in the Ollama Modelfile). + :param template: + The full prompt template (overrides what is defined in the Ollama Modelfile). + :param raw: + If True, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your API request. - :param timeout: The number of seconds before throwing a timeout error from the Ollama API. - Default is 120 seconds. - :param streaming_callback: A callback function that is called when a new token is received from the stream. + :param timeout: + The number of seconds before throwing a timeout error from the Ollama API. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. """ self.timeout = timeout @@ -54,8 +72,10 @@ def __init__( def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary. - :return: The serialized component as a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None return default_to_dict( @@ -73,9 +93,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": """ - Deserialize this component from a dictionary. - :param data: The dictionary representation of this component. - :return: The deserialized component instance. + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. """ init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") @@ -86,11 +109,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None) -> Dict[str, Any]: """ Returns a dictionary of JSON arguments for a POST request to an Ollama service. - :param prompt: The prompt to generate a response for. - :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, - top_p, and others. See the available arguments in - [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :return: A dictionary of arguments for a POST request to an Ollama service. """ generation_kwargs = generation_kwargs or {} return { @@ -105,9 +123,7 @@ def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None def _convert_to_response(self, ollama_response: Response) -> Dict[str, List[Any]]: """ - Convert a response from the Ollama API to the required Haystack format. - :param ollama_response: A response (requests library) from the Ollama API. - :return: A dictionary of the returned responses and metadata. + Converts a response from the Ollama API to the required Haystack format. """ resp_dict = ollama_response.json() @@ -119,9 +135,7 @@ def _convert_to_response(self, ollama_response: Response) -> Dict[str, List[Any] def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: """ - Convert a list of chunks response required Haystack format. - :param chunks: List of StreamingChunks - :return: A dictionary of the returned responses and metadata. + Converts a list of chunks response required Haystack format. """ replies = ["".join([c.content for c in chunks])] @@ -130,10 +144,8 @@ def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[s return {"replies": replies, "meta": [meta]} def _handle_streaming_response(self, response) -> List[StreamingChunk]: - """Handles Streaming response case - - :param response: streaming response from ollama api. - :return: The List[StreamingChunk]. + """ + Handles Streaming response cases """ chunks: List[StreamingChunk] = [] for chunk in response.iter_lines(): @@ -146,8 +158,6 @@ def _handle_streaming_response(self, response) -> List[StreamingChunk]: def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ Converts the response from the Ollama API to a StreamingChunk. - :param chunk: The chunk returned by the Ollama API. - :return: The StreamingChunk. """ decoded_chunk = json.loads(chunk_response.decode("utf-8")) @@ -164,12 +174,17 @@ def run( generation_kwargs: Optional[Dict[str, Any]] = None, ): """ - Run an Ollama Model on the given prompt. - :param prompt: The prompt to generate a response for. - :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + Runs an Ollama Model on the given prompt. + + :param prompt: + The prompt to generate a response for. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, and others. See the available arguments in [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :return: A dictionary of the response and returned metadata + :returns: A dictionary with the following keys: + - `replies`: The responses from the model + - `meta`: The metadata collected during the run """ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} diff --git a/integrations/pgvector/examples/example.py b/integrations/pgvector/examples/example.py index 764c915d1..37ea88929 100644 --- a/integrations/pgvector/examples/example.py +++ b/integrations/pgvector/examples/example.py @@ -11,7 +11,6 @@ # git clone https://github.com/anakin87/neural-search-pills import glob -import os from haystack import Pipeline from haystack.components.converters import MarkdownToDocument @@ -21,7 +20,8 @@ from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore -os.environ["PG_CONN_STR"] = "postgresql://postgres:postgres@localhost:5432/postgres" +# Set an environment variable `PG_CONN_STR` with the connection string to your PostgreSQL database. +# e.g., "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME" # Initialize PgvectorDocumentStore document_store = PgvectorDocumentStore( diff --git a/integrations/pgvector/pydoc/config.yml b/integrations/pgvector/pydoc/config.yml index ea354c14b..449937629 100644 --- a/integrations/pgvector/pydoc/config.yml +++ b/integrations/pgvector/pydoc/config.yml @@ -4,7 +4,6 @@ loaders: modules: [ "haystack_integrations.components.retrievers.pgvector.embedding_retriever", "haystack_integrations.document_stores.pgvector.document_store", - "haystack_integrations.document_stores.pgvector.filters", ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py index 4b8df868b..6085545cb 100644 --- a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py @@ -12,9 +12,47 @@ @component class PgvectorEmbeddingRetriever: """ - Retrieves documents from the PgvectorDocumentStore, based on their dense embeddings. + Retrieves documents from the `PgvectorDocumentStore`, based on their dense embeddings. - Needs to be connected to the PgvectorDocumentStore. + Example usage: + ```python + from haystack.document_stores import DuplicatePolicy + from haystack import Document, Pipeline + from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder + + from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever + + # Set an environment variable `PG_CONN_STR` with the connection string to your PostgreSQL database. + # e.g., "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME" + + document_store = PgvectorDocumentStore( + embedding_dimension=768, + vector_function="cosine_similarity", + recreate_table=True, + ) + + documents = [Document(content="There are over 7,000 languages spoken around the world today."), + Document(content="Elephants have been observed to behave in a way that indicates..."), + Document(content="In certain places, you can witness the phenomenon of bioluminescent waves.")] + + document_embedder = SentenceTransformersDocumentEmbedder() + document_embedder.warm_up() + documents_with_embeddings = document_embedder.run(documents) + + document_store.write_documents(documents_with_embeddings.get("documents"), policy=DuplicatePolicy.OVERWRITE) + + query_pipeline = Pipeline() + query_pipeline.add_component("text_embedder", SentenceTransformersTextEmbedder()) + query_pipeline.add_component("retriever", PgvectorEmbeddingRetriever(document_store=document_store)) + query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + + query = "How many languages are there?" + + res = query_pipeline.run({"text_embedder": {"text": query}}) + + assert res['retriever']['documents'][0].content == "There are over 7,000 languages spoken around the world today." + ``` """ def __init__( @@ -26,23 +64,20 @@ def __init__( vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, ): """ - Create the PgvectorEmbeddingRetriever component. - - :param document_store: An instance of PgvectorDocumentStore. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - :param top_k: Maximum number of Documents to return, defaults to 10. + :param document_store: An instance of `PgvectorDocumentStore}. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. :param vector_function: The similarity function to use when searching for similar embeddings. Defaults to the one set in the `document_store` instance. - "cosine_similarity" and "inner_product" are similarity functions and + `"cosine_similarity"` and `"inner_product"` are similarity functions and higher scores indicate greater similarity between the documents. - "l2_distance" returns the straight-line distance between vectors, + `"l2_distance"` returns the straight-line distance between vectors, and the most similar documents are the ones with the smallest score. - - Important: if the document store is using the "hnsw" search strategy, the vector function + **Important**: if the document store is using the `"hnsw"` search strategy, the vector function should match the one utilized during index creation to take advantage of the index. - :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] - :raises ValueError: If `document_store` is not an instance of PgvectorDocumentStore. + :raises ValueError: If `document_store` is not an instance of `PgvectorDocumentStore` or if `vector_function` + is not one of the valid options. """ if not isinstance(document_store, PgvectorDocumentStore): msg = "document_store must be an instance of PgvectorDocumentStore" @@ -58,6 +93,12 @@ def __init__( self.vector_function = vector_function or document_store.vector_function def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, filters=self.filters, @@ -68,6 +109,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ doc_store_params = data["init_parameters"]["document_store"] data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params) return default_from_dict(cls, data) @@ -81,14 +130,14 @@ def run( vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, ): """ - Retrieve documents from the PgvectorDocumentStore, based on their embeddings. + Retrieve documents from the `PgvectorDocumentStore`, based on their embeddings. :param query_embedding: Embedding of the query. :param filters: Filters applied to the retrieved Documents. :param top_k: Maximum number of Documents to return. :param vector_function: The similarity function to use when searching for similar embeddings. - :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] - :return: List of Documents similar to `query_embedding`. + + :returns: List of Documents similar to `query_embedding`. """ filters = filters or self.filters top_k = top_k or self.top_k diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 798c75276..3396c15ea 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -67,6 +67,10 @@ class PgvectorDocumentStore: + """ + A Document Store using PostgreSQL with the [pgvector extension](https://github.com/pgvector/pgvector) installed. + """ + def __init__( self, *, @@ -86,36 +90,33 @@ def __init__( A specific table to store Haystack documents will be created if it doesn't exist yet. :param connection_string: The connection string to use to connect to the PostgreSQL database, defined as an - environment variable, e.g.: PG_CONN_STR="postgresql://USER:PASSWORD@HOST:PORT/DB_NAME" - :param table_name: The name of the table to use to store Haystack documents. Defaults to "haystack_documents". - :param embedding_dimension: The dimension of the embedding. Defaults to 768. + environment variable, e.g.: `PG_CONN_STR="postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"` + :param table_name: The name of the table to use to store Haystack documents. + :param embedding_dimension: The dimension of the embedding. :param vector_function: The similarity function to use when searching for similar embeddings. - Defaults to "cosine_similarity". "cosine_similarity" and "inner_product" are similarity functions and + `"cosine_similarity"` and `"inner_product"` are similarity functions and higher scores indicate greater similarity between the documents. - "l2_distance" returns the straight-line distance between vectors, + `"l2_distance"` returns the straight-line distance between vectors, and the most similar documents are the ones with the smallest score. - - Important: when using the "hnsw" search strategy, an index will be created that depends on the + **Important**: when using the `"hnsw"` search strategy, an index will be created that depends on the `vector_function` passed here. Make sure subsequent queries will keep using the same vector similarity function in order to take advantage of the index. - :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] - :param recreate_table: Whether to recreate the table if it already exists. Defaults to False. + :param recreate_table: Whether to recreate the table if it already exists. :param search_strategy: The search strategy to use when searching for similar embeddings. - Defaults to "exact_nearest_neighbor". "hnsw" is an approximate nearest neighbor search strategy, + `"exact_nearest_neighbor"` provides perfect recall but can be slow for large numbers of documents. + `"hnsw"` is an approximate nearest neighbor search strategy, which trades off some accuracy for speed; it is recommended for large numbers of documents. - - Important: when using the "hnsw" search strategy, an index will be created that depends on the + **Important**: when using the `"hnsw"` search strategy, an index will be created that depends on the `vector_function` passed here. Make sure subsequent queries will keep using the same vector similarity function in order to take advantage of the index. - :type search_strategy: Literal["exact_nearest_neighbor", "hnsw"] :param hnsw_recreate_index_if_exists: Whether to recreate the HNSW index if it already exists. - Defaults to False. Only used if search_strategy is set to "hnsw". + Only used if search_strategy is set to `"hnsw"`. :param hnsw_index_creation_kwargs: Additional keyword arguments to pass to the HNSW index creation. - Only used if search_strategy is set to "hnsw". You can find the list of valid arguments in the - pgvector documentation: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw - :param hnsw_ef_search: The ef_search parameter to use at query time. Only used if search_strategy is set to - "hnsw". You can find more information about this parameter in the pgvector documentation: - https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw + Only used if search_strategy is set to `"hnsw"`. You can find the list of valid arguments in the + [pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw) + :param hnsw_ef_search: The `ef_search` parameter to use at query time. Only used if search_strategy is set to + `"hnsw"`. You can find more information about this parameter in the + [pgvector documentation](https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw) """ self.connection_string = connection_string @@ -150,6 +151,12 @@ def __init__( self._handle_hnsw() def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, connection_string=self.connection_string.to_dict(), @@ -165,6 +172,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PgvectorDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ deserialize_secrets_inplace(data["init_parameters"], ["connection_string"]) return default_from_dict(cls, data) @@ -209,6 +224,7 @@ def _create_table_if_not_exists(self): def delete_table(self): """ Deletes the table used to store Haystack documents. + The name of the table (`table_name`) is defined when initializing the `PgvectorDocumentStore`. """ delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(self.table_name)) @@ -218,7 +234,7 @@ def delete_table(self): def _handle_hnsw(self): """ Internal method to handle the HNSW index creation. - It also sets the hnsw.ef_search parameter for queries if it is specified. + It also sets the `hnsw.ef_search` parameter for queries if it is specified. """ if self.hnsw_ef_search: @@ -295,7 +311,8 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering) :param filters: The filters to apply to the document list. - :return: A list of Documents that match the given filters. + :raises TypeError: If `filters` is not a dictionary. + :returns: A list of Documents that match the given filters. """ if filters: if not isinstance(filters, dict): @@ -324,13 +341,13 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: """ - Writes documents into to PgvectorDocumentStore. + Writes documents to the document store. :param documents: A list of Documents to write to the document store. :param policy: The duplicate policy to use when writing documents. :raises DuplicateDocumentError: If a document with the same id already exists in the document store - and the policy is set to DuplicatePolicy.FAIL (or not specified). - :return: The number of documents written to the document store. + and the policy is set to `DuplicatePolicy.FAIL` (or not specified). + :returns: The number of documents written to the document store. """ if len(documents) > 0: @@ -432,7 +449,7 @@ def _from_pg_to_haystack_documents(documents: List[Dict[str, Any]]) -> List[Docu def delete_documents(self, document_ids: List[str]) -> None: """ - Deletes all documents with a matching document_ids from the document store. + Deletes documents that match the provided `document_ids` from the document store. :param document_ids: the document ids to delete """ @@ -462,8 +479,7 @@ def _embedding_retrieval( This method is not meant to be part of the public interface of `PgvectorDocumentStore` and it should not be called directly. `PgvectorEmbeddingRetriever` uses this method directly and is the public interface for it. - :raises ValueError - :return: List of Documents that are most similar to `query_embedding` + :returns: List of Documents that are most similar to `query_embedding` """ if not query_embedding: diff --git a/integrations/pinecone/pydoc/config.yml b/integrations/pinecone/pydoc/config.yml index 51ef0ee15..fff835877 100644 --- a/integrations/pinecone/pydoc/config.yml +++ b/integrations/pinecone/pydoc/config.yml @@ -4,9 +4,7 @@ loaders: modules: [ "haystack_integrations.components.retrievers.pinecone.embedding_retriever", - "haystack_integrations.document_stores.pinecone.document_store", - "haystack_integrations.document_stores.pinecone.errors", - "haystack_integrations.document_stores.pinecone.filters", + "haystack_integrations.document_stores.pinecone.document_store" ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py b/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py index 840c9e1f6..02c4a3a87 100644 --- a/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py +++ b/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/embedding_retriever.py @@ -12,9 +12,41 @@ @component class PineconeEmbeddingRetriever: """ - Retrieves documents from the PineconeDocumentStore, based on their dense embeddings. + Retrieves documents from the `PineconeDocumentStore`, based on their dense embeddings. - Needs to be connected to the PineconeDocumentStore. + Usage example: + ```python + import os + from haystack.document_stores.types import DuplicatePolicy + from haystack import Document + from haystack import Pipeline + from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder + from haystack_integrations.components.retrievers.pinecone import PineconeEmbeddingRetriever + from haystack_integrations.document_stores.pinecone import PineconeDocumentStore + + os.environ["PINECONE_API_KEY"] = "YOUR_PINECONE_API_KEY" + document_store = PineconeDocumentStore(index="my_index", namespace="my_namespace", dimension=768) + + documents = [Document(content="There are over 7,000 languages spoken around the world today."), + Document(content="Elephants have been observed to behave in a way that indicates..."), + Document(content="In certain places, you can witness the phenomenon of bioluminescent waves.")] + + document_embedder = SentenceTransformersDocumentEmbedder() + document_embedder.warm_up() + documents_with_embeddings = document_embedder.run(documents) + + document_store.write_documents(documents_with_embeddings.get("documents"), policy=DuplicatePolicy.OVERWRITE) + + query_pipeline = Pipeline() + query_pipeline.add_component("text_embedder", SentenceTransformersTextEmbedder()) + query_pipeline.add_component("retriever", PineconeEmbeddingRetriever(document_store=document_store)) + query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + + query = "How many languages are there?" + + res = query_pipeline.run({"text_embedder": {"text": query}}) + assert res['retriever']['documents'][0].content == "There are over 7,000 languages spoken around the world today." + ``` """ def __init__( @@ -25,13 +57,11 @@ def __init__( top_k: int = 10, ): """ - Create the PineconeEmbeddingRetriever component. - - :param document_store: An instance of PineconeDocumentStore. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - :param top_k: Maximum number of Documents to return, defaults to 10. + :param document_store: The Pinecone Document Store. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. - :raises ValueError: If `document_store` is not an instance of PineconeDocumentStore. + :raises ValueError: If `document_store` is not an instance of `PineconeDocumentStore`. """ if not isinstance(document_store, PineconeDocumentStore): msg = "document_store must be an instance of PineconeDocumentStore" @@ -42,6 +72,11 @@ def __init__( self.top_k = top_k def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, filters=self.filters, @@ -51,6 +86,13 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PineconeEmbeddingRetriever": + """ + Deserializes the component from a dictionary. + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ data["init_parameters"]["document_store"] = default_from_dict( PineconeDocumentStore, data["init_parameters"]["document_store"] ) @@ -59,10 +101,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "PineconeEmbeddingRetriever": @component.output_types(documents=List[Document]) def run(self, query_embedding: List[float]): """ - Retrieve documents from the PineconeDocumentStore, based on their dense embeddings. + Retrieve documents from the `PineconeDocumentStore`, based on their dense embeddings. :param query_embedding: Embedding of the query. - :return: List of Document similar to `query_embedding`. + :returns: List of Document similar to `query_embedding`. """ docs = self.document_store._embedding_retrieval( query_embedding=query_embedding, diff --git a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py index 91364d7bf..a23cf80f6 100644 --- a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py @@ -26,6 +26,10 @@ class PineconeDocumentStore: + """ + A Document Store using [Pinecone vector database](https://www.pinecone.io/). + """ + def __init__( self, *, @@ -42,20 +46,17 @@ def __init__( It is meant to be connected to a Pinecone index and namespace. :param api_key: The Pinecone API key. It can be explicitly provided or automatically read from the - environment variable PINECONE_API_KEY (recommended). - :param environment: The Pinecone environment to connect to. Defaults to "us-west1-gcp". + environment variable `PINECONE_API_KEY` (recommended). + :param environment: The Pinecone environment to connect to. :param index: The Pinecone index to connect to. If the index does not exist, it will be created. - Defaults to "default". :param namespace: The Pinecone namespace to connect to. If the namespace does not exist, it will be created - at the first write. Defaults to "default". - :param batch_size: The number of documents to write in a single batch. Defaults to 100, as recommended by - Pinecone. + at the first write. + :param batch_size: The number of documents to write in a single batch. When setting this parameter, + consider [documented Pinecone limits](https://docs.pinecone.io/docs/limits). :param dimension: The dimension of the embeddings. This parameter is only used when creating a new index. - Defaults to 768. :param index_creation_kwargs: Additional keyword arguments to pass to the index creation method. - For example, you can specify `metric`, `pods`, `replicas`... You can find the full list of supported arguments in the - [API reference](https://docs.pinecone.io/reference/create_index-1). + [API reference](https://docs.pinecone.io/reference/create_index). """ resolved_api_key = api_key.resolve_value() @@ -95,10 +96,22 @@ def __init__( @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PineconeDocumentStore": + """ + Deserializes the component from a dictionary. + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, api_key=self.api_key.to_dict(), @@ -128,7 +141,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D :param policy: The duplicate policy to use when writing documents. PineconeDocumentStore only supports `DuplicatePolicy.OVERWRITE`. - :return: The number of documents written to the document store. + :returns: The number of documents written to the document store. """ if len(documents) > 0 and not isinstance(documents[0], Document): msg = "param 'documents' must contain a list of objects of type Document" @@ -157,7 +170,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering) :param filters: The filters to apply to the document list. - :return: A list of Documents that match the given filters. + :returns: A list of Documents that match the given filters. """ # Pinecone only performs vector similarity search @@ -178,7 +191,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc def delete_documents(self, document_ids: List[str]) -> None: """ - Deletes all documents with a matching document_ids from the document store. + Deletes documents that match the provided `document_ids` from the document store. :param document_ids: the document ids to delete """ @@ -197,14 +210,14 @@ def _embedding_retrieval( This method is not mean to be part of the public interface of `PineconeDocumentStore` nor called directly. - `PineconeDenseRetriever` uses this method directly and is the public interface for it. + `PineconeEmbeddingRetriever` uses this method directly and is the public interface for it. :param query_embedding: Embedding of the query. :param namespace: Pinecone namespace to query. Defaults the namespace of the document store. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - :param top_k: Maximum number of Documents to return, defaults to 10 + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. - :return: List of Document that are most similar to `query_embedding` + :returns: List of Document that are most similar to `query_embedding` """ if not query_embedding: diff --git a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/errors.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/errors.py deleted file mode 100644 index 994f34cf0..000000000 --- a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/errors.py +++ /dev/null @@ -1,10 +0,0 @@ -from haystack.document_stores.errors import DocumentStoreError -from haystack.errors import FilterError - - -class PineconeDocumentStoreError(DocumentStoreError): - pass - - -class PineconeDocumentStoreFilterError(FilterError): - pass