-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* NvidiaGenerator first draft * Refine generator * Move function to get model id in client * Simplify invocation to keep it in line with embedders * Rename nvidia_generator.py to generator.py * Export NvidiaGenerator at package level * Add NvidiaGenerator tests * Fix embedders tests * Update docstring * Remove some unecessary logic, add usage output, fix docstrings * Update integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py * Fix linting --------- Co-authored-by: Madeesh Kannan <[email protected]>
- Loading branch information
1 parent
24c06fb
commit ead381e
Showing
12 changed files
with
465 additions
and
27 deletions.
There are no files selected for viewing
21 changes: 0 additions & 21 deletions
21
integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_schema.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 4 additions & 1 deletion
5
integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from .generator import NvidiaGenerator | ||
|
||
__all__ = ["NvidiaGenerator"] |
69 changes: 69 additions & 0 deletions
69
integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_schema.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from dataclasses import asdict, dataclass | ||
from typing import Any, Dict, List, Optional | ||
|
||
|
||
@dataclass | ||
class Message: | ||
content: str | ||
role: str | ||
|
||
|
||
@dataclass | ||
class GenerationRequest: | ||
messages: List[Message] | ||
temperature: float = 0.2 | ||
top_p: float = 0.7 | ||
max_tokens: int = 1024 | ||
seed: Optional[int] = None | ||
bad: Optional[List[str]] = None | ||
stop: Optional[List[str]] = None | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
return asdict(self) | ||
|
||
|
||
@dataclass | ||
class Choice: | ||
index: int | ||
message: Message | ||
finish_reason: str | ||
|
||
|
||
@dataclass | ||
class Usage: | ||
completion_tokens: int | ||
prompt_tokens: int | ||
total_tokens: int | ||
|
||
|
||
@dataclass | ||
class GenerationResponse: | ||
id: str | ||
choices: List[Choice] | ||
usage: Usage | ||
|
||
@classmethod | ||
def from_dict(cls, data: dict) -> "GenerationResponse": | ||
try: | ||
return cls( | ||
id=data["id"], | ||
choices=[ | ||
Choice( | ||
index=choice["index"], | ||
message=Message(content=choice["message"]["content"], role=choice["message"]["role"]), | ||
finish_reason=choice["finish_reason"], | ||
) | ||
for choice in data["choices"] | ||
], | ||
usage=Usage( | ||
completion_tokens=data["usage"]["completion_tokens"], | ||
prompt_tokens=data["usage"]["prompt_tokens"], | ||
total_tokens=data["usage"]["total_tokens"], | ||
), | ||
) | ||
except (KeyError, TypeError) as e: | ||
msg = f"Failed to parse {cls.__name__} from data: {data}" | ||
raise ValueError(msg) from e |
2 changes: 1 addition & 1 deletion
2
integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
154 changes: 154 additions & 0 deletions
154
integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from haystack import component, default_from_dict, default_to_dict | ||
from haystack.utils.auth import Secret, deserialize_secrets_inplace | ||
from haystack_integrations.utils.nvidia import NvidiaCloudFunctionsClient | ||
|
||
from ._schema import GenerationRequest, GenerationResponse, Message | ||
from .models import NvidiaGeneratorModel | ||
|
||
|
||
@component | ||
class NvidiaGenerator: | ||
""" | ||
A component for generating text using generative models provided by | ||
[NVIDIA AI Foundation Endpoints](https://www.nvidia.com/en-us/ai-data-science/foundation-models/). | ||
Usage example: | ||
```python | ||
from haystack_integrations.components.generators.nvidia import NvidiaGenerator | ||
generator = NvidiaGenerator( | ||
model=NvidiaGeneratorModel.NV_LLAMA2_RLHF_70B, | ||
model_arguments={ | ||
"temperature": 0.2, | ||
"top_p": 0.7, | ||
"max_tokens": 1024, | ||
"seed": None, | ||
"bad": None, | ||
"stop": None, | ||
}, | ||
) | ||
generator.warm_up() | ||
result = generator.run(prompt="What is the answer?") | ||
print(result["replies"]) | ||
print(result["meta"]) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: Union[str, NvidiaGeneratorModel], | ||
api_key: Secret = Secret.from_env_var("NVIDIA_API_KEY"), | ||
model_arguments: Optional[Dict[str, Any]] = None, | ||
): | ||
""" | ||
Create a NvidiaGenerator component. | ||
:param model: | ||
Name of the model to use for text generation. | ||
See the [Nvidia catalog](https://catalog.ngc.nvidia.com/ai-foundation-models) | ||
for more information on the supported models. | ||
:param api_key: | ||
Nvidia API key to use for authentication. | ||
:param model_arguments: | ||
Additional arguments to pass to the model provider. Different models accept different arguments. | ||
Search your model in the [Nvidia catalog](https://catalog.ngc.nvidia.com/ai-foundation-models) | ||
to know the supported arguments. | ||
:raises ValueError: If `model` is not supported. | ||
""" | ||
if isinstance(model, str): | ||
model = NvidiaGeneratorModel.from_str(model) | ||
|
||
self._model = model | ||
self._api_key = api_key | ||
self._model_arguments = model_arguments or {} | ||
# This is initialized in warm_up | ||
self._model_id = None | ||
|
||
self._client = NvidiaCloudFunctionsClient( | ||
api_key=api_key, | ||
headers={ | ||
"Content-Type": "application/json", | ||
"Accept": "application/json", | ||
}, | ||
) | ||
|
||
def warm_up(self): | ||
""" | ||
Initializes the component. | ||
""" | ||
if self._model_id is not None: | ||
return | ||
self._model_id = self._client.get_model_nvcf_id(str(self._model)) | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serializes the component to a dictionary. | ||
:returns: | ||
Dictionary with serialized data. | ||
""" | ||
return default_to_dict( | ||
self, model=str(self._model), api_key=self._api_key.to_dict(), model_arguments=self._model_arguments | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "NvidiaGenerator": | ||
""" | ||
Deserializes the component from a dictionary. | ||
:param data: | ||
Dictionary to deserialize from. | ||
:returns: | ||
Deserialized component. | ||
""" | ||
init_params = data.get("init_parameters", {}) | ||
deserialize_secrets_inplace(init_params, ["api_key"]) | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(replies=List[str], meta=List[Dict[str, Any]], usage=Dict[str, int]) | ||
def run(self, prompt: str): | ||
""" | ||
Queries the model with the provided prompt. | ||
:param prompt: | ||
Text to be sent to the generative model. | ||
:returns: | ||
A dictionary with the following keys: | ||
- `replies` - Replies generated by the model. | ||
- `meta` - Metadata for each reply. | ||
- `usage` - Usage statistics for the model. | ||
""" | ||
if self._model_id is None: | ||
msg = "The generation model has not been loaded. Call warm_up() before running." | ||
raise RuntimeError(msg) | ||
|
||
messages = [Message(role="user", content=prompt)] | ||
request = GenerationRequest(messages=messages, **self._model_arguments).to_dict() | ||
json_response = self._client.query_function(self._model_id, request) | ||
|
||
replies = [] | ||
meta = [] | ||
data = GenerationResponse.from_dict(json_response) | ||
for choice in data.choices: | ||
replies.append(choice.message.content) | ||
meta.append( | ||
{ | ||
"role": choice.message.role, | ||
"finish_reason": choice.finish_reason, | ||
} | ||
) | ||
|
||
usage = { | ||
"completion_tokens": data.usage.completion_tokens, | ||
"prompt_tokens": data.usage.prompt_tokens, | ||
"total_tokens": data.usage.total_tokens, | ||
} | ||
|
||
return {"replies": replies, "meta": meta, "usage": usage} |
35 changes: 35 additions & 0 deletions
35
integrations/nvidia/src/haystack_integrations/components/generators/nvidia/models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from enum import Enum | ||
|
||
|
||
class NvidiaGeneratorModel(Enum): | ||
""" | ||
Generator models supported by NvidiaGenerator and NvidiaChatGenerator. | ||
""" | ||
|
||
NV_LLAMA2_RLHF_70B = "playground_nv_llama2_rlhf_70b" | ||
STEERLM_LLAMA_70B = "playground_steerlm_llama_70b" | ||
NEMOTRON_STEERLM_8B = "playground_nemotron_steerlm_8b" | ||
NEMOTRON_QA_8B = "playground_nemotron_qa_8b" | ||
|
||
def __str__(self): | ||
return self.value | ||
|
||
@classmethod | ||
def from_str(cls, string: str) -> "NvidiaGeneratorModel": | ||
""" | ||
Create a generator model from a string. | ||
:param string: | ||
String to convert. | ||
:returns: | ||
A generator model. | ||
""" | ||
enum_map = {e.value: e for e in NvidiaGeneratorModel} | ||
models = enum_map.get(string) | ||
if models is None: | ||
msg = f"Unknown model '{string}'. Supported models are: {list(enum_map.keys())}" | ||
raise ValueError(msg) | ||
return models |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.