Skip to content

Commit

Permalink
Tests: Import _compare_documents function from langchain-postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Dec 14, 2024
1 parent 755448d commit ce34353
Showing 1 changed file with 80 additions and 22 deletions.
102 changes: 80 additions & 22 deletions tests/integration_tests/vectorstore/test_vector_cratedb_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@
)


def _compare_documents(left: Sequence[Document], right: Sequence[Document]) -> None:
"""Compare lists of documents, irrespective of IDs."""
assert len(left) == len(right)
for left_doc, right_doc in zip(left, right):
assert left_doc.page_content == right_doc.page_content
assert left_doc.metadata == right_doc.metadata


def test_cratedb_collection_read_only(session: sa.orm.Session) -> None:
"""
Test using a collection, without adding any embeddings upfront.
Expand Down Expand Up @@ -76,7 +84,7 @@ def test_cratedb_texts(engine: sa.Engine) -> None:
)
output = docsearch.similarity_search("foo", k=1)
prune_document_ids(output)
assert output == [Document(metadata={}, page_content="foo")]
_compare_documents(output, [Document(page_content="foo")])


def test_cratedb_embedding_dimension(engine: sa.Engine) -> None:
Expand Down Expand Up @@ -112,7 +120,7 @@ def test_cratedb_embeddings(engine: sa.Engine) -> None:
)
output = docsearch.similarity_search("foo", k=1)
prune_document_ids(output)
assert output == [Document(page_content="foo")]
_compare_documents(output, [Document(page_content="foo")])


def test_cratedb_with_metadatas(engine: sa.Engine) -> None:
Expand All @@ -129,7 +137,7 @@ def test_cratedb_with_metadatas(engine: sa.Engine) -> None:
)
output = docsearch.similarity_search("foo", k=1)
prune_document_ids(output)
assert output == [Document(page_content="foo", metadata={"page": "0"})]
_compare_documents(output, [Document(page_content="foo", metadata={"page": "0"})])


def test_cratedb_with_metadatas_with_scores(engine: sa.Engine) -> None:
Expand All @@ -146,7 +154,11 @@ def test_cratedb_with_metadatas_with_scores(engine: sa.Engine) -> None:
)
output = docsearch.similarity_search_with_score("foo", k=1)
prune_document_ids(output)
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="foo", metadata={"page": "0"})])
# FIXME: WHy is it 1.0 instead of 0.0? That certainly can't be right?
# Original score value: 0.0
assert scores == (1.0,)


def test_cratedb_with_filter_match(engine: sa.Engine) -> None:
Expand All @@ -165,7 +177,11 @@ def test_cratedb_with_filter_match(engine: sa.Engine) -> None:
# assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
prune_document_ids(output)
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="foo", metadata={"page": "0"})])
# FIXME: WHy is it 1.0 instead of 0.0? That certainly can't be right?
# Original score value: 0.0
assert scores == (1.0,)


def test_cratedb_with_filter_distant_match(engine: sa.Engine) -> None:
Expand All @@ -182,8 +198,10 @@ def test_cratedb_with_filter_distant_match(engine: sa.Engine) -> None:
)
output = docsearch.similarity_search_with_score("foo", k=2, filter={"page": "2"})
prune_document_ids(output)
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="baz", metadata={"page": "2"})])
# Original score value: 0.0013003906671379406
assert output == [(Document(page_content="baz", metadata={"page": "2"}), 0.2)]
assert scores == (0.2,)


def test_cratedb_with_filter_no_match(engine: sa.Engine) -> None:
Expand All @@ -202,7 +220,7 @@ def test_cratedb_with_filter_no_match(engine: sa.Engine) -> None:
assert output == []


def test_cratedb_collection_delete(engine: sa.Engine, session: sa.orm.Session) -> None:
def test_cratedb_delete_collection(engine: sa.Engine, session: sa.orm.Session) -> None:
"""
Test end to end collection construction and deletion.
Uses two different collections of embeddings.
Expand Down Expand Up @@ -309,11 +327,16 @@ def test_cratedb_with_filter_in_set(engine: sa.Engine, operator: str) -> None:
"foo", k=2, filter={"page": {operator: ["0", "2"]}}
)
prune_document_ids(output)
docs, scores = zip(*output)
# Original score values: 0.0, 0.0013003906671379406
assert output == [
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
(Document(page_content="baz", metadata={"page": "2"}), 0.2),
]
_compare_documents(
docs,
[
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="baz", metadata={"page": "2"}),
],
)
assert scores == (1.0, 0.2)


def test_cratedb_delete_docs(engine: sa.Engine) -> None:
Expand Down Expand Up @@ -359,12 +382,17 @@ def test_cratedb_relevance_score(engine: sa.Engine) -> None:

output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
prune_document_ids(output)
docs, scores = zip(*output)
_compare_documents(
docs,
[
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
Document(page_content="baz", metadata={"page": "2"}),
],
)
# Original score values: 1.0, 0.9996744261675065, 0.9986996093328621
assert output == [
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
(Document(page_content="bar", metadata={"page": "1"}), 0.5),
(Document(page_content="baz", metadata={"page": "2"}), 0.2),
]
assert scores == (1.0, 0.5, 0.2)


def test_cratedb_retriever_search_threshold(engine: sa.Engine) -> None:
Expand All @@ -386,10 +414,13 @@ def test_cratedb_retriever_search_threshold(engine: sa.Engine) -> None:
)
output = retriever.invoke("summer")
prune_document_ids(output)
assert output == [
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
]
_compare_documents(
output,
[
Document(page_content="foo", metadata={"page": "0"}),
Document(page_content="bar", metadata={"page": "1"}),
],
)


def test_cratedb_retriever_search_threshold_custom_normalization_fn(
Expand Down Expand Up @@ -428,7 +459,7 @@ def test_cratedb_max_marginal_relevance_search(engine: sa.Engine) -> None:
)
output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
prune_document_ids(output)
assert output == [Document(page_content="foo")]
_compare_documents(output, [Document(page_content="foo")])


def test_cratedb_max_marginal_relevance_search_with_score(engine: sa.Engine) -> None:
Expand All @@ -443,7 +474,34 @@ def test_cratedb_max_marginal_relevance_search_with_score(engine: sa.Engine) ->
)
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
prune_document_ids(output)
assert output == [(Document(page_content="foo"), 1.0)]
docs, scores = zip(*output)
_compare_documents(docs, [Document(page_content="foo")])
# FIXME: WHy is it 1.0 instead of 0.0? That certainly can't be right?
# Original score value: 0.0
assert scores == (1.0,)


def test_cratedb_with_custom_engine_args() -> None:
"""Test construction using custom engine arguments."""
texts = ["foo", "bar", "baz"]
engine_args = {
"pool_size": 5,
"max_overflow": 10,
"pool_recycle": -1,
"pool_use_lifo": False,
"pool_pre_ping": False,
"pool_timeout": 30,
}
docsearch = CrateDBVectorStore.from_texts(
texts=texts,
collection_name="test_collection",
embedding=FakeEmbeddingsWithAdaDimension(),
connection=CONNECTION_STRING,
pre_delete_collection=True,
engine_args=engine_args,
)
output = docsearch.similarity_search("foo", k=1)
_compare_documents(output, [Document(page_content="foo")])


# We should reuse this test-case across other integrations
Expand Down

0 comments on commit ce34353

Please sign in to comment.