Skip to content

Commit

Permalink
Update session management in vectorstore (#25)
Browse files Browse the repository at this point in the history
Update session management in the vectorstore
  • Loading branch information
eyurtsev authored Apr 9, 2024
1 parent 8d09a2b commit c587e4c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 35 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ jobs:
strategy:
matrix:
python-version:
# - "3.8"
# - "3.9"
# - "3.10"
- "3.8"
- "3.9"
- "3.10"
- "3.11"
name: Python ${{ matrix.python-version }}
steps:
Expand Down
38 changes: 18 additions & 20 deletions langchain_postgres/vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import annotations

import contextlib
import enum
import logging
import uuid
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Expand All @@ -21,7 +19,7 @@
import sqlalchemy
from sqlalchemy import SQLColumnExpression, cast, delete, func
from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert
from sqlalchemy.orm import Session, relationship
from sqlalchemy.orm import Session, relationship, sessionmaker

try:
from sqlalchemy.orm import declarative_base
Expand Down Expand Up @@ -288,15 +286,19 @@ def __init__(
self.override_relevance_score_fn = relevance_score_fn

if isinstance(connection, str):
self._bind = sqlalchemy.create_engine(url=connection, **(engine_args or {}))
self._engine = sqlalchemy.create_engine(
url=connection, **(engine_args or {})
)
elif isinstance(connection, sqlalchemy.engine.Engine):
self._bind = connection
self._engine = connection
else:
raise ValueError(
"connection should be a connection string or an instance of "
"sqlalchemy.engine.Engine"
)

self._session_maker = sessionmaker(bind=self._engine)

self.use_jsonb = use_jsonb
self.create_extension = create_extension

Expand All @@ -321,16 +323,16 @@ def __post_init__(
self.create_collection()

def __del__(self) -> None:
if isinstance(self._bind, sqlalchemy.engine.Connection):
self._bind.close()
if isinstance(self._engine, sqlalchemy.engine.Connection):
self._engine.close()

@property
def embeddings(self) -> Embeddings:
return self.embedding_function

def create_vector_extension(self) -> None:
try:
with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session: # type: ignore[arg-type]
# The advisor lock fixes issue arising from concurrent
# creation of the vector extension.
# https://github.com/langchain-ai/langchain/issues/12933
Expand All @@ -348,36 +350,31 @@ def create_vector_extension(self) -> None:
raise Exception(f"Failed to create vector extension: {e}") from e

def create_tables_if_not_exists(self) -> None:
with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
with self._session_maker() as session:
Base.metadata.create_all(session.get_bind())

def drop_tables(self) -> None:
with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
with self._session_maker() as session:
Base.metadata.drop_all(session.get_bind())

def create_collection(self) -> None:
if self.pre_delete_collection:
self.delete_collection()
with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session: # type: ignore[arg-type]
self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)

def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection")
with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
self.logger.warning("Collection not found")
return
session.delete(collection)
session.commit()

@contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]:
"""Create a context manager for the session, bind to _conn string."""
yield Session(self._bind) # type: ignore[arg-type]

def delete(
self,
ids: Optional[List[str]] = None,
Expand All @@ -390,7 +387,7 @@ def delete(
ids: List of ids to delete.
collection_only: Only delete ids in the collection.
"""
with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session:
if ids is not None:
self.logger.debug(
"Trying to delete vectors by ids (represented by the model "
Expand Down Expand Up @@ -476,7 +473,7 @@ def add_embeddings(
if not metadatas:
metadatas = [{} for _ in texts]

with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
Expand Down Expand Up @@ -901,7 +898,7 @@ def __query_collection(
filter: Optional[Dict[str, str]] = None,
) -> List[Any]:
"""Query the collection."""
with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
Expand Down Expand Up @@ -1066,6 +1063,7 @@ def from_existing_index(
embeddings=embedding,
distance_strategy=distance_strategy,
pre_delete_collection=pre_delete_collection,
**kwargs,
)

return store
Expand Down
29 changes: 17 additions & 12 deletions tests/unit_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test PGVector functionality."""

import contextlib
from typing import Any, Dict, Generator, List

import pytest
Expand All @@ -18,9 +18,8 @@
TYPE_4_FILTERING_TEST_CASES,
TYPE_5_FILTERING_TEST_CASES,
)
from tests.utils import VECTORSTORE_CONNECTION_STRING
from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING

CONNECTION_STRING = VECTORSTORE_CONNECTION_STRING
ADA_TOKEN_COUNT = 1536


Expand Down Expand Up @@ -159,7 +158,7 @@ def test_pgvector_collection_with_metadata() -> None:
connection=CONNECTION_STRING,
pre_delete_collection=True,
)
with pgvector._make_session() as session:
with pgvector._session_maker() as session:
collection = pgvector.get_collection(session)
if collection is None:
assert False, "Expected a CollectionStore object but received None"
Expand All @@ -182,14 +181,14 @@ def test_pgvector_delete_docs() -> None:
pre_delete_collection=True,
)
vectorstore.delete(["1", "2"])
with vectorstore._make_session() as session:
with vectorstore._session_maker() as session:
records = list(session.query(vectorstore.EmbeddingStore).all())
# ignoring type error since mypy cannot determine whether
# the list is sortable
assert sorted(record.id for record in records) == ["3"] # type: ignore

vectorstore.delete(["2", "3"]) # Should not raise on missing ids
with vectorstore._make_session() as session:
with vectorstore._session_maker() as session:
records = list(session.query(vectorstore.EmbeddingStore).all())
# ignoring type error since mypy cannot determine whether
# the list is sortable
Expand Down Expand Up @@ -229,7 +228,7 @@ def test_pgvector_index_documents() -> None:
connection=CONNECTION_STRING,
pre_delete_collection=True,
)
with vectorstore._make_session() as session:
with vectorstore._session_maker() as session:
records = list(session.query(vectorstore.EmbeddingStore).all())
# ignoring type error since mypy cannot determine whether
# the list is sortable
Expand All @@ -251,7 +250,7 @@ def test_pgvector_index_documents() -> None:

vectorstore.add_documents(documents, ids=[doc.metadata["id"] for doc in documents])

with vectorstore._make_session() as session:
with vectorstore._session_maker() as session:
records = list(session.query(vectorstore.EmbeddingStore).all())
ordered_records = sorted(records, key=lambda x: x.id)
# ignoring type error since mypy cannot determine whether
Expand Down Expand Up @@ -408,6 +407,13 @@ def test_pgvector_with_custom_engine_args() -> None:
@pytest.fixture
def pgvector() -> Generator[PGVector, None, None]:
"""Create a PGVector instance."""
with get_vectorstore() as vector_store:
yield vector_store


@contextlib.contextmanager
def get_vectorstore() -> Generator[PGVector, None, None]:
"""Get a pre-populated-vectorstore"""
store = PGVector.from_documents(
documents=DOCUMENTS,
collection_name="test_collection",
Expand All @@ -419,20 +425,19 @@ def pgvector() -> Generator[PGVector, None, None]:
)
try:
yield store
# Do clean up
finally:
store.drop_tables()


@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_1(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
with get_vectorstore() as pgvector:
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter


@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES)
Expand Down

0 comments on commit c587e4c

Please sign in to comment.