Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hide api key: arcee #14304

Merged
merged 23 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions libs/langchain/langchain/llms/arcee.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, Union, cast

from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator

Expand Down Expand Up @@ -30,7 +30,7 @@ class Arcee(LLM):
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee _client."""

arcee_api_key: Optional[SecretStr] = None
arcee_api_key: Union[SecretStr, str, None] = None
"""Arcee API Key"""

model: str
Expand Down Expand Up @@ -66,15 +66,16 @@ def __init__(self, **data: Any) -> None:
"""Initializes private fields."""

super().__init__(**data)
api_key = cast(SecretStr, self.arcee_api_key)
self._client = ArceeWrapper(
arcee_api_key=cast(SecretStr, self.arcee_api_key),
arcee_api_key=api_key,
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
model_name=self.model,
)

@root_validator()
@root_validator(pre=False)
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""

Expand Down Expand Up @@ -106,7 +107,7 @@ def validate_environments(cls, values: Dict) -> Dict:
)

# validate model kwargs
if values["model_kwargs"]:
if values.get("model_kwargs"):
kw = values["model_kwargs"]

# validate size
Expand All @@ -120,7 +121,6 @@ def validate_environments(cls, values: Dict) -> Dict:
raise ValueError("`filters` must be a list")
for f in kw.get("filters"):
DALMFilter(**f)

return values

def _call(
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/retrievers/arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, **data: Any) -> None:
super().__init__(**data)

self._client = ArceeWrapper(
arcee_api_key=self.arcee_api_key,
arcee_api_key=self.arcee_api_key.get_secret_value(),
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
Expand Down
21 changes: 16 additions & 5 deletions libs/langchain/langchain/utilities/arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,14 @@ def adapt(cls, arcee_document: ArceeDocument) -> Document:


class ArceeWrapper:
"""Wrapper for Arcee API."""
"""Wrapper for Arcee API.

For more details, see: https://www.arcee.ai/
"""

def __init__(
self,
arcee_api_key: SecretStr,
arcee_api_key: Union[str, SecretStr],
arcee_api_url: str,
arcee_api_version: str,
model_kwargs: Optional[Dict[str, Any]],
Expand All @@ -114,9 +117,12 @@ def __init__(
arcee_api_version: Version of Arcee API.
model_kwargs: Keyword arguments for Arcee API.
model_name: Name of an Arcee model.

"""
self.arcee_api_key = arcee_api_key
if isinstance(arcee_api_key, str):
arcee_api_key_ = SecretStr(arcee_api_key)
else:
arcee_api_key_ = arcee_api_key
self.arcee_api_key: SecretStr = arcee_api_key_
self.model_kwargs = model_kwargs
self.arcee_api_url = arcee_api_url
self.arcee_api_version = arcee_api_version
Expand Down Expand Up @@ -166,8 +172,13 @@ def _make_request(

def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
headers = headers or {}
if not isinstance(self.arcee_api_key, SecretStr):
raise TypeError(
f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}"
)
api_key = self.arcee_api_key.get_secret_value()
internal_headers = {
"X-Token": self.arcee_api_key.get_secret_value(),
"X-Token": api_key,
"Content-Type": "application/json",
}
headers.update(internal_headers)
Expand Down
72 changes: 54 additions & 18 deletions libs/langchain/tests/integration_tests/llms/test_arcee.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,70 @@
"""Test Arcee llm"""
from unittest.mock import MagicMock, patch

from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain.llms.arcee import Arcee


def test_api_key_is_secret_string() -> None:
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
assert isinstance(llm.arcee_api_key, SecretStr)
@patch("langchain.utilities.arcee.requests.get")
def test_arcee_api_key_is_secret_string(mock_get: MagicMock) -> None:
mock_response = mock_get.return_value
mock_response.status_code = 200
mock_response.json.return_value = {
"model_id": "",
"status": "training_complete",
}

arcee_without_env_var = Arcee(
model="DALM-PubMed",
arcee_api_key="secret_api_key",
arcee_api_url="https://localhost",
arcee_api_version="version",
)
assert isinstance(arcee_without_env_var.arcee_api_key, SecretStr)

def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key")

llm = Arcee(model="DALM-PubMed")
@patch("langchain.utilities.arcee.requests.get")
def test_api_key_masked_when_passed_via_constructor(
mock_get: MagicMock, capsys: CaptureFixture
) -> None:
mock_response = mock_get.return_value
mock_response.status_code = 200
mock_response.json.return_value = {
"model_id": "",
"status": "training_complete",
}

print(llm.arcee_api_key, end="")
arcee_without_env_var = Arcee(
model="DALM-PubMed",
arcee_api_key="secret_api_key",
arcee_api_url="https://localhost",
arcee_api_version="version",
)
print(arcee_without_env_var.arcee_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"

assert "**********" == captured.out

def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,

@patch("langchain.utilities.arcee.requests.get")
def test_api_key_masked_when_passed_from_env(
mock_get: MagicMock, capsys: CaptureFixture, monkeypatch: MonkeyPatch
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key")
mock_response = mock_get.return_value
mock_response.status_code = 200
mock_response.json.return_value = {
"model_id": "",
"status": "training_complete",
}

print(llm.arcee_api_key, end="")
monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key")
arcee_with_env_var = Arcee(
model="DALM-PubMed",
arcee_api_url="https://localhost",
arcee_api_version="version",
)
print(arcee_with_env_var.arcee_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"

assert "**********" == captured.out
Loading