Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WatsonX support #10238

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions docs/extras/integrations/llms/watsonx.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# WatsonX"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is an example to use models on WatsonX.ai, to have more information on how to use WatsonX.ai you can visit this [website](https://www.ibm.com/fr-fr/products/watsonx-ai). There are several models you can use, you can also deploy custom models. To use WatsonX you will need to have an API Key."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: ibm-generative-ai in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (0.2.5)\n",
"Requirement already satisfied: urllib3<2 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from ibm-generative-ai) (1.26.16)\n",
"Requirement already satisfied: requests>=2.31.0 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from ibm-generative-ai) (2.31.0)\n",
"Requirement already satisfied: pydantic<2,>=1.10.10 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from ibm-generative-ai) (1.10.12)\n",
"Requirement already satisfied: python-dotenv>=1.0.0 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from ibm-generative-ai) (1.0.0)\n",
"Requirement already satisfied: aiohttp>=3.8.4 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from ibm-generative-ai) (3.8.5)\n",
"Requirement already satisfied: pyyaml>=0.2.5 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from ibm-generative-ai) (6.0.1)\n",
"Requirement already satisfied: httpx>=0.24.1 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from ibm-generative-ai) (0.24.1)\n",
"Requirement already satisfied: aiolimiter>=1.1.0 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from ibm-generative-ai) (1.1.0)\n",
"Requirement already satisfied: tqdm>=4.65.0 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from ibm-generative-ai) (4.66.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from aiohttp>=3.8.4->ibm-generative-ai) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from aiohttp>=3.8.4->ibm-generative-ai) (3.2.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from aiohttp>=3.8.4->ibm-generative-ai) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from aiohttp>=3.8.4->ibm-generative-ai) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from aiohttp>=3.8.4->ibm-generative-ai) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from aiohttp>=3.8.4->ibm-generative-ai) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from aiohttp>=3.8.4->ibm-generative-ai) (1.3.1)\n",
"Requirement already satisfied: certifi in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from httpx>=0.24.1->ibm-generative-ai) (2023.7.22)\n",
"Requirement already satisfied: httpcore<0.18.0,>=0.15.0 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from httpx>=0.24.1->ibm-generative-ai) (0.17.3)\n",
"Requirement already satisfied: idna in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from httpx>=0.24.1->ibm-generative-ai) (3.4)\n",
"Requirement already satisfied: sniffio in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from httpx>=0.24.1->ibm-generative-ai) (1.3.0)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from pydantic<2,>=1.10.10->ibm-generative-ai) (4.7.1)\n",
"Requirement already satisfied: h11<0.15,>=0.13 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from httpcore<0.18.0,>=0.15.0->httpx>=0.24.1->ibm-generative-ai) (0.14.0)\n",
"Requirement already satisfied: anyio<5.0,>=3.0 in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from httpcore<0.18.0,>=0.15.0->httpx>=0.24.1->ibm-generative-ai) (3.7.1)\n",
"Requirement already satisfied: exceptiongroup in /Users/bignaudbaptiste/Desktop/langchain/.venv/lib/python3.10/site-packages (from anyio<5.0,>=3.0->httpcore<0.18.0,>=0.15.0->httpx>=0.24.1->ibm-generative-ai) (1.1.3)\n"
]
}
],
"source": [
"!pip install ibm-generative-ai"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms.watsonx import Watsonx\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"WATSONX_API_KEY\"] = \"<You Watson API key here>\"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Artificial intelligence (AI) is the simulation of human intelligence processes by machines, especially computer systems. These processes include learning (the acquisition of information and rules for using the information), reasoning (using rules to reach approximate or definite conclusions), and self-correction. AI research has been divided into subfields that often fail to communicate with each other.\n",
"The term artificial intelligence was coined in 1956, but AI research began in the 1940s and 1950\n"
]
}
],
"source": [
"llm = Watsonx(repetition_penalty=2)\n",
"print(llm._call(\"What is Artificial Intelligence?\"))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
80 changes: 80 additions & 0 deletions libs/langchain/langchain/llms/watsonx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, Dict, List, Optional

from langchain.callbacks.manager import (
Callbacks,
)
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Field
from langchain.utils import get_from_env


def _get_api_key() -> str:
return get_from_env("api_key", "WATSONX_API_KEY")


class Watsonx(LLM):
"""WatsonX LLM wrapper."""

model_name: str = "tiiuae/falcon-40b"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be able to document the parameters? Especially ones that might not be obvious for interpretation?

Here's a good example of the format:

max_tokens_to_sample: int = Field(default=256, alias="max_tokens")

api_key: str = Field(default_factory=_get_api_key)
decoding_method: str = "sample"
temperature: float = 0.05
top_p: float = 1
top_k: int = 50
min_new_tokens: int = 1
max_new_tokens: int = 100
api_endpoint: str = "https://workbench-api.res.ibm.com/v1"
repetition_penalty: Optional[float] = None
random_seed: Optional[int] = None
stop_sequences: Optional[List[str]] = None
truncate_input_tokens: Optional[int] = None

@property
def _llm_type(self) -> str:
return "watsonx"

def _call(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you match the method signature exactly:

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Run the LLM on the given prompt and input."""

self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
try:
import genai
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Package source: https://ibm.github.io/ibm-generative-ai/index.html

Looks like a name collision with: https://pypi.org/project/genai/ ?

except ImportError as e:
raise ImportError(
"Cannot import genai, please install with "
"`pip install ibm-generative-ai`."
) from e

creds = genai.credentials.Credentials(api_key=self.api_key)
params = self._identifying_params.copy()
params["stop_sequences"] = stop or params["stop_sequences"]
gen_params = genai.schemas.generate_params.GenerateParams(
**params,
)
model = genai.model.Model(
model=self.model_name, params=gen_params, credentials=creds
)
out = model.generate(prompts=[prompt])
return out[0].generated_text

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"decoding_method": self.decoding_method,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_new_tokens": self.min_new_tokens,
"max_new_tokens": self.max_new_tokens,
"repetition_penalty": self.repetition_penalty,
"random_seed": self.random_seed,
"stop_sequences": self.stop_sequences,
"truncate_input_tokens": self.truncate_input_tokens,
}
Loading