Skip to content

Commit

Permalink
[draft]: openai token counting
Browse files Browse the repository at this point in the history
  • Loading branch information
agola11 committed Oct 7, 2024
1 parent da3c1bb commit 360d77b
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/docs/create_api_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def _load_module_members(module_path: str, namespace: str) -> ModuleMembers:
else (
"enum"
if issubclass(type_, Enum)
else "Pydantic" if issubclass(type_, BaseModel) else "Regular"
else "Pydantic"
if issubclass(type_, BaseModel)
else "Regular"
)
)
if hasattr(type_, "__slots__"):
Expand Down
63 changes: 62 additions & 1 deletion python/langsmith/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from uuid import UUID

from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict

try:
from pydantic.v1 import ( # type: ignore[import]
Expand Down Expand Up @@ -891,3 +891,64 @@ class PromptSortField(str, Enum):
"""Last updated time."""
num_likes = "num_likes"
"""Number of likes."""


class InputTokenDetails(TypedDict, total=False):
"""Breakdown of input token counts.
Does *not* need to sum to full input token count. Does *not* need to have all keys.
"""

audio: int
"""Audio input tokens."""
cache_creation: int
"""Input tokens that were cached and there was a cache miss.
Since there was a cache miss, the cache was created from these tokens.
"""
cache_read: int
"""Input tokens that were cached and there was a cache hit.
Since there was a cache hit, the tokens were read from the cache. More precisely,
the model state given these tokens was read from the cache.
"""


class OutputTokenDetails(TypedDict, total=False):
"""Breakdown of output token counts.
Does *not* need to sum to full output token count. Does *not* need to have all keys.
"""

audio: int
"""Audio output tokens."""
reasoning: int
"""Reasoning output tokens.
Tokens generated by the model in a chain of thought process (i.e. by OpenAI's o1
models) that are not returned as part of model output.
"""


class UsageMetadata(TypedDict):
"""Usage metadata for a message, such as token counts.
This is a standard representation of token usage that is consistent across models.
"""

input_tokens: int
"""Count of input (or prompt) tokens. Sum of all input token types."""
output_tokens: int
"""Count of output (or completion) tokens. Sum of all output token types."""
total_tokens: int
"""Total token count. Sum of input_tokens + output_tokens."""
input_token_details: NotRequired[InputTokenDetails]
"""Breakdown of input token counts.
Does *not* need to sum to full input token count. Does *not* need to have all keys.
"""
output_token_details: NotRequired[OutputTokenDetails]
"""Breakdown of output token counts.
Does *not* need to sum to full output token count. Does *not* need to have all keys.
"""
60 changes: 60 additions & 0 deletions python/langsmith/wrappers/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from langsmith import client as ls_client
from langsmith import run_helpers
from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata

if TYPE_CHECKING:
from openai import AsyncOpenAI, OpenAI
Expand Down Expand Up @@ -141,6 +142,12 @@ def _reduce_chat(all_chunks: List[ChatCompletionChunk]) -> dict:
]
else:
d = {"choices": [{"message": {"role": "assistant", "content": ""}}]}
# streamed outputs don't go through `process_outputs`
# so we need to flatten metadata here
oai_token_usage = d.pop("usage")
d["usage_metadata"] = (
_create_usage_metadata(oai_token_usage) if oai_token_usage else None
)
return d


Expand All @@ -160,12 +167,62 @@ def _reduce_completions(all_chunks: List[Completion]) -> dict:
return d


def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata:
input_tokens = oai_token_usage.get("prompt_tokens", 0)
output_tokens = oai_token_usage.get("completion_tokens", 0)
total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens)
input_token_details: dict = {
"audio": (oai_token_usage.get("prompt_tokens_details") or {}).get(
"audio_tokens"
),
"cache_read": (oai_token_usage.get("prompt_tokens_details") or {}).get(
"cached_tokens"
),
}
output_token_details: dict = {
"audio": (oai_token_usage.get("completion_tokens_details") or {}).get(
"audio_tokens"
),
"reasoning": (oai_token_usage.get("completion_tokens_details") or {}).get(
"reasoning_tokens"
),
}
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
input_token_details=InputTokenDetails(
**{k: v for k, v in input_token_details.items() if v is not None}
),
output_token_details=OutputTokenDetails(
**{k: v for k, v in output_token_details.items() if v is not None}
),
)


def _process_chat_completion(outputs: Any):
"""Process the outputs of the chat completion endpoint. Turn the OpenAI objects
into a dictionary and insert the usage_metadata.
"""
try:
rdict = outputs.model_dump()
oai_token_usage = rdict.pop("usage")
rdict["usage_metadata"] = (
_create_usage_metadata(oai_token_usage) if oai_token_usage else None
)
return rdict
except BaseException as e:
logger.debug(f"Error processing chat completion: {e}")
return {"output": outputs}


def _get_wrapper(
original_create: Callable,
name: str,
reduce_fn: Callable,
tracing_extra: Optional[TracingExtra] = None,
invocation_params_fn: Optional[Callable] = None,
process_outputs: Optional[Callable] = None,
) -> Callable:
textra = tracing_extra or {}

Expand All @@ -177,6 +234,7 @@ def create(*args, stream: bool = False, **kwargs):
reduce_fn=reduce_fn if stream else None,
process_inputs=_strip_not_given,
_invocation_params_fn=invocation_params_fn,
process_outputs=process_outputs,
**textra,
)

Expand All @@ -191,6 +249,7 @@ async def acreate(*args, stream: bool = False, **kwargs):
reduce_fn=reduce_fn if stream else None,
process_inputs=_strip_not_given,
_invocation_params_fn=invocation_params_fn,
process_outputs=process_outputs,
**textra,
)
return await decorator(original_create)(*args, stream=stream, **kwargs)
Expand Down Expand Up @@ -232,6 +291,7 @@ def wrap_openai(
_reduce_chat,
tracing_extra=tracing_extra,
invocation_params_fn=functools.partial(_infer_invocation_params, "chat"),
process_outputs=_process_chat_completion,
)
client.completions.create = _get_wrapper( # type: ignore[method-assign]
client.completions.create,
Expand Down
108 changes: 108 additions & 0 deletions python/tests/integration_tests/wrappers/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# mypy: disable-error-code="attr-defined, union-attr, arg-type, call-overload"
import time
from datetime import datetime
from unittest import mock
from uuid import uuid4

import pytest

Expand Down Expand Up @@ -180,3 +182,109 @@ async def test_completions_async_api(mock_session: mock.MagicMock, stream: bool)
assert mock_session.return_value.request.call_count >= 1
for call in mock_session.return_value.request.call_args_list[1:]:
assert call[0][0].upper() == "POST"


class Collect:
"""
Collects the runs for inspection.
"""

def __init__(self):
self.run = None

def __call__(self, run):
self.run = run


def test_wrap_openai_token_counts():
import openai

oai_client = openai.Client()

wrapped_oai_client = wrap_openai(oai_client)

project_name = f"__test_wrap_openai_{datetime.now().isoformat()}_{uuid4().hex[:6]}"
ls_client = langsmith.Client()

collect = Collect()
try:
run_id_to_usage_metadata = {}
with langsmith.tracing_context(
enabled=True, project_name=project_name, client=ls_client
):
# stream usage
res = wrapped_oai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "howdy"}],
langsmith_extra={"on_end": collect},
stream=True,
stream_options={"include_usage": True},
)

for _ in res:
# consume the stream
pass

run_id_to_usage_metadata[collect.run.id] = collect.run.outputs[
"usage_metadata"
]

# stream without usage
res = wrapped_oai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "howdy"}],
langsmith_extra={"on_end": collect},
stream=True,
)

for _ in res:
# consume the stream
pass

assert collect.run.outputs.get("usage_metadata") is None
assert collect.run.outputs.get("usage") is None

wrapped_oai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "howdy"}],
langsmith_extra={"on_end": collect},
)

run_id_to_usage_metadata[collect.run.id] = collect.run.outputs[
"usage_metadata"
]

wrapped_oai_client.chat.completions.create(
model="o1-mini",
messages=[
{
"role": "user",
"content": "Write a bash script that takes a matrix represented as a string with format '[1,2],[3,4],[5,6]' and prints the transpose in the same format.",
}
],
langsmith_extra={"on_end": collect},
)

run_id_to_usage_metadata[collect.run.id] = collect.run.outputs[
"usage_metadata"
]

# handle pending runs
runs = list(ls_client.list_runs(project_name=project_name))
assert len(runs) == 4
for run in runs:
assert run.outputs.get("usage_metadata") is not None

# assert collect.run is not None
# print(collect.run)
# for call in mock_session.return_value.request.call_args_list:
# # assert call[0][0].upper() == "POST"
#
# json_bytes = call.kwargs.get("data")
# if json_bytes is not None:
# json_str = json_bytes.decode("utf-8")
# import json
# dict = json.loads(json_str)
# print(dict)
finally:
ls_client.delete_project(project_name=project_name)

0 comments on commit 360d77b

Please sign in to comment.