From 5d0a956ef591bfbc3841c6b0a6f80710b614b893 Mon Sep 17 00:00:00 2001 From: isaac hershenson Date: Wed, 18 Dec 2024 14:18:24 -0800 Subject: [PATCH] flag --- python/langsmith/client.py | 51 +++- python/tests/integration_tests/test_client.py | 250 +++++++++++++++--- 2 files changed, 257 insertions(+), 44 deletions(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 7d8e9ac2f..3f2912013 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -1537,6 +1537,7 @@ def multipart_ingest( ] = None, *, pre_sampled: bool = False, + dangerously_allow_filesystem: bool = False, ) -> None: """Batch ingest/upsert multiple runs in the Langsmith system. @@ -1621,6 +1622,18 @@ def multipart_ingest( ) ) + for op in serialized_ops: + if isinstance(op, SerializedRunOperation) and op.attachments: + for attachment in op.attachments.values(): + if ( + isinstance(attachment, tuple) + and isinstance(attachment[1], Path) + and not dangerously_allow_filesystem + ): + raise ValueError( + "Must set dangerously_allow_filesystem to True to use filesystem paths in multipart ingest." + ) + # sent the runs in multipart requests self._multipart_ingest_ops(serialized_ops) @@ -1684,6 +1697,7 @@ def update_run( extra: Optional[Dict] = None, tags: Optional[List[str]] = None, attachments: Optional[ls_schemas.Attachments] = None, + dangerously_allow_filesystem: bool = False, **kwargs: Any, ) -> None: """Update a run in the LangSmith API. @@ -1726,6 +1740,15 @@ def update_run( "session_name": kwargs.pop("session_name", None), } if attachments: + for _, attachment in attachments.items(): + if ( + isinstance(attachment, tuple) + and isinstance(attachment[1], Path) + and not dangerously_allow_filesystem + ): + raise ValueError( + "Must set dangerously_allow_filesystem=True to allow filesystem attachments." + ) data["attachments"] = attachments use_multipart = ( self.tracing_queue is not None @@ -3481,7 +3504,7 @@ def _prepare_multipart_data( | List[ls_schemas.ExampleUpdateWithAttachments], ], include_dataset_id: bool = False, - dangerously_allow_filesystem: Optional[bool] = False, + dangerously_allow_filesystem: bool = False, ) -> Tuple[Any, bytes]: parts: List[MultipartPart] = [] if include_dataset_id: @@ -3567,7 +3590,7 @@ def _prepare_multipart_data( for name, attachment in example.attachments.items(): if isinstance(attachment, tuple): if isinstance(attachment[1], Path): - if dangerously_allow_filesystem == True: + if dangerously_allow_filesystem: mime_type, file_path = attachment file_size = os.path.getsize(file_path) parts.append( @@ -3641,7 +3664,7 @@ def update_examples_multipart( *, dataset_id: ID_TYPE, updates: Optional[List[ls_schemas.ExampleUpdateWithAttachments]] = None, - dangerously_allow_filesystem: Optional[bool] = False, + dangerously_allow_filesystem: bool = False, ) -> ls_schemas.UpsertExamplesResponse: """Upload examples.""" if not (self.info.instance_flags or {}).get( @@ -3653,7 +3676,11 @@ def update_examples_multipart( if updates is None: updates = [] - encoder, data = self._prepare_multipart_data(updates, include_dataset_id=False, dangerously_allow_filesystem=dangerously_allow_filesystem) + encoder, data = self._prepare_multipart_data( + updates, + include_dataset_id=False, + dangerously_allow_filesystem=dangerously_allow_filesystem, + ) response = self.request_with_retries( "PATCH", @@ -3674,7 +3701,7 @@ def upload_examples_multipart( *, dataset_id: ID_TYPE, uploads: Optional[List[ls_schemas.ExampleUploadWithAttachments]] = None, - dangerously_allow_filesystem: Optional[bool] = False, + dangerously_allow_filesystem: bool = False, ) -> ls_schemas.UpsertExamplesResponse: """Upload examples.""" if not (self.info.instance_flags or {}).get( @@ -3685,7 +3712,11 @@ def upload_examples_multipart( ) if uploads is None: uploads = [] - encoder, data = self._prepare_multipart_data(uploads, include_dataset_id=False, dangerously_allow_filesystem=dangerously_allow_filesystem) + encoder, data = self._prepare_multipart_data( + uploads, + include_dataset_id=False, + dangerously_allow_filesystem=dangerously_allow_filesystem, + ) response = self.request_with_retries( "POST", @@ -3705,7 +3736,7 @@ def upsert_examples_multipart( self, *, upserts: Optional[List[ls_schemas.ExampleUpsertWithAttachments]] = None, - dangerously_allow_filesystem: Optional[bool] = False, + dangerously_allow_filesystem: bool = False, ) -> ls_schemas.UpsertExamplesResponse: """Upsert examples. @@ -3722,7 +3753,11 @@ def upsert_examples_multipart( if upserts is None: upserts = [] - encoder, data = self._prepare_multipart_data(upserts, include_dataset_id=True, dangerously_allow_filesystem=dangerously_allow_filesystem) + encoder, data = self._prepare_multipart_data( + upserts, + include_dataset_id=True, + dangerously_allow_filesystem=dangerously_allow_filesystem, + ) response = self.request_with_retries( "POST", diff --git a/python/tests/integration_tests/test_client.py b/python/tests/integration_tests/test_client.py index 4c277aa23..7b104a603 100644 --- a/python/tests/integration_tests/test_client.py +++ b/python/tests/integration_tests/test_client.py @@ -65,7 +65,7 @@ def langchain_client() -> Client: "dataset_examples_multipart_enabled": True, "examples_multipart_enabled": True, } - }, + } ) @@ -378,6 +378,41 @@ def test_persist_update_run(langchain_client: Client) -> None: langchain_client.delete_project(project_name=project_name) +def test_update_run_attachments(langchain_client: Client) -> None: + """Test the persist and update methods work as expected.""" + project_name = "__test_update_run_attachments" + uuid4().hex[:4] + if langchain_client.has_project(project_name): + langchain_client.delete_project(project_name=project_name) + try: + trace_id = uuid4() + start_time = datetime.datetime.now(datetime.timezone.utc) + run: dict = dict( + id=str(trace_id), + name="test_run", + run_type="llm", + inputs={"text": "hello world"}, + project_name=project_name, + api_url=os.getenv("LANGCHAIN_ENDPOINT"), + start_time=start_time, + extra={"extra": "extra"}, + trace_id=str(trace_id), + dotted_order=f"{start_time.strftime('%Y%m%dT%H%M%S%fZ')}{str(trace_id)}", + ) + langchain_client.create_run(**run) + run["outputs"] = {"output": ["Hi"]} + run["extra"]["foo"] = "bar" + run["name"] = "test_run_updated" + langchain_client.update_run(run["id"], **run) + wait_for(lambda: langchain_client.read_run(run["id"]).end_time is not None) + stored_run = langchain_client.read_run(run["id"]) + assert stored_run.name == run["name"] + assert str(stored_run.id) == run["id"] + assert stored_run.outputs == run["outputs"] + assert stored_run.start_time == run["start_time"].replace(tzinfo=None) + finally: + langchain_client.delete_project(project_name=project_name) + + @pytest.mark.parametrize("uri", ["http://localhost:1981", "http://api.langchain.minus"]) def test_error_surfaced_invalid_uri(uri: str) -> None: get_env_var.cache_clear() @@ -937,7 +972,7 @@ def test_multipart_ingest_empty( assert not caplog.records -def test_multipart_ingest_create_with_attachments( +def test_multipart_ingest_create_with_attachments_error( langchain_client: Client, caplog: pytest.LogCaptureFixture ) -> None: _session = "__test_multipart_ingest_create_with_attachments" @@ -966,13 +1001,46 @@ def test_multipart_ingest_create_with_attachments( ] # make sure no warnings logged - with caplog.at_level(logging.WARNING, logger="langsmith.client"): + with pytest.raises(ValueError, match="Must set dangerously_allow_filesystem"): langchain_client.multipart_ingest(create=runs_to_create, update=[]) + +def test_multipart_ingest_create_with_attachments( + langchain_client: Client, caplog: pytest.LogCaptureFixture +) -> None: + _session = "__test_multipart_ingest_create_with_attachments" + trace_a_id = uuid4() + current_time = datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y%m%dT%H%M%S%fZ" + ) + + runs_to_create: list[dict] = [ + { + "id": str(trace_a_id), + "session_name": _session, + "name": "trace a root", + "run_type": "chain", + "dotted_order": f"{current_time}{str(trace_a_id)}", + "trace_id": str(trace_a_id), + "inputs": {"input1": 1, "input2": 2}, + "attachments": { + "foo": ("text/plain", b"bar"), + "bar": ( + "image/png", + Path(__file__).parent / "test_data/parrot-icon.png", + ), + }, + } + ] + + # make sure no warnings logged + with caplog.at_level(logging.WARNING, logger="langsmith.client"): + langchain_client.multipart_ingest( + create=runs_to_create, update=[], dangerously_allow_filesystem=True + ) assert not caplog.records - time.sleep(5) # Need this so the run persists + wait_for(lambda: _get_run(str(trace_a_id), langchain_client)) created_run = langchain_client.read_run(run_id=str(trace_a_id)) - assert created_run.attachments assert sorted(created_run.attachments.keys()) == sorted(["foo", "bar"]) assert created_run.attachments["foo"]["reader"].read() == b"bar" assert ( @@ -980,17 +1048,79 @@ def test_multipart_ingest_create_with_attachments( == (Path(__file__).parent / "test_data/parrot-icon.png").read_bytes() ) - created_run = next(langchain_client.list_runs(run_ids=[str(trace_a_id)])) + +def test_multipart_ingest_update_with_attachments_no_paths( + langchain_client: Client, caplog: pytest.LogCaptureFixture +): + _session = "__test_multipart_ingest_update_with_attachments_no_paths" + trace_a_id = uuid4() + current_time = datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y%m%dT%H%M%S%fZ" + ) + + runs_to_create: list[dict] = [ + { + "id": str(trace_a_id), + "session_name": _session, + "name": "trace a root", + "run_type": "chain", + "dotted_order": f"{current_time}{str(trace_a_id)}", + "trace_id": str(trace_a_id), + "outputs": {"output1": 3, "output2": 4}, + "attachments": { + "foo": ("text/plain", b"bar"), + "bar": ("image/png", b"bar"), + }, + } + ] + with caplog.at_level(logging.WARNING, logger="langsmith.client"): + langchain_client.multipart_ingest(create=runs_to_create, update=[]) + + assert not caplog.records + wait_for(lambda: _get_run(str(trace_a_id), langchain_client)) + created_run = langchain_client.read_run(run_id=str(trace_a_id)) assert created_run.attachments assert sorted(created_run.attachments.keys()) == sorted(["foo", "bar"]) assert created_run.attachments["foo"]["reader"].read() == b"bar" - assert ( - created_run.attachments["bar"]["reader"].read() - == (Path(__file__).parent / "test_data/parrot-icon.png").read_bytes() + assert created_run.attachments["bar"]["reader"].read() == b"bar" + + runs_to_update: list[dict] = [ + { + "id": str(trace_a_id), + "dotted_order": f"{current_time}{str(trace_a_id)}", + "trace_id": str(trace_a_id), + "outputs": {"output1": 3, "output2": 4}, + "attachments": { + "baz": ("text/plain", b"bar"), + "qux": ("image/png", b"bar"), + }, + } + ] + with caplog.at_level(logging.WARNING, logger="langsmith.client"): + langchain_client.multipart_ingest(create=[], update=runs_to_update) + + assert not caplog.records + wait_for(lambda: _get_run(str(trace_a_id), langchain_client)) + created_run = langchain_client.read_run(run_id=str(trace_a_id)) + assert created_run.attachments + assert sorted(created_run.attachments.keys()) == sorted( + ["bar", "baz", "foo", "qux"] ) + assert created_run.attachments["baz"]["reader"].read() == b"bar" + assert created_run.attachments["qux"]["reader"].read() == b"bar" -def test_multipart_ingest_update_with_attachments( +def _get_run(run_id: ID_TYPE, langchain_client: Client, has_end: bool = False) -> bool: + try: + r = langchain_client.read_run(run_id) # type: ignore + if has_end: + return r.end_time is not None + return True + except LangSmithError: + return False + + +def test_multipart_ingest_update_with_attachments_error( langchain_client: Client, caplog: pytest.LogCaptureFixture ) -> None: _session = "__test_multipart_ingest_update_with_attachments" @@ -1015,31 +1145,79 @@ def test_multipart_ingest_update_with_attachments( # make sure no warnings logged with caplog.at_level(logging.WARNING, logger="langsmith.client"): langchain_client.multipart_ingest(create=runs_to_create, update=[]) - assert not caplog.records + wait_for(lambda: _get_run(str(trace_a_id), langchain_client)) - runs_to_update: list[dict] = [ + runs_to_update: list[dict] = [ + { + "id": str(trace_a_id), + "dotted_order": f"{current_time}{str(trace_a_id)}", + "trace_id": str(trace_a_id), + "inputs": {"input1": 3, "input2": 4}, + "attachments": { + "foo": ("text/plain", b"bar"), + "bar": ( + "image/png", + Path(__file__).parent / "test_data/parrot-icon.png", + ), + }, + } + ] + with pytest.raises(ValueError, match="Must set dangerously_allow_filesystem"): + langchain_client.multipart_ingest(create=[], update=runs_to_update) + + +def test_multipart_ingest_update_with_attachments( + langchain_client: Client, caplog: pytest.LogCaptureFixture +) -> None: + _session = "__test_multipart_ingest_update_with_attachments" + + trace_a_id = uuid4() + current_time = datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y%m%dT%H%M%S%fZ" + ) + + runs_to_create: list[dict] = [ { "id": str(trace_a_id), + "session_name": _session, + "name": "trace a root", + "run_type": "chain", "dotted_order": f"{current_time}{str(trace_a_id)}", "trace_id": str(trace_a_id), - "outputs": {"output1": 3, "output2": 4}, - "attachments": { - "foo": ("text/plain", b"bar"), - "bar": ( - "image/png", - Path(__file__).parent / "test_data/parrot-icon.png", - ), - }, + "inputs": {"input1": 1, "input2": 2}, } ] + + # make sure no warnings logged with caplog.at_level(logging.WARNING, logger="langsmith.client"): - langchain_client.multipart_ingest(create=[], update=runs_to_update) + langchain_client.multipart_ingest(create=runs_to_create, update=[]) + assert not caplog.records + wait_for(lambda: _get_run(str(trace_a_id), langchain_client)) + + runs_to_update: list[dict] = [ + { + "id": str(trace_a_id), + "dotted_order": f"{current_time}{str(trace_a_id)}", + "trace_id": str(trace_a_id), + "inputs": {"input1": 3, "input2": 4}, + "attachments": { + "foo": ("text/plain", b"bar"), + "bar": ( + "image/png", + Path(__file__).parent / "test_data/parrot-icon.png", + ), + }, + } + ] + langchain_client.multipart_ingest( + create=[], update=runs_to_update, dangerously_allow_filesystem=True + ) assert not caplog.records - time.sleep(3) + wait_for(lambda: _get_run(str(trace_a_id), langchain_client)) created_run = langchain_client.read_run(run_id=str(trace_a_id)) - assert created_run.attachments + assert created_run.inputs == {"input1": 3, "input2": 4} assert sorted(created_run.attachments.keys()) == sorted(["foo", "bar"]) assert created_run.attachments["foo"]["reader"].read() == b"bar" assert ( @@ -1075,6 +1253,7 @@ def test_multipart_ingest_create_then_update( langchain_client.multipart_ingest(create=runs_to_create, update=[]) assert not caplog.records + wait_for(lambda: _get_run(str(trace_a_id), langchain_client)) runs_to_update: list[dict] = [ { @@ -1088,6 +1267,10 @@ def test_multipart_ingest_create_then_update( langchain_client.multipart_ingest(create=[], update=runs_to_update) assert not caplog.records + wait_for(lambda: _get_run(str(trace_a_id), langchain_client)) + + created_run = langchain_client.read_run(run_id=str(trace_a_id)) + assert created_run.outputs == {"output1": 3, "output2": 4} def test_multipart_ingest_update_then_create( @@ -1201,16 +1384,7 @@ def test_update_run_extra(add_metadata: bool, do_batching: bool) -> None: revision_id = uuid4() langchain_client.create_run(**run, revision_id=revision_id) # type: ignore - def _get_run(run_id: ID_TYPE, has_end: bool = False) -> bool: - try: - r = langchain_client.read_run(run_id) # type: ignore - if has_end: - return r.end_time is not None - return True - except LangSmithError: - return False - - wait_for(lambda: _get_run(run_id)) + wait_for(lambda: _get_run(run_id, langchain_client)) created_run = langchain_client.read_run(run_id) assert created_run.metadata["foo"] == "bar" assert created_run.metadata["revision_id"] == str(revision_id) @@ -1219,7 +1393,7 @@ def _get_run(run_id: ID_TYPE, has_end: bool = False) -> bool: run["extra"]["metadata"]["foo2"] = "baz" # type: ignore run["tags"] = ["tag3"] langchain_client.update_run(run_id, **run) # type: ignore - wait_for(lambda: _get_run(run_id, has_end=True)) + wait_for(lambda: _get_run(run_id, langchain_client, has_end=True)) updated_run = langchain_client.read_run(run_id) assert updated_run.metadata["foo"] == "bar" # type: ignore assert updated_run.revision_id == str(revision_id) @@ -1988,7 +2162,9 @@ def test_examples_multipart_attachment_path(langchain_client: Client) -> None: ) langchain_client.update_examples_multipart( - dataset_id=dataset.id, updates=[example_update], dangerously_allow_filesystem=True + dataset_id=dataset.id, + updates=[example_update], + dangerously_allow_filesystem=True, ) retrieved = langchain_client.read_example(example_id) @@ -2013,7 +2189,9 @@ def test_examples_multipart_attachment_path(langchain_client: Client) -> None: with pytest.raises(FileNotFoundError) as exc_info: langchain_client.upload_examples_multipart( - dataset_id=dataset.id, uploads=[example_wrong_path] + dataset_id=dataset.id, + uploads=[example_wrong_path], + dangerously_allow_filesystem=True, ) assert "test_data/not-a-real-file.txt" in str(exc_info.value)