Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Apr 12, 2024
1 parent 0308d11 commit 25e5c75
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
36 changes: 36 additions & 0 deletions python/langsmith/wrappers/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@ async def acreate(*args, stream: bool = False, **kwargs):
return acreate if run_helpers.is_async(original_create) else create


def _wrap_assistant_method(original_method: Callable, name: str) -> Callable:
wrapped_method = run_helpers.traceable(name=name)(original_method)

@functools.wraps(original_method)
def new_method(*args, **kwargs):
thread_id = kwargs.get("thread_id")
langsmith_extra = kwargs.get("langsmith_extra") or {}
metadata = langsmith_extra.setdefault("metadata", {})
if thread_id and "thread_id" not in metadata:
metadata["thread_id"] = thread_id
return wrapped_method(*args, **kwargs, langsmith_extra=langsmith_extra)

return new_method


def wrap_openai(client: C) -> C:
"""Patch the OpenAI client to make it traceable.
Expand All @@ -187,4 +202,25 @@ def wrap_openai(client: C) -> C:
client.completions.create = _get_wrapper( # type: ignore[method-assign]
client.completions.create, "OpenAI", _reduce_completions
)
# Beta OpenAI assistants methods
methods = [
("beta.threads.messages.create", "MessageOpenAI", None),
("beta.threads.runs.create", "RunAssistant", None),
("beta.threads.runs.create_and_poll", "RunAndPollAssistant", None),
("beta.threads.runs.stream", "StreamAssistant", None),
]

for method, name, reduce_fn in methods:
try:
root = client
methods = method.split(".")
for attr in methods[:-1]:
root = getattr(root, attr)
traced_method = _wrap_assistant_method(
getattr(root, methods[-1]), name=name
)
setattr(root, traced_method) # type: ignore[attr-defined]
except BaseException as e:
logger.debug(f"Could not patch {method}: {repr(e)}")

return client
43 changes: 43 additions & 0 deletions python/tests/integration_tests/wrappers/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mypy: disable-error-code="attr-defined, union-attr, arg-type, call-overload"
import time
from typing import Any
from unittest import mock

import pytest
Expand Down Expand Up @@ -162,3 +163,45 @@ async def test_completions_async_api(mock_session: mock.MagicMock, stream: bool)
assert mock_session.return_value.request.call_count >= 1
for call in mock_session.return_value.request.call_args_list:
assert call[0][0].upper() == "POST"


# @mock.patch("langsmith.client.requests.Session")
@pytest.mark.parametrize("stream", [False, True])
def test_openai_assistant_sync_api(
stream: bool,
):
import openai

patched_client = wrap_openai(openai.Client())
try:
assistant = patched_client.beta.assistants.create(
name="Math Tutor",
instructions="You are a personal math tutor. Write"
" and run code to answer math questions.",
tools=[{"type": "code_interpreter"}],
model="gpt-4-turbo",
)
thread = patched_client.beta.threads.create()
patched_client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content="I need to solve the equation `3x + 11 = 14`. Can you help me?",
)
if stream:
with patched_client.beta.threads.runs.stream(
thread_id=thread.id,
assistant_id=assistant.id,
instructions="Please address the user as Jane Doe. "
"The user has a premium account.",
) as stream:
for chunk in stream:
print(chunk)
else:
patched_client.beta.threads.runs.create_and_poll(
thread_id=thread.id,
assistant_id=assistant.id,
instructions="Please address the user as Jane Doe. "
"The user has a premium account.",
)
finally:
patched_client.beta.assistants.delete(assistant.id)

0 comments on commit 25e5c75

Please sign in to comment.