From d1d132a68634bf06d3e8e6f4ace00524717626ff Mon Sep 17 00:00:00 2001 From: Fayvor Love Date: Sat, 2 Nov 2024 09:42:01 -0400 Subject: [PATCH 1/9] Add test for api token propagation (fails). Signed-off-by: Fayvor Love --- .../integration_tests/llms/test_replicate.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/libs/community/tests/integration_tests/llms/test_replicate.py b/libs/community/tests/integration_tests/llms/test_replicate.py index dbf4ba2ff1100..85fdeb33d329d 100644 --- a/libs/community/tests/integration_tests/llms/test_replicate.py +++ b/libs/community/tests/integration_tests/llms/test_replicate.py @@ -2,6 +2,7 @@ from langchain_community.llms.replicate import Replicate from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +import os TEST_MODEL_HELLO = ( "replicate/hello-world:" @@ -47,3 +48,20 @@ def test_replicate_model_kwargs() -> None: def test_replicate_input() -> None: llm = Replicate(model=TEST_MODEL_LANG, input={"max_new_tokens": 10}) assert llm.model_kwargs == {"max_new_tokens": 10} + + +def test_replicate_api_key_propagation() -> None: + """Test that API key passed to the model is used to access the service.""" + # Grab the api token from the environment variable. + api_token = os.getenv("REPLICATE_API_TOKEN") + + # Reset the environment variable to ensure it's not available. + os.environ["REPLICATE_API_TOKEN"] = "yo" + + # Pass the api token into the model. + llm = Replicate(model=TEST_MODEL_HELLO, replicate_api_token=api_token) + output = llm.invoke("What is a duck?") + + assert output + assert isinstance(output, str) + \ No newline at end of file From f174ea682dc21e8318e549cd9d19614ffa54d83c Mon Sep 17 00:00:00 2001 From: Fayvor Love Date: Sat, 2 Nov 2024 10:46:21 -0400 Subject: [PATCH 2/9] Clean up test. Signed-off-by: Fayvor Love --- libs/community/tests/integration_tests/llms/test_replicate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/community/tests/integration_tests/llms/test_replicate.py b/libs/community/tests/integration_tests/llms/test_replicate.py index 85fdeb33d329d..8dc4b8b7a984b 100644 --- a/libs/community/tests/integration_tests/llms/test_replicate.py +++ b/libs/community/tests/integration_tests/llms/test_replicate.py @@ -50,8 +50,8 @@ def test_replicate_input() -> None: assert llm.model_kwargs == {"max_new_tokens": 10} -def test_replicate_api_key_propagation() -> None: - """Test that API key passed to the model is used to access the service.""" +def test_replicate_api_token_propagation() -> None: + """Test that API token passed to the model is used to access the service.""" # Grab the api token from the environment variable. api_token = os.getenv("REPLICATE_API_TOKEN") From b1096c82ee93b0478327da871977e6e213461d38 Mon Sep 17 00:00:00 2001 From: Fayvor Love Date: Sat, 2 Nov 2024 11:37:52 -0400 Subject: [PATCH 3/9] Use internal client for Replicate to access services. Signed-off-by: Fayvor Love --- .../langchain_community/llms/replicate.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py index f6c4e15ba621c..061feed54bb4c 100644 --- a/libs/community/langchain_community/llms/replicate.py +++ b/libs/community/langchain_community/llms/replicate.py @@ -2,6 +2,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional +from typing_extensions import Self from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM @@ -9,6 +10,7 @@ from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_core.utils.pydantic import get_fields from pydantic import ConfigDict, Field, model_validator +from replicate.client import Client if TYPE_CHECKING: from replicate.prediction import Prediction @@ -56,6 +58,8 @@ class Replicate(LLM): stop: List[str] = Field(default_factory=list) """Stop sequences to early-terminate generation.""" + _client: Client = Client() + model_config = ConfigDict( populate_by_name=True, extra="forbid", @@ -98,6 +102,12 @@ def build_extra(cls, values: Dict[str, Any]) -> Any: values["model_kwargs"] = extra return values + @model_validator(mode="after") + def set_client(self) -> Self: + """Add a client to the values.""" + self._client = Client(api_token=self.replicate_api_token) + return self + @pre_init def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -188,22 +198,14 @@ def _stream( break def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction: - try: - import replicate as replicate_python - except ImportError: - raise ImportError( - "Could not import replicate python package. " - "Please install it with `pip install replicate`." - ) - # get the model and version if self.version_obj is None: if ":" in self.model: model_str, version_str = self.model.split(":") - model = replicate_python.models.get(model_str) + model = self._client.models.get(model_str) self.version_obj = model.versions.get(version_str) else: - model = replicate_python.models.get(self.model) + model = self._client.models.get(self.model) self.version_obj = model.latest_version if self.prompt_key is None: @@ -225,8 +227,8 @@ def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction: # if it's an official model if ":" not in self.model: - return replicate_python.models.predictions.create(self.model, input=input_) + return self._client.models.predictions.create(self.model, input=input_) else: - return replicate_python.predictions.create( + return self._client.predictions.create( version=self.version_obj, input=input_ ) From 72e242bb44276c9f98cd456051b21bcc46d8649f Mon Sep 17 00:00:00 2001 From: Fayvor Love Date: Sat, 2 Nov 2024 11:53:04 -0400 Subject: [PATCH 4/9] Format code. Signed-off-by: Fayvor Love --- libs/community/langchain_community/llms/replicate.py | 2 +- libs/community/tests/integration_tests/llms/test_replicate.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py index 061feed54bb4c..43a2d81283601 100644 --- a/libs/community/langchain_community/llms/replicate.py +++ b/libs/community/langchain_community/llms/replicate.py @@ -2,7 +2,6 @@ import logging from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional -from typing_extensions import Self from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM @@ -11,6 +10,7 @@ from langchain_core.utils.pydantic import get_fields from pydantic import ConfigDict, Field, model_validator from replicate.client import Client +from typing_extensions import Self if TYPE_CHECKING: from replicate.prediction import Prediction diff --git a/libs/community/tests/integration_tests/llms/test_replicate.py b/libs/community/tests/integration_tests/llms/test_replicate.py index 8dc4b8b7a984b..50cf45305d6c1 100644 --- a/libs/community/tests/integration_tests/llms/test_replicate.py +++ b/libs/community/tests/integration_tests/llms/test_replicate.py @@ -1,8 +1,9 @@ """Test Replicate API wrapper.""" +import os + from langchain_community.llms.replicate import Replicate from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler -import os TEST_MODEL_HELLO = ( "replicate/hello-world:" @@ -64,4 +65,3 @@ def test_replicate_api_token_propagation() -> None: assert output assert isinstance(output, str) - \ No newline at end of file From b581070ad53b9b9cf3a6d28d53ad049ebe65449b Mon Sep 17 00:00:00 2001 From: Fayvor Love Date: Sat, 2 Nov 2024 12:26:11 -0400 Subject: [PATCH 5/9] Wrap replicate package import in try/except. Signed-off-by: Fayvor Love --- libs/community/langchain_community/llms/replicate.py | 9 ++++++++- libs/community/pyproject.toml | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py index 43a2d81283601..f7f317cb3fa43 100644 --- a/libs/community/langchain_community/llms/replicate.py +++ b/libs/community/langchain_community/llms/replicate.py @@ -9,12 +9,19 @@ from langchain_core.utils import get_from_dict_or_env, pre_init from langchain_core.utils.pydantic import get_fields from pydantic import ConfigDict, Field, model_validator -from replicate.client import Client from typing_extensions import Self if TYPE_CHECKING: from replicate.prediction import Prediction +try: + from replicate.client import Client +except ImportError: + raise ImportError( + "Could not import replicate python package. " + "Please install it with `pip install replicate`." + ) + logger = logging.getLogger(__name__) diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 822c2b7b94995..1eba41f652ed7 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -56,9 +56,9 @@ select = [ "E", "F", "I", "T201",] omit = [ "tests/*",] [tool.pytest.ini_options] -addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv" +addopts = "--strict-markers --strict-config --durations=5 -vv" markers = [ "requires: mark tests as requiring a specific library", "scheduled: mark tests to run in scheduled testing", "compile: mark placeholder test used to compile integration tests without running them",] -asyncio_mode = "auto" +# asyncio_mode = "auto" filterwarnings = [ "ignore::langchain_core._api.beta_decorator.LangChainBetaWarning", "ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test", "ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:test",] [tool.poetry.group.test] From 28169fe6f8f309aa299b8c6bed4f212393873bec Mon Sep 17 00:00:00 2001 From: Fayvor Love Date: Sat, 2 Nov 2024 12:30:15 -0400 Subject: [PATCH 6/9] Import replicate package inside constructor. Signed-off-by: Fayvor Love --- .../langchain_community/llms/replicate.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py index f7f317cb3fa43..64e8e3425338a 100644 --- a/libs/community/langchain_community/llms/replicate.py +++ b/libs/community/langchain_community/llms/replicate.py @@ -13,14 +13,7 @@ if TYPE_CHECKING: from replicate.prediction import Prediction - -try: from replicate.client import Client -except ImportError: - raise ImportError( - "Could not import replicate python package. " - "Please install it with `pip install replicate`." - ) logger = logging.getLogger(__name__) @@ -65,7 +58,7 @@ class Replicate(LLM): stop: List[str] = Field(default_factory=list) """Stop sequences to early-terminate generation.""" - _client: Client = Client() + _client: Client = None model_config = ConfigDict( populate_by_name=True, @@ -112,6 +105,13 @@ def build_extra(cls, values: Dict[str, Any]) -> Any: @model_validator(mode="after") def set_client(self) -> Self: """Add a client to the values.""" + try: + from replicate.client import Client + except ImportError: + raise ImportError( + "Could not import replicate python package. " + "Please install it with `pip install replicate`." + ) self._client = Client(api_token=self.replicate_api_token) return self From 67cae610402c2e7537bda0aa2d8437ef5caaee02 Mon Sep 17 00:00:00 2001 From: Fayvor Love Date: Sat, 2 Nov 2024 12:32:53 -0400 Subject: [PATCH 7/9] Format code. Signed-off-by: Fayvor Love --- libs/community/langchain_community/llms/replicate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py index 64e8e3425338a..70ea672f24fea 100644 --- a/libs/community/langchain_community/llms/replicate.py +++ b/libs/community/langchain_community/llms/replicate.py @@ -12,8 +12,8 @@ from typing_extensions import Self if TYPE_CHECKING: - from replicate.prediction import Prediction from replicate.client import Client + from replicate.prediction import Prediction logger = logging.getLogger(__name__) From 1d09a37bbbd2bdbd3e3b675157b4b440d8c2fb41 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Fri, 13 Dec 2024 16:17:10 -0800 Subject: [PATCH 8/9] Update libs/community/pyproject.toml --- libs/community/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 1eba41f652ed7..ecdfe3590c72d 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -56,7 +56,7 @@ select = [ "E", "F", "I", "T201",] omit = [ "tests/*",] [tool.pytest.ini_options] -addopts = "--strict-markers --strict-config --durations=5 -vv" +addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv" markers = [ "requires: mark tests as requiring a specific library", "scheduled: mark tests to run in scheduled testing", "compile: mark placeholder test used to compile integration tests without running them",] # asyncio_mode = "auto" filterwarnings = [ "ignore::langchain_core._api.beta_decorator.LangChainBetaWarning", "ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test", "ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:test",] From df8a0805d91c1e064502ebf933cfc93d597ee2bf Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Fri, 13 Dec 2024 16:17:17 -0800 Subject: [PATCH 9/9] Update libs/community/pyproject.toml --- libs/community/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index ecdfe3590c72d..822c2b7b94995 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -58,7 +58,7 @@ omit = [ "tests/*",] [tool.pytest.ini_options] addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv" markers = [ "requires: mark tests as requiring a specific library", "scheduled: mark tests to run in scheduled testing", "compile: mark placeholder test used to compile integration tests without running them",] -# asyncio_mode = "auto" +asyncio_mode = "auto" filterwarnings = [ "ignore::langchain_core._api.beta_decorator.LangChainBetaWarning", "ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test", "ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:test",] [tool.poetry.group.test]