Skip to content

Commit

Permalink
fix: Fixes to NvidiaRanker (#1191)
Browse files Browse the repository at this point in the history
* Fixes to NvidiaRanker

* Add inits and headers

* More headers

* updates

* Reactivate test

* Fix tests

* Reenable test and add test
  • Loading branch information
sjrl authored Nov 14, 2024
1 parent c2d1b20 commit 3c04cfe
Show file tree
Hide file tree
Showing 27 changed files with 188 additions and 22 deletions.
2 changes: 1 addition & 1 deletion integrations/nvidia/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai", "requests"]
dependencies = ["haystack-ai", "requests", "tqdm"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme"
Expand Down
3 changes: 3 additions & 0 deletions integrations/nvidia/src/haystack_integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .document_embedder import NvidiaDocumentEmbedder
from .text_embedder import NvidiaTextEmbedder
from .truncate import EmbeddingTruncateMode
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from tqdm import tqdm

from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode
from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation

from .truncate import EmbeddingTruncateMode

_DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia"


Expand Down Expand Up @@ -167,7 +170,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaDocumentEmbedder":
:returns:
The deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
init_parameters = data.get("init_parameters", {})
if init_parameters:
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace

from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode
from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation

from .truncate import EmbeddingTruncateMode

_DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia"


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from enum import Enum


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .generator import NvidiaGenerator

__all__ = ["NvidiaGenerator"]
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Any, Dict, List, Optional

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .ranker import NvidiaRanker

__all__ = ["NvidiaRanker"]
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.utils import Secret, deserialize_secrets_inplace

from haystack_integrations.components.rankers.nvidia.truncate import RankerTruncateMode
from haystack_integrations.utils.nvidia import NimBackend, url_validation

from .truncate import RankerTruncateMode
logger = logging.getLogger(__name__)

_DEFAULT_MODEL = "nvidia/nv-rerankqa-mistral-4b-v3"

Expand Down Expand Up @@ -51,7 +56,7 @@ def __init__(
model: Optional[str] = None,
truncate: Optional[Union[RankerTruncateMode, str]] = None,
api_url: Optional[str] = None,
api_key: Optional[Secret] = None,
api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"),
top_k: int = 5,
):
"""
Expand Down Expand Up @@ -100,6 +105,7 @@ def __init__(
self._api_key = Secret.from_env_var("NVIDIA_API_KEY")
self._top_k = top_k
self._initialized = False
self._backend: Optional[Any] = None

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -113,7 +119,7 @@ def to_dict(self) -> Dict[str, Any]:
top_k=self._top_k,
truncate=self._truncate,
api_url=self._api_url,
api_key=self._api_key,
api_key=self._api_key.to_dict() if self._api_key else None,
)

@classmethod
Expand All @@ -124,7 +130,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaRanker":
:param data: A dictionary containing the ranker's attributes.
:returns: The deserialized ranker.
"""
deserialize_secrets_inplace(data, keys=["api_key"])
init_parameters = data.get("init_parameters", {})
if init_parameters:
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)

def warm_up(self):
Expand Down Expand Up @@ -170,23 +178,24 @@ def run(
msg = "The ranker has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)
if not isinstance(query, str):
msg = "Ranker expects the `query` parameter to be a string."
msg = "NvidiaRanker expects the `query` parameter to be a string."
raise TypeError(msg)
if not isinstance(documents, list):
msg = "Ranker expects the `documents` parameter to be a list."
msg = "NvidiaRanker expects the `documents` parameter to be a list."
raise TypeError(msg)
if not all(isinstance(doc, Document) for doc in documents):
msg = "Ranker expects the `documents` parameter to be a list of Document objects."
msg = "NvidiaRanker expects the `documents` parameter to be a list of Document objects."
raise TypeError(msg)
if top_k is not None and not isinstance(top_k, int):
msg = "Ranker expects the `top_k` parameter to be an integer."
msg = "NvidiaRanker expects the `top_k` parameter to be an integer."
raise TypeError(msg)

if len(documents) == 0:
return {"documents": []}

top_k = top_k if top_k is not None else self._top_k
if top_k < 1:
logger.warning("top_k should be at least 1, returning nothing")
warnings.warn("top_k should be at least 1, returning nothing", stacklevel=2)
return {"documents": []}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from enum import Enum


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .nim_backend import Model, NimBackend
from .utils import is_hosted, url_validation

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import List
from typing import List, Optional
from urllib.parse import urlparse, urlunparse


def url_validation(api_url: str, default_api_url: str, allowed_paths: List[str]) -> str:
def url_validation(api_url: str, default_api_url: Optional[str], allowed_paths: List[str]) -> str:
"""
Validate and normalize an API URL.
Expand Down
1 change: 1 addition & 0 deletions integrations/nvidia/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .conftest import MockBackend

__all__ = ["MockBackend"]
4 changes: 4 additions & 0 deletions integrations/nvidia/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional, Tuple

import pytest
Expand Down
4 changes: 4 additions & 0 deletions integrations/nvidia/tests/test_base_url.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import pytest

from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder, NvidiaTextEmbedder
Expand Down
27 changes: 24 additions & 3 deletions integrations/nvidia/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import os

import pytest
Expand Down Expand Up @@ -104,7 +108,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
},
}

def from_dict(self, monkeypatch):
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key")
data = {
"type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder",
Expand All @@ -122,15 +126,32 @@ def from_dict(self, monkeypatch):
},
}
component = NvidiaDocumentEmbedder.from_dict(data)
assert component.model == "nvolveqa_40k"
assert component.model == "playground_nvolveqa_40k"
assert component.api_url == "https://example.com/v1"
assert component.prefix == "prefix"
assert component.suffix == "suffix"
assert component.batch_size == 10
assert component.progress_bar is False
assert component.meta_fields_to_embed == ["test_field"]
assert component.embedding_separator == " | "
assert component.truncate == EmbeddingTruncateMode.START

def test_from_dict_defaults(self, monkeypatch):
monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key")
data = {
"type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder",
"init_parameters": {},
}
component = NvidiaDocumentEmbedder.from_dict(data)
assert component.model == "nvidia/nv-embedqa-e5-v5"
assert component.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia"
assert component.prefix == ""
assert component.suffix == ""
assert component.batch_size == 32
assert component.progress_bar
assert component.meta_fields_to_embed == []
assert component.embedding_separator == "\n"
assert component.truncate == EmbeddingTruncateMode.START
assert component.truncate is None

def test_prepare_texts_to_embed_w_metadata(self):
documents = [
Expand Down
4 changes: 4 additions & 0 deletions integrations/nvidia/tests/test_embedding_truncate_mode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import pytest

from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode
Expand Down
1 change: 1 addition & 0 deletions integrations/nvidia/tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import os

import pytest
Expand Down
49 changes: 49 additions & 0 deletions integrations/nvidia/tests/test_ranker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import os
import re
from typing import Any, Optional, Union
Expand Down Expand Up @@ -256,3 +260,48 @@ def test_warm_up_once(self, monkeypatch) -> None:
backend = client._backend
client.warm_up()
assert backend == client._backend

def test_to_dict(self) -> None:
client = NvidiaRanker()
assert client.to_dict() == {
"type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker",
"init_parameters": {
"model": "nvidia/nv-rerankqa-mistral-4b-v3",
"top_k": 5,
"truncate": None,
"api_url": None,
"api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True},
},
}

def test_from_dict(self) -> None:
client = NvidiaRanker.from_dict(
{
"type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker",
"init_parameters": {
"model": "nvidia/nv-rerankqa-mistral-4b-v3",
"top_k": 5,
"truncate": None,
"api_url": None,
"api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True},
},
}
)
assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3"
assert client._top_k == 5
assert client._truncate is None
assert client._api_url is None
assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY")

def test_from_dict_defaults(self) -> None:
client = NvidiaRanker.from_dict(
{
"type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker",
"init_parameters": {},
}
)
assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3"
assert client._top_k == 5
assert client._truncate is None
assert client._api_url is None
assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY")
Loading

0 comments on commit 3c04cfe

Please sign in to comment.