diff --git a/python/langsmith/_internal/_background_thread.py b/python/langsmith/_internal/_background_thread.py index 54013c7a6..6b8bbdcd7 100644 --- a/python/langsmith/_internal/_background_thread.py +++ b/python/langsmith/_internal/_background_thread.py @@ -8,6 +8,8 @@ import threading import time import weakref +from multiprocessing import cpu_count +import concurrent.futures from queue import Empty, Queue from typing import ( TYPE_CHECKING, @@ -36,6 +38,7 @@ logger = logging.getLogger("langsmith.client") +HTTP_REQUEST_THREAD_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=cpu_count()) @functools.total_ordering class TracingQueueItem: @@ -233,22 +236,6 @@ def keep_thread_active() -> bool: _tracing_thread_handle_batch(client, tracing_queue, next_batch, use_multipart) -def _worker_thread_func(client: Client, request_queue: Queue) -> None: - while True: - try: - data_stream = request_queue.get() - - if data_stream is None: - break - - client._send_compressed_multipart_req(data_stream) - - except Exception: - logger.error("Error in worker thread processing request", exc_info=True) - finally: - request_queue.task_done() - - def tracing_control_thread_func_compress_parallel( client_ref: weakref.ref[Client], ) -> None: @@ -261,18 +248,6 @@ def tracing_control_thread_func_compress_parallel( size_limit_bytes = batch_ingest_config.get("size_limit_bytes", 20_971_520) assert size_limit_bytes is not None - num_workers = min(4, os.cpu_count() or 1) - request_queue: Queue = Queue(maxsize=num_workers * 2) - workers = [] - - for _ in range(num_workers): - worker = threading.Thread( - target=_worker_thread_func, - args=(client, request_queue), - ) - worker.start() - workers.append(worker) - def keep_thread_active() -> bool: # if `client.cleanup()` was called, stop thread if not client or ( @@ -290,7 +265,13 @@ def keep_thread_active() -> bool: client, size_limit, size_limit_bytes ) if data_stream is not None: - request_queue.put(data_stream) + try: + HTTP_REQUEST_THREAD_POOL.submit( + client._send_compressed_multipart_req, data_stream + ) + print("submitted request") + except RuntimeError: + client._send_compressed_multipart_req(data_stream) else: time.sleep(0.05) except Exception: @@ -303,15 +284,16 @@ def keep_thread_active() -> bool: client, size_limit=1, size_limit_bytes=1 ) # Force final drain if final_data_stream is not None: - request_queue.put(final_data_stream) - - request_queue.join() - - for _ in workers: - request_queue.put(None) - for worker in workers: - worker.join() + try: + concurrent.futures.wait( + HTTP_REQUEST_THREAD_POOL.submit( + client._send_compressed_multipart_req, final_data_stream + ) + ) + except RuntimeError: + client._send_compressed_multipart_req(final_data_stream) + except Exception: logger.error("Error in final cleanup", exc_info=True)