diff --git a/python/langsmith/client.py b/python/langsmith/client.py index a674b034f..79d736039 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -348,7 +348,7 @@ def __init__( self.tracing_queue: Optional[PriorityQueue] = PriorityQueue() threading.Thread( - target=_tracing_thread_func, + target=_tracing_control_thread_func, # arg must be a weakref to self to avoid the Thread object # preventing garbage collection of the Client object args=(weakref.ref(self),), @@ -3147,11 +3147,18 @@ def _evaluate_strings(self, prediction, reference=None, input=None, **kwargs) -> def _tracing_thread_drain_queue( - tracing_queue: Queue, limit: Optional[int] = None + tracing_queue: Queue, limit: int = 100, block: bool = True ) -> List[TracingQueueItem]: next_batch: List[TracingQueueItem] = [] try: - while item := tracing_queue.get(block=True, timeout=0.25): + # wait 250ms for the first item, then + # - drain the queue with a 50ms block timeout + # - stop draining if we hit the limit + # shorter drain timeout is used instead of non-blocking calls to + # avoid creating too many small batches + if item := tracing_queue.get(block=block, timeout=0.25): + next_batch.append(item) + while item := tracing_queue.get(block=block, timeout=0.05): next_batch.append(item) if limit and len(next_batch) >= limit: break @@ -3172,24 +3179,70 @@ def _tracing_thread_handle_batch( tracing_queue.task_done() -def _tracing_thread_func(client_ref: weakref.ref[Client]) -> None: +_AUTO_SCALE_UP_QSIZE_TRIGGER = 1000 +_AUTO_SCALE_UP_NTHREADS_LIMIT = 16 +_AUTO_SCALE_DOWN_NEMPTY_TRIGGER = 4 + + +def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None: client = client_ref() if client is None: return tracing_queue = client.tracing_queue assert tracing_queue is not None + sub_threads: List[threading.Thread] = [] + # loop until while ( # the main thread dies threading.main_thread().is_alive() # or we're the only remaining reference to the client - and sys.getrefcount(client) > 3 + and sys.getrefcount(client) > 3 + len(sub_threads) # 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached ): - if next_batch := _tracing_thread_drain_queue(tracing_queue, 100): + for thread in sub_threads: + if not thread.is_alive(): + sub_threads.remove(thread) + if ( + len(sub_threads) < _AUTO_SCALE_UP_NTHREADS_LIMIT + and tracing_queue.qsize() > _AUTO_SCALE_UP_QSIZE_TRIGGER + ): + new_thread = threading.Thread( + target=_tracing_sub_thread_func, args=(weakref.ref(client),) + ) + sub_threads.append(new_thread) + new_thread.start() + if next_batch := _tracing_thread_drain_queue(tracing_queue): + _tracing_thread_handle_batch(client, tracing_queue, next_batch) + + # drain the queue on exit + while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False): + _tracing_thread_handle_batch(client, tracing_queue, next_batch) + + +def _tracing_sub_thread_func(client_ref: weakref.ref[Client]) -> None: + client = client_ref() + if client is None: + return + tracing_queue = client.tracing_queue + assert tracing_queue is not None + + seen_successive_empty_queues = 0 + + # loop until + while ( + # the main thread dies + threading.main_thread().is_alive() + # or we've seen the queue empty 4 times in a row + and seen_successive_empty_queues <= _AUTO_SCALE_DOWN_NEMPTY_TRIGGER + ): + if next_batch := _tracing_thread_drain_queue(tracing_queue): + seen_successive_empty_queues = 0 _tracing_thread_handle_batch(client, tracing_queue, next_batch) + else: + seen_successive_empty_queues += 1 # drain the queue on exit - while next_batch := _tracing_thread_drain_queue(tracing_queue, 100): + while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False): _tracing_thread_handle_batch(client, tracing_queue, next_batch) diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index 9e455c87e..956236dd5 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -208,23 +208,140 @@ def __call__(self, *args: object, **kwargs: object) -> None: self.counter += 1 +@pytest.mark.parametrize("auto_batch_tracing", [True, False]) +def test_client_gc_empty(auto_batch_tracing: bool) -> None: + client = Client( + api_url="http://localhost:1984", + api_key="123", + auto_batch_tracing=auto_batch_tracing, + ) + tracker = CallTracker() + weakref.finalize(client, tracker) + assert tracker.counter == 0 + + del client + time.sleep(1) # Give the background thread time to stop + gc.collect() # Force garbage collection + assert tracker.counter == 1, "Client was not garbage collected" + + @pytest.mark.parametrize("auto_batch_tracing", [True, False]) def test_client_gc(auto_batch_tracing: bool) -> None: + session = mock.MagicMock(spec=requests.Session) client = Client( api_url="http://localhost:1984", api_key="123", auto_batch_tracing=auto_batch_tracing, + session=session, ) tracker = CallTracker() weakref.finalize(client, tracker) assert tracker.counter == 0 + for _ in range(10): + id = uuid.uuid4() + client.create_run( + "my_run", + inputs={}, + run_type="llm", + execution_order=1, + id=id, + trace_id=id, + dotted_order=id, + ) + + if auto_batch_tracing: + assert client.tracing_queue + client.tracing_queue.join() + + request_calls = [call for call in session.request.mock_calls if call.args] + assert len(request_calls) == 1 + for call in request_calls: + assert call.args[0] == "post" + assert call.args[1] == "http://localhost:1984/runs/batch" + else: + request_calls = [call for call in session.request.mock_calls if call.args] + assert len(request_calls) == 10 + for call in request_calls: + assert call.args[0] == "post" + assert call.args[1] == "http://localhost:1984/runs" + del client time.sleep(1) # Give the background thread time to stop gc.collect() # Force garbage collection assert tracker.counter == 1, "Client was not garbage collected" +@pytest.mark.parametrize("auto_batch_tracing", [True, False]) +def test_client_gc_no_batched_runs(auto_batch_tracing: bool) -> None: + session = mock.MagicMock(spec=requests.Session) + client = Client( + api_url="http://localhost:1984", + api_key="123", + auto_batch_tracing=auto_batch_tracing, + session=session, + ) + tracker = CallTracker() + weakref.finalize(client, tracker) + assert tracker.counter == 0 + + # because no trace_id/dotted_order provided, auto batch is disabled + for _ in range(10): + client.create_run( + "my_run", inputs={}, run_type="llm", execution_order=1, id=uuid.uuid4() + ) + request_calls = [call for call in session.request.mock_calls if call.args] + assert len(request_calls) == 10 + for call in request_calls: + assert call.args[0] == "post" + assert call.args[1] == "http://localhost:1984/runs" + + del client + time.sleep(1) # Give the background thread time to stop + gc.collect() # Force garbage collection + assert tracker.counter == 1, "Client was not garbage collected" + + +def test_client_gc_after_autoscale() -> None: + session = mock.MagicMock(spec=requests.Session) + client = Client( + api_url="http://localhost:1984", + api_key="123", + session=session, + auto_batch_tracing=True, + ) + tracker = CallTracker() + weakref.finalize(client, tracker) + assert tracker.counter == 0 + + tracing_queue = client.tracing_queue + assert tracing_queue is not None + + for _ in range(50_000): + id = uuid.uuid4() + client.create_run( + "my_run", + inputs={}, + run_type="llm", + execution_order=1, + id=id, + trace_id=id, + dotted_order=id, + ) + + del client + tracing_queue.join() + time.sleep(2) # Give the background threads time to stop + gc.collect() # Force garbage collection + assert tracker.counter == 1, "Client was not garbage collected" + + request_calls = [call for call in session.request.mock_calls if call.args] + assert len(request_calls) >= 500 and len(request_calls) <= 550 + for call in request_calls: + assert call.args[0] == "post" + assert call.args[1] == "http://localhost:1984/runs/batch" + + @pytest.mark.parametrize("auto_batch_tracing", [True, False]) def test_create_run_includes_langchain_env_var_metadata( auto_batch_tracing: bool,