From a84bc954c4c16d20dfbdbd64c528ac6ee27720f5 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 29 Jan 2024 14:09:55 +0100 Subject: [PATCH 1/7] Don't double length of embeddings vs documents --- .../components/embedders/cohere/text_embedder.py | 2 +- .../haystack_integrations/components/embedders/cohere/utils.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py index 2fa922004..a9db15213 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py @@ -125,4 +125,4 @@ def run(self, text: str): ) embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate) - return {"embedding": embedding[0], "meta": metadata} + return {"embedding": embedding, "meta": metadata} diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py index 7b9c90730..9c16ecee7 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py @@ -46,8 +46,6 @@ def get_response( response = cohere_client.embed(batch, model=model_name, input_type=input_type, truncate=truncate) for emb in response.embeddings: all_embeddings.append(emb) - embeddings = [list(map(float, emb)) for emb in response.embeddings] - all_embeddings.extend(embeddings) if response.meta is not None: metadata = response.meta From 5417a6f1e3c94b96ffb3ef0f7e98e36b57ba05bf Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 29 Jan 2024 14:19:45 +0100 Subject: [PATCH 2/7] Fix issue --- .../components/embedders/cohere/text_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py index a9db15213..2fa922004 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py @@ -125,4 +125,4 @@ def run(self, text: str): ) embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate) - return {"embedding": embedding, "meta": metadata} + return {"embedding": embedding[0], "meta": metadata} From b0c17f778f48f227e679a780a9833cd6d6b189fb Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 29 Jan 2024 14:29:40 +0100 Subject: [PATCH 3/7] Fix failing tests --- integrations/cohere/tests/test_cohere_chat_generator.py | 2 +- integrations/cohere/tests/test_cohere_generators.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index c91ada419..c4d727f4a 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -260,7 +260,7 @@ def test_live_run(self): @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): component = CohereChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) - with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"): + with pytest.raises(cohere.CohereAPIError, match="model not found, make sure the correct model ID was used and that you have access to the model."): component.run(chat_messages) @pytest.mark.skipif( diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generators.py index e2ce10405..90d4d3e28 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generators.py @@ -164,7 +164,7 @@ def __init__(self): self.responses = "" def __call__(self, chunk): - self.responses += chunk.text + self.responses += chunk.content return chunk callback = Callback() From 96ad18638d3f066f46a0bce733fb85d9fc513d58 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 29 Jan 2024 14:32:36 +0100 Subject: [PATCH 4/7] Fix pylint --- integrations/cohere/tests/test_cohere_chat_generator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index c4d727f4a..b6568534b 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -260,7 +260,10 @@ def test_live_run(self): @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): component = CohereChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) - with pytest.raises(cohere.CohereAPIError, match="model not found, make sure the correct model ID was used and that you have access to the model."): + with pytest.raises( + cohere.CohereAPIError, + match="model not found, make sure the correct model ID was used and that you have access to the model." + ): component.run(chat_messages) @pytest.mark.skipif( From a962b94606c08ac9dec68c4008455f432e7df0ad Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 29 Jan 2024 14:38:53 +0100 Subject: [PATCH 5/7] Pylint --- integrations/cohere/tests/test_cohere_chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index b6568534b..edefc1a43 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -262,7 +262,7 @@ def test_live_run_wrong_model(self, chat_messages): component = CohereChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) with pytest.raises( cohere.CohereAPIError, - match="model not found, make sure the correct model ID was used and that you have access to the model." + match="model not found, make sure the correct model ID was used and that you have access to the model.", ): component.run(chat_messages) From 00d3f9f3088c34df02e6a589dada25bba1772b45 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 5 Feb 2024 09:58:39 +0100 Subject: [PATCH 6/7] Undo test fixes --- integrations/cohere/tests/test_cohere_chat_generator.py | 5 +---- integrations/cohere/tests/test_cohere_generators.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index edefc1a43..c91ada419 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -260,10 +260,7 @@ def test_live_run(self): @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): component = CohereChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) - with pytest.raises( - cohere.CohereAPIError, - match="model not found, make sure the correct model ID was used and that you have access to the model.", - ): + with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"): component.run(chat_messages) @pytest.mark.skipif( diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generators.py index 90d4d3e28..e2ce10405 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generators.py @@ -164,7 +164,7 @@ def __init__(self): self.responses = "" def __call__(self, chunk): - self.responses += chunk.content + self.responses += chunk.text return chunk callback = Callback() From c3257a9aa503d424db376eb91a8dca0f1e535555 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 5 Feb 2024 10:00:23 +0100 Subject: [PATCH 7/7] Fix linter --- .../components/generators/cohere/generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index fee410eab..92fed51aa 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -122,7 +122,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator": """ init_params = data.get("init_parameters", {}) streaming_callback = None - if "streaming_callback" in init_params and init_params["streaming_callback"]: + if "streaming_callback" in init_params and init_params["streaming_callback"] is not None: parts = init_params["streaming_callback"].split(".") module_name = ".".join(parts[:-1]) function_name = parts[-1]