-
Notifications
You must be signed in to change notification settings - Fork 167
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Vertex Check Grounding API integration (#186)
* feat: Vertex Check Grounding API integration
- Loading branch information
1 parent
8793ee8
commit 15b0264
Showing
7 changed files
with
494 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
241 changes: 241 additions & 0 deletions
241
libs/community/langchain_google_community/vertex_check_grounding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||
|
||
from google.api_core import exceptions as core_exceptions # type: ignore | ||
from google.auth.credentials import Credentials # type: ignore | ||
from langchain_core.documents import Document | ||
from langchain_core.pydantic_v1 import BaseModel, Extra, Field | ||
from langchain_core.runnables import RunnableConfig, RunnableSerializable | ||
|
||
if TYPE_CHECKING: | ||
from google.cloud import discoveryengine_v1alpha # type: ignore | ||
|
||
|
||
class VertexAICheckGroundingWrapper( | ||
RunnableSerializable[str, "VertexAICheckGroundingWrapper.CheckGroundingResponse"] | ||
): | ||
""" | ||
Initializes the Vertex AI CheckGroundingOutputParser with configurable parameters. | ||
Calls the Check Grounding API to validate the response against a given set of | ||
documents and returns back citations that support the claims along with the cited | ||
chunks. Output is of the type CheckGroundingResponse. | ||
Attributes: | ||
project_id (str): Google Cloud project ID | ||
location_id (str): Location ID for the ranking service. | ||
grounding_config (str): | ||
Required. The resource name of the grounding config, such as | ||
``default_grounding_config``. | ||
It is set to ``default_grounding_config`` by default if unspecified | ||
citation_threshold (float): | ||
The threshold (in [0,1]) used for determining whether a fact | ||
must be cited for a claim in the answer candidate. Choosing | ||
a higher threshold will lead to fewer but very strong | ||
citations, while choosing a lower threshold may lead to more | ||
but somewhat weaker citations. If unset, the threshold will | ||
default to 0.6. | ||
credentials (Optional[Credentials]): Google Cloud credentials object. | ||
credentials_path (Optional[str]): Path to the Google Cloud service | ||
account credentials file. | ||
""" | ||
|
||
project_id: str = Field(default=None) | ||
location_id: str = Field(default="global") | ||
grounding_config: str = Field(default="default_grounding_config") | ||
citation_threshold: Optional[float] = Field(default=0.6) | ||
client: Any | ||
credentials: Optional[Credentials] = Field(default=None) | ||
credentials_path: Optional[str] = Field(default=None) | ||
|
||
class CheckGroundingResponse(BaseModel): | ||
support_score: float = 0.0 | ||
cited_chunks: List[Dict[str, Any]] = [] | ||
claims: List[Dict[str, Any]] = [] | ||
answer_with_citations: str = "" | ||
|
||
def __init__(self, **kwargs: Any): | ||
""" | ||
Constructor for CheckGroundingOutputParser. | ||
Initializes the grounding check service client with necessary credentials | ||
and configurations. | ||
""" | ||
super().__init__(**kwargs) | ||
self.client = kwargs.get("client") | ||
if not self.client: | ||
self.client = self._get_check_grounding_service_client() | ||
|
||
def _get_check_grounding_service_client( | ||
self, | ||
) -> "discoveryengine_v1alpha.GroundedGenerationServiceClient": | ||
""" | ||
Returns a GroundedGenerationServiceClient instance using provided credentials. | ||
Raises ImportError if necessary packages are not installed. | ||
Returns: | ||
A GroundedGenerationServiceClient instance. | ||
""" | ||
try: | ||
from google.cloud import discoveryengine_v1alpha # type: ignore | ||
except ImportError as exc: | ||
raise ImportError( | ||
"Could not import google-cloud-discoveryengine python package. " | ||
"Please install vertexaisearch dependency group: " | ||
"`pip install langchain-google-community[vertexaisearch]`" | ||
) from exc | ||
return discoveryengine_v1alpha.GroundedGenerationServiceClient( | ||
credentials=( | ||
self.credentials | ||
or Credentials.from_service_account_file(self.credentials_path) | ||
if self.credentials_path | ||
else None | ||
) | ||
) | ||
|
||
def invoke( | ||
self, input: str, config: Optional[RunnableConfig] = None | ||
) -> CheckGroundingResponse: | ||
""" | ||
Calls the Vertex Check Grounding API for a given answer candidate and a list | ||
of documents (claims) to validate whether the set of claims support the | ||
answer candidate. | ||
Args: | ||
answer_candidate (str): The candidate answer to be evaluated for grounding. | ||
documents (List[Document]): The documents against which grounding is | ||
checked. This will be converted to facts: | ||
facts (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\ | ||
GroundingFact]): | ||
List of facts for the grounding check. | ||
We support up to 200 facts. | ||
Returns: | ||
Response of the type CheckGroundingResponse | ||
Attributes: | ||
support_score (float): | ||
The support score for the input answer | ||
candidate. Higher the score, higher is the | ||
fraction of claims that are supported by the | ||
provided facts. This is always set when a | ||
response is returned. | ||
cited_chunks (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\ | ||
FactChunk]): | ||
List of facts cited across all claims in the | ||
answer candidate. These are derived from the | ||
facts supplied in the request. | ||
claims (MutableSequence[google.cloud.discoveryengine_v1alpha.types.\ | ||
CheckGroundingResponse.Claim]): | ||
Claim texts and citation info across all | ||
claims in the answer candidate. | ||
answer_with_citations (str): | ||
Complete formed answer formatted with inline citations | ||
""" | ||
from google.cloud import discoveryengine_v1alpha # type: ignore | ||
|
||
answer_candidate = input | ||
documents = self.extract_documents(config) | ||
|
||
grounding_spec = discoveryengine_v1alpha.CheckGroundingSpec( | ||
citation_threshold=self.citation_threshold, | ||
) | ||
|
||
facts = [ | ||
discoveryengine_v1alpha.GroundingFact( | ||
fact_text=doc.page_content, | ||
attributes={ | ||
key: value | ||
for key, value in ( | ||
doc.metadata or {} | ||
).items() # Use an empty dict if metadata is None | ||
if key not in ["id", "relevance_score"] and value is not None | ||
}, | ||
) | ||
for doc in documents | ||
if doc.page_content # Only check that page_content is not None or empty | ||
] | ||
|
||
if not facts: | ||
raise ValueError("No valid documents provided for grounding.") | ||
|
||
request = discoveryengine_v1alpha.CheckGroundingRequest( | ||
grounding_config=f"projects/{self.project_id}/locations/{self.location_id}/groundingConfigs/{self.grounding_config}", | ||
answer_candidate=answer_candidate, | ||
facts=facts, | ||
grounding_spec=grounding_spec, | ||
) | ||
|
||
if self.client is None: | ||
raise ValueError("Client not initialized.") | ||
try: | ||
response = self.client.check_grounding(request=request) | ||
except core_exceptions.GoogleAPICallError as e: | ||
raise RuntimeError( | ||
f"Error in Vertex AI Check Grounding API call: {str(e)}" | ||
) from e | ||
|
||
support_score = response.support_score | ||
cited_chunks = [ | ||
{ | ||
"chunk_text": chunk.chunk_text, | ||
"source": documents[int(chunk.source)], | ||
} | ||
for chunk in response.cited_chunks | ||
] | ||
claims = [ | ||
{ | ||
"start_pos": claim.start_pos, | ||
"end_pos": claim.end_pos, | ||
"claim_text": claim.claim_text, | ||
"citation_indices": list(claim.citation_indices), | ||
} | ||
for claim in response.claims | ||
] | ||
|
||
answer_with_citations = self.combine_claims_with_citations(claims) | ||
return self.CheckGroundingResponse( | ||
support_score=support_score, | ||
cited_chunks=cited_chunks, | ||
claims=claims, | ||
answer_with_citations=answer_with_citations, | ||
) | ||
|
||
def extract_documents(self, config: Optional[RunnableConfig]) -> List[Document]: | ||
if not config: | ||
raise ValueError("Configuration is required.") | ||
|
||
potential_documents = config.get("configurable", {}).get("documents", []) | ||
if not isinstance(potential_documents, list) or not all( | ||
isinstance(doc, Document) for doc in potential_documents | ||
): | ||
raise ValueError("Invalid documents. Each must be an instance of Document.") | ||
|
||
if not potential_documents: | ||
raise ValueError("This wrapper requires documents for processing.") | ||
|
||
return potential_documents | ||
|
||
def combine_claims_with_citations(self, claims: List[Dict[str, Any]]) -> str: | ||
sorted_claims = sorted(claims, key=lambda x: x["start_pos"]) | ||
result = [] | ||
for claim in sorted_claims: | ||
if claim["citation_indices"]: | ||
citations = "".join([f"[{idx}]" for idx in claim["citation_indices"]]) | ||
claim_text = f"{claim['claim_text']}{citations}" | ||
else: | ||
claim_text = claim["claim_text"] | ||
result.append(claim_text) | ||
return " ".join(result).strip() | ||
|
||
@classmethod | ||
def get_lc_namespace(cls) -> List[str]: | ||
return ["langchain", "utilities", "check_grounding"] | ||
|
||
@classmethod | ||
def is_lc_serializable(cls) -> bool: | ||
return False | ||
|
||
class Config: | ||
extra = Extra.ignore | ||
arbitrary_types_allowed = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
117 changes: 117 additions & 0 deletions
117
libs/community/tests/integration_tests/test_check_grounding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import os | ||
from typing import List | ||
|
||
import pytest | ||
from google.cloud import discoveryengine_v1alpha # type: ignore | ||
from langchain_core.documents import Document | ||
|
||
from langchain_google_community.vertex_check_grounding import ( | ||
VertexAICheckGroundingWrapper, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def input_documents() -> List[Document]: | ||
return [ | ||
Document( | ||
page_content=( | ||
"Born in the German Empire, Einstein moved to Switzerland in 1895, " | ||
"forsaking his German citizenship (as a subject of the Kingdom of " | ||
"Württemberg)[note 1] the following year. In 1897, at the age of " | ||
"seventeen, he enrolled in the mathematics and physics teaching " | ||
"diploma program at the Swiss federal polytechnic school in Zürich, " | ||
"graduating in 1900. In 1901, he acquired Swiss citizenship, which " | ||
"he kept for the rest of his life. In 1903, he secured a permanent " | ||
"position at the Swiss Patent Office in Bern. In 1905, he submitted " | ||
"a successful PhD dissertation to the University of Zurich. In 1914, " | ||
"he moved to Berlin in order to join the Prussian Academy of Sciences " | ||
"and the Humboldt University of Berlin. In 1917, he became director " | ||
"of the Kaiser Wilhelm Institute for Physics; he also became a German " | ||
"citizen again, this time as a subject of the Kingdom of Prussia." | ||
"\nIn 1933, while he was visiting the United States, Adolf Hitler came " | ||
'to power in Germany. Horrified by the Nazi "war of extermination" ' | ||
"against his fellow Jews,[12] Einstein decided to remain in the US, " | ||
"and was granted American citizenship in 1940.[13] On the eve of World " | ||
"War II, he endorsed a letter to President Franklin D. Roosevelt " | ||
"alerting him to the potential German nuclear weapons program and " | ||
"recommending that the US begin similar research. Einstein supported " | ||
"the Allies but generally viewed the idea of nuclear weapons with " | ||
"great dismay.[14]" | ||
), | ||
metadata={ | ||
"language": "en", | ||
"source": "https://en.wikipedia.org/wiki/Albert_Einstein", | ||
"title": "Albert Einstein - Wikipedia", | ||
}, | ||
), | ||
Document( | ||
page_content=( | ||
"Life and career\n" | ||
"Childhood, youth and education\n" | ||
"See also: Einstein family\n" | ||
"Einstein in 1882, age\xa03\n" | ||
"Albert Einstein was born in Ulm,[19] in the Kingdom of Württemberg " | ||
"in the German Empire, on 14 March 1879.[20][21] His parents, secular " | ||
"Ashkenazi Jews, were Hermann Einstein, a salesman and engineer, and " | ||
"Pauline Koch. In 1880, the family moved to Munich's borough of " | ||
"Ludwigsvorstadt-Isarvorstadt, where Einstein's father and his uncle " | ||
"Jakob founded Elektrotechnische Fabrik J. Einstein & Cie, a company " | ||
"that manufactured electrical equipment based on direct current.[19]\n" | ||
"Albert attended a Catholic elementary school in Munich from the age " | ||
"of five. When he was eight, he was transferred to the Luitpold " | ||
"Gymnasium, where he received advanced primary and then secondary " | ||
"school education.[22]" | ||
), | ||
metadata={ | ||
"language": "en", | ||
"source": "https://en.wikipedia.org/wiki/Albert_Einstein", | ||
"title": "Albert Einstein - Wikipedia", | ||
}, | ||
), | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def grounded_generation_service_client() -> ( | ||
discoveryengine_v1alpha.GroundedGenerationServiceClient | ||
): | ||
return discoveryengine_v1alpha.GroundedGenerationServiceClient() | ||
|
||
|
||
@pytest.fixture | ||
def output_parser( | ||
grounded_generation_service_client: ( | ||
discoveryengine_v1alpha.GroundedGenerationServiceClient | ||
), | ||
) -> VertexAICheckGroundingWrapper: | ||
return VertexAICheckGroundingWrapper( | ||
project_id=os.environ["PROJECT_ID"], | ||
location_id=os.environ.get("REGION", "global"), | ||
grounding_config=os.environ.get("GROUNDING_CONFIG", "default_grounding_config"), | ||
client=grounded_generation_service_client, | ||
) | ||
|
||
|
||
@pytest.mark.extended | ||
def test_integration_parse( | ||
output_parser: VertexAICheckGroundingWrapper, | ||
input_documents: List[Document], | ||
) -> None: | ||
answer_candidate = "Ulm, in the Kingdom of Württemberg in the German Empire" | ||
response = output_parser.with_config( | ||
configurable={"documents": input_documents} | ||
).invoke(answer_candidate) | ||
|
||
assert isinstance(response, VertexAICheckGroundingWrapper.CheckGroundingResponse) | ||
assert response.support_score >= 0 and response.support_score <= 1 | ||
assert len(response.cited_chunks) > 0 | ||
for chunk in response.cited_chunks: | ||
assert isinstance(chunk["chunk_text"], str) | ||
assert isinstance(chunk["source"], Document) | ||
assert len(response.claims) > 0 | ||
for claim in response.claims: | ||
assert isinstance(claim["start_pos"], int) | ||
assert isinstance(claim["end_pos"], int) | ||
assert isinstance(claim["claim_text"], str) | ||
assert isinstance(claim["citation_indices"], list) | ||
assert isinstance(response.answer_with_citations, str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.