Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stream compress multipart sdk #1301

Closed
wants to merge 17 commits into from
Closed
141 changes: 138 additions & 3 deletions python/langsmith/_internal/_operations.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from __future__ import annotations

Check notice on line 1 in python/langsmith/_internal/_operations.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

........... WARNING: the benchmark result may be unstable * the standard deviation (113 ms) is 16% of the mean (707 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. create_5_000_run_trees: Mean +- std dev: 707 ms +- 113 ms ........... WARNING: the benchmark result may be unstable * the standard deviation (243 ms) is 17% of the mean (1.47 sec) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. create_10_000_run_trees: Mean +- std dev: 1.47 sec +- 0.24 sec ........... create_20_000_run_trees: Mean +- std dev: 1.43 sec +- 0.13 sec ........... dumps_class_nested_py_branch_and_leaf_200x400: Mean +- std dev: 707 us +- 22 us ........... dumps_class_nested_py_leaf_50x100: Mean +- std dev: 25.1 ms +- 0.3 ms ........... dumps_class_nested_py_leaf_100x200: Mean +- std dev: 105 ms +- 3 ms ........... dumps_dataclass_nested_50x100: Mean +- std dev: 25.7 ms +- 0.2 ms ........... WARNING: the benchmark result may be unstable * the standard deviation (17.9 ms) is 25% of the mean (72.0 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. dumps_pydantic_nested_50x100: Mean +- std dev: 72.0 ms +- 17.9 ms ........... dumps_pydanticv1_nested_50x100: Mean +- std dev: 197 ms +- 2 ms

Check notice on line 1 in python/langsmith/_internal/_operations.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------------+----------+------------------------+ | Benchmark | main | changes | +===============================================+==========+========================+ | dumps_pydanticv1_nested_50x100 | 221 ms | 197 ms: 1.12x faster | +-----------------------------------------------+----------+------------------------+ | create_5_000_run_trees | 724 ms | 707 ms: 1.02x faster | +-----------------------------------------------+----------+------------------------+ | dumps_class_nested_py_leaf_50x100 | 25.1 ms | 25.1 ms: 1.00x faster | +-----------------------------------------------+----------+------------------------+ | dumps_class_nested_py_leaf_100x200 | 105 ms | 105 ms: 1.00x faster | +-----------------------------------------------+----------+------------------------+ | dumps_dataclass_nested_50x100 | 25.6 ms | 25.7 ms: 1.00x slower | +-----------------------------------------------+----------+------------------------+ | dumps_class_nested_py_branch_and_leaf_200x400 | 705 us | 707 us: 1.00x slower | +-----------------------------------------------+----------+------------------------+ | create_20_000_run_trees | 1.39 sec | 1.43 sec: 1.03x slower | +-----------------------------------------------+----------+------------------------+ | create_10_000_run_trees | 1.40 sec | 1.47 sec: 1.05x slower | +-----------------------------------------------+----------+------------------------+ | dumps_pydantic_nested_50x100 | 66.2 ms | 72.0 ms: 1.09x slower | +-----------------------------------------------+----------+------------------------+ | Geometric mean | (ref) | 1.00x slower | +-----------------------------------------------+----------+------------------------+

import io
import itertools
import logging
import uuid
from typing import Literal, Optional, Union, cast
from typing import Iterator, Literal, Optional, Sequence, Union, cast

import zstandard as zstd

from langsmith import schemas as ls_schemas
from langsmith._internal import _orjson
from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext
from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext, join_multipart_parts_and_context
from langsmith._internal._serde import dumps_json as _dumps_json

logger = logging.getLogger(__name__)


BOUNDARY = uuid.uuid4().hex
class SerializedRunOperation:
operation: Literal["post", "patch"]
id: uuid.UUID
Expand Down Expand Up @@ -271,3 +274,135 @@
acc_parts,
f"trace={op.trace_id},id={op.id}",
)

class StreamingMultipartCompressor:
"""Incrementally compress multipart form data from multiple traces."""

def __init__(
self,
*,
compression_level: int = 3,
blocksize: int = 65536,
boundary: str = BOUNDARY,
):
self.compressor = zstd.ZstdCompressor(level=compression_level)
self.buffer = io.BytesIO()
self.blocksize = blocksize
self.boundary = boundary

def _yield_and_reset_buffer(self) -> Iterator[bytes]:
# Yield the current compressed data and reset the buffer.
compressed_data = self.buffer.getvalue()
if compressed_data:
yield compressed_data
self.buffer.seek(0)
self.buffer.truncate()

def _process_bytes(
self,
compressor: zstd.ZstdCompressionWriter,
data: Union[bytes, bytearray],
) -> Iterator[bytes]:
with memoryview(data) as view:
for i in range(0, len(view), self.blocksize):
chunk = view[i:i + self.blocksize]
compressor.write(chunk)
yield from self._yield_and_reset_buffer()

def compress_multipart_stream(
self,
parts_and_contexts: MultipartPartsAndContext,
) -> Iterator[bytes]:
# Create a streaming compressor context
compressor = self.compressor.stream_writer(self.buffer, closefd=False)

try:
for part_name, (filename, data, content_type, headers) in (
parts_and_contexts.parts
):
# Write part headers
part_header = (
f'--{self.boundary}\r\n'
f'Content-Disposition: form-data; name="{part_name}"'
)
if filename:
part_header += f'; filename="{filename}"'
part_header += f'\r\nContent-Type: {content_type}\r\n'

for header_name, header_value in headers.items():
part_header += f'{header_name}: {header_value}\r\n'

part_header += '\r\n'
compressor.write(part_header.encode())
# Yield compressed data to keep memory footprint small
yield from self._yield_and_reset_buffer()

# Write part data in chunks
if isinstance(data, (bytes, bytearray)):
yield from self._process_bytes(compressor, data)
else:
# Handle other data types
compressor.write(str(data).encode())
yield from self._yield_and_reset_buffer()

# End of this part
compressor.write(b'\r\n')
yield from self._yield_and_reset_buffer()

# Write final boundary
compressor.write(f'--{self.boundary}--\r\n'.encode())
yield from self._yield_and_reset_buffer()

finally:
# Close the compressor to flush any remaining data
compressor.close()
# Yield any final compressed data
yield from self._yield_and_reset_buffer()

def close(self):
"""Clean up resources after compression is complete."""
if self.buffer:
self.buffer.close()
self.buffer = None

def compress_operations(
self,
ops: Sequence[SerializedRunOperation],
*,
batch_size: Optional[int] = None,
) -> Iterator[bytes]:
"""Compress a sequence of operations into multipart form data. Used for batch run ingestion.

Args:
ops: Sequence of operations to compress
batch_size: Optional batch size for processing operations

Yields:
Compressed chunks of the multipart form data
"""
def chunk_ops(ops: Sequence[SerializedRunOperation],
size: Optional[int] = None,
) -> Iterator[Sequence[SerializedRunOperation]]:
if size is None:
yield ops
return

for i in range(0, len(ops), size):
yield ops[i:i + size]

def get_multipart_parts(
batch: Sequence[SerializedRunOperation]
) -> MultipartPartsAndContext:
parts_and_contexts = []
for op in batch:
parts_and_contexts.append(
serialized_run_operation_to_multipart_parts_and_context(op)
)
return join_multipart_parts_and_context(
parts_and_contexts
)

# Process operations in batches
for batch in chunk_ops(ops, batch_size):
multipart_data = get_multipart_parts(batch)
yield from self.compress_multipart_stream(multipart_data)
175 changes: 128 additions & 47 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import uuid
import warnings
import weakref
import zstandard as zstd
from inspect import signature
from queue import PriorityQueue
from typing import (
Expand All @@ -58,9 +59,6 @@

import requests
from requests import adapters as requests_adapters
from requests_toolbelt import ( # type: ignore[import-untyped]
multipart as rqtb_multipart,
)
from typing_extensions import TypeGuard, overload
from urllib3.poolmanager import PoolKey # type: ignore[attr-defined, import-untyped]
from urllib3.util import Retry # type: ignore[import-untyped]
Expand Down Expand Up @@ -89,6 +87,7 @@
from langsmith._internal._operations import (
SerializedFeedbackOperation,
SerializedRunOperation,
StreamingMultipartCompressor,
combine_serialized_queue_operations,
serialize_feedback_dict,
serialize_run_dict,
Expand Down Expand Up @@ -1623,52 +1622,134 @@ def multipart_ingest(
# sent the runs in multipart requests
self._multipart_ingest_ops(serialized_ops)

# def decompress_multipart_stream(self, compressed_stream: Iterator[bytes], boundary: str = BOUNDARY) -> Iterator[tuple[str, bytes]]:
# """Decompress and parse a multipart form data stream.

# Args:
# compressed_stream: Iterator of compressed bytes
# boundary: Multipart form boundary string

# Yields:
# Tuples of (part_name, part_data)
# """
# dctx = zstd.ZstdDecompressor()

# # Concatenate all compressed chunks and decompress
# compressed_data = b''.join(compressed_stream)
# stream_reader = dctx.stream_reader(compressed_data)
# decompressed_data = stream_reader.read()

# boundary_bytes = f'--{boundary}'.encode('utf-8')
# parts = decompressed_data.split(boundary_bytes)
# stream_reader.close()

# for part in parts:
# if not part or part.startswith(b'--\r\n'):
# continue

# # Split headers from content
# try:
# headers_raw, content = part.split(b'\r\n\r\n', 1)
# headers_text = headers_raw.decode('utf-8')

# # Extract part name from Content-Disposition
# content_disp = next(
# line for line in headers_text.split('\r\n')
# if line.startswith('Content-Disposition:')
# )
# part_name = content_disp.split('name="')[1].split('"')[0]

# # Remove trailing \r\n
# if content.endswith(b'\r\n'):
# content = content[:-2]

# yield part_name, content

# except Exception as e:
# logger.warning(f"Failed to parse multipart form part: {e}")
# continue

# def _send_multipart_req(self, acc: MultipartPartsAndContext, *, attempts: int = 3):
# """Test function that decompresses and validates the data locally"""
# compressor = StreamingMultipartCompressor(compression_level=3, blocksize=65536, boundary=BOUNDARY)

# compressed_data_iter = compressor.compress_multipart_stream(acc)

# # decompress and validate locally
# try:
# print("Decompressing and validating multipart data...")
# for part_name, content in self.decompress_multipart_stream(compressed_data_iter):
# print(f"Received part: {part_name}")
# print(f"Content length: {len(content)} bytes")

# # Validate JSON content if applicable
# if part_name.endswith('.json'):
# try:
# json_content = _orjson.loads(content)
# print(f"Valid JSON content for {part_name}")
# except Exception as e:
# print(f"Invalid JSON content for {part_name}: {e}")
# raise

# print("Successfully validated all multipart data")
# return True

# except Exception as e:
# print(f"Failed to validate multipart data: {e}")
# raise
# finally:
# compressor.close()

def _send_multipart_req(self, acc: MultipartPartsAndContext, *, attempts: int = 3):
parts = acc.parts
_context = acc.context
for api_url, api_key in self._write_api_urls.items():
for idx in range(1, attempts + 1):
try:
encoder = rqtb_multipart.MultipartEncoder(parts, boundary=BOUNDARY)
if encoder.len <= 20_000_000: # ~20 MB
data = encoder.to_string()
else:
data = encoder
self.request_with_retries(
"POST",
f"{api_url}/runs/multipart",
request_kwargs={
"data": data,
"headers": {
**self._headers,
X_API_KEY: api_key,
"Content-Type": encoder.content_type,
},
},
stop_after_attempt=1,
_context=_context,
)
break
except ls_utils.LangSmithConflictError:
break
except (
ls_utils.LangSmithConnectionError,
ls_utils.LangSmithRequestTimeout,
ls_utils.LangSmithAPIError,
) as exc:
if idx == attempts:
logger.warning(f"Failed to multipart ingest runs: {exc}")
else:
continue
except Exception as e:
try:
compressor = StreamingMultipartCompressor(compression_level=3, blocksize=65536, boundary=BOUNDARY)

compressed_data_iter = compressor.compress_multipart_stream(acc)

headers = {
**self._headers,
X_API_KEY: None, # Set inside the loop for each api_url
"Content-Type": f"multipart/form-data; boundary={BOUNDARY}",
"Content-Encoding": "zstd",
}

for api_url, api_key in self._write_api_urls.items():
headers[X_API_KEY] = api_key
for idx in range(1, attempts + 1):
try:
exc_desc_lines = traceback.format_exception_only(type(e), e)
exc_desc = "".join(exc_desc_lines).rstrip()
logger.warning(f"Failed to multipart ingest runs: {exc_desc}")
except Exception:
logger.warning(f"Failed to multipart ingest runs: {repr(e)}")
# do not retry by default
return
self.request_with_retries(
"POST",
f"{api_url}/runs/multipart",
request_kwargs={
"data": compressed_data_iter,
"headers": headers,
},
stop_after_attempt=1,
_context=acc.context,
)
break
except ls_utils.LangSmithConflictError:
break
except (
ls_utils.LangSmithConnectionError,
ls_utils.LangSmithRequestTimeout,
ls_utils.LangSmithAPIError,
) as exc:
if idx == attempts:
logger.warning(f"Failed to multipart ingest runs: {exc}")
else:
continue
except Exception as e:
try:
exc_desc_lines = traceback.format_exception_only(type(e), e)
exc_desc = "".join(exc_desc_lines).rstrip()
logger.warning(f"Failed to multipart ingest runs: {exc_desc}")
except Exception:
logger.warning(f"Failed to multipart ingest runs: {repr(e)}")
# do not retry by default
return
finally:
compressor.close()

def update_run(
self,
Expand Down
Loading
Loading