Skip to content

Commit

Permalink
feat: Vertex Check Grounding API integration (#186)
Browse files Browse the repository at this point in the history
* feat: Vertex Check Grounding API integration
  • Loading branch information
Abhishekbhagwat authored May 16, 2024
1 parent 8793ee8 commit 15b0264
Show file tree
Hide file tree
Showing 7 changed files with 494 additions and 2 deletions.
4 changes: 4 additions & 0 deletions libs/community/langchain_google_community/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
VertexAISearchRetriever,
VertexAISearchSummaryTool,
)
from langchain_google_community.vertex_check_grounding import (
VertexAICheckGroundingWrapper,
)
from langchain_google_community.vertex_rank import VertexAIRank
from langchain_google_community.vision import CloudVisionLoader, CloudVisionParser

Expand Down Expand Up @@ -52,4 +55,5 @@
"VertexAISearchRetriever",
"VertexAISearchSummaryTool",
"VertexAIRank",
"VertexAICheckGroundingWrapper",
]
241 changes: 241 additions & 0 deletions libs/community/langchain_google_community/vertex_check_grounding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from google.api_core import exceptions as core_exceptions # type: ignore
from google.auth.credentials import Credentials # type: ignore
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
from langchain_core.runnables import RunnableConfig, RunnableSerializable

if TYPE_CHECKING:
from google.cloud import discoveryengine_v1alpha # type: ignore


class VertexAICheckGroundingWrapper(
RunnableSerializable[str, "VertexAICheckGroundingWrapper.CheckGroundingResponse"]
):
"""
Initializes the Vertex AI CheckGroundingOutputParser with configurable parameters.
Calls the Check Grounding API to validate the response against a given set of
documents and returns back citations that support the claims along with the cited
chunks. Output is of the type CheckGroundingResponse.
Attributes:
project_id (str): Google Cloud project ID
location_id (str): Location ID for the ranking service.
grounding_config (str):
Required. The resource name of the grounding config, such as
``default_grounding_config``.
It is set to ``default_grounding_config`` by default if unspecified
citation_threshold (float):
The threshold (in [0,1]) used for determining whether a fact
must be cited for a claim in the answer candidate. Choosing
a higher threshold will lead to fewer but very strong
citations, while choosing a lower threshold may lead to more
but somewhat weaker citations. If unset, the threshold will
default to 0.6.
credentials (Optional[Credentials]): Google Cloud credentials object.
credentials_path (Optional[str]): Path to the Google Cloud service
account credentials file.
"""

project_id: str = Field(default=None)
location_id: str = Field(default="global")
grounding_config: str = Field(default="default_grounding_config")
citation_threshold: Optional[float] = Field(default=0.6)
client: Any
credentials: Optional[Credentials] = Field(default=None)
credentials_path: Optional[str] = Field(default=None)

class CheckGroundingResponse(BaseModel):
support_score: float = 0.0
cited_chunks: List[Dict[str, Any]] = []
claims: List[Dict[str, Any]] = []
answer_with_citations: str = ""

def __init__(self, **kwargs: Any):
"""
Constructor for CheckGroundingOutputParser.
Initializes the grounding check service client with necessary credentials
and configurations.
"""
super().__init__(**kwargs)
self.client = kwargs.get("client")
if not self.client:
self.client = self._get_check_grounding_service_client()

def _get_check_grounding_service_client(
self,
) -> "discoveryengine_v1alpha.GroundedGenerationServiceClient":
"""
Returns a GroundedGenerationServiceClient instance using provided credentials.
Raises ImportError if necessary packages are not installed.
Returns:
A GroundedGenerationServiceClient instance.
"""
try:
from google.cloud import discoveryengine_v1alpha # type: ignore
except ImportError as exc:
raise ImportError(
"Could not import google-cloud-discoveryengine python package. "
"Please install vertexaisearch dependency group: "
"`pip install langchain-google-community[vertexaisearch]`"
) from exc
return discoveryengine_v1alpha.GroundedGenerationServiceClient(
credentials=(
self.credentials
or Credentials.from_service_account_file(self.credentials_path)
if self.credentials_path
else None
)
)

def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> CheckGroundingResponse:
"""
Calls the Vertex Check Grounding API for a given answer candidate and a list
of documents (claims) to validate whether the set of claims support the
answer candidate.
Args:
answer_candidate (str): The candidate answer to be evaluated for grounding.
documents (List[Document]): The documents against which grounding is
checked. This will be converted to facts:
facts (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\
GroundingFact]):
List of facts for the grounding check.
We support up to 200 facts.
Returns:
Response of the type CheckGroundingResponse
Attributes:
support_score (float):
The support score for the input answer
candidate. Higher the score, higher is the
fraction of claims that are supported by the
provided facts. This is always set when a
response is returned.
cited_chunks (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\
FactChunk]):
List of facts cited across all claims in the
answer candidate. These are derived from the
facts supplied in the request.
claims (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\
CheckGroundingResponse.Claim]):
Claim texts and citation info across all
claims in the answer candidate.
answer_with_citations (str):
Complete formed answer formatted with inline citations
"""
from google.cloud import discoveryengine_v1alpha # type: ignore

answer_candidate = input
documents = self.extract_documents(config)

grounding_spec = discoveryengine_v1alpha.CheckGroundingSpec(
citation_threshold=self.citation_threshold,
)

facts = [
discoveryengine_v1alpha.GroundingFact(
fact_text=doc.page_content,
attributes={
key: value
for key, value in (
doc.metadata or {}
).items() # Use an empty dict if metadata is None
if key not in ["id", "relevance_score"] and value is not None
},
)
for doc in documents
if doc.page_content # Only check that page_content is not None or empty
]

if not facts:
raise ValueError("No valid documents provided for grounding.")

request = discoveryengine_v1alpha.CheckGroundingRequest(
grounding_config=f"projects/{self.project_id}/locations/{self.location_id}/groundingConfigs/{self.grounding_config}",
answer_candidate=answer_candidate,
facts=facts,
grounding_spec=grounding_spec,
)

if self.client is None:
raise ValueError("Client not initialized.")
try:
response = self.client.check_grounding(request=request)
except core_exceptions.GoogleAPICallError as e:
raise RuntimeError(
f"Error in Vertex AI Check Grounding API call: {str(e)}"
) from e

support_score = response.support_score
cited_chunks = [
{
"chunk_text": chunk.chunk_text,
"source": documents[int(chunk.source)],
}
for chunk in response.cited_chunks
]
claims = [
{
"start_pos": claim.start_pos,
"end_pos": claim.end_pos,
"claim_text": claim.claim_text,
"citation_indices": list(claim.citation_indices),
}
for claim in response.claims
]

answer_with_citations = self.combine_claims_with_citations(claims)
return self.CheckGroundingResponse(
support_score=support_score,
cited_chunks=cited_chunks,
claims=claims,
answer_with_citations=answer_with_citations,
)

def extract_documents(self, config: Optional[RunnableConfig]) -> List[Document]:
if not config:
raise ValueError("Configuration is required.")

potential_documents = config.get("configurable", {}).get("documents", [])
if not isinstance(potential_documents, list) or not all(
isinstance(doc, Document) for doc in potential_documents
):
raise ValueError("Invalid documents. Each must be an instance of Document.")

if not potential_documents:
raise ValueError("This wrapper requires documents for processing.")

return potential_documents

def combine_claims_with_citations(self, claims: List[Dict[str, Any]]) -> str:
sorted_claims = sorted(claims, key=lambda x: x["start_pos"])
result = []
for claim in sorted_claims:
if claim["citation_indices"]:
citations = "".join([f"[{idx}]" for idx in claim["citation_indices"]])
claim_text = f"{claim['claim_text']}{citations}"
else:
claim_text = claim["claim_text"]
result.append(claim_text)
return " ".join(result).strip()

@classmethod
def get_lc_namespace(cls) -> List[str]:
return ["langchain", "utilities", "check_grounding"]

@classmethod
def is_lc_serializable(cls) -> bool:
return False

class Config:
extra = Extra.ignore
arbitrary_types_allowed = True
1 change: 1 addition & 0 deletions libs/community/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ select = [

[tool.mypy]
disallow_untyped_defs = "True"
ignore_missing_imports = "True"

[tool.coverage.run]
omit = ["tests/*"]
Expand Down
117 changes: 117 additions & 0 deletions libs/community/tests/integration_tests/test_check_grounding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
from typing import List

import pytest
from google.cloud import discoveryengine_v1alpha # type: ignore
from langchain_core.documents import Document

from langchain_google_community.vertex_check_grounding import (
VertexAICheckGroundingWrapper,
)


@pytest.fixture
def input_documents() -> List[Document]:
return [
Document(
page_content=(
"Born in the German Empire, Einstein moved to Switzerland in 1895, "
"forsaking his German citizenship (as a subject of the Kingdom of "
"Württemberg)[note 1] the following year. In 1897, at the age of "
"seventeen, he enrolled in the mathematics and physics teaching "
"diploma program at the Swiss federal polytechnic school in Zürich, "
"graduating in 1900. In 1901, he acquired Swiss citizenship, which "
"he kept for the rest of his life. In 1903, he secured a permanent "
"position at the Swiss Patent Office in Bern. In 1905, he submitted "
"a successful PhD dissertation to the University of Zurich. In 1914, "
"he moved to Berlin in order to join the Prussian Academy of Sciences "
"and the Humboldt University of Berlin. In 1917, he became director "
"of the Kaiser Wilhelm Institute for Physics; he also became a German "
"citizen again, this time as a subject of the Kingdom of Prussia."
"\nIn 1933, while he was visiting the United States, Adolf Hitler came "
'to power in Germany. Horrified by the Nazi "war of extermination" '
"against his fellow Jews,[12] Einstein decided to remain in the US, "
"and was granted American citizenship in 1940.[13] On the eve of World "
"War II, he endorsed a letter to President Franklin D. Roosevelt "
"alerting him to the potential German nuclear weapons program and "
"recommending that the US begin similar research. Einstein supported "
"the Allies but generally viewed the idea of nuclear weapons with "
"great dismay.[14]"
),
metadata={
"language": "en",
"source": "https://en.wikipedia.org/wiki/Albert_Einstein",
"title": "Albert Einstein - Wikipedia",
},
),
Document(
page_content=(
"Life and career\n"
"Childhood, youth and education\n"
"See also: Einstein family\n"
"Einstein in 1882, age\xa03\n"
"Albert Einstein was born in Ulm,[19] in the Kingdom of Württemberg "
"in the German Empire, on 14 March 1879.[20][21] His parents, secular "
"Ashkenazi Jews, were Hermann Einstein, a salesman and engineer, and "
"Pauline Koch. In 1880, the family moved to Munich's borough of "
"Ludwigsvorstadt-Isarvorstadt, where Einstein's father and his uncle "
"Jakob founded Elektrotechnische Fabrik J. Einstein & Cie, a company "
"that manufactured electrical equipment based on direct current.[19]\n"
"Albert attended a Catholic elementary school in Munich from the age "
"of five. When he was eight, he was transferred to the Luitpold "
"Gymnasium, where he received advanced primary and then secondary "
"school education.[22]"
),
metadata={
"language": "en",
"source": "https://en.wikipedia.org/wiki/Albert_Einstein",
"title": "Albert Einstein - Wikipedia",
},
),
]


@pytest.fixture
def grounded_generation_service_client() -> (
discoveryengine_v1alpha.GroundedGenerationServiceClient
):
return discoveryengine_v1alpha.GroundedGenerationServiceClient()


@pytest.fixture
def output_parser(
grounded_generation_service_client: (
discoveryengine_v1alpha.GroundedGenerationServiceClient
),
) -> VertexAICheckGroundingWrapper:
return VertexAICheckGroundingWrapper(
project_id=os.environ["PROJECT_ID"],
location_id=os.environ.get("REGION", "global"),
grounding_config=os.environ.get("GROUNDING_CONFIG", "default_grounding_config"),
client=grounded_generation_service_client,
)


@pytest.mark.extended
def test_integration_parse(
output_parser: VertexAICheckGroundingWrapper,
input_documents: List[Document],
) -> None:
answer_candidate = "Ulm, in the Kingdom of Württemberg in the German Empire"
response = output_parser.with_config(
configurable={"documents": input_documents}
).invoke(answer_candidate)

assert isinstance(response, VertexAICheckGroundingWrapper.CheckGroundingResponse)
assert response.support_score >= 0 and response.support_score <= 1
assert len(response.cited_chunks) > 0
for chunk in response.cited_chunks:
assert isinstance(chunk["chunk_text"], str)
assert isinstance(chunk["source"], Document)
assert len(response.claims) > 0
for claim in response.claims:
assert isinstance(claim["start_pos"], int)
assert isinstance(claim["end_pos"], int)
assert isinstance(claim["claim_text"], str)
assert isinstance(claim["citation_indices"], list)
assert isinstance(response.answer_with_citations, str)
2 changes: 1 addition & 1 deletion libs/community/tests/integration_tests/test_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from unittest.mock import create_autospec

import pytest
from google.cloud import discoveryengine_v1alpha
from google.cloud import discoveryengine_v1alpha # type: ignore
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field
Expand Down
Loading

0 comments on commit 15b0264

Please sign in to comment.