Skip to content

Commit

Permalink
adding custom from_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista committed Feb 13, 2024
1 parent 7b87fe0 commit cc92e1b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
3 changes: 2 additions & 1 deletion integrations/pgvector/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ pip install pgvector-haystack

## Testing

Ensure you have PostgreSQL installed with the `pgvector` extension, for a quick setup using Docker:
Ensure that you have a PostgreSQL running with the `pgvector` extension. For a quick setup using Docker, run:
```
docker run -d -p 5432:5432 -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres ankane/pgvector
```

then run the tests:

```console
hatch run test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever":
data["init_parameters"]["document_store"] = default_from_dict(
PgvectorDocumentStore, data["init_parameters"]["document_store"]
)
doc_store_params = data["init_parameters"]["document_store"]
data["init_parameters"]["document_store"] = PgvectorDocumentStore.from_dict(doc_store_params)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from typing import Any, Dict, List, Literal, Optional

from haystack import default_to_dict
from haystack import default_to_dict, default_from_dict
from haystack.dataclasses.document import ByteStream, Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
Expand Down Expand Up @@ -163,6 +163,13 @@ def to_dict(self) -> Dict[str, Any]:
hnsw_ef_search=self.hnsw_ef_search,
)

@classmethod
def from_dict(cls, init_parameters: Dict[str, Any]) -> "PgvectorDocumentStore":
conn_str_data = init_parameters['init_parameters']["connection_string"]
conn_str = Secret.from_dict(conn_str_data) if conn_str_data is not None else None
init_parameters['init_parameters']["connection_string"] = conn_str
return default_from_dict(cls, init_parameters)

def _execute_sql(
self, sql_query: Query, params: Optional[tuple] = None, error_msg: str = "", cursor: Optional[Cursor] = None
):
Expand Down Expand Up @@ -222,15 +229,15 @@ def _handle_hnsw(self):
)
self._execute_sql(sql_set_hnsw_ef_search, error_msg="Could not set hnsw.ef_search")

index_esists = bool(
index_exists = bool(
self._execute_sql(
"SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s",
(self.table_name, HNSW_INDEX_NAME),
"Could not check if HNSW index exists",
).fetchone()
)

if index_esists and not self.hnsw_recreate_index_if_exists:
if index_exists and not self.hnsw_recreate_index_if_exists:
logger.warning(
"HNSW index already exists and won't be recreated. "
"If you want to recreate it, pass 'hnsw_recreate_index_if_exists=True' to the "
Expand Down Expand Up @@ -374,7 +381,8 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D

return written_docs

def _from_haystack_to_pg_documents(self, documents: List[Document]) -> List[Dict[str, Any]]:
@staticmethod
def _from_haystack_to_pg_documents(documents: List[Document]) -> List[Dict[str, Any]]:
"""
Internal method to convert a list of Haystack Documents to a list of dictionaries that can be used to insert
documents into the PgvectorDocumentStore.
Expand All @@ -396,7 +404,8 @@ def _from_haystack_to_pg_documents(self, documents: List[Document]) -> List[Dict

return db_documents

def _from_pg_to_haystack_documents(self, documents: List[Dict[str, Any]]) -> List[Document]:
@staticmethod
def _from_pg_to_haystack_documents(documents: List[Dict[str, Any]]) -> List[Document]:
"""
Internal method to convert a list of dictionaries from pgvector to a list of Haystack Documents.
"""
Expand Down
9 changes: 6 additions & 3 deletions integrations/pgvector/tests/test_retriever.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import Mock

from haystack.dataclasses import Document
from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore

from haystack.utils.auth import EnvVarSecret


class TestRetriever:
def test_init_default(self, document_store: PgvectorDocumentStore):
Expand Down Expand Up @@ -37,7 +40,7 @@ def test_to_dict(self, document_store: PgvectorDocumentStore):
"document_store": {
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres",
"connection_string": {'env_vars': ['PG_CONN_STR'], 'strict': True, 'type': 'env_var'},
"table_name": "haystack_test_to_dict",
"embedding_dimension": 768,
"vector_function": "cosine_similarity",
Expand All @@ -62,7 +65,7 @@ def test_from_dict(self):
"document_store": {
"type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore",
"init_parameters": {
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres",
"connection_string": {'env_vars': ['PG_CONN_STR'], 'strict': True, 'type': 'env_var'},
"table_name": "haystack_test_to_dict",
"embedding_dimension": 768,
"vector_function": "cosine_similarity",
Expand All @@ -83,7 +86,7 @@ def test_from_dict(self):
document_store = retriever.document_store

assert isinstance(document_store, PgvectorDocumentStore)
assert document_store.connection_string == "postgresql://postgres:postgres@localhost:5432/postgres"
assert isinstance(document_store.connection_string, EnvVarSecret)
assert document_store.table_name == "haystack_test_to_dict"
assert document_store.embedding_dimension == 768
assert document_store.vector_function == "cosine_similarity"
Expand Down

0 comments on commit cc92e1b

Please sign in to comment.