forked from deepset-ai/haystack-core-integrations
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: Fixes to NvidiaRanker (deepset-ai#1191)
* Fixes to NvidiaRanker * Add inits and headers * More headers * updates * Reactivate test * Fix tests * Reenable test and add test
- Loading branch information
Showing
27 changed files
with
188 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
3 changes: 3 additions & 0 deletions
3
integrations/nvidia/src/haystack_integrations/components/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
3 changes: 3 additions & 0 deletions
3
integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
4 changes: 4 additions & 0 deletions
4
integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .document_embedder import NvidiaDocumentEmbedder | ||
from .text_embedder import NvidiaTextEmbedder | ||
from .truncate import EmbeddingTruncateMode | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,17 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import warnings | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
from haystack import Document, component, default_from_dict, default_to_dict | ||
from haystack.utils import Secret, deserialize_secrets_inplace | ||
from tqdm import tqdm | ||
|
||
from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode | ||
from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation | ||
|
||
from .truncate import EmbeddingTruncateMode | ||
|
||
_DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" | ||
|
||
|
||
|
@@ -167,7 +170,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaDocumentEmbedder": | |
:returns: | ||
The deserialized component. | ||
""" | ||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) | ||
init_parameters = data.get("init_parameters", {}) | ||
if init_parameters: | ||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) | ||
return default_from_dict(cls, data) | ||
|
||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: | ||
|
7 changes: 5 additions & 2 deletions
7
integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,16 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import warnings | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from haystack import component, default_from_dict, default_to_dict | ||
from haystack.utils import Secret, deserialize_secrets_inplace | ||
|
||
from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode | ||
from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation | ||
|
||
from .truncate import EmbeddingTruncateMode | ||
|
||
_DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" | ||
|
||
|
||
|
4 changes: 4 additions & 0 deletions
4
integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from enum import Enum | ||
|
||
|
||
|
3 changes: 3 additions & 0 deletions
3
integrations/nvidia/src/haystack_integrations/components/generators/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
1 change: 1 addition & 0 deletions
1
integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .generator import NvidiaGenerator | ||
|
||
__all__ = ["NvidiaGenerator"] |
1 change: 1 addition & 0 deletions
1
integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import warnings | ||
from typing import Any, Dict, List, Optional | ||
|
||
|
3 changes: 3 additions & 0 deletions
3
integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
4 changes: 4 additions & 0 deletions
4
integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .ranker import NvidiaRanker | ||
|
||
__all__ = ["NvidiaRanker"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,17 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import warnings | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from haystack import Document, component, default_from_dict, default_to_dict | ||
from haystack import Document, component, default_from_dict, default_to_dict, logging | ||
from haystack.utils import Secret, deserialize_secrets_inplace | ||
|
||
from haystack_integrations.components.rankers.nvidia.truncate import RankerTruncateMode | ||
from haystack_integrations.utils.nvidia import NimBackend, url_validation | ||
|
||
from .truncate import RankerTruncateMode | ||
logger = logging.getLogger(__name__) | ||
|
||
_DEFAULT_MODEL = "nvidia/nv-rerankqa-mistral-4b-v3" | ||
|
||
|
@@ -51,7 +56,7 @@ def __init__( | |
model: Optional[str] = None, | ||
truncate: Optional[Union[RankerTruncateMode, str]] = None, | ||
api_url: Optional[str] = None, | ||
api_key: Optional[Secret] = None, | ||
api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), | ||
top_k: int = 5, | ||
): | ||
""" | ||
|
@@ -100,6 +105,7 @@ def __init__( | |
self._api_key = Secret.from_env_var("NVIDIA_API_KEY") | ||
self._top_k = top_k | ||
self._initialized = False | ||
self._backend: Optional[Any] = None | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
|
@@ -113,7 +119,7 @@ def to_dict(self) -> Dict[str, Any]: | |
top_k=self._top_k, | ||
truncate=self._truncate, | ||
api_url=self._api_url, | ||
api_key=self._api_key, | ||
api_key=self._api_key.to_dict() if self._api_key else None, | ||
) | ||
|
||
@classmethod | ||
|
@@ -124,7 +130,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaRanker": | |
:param data: A dictionary containing the ranker's attributes. | ||
:returns: The deserialized ranker. | ||
""" | ||
deserialize_secrets_inplace(data, keys=["api_key"]) | ||
init_parameters = data.get("init_parameters", {}) | ||
if init_parameters: | ||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) | ||
return default_from_dict(cls, data) | ||
|
||
def warm_up(self): | ||
|
@@ -170,23 +178,24 @@ def run( | |
msg = "The ranker has not been loaded. Please call warm_up() before running." | ||
raise RuntimeError(msg) | ||
if not isinstance(query, str): | ||
msg = "Ranker expects the `query` parameter to be a string." | ||
msg = "NvidiaRanker expects the `query` parameter to be a string." | ||
raise TypeError(msg) | ||
if not isinstance(documents, list): | ||
msg = "Ranker expects the `documents` parameter to be a list." | ||
msg = "NvidiaRanker expects the `documents` parameter to be a list." | ||
raise TypeError(msg) | ||
if not all(isinstance(doc, Document) for doc in documents): | ||
msg = "Ranker expects the `documents` parameter to be a list of Document objects." | ||
msg = "NvidiaRanker expects the `documents` parameter to be a list of Document objects." | ||
raise TypeError(msg) | ||
if top_k is not None and not isinstance(top_k, int): | ||
msg = "Ranker expects the `top_k` parameter to be an integer." | ||
msg = "NvidiaRanker expects the `top_k` parameter to be an integer." | ||
raise TypeError(msg) | ||
|
||
if len(documents) == 0: | ||
return {"documents": []} | ||
|
||
top_k = top_k if top_k is not None else self._top_k | ||
if top_k < 1: | ||
logger.warning("top_k should be at least 1, returning nothing") | ||
warnings.warn("top_k should be at least 1, returning nothing", stacklevel=2) | ||
return {"documents": []} | ||
|
||
|
4 changes: 4 additions & 0 deletions
4
integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from enum import Enum | ||
|
||
|
||
|
3 changes: 3 additions & 0 deletions
3
integrations/nvidia/src/haystack_integrations/utils/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
4 changes: 4 additions & 0 deletions
4
integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .nim_backend import Model, NimBackend | ||
from .utils import is_hosted, url_validation | ||
|
||
|
4 changes: 4 additions & 0 deletions
4
integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
|
8 changes: 6 additions & 2 deletions
8
integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,13 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import warnings | ||
from typing import List | ||
from typing import List, Optional | ||
from urllib.parse import urlparse, urlunparse | ||
|
||
|
||
def url_validation(api_url: str, default_api_url: str, allowed_paths: List[str]) -> str: | ||
def url_validation(api_url: str, default_api_url: Optional[str], allowed_paths: List[str]) -> str: | ||
""" | ||
Validate and normalize an API URL. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .conftest import MockBackend | ||
|
||
__all__ = ["MockBackend"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
import pytest | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder, NvidiaTextEmbedder | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
|
||
import pytest | ||
|
@@ -104,7 +108,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): | |
}, | ||
} | ||
|
||
def from_dict(self, monkeypatch): | ||
def test_from_dict(self, monkeypatch): | ||
monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") | ||
data = { | ||
"type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", | ||
|
@@ -122,15 +126,32 @@ def from_dict(self, monkeypatch): | |
}, | ||
} | ||
component = NvidiaDocumentEmbedder.from_dict(data) | ||
assert component.model == "nvolveqa_40k" | ||
assert component.model == "playground_nvolveqa_40k" | ||
assert component.api_url == "https://example.com/v1" | ||
assert component.prefix == "prefix" | ||
assert component.suffix == "suffix" | ||
assert component.batch_size == 10 | ||
assert component.progress_bar is False | ||
assert component.meta_fields_to_embed == ["test_field"] | ||
assert component.embedding_separator == " | " | ||
assert component.truncate == EmbeddingTruncateMode.START | ||
|
||
def test_from_dict_defaults(self, monkeypatch): | ||
monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") | ||
data = { | ||
"type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", | ||
"init_parameters": {}, | ||
} | ||
component = NvidiaDocumentEmbedder.from_dict(data) | ||
assert component.model == "nvidia/nv-embedqa-e5-v5" | ||
assert component.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" | ||
assert component.prefix == "" | ||
assert component.suffix == "" | ||
assert component.batch_size == 32 | ||
assert component.progress_bar | ||
assert component.meta_fields_to_embed == [] | ||
assert component.embedding_separator == "\n" | ||
assert component.truncate == EmbeddingTruncateMode.START | ||
assert component.truncate is None | ||
|
||
def test_prepare_texts_to_embed_w_metadata(self): | ||
documents = [ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
|
||
import pytest | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import re | ||
from typing import Any, Optional, Union | ||
|
@@ -256,3 +260,48 @@ def test_warm_up_once(self, monkeypatch) -> None: | |
backend = client._backend | ||
client.warm_up() | ||
assert backend == client._backend | ||
|
||
def test_to_dict(self) -> None: | ||
client = NvidiaRanker() | ||
assert client.to_dict() == { | ||
"type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", | ||
"init_parameters": { | ||
"model": "nvidia/nv-rerankqa-mistral-4b-v3", | ||
"top_k": 5, | ||
"truncate": None, | ||
"api_url": None, | ||
"api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, | ||
}, | ||
} | ||
|
||
def test_from_dict(self) -> None: | ||
client = NvidiaRanker.from_dict( | ||
{ | ||
"type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", | ||
"init_parameters": { | ||
"model": "nvidia/nv-rerankqa-mistral-4b-v3", | ||
"top_k": 5, | ||
"truncate": None, | ||
"api_url": None, | ||
"api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, | ||
}, | ||
} | ||
) | ||
assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3" | ||
assert client._top_k == 5 | ||
assert client._truncate is None | ||
assert client._api_url is None | ||
assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") | ||
|
||
def test_from_dict_defaults(self) -> None: | ||
client = NvidiaRanker.from_dict( | ||
{ | ||
"type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", | ||
"init_parameters": {}, | ||
} | ||
) | ||
assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3" | ||
assert client._top_k == 5 | ||
assert client._truncate is None | ||
assert client._api_url is None | ||
assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") |
Oops, something went wrong.