From 3c6e23ad95a0c02e7ba411a5c22020db9b04d148 Mon Sep 17 00:00:00 2001 From: Lawrence Tsang Date: Tue, 5 Mar 2024 04:01:29 -0500 Subject: [PATCH] Introducing Google Semantic Retriever and Attributed Question and Answering (AQA) (#48) --- libs/genai/README.md | 31 +- libs/genai/langchain_google_genai/__init__.py | 15 + .../_genai_extension.py | 618 ++++++++++++++++++ .../genai/langchain_google_genai/genai_aqa.py | 134 ++++ .../google_vector_store.py | 493 ++++++++++++++ libs/genai/poetry.lock | 16 +- libs/genai/pyproject.toml | 1 + libs/genai/tests/unit_tests/test_genai_aqa.py | 95 +++ .../unit_tests/test_google_vector_store.py | 440 +++++++++++++ libs/genai/tests/unit_tests/test_imports.py | 6 + 10 files changed, 1845 insertions(+), 4 deletions(-) create mode 100644 libs/genai/langchain_google_genai/_genai_extension.py create mode 100644 libs/genai/langchain_google_genai/genai_aqa.py create mode 100644 libs/genai/langchain_google_genai/google_vector_store.py create mode 100644 libs/genai/tests/unit_tests/test_genai_aqa.py create mode 100644 libs/genai/tests/unit_tests/test_google_vector_store.py diff --git a/libs/genai/README.md b/libs/genai/README.md index 7de2b6f6..101fed80 100644 --- a/libs/genai/README.md +++ b/libs/genai/README.md @@ -75,4 +75,33 @@ from langchain_google_genai import GoogleGenerativeAIEmbeddings embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") embeddings.embed_query("hello, world!") -``` \ No newline at end of file +``` + +## Semantic Retrieval + +Enables retrieval augmented generation (RAG) in your application. + +``` +# Create a new store for housing your documents. +corpus_store = GoogleVectorStore.create_corpus(display_name="My Corpus") + +# Create a new document under the above corpus. +document_store = GoogleVectorStore.create_document( + corpus_id=corpus_store.corpus_id, display_name="My Document" +) + +# Upload some texts to the document. +text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0) +for file in DirectoryLoader(path="data/").load(): + documents = text_splitter.split_documents([file]) + document_store.add_documents(documents) + +# Talk to your entire corpus with possibly many documents. +aqa = corpus_store.as_aqa() +answer = aqa.invoke("What is the meaning of life?") + +# Read the response along with the attributed passages and answerability. +print(response.answer) +print(response.attributed_passages) +print(response.answerable_probability) +``` diff --git a/libs/genai/langchain_google_genai/__init__.py b/libs/genai/langchain_google_genai/__init__.py index 187f7e3e..f7b3ac9a 100644 --- a/libs/genai/langchain_google_genai/__init__.py +++ b/libs/genai/langchain_google_genai/__init__.py @@ -58,12 +58,27 @@ from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory from langchain_google_genai.chat_models import ChatGoogleGenerativeAI from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings +from langchain_google_genai.genai_aqa import ( + AqaInput, + AqaOutput, + GenAIAqa, +) +from langchain_google_genai.google_vector_store import ( + DoesNotExistsException, + GoogleVectorStore, +) from langchain_google_genai.llms import GoogleGenerativeAI __all__ = [ + "AqaInput", + "AqaOutput", "ChatGoogleGenerativeAI", + "DoesNotExistsException", + "GenAIAqa", "GoogleGenerativeAIEmbeddings", "GoogleGenerativeAI", + "GoogleVectorStore", "HarmBlockThreshold", "HarmCategory", + "DoesNotExistsException", ] diff --git a/libs/genai/langchain_google_genai/_genai_extension.py b/libs/genai/langchain_google_genai/_genai_extension.py new file mode 100644 index 00000000..b7b9dc97 --- /dev/null +++ b/libs/genai/langchain_google_genai/_genai_extension.py @@ -0,0 +1,618 @@ +"""Temporary high-level library of the Google GenerativeAI API. + +The content of this file should eventually go into the Python package +google.generativeai. +""" + +import datetime +import logging +import re +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, MutableSequence, Optional + +import google.ai.generativelanguage as genai +import langchain_core +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as gapi_exception +from google.api_core import gapic_v1 +from google.auth import credentials, exceptions # type: ignore +from google.protobuf import timestamp_pb2 + +_logger = logging.getLogger(__name__) +_DEFAULT_API_ENDPOINT = "generativelanguage.googleapis.com" +_USER_AGENT = f"langchain/{langchain_core.__version__}" +_DEFAULT_PAGE_SIZE = 20 +_DEFAULT_GENERATE_SERVICE_MODEL = "models/aqa" +_MAX_REQUEST_PER_CHUNK = 100 +_NAME_REGEX = re.compile(r"^corpora/([^/]+?)(/documents/([^/]+?)(/chunks/([^/]+?))?)?$") + + +@dataclass +class EntityName: + corpus_id: str + document_id: Optional[str] = None + chunk_id: Optional[str] = None + + def __post_init__(self) -> None: + if self.chunk_id is not None and self.document_id is None: + raise ValueError(f"Chunk must have document ID but found {self}") + + @classmethod + def from_str(cls, encoded: str) -> "EntityName": + matched = _NAME_REGEX.match(encoded) + if not matched: + raise ValueError(f"Invalid entity name: {encoded}") + + return cls( + corpus_id=matched.group(1), + document_id=matched.group(3), + chunk_id=matched.group(5), + ) + + def __repr__(self) -> str: + name = f"corpora/{self.corpus_id}" + if self.document_id is None: + return name + name += f"/documents/{self.document_id}" + if self.chunk_id is None: + return name + name += f"/chunks/{self.chunk_id}" + return name + + def __str__(self) -> str: + return repr(self) + + def is_corpus(self) -> bool: + return self.document_id is None + + def is_document(self) -> bool: + return self.document_id is not None and self.chunk_id is None + + def is_chunk(self) -> bool: + return self.chunk_id is not None + + +@dataclass +class Corpus: + name: str + display_name: Optional[str] + create_time: Optional[timestamp_pb2.Timestamp] + update_time: Optional[timestamp_pb2.Timestamp] + + @property + def corpus_id(self) -> str: + name = EntityName.from_str(self.name) + return name.corpus_id + + @classmethod + def from_corpus(cls, c: genai.Corpus) -> "Corpus": + return cls( + name=c.name, + display_name=c.display_name, + create_time=c.create_time, + update_time=c.update_time, + ) + + +@dataclass +class Document: + name: str + display_name: Optional[str] + create_time: Optional[timestamp_pb2.Timestamp] + update_time: Optional[timestamp_pb2.Timestamp] + custom_metadata: Optional[MutableSequence[genai.CustomMetadata]] + + @property + def corpus_id(self) -> str: + name = EntityName.from_str(self.name) + return name.corpus_id + + @property + def document_id(self) -> str: + name = EntityName.from_str(self.name) + assert isinstance(name.document_id, str) + return name.document_id + + @classmethod + def from_document(cls, d: genai.Document) -> "Document": + return cls( + name=d.name, + display_name=d.display_name, + create_time=d.create_time, + update_time=d.update_time, + custom_metadata=d.custom_metadata, + ) + + +@dataclass +class Config: + """Global configuration for Google Generative AI API. + + Normally, the defaults should work fine. Use this to pass Google Auth credentials + such as using a service account. Refer to for auth credentials documentation: + https://developers.google.com/identity/protocols/oauth2/service-account#creatinganaccount. + + Attributes: + api_endpoint: The Google Generative API endpoint address. + user_agent: The user agent to use for logging. + page_size: For paging RPCs, how many entities to return per RPC. + testing: Are the unit tests running? + auth_credentials: For setting credentials such as using service accounts. + """ + + api_endpoint: str = _DEFAULT_API_ENDPOINT + user_agent: str = _USER_AGENT + page_size: int = _DEFAULT_PAGE_SIZE + testing: bool = False + auth_credentials: Optional[credentials.Credentials] = None + + +def set_config(config: Config) -> None: + """Set global defaults for operations with Google Generative AI API.""" + global _config + _config = config + + +def get_config() -> Config: + return _config + + +_config = Config() + + +class TestCredentials(credentials.Credentials): + """Credentials that do not provide any authentication information. + + Useful for unit tests where the credentials are not used. + """ + + @property + def expired(self) -> bool: + """Returns `False`, test credentials never expire.""" + return False + + @property + def valid(self) -> bool: + """Returns `True`, test credentials are always valid.""" + return True + + def refresh(self, request: Any) -> None: + """Raises :class:``InvalidOperation``, test credentials cannot be + refreshed. + """ + raise exceptions.InvalidOperation("Test credentials cannot be refreshed.") + + def apply(self, headers: Any, token: Any = None) -> None: + """Anonymous credentials do nothing to the request. + + The optional ``token`` argument is not supported. + + Raises: + google.auth.exceptions.InvalidValue: If a token was specified. + """ + if token is not None: + raise exceptions.InvalidValue("Test credentials don't support tokens.") + + def before_request(self, request: Any, method: Any, url: Any, headers: Any) -> None: + """Test credentials do nothing to the request.""" + + +def _get_credentials() -> Optional[credentials.Credentials]: + """Returns credential from config if set or fake credentials for unit testing. + + If _config.testing is True, a fake credential is returned. + Otherwise, we are in a real environment and will use credentials if provided + or None is returned. + + If None is passed to the clients later on, the actual credentials will be + inferred by the rules specified in google.auth package. + """ + if _config.testing: + return TestCredentials() + elif _config.auth_credentials: + return _config.auth_credentials + return None + + +def build_semantic_retriever() -> genai.RetrieverServiceClient: + credentials = _get_credentials() + return genai.RetrieverServiceClient( + credentials=credentials, + client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT), + client_options=client_options_lib.ClientOptions( + api_endpoint=_config.api_endpoint + ), + ) + + +def build_generative_service() -> genai.GenerativeServiceClient: + credentials = _get_credentials() + return genai.GenerativeServiceClient( + credentials=credentials, + client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT), + client_options=client_options_lib.ClientOptions( + api_endpoint=_config.api_endpoint + ), + ) + + +def list_corpora( + *, + client: genai.RetrieverServiceClient, +) -> Iterator[Corpus]: + for corpus in client.list_corpora( + genai.ListCorporaRequest(page_size=_config.page_size) + ): + yield Corpus.from_corpus(corpus) + + +def get_corpus( + *, + corpus_id: str, + client: genai.RetrieverServiceClient, +) -> Optional[Corpus]: + try: + corpus = client.get_corpus( + genai.GetCorpusRequest(name=str(EntityName(corpus_id=corpus_id))) + ) + return Corpus.from_corpus(corpus) + except Exception as e: + # If the corpus does not exist, the server returns a permission error. + if not isinstance(e, gapi_exception.PermissionDenied): + raise + _logger.warning(f"Corpus {corpus_id} not found: {e}") + return None + + +def create_corpus( + *, + corpus_id: Optional[str] = None, + display_name: Optional[str] = None, + client: genai.RetrieverServiceClient, +) -> Corpus: + name: Optional[str] + if corpus_id is not None: + name = str(EntityName(corpus_id=corpus_id)) + else: + name = None + + new_display_name = display_name or f"Untitled {datetime.datetime.now()}" + + new_corpus = client.create_corpus( + genai.CreateCorpusRequest( + corpus=genai.Corpus(name=name, display_name=new_display_name) + ) + ) + + return Corpus.from_corpus(new_corpus) + + +def delete_corpus( + *, + corpus_id: str, + client: genai.RetrieverServiceClient, +) -> None: + client.delete_corpus( + genai.DeleteCorpusRequest(name=str(EntityName(corpus_id=corpus_id)), force=True) + ) + + +def list_documents( + *, + corpus_id: str, + client: genai.RetrieverServiceClient, +) -> Iterator[Document]: + for document in client.list_documents( + genai.ListDocumentsRequest( + parent=str(EntityName(corpus_id=corpus_id)), page_size=_DEFAULT_PAGE_SIZE + ) + ): + yield Document.from_document(document) + + +def get_document( + *, + corpus_id: str, + document_id: str, + client: genai.RetrieverServiceClient, +) -> Optional[Document]: + try: + document = client.get_document( + genai.GetDocumentRequest( + name=str(EntityName(corpus_id=corpus_id, document_id=document_id)) + ) + ) + return Document.from_document(document) + except Exception as e: + if not isinstance(e, gapi_exception.NotFound): + raise + _logger.warning(f"Document {document_id} in corpus {corpus_id} not found: {e}") + return None + + +def create_document( + *, + corpus_id: str, + document_id: Optional[str] = None, + display_name: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + client: genai.RetrieverServiceClient, +) -> Document: + name: Optional[str] + if document_id is not None: + name = str(EntityName(corpus_id=corpus_id, document_id=document_id)) + else: + name = None + + new_display_name = display_name or f"Untitled {datetime.datetime.now()}" + new_metadatas = _convert_to_metadata(metadata) if metadata else None + + new_document = client.create_document( + genai.CreateDocumentRequest( + parent=str(EntityName(corpus_id=corpus_id)), + document=genai.Document( + name=name, display_name=new_display_name, custom_metadata=new_metadatas + ), + ) + ) + + return Document.from_document(new_document) + + +def delete_document( + *, + corpus_id: str, + document_id: str, + client: genai.RetrieverServiceClient, +) -> None: + client.delete_document( + genai.DeleteDocumentRequest( + name=str(EntityName(corpus_id=corpus_id, document_id=document_id)), + force=True, + ) + ) + + +def batch_create_chunk( + *, + corpus_id: str, + document_id: str, + texts: List[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + client: genai.RetrieverServiceClient, +) -> List[genai.Chunk]: + if metadatas is None: + metadatas = [{} for _ in texts] + if len(texts) != len(metadatas): + raise ValueError( + f"metadatas's length {len(metadatas)} " + f"and texts's length {len(texts)} are mismatched" + ) + + doc_name = str(EntityName(corpus_id=corpus_id, document_id=document_id)) + + created_chunks: List[genai.Chunk] = [] + + batch_request = genai.BatchCreateChunksRequest( + parent=doc_name, + requests=[], + ) + for text, metadata in zip(texts, metadatas): + batch_request.requests.append( + genai.CreateChunkRequest( + parent=doc_name, + chunk=genai.Chunk( + data=genai.ChunkData(string_value=text), + custom_metadata=_convert_to_metadata(metadata), + ), + ) + ) + + if len(batch_request.requests) >= _MAX_REQUEST_PER_CHUNK: + response = client.batch_create_chunks(batch_request) + created_chunks.extend(list(response.chunks)) + # Prepare a new batch for next round. + batch_request = genai.BatchCreateChunksRequest( + parent=doc_name, + requests=[], + ) + + # Process left over. + if len(batch_request.requests) > 0: + response = client.batch_create_chunks(batch_request) + created_chunks.extend(list(response.chunks)) + + return created_chunks + + +def delete_chunk( + *, + corpus_id: str, + document_id: str, + chunk_id: str, + client: genai.RetrieverServiceClient, +) -> None: + client.delete_chunk( + genai.DeleteChunkRequest( + name=str( + EntityName( + corpus_id=corpus_id, document_id=document_id, chunk_id=chunk_id + ) + ) + ) + ) + + +def query_corpus( + *, + corpus_id: str, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + client: genai.RetrieverServiceClient, +) -> List[genai.RelevantChunk]: + response = client.query_corpus( + genai.QueryCorpusRequest( + name=str(EntityName(corpus_id=corpus_id)), + query=query, + metadata_filters=_convert_filter(filter), + results_count=k, + ) + ) + return list(response.relevant_chunks) + + +def query_document( + *, + corpus_id: str, + document_id: str, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + client: genai.RetrieverServiceClient, +) -> List[genai.RelevantChunk]: + response = client.query_document( + genai.QueryDocumentRequest( + name=str(EntityName(corpus_id=corpus_id, document_id=document_id)), + query=query, + metadata_filters=_convert_filter(filter), + results_count=k, + ) + ) + return list(response.relevant_chunks) + + +@dataclass +class Passage: + text: str + id: str + + +@dataclass +class GroundedAnswer: + answer: str + attributed_passages: List[Passage] + answerable_probability: Optional[float] + + +@dataclass +class GenerateAnswerError(Exception): + finish_reason: genai.Candidate.FinishReason + finish_message: str + safety_ratings: MutableSequence[genai.SafetyRating] + + def __str__(self) -> str: + return ( + f"finish_reason: {self.finish_reason} " + f"finish_message: {self.finish_message} " + f"safety ratings: {self.safety_ratings}" + ) + + +def generate_answer( + *, + prompt: str, + passages: List[str], + answer_style: int = genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + safety_settings: List[genai.SafetySetting] = [], + temperature: Optional[float] = None, + client: genai.GenerativeServiceClient, +) -> GroundedAnswer: + # TODO: Consider passing in the corpus ID instead of the actual + # passages. + response = client.generate_answer( + genai.GenerateAnswerRequest( + contents=[ + genai.Content(parts=[genai.Part(text=prompt)]), + ], + model=_DEFAULT_GENERATE_SERVICE_MODEL, + answer_style=answer_style, + safety_settings=safety_settings, + temperature=temperature, + inline_passages=genai.GroundingPassages( + passages=[ + genai.GroundingPassage( + # IDs here takes alphanumeric only. No dashes allowed. + id=str(index), + content=genai.Content(parts=[genai.Part(text=chunk)]), + ) + for index, chunk in enumerate(passages) + ] + ), + ) + ) + + if response.answer.finish_reason != genai.Candidate.FinishReason.STOP: + finish_message = _get_finish_message(response.answer) + raise GenerateAnswerError( + finish_reason=response.answer.finish_reason, + finish_message=finish_message, + safety_ratings=response.answer.safety_ratings, + ) + + assert len(response.answer.content.parts) == 1 + return GroundedAnswer( + answer=response.answer.content.parts[0].text, + attributed_passages=[ + Passage( + text=passage.content.parts[0].text, + id=passage.source_id.grounding_passage.passage_id, + ) + for passage in response.answer.grounding_attributions + if len(passage.content.parts) > 0 + ], + answerable_probability=response.answerable_probability, + ) + + +# TODO: Use candidate.finish_message when that field is launched. +# For now, we derive this message from other existing fields. +def _get_finish_message(candidate: genai.Candidate) -> str: + finish_messages: Dict[int, str] = { + genai.Candidate.FinishReason.MAX_TOKENS: "Maximum token in context window reached", # noqa: E501 + genai.Candidate.FinishReason.SAFETY: "Blocked because of safety", + genai.Candidate.FinishReason.RECITATION: "Blocked because of recitation", + } + + finish_reason = candidate.finish_reason + if finish_reason not in finish_messages: + return "Unexpected generation error" + + return finish_messages[finish_reason] + + +def _convert_to_metadata(metadata: Dict[str, Any]) -> List[genai.CustomMetadata]: + cs: List[genai.CustomMetadata] = [] + for key, value in metadata.items(): + if isinstance(value, str): + c = genai.CustomMetadata(key=key, string_value=value) + elif isinstance(value, (float, int)): + c = genai.CustomMetadata(key=key, numeric_value=value) + else: + raise ValueError(f"Metadata value {value} is not supported") + + cs.append(c) + return cs + + +def _convert_filter(fs: Optional[Dict[str, Any]]) -> List[genai.MetadataFilter]: + if fs is None: + return [] + assert isinstance(fs, dict) + + filters: List[genai.MetadataFilter] = [] + for key, value in fs.items(): + if isinstance(value, str): + condition = genai.Condition( + operation=genai.Condition.Operator.EQUAL, string_value=value + ) + elif isinstance(value, (float, int)): + condition = genai.Condition( + operation=genai.Condition.Operator.EQUAL, numeric_value=value + ) + else: + raise ValueError(f"Filter value {value} is not supported") + + filters.append(genai.MetadataFilter(key=key, conditions=[condition])) + + return filters diff --git a/libs/genai/langchain_google_genai/genai_aqa.py b/libs/genai/langchain_google_genai/genai_aqa.py new file mode 100644 index 00000000..c339f4d0 --- /dev/null +++ b/libs/genai/langchain_google_genai/genai_aqa.py @@ -0,0 +1,134 @@ +"""Google GenerativeAI Attributed Question and Answering (AQA) service. + +The GenAI Semantic AQA API is a managed end to end service that allows +developers to create responses grounded on specified passages based on +a user query. For more information visit: +https://developers.generativeai.google/guide +""" + +from typing import Any, List, Optional + +import google.ai.generativelanguage as genai +from langchain_core.pydantic_v1 import BaseModel, PrivateAttr +from langchain_core.runnables import RunnableSerializable +from langchain_core.runnables.config import RunnableConfig + +from . import _genai_extension as genaix + + +class AqaInput(BaseModel): + """Input to `GenAIAqa.invoke`. + + Attributes: + prompt: The user's inquiry. + source_passages: A list of passage that the LLM should use only to + answer the user's inquiry. + """ + + prompt: str + source_passages: List[str] + + +class AqaOutput(BaseModel): + """Output from `GenAIAqa.invoke`. + + Attributes: + answer: The answer to the user's inquiry. + attributed_passages: A list of passages that the LLM used to construct + the answer. + answerable_probability: The probability of the question being answered + from the provided passages. + """ + + answer: str + attributed_passages: List[str] + answerable_probability: float + + +class _AqaModel(BaseModel): + """Wrapper for Google's internal AQA model.""" + + _client: genai.GenerativeServiceClient = PrivateAttr() + _answer_style: int = PrivateAttr() + _safety_settings: List[genai.SafetySetting] = PrivateAttr() + _temperature: Optional[float] = PrivateAttr() + + def __init__( + self, + answer_style: int = genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + safety_settings: List[genai.SafetySetting] = [], + temperature: Optional[float] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self._client = genaix.build_generative_service() + self._answer_style = answer_style + self._safety_settings = safety_settings + self._temperature = temperature + + def generate_answer( + self, + prompt: str, + passages: List[str], + ) -> genaix.GroundedAnswer: + return genaix.generate_answer( + prompt=prompt, + passages=passages, + client=self._client, + answer_style=self._answer_style, + safety_settings=self._safety_settings, + temperature=self._temperature, + ) + + +class GenAIAqa(RunnableSerializable[AqaInput, AqaOutput]): + """Google's Attributed Question and Answering service. + + Given a user's query and a list of passages, Google's server will return + a response that is grounded to the provided list of passages. It will not + base the response on parametric memory. + + Attributes: + answer_style: keyword-only argument. See + `google.ai.generativelanguage.AnswerStyle` for details. + """ + + # Actual type is .aqa_model.AqaModel. + _client: _AqaModel = PrivateAttr() + + # Actual type is genai.AnswerStyle. + # 1 = ABSTRACTIVE. + # Cannot use the actual type here because user may not have + # google.generativeai installed. + answer_style: int = 1 + + def __init__(self, **kwargs: Any) -> None: + """Construct a Google Generative AI AQA model. + + All arguments are optional. + + Args: + answer_style: See + `google.ai.generativelanguage.GenerateAnswerRequest.AnswerStyle`. + safety_settings: See `google.ai.generativelanguage.SafetySetting`. + temperature: 0.0 to 1.0. + """ + super().__init__(**kwargs) + self._client = _AqaModel(**kwargs) + + def invoke( + self, input: AqaInput, config: Optional[RunnableConfig] = None + ) -> AqaOutput: + """Generates a grounded response using the provided passages.""" + + response = self._client.generate_answer( + prompt=input.prompt, passages=input.source_passages + ) + + return AqaOutput( + answer=response.answer, + attributed_passages=[ + passage.text for passage in response.attributed_passages + ], + answerable_probability=response.answerable_probability or 0.0, + ) diff --git a/libs/genai/langchain_google_genai/google_vector_store.py b/libs/genai/langchain_google_genai/google_vector_store.py new file mode 100644 index 00000000..79c75d15 --- /dev/null +++ b/libs/genai/langchain_google_genai/google_vector_store.py @@ -0,0 +1,493 @@ +"""Google Generative AI Vector Store. + +The GenAI Semantic Retriever API is a managed end-to-end service that allows +developers to create a corpus of documents to perform semantic search on +related passages given a user query. For more information visit: +https://developers.generativeai.google/guide +""" + +import asyncio +from functools import partial +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, +) + +import google.ai.generativelanguage as genai +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, PrivateAttr +from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough +from langchain_core.vectorstores import VectorStore + +from . import _genai_extension as genaix +from .genai_aqa import ( + AqaInput, + AqaOutput, + GenAIAqa, +) + + +class ServerSideEmbedding(Embeddings): + """Do nothing embedding model where the embedding is done by the server.""" + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [[] for _ in texts] + + def embed_query(self, text: str) -> List[float]: + return [] + + +class DoesNotExistsException(Exception): + def __init__(self, *, corpus_id: str, document_id: Optional[str] = None) -> None: + if document_id is None: + message = f"No such corpus {corpus_id}" + else: + message = f"No such document {document_id} under corpus {corpus_id}" + super().__init__(message) + + +class _SemanticRetriever(BaseModel): + """Wrapper class to Google's internal semantric retriever service.""" + + name: genaix.EntityName + _client: genai.RetrieverServiceClient = PrivateAttr() + + def __init__(self, *, client: genai.RetrieverServiceClient, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._client = client + + @classmethod + def from_ids( + cls, corpus_id: str, document_id: Optional[str] + ) -> "_SemanticRetriever": + name = genaix.EntityName(corpus_id=corpus_id, document_id=document_id) + client = genaix.build_semantic_retriever() + + # Check the entity exists on Google server. + if name.is_corpus(): + if genaix.get_corpus(corpus_id=corpus_id, client=client) is None: + raise DoesNotExistsException(corpus_id=corpus_id) + elif name.is_document(): + assert document_id is not None + if ( + genaix.get_document( + corpus_id=corpus_id, document_id=document_id, client=client + ) + is None + ): + raise DoesNotExistsException( + corpus_id=corpus_id, document_id=document_id + ) + + return cls(name=name, client=client) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + document_id: Optional[str] = None, + ) -> List[str]: + if self.name.document_id is None and document_id is None: + raise NotImplementedError( + "Adding texts to a corpus directly is not supported. " + "Please provide a document ID under the corpus first. " + "Then add the texts to the document." + ) + if ( + self.name.document_id is not None + and document_id is not None + and self.name.document_id != document_id + ): + raise NotImplementedError( + f"Parameter `document_id` {document_id} does not match the " + f"vector store's `document_id` {self.name.document_id}" + ) + assert self.name.document_id or document_id is not None + new_document_id = self.name.document_id or document_id or "" + + texts = list(texts) + if metadatas is None: + metadatas = [{} for _ in texts] + if len(texts) != len(metadatas): + raise ValueError( + f"metadatas's length {len(metadatas)} and " + f"texts's length {len(texts)} are mismatched" + ) + + chunks = genaix.batch_create_chunk( + corpus_id=self.name.corpus_id, + document_id=new_document_id, + texts=texts, + metadatas=metadatas, + client=self._client, + ) + + return [chunk.name for chunk in chunks if chunk.name] + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Tuple[str, float]]: + if self.name.is_corpus(): + relevant_chunks = genaix.query_corpus( + corpus_id=self.name.corpus_id, + query=query, + k=k, + filter=filter, + client=self._client, + ) + else: + assert self.name.is_document() + assert self.name.document_id is not None + relevant_chunks = genaix.query_document( + corpus_id=self.name.corpus_id, + document_id=self.name.document_id, + query=query, + k=k, + filter=filter, + client=self._client, + ) + + return [ + (chunk.chunk.data.string_value, chunk.chunk_relevance_score) + for chunk in relevant_chunks + ] + + def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: + for id in ids or []: + name = genaix.EntityName.from_str(id) + _delete_chunk( + corpus_id=name.corpus_id, + document_id=name.document_id, + chunk_id=name.chunk_id, + client=self._client, + ) + return True + + +def _delete_chunk( + *, + corpus_id: str, + document_id: Optional[str], + chunk_id: Optional[str], + client: genai.RetrieverServiceClient, +) -> None: + if chunk_id is not None: + if document_id is None: + raise ValueError(f"Chunk {chunk_id} requires a document ID") + genaix.delete_chunk( + corpus_id=corpus_id, + document_id=document_id, + chunk_id=chunk_id, + client=client, + ) + elif document_id is not None: + genaix.delete_document( + corpus_id=corpus_id, document_id=document_id, client=client + ) + else: + genaix.delete_corpus(corpus_id=corpus_id, client=client) + + +class GoogleVectorStore(VectorStore): + """Google GenerativeAI Vector Store. + + Currently, it computes the embedding vectors on the server side. + + Example: Add texts to an existing corpus. + + store = GoogleVectorStore(corpus_id="123") + store.add_documents(documents, document_id="456") + + Example: Create a new corpus. + + store = GoogleVectorStore.create_corpus( + corpus_id="123", display_name="My Google corpus") + + Example: Query the corpus for relevant passages. + + store.as_retriever() \ + .get_relevant_documents("Who caught the gingerbread man?") + + Example: Ask the corpus for grounded responses! + + aqa = store.as_aqa() + response = aqa.invoke("Who caught the gingerbread man?") + print(response.answer) + print(response.attributed_passages) + print(response.answerability_probability) + + You can also operate at Google's Document level. + + Example: Add texts to an existing Google Vector Store Document. + + doc_store = GoogleVectorStore(corpus_id="123", document_id="456") + doc_store.add_documents(documents) + + Example: Create a new Google Vector Store Document. + + doc_store = GoogleVectorStore.create_document( + corpus_id="123", document_id="456", display_name="My Google document") + + Example: Query the Google document. + + doc_store.as_retriever() \ + .get_relevant_documents("Who caught the gingerbread man?") + + For more details, see the class's methods. + """ + + _retriever: _SemanticRetriever + + def __init__( + self, *, corpus_id: str, document_id: Optional[str] = None, **kwargs: Any + ): + """Returns an existing Google Semantic Retriever corpus or document. + + If just the corpus ID is provided, the vector store operates over all + documents within that corpus. + + If the document ID is provided, the vector store operates over just that + document. + + Raises: + DoesNotExistsException if the IDs do not match to anything on Google + server. In this case, consider using `create_corpus` or + `create_document` to create one. + """ + super().__init__(**kwargs) + self._retriever = _SemanticRetriever.from_ids(corpus_id, document_id) + + @classmethod + def create_corpus( + cls, + corpus_id: Optional[str] = None, + display_name: Optional[str] = None, + ) -> "GoogleVectorStore": + """Create a Google Semantic Retriever corpus. + + Args: + corpus_id: The ID to use to create the new corpus. If not provided, + Google server will provide one. + display_name: The title of the new corpus. If not provided, Google + server will provide one. + + Returns: + An instance of vector store that points to the newly created corpus. + """ + client = genaix.build_semantic_retriever() + corpus = genaix.create_corpus( + corpus_id=corpus_id, display_name=display_name, client=client + ) + + n = genaix.EntityName.from_str(corpus.name) + return cls(corpus_id=n.corpus_id) + + @classmethod + def create_document( + cls, + corpus_id: str, + document_id: Optional[str] = None, + display_name: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> "GoogleVectorStore": + """Create a Google Semantic Retriever document. + + Args: + corpus_id: ID of an existing corpus. + document_id: The ID to use to create the new Google Semantic + Retriever document. If not provided, Google server will provide + one. + display_name: The title of the new document. If not provided, Google + server will provide one. + + Returns: + An instance of vector store that points to the newly created + document. + """ + client = genaix.build_semantic_retriever() + document = genaix.create_document( + corpus_id=corpus_id, + document_id=document_id, + display_name=display_name, + metadata=metadata, + client=client, + ) + + assert document.name is not None + d = genaix.EntityName.from_str(document.name) + return cls(corpus_id=d.corpus_id, document_id=d.document_id) + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Optional[Embeddings] = None, + metadatas: Optional[List[dict[str, Any]]] = None, + *, + corpus_id: Optional[str] = None, # str required + document_id: Optional[str] = None, # str required + **kwargs: Any, + ) -> "GoogleVectorStore": + """Returns a vector store of an existing document with the specified text. + + Args: + corpus_id: REQUIRED. Must be an existing corpus. + document_id: REQUIRED. Must be an existing document. + texts: Texts to be loaded into the vector store. + + Returns: + A vector store pointing to the specified Google Semantic Retriever + Document. + + Raises: + DoesNotExistsException if the IDs do not match to anything at + Google server. + """ + if corpus_id is None or document_id is None: + raise NotImplementedError( + "Must provide an existing corpus ID and document ID" + ) + + doc_store = cls(corpus_id=corpus_id, document_id=document_id, **kwargs) + doc_store.add_texts(texts, metadatas) + + return doc_store + + @property + def name(self) -> str: + """Returns the name of the Google entity. + + You shouldn't need to care about this unless you want to access your + corpus or document via Google Generative AI API. + """ + return str(self._retriever.name) + + @property + def corpus_id(self) -> str: + """Returns the corpus ID managed by this vector store.""" + return self._retriever.name.corpus_id + + @property + def document_id(self) -> Optional[str]: + """Returns the document ID managed by this vector store.""" + return self._retriever.name.document_id + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + *, + document_id: Optional[str] = None, + **kwargs: Any, + ) -> List[str]: + """Add texts to the vector store. + + If the vector store points to a corpus (instead of a document), you must + also provide a `document_id`. + + Returns: + Chunk's names created on Google servers. + """ + return self._retriever.add_texts(texts, metadatas, document_id) + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Document]: + """Search the vector store for relevant texts.""" + return [ + document + for document, _ in self.similarity_search_with_score( + query, k, filter, **kwargs + ) + ] + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Run similarity search with distance.""" + return [ + (Document(page_content=text), score) + for text, score in self._retriever.similarity_search(query, k, filter) + ] + + def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: + """Delete chunnks. + + Note that the "ids" are not corpus ID or document ID. Rather, these + are the entity names returned by `add_texts`. + + Returns: + True if successful. Otherwise, you should get an exception anyway. + """ + return self._retriever.delete(ids) + + async def adelete( + self, ids: Optional[List[str]] = None, **kwargs: Any + ) -> Optional[bool]: + return await asyncio.get_running_loop().run_in_executor( + None, partial(self.delete, **kwargs), ids + ) + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + """ + TODO: Check with the team about this! + The underlying vector store already returns a "score proper", + i.e. one in [0, 1] where higher means more *similar*. + """ + return lambda score: score + + def as_aqa(self, **kwargs: Any) -> Runnable[str, AqaOutput]: + """Construct a Google Generative AI AQA engine. + + All arguments are optional. + + Args: + answer_style: See + `google.ai.generativelanguage.GenerateAnswerRequest.AnswerStyle`. + safety_settings: See `google.ai.generativelanguage.SafetySetting`. + temperature: 0.0 to 1.0. + """ + return ( + RunnablePassthrough[str]() + | { + "prompt": RunnablePassthrough(), + "passages": self.as_retriever(), + } + | RunnableLambda(_toAqaInput) + | GenAIAqa(**kwargs) + ) + + +def _toAqaInput(input: Dict[str, Any]) -> AqaInput: + prompt = input["prompt"] + assert isinstance(prompt, str) + + passages = input["passages"] + assert isinstance(passages, list) + + source_passages: List[str] = [] + for passage in passages: + assert isinstance(passage, Document) + source_passages.append(passage.page_content) + + return AqaInput( + prompt=prompt, + source_passages=source_passages, + ) diff --git a/libs/genai/poetry.lock b/libs/genai/poetry.lock index 0436b37f..c07d99f5 100644 --- a/libs/genai/poetry.lock +++ b/libs/genai/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -985,7 +985,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1184,6 +1183,17 @@ files = [ {file = "types_Pillow-10.2.0.20240213-py3-none-any.whl", hash = "sha256:062c5a0f20301a30f2df4db583f15b3c2a1283a12518d1f9d81396154e12c1af"}, ] +[[package]] +name = "types-protobuf" +version = "4.24.0.20240302" +description = "Typing stubs for protobuf" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-protobuf-4.24.0.20240302.tar.gz", hash = "sha256:f22c00cc0cea9722e71e14d389bba429af9e35a74a949719c167203a5abbe2e4"}, + {file = "types_protobuf-4.24.0.20240302-py3-none-any.whl", hash = "sha256:5c607990f50f14606c2edaf379f8acc7418fef1451b227aa3c6a8a2cbc6ff14a"}, +] + [[package]] name = "types-requests" version = "2.31.0.20240125" @@ -1273,4 +1283,4 @@ images = ["pillow"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "c454ed507b0fb84e27b6e4d4838f0c1490246f7c5fb196e7a4dd7fd1bad3d731" +content-hash = "42b8221a821ed44e8c98e12c6a2279d2aa3e284415cc0417165eea57bd9bc830" diff --git a/libs/genai/pyproject.toml b/libs/genai/pyproject.toml index 6a152db4..0e24c840 100644 --- a/libs/genai/pyproject.toml +++ b/libs/genai/pyproject.toml @@ -55,6 +55,7 @@ mypy = "^0.991" types-requests = "^2.28.11.5" types-google-cloud-ndb = "^2.2.0.1" types-pillow = "^10.1.0.2" +types-protobuf = "^4.24.0.20240302" [tool.poetry.group.dev] optional = true diff --git a/libs/genai/tests/unit_tests/test_genai_aqa.py b/libs/genai/tests/unit_tests/test_genai_aqa.py new file mode 100644 index 00000000..c8cd521b --- /dev/null +++ b/libs/genai/tests/unit_tests/test_genai_aqa.py @@ -0,0 +1,95 @@ +from unittest.mock import MagicMock, patch + +import google.ai.generativelanguage as genai +import pytest + +from langchain_google_genai import ( + AqaInput, + GenAIAqa, +) +from langchain_google_genai import _genai_extension as genaix + +# Make sure the tests do not hit actual production servers. +genaix.set_config( + genaix.Config( + api_endpoint="No-such-endpoint-to-prevent-hitting-real-backend", + testing=True, + ) +) + + +@pytest.mark.requires("google.ai.generativelanguage") +def test_it_can_be_constructed() -> None: + GenAIAqa() + + +@pytest.mark.requires("google.ai.generativelanguage") +@patch("google.ai.generativelanguage.GenerativeServiceClient.generate_answer") +def test_invoke(mock_generate_answer: MagicMock) -> None: + # Arrange + mock_generate_answer.return_value = genai.GenerateAnswerResponse( + answer=genai.Candidate( + content=genai.Content(parts=[genai.Part(text="42")]), + grounding_attributions=[ + genai.GroundingAttribution( + content=genai.Content( + parts=[genai.Part(text="Meaning of life is 42.")] + ), + source_id=genai.AttributionSourceId( + grounding_passage=genai.AttributionSourceId.GroundingPassageId( + passage_id="corpora/123/documents/456/chunks/789", + part_index=0, + ) + ), + ), + ], + finish_reason=genai.Candidate.FinishReason.STOP, + ), + answerable_probability=0.7, + ) + + # Act + aqa = GenAIAqa( + temperature=0.5, + answer_style=genai.GenerateAnswerRequest.AnswerStyle.EXTRACTIVE, + safety_settings=[ + genai.SafetySetting( + category=genai.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold=genai.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + ) + ], + ) + output = aqa.invoke( + input=AqaInput( + prompt="What is the meaning of life?", + source_passages=["It's 42."], + ) + ) + + # Assert + assert output.answer == "42" + assert output.attributed_passages == ["Meaning of life is 42."] + assert output.answerable_probability == pytest.approx(0.7) + + assert mock_generate_answer.call_count == 1 + request = mock_generate_answer.call_args.args[0] + assert request.contents[0].parts[0].text == "What is the meaning of life?" + + assert request.answer_style == genai.GenerateAnswerRequest.AnswerStyle.EXTRACTIVE + + assert len(request.safety_settings) == 1 + assert ( + request.safety_settings[0].category + == genai.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT + ) + assert ( + request.safety_settings[0].threshold + == genai.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE + ) + + assert request.temperature == 0.5 + + passages = request.inline_passages.passages + assert len(passages) == 1 + passage = passages[0] + assert passage.content.parts[0].text == "It's 42." diff --git a/libs/genai/tests/unit_tests/test_google_vector_store.py b/libs/genai/tests/unit_tests/test_google_vector_store.py new file mode 100644 index 00000000..ba6e183d --- /dev/null +++ b/libs/genai/tests/unit_tests/test_google_vector_store.py @@ -0,0 +1,440 @@ +from unittest.mock import MagicMock, patch + +import google.ai.generativelanguage as genai +import pytest +from langchain_core.documents import Document + +from langchain_google_genai import GoogleVectorStore +from langchain_google_genai import _genai_extension as genaix + +# Make sure the tests do not hit actual production servers. +genaix.set_config( + genaix.Config( + api_endpoint="No-such-endpoint-to-prevent-hitting-real-backend", + testing=True, + ) +) + + +@pytest.mark.requires("google.ai.generativelanguage") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +def test_load_corpus(mock_get_corpus: MagicMock) -> None: + # Arrange + mock_get_corpus.return_value = genai.Corpus(name="corpora/123") + + # Act + store = GoogleVectorStore(corpus_id="123") + + # Assert + assert store.name == "corpora/123" + assert store.corpus_id == "123" + assert store.document_id is None + + +@pytest.mark.requires("google.ai.generativelanguage") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_document") +def test_load_document(mock_get_document: MagicMock) -> None: + # Arrange + mock_get_document.return_value = genai.Document(name="corpora/123/documents/456") + + # Act + store = GoogleVectorStore(corpus_id="123", document_id="456") + + # Assert + assert store.name == "corpora/123/documents/456" + assert store.corpus_id == "123" + assert store.document_id == "456" + + +@pytest.mark.requires("google.ai.generativelanguage") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +@patch("google.ai.generativelanguage.RetrieverServiceClient.create_corpus") +def test_create_corpus( + mock_create_corpus: MagicMock, mock_get_corpus: MagicMock +) -> None: + # Arrange + fake_corpus = genai.Corpus(name="corpora/123", display_name="My Corpus") + mock_create_corpus.return_value = fake_corpus + mock_get_corpus.return_value = fake_corpus + + # Act + store = GoogleVectorStore.create_corpus(display_name="My Corpus") + + # Assert + assert store.name == "corpora/123" + assert store.corpus_id == "123" + assert store.document_id is None + + assert mock_create_corpus.call_count == 1 + + create_request = mock_create_corpus.call_args.args[0] + assert create_request.corpus.name == "" + assert create_request.corpus.display_name == "My Corpus" + + get_request = mock_get_corpus.call_args.args[0] + assert get_request.name == "corpora/123" + + +@pytest.mark.requires("google.ai.generativelanguage") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_document") +@patch("google.ai.generativelanguage.RetrieverServiceClient.create_document") +def test_create_document( + mock_create_document: MagicMock, mock_get_document: MagicMock +) -> None: + # Arrange + fake_document = genai.Document( + name="corpora/123/documents/456", display_name="My Document" + ) + mock_create_document.return_value = fake_document + mock_get_document.return_value = fake_document + + # Act + store = GoogleVectorStore.create_document( + corpus_id="123", display_name="My Document" + ) + + # Assert + assert store.name == "corpora/123/documents/456" + assert store.corpus_id == "123" + assert store.document_id == "456" + + assert mock_create_document.call_count == 1 + + create_request = mock_create_document.call_args.args[0] + assert create_request.parent == "corpora/123" + assert create_request.document.name == "" + assert create_request.document.display_name == "My Document" + + get_request = mock_get_document.call_args.args[0] + assert get_request.name == "corpora/123/documents/456" + + +@pytest.mark.requires("google.ai.generativelanguage") +@patch("google.ai.generativelanguage.RetrieverServiceClient.batch_create_chunks") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_document") +def test_from_texts( + mock_get_document: MagicMock, + mock_batch_create_chunks: MagicMock, +) -> None: + # Arrange + # We will use a max requests per batch to be 2. + # Then, we send 3 requests. + # We expect to have 2 batches where the last batch has only 1 request. + genaix._MAX_REQUEST_PER_CHUNK = 2 + mock_get_document.return_value = genai.Document( + name="corpora/123/documents/456", display_name="My Document" + ) + mock_batch_create_chunks.side_effect = [ + genai.BatchCreateChunksResponse( + chunks=[ + genai.Chunk(name="corpora/123/documents/456/chunks/777"), + genai.Chunk(name="corpora/123/documents/456/chunks/888"), + ] + ), + genai.BatchCreateChunksResponse( + chunks=[ + genai.Chunk(name="corpora/123/documents/456/chunks/999"), + ] + ), + ] + + # Act + store = GoogleVectorStore.from_texts( + texts=[ + "Hello my baby", + "Hello my honey", + "Hello my ragtime gal", + ], + metadatas=[ + {"position": 100}, + {"position": 200}, + {"position": 300}, + ], + corpus_id="123", + document_id="456", + ) + + # Assert + assert store.corpus_id == "123" + assert store.document_id == "456" + + assert mock_batch_create_chunks.call_count == 2 + + first_batch_request = mock_batch_create_chunks.call_args_list[0].args[0] + assert first_batch_request == genai.BatchCreateChunksRequest( + parent="corpora/123/documents/456", + requests=[ + genai.CreateChunkRequest( + parent="corpora/123/documents/456", + chunk=genai.Chunk( + data=genai.ChunkData(string_value="Hello my baby"), + custom_metadata=[ + genai.CustomMetadata( + key="position", + numeric_value=100, + ), + ], + ), + ), + genai.CreateChunkRequest( + parent="corpora/123/documents/456", + chunk=genai.Chunk( + data=genai.ChunkData(string_value="Hello my honey"), + custom_metadata=[ + genai.CustomMetadata( + key="position", + numeric_value=200, + ), + ], + ), + ), + ], + ) + + second_batch_request = mock_batch_create_chunks.call_args_list[1].args[0] + assert second_batch_request == genai.BatchCreateChunksRequest( + parent="corpora/123/documents/456", + requests=[ + genai.CreateChunkRequest( + parent="corpora/123/documents/456", + chunk=genai.Chunk( + data=genai.ChunkData(string_value="Hello my ragtime gal"), + custom_metadata=[ + genai.CustomMetadata( + key="position", + numeric_value=300, + ), + ], + ), + ), + ], + ) + + +@pytest.mark.requires("google.ai.generativelanguage") +@patch("google.ai.generativelanguage.RetrieverServiceClient.query_corpus") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +def test_similarity_search_with_score_on_corpus( + mock_get_corpus: MagicMock, + mock_query_corpus: MagicMock, +) -> None: + # Arrange + mock_get_corpus.return_value = genai.Corpus( + name="corpora/123", display_name="My Corpus" + ) + mock_query_corpus.return_value = genai.QueryCorpusResponse( + relevant_chunks=[ + genai.RelevantChunk( + chunk=genai.Chunk( + name="corpora/123/documents/456/chunks/789", + data=genai.ChunkData(string_value="42"), + ), + chunk_relevance_score=0.9, + ) + ] + ) + + # Act + store = GoogleVectorStore(corpus_id="123") + documents_with_scores = store.similarity_search_with_score( + query="What is the meaning of life?", + k=3, + filter={ + "author": "Arthur Schopenhauer", + "year": 1818, + }, + ) + + # Assert + assert len(documents_with_scores) == 1 + document, relevant_score = documents_with_scores[0] + assert document == Document(page_content="42") + assert relevant_score == pytest.approx(0.9) + + assert mock_query_corpus.call_count == 1 + query_corpus_request = mock_query_corpus.call_args.args[0] + assert query_corpus_request == genai.QueryCorpusRequest( + name="corpora/123", + query="What is the meaning of life?", + metadata_filters=[ + genai.MetadataFilter( + key="author", + conditions=[ + genai.Condition( + operation=genai.Condition.Operator.EQUAL, + string_value="Arthur Schopenhauer", + ) + ], + ), + genai.MetadataFilter( + key="year", + conditions=[ + genai.Condition( + operation=genai.Condition.Operator.EQUAL, + numeric_value=1818, + ) + ], + ), + ], + results_count=3, + ) + + +@pytest.mark.requires("google.ai.generativelanguage") +@patch("google.ai.generativelanguage.RetrieverServiceClient.query_document") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_document") +def test_similarity_search_with_score_on_document( + mock_get_document: MagicMock, + mock_query_document: MagicMock, +) -> None: + # Arrange + mock_get_document.return_value = genai.Document( + name="corpora/123/documents/456", display_name="My Document" + ) + mock_query_document.return_value = genai.QueryCorpusResponse( + relevant_chunks=[ + genai.RelevantChunk( + chunk=genai.Chunk( + name="corpora/123/documents/456/chunks/789", + data=genai.ChunkData(string_value="42"), + ), + chunk_relevance_score=0.9, + ) + ] + ) + + # Act + store = GoogleVectorStore(corpus_id="123", document_id="456") + documents_with_scores = store.similarity_search_with_score( + query="What is the meaning of life?", + k=3, + filter={ + "author": "Arthur Schopenhauer", + "year": 1818, + }, + ) + + # Assert + assert len(documents_with_scores) == 1 + document, relevant_score = documents_with_scores[0] + assert document == Document(page_content="42") + assert relevant_score == pytest.approx(0.9) + + assert mock_query_document.call_count == 1 + query_document_request = mock_query_document.call_args.args[0] + assert query_document_request == genai.QueryDocumentRequest( + name="corpora/123/documents/456", + query="What is the meaning of life?", + metadata_filters=[ + genai.MetadataFilter( + key="author", + conditions=[ + genai.Condition( + operation=genai.Condition.Operator.EQUAL, + string_value="Arthur Schopenhauer", + ) + ], + ), + genai.MetadataFilter( + key="year", + conditions=[ + genai.Condition( + operation=genai.Condition.Operator.EQUAL, + numeric_value=1818, + ) + ], + ), + ], + results_count=3, + ) + + +@pytest.mark.requires("google.ai.generativelanguage") +@patch("google.ai.generativelanguage.RetrieverServiceClient.delete_chunk") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +def test_delete( + mock_get_corpus: MagicMock, + mock_delete_chunk: MagicMock, +) -> None: + # Arrange + mock_get_corpus.return_value = genai.Corpus(name="corpora/123") + + # Act + store = GoogleVectorStore(corpus_id="123") + store.delete( + ids=[ + "corpora/123/documents/456/chunks/1001", + "corpora/123/documents/456/chunks/1002", + ] + ) + + # Assert + assert mock_delete_chunk.call_count == 2 + delete_chunk_requests = mock_delete_chunk.call_args_list + + delete_chunk_request_1 = delete_chunk_requests[0].args[0] + assert delete_chunk_request_1 == genai.DeleteChunkRequest( + name="corpora/123/documents/456/chunks/1001", + ) + + delete_chunk_request_2 = delete_chunk_requests[1].args[0] + assert delete_chunk_request_2 == genai.DeleteChunkRequest( + name="corpora/123/documents/456/chunks/1002", + ) + + +@pytest.mark.requires("google.ai.generativelanguage") +@patch("google.ai.generativelanguage.GenerativeServiceClient.generate_answer") +@patch("google.ai.generativelanguage.RetrieverServiceClient.query_corpus") +@patch("google.ai.generativelanguage.RetrieverServiceClient.get_corpus") +def test_aqa( + mock_get_corpus: MagicMock, + mock_query_corpus: MagicMock, + mock_generate_answer: MagicMock, +) -> None: + # Arrange + mock_get_corpus.return_value = genai.Corpus(name="corpora/123") + mock_query_corpus.return_value = genai.QueryCorpusResponse( + relevant_chunks=[ + genai.RelevantChunk( + chunk=genai.Chunk( + name="corpora/123/documents/456/chunks/789", + data=genai.ChunkData(string_value="42"), + ), + chunk_relevance_score=0.9, + ) + ] + ) + mock_generate_answer.return_value = genai.GenerateAnswerResponse( + answer=genai.Candidate( + content=genai.Content(parts=[genai.Part(text="42")]), + grounding_attributions=[ + genai.GroundingAttribution( + content=genai.Content( + parts=[genai.Part(text="Meaning of life is 42.")] + ), + source_id=genai.AttributionSourceId( + grounding_passage=genai.AttributionSourceId.GroundingPassageId( + passage_id="corpora/123/documents/456/chunks/789", + part_index=0, + ) + ), + ), + ], + finish_reason=genai.Candidate.FinishReason.STOP, + ), + answerable_probability=0.7, + ) + + # Act + store = GoogleVectorStore(corpus_id="123") + aqa = store.as_aqa(answer_style=genai.GenerateAnswerRequest.AnswerStyle.EXTRACTIVE) + response = aqa.invoke("What is the meaning of life?") + + # Assert + assert response.answer == "42" + assert response.attributed_passages == ["Meaning of life is 42."] + assert response.answerable_probability == pytest.approx(0.7) + + request = mock_generate_answer.call_args.args[0] + assert request.answer_style == genai.GenerateAnswerRequest.AnswerStyle.EXTRACTIVE diff --git a/libs/genai/tests/unit_tests/test_imports.py b/libs/genai/tests/unit_tests/test_imports.py index 8c90cb2b..c968ca0a 100644 --- a/libs/genai/tests/unit_tests/test_imports.py +++ b/libs/genai/tests/unit_tests/test_imports.py @@ -1,11 +1,17 @@ from langchain_google_genai import __all__ EXPECTED_ALL = [ + "AqaInput", + "AqaOutput", "ChatGoogleGenerativeAI", + "DoesNotExistsException", + "GenAIAqa", "GoogleGenerativeAIEmbeddings", "GoogleGenerativeAI", + "GoogleVectorStore", "HarmBlockThreshold", "HarmCategory", + "DoesNotExistsException", ]