Skip to content

Commit

Permalink
Attachments in patch (#1212)
Browse files Browse the repository at this point in the history
Co-authored-by: isaac hershenson <[email protected]>
  • Loading branch information
hinthornw and isahers1 authored Nov 21, 2024
1 parent a3700dd commit 09ab56b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
8 changes: 8 additions & 0 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,9 @@ def update_run(
events: Optional[Sequence[dict]] = None,
extra: Optional[Dict] = None,
tags: Optional[List[str]] = None,
attachments: Optional[
Dict[str, tuple[str, bytes] | ls_schemas.Attachment]
] = None,
**kwargs: Any,
) -> None:
"""Update a run in the LangSmith API.
Expand All @@ -1644,6 +1647,9 @@ def update_run(
The extra information for the run.
tags : List[str] or None, default=None
The tags for the run.
attachments: dict[str, ls_schemas.Attachment] or None, default=None
A dictionary of attachments to add to the run. The keys are the attachment names,
and the values are Attachment objects containing the data and mime type.
**kwargs : Any
Kwargs are ignored.
"""
Expand All @@ -1658,6 +1664,8 @@ def update_run(
"session_id": kwargs.pop("session_id", None),
"session_name": kwargs.pop("session_name", None),
}
if attachments:
data["attachments"] = attachments
use_multipart = (
self.tracing_queue is not None
# batch ingest requires trace_id and dotted_order to be set
Expand Down
30 changes: 30 additions & 0 deletions python/langsmith/run_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,15 @@ def post(self, exclude_child_runs: bool = True) -> None:
"""Post the run tree to the API asynchronously."""
kwargs = self._get_dicts_safe()
self.client.create_run(**kwargs)
if attachments := kwargs.get("attachments"):
keys = [str(name) for name in attachments]
self.events.append(
{
"name": "uploaded_attachment",
"time": datetime.now(timezone.utc).isoformat(),
"message": set(keys),
}
)
if not exclude_child_runs:
for child_run in self.child_runs:
child_run.post(exclude_child_runs=False)
Expand All @@ -309,6 +318,26 @@ def patch(self) -> None:
"""Patch the run tree to the API in a background thread."""
if not self.end_time:
self.end()
attachments = self.attachments
try:
# Avoid loading the same attachment twice
if attachments:
uploaded = next(
(
ev
for ev in self.events
if ev.get("name") == "uploaded_attachment"
),
None,
)
if uploaded:
attachments = {
a: v
for a, v in attachments.items()
if a not in uploaded["message"]
}
except Exception as e:
logger.warning(f"Error filtering attachments to upload: {e}")
self.client.update_run(
name=self.name,
run_id=self.id,
Expand All @@ -322,6 +351,7 @@ def patch(self) -> None:
events=self.events,
tags=self.tags,
extra=self.extra,
attachments=attachments,
)

def wait(self) -> None:
Expand Down
25 changes: 20 additions & 5 deletions python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,11 @@ def my_func(
val: int,
att1: ls_schemas.Attachment,
att2: Annotated[tuple, ls_schemas.Attachment],
run_tree: RunTree,
):
run_tree.attachments["anoutput"] = ls_schemas.Attachment(
mime_type="text/plain", data=b"noidea"
)
return "foo"

mock_client = _get_mock_client(
Expand All @@ -1739,11 +1743,15 @@ def my_func(
)
assert result == "foo"

calls = _get_calls(mock_client)
datas = _get_multipart_data(calls)
for _ in range(10):
calls = _get_calls(mock_client)
datas = _get_multipart_data(calls)
if len(datas) >= 7:
break
time.sleep(1)

# main run, inputs, outputs, events, att1, att2
assert len(datas) == 6
# main run, inputs, outputs, events, att1, att2, anoutput
assert len(datas) == 7
# First 4 are type application/json (run, inputs, outputs, events)
trace_id = datas[0][0].split(".")[1]
_, (_, run_stuff) = next(
Expand All @@ -1760,7 +1768,7 @@ def my_func(
data for data in datas if data[0] == f"post.{trace_id}.inputs"
)
assert json.loads(inputs) == {"val": 42}
# last two are the mime types provided
# last three are the mime types provided
_, (mime_type1, content1) = next(
data for data in datas if data[0] == f"attachment.{trace_id}.att1"
)
Expand All @@ -1772,3 +1780,10 @@ def my_func(
)
assert mime_type2 == "application/octet-stream"
assert content2 == b"content2"

# Assert that anoutput is uploaded
_, (mime_type_output, content_output) = next(
data for data in datas if data[0] == f"attachment.{trace_id}.anoutput"
)
assert mime_type_output == "text/plain"
assert content_output == b"noidea"

0 comments on commit 09ab56b

Please sign in to comment.