From e8e5d67a8d8839c96dc54552b5ff007b95992345 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Thu, 3 Oct 2024 18:25:38 -0700 Subject: [PATCH] openai: fix None token detail (#27091) happens in Azure --- .../langchain_openai/chat_models/base.py | 10 ++++++---- .../tests/unit_tests/chat_models/test_base.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 27d4adf06817a..4e91fab34a8d5 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2161,16 +2161,18 @@ def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata: output_tokens = oai_token_usage.get("completion_tokens", 0) total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens) input_token_details: dict = { - "audio": oai_token_usage.get("prompt_tokens_details", {}).get("audio_tokens"), - "cache_read": oai_token_usage.get("prompt_tokens_details", {}).get( + "audio": (oai_token_usage.get("prompt_tokens_details") or {}).get( + "audio_tokens" + ), + "cache_read": (oai_token_usage.get("prompt_tokens_details") or {}).get( "cached_tokens" ), } output_token_details: dict = { - "audio": oai_token_usage.get("completion_tokens_details", {}).get( + "audio": (oai_token_usage.get("completion_tokens_details") or {}).get( "audio_tokens" ), - "reasoning": oai_token_usage.get("completion_tokens_details", {}).get( + "reasoning": (oai_token_usage.get("completion_tokens_details") or {}).get( "reasoning_tokens" ), } diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index d205276704016..cc03698e5ef93 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -23,6 +23,7 @@ from langchain_openai.chat_models.base import ( _convert_dict_to_message, _convert_message_to_dict, + _create_usage_metadata, _format_message_content, ) @@ -730,3 +731,21 @@ def test_schema_from_with_structured_output(schema: Type) -> None: } actual = structured_llm.get_output_schema().model_json_schema() assert actual == expected + + +def test__create_usage_metadata() -> None: + usage_metadata = { + "completion_tokens": 15, + "prompt_tokens_details": None, + "completion_tokens_details": None, + "prompt_tokens": 11, + "total_tokens": 26, + } + result = _create_usage_metadata(usage_metadata) + assert result == UsageMetadata( + output_tokens=15, + input_tokens=11, + total_tokens=26, + input_token_details={}, + output_token_details={}, + )