Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: redundant casting to SecretStr API keys #14446

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Mapping, Optional, cast
from typing import Any, Dict, List, Mapping, Optional

from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
Expand All @@ -19,6 +19,7 @@
ChatResult,
)
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr
from langchain_core.utils import extract_secret_value

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,7 +86,7 @@ def __init__(self, **kwargs: Any):
try:
self.client = JavelinClient(
base_url=self.gateway_uri,
api_key=cast(SecretStr, self.javelin_api_key).get_secret_value(),
api_key=extract_secret_value(self.javelin_api_key),
)
except UnauthorizedError as e:
raise ValueError("Javelin: Incorrect API Key.") from e
Expand All @@ -94,7 +95,7 @@ def __init__(self, **kwargs: Any):
def _default_params(self) -> Dict[str, Any]:
params: Dict[str, Any] = {
"gateway_uri": self.gateway_uri,
"javelin_api_key": cast(SecretStr, self.javelin_api_key).get_secret_value(),
"javelin_api_key": extract_secret_value(self.javelin_api_key),
"route": self.route,
**(self.params.dict() if self.params else {}),
}
Expand Down
4 changes: 2 additions & 2 deletions libs/community/langchain_community/chat_models/jinachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
extract_secret_value,
get_from_dict_or_env,
get_pydantic_field_names,
)
Expand Down Expand Up @@ -397,8 +398,7 @@ async def _agenerate(
def _invocation_params(self) -> Mapping[str, Any]:
"""Get the parameters used to invoke the model."""
jinachat_creds: Dict[str, Any] = {
"api_key": self.jinachat_api_key
and self.jinachat_api_key.get_secret_value(),
"api_key": extract_secret_value(self.jinachat_api_key),
"api_base": "https://api.chat.jina.ai/v1",
"model": "jinachat",
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Iterator, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, SecretStr


def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
Expand Down Expand Up @@ -38,7 +38,7 @@ class JavelinAIGatewayEmbeddings(Embeddings, BaseModel):
gateway_uri: Optional[str] = None
"""The URI for the Javelin AI Gateway API."""

javelin_api_key: Optional[str] = None
javelin_api_key: Optional[SecretStr] = None
"""The API key for the Javelin AI Gateway API."""

def __init__(self, **kwargs: Any):
Expand Down
9 changes: 6 additions & 3 deletions libs/community/langchain_community/embeddings/voyageai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
Optional,
Tuple,
Union,
cast,
)

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
convert_to_secret_str,
extract_secret_value,
get_from_dict_or_env,
)
from tenacity import (
before_sleep_log,
retry,
Expand Down Expand Up @@ -103,7 +106,7 @@ def validate_environment(cls, values: Dict) -> Dict:
def _invocation_params(
self, input: List[str], input_type: Optional[str] = None
) -> Dict:
api_key = cast(SecretStr, self.voyage_api_key).get_secret_value()
api_key = extract_secret_value(self.voyage_api_key)
params = {
"url": self.voyage_api_base,
"headers": {"Authorization": f"Bearer {api_key}"},
Expand Down
14 changes: 10 additions & 4 deletions libs/community/langchain_community/llms/ai21.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
convert_to_secret_str,
extract_secret_value,
get_from_dict_or_env,
)


class AI21PenaltyData(BaseModel):
Expand Down Expand Up @@ -142,10 +146,12 @@ def _call(
else:
base_url = "https://api.ai21.com/studio/v1"
params = {**self._default_params, **kwargs}
self.ai21_api_key = cast(SecretStr, self.ai21_api_key)
self.ai21_api_key = convert_to_secret_str(self.ai21_api_key)
response = requests.post(
url=f"{base_url}/{self.model}/complete",
headers={"Authorization": f"Bearer {self.ai21_api_key.get_secret_value()}"},
headers={
"Authorization": f"Bearer {extract_secret_value(self.ai21_api_key)}"
},
json={"prompt": prompt, "stopSequences": stop, **params},
)
if response.status_code != 200:
Expand Down
9 changes: 6 additions & 3 deletions libs/community/langchain_community/llms/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Optional,
Set,
Tuple,
cast,
)

from langchain_core.callbacks import (
Expand All @@ -18,7 +17,11 @@
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
convert_to_secret_str,
extract_secret_value,
get_from_dict_or_env,
)

from langchain_community.llms.openai import (
BaseOpenAI,
Expand Down Expand Up @@ -132,7 +135,7 @@ def _identifying_params(self) -> Mapping[str, Any]:
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
openai_creds: Dict[str, Any] = {
"api_key": cast(SecretStr, self.anyscale_api_key).get_secret_value(),
"api_key": extract_secret_value(self.anyscale_api_key),
"api_base": self.anyscale_api_base,
}
return {**openai_creds, **{"model": self.model_name}, **super()._default_params}
Expand Down
10 changes: 7 additions & 3 deletions libs/community/langchain_community/llms/arcee.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Union

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
convert_to_secret_str,
extract_secret_value,
get_from_dict_or_env,
)

from langchain_community.utilities.arcee import ArceeWrapper, DALMFilter

Expand Down Expand Up @@ -66,7 +70,7 @@ def __init__(self, **data: Any) -> None:
"""Initializes private fields."""

super().__init__(**data)
api_key = cast(SecretStr, self.arcee_api_key)
api_key = extract_secret_value(self.arcee_api_key)
self._client = ArceeWrapper(
arcee_api_key=api_key,
arcee_api_url=self.arcee_api_url,
Expand Down
12 changes: 7 additions & 5 deletions libs/community/langchain_community/llms/cerebriumai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
from typing import Any, Dict, List, Mapping, Optional, cast
from typing import Any, Dict, List, Mapping, Optional

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
convert_to_secret_str,
extract_secret_value,
get_from_dict_or_env,
)

from langchain_community.llms.utils import enforce_stop_tokens

Expand Down Expand Up @@ -92,9 +96,7 @@ def _call(
**kwargs: Any,
) -> str:
headers: Dict = {
"Authorization": cast(
SecretStr, self.cerebriumai_api_key
).get_secret_value(),
"Authorization": extract_secret_value(self.cerebriumai_api_key),
"Content-Type": "application/json",
}
params = self.model_kwargs or {}
Expand Down
11 changes: 7 additions & 4 deletions libs/community/langchain_community/utilities/google_finance.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Util that calls Google Finance Search."""
from typing import Any, Dict, Optional, cast
from typing import Any, Dict, Optional

from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
convert_to_secret_str,
extract_secret_value,
get_from_dict_or_env,
)


class GoogleFinanceAPIWrapper(BaseModel):
Expand Down Expand Up @@ -52,10 +56,9 @@ def validate_environment(cls, values: Dict) -> Dict:

def run(self, query: str) -> str:
"""Run query through Google Finance with Serpapi"""
serpapi_api_key = cast(SecretStr, self.serp_api_key)
params = {
"engine": "google_finance",
"api_key": serpapi_api_key.get_secret_value(),
"api_key": extract_secret_value(self.serpapi_api_key),
"q": query,
}

Expand Down
11 changes: 7 additions & 4 deletions libs/community/langchain_community/utilities/google_jobs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Util that calls Google Scholar Search."""
from typing import Any, Dict, Optional, cast
from typing import Any, Dict, Optional

from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
convert_to_secret_str,
extract_secret_value,
get_from_dict_or_env,
)


class GoogleJobsAPIWrapper(BaseModel):
Expand Down Expand Up @@ -54,10 +58,9 @@ def run(self, query: str) -> str:
"""Run query through Google Trends with Serpapi"""

# set up query
serpapi_api_key = cast(SecretStr, self.serp_api_key)
params = {
"engine": "google_jobs",
"api_key": serpapi_api_key.get_secret_value(),
"api_key": extract_secret_value(self.serpapi_api_key),
"q": query,
}

Expand Down
12 changes: 7 additions & 5 deletions libs/community/langchain_community/utilities/google_lens.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Util that calls Google Lens Search."""
from typing import Any, Dict, Optional, cast
from typing import Any, Dict, Optional

import requests
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
convert_to_secret_str,
extract_secret_value,
get_from_dict_or_env,
)


class GoogleLensAPIWrapper(BaseModel):
Expand Down Expand Up @@ -45,11 +49,9 @@ def validate_environment(cls, values: Dict) -> Dict:

def run(self, query: str) -> str:
"""Run query through Google Trends with Serpapi"""
serpapi_api_key = cast(SecretStr, self.serp_api_key)

params = {
"engine": "google_lens",
"api_key": serpapi_api_key.get_secret_value(),
"api_key": extract_secret_value(self.serpapi_api_key),
"url": query,
}
queryURL = f"https://serpapi.com/search?engine={params['engine']}&api_key={params['api_key']}&url={params['url']}"
Expand Down
14 changes: 9 additions & 5 deletions libs/community/langchain_community/utilities/google_trends.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Util that calls Google Scholar Search."""
from typing import Any, Dict, Optional, cast
from typing import Any, Dict, Optional

from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils import (
convert_to_secret_str,
extract_secret_value,
get_from_dict_or_env,
)


class GoogleTrendsAPIWrapper(BaseModel):
Expand Down Expand Up @@ -56,10 +60,10 @@ def validate_environment(cls, values: Dict) -> Dict:

def run(self, query: str) -> str:
"""Run query through Google Trends with Serpapi"""
serpapi_api_key = cast(SecretStr, self.serp_api_key)
api_key = extract_secret_value(self.serpapi_api_key)
params = {
"engine": "google_trends",
"api_key": serpapi_api_key.get_secret_value(),
"api_key": api_key,
"q": query,
}

Expand All @@ -86,7 +90,7 @@ def run(self, query: str) -> str:

params = {
"engine": "google_trends",
"api_key": serpapi_api_key.get_secret_value(),
"api_key": api_key,
"data_type": "RELATED_QUERIES",
"q": query,
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Test JinaChat wrapper."""

from typing import cast

import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
Expand All @@ -13,7 +11,7 @@
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler


def test_jinachat_api_key_is_secret_string() -> None:
def test_jinachat_api_key_is_secretstr() -> None:
llm = JinaChat(jinachat_api_key="secret-api-key")
assert isinstance(llm.jinachat_api_key, SecretStr)

Expand Down Expand Up @@ -44,7 +42,7 @@ def test_jinachat_api_key_masked_when_passed_via_constructor(
def test_uses_actual_secret_value_from_secretstr() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
llm = JinaChat(jinachat_api_key="secret-api-key")
assert cast(SecretStr, llm.jinachat_api_key).get_secret_value() == "secret-api-key"
assert llm.jinachat_api_key.get_secret_value() == "secret-api-key"


def test_jinachat() -> None:
Expand Down
4 changes: 1 addition & 3 deletions libs/community/tests/integration_tests/llms/test_nlpcloud.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Test NLPCloud API wrapper."""

from pathlib import Path
from typing import cast

from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
Expand Down Expand Up @@ -34,8 +33,7 @@ def test_nlpcloud_api_key(monkeypatch: MonkeyPatch, capsys: CaptureFixture) -> N
monkeypatch.setenv("NLPCLOUD_API_KEY", "secret-api-key")
llm = NLPCloud()
assert isinstance(llm.nlpcloud_api_key, SecretStr)

assert cast(SecretStr, llm.nlpcloud_api_key).get_secret_value() == "secret-api-key"
assert llm.nlpcloud_api_key.get_secret_value() == "secret-api-key"

print(llm.nlpcloud_api_key, end="")
captured = capsys.readouterr()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def api_passed_via_constructor_fixture() -> AzureMLChatOnlineEndpoint:
["api_passed_via_constructor_fixture", "api_passed_via_environment_fixture"],
)
class TestAzureMLChatOnlineEndpoint:
def test_api_key_is_secret_string(
def test_api_key_is_secretstr(
self, fixture_name: str, request: FixtureRequest
) -> None:
"""Test that the API key is a SecretStr instance"""
Expand Down
Loading
Loading