diff --git a/libs/community/langchain_community/llms/replicate.py b/libs/community/langchain_community/llms/replicate.py index f7f317cb3fa43..64e8e3425338a 100644 --- a/libs/community/langchain_community/llms/replicate.py +++ b/libs/community/langchain_community/llms/replicate.py @@ -13,14 +13,7 @@ if TYPE_CHECKING: from replicate.prediction import Prediction - -try: from replicate.client import Client -except ImportError: - raise ImportError( - "Could not import replicate python package. " - "Please install it with `pip install replicate`." - ) logger = logging.getLogger(__name__) @@ -65,7 +58,7 @@ class Replicate(LLM): stop: List[str] = Field(default_factory=list) """Stop sequences to early-terminate generation.""" - _client: Client = Client() + _client: Client = None model_config = ConfigDict( populate_by_name=True, @@ -112,6 +105,13 @@ def build_extra(cls, values: Dict[str, Any]) -> Any: @model_validator(mode="after") def set_client(self) -> Self: """Add a client to the values.""" + try: + from replicate.client import Client + except ImportError: + raise ImportError( + "Could not import replicate python package. " + "Please install it with `pip install replicate`." + ) self._client = Client(api_token=self.replicate_api_token) return self