Skip to content

Commit

Permalink
nvidia-ai-endpoints[patch]: model arguments (e.g. temperature) on con…
Browse files Browse the repository at this point in the history
…struction bug (#17290)

- **Issue:** Issue with model argument support (been there for a while
actually):
- Non-specially-handled arguments like temperature don't work when
passed through constructor.
- Such arguments DO work quite well with `bind`, but also do not abide
by field requirements.
- Since initial push, server-side error messages have gotten better and
v0.0.2 raises better exceptions. So maybe it's better to let server-side
handle such issues?
- **Description:**
- Removed ChatNVIDIA's argument fields in favor of
`model_kwargs`/`model_kws` arguments which aggregates constructor kwargs
(from constructor pathway) and merges them with call kwargs (bind
pathway).
- Shuffled a few functions from `_NVIDIAClient` to `ChatNVIDIA` to
streamline construction for future integrations.
- Minor/Optional: Old services didn't have stop support, so client-side
stopping was implemented. Now do both.
- **Any Breaking Changes:** Minor breaking changes if you strongly rely
on chat_model.temperature, etc. This is captured by
chat_model.model_kwargs.

PR passes tests and example notebooks and example testing. Still gonna
chat with some people, so leaving as draft for now.

---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
VKudlay and efriis authored Feb 9, 2024
1 parent 932c52c commit 5f9ac69
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 240 deletions.
297 changes: 190 additions & 107 deletions docs/docs/integrations/chat/nvidia_ai_endpoints.ipynb

Large diffs are not rendered by default.

44 changes: 30 additions & 14 deletions docs/docs/integrations/text_embedding/nvidia_ai_endpoints.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install --upgrade --quiet langchain-nvidia-ai-endpoints"
]
Expand All @@ -56,15 +64,23 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hoF41-tNczS3",
"outputId": "7f2833dc-191c-4d73-b823-7b2745a93a2f"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Valid NVIDIA_API_KEY already in environment. Delete to reset\n"
]
}
],
"source": [
"import getpass\n",
"import os\n",
Expand Down Expand Up @@ -105,7 +121,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {
"id": "hbXmJssPdIPX"
},
Expand Down Expand Up @@ -180,7 +196,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand All @@ -194,15 +210,15 @@
"output_type": "stream",
"text": [
"Single Query Embedding: \n",
"\u001b[1mExecuted in 1.39 seconds.\u001b[0m\n",
"\u001b[1mExecuted in 2.19 seconds.\u001b[0m\n",
"Shape: (1024,)\n",
"\n",
"Sequential Embedding: \n",
"\u001b[1mExecuted in 3.20 seconds.\u001b[0m\n",
"\u001b[1mExecuted in 3.16 seconds.\u001b[0m\n",
"Shape: (5, 1024)\n",
"\n",
"Batch Query Embedding: \n",
"\u001b[1mExecuted in 1.52 seconds.\u001b[0m\n",
"\u001b[1mExecuted in 1.23 seconds.\u001b[0m\n",
"Shape: (5, 1024)\n"
]
}
Expand Down Expand Up @@ -260,7 +276,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand All @@ -274,11 +290,11 @@
"output_type": "stream",
"text": [
"Single Document Embedding: \n",
"\u001b[1mExecuted in 0.76 seconds.\u001b[0m\n",
"\u001b[1mExecuted in 0.52 seconds.\u001b[0m\n",
"Shape: (1024,)\n",
"\n",
"Batch Document Embedding: \n",
"\u001b[1mExecuted in 0.86 seconds.\u001b[0m\n",
"\u001b[1mExecuted in 0.89 seconds.\u001b[0m\n",
"Shape: (5, 1024)\n"
]
}
Expand Down Expand Up @@ -324,7 +340,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -341,7 +357,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import aiohttp
import requests
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
Expand Down Expand Up @@ -440,10 +439,6 @@ class _NVIDIAClient(BaseModel):

model: str = Field(..., description="Name of the model to invoke")

temperature: float = Field(0.2, le=1.0, gt=0.0)
top_p: float = Field(0.7, le=1.0, ge=0.0)
max_tokens: int = Field(1024, le=1024, ge=32)

####################################################################################

@root_validator(pre=True)
Expand Down Expand Up @@ -485,67 +480,3 @@ def get_model_details(self, model: Optional[str] = None) -> dict:
known_fns = self.client.available_functions
fn_spec = [f for f in known_fns if f.get("id") == model_key][0]
return fn_spec

def get_generation(
self,
inputs: Sequence[Dict],
labels: Optional[dict] = None,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> dict:
"""Call to client generate method with call scope"""
payload = self.get_payload(inputs=inputs, stream=False, labels=labels, **kwargs)
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
return out

def get_stream(
self,
inputs: Sequence[Dict],
labels: Optional[dict] = None,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> Iterator:
"""Call to client stream method with call scope"""
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
return self.client.get_req_stream(self.model, stop=stop, payload=payload)

def get_astream(
self,
inputs: Sequence[Dict],
labels: Optional[dict] = None,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator:
"""Call to client astream methods with call scope"""
payload = self.get_payload(inputs=inputs, stream=True, labels=labels, **kwargs)
return self.client.get_req_astream(self.model, stop=stop, payload=payload)

def get_payload(
self, inputs: Sequence[Dict], labels: Optional[dict] = None, **kwargs: Any
) -> dict:
"""Generates payload for the _NVIDIAClient API to send to service."""
return {
**self.preprocess(inputs=inputs, labels=labels),
**kwargs,
}

def preprocess(self, inputs: Sequence[Dict], labels: Optional[dict] = None) -> dict:
"""Prepares a message or list of messages for the payload"""
messages = [self.prep_msg(m) for m in inputs]
if labels:
# (WFH) Labels are currently (?) always passed as an assistant
# suffix message, but this API seems less stable.
messages += [{"labels": labels, "role": "assistant"}]
return {"messages": messages}

def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
"""Helper Method: Ensures a message is a dictionary with a role and content."""
if isinstance(msg, str):
# (WFH) this shouldn't ever be reached but leaving this here bcs
# it's a Chesterton's fence I'm unwilling to touch
return dict(role="user", content=msg)
if isinstance(msg, dict):
if msg.get("content", None) is None:
raise ValueError(f"Message {msg} has no content")
return msg
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from langchain_core.language_models.chat_models import SimpleChatModel
from langchain_core.messages import BaseMessage, ChatMessage, ChatMessageChunk
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.pydantic_v1 import Field

from langchain_nvidia_ai_endpoints import _common as nvidia_ai_endpoints

Expand Down Expand Up @@ -116,6 +117,14 @@ class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
response = model.invoke("Hello")
"""

temperature: Optional[float] = Field(description="Sampling temperature in [0, 1]")
max_tokens: Optional[int] = Field(description="Maximum # of tokens to generate")
top_p: Optional[float] = Field(description="Top-p for distribution sampling")
seed: Optional[int] = Field(description="The seed for deterministic results")
bad: Optional[Sequence[str]] = Field(description="Bad words to avoid (cased)")
stop: Optional[Sequence[str]] = Field(description="Stop words (cased)")
labels: Optional[Dict[str, float]] = Field(description="Steering parameters")

@property
def _llm_type(self) -> str:
"""Return type of NVIDIA AI Foundation Model Interface."""
Expand All @@ -126,14 +135,11 @@ def _call(
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
labels: Optional[dict] = None,
**kwargs: Any,
) -> str:
"""Invoke on a single list of chat messages."""
inputs = self.custom_preprocess(messages)
responses = self.get_generation(
inputs=inputs, stop=stop, labels=labels, **kwargs
)
responses = self.get_generation(inputs=inputs, stop=stop, **kwargs)
outputs = self.custom_postprocess(responses)
return outputs

Expand All @@ -148,14 +154,11 @@ def _stream(
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
labels: Optional[dict] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Allows streaming to model!"""
inputs = self.custom_preprocess(messages)
for response in self.get_stream(
inputs=inputs, stop=stop, labels=labels, **kwargs
):
for response in self.get_stream(inputs=inputs, stop=stop, **kwargs):
chunk = self._get_filled_chunk(self.custom_postprocess(response))
yield chunk
if run_manager:
Expand All @@ -166,13 +169,10 @@ async def _astream(
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
labels: Optional[dict] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
inputs = self.custom_preprocess(messages)
async for response in self.get_astream(
inputs=inputs, stop=stop, labels=labels, **kwargs
):
async for response in self.get_astream(inputs=inputs, stop=stop, **kwargs):
chunk = self._get_filled_chunk(self.custom_postprocess(response))
yield chunk
if run_manager:
Expand Down Expand Up @@ -229,7 +229,78 @@ def preprocess_msg(self, msg: BaseMessage) -> Dict[str, str]:
def custom_postprocess(self, msg: dict) -> str:
if "content" in msg:
return msg["content"]
logger.warning(
f"Got ambiguous message in postprocessing; returning as-is: msg = {msg}"
)
elif "b64_json" in msg:
return msg["b64_json"]
return str(msg)

######################################################################################
## Core client-side interfaces

def get_generation(
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> dict:
"""Call to client generate method with call scope"""
stop = kwargs.get("stop", None)
payload = self.get_payload(inputs=inputs, stream=False, **kwargs)
out = self.client.get_req_generation(self.model, stop=stop, payload=payload)
return out

def get_stream(
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> Iterator:
"""Call to client stream method with call scope"""
stop = kwargs.get("stop", None)
payload = self.get_payload(inputs=inputs, stream=True, **kwargs)
return self.client.get_req_stream(self.model, stop=stop, payload=payload)

def get_astream(
self,
inputs: Sequence[Dict],
**kwargs: Any,
) -> AsyncIterator:
"""Call to client astream methods with call scope"""
stop = kwargs.get("stop", None)
payload = self.get_payload(inputs=inputs, stream=True, **kwargs)
return self.client.get_req_astream(self.model, stop=stop, payload=payload)

def get_payload(self, inputs: Sequence[Dict], **kwargs: Any) -> dict:
"""Generates payload for the _NVIDIAClient API to send to service."""
attr_kwargs = {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"seed": self.seed,
"bad": self.bad,
"stop": self.stop,
"labels": self.labels,
}
attr_kwargs = {k: v for k, v in attr_kwargs.items() if v is not None}
new_kwargs = {**attr_kwargs, **kwargs}
return self.prep_payload(inputs=inputs, **new_kwargs)

def prep_payload(self, inputs: Sequence[Dict], **kwargs: Any) -> dict:
"""Prepares a message or list of messages for the payload"""
messages = [self.prep_msg(m) for m in inputs]
if kwargs.get("labels"):
# (WFH) Labels are currently (?) always passed as an assistant
# suffix message, but this API seems less stable.
messages += [{"labels": kwargs.pop("labels"), "role": "assistant"}]
if kwargs.get("stop") is None:
kwargs.pop("stop")
return {"messages": messages, **kwargs}

def prep_msg(self, msg: Union[str, dict, BaseMessage]) -> dict:
"""Helper Method: Ensures a message is a dictionary with a role and content."""
if isinstance(msg, str):
# (WFH) this shouldn't ever be reached but leaving this here bcs
# it's a Chesterton's fence I'm unwilling to touch
return dict(role="user", content=msg)
if isinstance(msg, dict):
if msg.get("content", None) is None:
raise ValueError(f"Message {msg} has no content")
return msg
raise ValueError(f"Unknown message received: {msg} of type {type(msg)}")
Loading

0 comments on commit 5f9ac69

Please sign in to comment.