Skip to content

Commit

Permalink
feat: Instrument wrapper for OpenAI structured outputs (#1373)
Browse files Browse the repository at this point in the history
### Purpose
Instrument tracing for OpenAI Structured Outputs
### Changes
1. Created openai parser wrapper for structured outputs beta method
`client.beta.chat.completions.parse`
1. Added support for both sync and async tracing
1. Maintains backwards compatibility for clients that don't have beta
parse methods
### Tests
Added unit tests for parsing methods, and token counts. Tested locally
to confirm parsing gets traced
<img width="318" alt="Screenshot 2024-12-02 at 9 45 16 PM"
src="https://github.com/user-attachments/assets/20300d05-65ff-4a8b-857d-e18e2e89f79d"
/>



> [!NOTE]  
> Parsing does not support streaming, so the `reduce_fn` is always set
to None
  • Loading branch information
angus-langchain authored Jan 3, 2025
1 parent fffa5b5 commit 281dc10
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 0 deletions.
53 changes: 53 additions & 0 deletions python/langsmith/wrappers/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,44 @@ async def acreate(*args, stream: bool = False, **kwargs):
return acreate if run_helpers.is_async(original_create) else create


def _get_parse_wrapper(
original_parse: Callable,
name: str,
tracing_extra: Optional[TracingExtra] = None,
invocation_params_fn: Optional[Callable] = None,
) -> Callable:
textra = tracing_extra or {}

@functools.wraps(original_parse)
def parse(*args, **kwargs):
decorator = run_helpers.traceable(
name=name,
run_type="llm",
reduce_fn=None,
process_inputs=_strip_not_given,
_invocation_params_fn=invocation_params_fn,
process_outputs=_process_chat_completion,
**textra,
)
return decorator(original_parse)(*args, **kwargs)

@functools.wraps(original_parse)
async def aparse(*args, **kwargs):
kwargs = _strip_not_given(kwargs)
decorator = run_helpers.traceable(
name=name,
run_type="llm",
reduce_fn=None,
process_inputs=_strip_not_given,
_invocation_params_fn=invocation_params_fn,
process_outputs=_process_chat_completion,
**textra,
)
return await decorator(original_parse)(*args, **kwargs)

return aparse if run_helpers.is_async(original_parse) else parse


class TracingExtra(TypedDict, total=False):
metadata: Optional[Mapping[str, Any]]
tags: Optional[List[str]]
Expand Down Expand Up @@ -297,4 +335,19 @@ def wrap_openai(
tracing_extra=tracing_extra,
invocation_params_fn=functools.partial(_infer_invocation_params, "llm"),
)

# Wrap beta.chat.completions.parse if it exists
if (
hasattr(client, "beta")
and hasattr(client.beta, "chat")
and hasattr(client.beta.chat, "completions")
and hasattr(client.beta.chat.completions, "parse")
):
client.beta.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign]
client.beta.chat.completions.parse, # type: ignore
chat_name,
tracing_extra=tracing_extra,
invocation_params_fn=functools.partial(_infer_invocation_params, "chat"),
)

return client
99 changes: 99 additions & 0 deletions python/tests/integration_tests/wrappers/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,102 @@ async def test_wrap_openai_chat_async_tokens(test_case):

filename = f"langsmith_py_wrap_openai_{test_case['description'].replace(' ', '_')}"
_collect_requests(mock_session, filename)


def test_parse_sync_api():
"""Test that the sync parse method can be traced without errors."""
import openai # noqa

mock_session = mock.MagicMock()
ls_client = langsmith.Client(session=mock_session)

original_client = openai.Client()
patched_client = wrap_openai(openai.Client(), tracing_extra={"client": ls_client})

messages = [{"role": "user", "content": "Say 'Foo' then stop."}]

original = original_client.beta.chat.completions.parse(
messages=messages, model="gpt-3.5-turbo"
)
patched = patched_client.beta.chat.completions.parse(
messages=messages, model="gpt-3.5-turbo"
)

assert type(original) is type(patched)
assert original.choices == patched.choices

time.sleep(0.1)
for call in mock_session.request.call_args_list:
assert call[0][0].upper() in ["POST", "GET", "PATCH"]

_collect_requests(mock_session, "test_parse_sync_api")


@pytest.mark.asyncio
async def test_parse_async_api():
"""Test that the async parse method can be traced without errors."""
import openai # noqa

mock_session = mock.MagicMock()
ls_client = langsmith.Client(session=mock_session)

original_client = openai.AsyncClient()
patched_client = wrap_openai(
openai.AsyncClient(), tracing_extra={"client": ls_client}
)

messages = [{"role": "user", "content": "Say 'Foo' then stop."}]

original = await original_client.beta.chat.completions.parse(
messages=messages, model="gpt-3.5-turbo"
)
patched = await patched_client.beta.chat.completions.parse(
messages=messages, model="gpt-3.5-turbo"
)

assert type(original) is type(patched)
assert original.choices == patched.choices

time.sleep(0.1)
for call in mock_session.request.call_args_list:
assert call[0][0].upper() in ["POST", "GET", "PATCH"]

_collect_requests(mock_session, "test_parse_async_api")


def test_parse_tokens():
"""
Test that usage tokens are captured for parse calls
"""
import openai
from openai.types.chat import ChatCompletion

mock_session = mock.MagicMock()
ls_client = langsmith.Client(session=mock_session)
wrapped_oai_client = wrap_openai(
openai.Client(), tracing_extra={"client": ls_client}
)

collect = Collect()
messages = [{"role": "user", "content": "Check usage"}]

with langsmith.tracing_context(enabled=True):
res = wrapped_oai_client.beta.chat.completions.parse(
messages=messages,
model="gpt-3.5-turbo",
langsmith_extra={"on_end": collect},
)
assert isinstance(res, ChatCompletion)

usage_metadata = collect.run.outputs.get("usage_metadata")

if usage_metadata:
assert usage_metadata["input_tokens"] >= 0
assert usage_metadata["output_tokens"] >= 0
assert usage_metadata["total_tokens"] >= 0

time.sleep(0.1)
for call in mock_session.request.call_args_list:
assert call[0][0].upper() in ["POST", "GET", "PATCH"]

_collect_requests(mock_session, "test_parse_tokens")

0 comments on commit 281dc10

Please sign in to comment.