From d269dd2e2fedc2fa843ae3c9713adb8dea94f9fa Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Tue, 17 Oct 2023 02:05:12 +0200 Subject: [PATCH] added a multiturn search based on Vertex AI Search (#11885) Replace this entire comment with: - **Description:** Added a retriever based on multi-turn Vertex AI Search - **Twitter handle:** lkuligin --- .../retrievers/google_vertex_ai_search.ipynb | 33 +- .../langchain/retrievers/__init__.py | 2 + .../retrievers/google_vertex_ai_search.py | 289 ++++++++++-------- .../test_google_vertex_ai_search.py | 22 +- 4 files changed, 220 insertions(+), 126 deletions(-) diff --git a/docs/docs/integrations/retrievers/google_vertex_ai_search.ipynb b/docs/docs/integrations/retrievers/google_vertex_ai_search.ipynb index 16ebc271d1037..d3f3300e6e197 100644 --- a/docs/docs/integrations/retrievers/google_vertex_ai_search.ipynb +++ b/docs/docs/integrations/retrievers/google_vertex_ai_search.ipynb @@ -161,7 +161,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.retrievers import GoogleVertexAISearchRetriever\n", + "from langchain.retrievers import GoogleVertexAISearchRetriever, GoogleVertexAIMultiTurnSearchRetriever\n", "\n", "PROJECT_ID = \"\" # Set to your Project ID\n", "LOCATION_ID = \"\" # Set to your data store location\n", @@ -247,6 +247,37 @@ "for doc in result:\n", " print(doc)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configure and use the retrieve for multi-turn search" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Search with follow-ups is [based](https://cloud.google.com/generative-ai-app-builder/docs/multi-turn-search) on generative AI models and it is different from the regular unstructured data search." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retriever = GoogleVertexAIMultiTurnSearchRetriever(\n", + " project_id=PROJECT_ID,\n", + " location_id=LOCATION_ID,\n", + " data_store_id=DATA_STORE_ID\n", + ")\n", + "\n", + "result = retriever.get_relevant_documents(query)\n", + "for doc in result:\n", + " print(doc)" + ] } ], "metadata": { diff --git a/libs/langchain/langchain/retrievers/__init__.py b/libs/langchain/langchain/retrievers/__init__.py index e1c3fa3da86f4..ea43611894de6 100644 --- a/libs/langchain/langchain/retrievers/__init__.py +++ b/libs/langchain/langchain/retrievers/__init__.py @@ -35,6 +35,7 @@ GoogleCloudEnterpriseSearchRetriever, ) from langchain.retrievers.google_vertex_ai_search import ( + GoogleVertexAIMultiTurnSearchRetriever, GoogleVertexAISearchRetriever, ) from langchain.retrievers.kay import KayAiRetriever @@ -79,6 +80,7 @@ "ElasticSearchBM25Retriever", "GoogleDocumentAIWarehouseRetriever", "GoogleCloudEnterpriseSearchRetriever", + "GoogleVertexAIMultiTurnSearchRetriever", "GoogleVertexAISearchRetriever", "KayAiRetriever", "KNNRetriever", diff --git a/libs/langchain/langchain/retrievers/google_vertex_ai_search.py b/libs/langchain/langchain/retrievers/google_vertex_ai_search.py index 9c4a5d0575b20..25f1478ceb249 100644 --- a/libs/langchain/langchain/retrievers/google_vertex_ai_search.py +++ b/libs/langchain/langchain/retrievers/google_vertex_ai_search.py @@ -4,88 +4,32 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from langchain.callbacks.manager import CallbackManagerForRetrieverRun -from langchain.pydantic_v1 import Extra, Field, root_validator +from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.schema import BaseRetriever, Document from langchain.utils import get_from_dict_or_env if TYPE_CHECKING: + from google.api_core.client_options import ClientOptions from google.cloud.discoveryengine_v1beta import ( + ConversationalSearchServiceClient, SearchRequest, SearchResult, SearchServiceClient, ) -class GoogleVertexAISearchRetriever(BaseRetriever): - """`Google Vertex AI Search` retriever. - - For a detailed explanation of the Vertex AI Search concepts - and configuration parameters, refer to the product documentation. - https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction - """ - +class _BaseGoogleVertexAISearchRetriever(BaseModel): project_id: str """Google Cloud Project ID.""" data_store_id: str """Vertex AI Search data store ID.""" - serving_config_id: str = "default_config" - """Vertex AI Search serving config ID.""" location_id: str = "global" """Vertex AI Search data store location.""" - filter: Optional[str] = None - """Filter expression.""" - get_extractive_answers: bool = False - """If True return Extractive Answers, otherwise return Extractive Segments.""" - max_documents: int = Field(default=5, ge=1, le=100) - """The maximum number of documents to return.""" - max_extractive_answer_count: int = Field(default=1, ge=1, le=5) - """The maximum number of extractive answers returned in each search result. - At most 5 answers will be returned for each SearchResult. - """ - max_extractive_segment_count: int = Field(default=1, ge=1, le=1) - """The maximum number of extractive segments returned in each search result. - Currently one segment will be returned for each SearchResult. - """ - query_expansion_condition: int = Field(default=1, ge=0, le=2) - """Specification to determine under which conditions query expansion should occur. - 0 - Unspecified query expansion condition. In this case, server behavior defaults - to disabled - 1 - Disabled query expansion. Only the exact search query is used, even if - SearchResponse.total_size is zero. - 2 - Automatic query expansion built by the Search API. - """ - spell_correction_mode: int = Field(default=2, ge=0, le=2) - """Specification to determine under which conditions query expansion should occur. - 0 - Unspecified spell correction mode. In this case, server behavior defaults - to auto. - 1 - Suggestion only. Search API will try to find a spell suggestion if there is any - and put in the `SearchResponse.corrected_query`. - The spell suggestion will not be used as the search query. - 2 - Automatic spell correction built by the Search API. - Search will be based on the corrected query if found. - """ credentials: Any = None """The default custom credentials (google.auth.credentials.Credentials) to use when making API calls. If not provided, credentials will be ascertained from the environment.""" - # TODO: Add extra data type handling for type website - engine_data_type: int = Field(default=0, ge=0, le=1) - """ Defines the Vertex AI Search data type - 0 - Unstructured data - 1 - Structured data - """ - - _client: SearchServiceClient - _serving_config: str - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.ignore - arbitrary_types_allowed = True - underscore_attrs_are_private = True - @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validates the environment.""" @@ -94,9 +38,9 @@ def validate_environment(cls, values: Dict) -> Dict: except ImportError as exc: raise ImportError( "google.cloud.discoveryengine is not installed." - "Please install it with pip install google-cloud-discoveryengine" + "Please install it with pip install " + "google-cloud-discoveryengine>=0.11.0" ) from exc - try: from google.api_core.exceptions import InvalidArgument # noqa: F401 except ImportError as exc: @@ -130,47 +74,42 @@ def validate_environment(cls, values: Dict) -> Dict: return values - def __init__(self, **data: Any) -> None: - """Initializes private fields.""" - try: - from google.cloud.discoveryengine_v1beta import SearchServiceClient - except ImportError as exc: - raise ImportError( - "google.cloud.discoveryengine is not installed." - "Please install it with pip install google-cloud-discoveryengine" - ) from exc - try: - from google.api_core.client_options import ClientOptions - except ImportError as exc: - raise ImportError( - "google.api_core.client_options is not installed." - "Please install it with pip install google-api-core" - ) from exc - - super().__init__(**data) + @property + def client_options(self) -> "ClientOptions": + from google.api_core.client_options import ClientOptions - # For more information, refer to: - # https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store - api_endpoint = ( - "discoveryengine.googleapis.com" - if self.location_id == "global" - else f"{self.location_id}-discoveryengine.googleapis.com" + return ClientOptions( + api_endpoint=f"{self.location_id}-discoveryengine.googleapis.com" + if self.location_id != "global" + else None ) - self._client = SearchServiceClient( - credentials=self.credentials, - client_options=ClientOptions(api_endpoint=api_endpoint), - ) + def _convert_structured_search_response( + self, results: Sequence[SearchResult] + ) -> List[Document]: + """Converts a sequence of search results to a list of LangChain documents.""" + import json - self._serving_config = self._client.serving_config_path( - project=self.project_id, - location=self.location_id, - data_store=self.data_store_id, - serving_config=self.serving_config_id, - ) + from google.protobuf.json_format import MessageToDict + + documents: List[Document] = [] + + for result in results: + document_dict = MessageToDict( + result.document._pb, preserving_proto_field_name=True + ) + + documents.append( + Document( + page_content=json.dumps(document_dict.get("struct_data", {})), + metadata={"id": document_dict["id"], "name": document_dict["name"]}, + ) + ) + + return documents def _convert_unstructured_search_response( - self, results: Sequence[SearchResult] + self, results: Sequence[SearchResult], chunk_type: str ) -> List[Document]: """Converts a sequence of search results to a list of LangChain documents.""" from google.protobuf.json_format import MessageToDict @@ -188,12 +127,6 @@ def _convert_unstructured_search_response( doc_metadata = document_dict.get("struct_data", {}) doc_metadata["id"] = document_dict["id"] - chunk_type = ( - "extractive_answers" - if self.get_extractive_answers - else "extractive_segments" - ) - if chunk_type not in derived_struct_data: continue @@ -211,29 +144,91 @@ def _convert_unstructured_search_response( return documents - def _convert_structured_search_response( - self, results: Sequence[SearchResult] - ) -> List[Document]: - """Converts a sequence of search results to a list of LangChain documents.""" - import json - from google.protobuf.json_format import MessageToDict +class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever): + """`Google Vertex AI Search` retriever. - documents: List[Document] = [] + For a detailed explanation of the Vertex AI Search concepts + and configuration parameters, refer to the product documentation. + https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction + """ - for result in results: - document_dict = MessageToDict( - result.document._pb, preserving_proto_field_name=True - ) + serving_config_id: str = "default_config" + """Vertex AI Search serving config ID.""" + filter: Optional[str] = None + """Filter expression.""" + get_extractive_answers: bool = False + """If True return Extractive Answers, otherwise return Extractive Segments.""" + max_documents: int = Field(default=5, ge=1, le=100) + """The maximum number of documents to return.""" + max_extractive_answer_count: int = Field(default=1, ge=1, le=5) + """The maximum number of extractive answers returned in each search result. + At most 5 answers will be returned for each SearchResult. + """ + max_extractive_segment_count: int = Field(default=1, ge=1, le=1) + """The maximum number of extractive segments returned in each search result. + Currently one segment will be returned for each SearchResult. + """ + query_expansion_condition: int = Field(default=1, ge=0, le=2) + """Specification to determine under which conditions query expansion should occur. + 0 - Unspecified query expansion condition. In this case, server behavior defaults + to disabled + 1 - Disabled query expansion. Only the exact search query is used, even if + SearchResponse.total_size is zero. + 2 - Automatic query expansion built by the Search API. + """ + spell_correction_mode: int = Field(default=2, ge=0, le=2) + """Specification to determine under which conditions query expansion should occur. + 0 - Unspecified spell correction mode. In this case, server behavior defaults + to auto. + 1 - Suggestion only. Search API will try to find a spell suggestion if there is any + and put in the `SearchResponse.corrected_query`. + The spell suggestion will not be used as the search query. + 2 - Automatic spell correction built by the Search API. + Search will be based on the corrected query if found. + """ - documents.append( - Document( - page_content=json.dumps(document_dict.get("struct_data", {})), - metadata={"id": document_dict["id"], "name": document_dict["name"]}, - ) - ) + # TODO: Add extra data type handling for type website + engine_data_type: int = Field(default=0, ge=0, le=1) + """ Defines the Vertex AI Search data type + 0 - Unstructured data + 1 - Structured data + """ - return documents + _client: SearchServiceClient + _serving_config: str + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.ignore + arbitrary_types_allowed = True + underscore_attrs_are_private = True + + def __init__(self, **kwargs: Any) -> None: + """Initializes private fields.""" + try: + from google.cloud.discoveryengine_v1beta import SearchServiceClient + except ImportError as exc: + raise ImportError( + "google.cloud.discoveryengine is not installed." + "Please install it with pip install google-cloud-discoveryengine" + ) from exc + + super().__init__(**kwargs) + + # For more information, refer to: + # https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store + self._client = SearchServiceClient( + credentials=self.credentials, client_options=self.client_options + ) + + self._serving_config = self._client.serving_config_path( + project=self.project_id, + location=self.location_id, + data_store=self.data_store_id, + serving_config=self.serving_config_id, + ) def _create_search_request(self, query: str) -> SearchRequest: """Prepares a SearchRequest object.""" @@ -300,7 +295,14 @@ def _get_relevant_documents( ) if self.engine_data_type == 0: - documents = self._convert_unstructured_search_response(response.results) + chunk_type = ( + "extractive_answers" + if self.get_extractive_answers + else "extractive_segments" + ) + documents = self._convert_unstructured_search_response( + response.results, chunk_type + ) elif self.engine_data_type == 1: documents = self._convert_structured_search_response(response.results) else: @@ -312,3 +314,46 @@ def _get_relevant_documents( ) return documents + + +class GoogleVertexAIMultiTurnSearchRetriever( + BaseRetriever, _BaseGoogleVertexAISearchRetriever +): + _client: ConversationalSearchServiceClient + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.ignore + arbitrary_types_allowed = True + underscore_attrs_are_private = True + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + from google.cloud.discoveryengine_v1beta import ( + ConversationalSearchServiceClient, + ) + + self._client = ConversationalSearchServiceClient( + credentials=self.credentials, client_options=self.client_options + ) + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Get documents relevant for a query.""" + from google.cloud.discoveryengine_v1beta import ( + ConverseConversationRequest, + TextInput, + ) + + request = ConverseConversationRequest( + name=self._client.conversation_path( + self.project_id, self.location_id, self.data_store_id, "-" + ), + query=TextInput(input=query), + ) + response = self._client.converse_conversation(request) + return self._convert_unstructured_search_response( + response.search_results, "extractive_answers" + ) diff --git a/libs/langchain/tests/integration_tests/retrievers/test_google_vertex_ai_search.py b/libs/langchain/tests/integration_tests/retrievers/test_google_vertex_ai_search.py index 363eadbe4bf6e..57d097c070ce1 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_google_vertex_ai_search.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_google_vertex_ai_search.py @@ -7,8 +7,8 @@ to set up the app and configure authentication. Set the following environment variables before the tests: -PROJECT_ID - set to your Google Cloud project ID -DATA_STORE_ID - the ID of the search engine to use for the test +export PROJECT_ID=... - set to your Google Cloud project ID +export DATA_STORE_ID=... - the ID of the search engine to use for the test """ import os @@ -18,7 +18,10 @@ from langchain.retrievers.google_cloud_enterprise_search import ( GoogleCloudEnterpriseSearchRetriever, ) -from langchain.retrievers.google_vertex_ai_search import GoogleVertexAISearchRetriever +from langchain.retrievers.google_vertex_ai_search import ( + GoogleVertexAIMultiTurnSearchRetriever, + GoogleVertexAISearchRetriever, +) from langchain.schema import Document @@ -35,6 +38,19 @@ def test_google_vertex_ai_search_get_relevant_documents() -> None: assert doc.metadata["source"] +@pytest.mark.requires("google_api_core") +def test_google_vertex_ai_multiturnsearch_get_relevant_documents() -> None: + """Test the get_relevant_documents() method.""" + retriever = GoogleVertexAIMultiTurnSearchRetriever() + documents = retriever.get_relevant_documents("What are Alphabet's Other Bets?") + assert len(documents) > 0 + for doc in documents: + assert isinstance(doc, Document) + assert doc.page_content + assert doc.metadata["id"] + assert doc.metadata["source"] + + @pytest.mark.requires("google_api_core") def test_google_vertex_ai_search_enterprise_search_deprecation() -> None: """Test the deprecation of GoogleCloudEnterpriseSearchRetriever."""