From 26bb3288e4c4934544bec9c2e55b292d8e113718 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 16 Sep 2024 16:22:11 +0200 Subject: [PATCH] feat: Cohere LLM - adjust token counting meta to match OpenAI format (#1086) * Cohere - adjust token counting in meta * Update integration test * Lint --- .../components/generators/cohere/chat/chat_generator.py | 8 +++++--- integrations/cohere/tests/test_cohere_chat_generator.py | 7 +++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 568a26979..e635e291c 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -178,7 +178,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, if finish_response.meta.billed_units: tokens_in = finish_response.meta.billed_units.input_tokens or -1 tokens_out = finish_response.meta.billed_units.output_tokens or -1 - chat_message.meta["usage"] = tokens_in + tokens_out + chat_message.meta["usage"] = {"prompt_tokens": tokens_in, "completion_tokens": tokens_out} chat_message.meta.update( { "model": self.model, @@ -220,11 +220,13 @@ def _build_message(self, cohere_response): message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json()) elif cohere_response.text: message = ChatMessage.from_assistant(content=cohere_response.text) - total_tokens = cohere_response.meta.billed_units.input_tokens + cohere_response.meta.billed_units.output_tokens message.meta.update( { "model": self.model, - "usage": total_tokens, + "usage": { + "prompt_tokens": cohere_response.meta.billed_units.input_tokens, + "completion_tokens": cohere_response.meta.billed_units.output_tokens, + }, "index": 0, "finish_reason": cohere_response.finish_reason, "documents": cohere_response.documents, diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 6521503f2..fe9b7f43e 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -169,6 +169,9 @@ def test_live_run(self): assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.content + assert "usage" in message.meta + assert "prompt_tokens" in message.meta["usage"] + assert "completion_tokens" in message.meta["usage"] @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), @@ -210,6 +213,10 @@ def __call__(self, chunk: StreamingChunk) -> None: assert callback.counter > 1 assert "Paris" in callback.responses + assert "usage" in message.meta + assert "prompt_tokens" in message.meta["usage"] + assert "completion_tokens" in message.meta["usage"] + @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.",