Skip to content

Commit

Permalink
add parallel works for sending multipart req
Browse files Browse the repository at this point in the history
  • Loading branch information
angus-langchain committed Dec 12, 2024
1 parent 066b9b6 commit ec15660
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
52 changes: 44 additions & 8 deletions python/langsmith/_internal/_background_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)

import zstandard as zstd

import os
from langsmith import schemas as ls_schemas
from langsmith._internal._constants import (
_AUTO_SCALE_DOWN_NEMPTY_TRIGGER,
Expand Down Expand Up @@ -94,7 +94,7 @@ def _tracing_thread_drain_queue(


def _tracing_thread_drain_compressed_buffer(
client: Client, size_limit: int = 100, size_limit_bytes: int = 50 * 1024 * 1024
client: Client, size_limit: int = 100, size_limit_bytes: int = 65_536
) -> Optional[io.BytesIO]:
assert client.compressed_runs_buffer is not None
assert client.compressor_writer is not None
Expand Down Expand Up @@ -155,7 +155,7 @@ def _ensure_ingest_config(
) -> ls_schemas.BatchIngestConfig:
default_config = ls_schemas.BatchIngestConfig(
use_multipart_endpoint=False,
size_limit_bytes=50 * 1024 * 1024,
size_limit_bytes=None, # Note this field is not used here
size_limit=100,
scale_up_nthreads_limit=_AUTO_SCALE_UP_NTHREADS_LIMIT,
scale_up_qsize_trigger=_AUTO_SCALE_UP_QSIZE_TRIGGER,
Expand Down Expand Up @@ -231,16 +231,44 @@ def keep_thread_active() -> bool:
):
_tracing_thread_handle_batch(client, tracing_queue, next_batch, use_multipart)

def _worker_thread_func(client: Client, request_queue: Queue) -> None:
"""Worker thread function that processes requests from the queue"""
while True:
try:
data_stream = request_queue.get()

if data_stream is None:
break

client._send_compressed_multipart_req(data_stream)

except Exception:
logger.error("Error in worker thread processing request", exc_info=True)
finally:
request_queue.task_done()

def tracing_control_thread_func_compress(client_ref: weakref.ref[Client]) -> None:
def tracing_control_thread_func_compress_parallel(client_ref: weakref.ref[Client]) -> None:
client = client_ref()
if client is None:
return

batch_ingest_config = _ensure_ingest_config(client.info)
size_limit: int = batch_ingest_config["size_limit"]
size_limit_bytes = batch_ingest_config.get("size_limit_bytes", 50 * 1024 * 1024)
size_limit_bytes = batch_ingest_config.get("size_limit_bytes", 65_536)
assert size_limit_bytes is not None

num_workers = min(4, os.cpu_count())
request_queue: Queue = Queue(maxsize=num_workers * 2)
workers = []

for _ in range(num_workers):
worker = threading.Thread(
target=_worker_thread_func,
args=(client, request_queue),
)
worker.start()
workers.append(worker)

def keep_thread_active() -> bool:
# if `client.cleanup()` was called, stop thread
if not client or (
Expand All @@ -258,7 +286,7 @@ def keep_thread_active() -> bool:
client, size_limit, size_limit_bytes
)
if data_stream is not None:
client._send_compressed_multipart_req(data_stream)
request_queue.put(data_stream)
else:
time.sleep(0.05)
except Exception:
Expand All @@ -271,9 +299,17 @@ def keep_thread_active() -> bool:
client, size_limit=1, size_limit_bytes=1
) # Force final drain
if final_data_stream is not None:
client._send_compressed_multipart_req(final_data_stream)
request_queue.put(final_data_stream)

request_queue.join()

for _ in workers:
request_queue.put(None)
for worker in workers:
worker.join()

except Exception:
logger.error("Error in final buffer drain", exc_info=True)
logger.error("Error in final cleanup", exc_info=True)


def _tracing_sub_thread_func(
Expand Down
9 changes: 3 additions & 6 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
tracing_control_thread_func as _tracing_control_thread_func,
)
from langsmith._internal._background_thread import (
tracing_control_thread_func_compress as _tracing_control_thread_func_compress,
tracing_control_thread_func_compress_parallel as _tracing_control_thread_func_compress_parallel,
)
from langsmith._internal._beta_decorator import warn_beta
from langsmith._internal._constants import (
Expand Down Expand Up @@ -526,7 +526,7 @@ def __init__(
if auto_batch_tracing and self.compress_traces:
self.tracing_queue: Optional[PriorityQueue] = None
threading.Thread(
target=_tracing_control_thread_func_compress,
target=_tracing_control_thread_func_compress_parallel,
# arg must be a weakref to self to avoid the Thread object
# preventing garbage collection of the Client object
args=(weakref.ref(self),),
Expand Down Expand Up @@ -1716,10 +1716,7 @@ def _send_multipart_req(self, acc: MultipartPartsAndContext, *, attempts: int =
return

def _send_compressed_multipart_req(self, data_stream, *, attempts: int = 3):
"""Send a zstd-compressed multipart form data stream to the backend.
Uses similar retry logic as _send_multipart_req.
"""
"""Send a zstd-compressed multipart form data stream to the backend."""
_context: str = ""

for api_url, api_key in self._write_api_urls.items():
Expand Down

0 comments on commit ec15660

Please sign in to comment.