Skip to content

Commit

Permalink
fix: Missing Nvidia embedding truncate mode (#1043)
Browse files Browse the repository at this point in the history
* fix: Add NONE option to EmbeddingTruncateMode

* refactor: Validate input with _missing_ method

* test: Add EmbeddingTruncateMode test

* refactor: Revert "refactor: Validate input with _missing_ method"

This reverts commit 8334a50.

* Update integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py

---------

Co-authored-by: Madeesh Kannan <[email protected]>
  • Loading branch information
2 people authored and Amnah199 committed Sep 6, 2024
1 parent f7c9816 commit f74cb78
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ class EmbeddingTruncateMode(Enum):
Specifies how inputs to the NVIDIA embedding components are truncated.
If START, the input will be truncated from the start.
If END, the input will be truncated from the end.
If NONE, an error will be returned (if the input is too long).
"""

START = "START"
END = "END"
NONE = "NONE"

def __str__(self):
return self.value
Expand Down
40 changes: 40 additions & 0 deletions integrations/nvidia/tests/test_embedding_truncate_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest

from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode


class TestEmbeddingTruncateMode:
@pytest.mark.parametrize(
"mode, expected",
[
("START", EmbeddingTruncateMode.START),
("END", EmbeddingTruncateMode.END),
("NONE", EmbeddingTruncateMode.NONE),
(EmbeddingTruncateMode.START, EmbeddingTruncateMode.START),
(EmbeddingTruncateMode.END, EmbeddingTruncateMode.END),
(EmbeddingTruncateMode.NONE, EmbeddingTruncateMode.NONE),
],
)
def test_init_with_valid_mode(self, mode, expected):
assert EmbeddingTruncateMode(mode) == expected

def test_init_with_invalid_mode_raises_value_error(self):
with pytest.raises(ValueError):
invalid_mode = "INVALID"
EmbeddingTruncateMode(invalid_mode)

@pytest.mark.parametrize(
"mode, expected",
[
("START", EmbeddingTruncateMode.START),
("END", EmbeddingTruncateMode.END),
("NONE", EmbeddingTruncateMode.NONE),
],
)
def test_from_str_with_valid_mode(self, mode, expected):
assert EmbeddingTruncateMode.from_str(mode) == expected

def test_from_str_with_invalid_mode_raises_value_error(self):
with pytest.raises(ValueError):
invalid_mode = "INVALID"
EmbeddingTruncateMode.from_str(invalid_mode)

0 comments on commit f74cb78

Please sign in to comment.