From 6dcb7135300e6072d1566d9cb0546695047cc7c3 Mon Sep 17 00:00:00 2001 From: jack Date: Tue, 19 Mar 2024 16:13:05 -0500 Subject: [PATCH 1/4] Adding credentials to GoogleGenerativeAI and GoogleGenerativeAIEmbeddings --- .../langchain_google_genai/embeddings.py | 38 +++++++++++++------ libs/genai/langchain_google_genai/llms.py | 35 +++++++++++------ 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/libs/genai/langchain_google_genai/embeddings.py b/libs/genai/langchain_google_genai/embeddings.py index 5e61581e..5da34180 100644 --- a/libs/genai/langchain_google_genai/embeddings.py +++ b/libs/genai/langchain_google_genai/embeddings.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Any # TODO: remove ignore once the google package is published with types import google.generativeai as genai # type: ignore[import] @@ -43,6 +43,13 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings): description="The Google API key to use. If not provided, " "the GOOGLE_API_KEY environment variable will be used.", ) + credentials: Any = Field( + default=None, + exclude=True, + description="The default custom credentials (google.auth.credentials.Credentials) " + "to use when making API calls. If not provided, credentials will be ascertained from " + "the GOOGLE_API_KEY envvar" + ) client_options: Optional[Dict] = Field( None, description=( @@ -58,17 +65,24 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validates params and passes them to google-generativeai package.""" - google_api_key = get_from_dict_or_env( - values, "google_api_key", "GOOGLE_API_KEY" - ) - if isinstance(google_api_key, SecretStr): - google_api_key = google_api_key.get_secret_value() - - genai.configure( - api_key=google_api_key, - transport=values.get("transport"), - client_options=values.get("client_options"), - ) + if values.get("credentials"): + genai.configure( + credentials=values.get("credentials"), + transport=values.get("transport"), + client_options=values.get("client_options"), + ) + else: + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + if isinstance(google_api_key, SecretStr): + google_api_key = google_api_key.get_secret_value() + + genai.configure( + api_key=google_api_key, + transport=values.get("transport"), + client_options=values.get("client_options"), + ) return values def _embed( diff --git a/libs/genai/langchain_google_genai/llms.py b/libs/genai/langchain_google_genai/llms.py index 9b483873..d0a6243d 100644 --- a/libs/genai/langchain_google_genai/llms.py +++ b/libs/genai/langchain_google_genai/llms.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Union import google.api_core +import google.cloud.aiplatform import google.generativeai as genai # type: ignore[import] from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -122,6 +123,10 @@ class _BaseGoogleGenerativeAI(BaseModel): ) """Model name to use.""" google_api_key: Optional[SecretStr] = None + credentials: Any = None + "The default custom credentials (google.auth.credentials.Credentials) to use " + "when making API calls. If not provided, credentials will be ascertained from " + "the GOOGLE_API_KEY envvar" temperature: float = 0.7 """Run inference with this temperature. Must by in the closed interval [0.0, 1.0].""" @@ -203,22 +208,28 @@ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validates params and passes them to google-generativeai package.""" - google_api_key = get_from_dict_or_env( - values, "google_api_key", "GOOGLE_API_KEY" - ) + if values.get("credentials"): + genai.configure( + credentials=values.get("credentials"), + transport=values.get("transport"), + client_options=values.get("client_options"), + ) + else: + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + if isinstance(google_api_key, SecretStr): + google_api_key = google_api_key.get_secret_value() + genai.configure( + api_key=google_api_key, + transport=values.get("transport"), + client_options=values.get("client_options"), + ) + model_name = values["model"] safety_settings = values["safety_settings"] - if isinstance(google_api_key, SecretStr): - google_api_key = google_api_key.get_secret_value() - - genai.configure( - api_key=google_api_key, - transport=values.get("transport"), - client_options=values.get("client_options"), - ) - if safety_settings and ( not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI ): From 2fec60a9710391ba2fd4e8ffd3cb93f582cc15eb Mon Sep 17 00:00:00 2001 From: jack Date: Tue, 19 Mar 2024 18:11:09 -0500 Subject: [PATCH 2/4] Adding credentials to ChatGoogleGenerativeAI class --- .../langchain_google_genai/chat_models.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 4f0e9f0a..3abf688f 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -483,17 +483,24 @@ def is_lc_serializable(self) -> bool: @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validates params and passes them to google-generativeai package.""" - google_api_key = get_from_dict_or_env( - values, "google_api_key", "GOOGLE_API_KEY" - ) - if isinstance(google_api_key, SecretStr): - google_api_key = google_api_key.get_secret_value() + if values.get("credentials"): + genai.configure( + credentials=values.get("credentials"), + transport=values.get("transport"), + client_options=values.get("client_options"), + ) + else: + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + if isinstance(google_api_key, SecretStr): + google_api_key = google_api_key.get_secret_value() - genai.configure( - api_key=google_api_key, - transport=values.get("transport"), - client_options=values.get("client_options"), - ) + genai.configure( + api_key=google_api_key, + transport=values.get("transport"), + client_options=values.get("client_options"), + ) if ( values.get("temperature") is not None and not 0 <= values["temperature"] <= 1 From bad13b63a4f637995e47164b0e6e259f6581c55f Mon Sep 17 00:00:00 2001 From: jack Date: Wed, 20 Mar 2024 21:27:07 -0500 Subject: [PATCH 3/4] Remove errant google.cloud.aiplatform dependency --- libs/genai/langchain_google_genai/llms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/genai/langchain_google_genai/llms.py b/libs/genai/langchain_google_genai/llms.py index d0a6243d..51b3bc3c 100644 --- a/libs/genai/langchain_google_genai/llms.py +++ b/libs/genai/langchain_google_genai/llms.py @@ -4,7 +4,6 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Union import google.api_core -import google.cloud.aiplatform import google.generativeai as genai # type: ignore[import] from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, From 6f9fa7e6003979d4fe3ab3a4f423378c1ded1820 Mon Sep 17 00:00:00 2001 From: jack Date: Wed, 20 Mar 2024 21:31:01 -0500 Subject: [PATCH 4/4] Embeddings formatting --- libs/genai/langchain_google_genai/embeddings.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/genai/langchain_google_genai/embeddings.py b/libs/genai/langchain_google_genai/embeddings.py index 5da34180..09142f24 100644 --- a/libs/genai/langchain_google_genai/embeddings.py +++ b/libs/genai/langchain_google_genai/embeddings.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Any +from typing import Any, Dict, List, Optional # TODO: remove ignore once the google package is published with types import google.generativeai as genai # type: ignore[import] @@ -46,9 +46,9 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings): credentials: Any = Field( default=None, exclude=True, - description="The default custom credentials (google.auth.credentials.Credentials) " - "to use when making API calls. If not provided, credentials will be ascertained from " - "the GOOGLE_API_KEY envvar" + description="The default custom credentials " + "(google.auth.credentials.Credentials) to use when making API calls. If not " + "provided, credentials will be ascertained from the GOOGLE_API_KEY envvar", ) client_options: Optional[Dict] = Field( None,