diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 1647b790d..9dd1442a1 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -1620,6 +1620,9 @@ def update_run( events: Optional[Sequence[dict]] = None, extra: Optional[Dict] = None, tags: Optional[List[str]] = None, + attachments: Optional[ + Dict[str, tuple[str, bytes] | ls_schemas.Attachment] + ] = None, **kwargs: Any, ) -> None: """Update a run in the LangSmith API. @@ -1644,6 +1647,9 @@ def update_run( The extra information for the run. tags : List[str] or None, default=None The tags for the run. + attachments: dict[str, ls_schemas.Attachment] or None, default=None + A dictionary of attachments to add to the run. The keys are the attachment names, + and the values are Attachment objects containing the data and mime type. **kwargs : Any Kwargs are ignored. """ @@ -1658,6 +1664,8 @@ def update_run( "session_id": kwargs.pop("session_id", None), "session_name": kwargs.pop("session_name", None), } + if attachments: + data["attachments"] = attachments use_multipart = ( self.tracing_queue is not None # batch ingest requires trace_id and dotted_order to be set diff --git a/python/langsmith/run_trees.py b/python/langsmith/run_trees.py index d3083e4f6..63f0cb4e5 100644 --- a/python/langsmith/run_trees.py +++ b/python/langsmith/run_trees.py @@ -301,6 +301,15 @@ def post(self, exclude_child_runs: bool = True) -> None: """Post the run tree to the API asynchronously.""" kwargs = self._get_dicts_safe() self.client.create_run(**kwargs) + if attachments := kwargs.get("attachments"): + keys = [str(name) for name in attachments] + self.events.append( + { + "name": "uploaded_attachment", + "time": datetime.now(timezone.utc).isoformat(), + "message": set(keys), + } + ) if not exclude_child_runs: for child_run in self.child_runs: child_run.post(exclude_child_runs=False) @@ -309,6 +318,26 @@ def patch(self) -> None: """Patch the run tree to the API in a background thread.""" if not self.end_time: self.end() + attachments = self.attachments + try: + # Avoid loading the same attachment twice + if attachments: + uploaded = next( + ( + ev + for ev in self.events + if ev.get("name") == "uploaded_attachment" + ), + None, + ) + if uploaded: + attachments = { + a: v + for a, v in attachments.items() + if a not in uploaded["message"] + } + except Exception as e: + logger.warning(f"Error filtering attachments to upload: {e}") self.client.update_run( name=self.name, run_id=self.id, @@ -322,6 +351,7 @@ def patch(self) -> None: events=self.events, tags=self.tags, extra=self.extra, + attachments=attachments, ) def wait(self) -> None: diff --git a/python/tests/unit_tests/test_run_helpers.py b/python/tests/unit_tests/test_run_helpers.py index dbbbe1adf..34df400e7 100644 --- a/python/tests/unit_tests/test_run_helpers.py +++ b/python/tests/unit_tests/test_run_helpers.py @@ -1714,7 +1714,11 @@ def my_func( val: int, att1: ls_schemas.Attachment, att2: Annotated[tuple, ls_schemas.Attachment], + run_tree: RunTree, ): + run_tree.attachments["anoutput"] = ls_schemas.Attachment( + mime_type="text/plain", data=b"noidea" + ) return "foo" mock_client = _get_mock_client( @@ -1739,11 +1743,15 @@ def my_func( ) assert result == "foo" - calls = _get_calls(mock_client) - datas = _get_multipart_data(calls) + for _ in range(10): + calls = _get_calls(mock_client) + datas = _get_multipart_data(calls) + if len(datas) >= 7: + break + time.sleep(1) - # main run, inputs, outputs, events, att1, att2 - assert len(datas) == 6 + # main run, inputs, outputs, events, att1, att2, anoutput + assert len(datas) == 7 # First 4 are type application/json (run, inputs, outputs, events) trace_id = datas[0][0].split(".")[1] _, (_, run_stuff) = next( @@ -1760,7 +1768,7 @@ def my_func( data for data in datas if data[0] == f"post.{trace_id}.inputs" ) assert json.loads(inputs) == {"val": 42} - # last two are the mime types provided + # last three are the mime types provided _, (mime_type1, content1) = next( data for data in datas if data[0] == f"attachment.{trace_id}.att1" ) @@ -1772,3 +1780,10 @@ def my_func( ) assert mime_type2 == "application/octet-stream" assert content2 == b"content2" + + # Assert that anoutput is uploaded + _, (mime_type_output, content_output) = next( + data for data in datas if data[0] == f"attachment.{trace_id}.anoutput" + ) + assert mime_type_output == "text/plain" + assert content_output == b"noidea"