From 31f7b087e1f644d120fcda27424b363eb49e1130 Mon Sep 17 00:00:00 2001 From: vignesh14052002 Date: Sat, 5 Oct 2024 17:02:49 +0530 Subject: [PATCH 1/3] Add prompt cache and reasoning token tracking --- .../callbacks/openai_info.py | 58 ++++++++++++++++--- .../unit_tests/callbacks/test_openai_info.py | 40 ++++++++++++- 2 files changed, 90 insertions(+), 8 deletions(-) diff --git a/libs/community/langchain_community/callbacks/openai_info.py b/libs/community/langchain_community/callbacks/openai_info.py index 7fbb0d097e1b9..9250ae9c80149 100644 --- a/libs/community/langchain_community/callbacks/openai_info.py +++ b/libs/community/langchain_community/callbacks/openai_info.py @@ -1,6 +1,7 @@ """Callback Handler that prints to std out.""" import threading +from enum import Enum, auto from typing import Any, Dict, List from langchain_core.callbacks import BaseCallbackHandler @@ -10,26 +11,34 @@ MODEL_COST_PER_1K_TOKENS = { # OpenAI o1-preview input "o1-preview": 0.015, + "o1-preview-cached": 0.0075, "o1-preview-2024-09-12": 0.015, + "o1-preview-2024-09-12-cached": 0.0075, # OpenAI o1-preview output "o1-preview-completion": 0.06, "o1-preview-2024-09-12-completion": 0.06, # OpenAI o1-mini input "o1-mini": 0.003, + "o1-mini-cached": 0.0015, "o1-mini-2024-09-12": 0.003, + "o1-mini-2024-09-12-cached": 0.0015, # OpenAI o1-mini output "o1-mini-completion": 0.012, "o1-mini-2024-09-12-completion": 0.012, # GPT-4o-mini input "gpt-4o-mini": 0.00015, + "gpt-4o-mini-cached": 0.000075, "gpt-4o-mini-2024-07-18": 0.00015, + "gpt-4o-mini-2024-07-18-cached": 0.000075, # GPT-4o-mini output "gpt-4o-mini-completion": 0.0006, "gpt-4o-mini-2024-07-18-completion": 0.0006, # GPT-4o input "gpt-4o": 0.005, + "gpt-4o-cached": 0.00125, "gpt-4o-2024-05-13": 0.005, "gpt-4o-2024-08-06": 0.0025, + "gpt-4o-2024-08-06-cached": 0.00125, # GPT-4o output "gpt-4o-completion": 0.015, "gpt-4o-2024-05-13-completion": 0.015, @@ -138,9 +147,17 @@ } +class TokenType(Enum): + """Token type enum.""" + + PROMPT = auto() + PROMPT_CACHED = auto() + COMPLETION = auto() + + def standardize_model_name( model_name: str, - is_completion: bool = False, + token_type: TokenType = TokenType.PROMPT, ) -> str: """ Standardize the model name to a format that can be used in the OpenAI API. @@ -161,7 +178,7 @@ def standardize_model_name( model_name = model_name.split(":")[0] + "-finetuned-legacy" if "ft:" in model_name: model_name = model_name.split(":")[1] + "-finetuned" - if is_completion and ( + if token_type == TokenType.COMPLETION and ( model_name.startswith("gpt-4") or model_name.startswith("gpt-3.5") or model_name.startswith("gpt-35") @@ -169,12 +186,16 @@ def standardize_model_name( or ("finetuned" in model_name and "legacy" not in model_name) ): return model_name + "-completion" + if token_type == TokenType.PROMPT_CACHED and ( + model_name.startswith("gpt-4o") or model_name.startswith("o1") + ): + return model_name + "-cached" else: return model_name def get_openai_token_cost_for_model( - model_name: str, num_tokens: int, is_completion: bool = False + model_name: str, num_tokens: int, token_type: TokenType = TokenType.PROMPT ) -> float: """ Get the cost in USD for a given model and number of tokens. @@ -188,7 +209,7 @@ def get_openai_token_cost_for_model( Returns: Cost in USD. """ - model_name = standardize_model_name(model_name, is_completion=is_completion) + model_name = standardize_model_name(model_name, token_type=token_type) if model_name not in MODEL_COST_PER_1K_TOKENS: raise ValueError( f"Unknown model: {model_name}. Please provide a valid OpenAI model name." @@ -202,7 +223,9 @@ class OpenAICallbackHandler(BaseCallbackHandler): total_tokens: int = 0 prompt_tokens: int = 0 + prompt_tokens_cached: int = 0 completion_tokens: int = 0 + reasoning_tokens: int = 0 successful_requests: int = 0 total_cost: float = 0.0 @@ -214,7 +237,9 @@ def __repr__(self) -> str: return ( f"Tokens Used: {self.total_tokens}\n" f"\tPrompt Tokens: {self.prompt_tokens}\n" + f"\t\tPrompt Tokens Cached: {self.prompt_tokens_cached}\n" f"\tCompletion Tokens: {self.completion_tokens}\n" + f"\t\tReasoning Tokens: {self.reasoning_tokens}\n" f"Successful Requests: {self.successful_requests}\n" f"Total Cost (USD): ${self.total_cost}" ) @@ -256,6 +281,10 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: else: usage_metadata = None response_metadata = None + + prompt_tokens_cached = 0 + reasoning_tokens = 0 + if usage_metadata: token_usage = {"total_tokens": usage_metadata["total_tokens"]} completion_tokens = usage_metadata["output_tokens"] @@ -268,7 +297,12 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: model_name = standardize_model_name( response.llm_output.get("model_name", "") ) - + match usage_metadata: + case {"input_token_details": {"cache_read": cached_tokens}}: + prompt_tokens_cached = cached_tokens + match usage_metadata: + case {"output_token_details": {"reasoning": reasoning}}: + reasoning_tokens = reasoning else: if response.llm_output is None: return None @@ -285,11 +319,19 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: model_name = standardize_model_name( response.llm_output.get("model_name", "") ) + if model_name in MODEL_COST_PER_1K_TOKENS: + uncached_prompt_tokens = prompt_tokens - prompt_tokens_cached + uncached_prompt_cost = get_openai_token_cost_for_model( + model_name, uncached_prompt_tokens, token_type=TokenType.PROMPT + ) + cached_prompt_cost = get_openai_token_cost_for_model( + model_name, prompt_tokens_cached, token_type=TokenType.PROMPT_CACHED + ) + prompt_cost = uncached_prompt_cost + cached_prompt_cost completion_cost = get_openai_token_cost_for_model( - model_name, completion_tokens, is_completion=True + model_name, completion_tokens, token_type=TokenType.COMPLETION ) - prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens) else: completion_cost = 0 prompt_cost = 0 @@ -299,7 +341,9 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.total_cost += prompt_cost + completion_cost self.total_tokens += token_usage.get("total_tokens", 0) self.prompt_tokens += prompt_tokens + self.prompt_tokens_cached += prompt_tokens_cached self.completion_tokens += completion_tokens + self.reasoning_tokens += reasoning_tokens self.successful_requests += 1 def __copy__(self) -> "OpenAICallbackHandler": diff --git a/libs/community/tests/unit_tests/callbacks/test_openai_info.py b/libs/community/tests/unit_tests/callbacks/test_openai_info.py index 48ab5fd1a9a98..09b838b0d9d7b 100644 --- a/libs/community/tests/unit_tests/callbacks/test_openai_info.py +++ b/libs/community/tests/unit_tests/callbacks/test_openai_info.py @@ -3,7 +3,8 @@ import numpy as np import pytest -from langchain_core.outputs import LLMResult +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.utils.pydantic import get_fields from langchain_community.callbacks import OpenAICallbackHandler @@ -35,6 +36,43 @@ def test_on_llm_end(handler: OpenAICallbackHandler) -> None: assert handler.total_cost > 0 +def test_on_llm_end_with_chat_generation(handler: OpenAICallbackHandler) -> None: + response = LLMResult( + generations=[ + [ + ChatGeneration( + text="Hello, world!", + message=AIMessage( + content="Hello, world!", + usage_metadata={ + "input_tokens": 2, + "output_tokens": 2, + "total_tokens": 4, + "input_token_details": { + "cache_read": 1, + }, + "output_token_details": { + "reasoning": 1, + }, + }, + ), + ) + ] + ], + llm_output={ + "model_name": get_fields(BaseOpenAI)["model_name"].default, + }, + ) + handler.on_llm_end(response) + assert handler.successful_requests == 1 + assert handler.total_tokens == 4 + assert handler.prompt_tokens == 2 + assert handler.prompt_tokens_cached == 1 + assert handler.completion_tokens == 2 + assert handler.reasoning_tokens == 1 + assert handler.total_cost > 0 + + def test_on_llm_end_custom_model(handler: OpenAICallbackHandler) -> None: response = LLMResult( generations=[], From 878131b5ea8244fdcb8f85789f05bd2992d06c6c Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Wed, 18 Dec 2024 11:13:26 -0500 Subject: [PATCH 2/3] keep is_completion and add warning --- .../callbacks/openai_info.py | 37 +++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/callbacks/openai_info.py b/libs/community/langchain_community/callbacks/openai_info.py index e86ba32d6e7c1..8a194d7b42b7c 100644 --- a/libs/community/langchain_community/callbacks/openai_info.py +++ b/libs/community/langchain_community/callbacks/openai_info.py @@ -4,6 +4,7 @@ from enum import Enum, auto from typing import Any, Dict, List +from langchain_core._api import warn_deprecated from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import AIMessage from langchain_core.outputs import ChatGeneration, LLMResult @@ -159,6 +160,8 @@ class TokenType(Enum): def standardize_model_name( model_name: str, + is_completion: bool = False, + *, token_type: TokenType = TokenType.PROMPT, ) -> str: """ @@ -167,12 +170,24 @@ def standardize_model_name( Args: model_name: Model name to standardize. is_completion: Whether the model is used for completion or not. - Defaults to False. + Defaults to False. Deprecated in favor of ``token_type``. + token_type: Token type. Defaults to ``TokenType.PROMPT``. Returns: Standardized model name. """ + if is_completion: + warn_deprecated( + since="0.3.13", + message=( + "is_completion is deprecated. Use token_type instead. Example:\n\n" + "from langchain_community.callbacks.openai_info import TokenType\n\n" + "standardize_model_name('gpt-4o', token_type=TokenType.COMPLETION)\n" + ), + removal="1.0", + ) + token_type = TokenType.COMPLETION model_name = model_name.lower() if ".ft-" in model_name: model_name = model_name.split(".ft-")[0] + "-azure-finetuned" @@ -197,7 +212,11 @@ def standardize_model_name( def get_openai_token_cost_for_model( - model_name: str, num_tokens: int, token_type: TokenType = TokenType.PROMPT + model_name: str, + num_tokens: int, + is_completion: bool = False, + *, + token_type: TokenType = TokenType.PROMPT, ) -> float: """ Get the cost in USD for a given model and number of tokens. @@ -206,11 +225,23 @@ def get_openai_token_cost_for_model( model_name: Name of the model num_tokens: Number of tokens. is_completion: Whether the model is used for completion or not. - Defaults to False. + Defaults to False. Deprecated in favor of ``token_type``. + token_type: Token type. Defaults to ``TokenType.PROMPT``. Returns: Cost in USD. """ + if is_completion: + warn_deprecated( + since="0.3.13", + message=( + "is_completion is deprecated. Use token_type instead. Example:\n\n" + "from langchain_community.callbacks.openai_info import TokenType\n\n" + "get_openai_token_cost_for_model('gpt-4o', 10, token_type=TokenType.COMPLETION)\n" # noqa: E501 + ), + removal="1.0", + ) + token_type = TokenType.COMPLETION model_name = standardize_model_name(model_name, token_type=token_type) if model_name not in MODEL_COST_PER_1K_TOKENS: raise ValueError( From cd67cf49d2d217bfee181a2dc55d5e0886f96c46 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Wed, 18 Dec 2024 11:24:16 -0500 Subject: [PATCH 3/3] python 3.9 compatibility --- .../langchain_community/callbacks/openai_info.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/callbacks/openai_info.py b/libs/community/langchain_community/callbacks/openai_info.py index 8a194d7b42b7c..b0688de5a9530 100644 --- a/libs/community/langchain_community/callbacks/openai_info.py +++ b/libs/community/langchain_community/callbacks/openai_info.py @@ -330,12 +330,12 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: model_name = standardize_model_name( response.llm_output.get("model_name", "") ) - match usage_metadata: - case {"input_token_details": {"cache_read": cached_tokens}}: - prompt_tokens_cached = cached_tokens - match usage_metadata: - case {"output_token_details": {"reasoning": reasoning}}: - reasoning_tokens = reasoning + if "cache_read" in usage_metadata.get("input_token_details", {}): + prompt_tokens_cached = usage_metadata["input_token_details"][ + "cache_read" + ] + if "reasoning" in usage_metadata.get("output_token_details", {}): + reasoning_tokens = usage_metadata["output_token_details"]["reasoning"] else: if response.llm_output is None: return None