Skip to content

Commit

Permalink
feat: Cohere LLM - adjust token counting meta to match OpenAI format (#…
Browse files Browse the repository at this point in the history
…1086)

* Cohere - adjust token counting in meta

* Update integration test

* Lint
  • Loading branch information
vblagoje authored Sep 16, 2024
1 parent b47583f commit 26bb328
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
if finish_response.meta.billed_units:
tokens_in = finish_response.meta.billed_units.input_tokens or -1
tokens_out = finish_response.meta.billed_units.output_tokens or -1
chat_message.meta["usage"] = tokens_in + tokens_out
chat_message.meta["usage"] = {"prompt_tokens": tokens_in, "completion_tokens": tokens_out}
chat_message.meta.update(
{
"model": self.model,
Expand Down Expand Up @@ -220,11 +220,13 @@ def _build_message(self, cohere_response):
message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json())
elif cohere_response.text:
message = ChatMessage.from_assistant(content=cohere_response.text)
total_tokens = cohere_response.meta.billed_units.input_tokens + cohere_response.meta.billed_units.output_tokens
message.meta.update(
{
"model": self.model,
"usage": total_tokens,
"usage": {
"prompt_tokens": cohere_response.meta.billed_units.input_tokens,
"completion_tokens": cohere_response.meta.billed_units.output_tokens,
},
"index": 0,
"finish_reason": cohere_response.finish_reason,
"documents": cohere_response.documents,
Expand Down
7 changes: 7 additions & 0 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def test_live_run(self):
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "usage" in message.meta
assert "prompt_tokens" in message.meta["usage"]
assert "completion_tokens" in message.meta["usage"]

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),
Expand Down Expand Up @@ -210,6 +213,10 @@ def __call__(self, chunk: StreamingChunk) -> None:
assert callback.counter > 1
assert "Paris" in callback.responses

assert "usage" in message.meta
assert "prompt_tokens" in message.meta["usage"]
assert "completion_tokens" in message.meta["usage"]

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),
reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.",
Expand Down

0 comments on commit 26bb328

Please sign in to comment.