diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py index 2c32eabb1..3a8eb9d07 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py @@ -6,10 +6,12 @@ class EmbeddingTruncateMode(Enum): Specifies how inputs to the NVIDIA embedding components are truncated. If START, the input will be truncated from the start. If END, the input will be truncated from the end. + If NONE, an error will be returned (if the input is too long). """ START = "START" END = "END" + NONE = "NONE" def __str__(self): return self.value diff --git a/integrations/nvidia/tests/test_embedding_truncate_mode.py b/integrations/nvidia/tests/test_embedding_truncate_mode.py new file mode 100644 index 000000000..e74d0308c --- /dev/null +++ b/integrations/nvidia/tests/test_embedding_truncate_mode.py @@ -0,0 +1,40 @@ +import pytest + +from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode + + +class TestEmbeddingTruncateMode: + @pytest.mark.parametrize( + "mode, expected", + [ + ("START", EmbeddingTruncateMode.START), + ("END", EmbeddingTruncateMode.END), + ("NONE", EmbeddingTruncateMode.NONE), + (EmbeddingTruncateMode.START, EmbeddingTruncateMode.START), + (EmbeddingTruncateMode.END, EmbeddingTruncateMode.END), + (EmbeddingTruncateMode.NONE, EmbeddingTruncateMode.NONE), + ], + ) + def test_init_with_valid_mode(self, mode, expected): + assert EmbeddingTruncateMode(mode) == expected + + def test_init_with_invalid_mode_raises_value_error(self): + with pytest.raises(ValueError): + invalid_mode = "INVALID" + EmbeddingTruncateMode(invalid_mode) + + @pytest.mark.parametrize( + "mode, expected", + [ + ("START", EmbeddingTruncateMode.START), + ("END", EmbeddingTruncateMode.END), + ("NONE", EmbeddingTruncateMode.NONE), + ], + ) + def test_from_str_with_valid_mode(self, mode, expected): + assert EmbeddingTruncateMode.from_str(mode) == expected + + def test_from_str_with_invalid_mode_raises_value_error(self): + with pytest.raises(ValueError): + invalid_mode = "INVALID" + EmbeddingTruncateMode.from_str(invalid_mode)