From 836648cdd618fd468273ccd290c587f05cc8a990 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 25 Jan 2024 12:59:12 -0800 Subject: [PATCH] Autoscale background threads for tracer auto batching --- python/langsmith/client.py | 53 +++++++++-- python/tests/unit_tests/test_client.py | 116 +++++++++++++++++++++++++ 2 files changed, 162 insertions(+), 7 deletions(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index d8bc7a8bc..50629173d 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -349,7 +349,7 @@ def __init__( self.tracing_queue: Optional[PriorityQueue] = PriorityQueue() threading.Thread( - target=_tracing_thread_func, + target=_tracing_control_thread_func, # arg must be a weakref to self to avoid the Thread object # preventing garbage collection of the Client object args=(weakref.ref(self),), @@ -3138,11 +3138,11 @@ def _evaluate_strings(self, prediction, reference=None, input=None, **kwargs) -> def _tracing_thread_drain_queue( - tracing_queue: Queue, limit: Optional[int] = None + tracing_queue: Queue, limit: int = 100, block: bool = True ) -> List[TracingQueueItem]: next_batch: List[TracingQueueItem] = [] try: - while item := tracing_queue.get(block=True, timeout=0.25): + while item := tracing_queue.get(block=block, timeout=0.25): next_batch.append(item) if limit and len(next_batch) >= limit: break @@ -3163,24 +3163,63 @@ def _tracing_thread_handle_batch( tracing_queue.task_done() -def _tracing_thread_func(client_ref: weakref.ref[Client]) -> None: +def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None: client = client_ref() if client is None: return tracing_queue = client.tracing_queue assert tracing_queue is not None + sub_threads: List[threading.Thread] = [] + # loop until while ( # the main thread dies threading.main_thread().is_alive() # or we're the only remaining reference to the client - and sys.getrefcount(client) > 3 + and sys.getrefcount(client) > 3 + len(sub_threads) # 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached ): - if next_batch := _tracing_thread_drain_queue(tracing_queue, 100): + print("im looping", tracing_queue.qsize()) + for thread in sub_threads: + if not thread.is_alive(): + sub_threads.remove(thread) + if tracing_queue.qsize() > 1000: + new_thread = threading.Thread( + target=_tracing_sub_thread_func, args=(weakref.ref(client),) + ) + sub_threads.append(new_thread) + new_thread.start() + if next_batch := _tracing_thread_drain_queue(tracing_queue): _tracing_thread_handle_batch(client, tracing_queue, next_batch) # drain the queue on exit - while next_batch := _tracing_thread_drain_queue(tracing_queue, 100): + while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False): + _tracing_thread_handle_batch(client, tracing_queue, next_batch) + + +def _tracing_sub_thread_func(client_ref: weakref.ref[Client]) -> None: + client = client_ref() + if client is None: + return + tracing_queue = client.tracing_queue + assert tracing_queue is not None + + seen_successive_empty_queues = 0 + + # loop until + while ( + # the main thread dies + threading.main_thread().is_alive() + # or we've seen the queue empty 4 times in a row + and seen_successive_empty_queues < 5 + ): + if next_batch := _tracing_thread_drain_queue(tracing_queue): + seen_successive_empty_queues = 0 + _tracing_thread_handle_batch(client, tracing_queue, next_batch) + else: + seen_successive_empty_queues += 1 + + # drain the queue on exit + while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False): _tracing_thread_handle_batch(client, tracing_queue, next_batch) diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index 9e455c87e..7cc484e5b 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -208,23 +208,139 @@ def __call__(self, *args: object, **kwargs: object) -> None: self.counter += 1 +@pytest.mark.parametrize("auto_batch_tracing", [True, False]) +def test_client_gc_empty(auto_batch_tracing: bool) -> None: + client = Client( + api_url="http://localhost:1984", + api_key="123", + auto_batch_tracing=auto_batch_tracing, + ) + tracker = CallTracker() + weakref.finalize(client, tracker) + assert tracker.counter == 0 + + del client + time.sleep(1) # Give the background thread time to stop + gc.collect() # Force garbage collection + assert tracker.counter == 1, "Client was not garbage collected" + + @pytest.mark.parametrize("auto_batch_tracing", [True, False]) def test_client_gc(auto_batch_tracing: bool) -> None: + session = mock.MagicMock(spec=requests.Session) client = Client( api_url="http://localhost:1984", api_key="123", auto_batch_tracing=auto_batch_tracing, + session=session, ) tracker = CallTracker() weakref.finalize(client, tracker) assert tracker.counter == 0 + for _ in range(10): + id = uuid.uuid4() + client.create_run( + "my_run", + inputs={}, + run_type="llm", + execution_order=1, + id=id, + trace_id=id, + dotted_order=id, + ) + + if auto_batch_tracing: + assert client.tracing_queue + client.tracing_queue.join() + + request_calls = [call for call in session.request.mock_calls if call.args] + assert len(request_calls) == 1 + for call in request_calls: + assert call.args[0] == "post" + assert call.args[1] == "http://localhost:1984/runs/batch" + else: + request_calls = [call for call in session.request.mock_calls if call.args] + assert len(request_calls) == 10 + for call in request_calls: + assert call.args[0] == "post" + assert call.args[1] == "http://localhost:1984/runs" + del client time.sleep(1) # Give the background thread time to stop gc.collect() # Force garbage collection assert tracker.counter == 1, "Client was not garbage collected" +@pytest.mark.parametrize("auto_batch_tracing", [True, False]) +def test_client_gc_no_batched_runs(auto_batch_tracing: bool) -> None: + session = mock.MagicMock(spec=requests.Session) + client = Client( + api_url="http://localhost:1984", + api_key="123", + auto_batch_tracing=auto_batch_tracing, + session=session, + ) + tracker = CallTracker() + weakref.finalize(client, tracker) + assert tracker.counter == 0 + + # because no trace_id/dotted_order provided, auto batch is disabled + for _ in range(10): + client.create_run( + "my_run", inputs={}, run_type="llm", execution_order=1, id=uuid.uuid4() + ) + request_calls = [call for call in session.request.mock_calls if call.args] + assert len(request_calls) == 10 + for call in request_calls: + assert call.args[0] == "post" + assert call.args[1] == "http://localhost:1984/runs" + + del client + time.sleep(1) # Give the background thread time to stop + gc.collect() # Force garbage collection + assert tracker.counter == 1, "Client was not garbage collected" + + +def test_client_gc_after_autoscale() -> None: + session = mock.MagicMock(spec=requests.Session) + client = Client( + api_url="http://localhost:1984", + api_key="123", + session=session, + ) + tracker = CallTracker() + weakref.finalize(client, tracker) + assert tracker.counter == 0 + + tracing_queue = client.tracing_queue + assert tracing_queue is not None + + for _ in range(50_000): + id = uuid.uuid4() + client.create_run( + "my_run", + inputs={}, + run_type="llm", + execution_order=1, + id=id, + trace_id=id, + dotted_order=id, + ) + + del client + tracing_queue.join() + time.sleep(2) # Give the background threads time to stop + gc.collect() # Force garbage collection + assert tracker.counter == 1, "Client was not garbage collected" + + request_calls = [call for call in session.request.mock_calls if call.args] + assert len(request_calls) >= 500 and len(request_calls) <= 550 + for call in request_calls: + assert call.args[0] == "post" + assert call.args[1] == "http://localhost:1984/runs/batch" + + @pytest.mark.parametrize("auto_batch_tracing", [True, False]) def test_create_run_includes_langchain_env_var_metadata( auto_batch_tracing: bool,