Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: use api token from Replicate constructor for service access #27859

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions libs/community/langchain_community/llms/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from langchain_core.utils import get_from_dict_or_env, pre_init
from langchain_core.utils.pydantic import get_fields
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self

if TYPE_CHECKING:
from replicate.client import Client
from replicate.prediction import Prediction

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -56,6 +58,8 @@ class Replicate(LLM):
stop: List[str] = Field(default_factory=list)
"""Stop sequences to early-terminate generation."""

_client: Client = None

model_config = ConfigDict(
populate_by_name=True,
extra="forbid",
Expand Down Expand Up @@ -98,6 +102,19 @@ def build_extra(cls, values: Dict[str, Any]) -> Any:
values["model_kwargs"] = extra
return values

@model_validator(mode="after")
def set_client(self) -> Self:
"""Add a client to the values."""
try:
from replicate.client import Client
except ImportError:
raise ImportError(
"Could not import replicate python package. "
"Please install it with `pip install replicate`."
)
self._client = Client(api_token=self.replicate_api_token)
return self

@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
Expand Down Expand Up @@ -188,22 +205,14 @@ def _stream(
break

def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction:
try:
import replicate as replicate_python
except ImportError:
raise ImportError(
"Could not import replicate python package. "
"Please install it with `pip install replicate`."
)

# get the model and version
if self.version_obj is None:
if ":" in self.model:
model_str, version_str = self.model.split(":")
model = replicate_python.models.get(model_str)
model = self._client.models.get(model_str)
self.version_obj = model.versions.get(version_str)
else:
model = replicate_python.models.get(self.model)
model = self._client.models.get(self.model)
self.version_obj = model.latest_version

if self.prompt_key is None:
Expand All @@ -225,8 +234,8 @@ def _create_prediction(self, prompt: str, **kwargs: Any) -> Prediction:

# if it's an official model
if ":" not in self.model:
return replicate_python.models.predictions.create(self.model, input=input_)
return self._client.models.predictions.create(self.model, input=input_)
else:
return replicate_python.predictions.create(
return self._client.predictions.create(
version=self.version_obj, input=input_
)
4 changes: 2 additions & 2 deletions libs/community/pyproject.toml
efriis marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ select = [ "E", "F", "I", "T201",]
omit = [ "tests/*",]

[tool.pytest.ini_options]
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv"
addopts = "--strict-markers --strict-config --durations=5 -vv"
efriis marked this conversation as resolved.
Show resolved Hide resolved
markers = [ "requires: mark tests as requiring a specific library", "scheduled: mark tests to run in scheduled testing", "compile: mark placeholder test used to compile integration tests without running them",]
asyncio_mode = "auto"
# asyncio_mode = "auto"
efriis marked this conversation as resolved.
Show resolved Hide resolved
filterwarnings = [ "ignore::langchain_core._api.beta_decorator.LangChainBetaWarning", "ignore::langchain_core._api.deprecation.LangChainDeprecationWarning:test", "ignore::langchain_core._api.deprecation.LangChainPendingDeprecationWarning:test",]

[tool.poetry.group.test]
Expand Down
18 changes: 18 additions & 0 deletions libs/community/tests/integration_tests/llms/test_replicate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test Replicate API wrapper."""

import os

from langchain_community.llms.replicate import Replicate
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

Expand Down Expand Up @@ -47,3 +49,19 @@ def test_replicate_model_kwargs() -> None:
def test_replicate_input() -> None:
llm = Replicate(model=TEST_MODEL_LANG, input={"max_new_tokens": 10})
assert llm.model_kwargs == {"max_new_tokens": 10}


def test_replicate_api_token_propagation() -> None:
"""Test that API token passed to the model is used to access the service."""
# Grab the api token from the environment variable.
api_token = os.getenv("REPLICATE_API_TOKEN")

# Reset the environment variable to ensure it's not available.
os.environ["REPLICATE_API_TOKEN"] = "yo"

# Pass the api token into the model.
llm = Replicate(model=TEST_MODEL_HELLO, replicate_api_token=api_token)
output = llm.invoke("What is a duck?")

assert output
assert isinstance(output, str)
Loading