Skip to content

Commit

Permalink
TextLLM: Add token usage to response meta (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixTJDietrich authored Sep 27, 2024
1 parent d9ff3bd commit 5c2eae4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.messages.ai import UsageMetadata

from athena import emit_meta, get_meta


class UsageHandler(BaseCallbackHandler):
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
meta = get_meta()

total_usage = meta.get("total_usage", {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
llm_calls = meta.get("llm_calls", [])

for generations in response.generations:
for generation in generations:
message = generation.dict()["message"]
generation_usage: UsageMetadata = message["usage_metadata"]
model_name = message["response_metadata"].get("model_name", None)

total_usage["input_tokens"] += generation_usage["input_tokens"]
total_usage["output_tokens"] += generation_usage["output_tokens"]
total_usage["total_tokens"] += generation_usage["total_tokens"]

llm_calls.append({
"model_name": model_name,
"input_tokens": generation_usage["input_tokens"],
"output_tokens": generation_usage["output_tokens"],
"total_tokens": generation_usage["total_tokens"],
})

emit_meta("total_usage", total_usage)
emit_meta("llm_calls", llm_calls)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from athena.logger import logger
from .model_config import ModelConfig
from .callbacks import UsageHandler


OPENAI_PREFIX = "openai_"
AZURE_OPENAI_PREFIX = "azure_openai_"
Expand Down Expand Up @@ -132,6 +134,7 @@ def get_model(self) -> BaseLanguageModel:
# Otherwise, add it to model_kwargs (necessary for chat models)
model_kwargs[attr] = value
kwargs["model_kwargs"] = model_kwargs
kwargs["callbacks"] = [UsageHandler()]

# Initialize a copy of the model using the config
model = model.__class__(**kwargs)
Expand Down

0 comments on commit 5c2eae4

Please sign in to comment.