Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update session management in vectorstore #25

Merged
merged 8 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ jobs:
strategy:
matrix:
python-version:
# - "3.8"
# - "3.9"
# - "3.10"
- "3.8"
- "3.9"
- "3.10"
- "3.11"
name: Python ${{ matrix.python-version }}
steps:
Expand Down
38 changes: 18 additions & 20 deletions langchain_postgres/vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import annotations

import contextlib
import enum
import logging
import uuid
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Expand All @@ -21,7 +19,7 @@
import sqlalchemy
from sqlalchemy import SQLColumnExpression, cast, delete, func
from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert
from sqlalchemy.orm import Session, relationship
from sqlalchemy.orm import Session, relationship, sessionmaker

try:
from sqlalchemy.orm import declarative_base
Expand Down Expand Up @@ -288,15 +286,19 @@ def __init__(
self.override_relevance_score_fn = relevance_score_fn

if isinstance(connection, str):
self._bind = sqlalchemy.create_engine(url=connection, **(engine_args or {}))
self._engine = sqlalchemy.create_engine(
url=connection, **(engine_args or {})
)
elif isinstance(connection, sqlalchemy.engine.Engine):
self._bind = connection
self._engine = connection
else:
raise ValueError(
"connection should be a connection string or an instance of "
"sqlalchemy.engine.Engine"
)

self._session_maker = sessionmaker(bind=self._engine)

self.use_jsonb = use_jsonb
self.create_extension = create_extension

Expand All @@ -321,16 +323,16 @@ def __post_init__(
self.create_collection()

def __del__(self) -> None:
if isinstance(self._bind, sqlalchemy.engine.Connection):
self._bind.close()
if isinstance(self._engine, sqlalchemy.engine.Connection):
self._engine.close()

@property
def embeddings(self) -> Embeddings:
return self.embedding_function

def create_vector_extension(self) -> None:
try:
with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session: # type: ignore[arg-type]
# The advisor lock fixes issue arising from concurrent
# creation of the vector extension.
# https://github.com/langchain-ai/langchain/issues/12933
Expand All @@ -348,36 +350,31 @@ def create_vector_extension(self) -> None:
raise Exception(f"Failed to create vector extension: {e}") from e

def create_tables_if_not_exists(self) -> None:
with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
with self._session_maker() as session:
Base.metadata.create_all(session.get_bind())

def drop_tables(self) -> None:
with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
with self._session_maker() as session:
Base.metadata.drop_all(session.get_bind())

def create_collection(self) -> None:
if self.pre_delete_collection:
self.delete_collection()
with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session: # type: ignore[arg-type]
self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)

def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection")
with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
self.logger.warning("Collection not found")
return
session.delete(collection)
session.commit()

@contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]:
"""Create a context manager for the session, bind to _conn string."""
yield Session(self._bind) # type: ignore[arg-type]

def delete(
self,
ids: Optional[List[str]] = None,
Expand All @@ -390,7 +387,7 @@ def delete(
ids: List of ids to delete.
collection_only: Only delete ids in the collection.
"""
with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session:
if ids is not None:
self.logger.debug(
"Trying to delete vectors by ids (represented by the model "
Expand Down Expand Up @@ -476,7 +473,7 @@ def add_embeddings(
if not metadatas:
metadatas = [{} for _ in texts]

with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
Expand Down Expand Up @@ -901,7 +898,7 @@ def __query_collection(
filter: Optional[Dict[str, str]] = None,
) -> List[Any]:
"""Query the collection."""
with Session(self._bind) as session: # type: ignore[arg-type]
with self._session_maker() as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
Expand Down Expand Up @@ -1066,6 +1063,7 @@ def from_existing_index(
embeddings=embedding,
distance_strategy=distance_strategy,
pre_delete_collection=pre_delete_collection,
**kwargs,
)

return store
Expand Down
29 changes: 17 additions & 12 deletions tests/unit_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test PGVector functionality."""

import contextlib
from typing import Any, Dict, Generator, List

import pytest
Expand All @@ -18,9 +18,8 @@
TYPE_4_FILTERING_TEST_CASES,
TYPE_5_FILTERING_TEST_CASES,
)
from tests.utils import VECTORSTORE_CONNECTION_STRING
from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING

CONNECTION_STRING = VECTORSTORE_CONNECTION_STRING
ADA_TOKEN_COUNT = 1536


Expand Down Expand Up @@ -159,7 +158,7 @@ def test_pgvector_collection_with_metadata() -> None:
connection=CONNECTION_STRING,
pre_delete_collection=True,
)
with pgvector._make_session() as session:
with pgvector._session_maker() as session:
collection = pgvector.get_collection(session)
if collection is None:
assert False, "Expected a CollectionStore object but received None"
Expand All @@ -182,14 +181,14 @@ def test_pgvector_delete_docs() -> None:
pre_delete_collection=True,
)
vectorstore.delete(["1", "2"])
with vectorstore._make_session() as session:
with vectorstore._session_maker() as session:
records = list(session.query(vectorstore.EmbeddingStore).all())
# ignoring type error since mypy cannot determine whether
# the list is sortable
assert sorted(record.id for record in records) == ["3"] # type: ignore

vectorstore.delete(["2", "3"]) # Should not raise on missing ids
with vectorstore._make_session() as session:
with vectorstore._session_maker() as session:
records = list(session.query(vectorstore.EmbeddingStore).all())
# ignoring type error since mypy cannot determine whether
# the list is sortable
Expand Down Expand Up @@ -229,7 +228,7 @@ def test_pgvector_index_documents() -> None:
connection=CONNECTION_STRING,
pre_delete_collection=True,
)
with vectorstore._make_session() as session:
with vectorstore._session_maker() as session:
records = list(session.query(vectorstore.EmbeddingStore).all())
# ignoring type error since mypy cannot determine whether
# the list is sortable
Expand All @@ -251,7 +250,7 @@ def test_pgvector_index_documents() -> None:

vectorstore.add_documents(documents, ids=[doc.metadata["id"] for doc in documents])

with vectorstore._make_session() as session:
with vectorstore._session_maker() as session:
records = list(session.query(vectorstore.EmbeddingStore).all())
ordered_records = sorted(records, key=lambda x: x.id)
# ignoring type error since mypy cannot determine whether
Expand Down Expand Up @@ -408,6 +407,13 @@ def test_pgvector_with_custom_engine_args() -> None:
@pytest.fixture
def pgvector() -> Generator[PGVector, None, None]:
"""Create a PGVector instance."""
with get_vectorstore() as vector_store:
yield vector_store


@contextlib.contextmanager
def get_vectorstore() -> Generator[PGVector, None, None]:
"""Get a pre-populated-vectorstore"""
store = PGVector.from_documents(
documents=DOCUMENTS,
collection_name="test_collection",
Expand All @@ -419,20 +425,19 @@ def pgvector() -> Generator[PGVector, None, None]:
)
try:
yield store
# Do clean up
finally:
store.drop_tables()


@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES)
def test_pgvector_with_with_metadata_filters_1(
pgvector: PGVector,
test_filter: Dict[str, Any],
expected_ids: List[int],
) -> None:
"""Test end to end construction and search."""
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter
with get_vectorstore() as pgvector:
docs = pgvector.similarity_search("meow", k=5, filter=test_filter)
assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter


@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES)
Expand Down
Loading