Skip to content

Commit

Permalink
fixed some linting
Browse files Browse the repository at this point in the history
  • Loading branch information
sahusiddharth committed Jan 4, 2024
1 parent e7451dc commit c87c374
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 49 deletions.
4 changes: 2 additions & 2 deletions integrations/pgvector/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "pgvector_haystack"
dynamic = ["version"]
description = ''
description = ""
readme = "README.md"
requires-python = ">=3.7"
license = "Apache-2.0"
Expand All @@ -26,7 +26,7 @@ classifiers = [
]
dependencies = [
"haystack-ai",
"vecs=0.4.2",
"vecs>=0.4.2",
]

[project.urls]
Expand Down
4 changes: 2 additions & 2 deletions integrations/pgvector/src/pgvector_haystack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: 2023-present John Doe <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from pgvector_haystack.document_store import pgvectorDocumentStore
from pgvector_haystack.document_store import PGvectorDocumentStore

__all__ = ["pgvectorDocumentStore"]
__all__ = ["PGvectorDocumentStore"]
80 changes: 46 additions & 34 deletions integrations/pgvector/src/pgvector_haystack/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from copy import copy
from typing import Any, Dict, List, Optional

import vecs
import numpy as np
import vecs
from haystack.dataclasses import Document
from haystack.document_stores.errors import DuplicateDocumentError, MissingDocumentError
from haystack.document_stores.protocol import DuplicatePolicy

logger = logging.getLogger(__name__)


class pgvectorDocumentStore:
class PGvectorDocumentStore:
def __init__(
self,
user:str,
Expand All @@ -27,12 +27,24 @@ def __init__(
**collection_creation_kwargs,
):
"""
Creates a new PGvectorDocumentStore instance.
For more information on connection parameters, see the official PGvector documentation: https://supabase.github.io/vecs/0.4/
:param user: The username for connecting to the PostgreSQL database.
:param password: The password for connecting to the PostgreSQL database.
:param host: The host address of the PostgreSQL database server.
:param port: The port number on which the PostgreSQL database server is listening.
:param db_name: The name of the PostgreSQL database.
:param collection_name: The name of the collection or table in the database where vectors will be stored.
:param dimension: The dimensionality of the vectors to be stored in the document store.
:param **collection_creation_kwargs: Optional arguments that ``PGvector Document Store`` takes.
"""
self._collection_name = collection_name
self._dummy_vector = [0.0]*dimension
self._adapter = collection_creation_kwargs['adapter']
DB_CONNECTION = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
self._pgvector_client = vecs.create_client(DB_CONNECTION)
self._adapter = collection_creation_kwargs["adapter"]
db_connection = f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
self._pgvector_client = vecs.create_client(db_connection)
self._collection = self._pgvector_client.get_or_create_collection(name=collection_name, dimension=dimension, **collection_creation_kwargs)


Expand All @@ -41,7 +53,7 @@ def count_documents(self) -> int:
Returns how many documents are present in the document store.
"""
return self._collection.__len__()


def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Expand Down Expand Up @@ -73,11 +85,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
- `<`
- `<=`
- `in`
- `not in`
The `operator` values in Logic dictionaries must be one of:
- `NOT`
- `OR`
- `AND`
Expand Down Expand Up @@ -112,14 +122,14 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
if filters and not isinstance(filters, dict):
msg = "Filter must be a dictionary"
raise ValueError(msg)

filters = self._normalize_filters(filters)

# pgvector store performs vector similarity search
# here we are querying with a dummy vector and the max compatible top_k
documents = self._embedding_retrieval(
query_embedding=self._dummy_vector,
filters=filters,
query_embedding=self._dummy_vector,
filters=filters,
)

return self._convert_query_result_to_documents(documents)
Expand All @@ -131,29 +141,29 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
:param documents: a list of documents.
:param policy: The duplicate policy to use when writing documents.
pgvectorDocumentStore only supports `DuplicatePolicy.OVERWRITE`.
PGvectorDocumentStore only supports `DuplicatePolicy.OVERWRITE`.
:return: None
"""
if policy not in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]:
logger.warning(
f"pgvectorDocumentStore only supports `DuplicatePolicy.OVERWRITE`"
f"PGvectorDocumentStore only supports `DuplicatePolicy.OVERWRITE`"
f"but got {policy}. Overwriting duplicates is enabled by default."
)


for doc in documents:
if not isinstance(doc, Document):
msg = "param 'documents' must contain a list of objects of type Document"
raise ValueError(msg)
if doc.content is None:
logger.warning(
"pgvectorDocumentStore can only store the text field of Documents: "
"PGvectorDocumentStore can only store the text field of Documents: "
"'array', 'dataframe' and 'blob' will be dropped."
)

if self._adapter is not None:
data = (doc.id, doc.content, {'content':doc.content, **doc.meta})
data = (doc.id, doc.content, {"content":doc.content, **doc.meta})
self._collection.upsert(records=[data])
else:
embedding = copy(doc.embedding)
Expand All @@ -164,9 +174,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
)
embedding = self._dummy_vector

data = (doc.id, embedding, {'content':doc.content, **doc.meta})
data = (doc.id, embedding, {"content":doc.content, **doc.meta})
self._collection.upsert(records=[data])


def delete_documents(self, document_ids: List[str]) -> None:
"""
Expand All @@ -176,23 +186,23 @@ def delete_documents(self, document_ids: List[str]) -> None:
"""
self._collection.delete(document_ids)


def _convert_query_result_to_documents(self, result) -> List[Document]:
"""
Helper function to convert Chroma results into Haystack Documents
"""
documents = []
for i in result:
document_dict: Dict[str, Any] = {'id':i[0]}
document_dict: Dict[str, Any] = {"id":i[0]}
document_dict["embedding"] = np.array(i[1])
metadata = i[2]
document_dict['content'] = metadata['content']
del metadata['content']
document_dict['meta'] = metadata
document_dict["content"] = metadata["content"]
del metadata["content"]
document_dict["meta"] = metadata
documents.append(Document.from_dict(dict))

return documents


def _embedding_retrieval(
self,
Expand Down Expand Up @@ -227,12 +237,13 @@ def _embedding_retrieval(
return self._convert_query_result_to_documents(result=results)


def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
def _normalize_filters(self, filters: Dict[str, Any]) -> Dict[str, Any]:
"""
Translate Haystack filters to pgvector filters. It returns a dictionary.
"""
if filters and not isinstance(filters, dict):
raise ValueError("Filter must be a dictionary")
msg = "Filter must be a dictionary"
raise ValueError(msg)

operator_mapping = {
"==": "$eq",
Expand All @@ -247,17 +258,18 @@ def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]:
}

def convert(filters: Dict[str, Any]) -> Any:
op = filters.get('operator')
op = filters.get("operator")
if op not in operator_mapping:
raise ValueError(f"{op} not supported in pgvector metadata filtering")
msg = f"{op} not supported in pgvector metadata filtering"
raise ValueError(msg)

if 'conditions' in filters:
if "conditions" in filters:
# Recursive call for nested conditions
return {operator_mapping[op]: [convert(cond) for cond in filters['conditions']]}
return {operator_mapping[op]: [convert(cond) for cond in filters["conditions"]]}
else:
# Simple statement
field = filters['field']
value = filters['value']
field = filters["field"]
value = filters["value"]
return {field: {operator_mapping[op]: value}}

return convert(filters)
return convert(filters)
18 changes: 9 additions & 9 deletions integrations/pgvector/src/pgvector_haystack/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@

from haystack import component

from pgvector_haystack import pgvectorDocumentStore
from pgvector_haystack import PGvectorDocumentStore


@component
class pgvectorQueryRetriever:
class PGvectorQueryRetriever:
"""
A component for retrieving documents from an pgvectorDocumentStore.
A component for retrieving documents from an PGvectorDocumentStore.
"""

def __init__(
self,
self,
*,
document_store: pgvectorDocumentStore,
filters: Optional[Dict[str, Any]] = None,
document_store: PGvectorDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
):
"""
Create an pgvectorRetriever component.
Create an PGvectorRetriever component.
:param document_store: An instance of pgvectorDocumentStore
:param document_store: An instance of PGvectorDocumentStore
:param filters: A dictionary with filters to narrow down the search space (default is None).
:param top_k: The maximum number of documents to retrieve default is 10.
Expand All @@ -44,4 +44,4 @@ def run(self, _):
:param data: The input data for the retriever. In this case, a list of queries.
:return: The retrieved documents.
"""
return [] # FIXME
return [] # FIXME
2 changes: 1 addition & 1 deletion integrations/pgvector/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# SPDX-FileCopyrightText: 2023-present John Doe <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
2 changes: 1 addition & 1 deletion integrations/pgvector/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
import numpy as np
import pytest
from haystack import Document
from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest
from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest

0 comments on commit c87c374

Please sign in to comment.