From d7bc26ab2ccb896614b74024b820bc32c9dd396e Mon Sep 17 00:00:00 2001 From: Jack Klika Date: Thu, 21 Mar 2024 14:14:55 -0500 Subject: [PATCH] Add credentials parameter to GoogleGenerativeAI, ChatGoogleGenerativeAI, and GoogleGenerativeAIEmbedded (#78) * Adding credentials to GoogleGenerativeAI and GoogleGenerativeAIEmbeddings --------- Co-authored-by: jack --- .../langchain_google_genai/chat_models.py | 27 ++++++++----- .../langchain_google_genai/embeddings.py | 38 +++++++++++++------ libs/genai/langchain_google_genai/llms.py | 34 +++++++++++------ 3 files changed, 65 insertions(+), 34 deletions(-) diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 4f0e9f0a..3abf688f 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -483,17 +483,24 @@ def is_lc_serializable(self) -> bool: @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validates params and passes them to google-generativeai package.""" - google_api_key = get_from_dict_or_env( - values, "google_api_key", "GOOGLE_API_KEY" - ) - if isinstance(google_api_key, SecretStr): - google_api_key = google_api_key.get_secret_value() + if values.get("credentials"): + genai.configure( + credentials=values.get("credentials"), + transport=values.get("transport"), + client_options=values.get("client_options"), + ) + else: + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + if isinstance(google_api_key, SecretStr): + google_api_key = google_api_key.get_secret_value() - genai.configure( - api_key=google_api_key, - transport=values.get("transport"), - client_options=values.get("client_options"), - ) + genai.configure( + api_key=google_api_key, + transport=values.get("transport"), + client_options=values.get("client_options"), + ) if ( values.get("temperature") is not None and not 0 <= values["temperature"] <= 1 diff --git a/libs/genai/langchain_google_genai/embeddings.py b/libs/genai/langchain_google_genai/embeddings.py index 5e61581e..09142f24 100644 --- a/libs/genai/langchain_google_genai/embeddings.py +++ b/libs/genai/langchain_google_genai/embeddings.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional # TODO: remove ignore once the google package is published with types import google.generativeai as genai # type: ignore[import] @@ -43,6 +43,13 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings): description="The Google API key to use. If not provided, " "the GOOGLE_API_KEY environment variable will be used.", ) + credentials: Any = Field( + default=None, + exclude=True, + description="The default custom credentials " + "(google.auth.credentials.Credentials) to use when making API calls. If not " + "provided, credentials will be ascertained from the GOOGLE_API_KEY envvar", + ) client_options: Optional[Dict] = Field( None, description=( @@ -58,17 +65,24 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validates params and passes them to google-generativeai package.""" - google_api_key = get_from_dict_or_env( - values, "google_api_key", "GOOGLE_API_KEY" - ) - if isinstance(google_api_key, SecretStr): - google_api_key = google_api_key.get_secret_value() - - genai.configure( - api_key=google_api_key, - transport=values.get("transport"), - client_options=values.get("client_options"), - ) + if values.get("credentials"): + genai.configure( + credentials=values.get("credentials"), + transport=values.get("transport"), + client_options=values.get("client_options"), + ) + else: + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + if isinstance(google_api_key, SecretStr): + google_api_key = google_api_key.get_secret_value() + + genai.configure( + api_key=google_api_key, + transport=values.get("transport"), + client_options=values.get("client_options"), + ) return values def _embed( diff --git a/libs/genai/langchain_google_genai/llms.py b/libs/genai/langchain_google_genai/llms.py index 9b483873..51b3bc3c 100644 --- a/libs/genai/langchain_google_genai/llms.py +++ b/libs/genai/langchain_google_genai/llms.py @@ -122,6 +122,10 @@ class _BaseGoogleGenerativeAI(BaseModel): ) """Model name to use.""" google_api_key: Optional[SecretStr] = None + credentials: Any = None + "The default custom credentials (google.auth.credentials.Credentials) to use " + "when making API calls. If not provided, credentials will be ascertained from " + "the GOOGLE_API_KEY envvar" temperature: float = 0.7 """Run inference with this temperature. Must by in the closed interval [0.0, 1.0].""" @@ -203,22 +207,28 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validates params and passes them to google-generativeai package.""" - google_api_key = get_from_dict_or_env( - values, "google_api_key", "GOOGLE_API_KEY" - ) + if values.get("credentials"): + genai.configure( + credentials=values.get("credentials"), + transport=values.get("transport"), + client_options=values.get("client_options"), + ) + else: + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + if isinstance(google_api_key, SecretStr): + google_api_key = google_api_key.get_secret_value() + genai.configure( + api_key=google_api_key, + transport=values.get("transport"), + client_options=values.get("client_options"), + ) + model_name = values["model"] safety_settings = values["safety_settings"] - if isinstance(google_api_key, SecretStr): - google_api_key = google_api_key.get_secret_value() - - genai.configure( - api_key=google_api_key, - transport=values.get("transport"), - client_options=values.get("client_options"), - ) - if safety_settings and ( not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI ):