diff --git a/python/docs/create_api_rst.py b/python/docs/create_api_rst.py index 253352767..7a2f75260 100644 --- a/python/docs/create_api_rst.py +++ b/python/docs/create_api_rst.py @@ -105,7 +105,9 @@ def _load_module_members(module_path: str, namespace: str) -> ModuleMembers: else ( "enum" if issubclass(type_, Enum) - else "Pydantic" if issubclass(type_, BaseModel) else "Regular" + else "Pydantic" + if issubclass(type_, BaseModel) + else "Regular" ) ) if hasattr(type_, "__slots__"): diff --git a/python/langsmith/schemas.py b/python/langsmith/schemas.py index d859418e7..8ad12b3d0 100644 --- a/python/langsmith/schemas.py +++ b/python/langsmith/schemas.py @@ -17,7 +17,7 @@ ) from uuid import UUID -from typing_extensions import TypedDict +from typing_extensions import NotRequired, TypedDict try: from pydantic.v1 import ( # type: ignore[import] @@ -891,3 +891,64 @@ class PromptSortField(str, Enum): """Last updated time.""" num_likes = "num_likes" """Number of likes.""" + + +class InputTokenDetails(TypedDict, total=False): + """Breakdown of input token counts. + + Does *not* need to sum to full input token count. Does *not* need to have all keys. + """ + + audio: int + """Audio input tokens.""" + cache_creation: int + """Input tokens that were cached and there was a cache miss. + + Since there was a cache miss, the cache was created from these tokens. + """ + cache_read: int + """Input tokens that were cached and there was a cache hit. + + Since there was a cache hit, the tokens were read from the cache. More precisely, + the model state given these tokens was read from the cache. + """ + + +class OutputTokenDetails(TypedDict, total=False): + """Breakdown of output token counts. + + Does *not* need to sum to full output token count. Does *not* need to have all keys. + """ + + audio: int + """Audio output tokens.""" + reasoning: int + """Reasoning output tokens. + + Tokens generated by the model in a chain of thought process (i.e. by OpenAI's o1 + models) that are not returned as part of model output. + """ + + +class UsageMetadata(TypedDict): + """Usage metadata for a message, such as token counts. + + This is a standard representation of token usage that is consistent across models. + """ + + input_tokens: int + """Count of input (or prompt) tokens. Sum of all input token types.""" + output_tokens: int + """Count of output (or completion) tokens. Sum of all output token types.""" + total_tokens: int + """Total token count. Sum of input_tokens + output_tokens.""" + input_token_details: NotRequired[InputTokenDetails] + """Breakdown of input token counts. + + Does *not* need to sum to full input token count. Does *not* need to have all keys. + """ + output_token_details: NotRequired[OutputTokenDetails] + """Breakdown of output token counts. + + Does *not* need to sum to full output token count. Does *not* need to have all keys. + """ diff --git a/python/langsmith/wrappers/_openai.py b/python/langsmith/wrappers/_openai.py index 014d364cd..c6d8184e4 100644 --- a/python/langsmith/wrappers/_openai.py +++ b/python/langsmith/wrappers/_openai.py @@ -21,6 +21,7 @@ from langsmith import client as ls_client from langsmith import run_helpers +from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata if TYPE_CHECKING: from openai import AsyncOpenAI, OpenAI @@ -141,6 +142,12 @@ def _reduce_chat(all_chunks: List[ChatCompletionChunk]) -> dict: ] else: d = {"choices": [{"message": {"role": "assistant", "content": ""}}]} + # streamed outputs don't go through `process_outputs` + # so we need to flatten metadata here + oai_token_usage = d.pop("usage") + d["usage_metadata"] = ( + _create_usage_metadata(oai_token_usage) if oai_token_usage else None + ) return d @@ -160,12 +167,62 @@ def _reduce_completions(all_chunks: List[Completion]) -> dict: return d +def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata: + input_tokens = oai_token_usage.get("prompt_tokens", 0) + output_tokens = oai_token_usage.get("completion_tokens", 0) + total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens) + input_token_details: dict = { + "audio": (oai_token_usage.get("prompt_tokens_details") or {}).get( + "audio_tokens" + ), + "cache_read": (oai_token_usage.get("prompt_tokens_details") or {}).get( + "cached_tokens" + ), + } + output_token_details: dict = { + "audio": (oai_token_usage.get("completion_tokens_details") or {}).get( + "audio_tokens" + ), + "reasoning": (oai_token_usage.get("completion_tokens_details") or {}).get( + "reasoning_tokens" + ), + } + return UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + input_token_details=InputTokenDetails( + **{k: v for k, v in input_token_details.items() if v is not None} + ), + output_token_details=OutputTokenDetails( + **{k: v for k, v in output_token_details.items() if v is not None} + ), + ) + + +def _process_chat_completion(outputs: Any): + """Process the outputs of the chat completion endpoint. Turn the OpenAI objects + into a dictionary and insert the usage_metadata. + """ + try: + rdict = outputs.model_dump() + oai_token_usage = rdict.pop("usage") + rdict["usage_metadata"] = ( + _create_usage_metadata(oai_token_usage) if oai_token_usage else None + ) + return rdict + except BaseException as e: + logger.debug(f"Error processing chat completion: {e}") + return {"output": outputs} + + def _get_wrapper( original_create: Callable, name: str, reduce_fn: Callable, tracing_extra: Optional[TracingExtra] = None, invocation_params_fn: Optional[Callable] = None, + process_outputs: Optional[Callable] = None, ) -> Callable: textra = tracing_extra or {} @@ -177,6 +234,7 @@ def create(*args, stream: bool = False, **kwargs): reduce_fn=reduce_fn if stream else None, process_inputs=_strip_not_given, _invocation_params_fn=invocation_params_fn, + process_outputs=process_outputs, **textra, ) @@ -191,6 +249,7 @@ async def acreate(*args, stream: bool = False, **kwargs): reduce_fn=reduce_fn if stream else None, process_inputs=_strip_not_given, _invocation_params_fn=invocation_params_fn, + process_outputs=process_outputs, **textra, ) return await decorator(original_create)(*args, stream=stream, **kwargs) @@ -232,6 +291,7 @@ def wrap_openai( _reduce_chat, tracing_extra=tracing_extra, invocation_params_fn=functools.partial(_infer_invocation_params, "chat"), + process_outputs=_process_chat_completion, ) client.completions.create = _get_wrapper( # type: ignore[method-assign] client.completions.create, diff --git a/python/tests/integration_tests/wrappers/test_openai.py b/python/tests/integration_tests/wrappers/test_openai.py index 32dcd85c2..c98805775 100644 --- a/python/tests/integration_tests/wrappers/test_openai.py +++ b/python/tests/integration_tests/wrappers/test_openai.py @@ -1,6 +1,8 @@ # mypy: disable-error-code="attr-defined, union-attr, arg-type, call-overload" import time +from datetime import datetime from unittest import mock +from uuid import uuid4 import pytest @@ -180,3 +182,109 @@ async def test_completions_async_api(mock_session: mock.MagicMock, stream: bool) assert mock_session.return_value.request.call_count >= 1 for call in mock_session.return_value.request.call_args_list[1:]: assert call[0][0].upper() == "POST" + + +class Collect: + """ + Collects the runs for inspection. + """ + + def __init__(self): + self.run = None + + def __call__(self, run): + self.run = run + + +def test_wrap_openai_token_counts(): + import openai + + oai_client = openai.Client() + + wrapped_oai_client = wrap_openai(oai_client) + + project_name = f"__test_wrap_openai_{datetime.now().isoformat()}_{uuid4().hex[:6]}" + ls_client = langsmith.Client() + + collect = Collect() + try: + run_id_to_usage_metadata = {} + with langsmith.tracing_context( + enabled=True, project_name=project_name, client=ls_client + ): + # stream usage + res = wrapped_oai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "howdy"}], + langsmith_extra={"on_end": collect}, + stream=True, + stream_options={"include_usage": True}, + ) + + for _ in res: + # consume the stream + pass + + run_id_to_usage_metadata[collect.run.id] = collect.run.outputs[ + "usage_metadata" + ] + + # stream without usage + res = wrapped_oai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "howdy"}], + langsmith_extra={"on_end": collect}, + stream=True, + ) + + for _ in res: + # consume the stream + pass + + assert collect.run.outputs.get("usage_metadata") is None + assert collect.run.outputs.get("usage") is None + + wrapped_oai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "howdy"}], + langsmith_extra={"on_end": collect}, + ) + + run_id_to_usage_metadata[collect.run.id] = collect.run.outputs[ + "usage_metadata" + ] + + wrapped_oai_client.chat.completions.create( + model="o1-mini", + messages=[ + { + "role": "user", + "content": "Write a bash script that takes a matrix represented as a string with format '[1,2],[3,4],[5,6]' and prints the transpose in the same format.", + } + ], + langsmith_extra={"on_end": collect}, + ) + + run_id_to_usage_metadata[collect.run.id] = collect.run.outputs[ + "usage_metadata" + ] + + # handle pending runs + runs = list(ls_client.list_runs(project_name=project_name)) + assert len(runs) == 4 + for run in runs: + assert run.outputs.get("usage_metadata") is not None + + # assert collect.run is not None + # print(collect.run) + # for call in mock_session.return_value.request.call_args_list: + # # assert call[0][0].upper() == "POST" + # + # json_bytes = call.kwargs.get("data") + # if json_bytes is not None: + # json_str = json_bytes.decode("utf-8") + # import json + # dict = json.loads(json_str) + # print(dict) + finally: + ls_client.delete_project(project_name=project_name)