diff --git a/modules/text/module_text_llm/module_text_llm/helpers/models/callbacks.py b/modules/text/module_text_llm/module_text_llm/helpers/models/callbacks.py new file mode 100644 index 00000000..d33b0a6b --- /dev/null +++ b/modules/text/module_text_llm/module_text_llm/helpers/models/callbacks.py @@ -0,0 +1,33 @@ +from langchain.callbacks.base import BaseCallbackHandler +from langchain_core.outputs import LLMResult +from langchain_core.messages.ai import UsageMetadata + +from athena import emit_meta, get_meta + + +class UsageHandler(BaseCallbackHandler): + def on_llm_end(self, response: LLMResult, **kwargs) -> None: + meta = get_meta() + + total_usage = meta.get("total_usage", {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}) + llm_calls = meta.get("llm_calls", []) + + for generations in response.generations: + for generation in generations: + message = generation.dict()["message"] + generation_usage: UsageMetadata = message["usage_metadata"] + model_name = message["response_metadata"].get("model_name", None) + + total_usage["input_tokens"] += generation_usage["input_tokens"] + total_usage["output_tokens"] += generation_usage["output_tokens"] + total_usage["total_tokens"] += generation_usage["total_tokens"] + + llm_calls.append({ + "model_name": model_name, + "input_tokens": generation_usage["input_tokens"], + "output_tokens": generation_usage["output_tokens"], + "total_tokens": generation_usage["total_tokens"], + }) + + emit_meta("total_usage", total_usage) + emit_meta("llm_calls", llm_calls) diff --git a/modules/text/module_text_llm/module_text_llm/helpers/models/openai.py b/modules/text/module_text_llm/module_text_llm/helpers/models/openai.py index 23e1669c..d12df22d 100644 --- a/modules/text/module_text_llm/module_text_llm/helpers/models/openai.py +++ b/modules/text/module_text_llm/module_text_llm/helpers/models/openai.py @@ -10,6 +10,8 @@ from athena.logger import logger from .model_config import ModelConfig +from .callbacks import UsageHandler + OPENAI_PREFIX = "openai_" AZURE_OPENAI_PREFIX = "azure_openai_" @@ -132,6 +134,7 @@ def get_model(self) -> BaseLanguageModel: # Otherwise, add it to model_kwargs (necessary for chat models) model_kwargs[attr] = value kwargs["model_kwargs"] = model_kwargs + kwargs["callbacks"] = [UsageHandler()] # Initialize a copy of the model using the config model = model.__class__(**kwargs)