From f03869e0afc2161006ba75b6d2f02003f3e3d649 Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 20:20:25 +0100 Subject: [PATCH 01/22] test setup --- .../integration_tests/llms/test_arcee.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 libs/langchain/tests/integration_tests/llms/test_arcee.py diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py new file mode 100644 index 0000000000000..f9f2194a8c094 --- /dev/null +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -0,0 +1,39 @@ +import unittest +from unittest.mock import patch + +from pydantic import SecretStr + +from langchain.llms.arcee import Arcee + + +class TestApiConfigSecurity(unittest.TestCase): + @patch('langchain.utilities.arcee.requests.get') + def test_arcee_api_key_is_secret_string(self, mock_get) -> None: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = {"model_id": "", "status": "training_complete"} + + llm = Arcee( + model="DALM-PubMed", + arcee_api_key="secret_api_key", + arcee_api_url="localhost", + arcee_api_version="version", + ) + + +# def test_api_key_securely_wrapped(self): +# # Ensure that the API key is securely wrapped using SecretStr. +# config = ApiConfig(api_key="your_api_key_here") +# self.assertIsInstance(config.api_key, SecretStr) +# +# def test_no_secret_in_logs(self): +# # Ensure that sensitive data is not exposed in logs. +# config = ApiConfig(api_key="your_api_key_here") +# log_output = some_logging_function(config.api_key.get_secret_value()) +# self.assertNotIn("your_api_key_here", log_output) +# +# def test_proper_access_control(self): +# # Ensure that proper access control is enforced for sensitive data. +# config = ApiConfig(api_key="your_api_key_here") +# # Perform actions that require API key and assert proper access control. +# self.assertTrue(some_function_requiring_api_key(config.api_key.get_secret_value())) From 9abf40e8325d5ac641628d2122bdbe2d862e4e44 Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 20:24:14 +0100 Subject: [PATCH 02/22] change key str to SecretStr --- libs/langchain/langchain/llms/arcee.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 72028097c2a05..1f8435bcff9ef 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from langchain_core.pydantic_v1 import Extra, root_validator +from langchain_core.pydantic_v1 import Extra, root_validator, SecretStr from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -30,7 +30,7 @@ class Arcee(LLM): _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee _client.""" - arcee_api_key: str = "" + arcee_api_key: SecretStr """Arcee API Key""" model: str From 409ef0950580ff68949bb65ce22e3064aecac22d Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 20:26:00 +0100 Subject: [PATCH 03/22] make SecretStr optional --- libs/langchain/langchain/llms/arcee.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 1f8435bcff9ef..4086b2ac7c4e4 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -5,7 +5,7 @@ from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utilities.arcee import ArceeWrapper, DALMFilter -from langchain.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env, convert_to_secret_str class Arcee(LLM): @@ -30,7 +30,7 @@ class Arcee(LLM): _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee _client.""" - arcee_api_key: SecretStr + arcee_api_key: Optional[SecretStr] """Arcee API Key""" model: str From 0609b0864b4c4089d16d2e28fe20cc897b0088bb Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 20:27:52 +0100 Subject: [PATCH 04/22] update tests --- libs/langchain/tests/integration_tests/llms/test_arcee.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index f9f2194a8c094..43422c9142f16 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -13,13 +13,14 @@ def test_arcee_api_key_is_secret_string(self, mock_get) -> None: mock_response.status_code = 200 mock_response.json.return_value = {"model_id": "", "status": "training_complete"} - llm = Arcee( + arcee = Arcee( model="DALM-PubMed", arcee_api_key="secret_api_key", arcee_api_url="localhost", arcee_api_version="version", ) + self.assertTrue(isinstance(arcee.arcee_api_key, SecretStr)) # def test_api_key_securely_wrapped(self): # # Ensure that the API key is securely wrapped using SecretStr. From d3e49c53751a66ebc755fb500468c56784db20f5 Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 20:32:23 +0100 Subject: [PATCH 05/22] update model --- libs/langchain/langchain/llms/arcee.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 4086b2ac7c4e4..747a879ac9930 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -30,7 +30,7 @@ class Arcee(LLM): _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee _client.""" - arcee_api_key: Optional[SecretStr] + arcee_api_key: Optional[str, SecretStr] """Arcee API Key""" model: str @@ -82,10 +82,12 @@ def validate_environments(cls, values: Dict) -> Dict: """Validate Arcee environment variables.""" # validate env vars - values["arcee_api_key"] = get_from_dict_or_env( - values, - "arcee_api_key", - "ARCEE_API_KEY", + values["arcee_api_key"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "arcee_api_key", + "ARCEE_API_KEY", + ) ) values["arcee_api_url"] = get_from_dict_or_env( From d8b4baa9bc3c371d732aa8ae1934f90a1ad7fb84 Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 20:37:36 +0100 Subject: [PATCH 06/22] make api key str or SecretStr --- libs/langchain/langchain/llms/arcee.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 747a879ac9930..e999b900d51f8 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from langchain_core.pydantic_v1 import Extra, root_validator, SecretStr @@ -30,7 +30,7 @@ class Arcee(LLM): _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee _client.""" - arcee_api_key: Optional[str, SecretStr] + arcee_api_key: Union[Optional[str], SecretStr] """Arcee API Key""" model: str From 1ee4a043c18eba2ed8aa5437a0f5c83d505bc924 Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 20:38:29 +0100 Subject: [PATCH 07/22] update tests --- .../langchain/tests/integration_tests/llms/test_arcee.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index 43422c9142f16..f5fe83378a5ff 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -7,20 +7,23 @@ class TestApiConfigSecurity(unittest.TestCase): + @patch('langchain.utilities.arcee.requests.get') - def test_arcee_api_key_is_secret_string(self, mock_get) -> None: + def setUp(self, mock_get) -> None: mock_response = mock_get.return_value mock_response.status_code = 200 mock_response.json.return_value = {"model_id": "", "status": "training_complete"} - arcee = Arcee( + self.arcee = Arcee( model="DALM-PubMed", arcee_api_key="secret_api_key", arcee_api_url="localhost", arcee_api_version="version", ) - self.assertTrue(isinstance(arcee.arcee_api_key, SecretStr)) + def test_arcee_api_key_is_secret_string(self) -> None: + + self.assertTrue(isinstance(self.arcee.arcee_api_key, SecretStr)) # def test_api_key_securely_wrapped(self): # # Ensure that the API key is securely wrapped using SecretStr. From 092204a58ca862f52137c56bdbb0ead1dd27c6a0 Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 20:49:02 +0100 Subject: [PATCH 08/22] update Arcee wrapper --- libs/langchain/langchain/utilities/arcee.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 318af14eb5758..c09f6b152f784 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -8,6 +8,7 @@ import requests from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.retrievers import Document +from pydantic import SecretStr class ArceeRoute(str, Enum): @@ -100,7 +101,7 @@ class ArceeWrapper: def __init__( self, - arcee_api_key: str, + arcee_api_key: Union[Optional[str], SecretStr], arcee_api_url: str, arcee_api_version: str, model_kwargs: Optional[Dict[str, Any]], @@ -167,7 +168,7 @@ def _make_request( def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: headers = headers or {} internal_headers = { - "X-Token": self.arcee_api_key, + "X-Token": self.arcee_api_key if isinstance(self.arcee_api_key, str) else self.arcee_api_key.get_secret_value(), "Content-Type": "application/json", } headers.update(internal_headers) From 48992c96e68f1a643f839037feec623c547ee3e2 Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 20:56:10 +0100 Subject: [PATCH 09/22] update tests --- .../integration_tests/llms/test_arcee.py | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index f5fe83378a5ff..6d16b66d0e05b 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -1,6 +1,7 @@ import unittest from unittest.mock import patch +import pytest from pydantic import SecretStr from langchain.llms.arcee import Arcee @@ -14,30 +15,27 @@ def setUp(self, mock_get) -> None: mock_response.status_code = 200 mock_response.json.return_value = {"model_id": "", "status": "training_complete"} - self.arcee = Arcee( + self.arcee_without_env_var = Arcee( model="DALM-PubMed", arcee_api_key="secret_api_key", arcee_api_url="localhost", arcee_api_version="version", ) + + @pytest.fixture(autouse=True) + def capsys(self, capsys): + self.capsys = capsys + + @pytest.fixture(autouse=True) + def monkeypatch(self, monkeypatch): + self.monkeypatch = monkeypatch + def test_arcee_api_key_is_secret_string(self) -> None: + self.assertTrue(isinstance(self.arcee_without_env_var.arcee_api_key, SecretStr)) + + def test_api_key_masked_when_passed_via_constructor(self) -> None: + print(self.arcee_without_env_var.arcee_api_key, end="") + captured = self.capsys.readouterr() - self.assertTrue(isinstance(self.arcee.arcee_api_key, SecretStr)) - -# def test_api_key_securely_wrapped(self): -# # Ensure that the API key is securely wrapped using SecretStr. -# config = ApiConfig(api_key="your_api_key_here") -# self.assertIsInstance(config.api_key, SecretStr) -# -# def test_no_secret_in_logs(self): -# # Ensure that sensitive data is not exposed in logs. -# config = ApiConfig(api_key="your_api_key_here") -# log_output = some_logging_function(config.api_key.get_secret_value()) -# self.assertNotIn("your_api_key_here", log_output) -# -# def test_proper_access_control(self): -# # Ensure that proper access control is enforced for sensitive data. -# config = ApiConfig(api_key="your_api_key_here") -# # Perform actions that require API key and assert proper access control. -# self.assertTrue(some_function_requiring_api_key(config.api_key.get_secret_value())) + assert captured.out == "**********" From 8415a152367f5d2c13d75afdeb5b337da3a4ad2e Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 21:02:14 +0100 Subject: [PATCH 10/22] add arcee_with_env_var --- .../integration_tests/llms/test_arcee.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index 6d16b66d0e05b..d59e7f402bc8f 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -9,6 +9,15 @@ class TestApiConfigSecurity(unittest.TestCase): + + @pytest.fixture(autouse=True) + def capsys(self, capsys): + self.capsys = capsys + + @pytest.fixture(autouse=True) + def monkeypatch(self, monkeypatch): + self.monkeypatch = monkeypatch + @patch('langchain.utilities.arcee.requests.get') def setUp(self, mock_get) -> None: mock_response = mock_get.return_value @@ -21,15 +30,13 @@ def setUp(self, mock_get) -> None: arcee_api_url="localhost", arcee_api_version="version", ) - - - @pytest.fixture(autouse=True) - def capsys(self, capsys): - self.capsys = capsys - - @pytest.fixture(autouse=True) - def monkeypatch(self, monkeypatch): - self.monkeypatch = monkeypatch + self.monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key") + self.arcee_with_env_var = Arcee( + model="DALM-PubMed", + arcee_api_key="", + arcee_api_url="localhost", + arcee_api_version="version", + ) def test_arcee_api_key_is_secret_string(self) -> None: self.assertTrue(isinstance(self.arcee_without_env_var.arcee_api_key, SecretStr)) @@ -38,4 +45,4 @@ def test_api_key_masked_when_passed_via_constructor(self) -> None: print(self.arcee_without_env_var.arcee_api_key, end="") captured = self.capsys.readouterr() - assert captured.out == "**********" + assert captured.out == "**********" \ No newline at end of file From 869ee01d208c67a7b9f15d6a0b77fabad957e9fe Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 21:03:31 +0100 Subject: [PATCH 11/22] test with env var --- libs/langchain/tests/integration_tests/llms/test_arcee.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index d59e7f402bc8f..31845ee6bbc46 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -45,4 +45,10 @@ def test_api_key_masked_when_passed_via_constructor(self) -> None: print(self.arcee_without_env_var.arcee_api_key, end="") captured = self.capsys.readouterr() - assert captured.out == "**********" \ No newline at end of file + self.assertEquals("**********", captured.out) + + def test_api_key_masked_when_passed_from_env(self) -> None: + print(self.arcee_with_env_var.arcee_api_key, end="") + captured = self.capsys.readouterr() + + self.assertEquals("**********", captured.out) From 637df73c433ff6f1615444a799ab44f64329285e Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 21:05:45 +0100 Subject: [PATCH 12/22] formatting --- libs/langchain/langchain/llms/arcee.py | 4 ++-- libs/langchain/langchain/utilities/arcee.py | 4 +++- .../langchain/tests/integration_tests/llms/test_arcee.py | 9 +++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index e999b900d51f8..e51817ddb4b30 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -1,11 +1,11 @@ from typing import Any, Dict, List, Optional, Union -from langchain_core.pydantic_v1 import Extra, root_validator, SecretStr +from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utilities.arcee import ArceeWrapper, DALMFilter -from langchain.utils import get_from_dict_or_env, convert_to_secret_str +from langchain.utils import convert_to_secret_str, get_from_dict_or_env class Arcee(LLM): diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index c09f6b152f784..5017697a7dc35 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -168,7 +168,9 @@ def _make_request( def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: headers = headers or {} internal_headers = { - "X-Token": self.arcee_api_key if isinstance(self.arcee_api_key, str) else self.arcee_api_key.get_secret_value(), + "X-Token": self.arcee_api_key + if isinstance(self.arcee_api_key, str) + else self.arcee_api_key.get_secret_value(), "Content-Type": "application/json", } headers.update(internal_headers) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index 31845ee6bbc46..31e1903d334fd 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -8,8 +8,6 @@ class TestApiConfigSecurity(unittest.TestCase): - - @pytest.fixture(autouse=True) def capsys(self, capsys): self.capsys = capsys @@ -18,11 +16,14 @@ def capsys(self, capsys): def monkeypatch(self, monkeypatch): self.monkeypatch = monkeypatch - @patch('langchain.utilities.arcee.requests.get') + @patch("langchain.utilities.arcee.requests.get") def setUp(self, mock_get) -> None: mock_response = mock_get.return_value mock_response.status_code = 200 - mock_response.json.return_value = {"model_id": "", "status": "training_complete"} + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } self.arcee_without_env_var = Arcee( model="DALM-PubMed", From 15ffbf50545cf61a204a96efad3093bed40fa51e Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 21:09:33 +0100 Subject: [PATCH 13/22] lint --- libs/langchain/langchain/utilities/arcee.py | 3 +-- .../tests/integration_tests/llms/test_arcee.py | 12 +++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 5017697a7dc35..d8be0bcaa0ca1 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -6,9 +6,8 @@ from typing import Any, Dict, List, Literal, Mapping, Optional, Union import requests -from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator from langchain_core.retrievers import Document -from pydantic import SecretStr class ArceeRoute(str, Enum): diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index 31e1903d334fd..d12e8d8b7d03d 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -1,23 +1,25 @@ import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest -from pydantic import SecretStr +from _pytest.capture import CaptureFixture +from _pytest.monkeypatch import MonkeyPatch +from langchain_core.pydantic_v1 import SecretStr from langchain.llms.arcee import Arcee class TestApiConfigSecurity(unittest.TestCase): @pytest.fixture(autouse=True) - def capsys(self, capsys): + def capsys(self, capsys: CaptureFixture) -> None: self.capsys = capsys @pytest.fixture(autouse=True) - def monkeypatch(self, monkeypatch): + def monkeypatch(self, monkeypatch: MonkeyPatch) -> None: self.monkeypatch = monkeypatch @patch("langchain.utilities.arcee.requests.get") - def setUp(self, mock_get) -> None: + def setUp(self, mock_get: MagicMock) -> None: mock_response = mock_get.return_value mock_response.status_code = 200 mock_response.json.return_value = { From f9cf6d144bdbdbf27b2cc2e37fc5ac45dea244bf Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 21:33:08 +0100 Subject: [PATCH 14/22] remove unittest class --- .../integration_tests/llms/test_arcee.py | 108 ++++++++++-------- 1 file changed, 60 insertions(+), 48 deletions(-) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index d12e8d8b7d03d..f6df105fa763e 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -1,7 +1,5 @@ -import unittest from unittest.mock import MagicMock, patch -import pytest from _pytest.capture import CaptureFixture from _pytest.monkeypatch import MonkeyPatch from langchain_core.pydantic_v1 import SecretStr @@ -9,49 +7,63 @@ from langchain.llms.arcee import Arcee -class TestApiConfigSecurity(unittest.TestCase): - @pytest.fixture(autouse=True) - def capsys(self, capsys: CaptureFixture) -> None: - self.capsys = capsys - - @pytest.fixture(autouse=True) - def monkeypatch(self, monkeypatch: MonkeyPatch) -> None: - self.monkeypatch = monkeypatch - - @patch("langchain.utilities.arcee.requests.get") - def setUp(self, mock_get: MagicMock) -> None: - mock_response = mock_get.return_value - mock_response.status_code = 200 - mock_response.json.return_value = { - "model_id": "", - "status": "training_complete", - } - - self.arcee_without_env_var = Arcee( - model="DALM-PubMed", - arcee_api_key="secret_api_key", - arcee_api_url="localhost", - arcee_api_version="version", - ) - self.monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key") - self.arcee_with_env_var = Arcee( - model="DALM-PubMed", - arcee_api_key="", - arcee_api_url="localhost", - arcee_api_version="version", - ) - - def test_arcee_api_key_is_secret_string(self) -> None: - self.assertTrue(isinstance(self.arcee_without_env_var.arcee_api_key, SecretStr)) - - def test_api_key_masked_when_passed_via_constructor(self) -> None: - print(self.arcee_without_env_var.arcee_api_key, end="") - captured = self.capsys.readouterr() - - self.assertEquals("**********", captured.out) - - def test_api_key_masked_when_passed_from_env(self) -> None: - print(self.arcee_with_env_var.arcee_api_key, end="") - captured = self.capsys.readouterr() - - self.assertEquals("**********", captured.out) +@patch("langchain.utilities.arcee.requests.get") +def test_arcee_api_key_is_secret_string(mock_get: MagicMock) -> None: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } + + arcee_without_env_var = Arcee( + model="DALM-PubMed", + arcee_api_key="secret_api_key", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + assert isinstance(arcee_without_env_var.arcee_api_key, SecretStr) + + +@patch("langchain.utilities.arcee.requests.get") +def test_api_key_masked_when_passed_via_constructor(mock_get: MagicMock, capsys: CaptureFixture) -> None: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } + + arcee_without_env_var = Arcee( + model="DALM-PubMed", + arcee_api_key="secret_api_key", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + print(arcee_without_env_var.arcee_api_key, end="") + captured = capsys.readouterr() + + assert "**********" == captured.out + + +@patch("langchain.utilities.arcee.requests.get") +def test_api_key_masked_when_passed_from_env(mock_get: MagicMock, capsys: CaptureFixture, + monkeypatch: MonkeyPatch) -> None: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } + + monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key") + arcee_with_env_var = Arcee( + model="DALM-PubMed", + arcee_api_key="", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + print(arcee_with_env_var.arcee_api_key, end="") + captured = capsys.readouterr() + + assert "**********" == captured.out From 161df1e6e189401f079f7ed2b8d962f8a1b897f4 Mon Sep 17 00:00:00 2001 From: raphael Date: Thu, 30 Nov 2023 21:38:56 +0100 Subject: [PATCH 15/22] finalize make tasks --- libs/langchain/langchain/utilities/arcee.py | 9 ++++++--- .../langchain/tests/integration_tests/llms/test_arcee.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index d8be0bcaa0ca1..9d3538fe3b478 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -166,10 +166,13 @@ def _make_request( def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: headers = headers or {} + api_key = None + if isinstance(self.arcee_api_key, str): + api_key = self.arcee_api_key + elif isinstance(self.arcee_api_key, SecretStr): + api_key = self.arcee_api_key.get_secret_value() internal_headers = { - "X-Token": self.arcee_api_key - if isinstance(self.arcee_api_key, str) - else self.arcee_api_key.get_secret_value(), + "X-Token": api_key, "Content-Type": "application/json", } headers.update(internal_headers) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index f6df105fa763e..93a66e249d28e 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -26,7 +26,9 @@ def test_arcee_api_key_is_secret_string(mock_get: MagicMock) -> None: @patch("langchain.utilities.arcee.requests.get") -def test_api_key_masked_when_passed_via_constructor(mock_get: MagicMock, capsys: CaptureFixture) -> None: +def test_api_key_masked_when_passed_via_constructor( + mock_get: MagicMock, capsys: CaptureFixture +) -> None: mock_response = mock_get.return_value mock_response.status_code = 200 mock_response.json.return_value = { @@ -47,8 +49,9 @@ def test_api_key_masked_when_passed_via_constructor(mock_get: MagicMock, capsys: @patch("langchain.utilities.arcee.requests.get") -def test_api_key_masked_when_passed_from_env(mock_get: MagicMock, capsys: CaptureFixture, - monkeypatch: MonkeyPatch) -> None: +def test_api_key_masked_when_passed_from_env( + mock_get: MagicMock, capsys: CaptureFixture, monkeypatch: MonkeyPatch +) -> None: mock_response = mock_get.return_value mock_response.status_code = 200 mock_response.json.return_value = { From ce01b79f9e540ec5853afed9f2699a4220d368d9 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 5 Dec 2023 12:09:14 -0500 Subject: [PATCH 16/22] x --- libs/langchain/langchain/llms/arcee.py | 39 +++++++++++---------- libs/langchain/langchain/utilities/arcee.py | 28 ++++++++------- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 2e9392ec39646..a3af95cca6e10 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -27,10 +27,10 @@ class Arcee(LLM): response = arcee("AI-driven music therapy") """ - _client: Optional[ArceeWrapper] = None #: :meta private: + _client: ArceeWrapper #: :meta private: """Arcee _client.""" - arcee_api_key: Optional[SecretStr] = None + arcee_api_key: SecretStr """Arcee API Key""" model: str @@ -66,29 +66,20 @@ def __init__(self, **data: Any) -> None: """Initializes private fields.""" super().__init__(**data) - self._client = ArceeWrapper( - arcee_api_key=arcee_api_key # FIX ME, - arcee_api_url=self.arcee_api_url, - arcee_api_version=self.arcee_api_version, - model_kwargs=self.model_kwargs, - model_name=self.model, - ) - - self._client.validate_model_training_status() - @root_validator() + @root_validator(pre=True) # Use pre=False to pick up defaults def validate_environments(cls, values: Dict) -> Dict: """Validate Arcee environment variables.""" # validate env vars - values["arcee_api_key"] = convert_to_secret_str( - get_from_dict_or_env( - values, - "arcee_api_key", - "ARCEE_API_KEY", - ) + arcee_api_key = get_from_dict_or_env( + values, + "arcee_api_key", + "ARCEE_API_KEY", ) + values["arcee_api_key"] = convert_to_secret_str(arcee_api_key) + values["arcee_api_url"] = get_from_dict_or_env( values, "arcee_api_url", @@ -108,7 +99,7 @@ def validate_environments(cls, values: Dict) -> Dict: ) # validate model kwargs - if values["model_kwargs"]: + if values.get("model_kwargs"): kw = values["model_kwargs"] # validate size @@ -123,6 +114,16 @@ def validate_environments(cls, values: Dict) -> Dict: for f in kw.get("filters"): DALMFilter(**f) + client = ArceeWrapper( + arcee_api_key=arcee_api_key, + arcee_api_url=values["arcee_api_url"], + arcee_api_version=values["arcee_api_version"], + model_kwargs=values.get("model_kwargs"), + model_name=values["model"], + ) + + client.validate_model_training_status() + values["_client"] = client return values def _call( diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 8abf33d1a2aca..7781bfa4f90ad 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -96,11 +96,14 @@ def adapt(cls, arcee_document: ArceeDocument) -> Document: class ArceeWrapper: - """Wrapper for Arcee API.""" + """Wrapper for Arcee API. + + For more details, see: https://www.arcee.ai/ + """ def __init__( self, - arcee_api_key: Union[str, SecretStr, None], + arcee_api_key: str, arcee_api_url: str, arcee_api_version: str, model_kwargs: Optional[Dict[str, Any]], @@ -114,9 +117,12 @@ def __init__( arcee_api_version: Version of Arcee API. model_kwargs: Keyword arguments for Arcee API. model_name: Name of an Arcee model. - """ - self.arcee_api_key = arcee_api_key + if not isinstance(arcee_api_key, (str,)): + raise TypeError( + f"arcee_api_key must be a string. Got {type(arcee_api_key)}" + ) + self.arcee_api_key = SecretStr(arcee_api_key) self.model_kwargs = model_kwargs self.arcee_api_url = arcee_api_url self.arcee_api_version = arcee_api_version @@ -166,17 +172,13 @@ def _make_request( def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: headers = headers or {} - api_key = None # Fix here - if isinstance(self.arcee_api_key, str): - api_key = self.arcee_api_key - elif isinstance(self.arcee_api_key, SecretStr): - api_key = self.arcee_api_key.get_secret_value() + if not isinstance(self.arcee_api_key, SecretStr): + raise TypeError( + f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}" + ) + api_key = self.arcee_api_key.get_secret_value() internal_headers = { -<<<<<<< HEAD "X-Token": api_key, -======= - "X-Token": self.arcee_api_key.get_secret_value(), ->>>>>>> master "Content-Type": "application/json", } headers.update(internal_headers) From a5451bbac0c6d2f04d202b052987a115eaf8f79c Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 5 Dec 2023 12:30:39 -0500 Subject: [PATCH 17/22] x --- libs/langchain/langchain/llms/arcee.py | 35 ++++++++++++-------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index a3af95cca6e10..94139525384fa 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator @@ -27,7 +27,7 @@ class Arcee(LLM): response = arcee("AI-driven music therapy") """ - _client: ArceeWrapper #: :meta private: + _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee _client.""" arcee_api_key: SecretStr @@ -66,20 +66,28 @@ def __init__(self, **data: Any) -> None: """Initializes private fields.""" super().__init__(**data) + self._client = ArceeWrapper( + arcee_api_key=self.arcee_api_key.get_secret_value(), + arcee_api_url=self.arcee_api_url, + arcee_api_version=self.arcee_api_version, + model_kwargs=self.model_kwargs, + model_name=self.model, + ) + self._client.validate_model_training_status() @root_validator(pre=True) # Use pre=False to pick up defaults def validate_environments(cls, values: Dict) -> Dict: """Validate Arcee environment variables.""" # validate env vars - arcee_api_key = get_from_dict_or_env( - values, - "arcee_api_key", - "ARCEE_API_KEY", + values["arcee_api_key"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "arcee_api_key", + "ARCEE_API_KEY", + ) ) - values["arcee_api_key"] = convert_to_secret_str(arcee_api_key) - values["arcee_api_url"] = get_from_dict_or_env( values, "arcee_api_url", @@ -113,17 +121,6 @@ def validate_environments(cls, values: Dict) -> Dict: raise ValueError("`filters` must be a list") for f in kw.get("filters"): DALMFilter(**f) - - client = ArceeWrapper( - arcee_api_key=arcee_api_key, - arcee_api_url=values["arcee_api_url"], - arcee_api_version=values["arcee_api_version"], - model_kwargs=values.get("model_kwargs"), - model_name=values["model"], - ) - - client.validate_model_training_status() - values["_client"] = client return values def _call( From 4cef21c3dc754022749cf115ceccbc1d542de820 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 5 Dec 2023 12:33:15 -0500 Subject: [PATCH 18/22] x --- .../integration_tests/llms/test_arcee.py | 33 +------------------ 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index b66a91a82e2e5..0598c824aa7df 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -1,8 +1,7 @@ from unittest.mock import MagicMock, patch -from _pytest.capture import CaptureFixture -from _pytest.monkeypatch import MonkeyPatch from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch from langchain.llms.arcee import Arcee @@ -71,37 +70,7 @@ def test_api_key_masked_when_passed_from_env( assert "**********" == captured.out -from langchain_core.pydantic_v1 import SecretStr -from pytest import CaptureFixture, MonkeyPatch - -from langchain.llms.arcee import Arcee - def test_api_key_is_secret_string() -> None: llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") assert isinstance(llm.arcee_api_key, SecretStr) - - -def test_api_key_masked_when_passed_from_env( - monkeypatch: MonkeyPatch, capsys: CaptureFixture -) -> None: - """Test initialization with an API key provided via an env variable""" - monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key") - - llm = Arcee(model="DALM-PubMed") - - print(llm.arcee_api_key, end="") - captured = capsys.readouterr() - assert captured.out == "**********" - - -def test_api_key_masked_when_passed_via_constructor( - capsys: CaptureFixture, -) -> None: - """Test initialization with an API key provided via the initializer""" - llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") - - print(llm.arcee_api_key, end="") - captured = capsys.readouterr() - assert captured.out == "**********" - From bf285e32a24d34bb2a39bd8068d93c66dc192116 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 5 Dec 2023 12:38:57 -0500 Subject: [PATCH 19/22] x --- libs/langchain/langchain/llms/arcee.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 94139525384fa..13745884a5778 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -75,7 +75,7 @@ def __init__(self, **data: Any) -> None: ) self._client.validate_model_training_status() - @root_validator(pre=True) # Use pre=False to pick up defaults + @root_validator(pre=False) def validate_environments(cls, values: Dict) -> Dict: """Validate Arcee environment variables.""" From c9f9d448bd4cbe5f5f78f511ddb68361003c4f65 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 5 Dec 2023 13:00:57 -0500 Subject: [PATCH 20/22] x --- libs/langchain/langchain/llms/arcee.py | 1 - libs/langchain/tests/integration_tests/llms/test_arcee.py | 5 ----- 2 files changed, 6 deletions(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 13745884a5778..6e364812848b0 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -73,7 +73,6 @@ def __init__(self, **data: Any) -> None: model_kwargs=self.model_kwargs, model_name=self.model, ) - self._client.validate_model_training_status() @root_validator(pre=False) def validate_environments(cls, values: Dict) -> Dict: diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index 0598c824aa7df..be4e40e1e79ff 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -69,8 +69,3 @@ def test_api_key_masked_when_passed_from_env( captured = capsys.readouterr() assert "**********" == captured.out - - -def test_api_key_is_secret_string() -> None: - llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") - assert isinstance(llm.arcee_api_key, SecretStr) From 8440809752d6f84f35514619733888558aaae198 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 5 Dec 2023 13:05:46 -0500 Subject: [PATCH 21/22] x --- libs/langchain/langchain/retrievers/arcee.py | 2 +- libs/langchain/langchain/utilities/arcee.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/retrievers/arcee.py b/libs/langchain/langchain/retrievers/arcee.py index 7d3e7b822f5b2..e360f62a03f3b 100644 --- a/libs/langchain/langchain/retrievers/arcee.py +++ b/libs/langchain/langchain/retrievers/arcee.py @@ -61,7 +61,7 @@ def __init__(self, **data: Any) -> None: super().__init__(**data) self._client = ArceeWrapper( - arcee_api_key=self.arcee_api_key, + arcee_api_key=self.arcee_api_key.get_secret_value(), arcee_api_url=self.arcee_api_url, arcee_api_version=self.arcee_api_version, model_kwargs=self.model_kwargs, diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 7781bfa4f90ad..15b890f7066ba 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -118,7 +118,7 @@ def __init__( model_kwargs: Keyword arguments for Arcee API. model_name: Name of an Arcee model. """ - if not isinstance(arcee_api_key, (str,)): + if not isinstance(arcee_api_key, str): raise TypeError( f"arcee_api_key must be a string. Got {type(arcee_api_key)}" ) From de0bf9dc066d6155ec2662175c055333e8d67b03 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 5 Dec 2023 13:14:32 -0500 Subject: [PATCH 22/22] x --- libs/langchain/langchain/llms/arcee.py | 7 ++++--- libs/langchain/langchain/utilities/arcee.py | 12 ++++++------ .../tests/integration_tests/llms/test_arcee.py | 1 - 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/llms/arcee.py b/libs/langchain/langchain/llms/arcee.py index 6e364812848b0..7e83219cdd6b7 100644 --- a/libs/langchain/langchain/llms/arcee.py +++ b/libs/langchain/langchain/llms/arcee.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union, cast from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator @@ -30,7 +30,7 @@ class Arcee(LLM): _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee _client.""" - arcee_api_key: SecretStr + arcee_api_key: Union[SecretStr, str, None] = None """Arcee API Key""" model: str @@ -66,8 +66,9 @@ def __init__(self, **data: Any) -> None: """Initializes private fields.""" super().__init__(**data) + api_key = cast(SecretStr, self.arcee_api_key) self._client = ArceeWrapper( - arcee_api_key=self.arcee_api_key.get_secret_value(), + arcee_api_key=api_key, arcee_api_url=self.arcee_api_url, arcee_api_version=self.arcee_api_version, model_kwargs=self.model_kwargs, diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index 15b890f7066ba..7217034858310 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -103,7 +103,7 @@ class ArceeWrapper: def __init__( self, - arcee_api_key: str, + arcee_api_key: Union[str, SecretStr], arcee_api_url: str, arcee_api_version: str, model_kwargs: Optional[Dict[str, Any]], @@ -118,11 +118,11 @@ def __init__( model_kwargs: Keyword arguments for Arcee API. model_name: Name of an Arcee model. """ - if not isinstance(arcee_api_key, str): - raise TypeError( - f"arcee_api_key must be a string. Got {type(arcee_api_key)}" - ) - self.arcee_api_key = SecretStr(arcee_api_key) + if isinstance(arcee_api_key, str): + arcee_api_key_ = SecretStr(arcee_api_key) + else: + arcee_api_key_ = arcee_api_key + self.arcee_api_key: SecretStr = arcee_api_key_ self.model_kwargs = model_kwargs self.arcee_api_url = arcee_api_url self.arcee_api_version = arcee_api_version diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index be4e40e1e79ff..40daec3682fb9 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -61,7 +61,6 @@ def test_api_key_masked_when_passed_from_env( monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key") arcee_with_env_var = Arcee( model="DALM-PubMed", - arcee_api_key="", arcee_api_url="https://localhost", arcee_api_version="version", )