Skip to content

Commit

Permalink
Compatibility Implementation for Ragas.io with WatsonxLLM Parameters …
Browse files Browse the repository at this point in the history
…Handling (#5)

* To ensure compatibilty with ragas.io, two two methods were added to ensure only valid parameters are fed into WatsonxLLM and to add the capability to receive parameters outside the params dictionary, which is how ragas invokes the generate method.

* Added test case for invoke with parameters outside params dictionary

* Renamed test function

* fix _override_chat_params and _validate_chat_params method, add support when streaming, make format, make lint

---------

Co-authored-by: julioe-sanchezd <[email protected]>
Co-authored-by: Mateusz Szewczyk <[email protected]>
  • Loading branch information
3 people authored Jul 22, 2024
1 parent b9a5318 commit 0c81f32
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 10 deletions.
58 changes: 48 additions & 10 deletions libs/ibm/langchain_ibm/llms.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import logging
import os
from typing import Any, Dict, Iterator, List, Mapping, Optional, Union
from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union

from ibm_watsonx_ai import Credentials # type: ignore
from ibm_watsonx_ai.foundation_models import Model, ModelInference # type: ignore
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames # type: ignore
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env

logger = logging.getLogger(__name__)
textgen_valid_params = [
value for key, value in GenTextParamsMetaNames.__dict__.items() if key.isupper()
]


class WatsonxLLM(BaseLLM):
Expand Down Expand Up @@ -265,18 +269,50 @@ def get_count_value(key: str, result: Dict[str, Any]) -> int:
"input_token_count": input_token_count,
}

@staticmethod
def _validate_chat_params(
params: Dict[str, Any],
) -> Dict[str, Any]:
"""Validate and fix the chat parameters"""
for param in params.keys():
if param.lower() not in textgen_valid_params:
raise Exception(
f"Parameter {param} is not valid. "
f"Valid parameters are: {textgen_valid_params}"
)
return params

@staticmethod
def _override_chat_params(
params: Dict[str, Any], **kwargs: Any
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Override class parameters with those provided in the invoke method.
Merges the 'params' dictionary with any 'params' found in kwargs,
then updates 'params' with matching keys from kwargs and removes
those keys from kwargs.
"""
for key in list(kwargs.keys()):
if key.lower() in textgen_valid_params:
params[key] = kwargs.pop(key)
return params, kwargs

def _get_chat_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Optional[Dict[str, Any]]:
params = {**self.params} if self.params else {}
params = params | {**kwargs.get("params", {})}
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
params = (
{**self.params, **kwargs.pop("params", {})}
if self.params
else kwargs.pop("params", {})
)
params, kwargs = self._override_chat_params(params, **kwargs)
if stop is not None:
if params and "stop_sequences" in params:
raise ValueError(
"`stop_sequences` found in both the input and default params."
)
params = (params or {}) | {"stop_sequences": stop}
return params
return params, kwargs

def _create_llm_result(self, response: List[dict]) -> LLMResult:
"""Create the LLMResult from the choices and prompts."""
Expand Down Expand Up @@ -360,7 +396,8 @@ def _generate(
response = watsonx_llm.generate(["What is a molecule"])
"""
params = self._get_chat_params(stop=stop, **kwargs)
params, kwargs = self._get_chat_params(stop=stop, **kwargs)
params = self._validate_chat_params(params)
should_stream = stream if stream is not None else self.streaming
if should_stream:
if len(prompts) > 1:
Expand All @@ -383,7 +420,7 @@ def _generate(
return LLMResult(generations=[[generation]])
else:
response = self.watsonx_model.generate(
prompt=prompts, **(kwargs | {"params": params})
prompt=prompts, params=params, **kwargs
)
return self._create_llm_result(response)

Expand All @@ -406,11 +443,12 @@ def _stream(
response = watsonx_llm.stream("What is a molecule")
for chunk in response:
print(chunk, end='')
print(chunk, end='', flush=True)
"""
params = self._get_chat_params(stop=stop, **kwargs)
params, kwargs = self._get_chat_params(stop=stop, **kwargs)
params = self._validate_chat_params(params)
for stream_resp in self.watsonx_model.generate_text_stream(
prompt=prompt, raw_response=True, **(kwargs | {"params": params})
prompt=prompt, params=params, **(kwargs | {"raw_response": True})
):
if not isinstance(stream_resp, dict):
stream_resp = stream_resp.dict()
Expand Down
142 changes: 142 additions & 0 deletions libs/ibm/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ def test_watsonxllm_invoke_with_params_3() -> None:
assert len(response) > 0


def test_watsonxllm_invoke_with_params_4() -> None:
parameters_1 = {
GenTextParamsMetaNames.DECODING_METHOD: "sample",
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
}
parameters_2 = {
"temperature": 0.6,
}

watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
params=parameters_1,
)
response = watsonxllm.invoke("What color sunflower is?", **parameters_2) # type: ignore[arg-type]
print(f"\nResponse: {response}")
assert isinstance(response, str)
assert len(response) > 0


def test_watsonxllm_generate() -> None:
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
Expand Down Expand Up @@ -140,6 +161,18 @@ def test_watsonxllm_generate_with_multiple_prompts() -> None:
assert len(response_text) > 0


def test_watsonxllm_invoke_with_guardrails() -> None:
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
)
response = watsonxllm.invoke("What color sunflower is?", guardrails=True)
print(f"\nResponse: {response}")
assert isinstance(response, str)
assert len(response) > 0


def test_watsonxllm_generate_stream() -> None:
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
Expand Down Expand Up @@ -177,6 +210,115 @@ def test_watsonxllm_stream() -> None:
), "Linked text stream are not the same as generated text"


def test_watsonxllm_stream_with_kwargs() -> None:
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
)
stream_response = watsonxllm.stream("What color sunflower is?", raw_response=True)

for chunk in stream_response:
assert isinstance(
chunk, str
), f"chunk expect type '{str}', actual '{type(chunk)}'"


def test_watsonxllm_stream_with_params() -> None:
parameters = {
GenTextParamsMetaNames.DECODING_METHOD: "greedy",
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
GenTextParamsMetaNames.MIN_NEW_TOKENS: 5,
}
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
params=parameters,
)
response = watsonxllm.invoke("What color sunflower is?")
print(f"\nResponse: {response}")

stream_response = watsonxllm.stream("What color sunflower is?")

linked_text_stream = ""
for chunk in stream_response:
assert isinstance(
chunk, str
), f"chunk expect type '{str}', actual '{type(chunk)}'"
linked_text_stream += chunk
print(f"Linked text stream: {linked_text_stream}")
assert (
response == linked_text_stream
), "Linked text stream are not the same as generated text"


def test_watsonxllm_stream_with_params_2() -> None:
parameters = {
GenTextParamsMetaNames.DECODING_METHOD: "greedy",
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
GenTextParamsMetaNames.MIN_NEW_TOKENS: 5,
}
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
)
stream_response = watsonxllm.stream("What color sunflower is?", params=parameters)

for chunk in stream_response:
assert isinstance(
chunk, str
), f"chunk expect type '{str}', actual '{type(chunk)}'"
print(chunk)


def test_watsonxllm_stream_with_params_3() -> None:
parameters_1 = {
GenTextParamsMetaNames.DECODING_METHOD: "sample",
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
}
parameters_2 = {
GenTextParamsMetaNames.MIN_NEW_TOKENS: 5,
}
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
params=parameters_1,
)
stream_response = watsonxllm.stream("What color sunflower is?", params=parameters_2)

for chunk in stream_response:
assert isinstance(
chunk, str
), f"chunk expect type '{str}', actual '{type(chunk)}'"
print(chunk)


def test_watsonxllm_stream_with_params_4() -> None:
parameters_1 = {
GenTextParamsMetaNames.DECODING_METHOD: "sample",
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
}
parameters_2 = {
"temperature": 0.6,
}
watsonxllm = WatsonxLLM(
model_id=MODEL_ID,
url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type]
project_id=WX_PROJECT_ID,
params=parameters_1,
)
stream_response = watsonxllm.stream("What color sunflower is?", **parameters_2) # type: ignore[arg-type]

for chunk in stream_response:
assert isinstance(
chunk, str
), f"chunk expect type '{str}', actual '{type(chunk)}'"
print(chunk)


def test_watsonxllm_invoke_from_wx_model() -> None:
model = Model(
model_id=MODEL_ID,
Expand Down

0 comments on commit 0c81f32

Please sign in to comment.