Skip to content

Commit

Permalink
docs: fix docstrings (#586)
Browse files Browse the repository at this point in the history
* fix docstrings for retrievers

* document store

* fix test

* actual fix
  • Loading branch information
masci authored Mar 15, 2024
1 parent f4730e5 commit d2780c9
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@
@component
class WeaviateBM25Retriever:
"""
Retriever that uses BM25 to find the most promising documents for a given query.
A component for retrieving documents from Weaviate using the BM25 algorithm.
Example usage:
```python
from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore
from haystack_integrations.components.retrievers.weaviate.bm25_retriever import WeaviateBM25Retriever
document_store = WeaviateDocumentStore(url="http://localhost:8080")
retriever = WeaviateBM25Retriever(document_store=document_store)
retriever.run(query="How to make a pizza", top_k=3)
```
"""

def __init__(
Expand All @@ -20,15 +30,24 @@ def __init__(
"""
Create a new instance of WeaviateBM25Retriever.
:param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever.
:param filters: Custom filters applied when running the retriever, defaults to None
:param top_k: Maximum number of documents to return, defaults to 10
:param document_store:
Instance of WeaviateDocumentStore that will be used from this retriever.
:param filters:
Custom filters applied when running the retriever
:param top_k:
Maximum number of documents to return
"""
self._document_store = document_store
self._filters = filters or {}
self._top_k = top_k

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self._filters,
Expand All @@ -38,13 +57,31 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WeaviateBM25Retriever":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
"""
Retrieves documents from Weaviate using the BM25 algorithm.
:param query:
The query text.
:param filters:
Filters to use when running the retriever.
:param top_k:
The maximum number of documents to return.
"""
filters = filters or self._filters
top_k = top_k or self._top_k
documents = self._document_store._bm25_retrieval(query=query, filters=filters, top_k=top_k)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@ def __init__(
certainty: Optional[float] = None,
):
"""
Create a new instance of WeaviateEmbeddingRetriever.
Raises ValueError if both `distance` and `certainty` are provided.
See the official Weaviate documentation to learn more about the `distance` and `certainty` parameters:
https://weaviate.io/developers/weaviate/api/graphql/search-operators#variables
Creates a new instance of WeaviateEmbeddingRetriever.
:param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever.
:param filters: Custom filters applied when running the retriever, defaults to None
:param top_k: Maximum number of documents to return, defaults to 10
:param distance: The maximum allowed distance between Documents' embeddings, defaults to None
:param certainty: Normalized distance between the result item and the search vector, defaults to None
:param document_store:
Instance of WeaviateDocumentStore that will be used from this retriever.
:param filters:
Custom filters applied when running the retriever.
:param top_k:
Maximum number of documents to return.
:param distance:
The maximum allowed distance between Documents' embeddings.
:param certainty:
Normalized distance between the result item and the search vector.
:raises ValueError:
If both `distance` and `certainty` are provided.
See https://weaviate.io/developers/weaviate/api/graphql/search-operators#variables to learn more about
`distance` and `certainty` parameters.
"""
if distance is not None and certainty is not None:
msg = "Can't use 'distance' and 'certainty' parameters together"
Expand All @@ -42,6 +48,12 @@ def __init__(
self._certainty = certainty

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
filters=self._filters,
Expand All @@ -53,6 +65,14 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WeaviateEmbeddingRetriever":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
Expand All @@ -67,10 +87,33 @@ def run(
distance: Optional[float] = None,
certainty: Optional[float] = None,
):
"""
Retrieves documents from Weaviate using the vector search.
:param query_embedding:
Embedding of the query.
:param filters:
Filters to use when running the retriever.
:param top_k:
The maximum number of documents to return.
:param distance:
The maximum allowed distance between Documents' embeddings.
:param certainty:
Normalized distance between the result item and the search vector.
:raises ValueError:
If both `distance` and `certainty` are provided.
See https://weaviate.io/developers/weaviate/api/graphql/search-operators#variables to learn more about
`distance` and `certainty` parameters.
"""
filters = filters or self._filters
top_k = top_k or self._top_k

distance = distance or self._distance
certainty = certainty or self._certainty
if distance is not None and certainty is not None:
msg = "Can't use 'distance' and 'certainty' parameters together"
raise ValueError(msg)

documents = self._document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ def __init__(
"""
Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance.
:param url: The URL to the weaviate instance, defaults to None.
:param collection_settings: The collection settings to use, defaults to None.
If None it will use a collection named `default` with the following properties:
:param url:
The URL to the weaviate instance.
:param collection_settings:
The collection settings to use. If `None`, it will use a collection named `default` with the following
properties:
- _original_id: text
- content: text
- dataframe: text
Expand All @@ -93,22 +95,27 @@ def __init__(
production use.
See the official `Weaviate documentation<https://weaviate.io/developers/weaviate/manage-data/collections>`_
for more information on collections and their properties.
:param auth_client_secret: Authentication credentials, defaults to None.
Can be one of the following types depending on the authentication mode:
:param auth_client_secret:
Authentication credentials. Can be one of the following types depending on the authentication mode:
- `AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens
- `AuthClientPassword` to use username and password for oidc Resource Owner Password flow
- `AuthClientCredentials` to use a client secret for oidc client credential flow
- `AuthApiKey` to use an API key
:param additional_headers: Additional headers to include in the requests, defaults to None.
Can be used to set OpenAI/HuggingFace keys. OpenAI/HuggingFace key looks like this:
:param additional_headers:
Additional headers to include in the requests. Can be used to set OpenAI/HuggingFace keys.
OpenAI/HuggingFace key looks like this:
```
{"X-OpenAI-Api-Key": "<THE-KEY>"}, {"X-HuggingFace-Api-Key": "<THE-KEY>"}
```
:param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None.
For a full list of options see `weaviate.embedded.EmbeddedOptions`.
:param additional_config: Additional and advanced configuration options for weaviate, defaults to None.
:param grpc_port: The port to use for the gRPC connection, defaults to 50051.
:param grpc_secure: Whether to use a secure channel for the underlying gRPC API.
:param embedded_options:
If set, create an embedded Weaviate cluster inside the client. For a full list of options see
`weaviate.embedded.EmbeddedOptions`.
:param additional_config:
Additional and advanced configuration options for weaviate.
:param grpc_port:
The port to use for the gRPC connection.
:param grpc_secure:
Whether to use a secure channel for the underlying gRPC API.
"""
# proxies, timeout_config, trust_env are part of additional_config now
# startup_period has been removed
Expand Down Expand Up @@ -153,6 +160,12 @@ def __init__(
self._collection = self._client.collections.get(collection_settings["class"])

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
embedded_options = asdict(self._embedded_options) if self._embedded_options else None
additional_config = (
json.loads(self._additional_config.model_dump_json(by_alias=True)) if self._additional_config else None
Expand All @@ -170,6 +183,14 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None:
data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict(auth_client_secret)
if (embedded_options := data["init_parameters"].get("embedded_options")) is not None:
Expand All @@ -187,7 +208,7 @@ def count_documents(self) -> int:

def _to_data_object(self, document: Document) -> Dict[str, Any]:
"""
Convert a Document to a Weaviate data object ready to be saved.
Converts a Document to a Weaviate data object ready to be saved.
"""
data = document.to_dict()
# Weaviate forces a UUID as an id.
Expand All @@ -207,7 +228,7 @@ def _to_data_object(self, document: Document) -> Dict[str, Any]:

def _to_document(self, data: DataObject[Dict[str, Any], None]) -> Document:
"""
Convert a data object read from Weaviate into a Document.
Converts a data object read from Weaviate into a Document.
"""
document_data = data.properties
document_data["id"] = document_data.pop("_original_id")
Expand Down
4 changes: 2 additions & 2 deletions integrations/weaviate/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_run(mock_document_store):
retriever = WeaviateEmbeddingRetriever(document_store=mock_document_store)
query_embedding = [0.1, 0.1, 0.1, 0.1]
filters = {"field": "content", "operator": "==", "value": "Some text"}
retriever.run(query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=0.1)
retriever.run(query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1)
mock_document_store._embedding_retrieval.assert_called_once_with(
query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=0.1
query_embedding=query_embedding, filters=filters, top_k=5, distance=0.1, certainty=None
)

0 comments on commit d2780c9

Please sign in to comment.