Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft]: propagate score through retriever #20800

Closed
wants to merge 17 commits into from
12 changes: 10 additions & 2 deletions libs/community/langchain_community/vectorstores/redis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,11 @@ class Config:
arbitrary_types_allowed = True

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
Expand Down Expand Up @@ -1472,7 +1476,11 @@ def _get_relevant_documents(
return docs

async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import itertools
import random
import uuid
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, cast
from unittest.mock import MagicMock, patch

import pytest
from langchain_core.documents import DocumentSearchHit

from langchain_community.vectorstores import DatabricksVectorSearch
from tests.integration_tests.vectorstores.fake_embeddings import (
Expand Down Expand Up @@ -598,6 +599,13 @@ def test_similarity_score_threshold(index_details: dict, threshold: float) -> No
assert len(search_result) == len(fake_texts)
else:
assert len(search_result) == 0
result_with_scores = cast(
List[DocumentSearchHit], retriever.invoke(query, include_score=True)
)
for idx, result in enumerate(result_with_scores):
assert result.score >= threshold
assert result.page_content == search_result[idx].page_content
assert result.metadata == search_result[idx].metadata


@pytest.mark.requires("databricks", "databricks.vector_search")
Expand Down
9 changes: 7 additions & 2 deletions libs/core/langchain_core/documents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
and their transformations.

"""
from langchain_core.documents.base import Document
from langchain_core.documents.base import Document, DocumentSearchHit
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_core.documents.transformers import BaseDocumentTransformer

__all__ = ["Document", "BaseDocumentTransformer", "BaseDocumentCompressor"]
__all__ = [
"Document",
"DocumentSearchHit",
"BaseDocumentTransformer",
"BaseDocumentCompressor",
]
18 changes: 18 additions & 0 deletions libs/core/langchain_core/documents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,21 @@ def is_lc_serializable(cls) -> bool:
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "document"]


class DocumentSearchHit(Document):
"""Class for storing a document and fields associated with retrieval."""

score: float
"""Score associated with the document's relevance to a query."""
type: Literal["DocumentSearchHit"] = "DocumentSearchHit" # type: ignore[assignment] # noqa: E501

@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "document_search_hit"]
12 changes: 12 additions & 0 deletions libs/core/langchain_core/load/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@
"base",
"Document",
),
("langchain", "schema", "document_search_hit", "DocumentSearchHit"): (
"langchain_core",
"documents",
"base",
"DocumentSearchHit",
),
("langchain", "output_parsers", "fix", "OutputFixingParser"): (
"langchain",
"output_parsers",
Expand Down Expand Up @@ -666,6 +672,12 @@
"base",
"Document",
),
("langchain_core", "documents", "base", "DocumentSearchHit"): (
"langchain_core",
"documents",
"base",
"DocumentSearchHit",
),
("langchain_core", "prompts", "chat", "AIMessagePromptTemplate"): (
"langchain_core",
"prompts",
Expand Down
41 changes: 39 additions & 2 deletions libs/core/langchain_core/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
TypeVar,
)

from langchain_core.documents import DocumentSearchHit
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
Expand Down Expand Up @@ -690,8 +691,17 @@ def validate_search_type(cls, values: Dict) -> Dict:
return values

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if include_score and self.search_type != "similarity_score_threshold":
raise ValueError(
"include_score is only supported "
"for search_type=similarity_score_threshold"
)
if self.search_type == "similarity":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this case feels like it could lead to some confusion. Should we throw an error if include_score is set to true in this case with a message to switch search type?

Also do we want to use relevance scores for this, or do we want to just return raw scores? My goal was to deprecate the relevance score stuff as-is in #20302

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might also be good to add a typing override such that when include_score is set to true, this returns DocumentSearchHit. Alternative is to just add the score stuff as an optional to Document.

docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
Expand All @@ -700,6 +710,15 @@ def _get_relevant_documents(
query, **self.search_kwargs
)
)
if include_score:
return [
DocumentSearchHit(
page_content=doc.page_content,
metadata=doc.metadata,
score=score,
)
for doc, score in docs_and_similarities
]
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
Expand All @@ -710,8 +729,17 @@ def _get_relevant_documents(
return docs

async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if include_score and self.search_type != "similarity_score_threshold":
raise ValueError(
"include_score is only supported "
"for search_type=similarity_score_threshold"
)
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
Expand All @@ -722,6 +750,15 @@ async def _aget_relevant_documents(
query, **self.search_kwargs
)
)
if include_score:
return [
DocumentSearchHit(
page_content=doc.page_content,
metadata=doc.metadata,
score=score,
)
for doc, score in docs_and_similarities
]
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = await self.vectorstore.amax_marginal_relevance_search(
Expand Down
7 changes: 6 additions & 1 deletion libs/core/tests/unit_tests/documents/test_imports.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from langchain_core.documents import __all__

EXPECTED_ALL = ["Document", "BaseDocumentTransformer", "BaseDocumentCompressor"]
EXPECTED_ALL = [
"Document",
"DocumentSearchHit",
"BaseDocumentTransformer",
"BaseDocumentCompressor",
]


def test_all_imports() -> None:
Expand Down
54 changes: 45 additions & 9 deletions libs/langchain/langchain/retrievers/self_query/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.documents import Document, DocumentSearchHit
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
Expand Down Expand Up @@ -192,19 +192,47 @@ def _prepare_query(
return new_query, search_kwargs

def _get_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
) -> List[Document]:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
if include_score:
docs_and_scores = self.vectorstore.similarity_search_with_score(
query, **search_kwargs
)
return [
DocumentSearchHit(
page_content=doc.page_content, metadata=doc.metadata, score=score
)
for doc, score in docs_and_scores
]
else:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs

async def _aget_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
) -> List[Document]:
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
if include_score:
docs_and_scores = await self.vectorstore.asimilarity_search_with_score(
query, **search_kwargs
)
return [
DocumentSearchHit(
page_content=doc.page_content, metadata=doc.metadata, score=score
)
for doc, score in docs_and_scores
]
else:
docs = await self.vectorstore.asearch(
query, self.search_type, **search_kwargs
)
return docs

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
"""Get documents relevant for a query.

Expand All @@ -220,11 +248,17 @@ def _get_relevant_documents(
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = self._get_docs_with_query(new_query, search_kwargs)
docs = self._get_docs_with_query(
new_query, search_kwargs, include_score=include_score
)
return docs

async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
"""Get documents relevant for a query.

Expand All @@ -240,7 +274,9 @@ async def _aget_relevant_documents(
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = await self._aget_docs_with_query(new_query, search_kwargs)
docs = await self._aget_docs_with_query(
new_query, search_kwargs, include_score=include_score
)
return docs

@classmethod
Expand Down
12 changes: 9 additions & 3 deletions libs/langchain/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/langchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ langchain-server = "langchain.server:main"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.48"
langchain-core = "^0.1.52"
langchain-text-splitters = ">=0.0.1,<0.1"
langchain-community = ">=0.0.37,<0.1"
langsmith = "^0.1.17"
Expand Down
Loading