Skip to content

Commit

Permalink
Add support for passing token to initialise ChatWatsonx on cloud (#44)
Browse files Browse the repository at this point in the history
* Add support for passing token to initialise ChatWatsonx on cloud
  • Loading branch information
MateuszOssGit authored Dec 2, 2024
1 parent ec5b5f0 commit 3d74530
Show file tree
Hide file tree
Showing 10 changed files with 372 additions and 268 deletions.
24 changes: 22 additions & 2 deletions libs/ibm/langchain_ibm/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ class ChatWatsonx(BaseChatModel):
"""ID of the Watson Studio space."""

url: SecretStr = Field(
alias="url", default_factory=secret_from_env("WATSONX_URL", default=None)
alias="url",
default_factory=secret_from_env("WATSONX_URL", default=None), # type: ignore[assignment]
)
"""URL to the Watson Machine Learning or CPD instance."""

Expand Down Expand Up @@ -507,6 +508,9 @@ class ChatWatsonx(BaseChatModel):
We generally recommend altering this or top_p but not both."""

response_format: Optional[dict] = None
"""The chat response format parameters."""

top_p: Optional[float] = None
"""An alternative to sampling with temperature, called nucleus sampling,
where the model considers the results of the tokens with top_p probability
Expand All @@ -515,6 +519,10 @@ class ChatWatsonx(BaseChatModel):
We generally recommend altering this or temperature but not both."""

time_limit: Optional[int] = None
"""Time limit in milliseconds - if not completed within this time,
generation will stop."""

verify: Union[str, bool, None] = None
"""You can pass one of following as verify:
* the path to a CA_BUNDLE file
Expand Down Expand Up @@ -596,8 +604,10 @@ def validate_environment(self) -> Self:
"max_tokens": self.max_tokens,
"n": self.n,
"presence_penalty": self.presence_penalty,
"response_format": self.response_format,
"temperature": self.temperature,
"top_p": self.top_p,
"time_limit": self.time_limit,
}.items()
if v is not None
}
Expand All @@ -619,7 +629,15 @@ def validate_environment(self) -> Self:
check_for_attribute(self.url, "url", "WATSONX_URL")

if "cloud.ibm.com" in self.url.get_secret_value():
check_for_attribute(self.apikey, "apikey", "WATSONX_APIKEY")
if not self.token and not self.apikey:
raise ValueError(
"Did not find 'apikey' or 'token',"
" please add an environment variable"
" `WATSONX_APIKEY` or 'WATSONX_TOKEN' "
"which contains it,"
" or pass 'apikey' or 'token'"
" as a named parameter."
)
else:
if not self.token and not self.password and not self.apikey:
raise ValueError(
Expand Down Expand Up @@ -753,8 +771,10 @@ def _merge_params(params: dict, kwargs: dict) -> dict:
"max_tokens",
"n",
"presence_penalty",
"response_format",
"temperature",
"top_p",
"time_limit",
]:
if kwargs.get(k) is not None:
param_updates[k] = kwargs.pop(k)
Expand Down
13 changes: 11 additions & 2 deletions libs/ibm/langchain_ibm/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class WatsonxEmbeddings(BaseModel, LangChainEmbeddings):
"""ID of the Watson Studio space."""

url: SecretStr = Field(
alias="url", default_factory=secret_from_env("WATSONX_URL", default=None)
alias="url",
default_factory=secret_from_env("WATSONX_URL", default=None), # type: ignore[assignment]
)
"""URL to the Watson Machine Learning or CPD instance."""

Expand Down Expand Up @@ -99,7 +100,15 @@ def validate_environment(self) -> Self:
check_for_attribute(self.url, "url", "WATSONX_URL")

if "cloud.ibm.com" in self.url.get_secret_value():
check_for_attribute(self.apikey, "apikey", "WATSONX_APIKEY")
if not self.token and not self.apikey:
raise ValueError(
"Did not find 'apikey' or 'token',"
" please add an environment variable"
" `WATSONX_APIKEY` or 'WATSONX_TOKEN' "
"which contains it,"
" or pass 'apikey' or 'token'"
" as a named parameter."
)
else:
if not self.token and not self.password and not self.apikey:
raise ValueError(
Expand Down
13 changes: 11 additions & 2 deletions libs/ibm/langchain_ibm/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class WatsonxLLM(BaseLLM):
"""ID of the Watson Studio space."""

url: SecretStr = Field(
alias="url", default_factory=secret_from_env("WATSONX_URL", default=None)
alias="url",
default_factory=secret_from_env("WATSONX_URL", default=None), # type: ignore[assignment]
)
"""URL to the Watson Machine Learning or CPD instance."""

Expand Down Expand Up @@ -182,7 +183,15 @@ def validate_environment(self) -> Self:
check_for_attribute(self.url, "url", "WATSONX_URL")

if "cloud.ibm.com" in self.url.get_secret_value():
check_for_attribute(self.apikey, "apikey", "WATSONX_APIKEY")
if not self.token and not self.apikey:
raise ValueError(
"Did not find 'apikey' or 'token',"
" please add an environment variable"
" `WATSONX_APIKEY` or 'WATSONX_TOKEN' "
"which contains it,"
" or pass 'apikey' or 'token'"
" as a named parameter."
)
else:
if not self.token and not self.password and not self.apikey:
raise ValueError(
Expand Down
13 changes: 11 additions & 2 deletions libs/ibm/langchain_ibm/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class WatsonxRerank(BaseDocumentCompressor):
"""ID of the Watson Studio space."""

url: SecretStr = Field(
alias="url", default_factory=secret_from_env("WATSONX_URL", default=None)
alias="url",
default_factory=secret_from_env("WATSONX_URL", default=None), # type: ignore[assignment]
)
"""URL to the Watson Machine Learning or CPD instance."""

Expand Down Expand Up @@ -132,7 +133,15 @@ def validate_environment(self) -> Self:
check_for_attribute(self.url, "url", "WATSONX_URL")

if "cloud.ibm.com" in self.url.get_secret_value():
check_for_attribute(self.apikey, "apikey", "WATSONX_APIKEY")
if not self.token and not self.apikey:
raise ValueError(
"Did not find 'apikey' or 'token',"
" please add an environment variable"
" `WATSONX_APIKEY` or 'WATSONX_TOKEN' "
"which contains it,"
" or pass 'apikey' or 'token'"
" as a named parameter."
)
else:
if not self.token and not self.password and not self.apikey:
raise ValueError(
Expand Down
Loading

0 comments on commit 3d74530

Please sign in to comment.