diff --git a/libs/community/langchain_community/chat_models/javelin_ai_gateway.py b/libs/community/langchain_community/chat_models/javelin_ai_gateway.py index 6b7001b62604a..906c5ffbbdf08 100644 --- a/libs/community/langchain_community/chat_models/javelin_ai_gateway.py +++ b/libs/community/langchain_community/chat_models/javelin_ai_gateway.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Mapping, Optional, cast +from typing import Any, Dict, List, Mapping, Optional from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -19,6 +19,7 @@ ChatResult, ) from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr +from langchain_core.utils import extract_secret_value logger = logging.getLogger(__name__) @@ -85,7 +86,7 @@ def __init__(self, **kwargs: Any): try: self.client = JavelinClient( base_url=self.gateway_uri, - api_key=cast(SecretStr, self.javelin_api_key).get_secret_value(), + api_key=extract_secret_value(self.javelin_api_key), ) except UnauthorizedError as e: raise ValueError("Javelin: Incorrect API Key.") from e @@ -94,7 +95,7 @@ def __init__(self, **kwargs: Any): def _default_params(self) -> Dict[str, Any]: params: Dict[str, Any] = { "gateway_uri": self.gateway_uri, - "javelin_api_key": cast(SecretStr, self.javelin_api_key).get_secret_value(), + "javelin_api_key": extract_secret_value(self.javelin_api_key), "route": self.route, **(self.params.dict() if self.params else {}), } diff --git a/libs/community/langchain_community/chat_models/jinachat.py b/libs/community/langchain_community/chat_models/jinachat.py index b234c1e01db92..21c8377a7d1ad 100644 --- a/libs/community/langchain_community/chat_models/jinachat.py +++ b/libs/community/langchain_community/chat_models/jinachat.py @@ -42,6 +42,7 @@ from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import ( convert_to_secret_str, + extract_secret_value, get_from_dict_or_env, get_pydantic_field_names, ) @@ -397,8 +398,7 @@ async def _agenerate( def _invocation_params(self) -> Mapping[str, Any]: """Get the parameters used to invoke the model.""" jinachat_creds: Dict[str, Any] = { - "api_key": self.jinachat_api_key - and self.jinachat_api_key.get_secret_value(), + "api_key": extract_secret_value(self.jinachat_api_key), "api_base": "https://api.chat.jina.ai/v1", "model": "jinachat", } diff --git a/libs/community/langchain_community/embeddings/javelin_ai_gateway.py b/libs/community/langchain_community/embeddings/javelin_ai_gateway.py index c91a003291d6b..d6d7356bfbc94 100644 --- a/libs/community/langchain_community/embeddings/javelin_ai_gateway.py +++ b/libs/community/langchain_community/embeddings/javelin_ai_gateway.py @@ -3,7 +3,7 @@ from typing import Any, Iterator, List, Optional from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel, SecretStr def _chunk(texts: List[str], size: int) -> Iterator[List[str]]: @@ -38,7 +38,7 @@ class JavelinAIGatewayEmbeddings(Embeddings, BaseModel): gateway_uri: Optional[str] = None """The URI for the Javelin AI Gateway API.""" - javelin_api_key: Optional[str] = None + javelin_api_key: Optional[SecretStr] = None """The API key for the Javelin AI Gateway API.""" def __init__(self, **kwargs: Any): diff --git a/libs/community/langchain_community/embeddings/voyageai.py b/libs/community/langchain_community/embeddings/voyageai.py index 93109d45c65b6..d6848722b836b 100644 --- a/libs/community/langchain_community/embeddings/voyageai.py +++ b/libs/community/langchain_community/embeddings/voyageai.py @@ -10,13 +10,16 @@ Optional, Tuple, Union, - cast, ) import requests from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + convert_to_secret_str, + extract_secret_value, + get_from_dict_or_env, +) from tenacity import ( before_sleep_log, retry, @@ -103,7 +106,7 @@ def validate_environment(cls, values: Dict) -> Dict: def _invocation_params( self, input: List[str], input_type: Optional[str] = None ) -> Dict: - api_key = cast(SecretStr, self.voyage_api_key).get_secret_value() + api_key = extract_secret_value(self.voyage_api_key) params = { "url": self.voyage_api_base, "headers": {"Authorization": f"Bearer {api_key}"}, diff --git a/libs/community/langchain_community/llms/ai21.py b/libs/community/langchain_community/llms/ai21.py index dd86ba516aec8..9ba5dcb462768 100644 --- a/libs/community/langchain_community/llms/ai21.py +++ b/libs/community/langchain_community/llms/ai21.py @@ -1,10 +1,14 @@ -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + convert_to_secret_str, + extract_secret_value, + get_from_dict_or_env, +) class AI21PenaltyData(BaseModel): @@ -142,10 +146,12 @@ def _call( else: base_url = "https://api.ai21.com/studio/v1" params = {**self._default_params, **kwargs} - self.ai21_api_key = cast(SecretStr, self.ai21_api_key) + self.ai21_api_key = convert_to_secret_str(self.ai21_api_key) response = requests.post( url=f"{base_url}/{self.model}/complete", - headers={"Authorization": f"Bearer {self.ai21_api_key.get_secret_value()}"}, + headers={ + "Authorization": f"Bearer {extract_secret_value(self.ai21_api_key)}" + }, json={"prompt": prompt, "stopSequences": stop, **params}, ) if response.status_code != 200: diff --git a/libs/community/langchain_community/llms/anyscale.py b/libs/community/langchain_community/llms/anyscale.py index b994c425ab9c5..7c717dbfb607d 100644 --- a/libs/community/langchain_community/llms/anyscale.py +++ b/libs/community/langchain_community/llms/anyscale.py @@ -9,7 +9,6 @@ Optional, Set, Tuple, - cast, ) from langchain_core.callbacks import ( @@ -18,7 +17,11 @@ ) from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + convert_to_secret_str, + extract_secret_value, + get_from_dict_or_env, +) from langchain_community.llms.openai import ( BaseOpenAI, @@ -132,7 +135,7 @@ def _identifying_params(self) -> Mapping[str, Any]: def _invocation_params(self) -> Dict[str, Any]: """Get the parameters used to invoke the model.""" openai_creds: Dict[str, Any] = { - "api_key": cast(SecretStr, self.anyscale_api_key).get_secret_value(), + "api_key": extract_secret_value(self.anyscale_api_key), "api_base": self.anyscale_api_base, } return {**openai_creds, **{"model": self.model_name}, **super()._default_params} diff --git a/libs/community/langchain_community/llms/arcee.py b/libs/community/langchain_community/llms/arcee.py index cab21c60e6811..1424b09542bc0 100644 --- a/libs/community/langchain_community/llms/arcee.py +++ b/libs/community/langchain_community/llms/arcee.py @@ -1,9 +1,13 @@ -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Union from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + convert_to_secret_str, + extract_secret_value, + get_from_dict_or_env, +) from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter @@ -66,7 +70,7 @@ def __init__(self, **data: Any) -> None: """Initializes private fields.""" super().__init__(**data) - api_key = cast(SecretStr, self.arcee_api_key) + api_key = extract_secret_value(self.arcee_api_key) self._client = ArceeWrapper( arcee_api_key=api_key, arcee_api_url=self.arcee_api_url, diff --git a/libs/community/langchain_community/llms/cerebriumai.py b/libs/community/langchain_community/llms/cerebriumai.py index c9e219995ae1b..9ffc498a80769 100644 --- a/libs/community/langchain_community/llms/cerebriumai.py +++ b/libs/community/langchain_community/llms/cerebriumai.py @@ -1,11 +1,15 @@ import logging -from typing import Any, Dict, List, Mapping, Optional, cast +from typing import Any, Dict, List, Mapping, Optional import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + convert_to_secret_str, + extract_secret_value, + get_from_dict_or_env, +) from langchain_community.llms.utils import enforce_stop_tokens @@ -92,9 +96,7 @@ def _call( **kwargs: Any, ) -> str: headers: Dict = { - "Authorization": cast( - SecretStr, self.cerebriumai_api_key - ).get_secret_value(), + "Authorization": extract_secret_value(self.cerebriumai_api_key), "Content-Type": "application/json", } params = self.model_kwargs or {} diff --git a/libs/community/langchain_community/utilities/google_finance.py b/libs/community/langchain_community/utilities/google_finance.py index 95e5b0c2e1552..70d333ce9ac3d 100644 --- a/libs/community/langchain_community/utilities/google_finance.py +++ b/libs/community/langchain_community/utilities/google_finance.py @@ -1,8 +1,12 @@ """Util that calls Google Finance Search.""" -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + convert_to_secret_str, + extract_secret_value, + get_from_dict_or_env, +) class GoogleFinanceAPIWrapper(BaseModel): @@ -52,10 +56,9 @@ def validate_environment(cls, values: Dict) -> Dict: def run(self, query: str) -> str: """Run query through Google Finance with Serpapi""" - serpapi_api_key = cast(SecretStr, self.serp_api_key) params = { "engine": "google_finance", - "api_key": serpapi_api_key.get_secret_value(), + "api_key": extract_secret_value(self.serpapi_api_key), "q": query, } diff --git a/libs/community/langchain_community/utilities/google_jobs.py b/libs/community/langchain_community/utilities/google_jobs.py index d532cbffd022b..7b1bc0e10f60a 100644 --- a/libs/community/langchain_community/utilities/google_jobs.py +++ b/libs/community/langchain_community/utilities/google_jobs.py @@ -1,8 +1,12 @@ """Util that calls Google Scholar Search.""" -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + convert_to_secret_str, + extract_secret_value, + get_from_dict_or_env, +) class GoogleJobsAPIWrapper(BaseModel): @@ -54,10 +58,9 @@ def run(self, query: str) -> str: """Run query through Google Trends with Serpapi""" # set up query - serpapi_api_key = cast(SecretStr, self.serp_api_key) params = { "engine": "google_jobs", - "api_key": serpapi_api_key.get_secret_value(), + "api_key": extract_secret_value(self.serpapi_api_key), "q": query, } diff --git a/libs/community/langchain_community/utilities/google_lens.py b/libs/community/langchain_community/utilities/google_lens.py index f8419bdb7585d..f3e60e78911b0 100644 --- a/libs/community/langchain_community/utilities/google_lens.py +++ b/libs/community/langchain_community/utilities/google_lens.py @@ -1,9 +1,13 @@ """Util that calls Google Lens Search.""" -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional import requests from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + convert_to_secret_str, + extract_secret_value, + get_from_dict_or_env, +) class GoogleLensAPIWrapper(BaseModel): @@ -45,11 +49,9 @@ def validate_environment(cls, values: Dict) -> Dict: def run(self, query: str) -> str: """Run query through Google Trends with Serpapi""" - serpapi_api_key = cast(SecretStr, self.serp_api_key) - params = { "engine": "google_lens", - "api_key": serpapi_api_key.get_secret_value(), + "api_key": extract_secret_value(self.serpapi_api_key), "url": query, } queryURL = f"https://serpapi.com/search?engine={params['engine']}&api_key={params['api_key']}&url={params['url']}" diff --git a/libs/community/langchain_community/utilities/google_trends.py b/libs/community/langchain_community/utilities/google_trends.py index f0f15000c8a1c..f7ed40c52bf63 100644 --- a/libs/community/langchain_community/utilities/google_trends.py +++ b/libs/community/langchain_community/utilities/google_trends.py @@ -1,8 +1,12 @@ """Util that calls Google Scholar Search.""" -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + convert_to_secret_str, + extract_secret_value, + get_from_dict_or_env, +) class GoogleTrendsAPIWrapper(BaseModel): @@ -56,10 +60,10 @@ def validate_environment(cls, values: Dict) -> Dict: def run(self, query: str) -> str: """Run query through Google Trends with Serpapi""" - serpapi_api_key = cast(SecretStr, self.serp_api_key) + api_key = extract_secret_value(self.serpapi_api_key) params = { "engine": "google_trends", - "api_key": serpapi_api_key.get_secret_value(), + "api_key": api_key, "q": query, } @@ -86,7 +90,7 @@ def run(self, query: str) -> str: params = { "engine": "google_trends", - "api_key": serpapi_api_key.get_secret_value(), + "api_key": api_key, "data_type": "RELATED_QUERIES", "q": query, } diff --git a/libs/community/tests/integration_tests/chat_models/test_jinachat.py b/libs/community/tests/integration_tests/chat_models/test_jinachat.py index 0d704ce386726..86ac3bd68f8ca 100644 --- a/libs/community/tests/integration_tests/chat_models/test_jinachat.py +++ b/libs/community/tests/integration_tests/chat_models/test_jinachat.py @@ -1,7 +1,5 @@ """Test JinaChat wrapper.""" -from typing import cast - import pytest from langchain_core.callbacks import CallbackManager from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage @@ -13,7 +11,7 @@ from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler -def test_jinachat_api_key_is_secret_string() -> None: +def test_jinachat_api_key_is_secretstr() -> None: llm = JinaChat(jinachat_api_key="secret-api-key") assert isinstance(llm.jinachat_api_key, SecretStr) @@ -44,7 +42,7 @@ def test_jinachat_api_key_masked_when_passed_via_constructor( def test_uses_actual_secret_value_from_secretstr() -> None: """Test that actual secret is retrieved using `.get_secret_value()`.""" llm = JinaChat(jinachat_api_key="secret-api-key") - assert cast(SecretStr, llm.jinachat_api_key).get_secret_value() == "secret-api-key" + assert llm.jinachat_api_key.get_secret_value() == "secret-api-key" def test_jinachat() -> None: diff --git a/libs/community/tests/integration_tests/llms/test_nlpcloud.py b/libs/community/tests/integration_tests/llms/test_nlpcloud.py index 25aff55f7580f..c18b578ebdccd 100644 --- a/libs/community/tests/integration_tests/llms/test_nlpcloud.py +++ b/libs/community/tests/integration_tests/llms/test_nlpcloud.py @@ -1,7 +1,6 @@ """Test NLPCloud API wrapper.""" from pathlib import Path -from typing import cast from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch @@ -34,8 +33,7 @@ def test_nlpcloud_api_key(monkeypatch: MonkeyPatch, capsys: CaptureFixture) -> N monkeypatch.setenv("NLPCLOUD_API_KEY", "secret-api-key") llm = NLPCloud() assert isinstance(llm.nlpcloud_api_key, SecretStr) - - assert cast(SecretStr, llm.nlpcloud_api_key).get_secret_value() == "secret-api-key" + assert llm.nlpcloud_api_key.get_secret_value() == "secret-api-key" print(llm.nlpcloud_api_key, end="") captured = capsys.readouterr() diff --git a/libs/community/tests/unit_tests/chat_models/test_azureml_endpoint.py b/libs/community/tests/unit_tests/chat_models/test_azureml_endpoint.py index 0cbe2ab9b1749..2b651603ef1cb 100644 --- a/libs/community/tests/unit_tests/chat_models/test_azureml_endpoint.py +++ b/libs/community/tests/unit_tests/chat_models/test_azureml_endpoint.py @@ -37,7 +37,7 @@ def api_passed_via_constructor_fixture() -> AzureMLChatOnlineEndpoint: ["api_passed_via_constructor_fixture", "api_passed_via_environment_fixture"], ) class TestAzureMLChatOnlineEndpoint: - def test_api_key_is_secret_string( + def test_api_key_is_secretstr( self, fixture_name: str, request: FixtureRequest ) -> None: """Test that the API key is a SecretStr instance""" diff --git a/libs/community/tests/unit_tests/chat_models/test_baichuan.py b/libs/community/tests/unit_tests/chat_models/test_baichuan.py index 6a4b4d2009cf4..e1c4a281135ac 100644 --- a/libs/community/tests/unit_tests/chat_models/test_baichuan.py +++ b/libs/community/tests/unit_tests/chat_models/test_baichuan.py @@ -1,4 +1,4 @@ -from typing import cast +"""Test ChatBaichuan wrapper.""" import pytest from langchain_core.messages import ( @@ -137,13 +137,10 @@ def test_baichuan_key_masked_when_passed_via_constructor( assert captured.out == "**********" -def test_uses_actual_secret_value_from_secret_str() -> None: +def test_uses_actual_secret_value_from_secretstr() -> None: """Test that actual secret is retrieved using `.get_secret_value()`.""" chat = ChatBaichuan( baichuan_api_key="test-api-key", baichuan_secret_key="test-secret-key" ) - assert cast(SecretStr, chat.baichuan_api_key).get_secret_value() == "test-api-key" - assert ( - cast(SecretStr, chat.baichuan_secret_key).get_secret_value() - == "test-secret-key" - ) + assert chat.baichuan_api_key.get_secret_value() == "test-api-key" + assert chat.baichuan_secret_key.get_secret_value() == "test-secret-key" diff --git a/libs/community/tests/unit_tests/chat_models/test_javelin_ai_gateway.py b/libs/community/tests/unit_tests/chat_models/test_javelin_ai_gateway.py index 7c4500d340d5f..1a30c41ccfcb3 100644 --- a/libs/community/tests/unit_tests/chat_models/test_javelin_ai_gateway.py +++ b/libs/community/tests/unit_tests/chat_models/test_javelin_ai_gateway.py @@ -7,7 +7,7 @@ @pytest.mark.requires("javelin_sdk") -def test_api_key_is_secret_string() -> None: +def test_api_key_is_secretstr() -> None: llm = ChatJavelinAIGateway( gateway_uri="", route="", diff --git a/libs/community/tests/unit_tests/llms/test_ai21.py b/libs/community/tests/unit_tests/llms/test_ai21.py index 60146a973d20e..bfc8c139e70fd 100644 --- a/libs/community/tests/unit_tests/llms/test_ai21.py +++ b/libs/community/tests/unit_tests/llms/test_ai21.py @@ -1,5 +1,4 @@ """Test AI21 llm""" -from typing import cast from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch @@ -7,7 +6,7 @@ from langchain_community.llms.ai21 import AI21 -def test_api_key_is_secret_string() -> None: +def test_api_key_is_secretstr() -> None: llm = AI21(ai21_api_key="secret-api-key") assert isinstance(llm.ai21_api_key, SecretStr) @@ -38,4 +37,4 @@ def test_api_key_masked_when_passed_via_constructor( def test_uses_actual_secret_value_from_secretstr() -> None: """Test that actual secret is retrieved using `.get_secret_value()`.""" llm = AI21(ai21_api_key="secret-api-key") - assert cast(SecretStr, llm.ai21_api_key).get_secret_value() == "secret-api-key" + assert llm.ai21_api_key.get_secret_value() == "secret-api-key" diff --git a/libs/community/tests/unit_tests/llms/test_aleph_alpha.py b/libs/community/tests/unit_tests/llms/test_aleph_alpha.py index c08d0e7292428..98afd7f53392f 100644 --- a/libs/community/tests/unit_tests/llms/test_aleph_alpha.py +++ b/libs/community/tests/unit_tests/llms/test_aleph_alpha.py @@ -8,7 +8,7 @@ @pytest.mark.requires("aleph_alpha_client") -def test_api_key_is_secret_string() -> None: +def test_api_key_is_secretstr() -> None: llm = AlephAlpha(aleph_alpha_api_key="secret-api-key") assert isinstance(llm.aleph_alpha_api_key, SecretStr) diff --git a/libs/community/tests/unit_tests/llms/test_anyscale.py b/libs/community/tests/unit_tests/llms/test_anyscale.py index 9dc38b7c20d0e..2e4d981330f71 100644 --- a/libs/community/tests/unit_tests/llms/test_anyscale.py +++ b/libs/community/tests/unit_tests/llms/test_anyscale.py @@ -7,7 +7,7 @@ @pytest.mark.requires("openai") -def test_api_key_is_secret_string() -> None: +def test_api_key_is_secretstr() -> None: llm = Anyscale( anyscale_api_key="secret-api-key", anyscale_api_base="test", model_name="test" ) diff --git a/libs/community/tests/unit_tests/llms/test_cerebriumai.py b/libs/community/tests/unit_tests/llms/test_cerebriumai.py index ff9da5745f2a0..b9136419d8d40 100644 --- a/libs/community/tests/unit_tests/llms/test_cerebriumai.py +++ b/libs/community/tests/unit_tests/llms/test_cerebriumai.py @@ -7,7 +7,7 @@ from langchain_community.llms.cerebriumai import CerebriumAI -def test_api_key_is_secret_string() -> None: +def test_api_key_is_secretstr() -> None: llm = CerebriumAI(cerebriumai_api_key="test-cerebriumai-api-key") assert isinstance(llm.cerebriumai_api_key, SecretStr) diff --git a/libs/community/tests/unit_tests/llms/test_forefrontai.py b/libs/community/tests/unit_tests/llms/test_forefrontai.py index 567b78180c2ff..90fda587860ec 100644 --- a/libs/community/tests/unit_tests/llms/test_forefrontai.py +++ b/libs/community/tests/unit_tests/llms/test_forefrontai.py @@ -1,5 +1,4 @@ """Test ForeFrontAI LLM""" -from typing import cast from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch @@ -7,7 +6,7 @@ from langchain_community.llms.forefrontai import ForefrontAI -def test_forefrontai_api_key_is_secret_string() -> None: +def test_forefrontai_api_key_is_secretstr() -> None: """Test that the API key is stored as a SecretStr.""" llm = ForefrontAI(forefrontai_api_key="secret-api-key", temperature=0.2) assert isinstance(llm.forefrontai_api_key, SecretStr) @@ -45,6 +44,4 @@ def test_forefrontai_uses_actual_secret_value_from_secretstr() -> None: forefrontai_api_key="secret-api-key", temperature=0.2, ) - assert ( - cast(SecretStr, llm.forefrontai_api_key).get_secret_value() == "secret-api-key" - ) + assert llm.forefrontai_api_key.get_secret_value() == "secret-api-key" diff --git a/libs/community/tests/unit_tests/llms/test_gooseai.py b/libs/community/tests/unit_tests/llms/test_gooseai.py index e87dd6f40d39c..4176bae655627 100644 --- a/libs/community/tests/unit_tests/llms/test_gooseai.py +++ b/libs/community/tests/unit_tests/llms/test_gooseai.py @@ -16,7 +16,7 @@ def _openai_v1_installed() -> bool: @pytest.mark.requires("openai") -def test_api_key_is_secret_string() -> None: +def test_api_key_is_secretstr() -> None: llm = GooseAI(gooseai_api_key="secret-api-key") assert isinstance(llm.gooseai_api_key, SecretStr) assert llm.gooseai_api_key.get_secret_value() == "secret-api-key" diff --git a/libs/community/tests/unit_tests/llms/test_minimax.py b/libs/community/tests/unit_tests/llms/test_minimax.py index 24e7fa972c474..38d517bfb5ea9 100644 --- a/libs/community/tests/unit_tests/llms/test_minimax.py +++ b/libs/community/tests/unit_tests/llms/test_minimax.py @@ -1,5 +1,4 @@ """Test Minimax llm""" -from typing import cast from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch @@ -7,7 +6,7 @@ from langchain_community.llms.minimax import Minimax -def test_api_key_is_secret_string() -> None: +def test_api_key_is_secretstr() -> None: llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id") assert isinstance(llm.minimax_api_key, SecretStr) @@ -39,4 +38,4 @@ def test_api_key_masked_when_passed_via_constructor( def test_uses_actual_secret_value_from_secretstr() -> None: """Test that actual secret is retrieved using `.get_secret_value()`.""" llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id") - assert cast(SecretStr, llm.minimax_api_key).get_secret_value() == "secret-api-key" + assert llm.minimax_api_key.get_secret_value() == "secret-api-key" diff --git a/libs/community/tests/unit_tests/llms/test_symblai_nebula.py b/libs/community/tests/unit_tests/llms/test_symblai_nebula.py index fb75a6ad6998c..673036e49cbb6 100644 --- a/libs/community/tests/unit_tests/llms/test_symblai_nebula.py +++ b/libs/community/tests/unit_tests/llms/test_symblai_nebula.py @@ -6,7 +6,7 @@ from langchain_community.llms.symblai_nebula import Nebula -def test_api_key_is_secret_string() -> None: +def test_api_key_is_secretstr() -> None: llm = Nebula(nebula_api_key="secret-api-key") assert isinstance(llm.nebula_api_key, SecretStr) assert llm.nebula_api_key.get_secret_value() == "secret-api-key" diff --git a/libs/community/tests/unit_tests/llms/test_together.py b/libs/community/tests/unit_tests/llms/test_together.py index 772cde4050ee2..3d52cd89300e8 100644 --- a/libs/community/tests/unit_tests/llms/test_together.py +++ b/libs/community/tests/unit_tests/llms/test_together.py @@ -1,5 +1,4 @@ """Test Together LLM""" -from typing import cast from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch @@ -7,7 +6,7 @@ from langchain_community.llms.together import Together -def test_together_api_key_is_secret_string() -> None: +def test_together_api_key_is_secretstr() -> None: """Test that the API key is stored as a SecretStr.""" llm = Together( together_api_key="secret-api-key", @@ -58,4 +57,4 @@ def test_together_uses_actual_secret_value_from_secretstr() -> None: temperature=0.2, max_tokens=250, ) - assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key" + assert llm.together_api_key.get_secret_value() == "secret-api-key" diff --git a/libs/core/langchain_core/utils/__init__.py b/libs/core/langchain_core/utils/__init__.py index 6491a85f17fb7..4ea744deedffe 100644 --- a/libs/core/langchain_core/utils/__init__.py +++ b/libs/core/langchain_core/utils/__init__.py @@ -18,6 +18,7 @@ build_extra_kwargs, check_package_version, convert_to_secret_str, + extract_secret_value, get_pydantic_field_names, guard_import, mock_now, @@ -29,6 +30,7 @@ "StrictFormatter", "check_package_version", "convert_to_secret_str", + "extract_secret_value", "formatter", "get_bolded_text", "get_color_mapping", diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 9b63ddf3ea63f..d00dc4f159c39 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -173,8 +173,13 @@ def build_extra_kwargs( return extra_kwargs -def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr: +def convert_to_secret_str(value: Union[SecretStr, str, None]) -> SecretStr: """Convert a string to a SecretStr if needed.""" if isinstance(value, SecretStr): return value - return SecretStr(value) + return SecretStr(str(value)) + + +def extract_secret_value(secret: Union[SecretStr, str, None]) -> str: + """Extract the SecretStr from all types.""" + return convert_to_secret_str(secret).get_secret_value() diff --git a/libs/core/tests/unit_tests/utils/test_imports.py b/libs/core/tests/unit_tests/utils/test_imports.py index ce56c02026f30..a1dfeb3a77fe4 100644 --- a/libs/core/tests/unit_tests/utils/test_imports.py +++ b/libs/core/tests/unit_tests/utils/test_imports.py @@ -4,6 +4,7 @@ "StrictFormatter", "check_package_version", "convert_to_secret_str", + "extract_secret_value", "formatter", "get_bolded_text", "get_color_mapping",