From 15b0264c005a391d3bed032bf347c6bbb971500e Mon Sep 17 00:00:00 2001 From: Abhishek Bhagwat Date: Thu, 16 May 2024 21:27:06 +0800 Subject: [PATCH] feat: Vertex Check Grounding API integration (#186) * feat: Vertex Check Grounding API integration --- .../langchain_google_community/__init__.py | 4 + .../vertex_check_grounding.py | 241 ++++++++++++++++++ libs/community/pyproject.toml | 1 + .../integration_tests/test_check_grounding.py | 117 +++++++++ .../tests/integration_tests/test_rank.py | 2 +- .../tests/unit_tests/test_check_grounding.py | 129 ++++++++++ libs/community/tests/unit_tests/test_rank.py | 2 +- 7 files changed, 494 insertions(+), 2 deletions(-) create mode 100644 libs/community/langchain_google_community/vertex_check_grounding.py create mode 100644 libs/community/tests/integration_tests/test_check_grounding.py create mode 100644 libs/community/tests/unit_tests/test_check_grounding.py diff --git a/libs/community/langchain_google_community/__init__.py b/libs/community/langchain_google_community/__init__.py index 873272cd..50642894 100644 --- a/libs/community/langchain_google_community/__init__.py +++ b/libs/community/langchain_google_community/__init__.py @@ -24,6 +24,9 @@ VertexAISearchRetriever, VertexAISearchSummaryTool, ) +from langchain_google_community.vertex_check_grounding import ( + VertexAICheckGroundingWrapper, +) from langchain_google_community.vertex_rank import VertexAIRank from langchain_google_community.vision import CloudVisionLoader, CloudVisionParser @@ -52,4 +55,5 @@ "VertexAISearchRetriever", "VertexAISearchSummaryTool", "VertexAIRank", + "VertexAICheckGroundingWrapper", ] diff --git a/libs/community/langchain_google_community/vertex_check_grounding.py b/libs/community/langchain_google_community/vertex_check_grounding.py new file mode 100644 index 00000000..eab24fd2 --- /dev/null +++ b/libs/community/langchain_google_community/vertex_check_grounding.py @@ -0,0 +1,241 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from google.api_core import exceptions as core_exceptions # type: ignore +from google.auth.credentials import Credentials # type: ignore +from langchain_core.documents import Document +from langchain_core.pydantic_v1 import BaseModel, Extra, Field +from langchain_core.runnables import RunnableConfig, RunnableSerializable + +if TYPE_CHECKING: + from google.cloud import discoveryengine_v1alpha # type: ignore + + +class VertexAICheckGroundingWrapper( + RunnableSerializable[str, "VertexAICheckGroundingWrapper.CheckGroundingResponse"] +): + """ + Initializes the Vertex AI CheckGroundingOutputParser with configurable parameters. + + Calls the Check Grounding API to validate the response against a given set of + documents and returns back citations that support the claims along with the cited + chunks. Output is of the type CheckGroundingResponse. + + Attributes: + project_id (str): Google Cloud project ID + location_id (str): Location ID for the ranking service. + grounding_config (str): + Required. The resource name of the grounding config, such as + ``default_grounding_config``. + It is set to ``default_grounding_config`` by default if unspecified + citation_threshold (float): + The threshold (in [0,1]) used for determining whether a fact + must be cited for a claim in the answer candidate. Choosing + a higher threshold will lead to fewer but very strong + citations, while choosing a lower threshold may lead to more + but somewhat weaker citations. If unset, the threshold will + default to 0.6. + credentials (Optional[Credentials]): Google Cloud credentials object. + credentials_path (Optional[str]): Path to the Google Cloud service + account credentials file. + """ + + project_id: str = Field(default=None) + location_id: str = Field(default="global") + grounding_config: str = Field(default="default_grounding_config") + citation_threshold: Optional[float] = Field(default=0.6) + client: Any + credentials: Optional[Credentials] = Field(default=None) + credentials_path: Optional[str] = Field(default=None) + + class CheckGroundingResponse(BaseModel): + support_score: float = 0.0 + cited_chunks: List[Dict[str, Any]] = [] + claims: List[Dict[str, Any]] = [] + answer_with_citations: str = "" + + def __init__(self, **kwargs: Any): + """ + Constructor for CheckGroundingOutputParser. + Initializes the grounding check service client with necessary credentials + and configurations. + """ + super().__init__(**kwargs) + self.client = kwargs.get("client") + if not self.client: + self.client = self._get_check_grounding_service_client() + + def _get_check_grounding_service_client( + self, + ) -> "discoveryengine_v1alpha.GroundedGenerationServiceClient": + """ + Returns a GroundedGenerationServiceClient instance using provided credentials. + Raises ImportError if necessary packages are not installed. + + Returns: + A GroundedGenerationServiceClient instance. + """ + try: + from google.cloud import discoveryengine_v1alpha # type: ignore + except ImportError as exc: + raise ImportError( + "Could not import google-cloud-discoveryengine python package. " + "Please install vertexaisearch dependency group: " + "`pip install langchain-google-community[vertexaisearch]`" + ) from exc + return discoveryengine_v1alpha.GroundedGenerationServiceClient( + credentials=( + self.credentials + or Credentials.from_service_account_file(self.credentials_path) + if self.credentials_path + else None + ) + ) + + def invoke( + self, input: str, config: Optional[RunnableConfig] = None + ) -> CheckGroundingResponse: + """ + Calls the Vertex Check Grounding API for a given answer candidate and a list + of documents (claims) to validate whether the set of claims support the + answer candidate. + + Args: + answer_candidate (str): The candidate answer to be evaluated for grounding. + documents (List[Document]): The documents against which grounding is + checked. This will be converted to facts: + facts (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\ + GroundingFact]): + List of facts for the grounding check. + We support up to 200 facts. + Returns: + Response of the type CheckGroundingResponse + + Attributes: + support_score (float): + The support score for the input answer + candidate. Higher the score, higher is the + fraction of claims that are supported by the + provided facts. This is always set when a + response is returned. + + cited_chunks (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\ + FactChunk]): + List of facts cited across all claims in the + answer candidate. These are derived from the + facts supplied in the request. + + claims (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\ + CheckGroundingResponse.Claim]): + Claim texts and citation info across all + claims in the answer candidate. + + answer_with_citations (str): + Complete formed answer formatted with inline citations + """ + from google.cloud import discoveryengine_v1alpha # type: ignore + + answer_candidate = input + documents = self.extract_documents(config) + + grounding_spec = discoveryengine_v1alpha.CheckGroundingSpec( + citation_threshold=self.citation_threshold, + ) + + facts = [ + discoveryengine_v1alpha.GroundingFact( + fact_text=doc.page_content, + attributes={ + key: value + for key, value in ( + doc.metadata or {} + ).items() # Use an empty dict if metadata is None + if key not in ["id", "relevance_score"] and value is not None + }, + ) + for doc in documents + if doc.page_content # Only check that page_content is not None or empty + ] + + if not facts: + raise ValueError("No valid documents provided for grounding.") + + request = discoveryengine_v1alpha.CheckGroundingRequest( + grounding_config=f"projects/{self.project_id}/locations/{self.location_id}/groundingConfigs/{self.grounding_config}", + answer_candidate=answer_candidate, + facts=facts, + grounding_spec=grounding_spec, + ) + + if self.client is None: + raise ValueError("Client not initialized.") + try: + response = self.client.check_grounding(request=request) + except core_exceptions.GoogleAPICallError as e: + raise RuntimeError( + f"Error in Vertex AI Check Grounding API call: {str(e)}" + ) from e + + support_score = response.support_score + cited_chunks = [ + { + "chunk_text": chunk.chunk_text, + "source": documents[int(chunk.source)], + } + for chunk in response.cited_chunks + ] + claims = [ + { + "start_pos": claim.start_pos, + "end_pos": claim.end_pos, + "claim_text": claim.claim_text, + "citation_indices": list(claim.citation_indices), + } + for claim in response.claims + ] + + answer_with_citations = self.combine_claims_with_citations(claims) + return self.CheckGroundingResponse( + support_score=support_score, + cited_chunks=cited_chunks, + claims=claims, + answer_with_citations=answer_with_citations, + ) + + def extract_documents(self, config: Optional[RunnableConfig]) -> List[Document]: + if not config: + raise ValueError("Configuration is required.") + + potential_documents = config.get("configurable", {}).get("documents", []) + if not isinstance(potential_documents, list) or not all( + isinstance(doc, Document) for doc in potential_documents + ): + raise ValueError("Invalid documents. Each must be an instance of Document.") + + if not potential_documents: + raise ValueError("This wrapper requires documents for processing.") + + return potential_documents + + def combine_claims_with_citations(self, claims: List[Dict[str, Any]]) -> str: + sorted_claims = sorted(claims, key=lambda x: x["start_pos"]) + result = [] + for claim in sorted_claims: + if claim["citation_indices"]: + citations = "".join([f"[{idx}]" for idx in claim["citation_indices"]]) + claim_text = f"{claim['claim_text']}{citations}" + else: + claim_text = claim["claim_text"] + result.append(claim_text) + return " ".join(result).strip() + + @classmethod + def get_lc_namespace(cls) -> List[str]: + return ["langchain", "utilities", "check_grounding"] + + @classmethod + def is_lc_serializable(cls) -> bool: + return False + + class Config: + extra = Extra.ignore + arbitrary_types_allowed = True diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index e112572a..f76d7eaa 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -105,6 +105,7 @@ select = [ [tool.mypy] disallow_untyped_defs = "True" +ignore_missing_imports = "True" [tool.coverage.run] omit = ["tests/*"] diff --git a/libs/community/tests/integration_tests/test_check_grounding.py b/libs/community/tests/integration_tests/test_check_grounding.py new file mode 100644 index 00000000..3bd571eb --- /dev/null +++ b/libs/community/tests/integration_tests/test_check_grounding.py @@ -0,0 +1,117 @@ +import os +from typing import List + +import pytest +from google.cloud import discoveryengine_v1alpha # type: ignore +from langchain_core.documents import Document + +from langchain_google_community.vertex_check_grounding import ( + VertexAICheckGroundingWrapper, +) + + +@pytest.fixture +def input_documents() -> List[Document]: + return [ + Document( + page_content=( + "Born in the German Empire, Einstein moved to Switzerland in 1895, " + "forsaking his German citizenship (as a subject of the Kingdom of " + "Württemberg)[note 1] the following year. In 1897, at the age of " + "seventeen, he enrolled in the mathematics and physics teaching " + "diploma program at the Swiss federal polytechnic school in Zürich, " + "graduating in 1900. In 1901, he acquired Swiss citizenship, which " + "he kept for the rest of his life. In 1903, he secured a permanent " + "position at the Swiss Patent Office in Bern. In 1905, he submitted " + "a successful PhD dissertation to the University of Zurich. In 1914, " + "he moved to Berlin in order to join the Prussian Academy of Sciences " + "and the Humboldt University of Berlin. In 1917, he became director " + "of the Kaiser Wilhelm Institute for Physics; he also became a German " + "citizen again, this time as a subject of the Kingdom of Prussia." + "\nIn 1933, while he was visiting the United States, Adolf Hitler came " + 'to power in Germany. Horrified by the Nazi "war of extermination" ' + "against his fellow Jews,[12] Einstein decided to remain in the US, " + "and was granted American citizenship in 1940.[13] On the eve of World " + "War II, he endorsed a letter to President Franklin D. Roosevelt " + "alerting him to the potential German nuclear weapons program and " + "recommending that the US begin similar research. Einstein supported " + "the Allies but generally viewed the idea of nuclear weapons with " + "great dismay.[14]" + ), + metadata={ + "language": "en", + "source": "https://en.wikipedia.org/wiki/Albert_Einstein", + "title": "Albert Einstein - Wikipedia", + }, + ), + Document( + page_content=( + "Life and career\n" + "Childhood, youth and education\n" + "See also: Einstein family\n" + "Einstein in 1882, age\xa03\n" + "Albert Einstein was born in Ulm,[19] in the Kingdom of Württemberg " + "in the German Empire, on 14 March 1879.[20][21] His parents, secular " + "Ashkenazi Jews, were Hermann Einstein, a salesman and engineer, and " + "Pauline Koch. In 1880, the family moved to Munich's borough of " + "Ludwigsvorstadt-Isarvorstadt, where Einstein's father and his uncle " + "Jakob founded Elektrotechnische Fabrik J. Einstein & Cie, a company " + "that manufactured electrical equipment based on direct current.[19]\n" + "Albert attended a Catholic elementary school in Munich from the age " + "of five. When he was eight, he was transferred to the Luitpold " + "Gymnasium, where he received advanced primary and then secondary " + "school education.[22]" + ), + metadata={ + "language": "en", + "source": "https://en.wikipedia.org/wiki/Albert_Einstein", + "title": "Albert Einstein - Wikipedia", + }, + ), + ] + + +@pytest.fixture +def grounded_generation_service_client() -> ( + discoveryengine_v1alpha.GroundedGenerationServiceClient +): + return discoveryengine_v1alpha.GroundedGenerationServiceClient() + + +@pytest.fixture +def output_parser( + grounded_generation_service_client: ( + discoveryengine_v1alpha.GroundedGenerationServiceClient + ), +) -> VertexAICheckGroundingWrapper: + return VertexAICheckGroundingWrapper( + project_id=os.environ["PROJECT_ID"], + location_id=os.environ.get("REGION", "global"), + grounding_config=os.environ.get("GROUNDING_CONFIG", "default_grounding_config"), + client=grounded_generation_service_client, + ) + + +@pytest.mark.extended +def test_integration_parse( + output_parser: VertexAICheckGroundingWrapper, + input_documents: List[Document], +) -> None: + answer_candidate = "Ulm, in the Kingdom of Württemberg in the German Empire" + response = output_parser.with_config( + configurable={"documents": input_documents} + ).invoke(answer_candidate) + + assert isinstance(response, VertexAICheckGroundingWrapper.CheckGroundingResponse) + assert response.support_score >= 0 and response.support_score <= 1 + assert len(response.cited_chunks) > 0 + for chunk in response.cited_chunks: + assert isinstance(chunk["chunk_text"], str) + assert isinstance(chunk["source"], Document) + assert len(response.claims) > 0 + for claim in response.claims: + assert isinstance(claim["start_pos"], int) + assert isinstance(claim["end_pos"], int) + assert isinstance(claim["claim_text"], str) + assert isinstance(claim["citation_indices"], list) + assert isinstance(response.answer_with_citations, str) diff --git a/libs/community/tests/integration_tests/test_rank.py b/libs/community/tests/integration_tests/test_rank.py index 0aff8c0c..f2b86487 100644 --- a/libs/community/tests/integration_tests/test_rank.py +++ b/libs/community/tests/integration_tests/test_rank.py @@ -3,7 +3,7 @@ from unittest.mock import create_autospec import pytest -from google.cloud import discoveryengine_v1alpha +from google.cloud import discoveryengine_v1alpha # type: ignore from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.pydantic_v1 import Field diff --git a/libs/community/tests/unit_tests/test_check_grounding.py b/libs/community/tests/unit_tests/test_check_grounding.py new file mode 100644 index 00000000..033c0dfd --- /dev/null +++ b/libs/community/tests/unit_tests/test_check_grounding.py @@ -0,0 +1,129 @@ +from unittest.mock import Mock + +import pytest +from google.cloud import discoveryengine_v1alpha # type: ignore +from langchain_core.documents import Document + +from langchain_google_community.vertex_check_grounding import ( + VertexAICheckGroundingWrapper, +) + + +@pytest.fixture +def mock_check_grounding_service_client() -> Mock: + mock_client = Mock(spec=discoveryengine_v1alpha.GroundedGenerationServiceClient) + mock_client.check_grounding.return_value = discoveryengine_v1alpha.CheckGroundingResponse( # noqa: E501 + support_score=0.9919261932373047, + cited_chunks=[ + discoveryengine_v1alpha.FactChunk( + chunk_text=( + "Life and career\n" + "Childhood, youth and education\n" + "See also: Einstein family\n" + "Einstein in 1882, age\xa03\n" + "Albert Einstein was born in Ulm,[19] in the Kingdom of " + "Württemberg in the German Empire, on 14 March " + "1879.[20][21] His parents, secular Ashkenazi Jews, were " + "Hermann Einstein, a salesman and engineer, and " + "Pauline Koch. In 1880, the family moved to Munich's " + "borough of Ludwigsvorstadt-Isarvorstadt, where " + "Einstein's father and his uncle Jakob founded " + "Elektrotechnische Fabrik J. Einstein & Cie, a company " + "that manufactured electrical equipment based on direct " + "current.[19]\n" + "Albert attended a Catholic elementary school in Munich " + "from the age of five. When he was eight, he was " + "transferred to the Luitpold Gymnasium, where he received " + "advanced primary and then secondary school education.[22]" + ), + source="0", + ), + ], + claims=[ + discoveryengine_v1alpha.CheckGroundingResponse.Claim( + start_pos=0, + end_pos=56, + claim_text="Ulm, in the Kingdom of Württemberg in the German Empire", + citation_indices=[0], + ), + ], + ) + return mock_client + + +def test_parse(mock_check_grounding_service_client: Mock) -> None: + output_parser = VertexAICheckGroundingWrapper( + project_id="test-project", + client=mock_check_grounding_service_client, + ) + documents = [ + Document( + page_content=( + "Life and career\n" + "Childhood, youth and education\n" + "See also: Einstein family\n" + "Einstein in 1882, age\xa03\n" + "Albert Einstein was born in Ulm,[19] in the Kingdom of " + "Württemberg in the German Empire, on 14 March " + "1879.[20][21] His parents, secular Ashkenazi Jews, were " + "Hermann Einstein, a salesman and engineer, and " + "Pauline Koch. In 1880, the family moved to Munich's " + "borough of Ludwigsvorstadt-Isarvorstadt, where " + "Einstein's father and his uncle Jakob founded " + "Elektrotechnische Fabrik J. Einstein & Cie, a company that " + "manufactured electrical equipment based on direct current.[19]\n" + "Albert attended a Catholic elementary school in Munich " + "from the age of five. When he was eight, he was " + "transferred to the Luitpold Gymnasium, where he received " + "advanced primary and then secondary school education.[22]" + ), + metadata={ + "language": "en", + "source": "https://en.wikipedia.org/wiki/Albert_Einstein", + "title": "Albert Einstein - Wikipedia", + }, + ), + ] + answer_candidate = "Ulm, in the Kingdom of Württemberg in the German Empire" + response = output_parser.with_config(configurable={"documents": documents}).invoke( + answer_candidate + ) + + assert response == VertexAICheckGroundingWrapper.CheckGroundingResponse( + support_score=0.9919261932373047, + cited_chunks=[ + { + "chunk_text": ( + "Life and career\n" + "Childhood, youth and education\n" + "See also: Einstein family\n" + "Einstein in 1882, age\xa03\n" + "Albert Einstein was born in Ulm,[19] in the Kingdom of " + "Württemberg in the German Empire, on 14 March " + "1879.[20][21] His parents, secular Ashkenazi Jews, were " + "Hermann Einstein, a salesman and engineer, and " + "Pauline Koch. In 1880, the family moved to Munich's " + "borough of Ludwigsvorstadt-Isarvorstadt, where " + "Einstein's father and his uncle Jakob founded " + "Elektrotechnische Fabrik J. Einstein & Cie, a company that " + "manufactured electrical equipment based on direct current.[19]\n" + "Albert attended a Catholic elementary school in Munich " + "from the age of five. When he was eight, he was " + "transferred to the Luitpold Gymnasium, where he received " + "advanced primary and then secondary school education.[22]" + ), + "source": documents[0], + }, + ], + claims=[ + { + "start_pos": 0, + "end_pos": 56, + "claim_text": "Ulm, in the Kingdom of Württemberg in the German Empire", + "citation_indices": [0], + }, + ], + answer_with_citations=( + "Ulm, in the Kingdom of Württemberg in the German Empire[0]" + ), + ) diff --git a/libs/community/tests/unit_tests/test_rank.py b/libs/community/tests/unit_tests/test_rank.py index f35791a4..81049708 100644 --- a/libs/community/tests/unit_tests/test_rank.py +++ b/libs/community/tests/unit_tests/test_rank.py @@ -1,7 +1,7 @@ from unittest.mock import Mock, patch import pytest -from google.cloud import discoveryengine_v1alpha +from google.cloud import discoveryengine_v1alpha # type: ignore from langchain_core.documents import Document from pytest import approx