diff --git a/python/langsmith/wrappers/_openai.py b/python/langsmith/wrappers/_openai.py index 22b4e25c5..d850c8961 100644 --- a/python/langsmith/wrappers/_openai.py +++ b/python/langsmith/wrappers/_openai.py @@ -171,6 +171,21 @@ async def acreate(*args, stream: bool = False, **kwargs): return acreate if run_helpers.is_async(original_create) else create +def _wrap_assistant_method(original_method: Callable, name: str) -> Callable: + wrapped_method = run_helpers.traceable(name=name)(original_method) + + @functools.wraps(original_method) + def new_method(*args, **kwargs): + thread_id = kwargs.get("thread_id") + langsmith_extra = kwargs.get("langsmith_extra") or {} + metadata = langsmith_extra.setdefault("metadata", {}) + if thread_id and "thread_id" not in metadata: + metadata["thread_id"] = thread_id + return wrapped_method(*args, **kwargs, langsmith_extra=langsmith_extra) + + return new_method + + def wrap_openai(client: C) -> C: """Patch the OpenAI client to make it traceable. @@ -187,4 +202,25 @@ def wrap_openai(client: C) -> C: client.completions.create = _get_wrapper( # type: ignore[method-assign] client.completions.create, "OpenAI", _reduce_completions ) + # Beta OpenAI assistants methods + methods = [ + ("beta.threads.messages.create", "MessageOpenAI", None), + ("beta.threads.runs.create", "RunAssistant", None), + ("beta.threads.runs.create_and_poll", "RunAndPollAssistant", None), + ("beta.threads.runs.stream", "StreamAssistant", None), + ] + + for method, name, reduce_fn in methods: + try: + root = client + methods = method.split(".") + for attr in methods[:-1]: + root = getattr(root, attr) + traced_method = _wrap_assistant_method( + getattr(root, methods[-1]), name=name + ) + setattr(root, traced_method) # type: ignore[attr-defined] + except BaseException as e: + logger.debug(f"Could not patch {method}: {repr(e)}") + return client diff --git a/python/tests/integration_tests/wrappers/test_openai.py b/python/tests/integration_tests/wrappers/test_openai.py index 40804b96d..b6b2d5003 100644 --- a/python/tests/integration_tests/wrappers/test_openai.py +++ b/python/tests/integration_tests/wrappers/test_openai.py @@ -1,5 +1,6 @@ # mypy: disable-error-code="attr-defined, union-attr, arg-type, call-overload" import time +from typing import Any from unittest import mock import pytest @@ -162,3 +163,45 @@ async def test_completions_async_api(mock_session: mock.MagicMock, stream: bool) assert mock_session.return_value.request.call_count >= 1 for call in mock_session.return_value.request.call_args_list: assert call[0][0].upper() == "POST" + + +# @mock.patch("langsmith.client.requests.Session") +@pytest.mark.parametrize("stream", [False, True]) +def test_openai_assistant_sync_api( + stream: bool, +): + import openai + + patched_client = wrap_openai(openai.Client()) + try: + assistant = patched_client.beta.assistants.create( + name="Math Tutor", + instructions="You are a personal math tutor. Write" + " and run code to answer math questions.", + tools=[{"type": "code_interpreter"}], + model="gpt-4-turbo", + ) + thread = patched_client.beta.threads.create() + patched_client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content="I need to solve the equation `3x + 11 = 14`. Can you help me?", + ) + if stream: + with patched_client.beta.threads.runs.stream( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="Please address the user as Jane Doe. " + "The user has a premium account.", + ) as stream: + for chunk in stream: + print(chunk) + else: + patched_client.beta.threads.runs.create_and_poll( + thread_id=thread.id, + assistant_id=assistant.id, + instructions="Please address the user as Jane Doe. " + "The user has a premium account.", + ) + finally: + patched_client.beta.assistants.delete(assistant.id)