Skip to content

Commit

Permalink
Autoscale background threads for tracer auto batching
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Jan 25, 2024
1 parent 5bed311 commit 836648c
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 7 deletions.
53 changes: 46 additions & 7 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def __init__(
self.tracing_queue: Optional[PriorityQueue] = PriorityQueue()

threading.Thread(
target=_tracing_thread_func,
target=_tracing_control_thread_func,
# arg must be a weakref to self to avoid the Thread object
# preventing garbage collection of the Client object
args=(weakref.ref(self),),
Expand Down Expand Up @@ -3138,11 +3138,11 @@ def _evaluate_strings(self, prediction, reference=None, input=None, **kwargs) ->


def _tracing_thread_drain_queue(
tracing_queue: Queue, limit: Optional[int] = None
tracing_queue: Queue, limit: int = 100, block: bool = True
) -> List[TracingQueueItem]:
next_batch: List[TracingQueueItem] = []
try:
while item := tracing_queue.get(block=True, timeout=0.25):
while item := tracing_queue.get(block=block, timeout=0.25):
next_batch.append(item)
if limit and len(next_batch) >= limit:
break
Expand All @@ -3163,24 +3163,63 @@ def _tracing_thread_handle_batch(
tracing_queue.task_done()


def _tracing_thread_func(client_ref: weakref.ref[Client]) -> None:
def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
client = client_ref()
if client is None:
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None

sub_threads: List[threading.Thread] = []

# loop until
while (
# the main thread dies
threading.main_thread().is_alive()
# or we're the only remaining reference to the client
and sys.getrefcount(client) > 3
and sys.getrefcount(client) > 3 + len(sub_threads)
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
):
if next_batch := _tracing_thread_drain_queue(tracing_queue, 100):
print("im looping", tracing_queue.qsize())
for thread in sub_threads:
if not thread.is_alive():
sub_threads.remove(thread)
if tracing_queue.qsize() > 1000:
new_thread = threading.Thread(
target=_tracing_sub_thread_func, args=(weakref.ref(client),)
)
sub_threads.append(new_thread)
new_thread.start()
if next_batch := _tracing_thread_drain_queue(tracing_queue):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)

# drain the queue on exit
while next_batch := _tracing_thread_drain_queue(tracing_queue, 100):
while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)


def _tracing_sub_thread_func(client_ref: weakref.ref[Client]) -> None:
client = client_ref()
if client is None:
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None

seen_successive_empty_queues = 0

# loop until
while (
# the main thread dies
threading.main_thread().is_alive()
# or we've seen the queue empty 4 times in a row
and seen_successive_empty_queues < 5
):
if next_batch := _tracing_thread_drain_queue(tracing_queue):
seen_successive_empty_queues = 0
_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
seen_successive_empty_queues += 1

# drain the queue on exit
while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)
116 changes: 116 additions & 0 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,23 +208,139 @@ def __call__(self, *args: object, **kwargs: object) -> None:
self.counter += 1


@pytest.mark.parametrize("auto_batch_tracing", [True, False])
def test_client_gc_empty(auto_batch_tracing: bool) -> None:
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=auto_batch_tracing,
)
tracker = CallTracker()
weakref.finalize(client, tracker)
assert tracker.counter == 0

del client
time.sleep(1) # Give the background thread time to stop
gc.collect() # Force garbage collection
assert tracker.counter == 1, "Client was not garbage collected"


@pytest.mark.parametrize("auto_batch_tracing", [True, False])
def test_client_gc(auto_batch_tracing: bool) -> None:
session = mock.MagicMock(spec=requests.Session)
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=auto_batch_tracing,
session=session,
)
tracker = CallTracker()
weakref.finalize(client, tracker)
assert tracker.counter == 0

for _ in range(10):
id = uuid.uuid4()
client.create_run(
"my_run",
inputs={},
run_type="llm",
execution_order=1,
id=id,
trace_id=id,
dotted_order=id,
)

if auto_batch_tracing:
assert client.tracing_queue
client.tracing_queue.join()

request_calls = [call for call in session.request.mock_calls if call.args]
assert len(request_calls) == 1
for call in request_calls:
assert call.args[0] == "post"
assert call.args[1] == "http://localhost:1984/runs/batch"
else:
request_calls = [call for call in session.request.mock_calls if call.args]
assert len(request_calls) == 10
for call in request_calls:
assert call.args[0] == "post"
assert call.args[1] == "http://localhost:1984/runs"

del client
time.sleep(1) # Give the background thread time to stop
gc.collect() # Force garbage collection
assert tracker.counter == 1, "Client was not garbage collected"


@pytest.mark.parametrize("auto_batch_tracing", [True, False])
def test_client_gc_no_batched_runs(auto_batch_tracing: bool) -> None:
session = mock.MagicMock(spec=requests.Session)
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=auto_batch_tracing,
session=session,
)
tracker = CallTracker()
weakref.finalize(client, tracker)
assert tracker.counter == 0

# because no trace_id/dotted_order provided, auto batch is disabled
for _ in range(10):
client.create_run(
"my_run", inputs={}, run_type="llm", execution_order=1, id=uuid.uuid4()
)
request_calls = [call for call in session.request.mock_calls if call.args]
assert len(request_calls) == 10
for call in request_calls:
assert call.args[0] == "post"
assert call.args[1] == "http://localhost:1984/runs"

del client
time.sleep(1) # Give the background thread time to stop
gc.collect() # Force garbage collection
assert tracker.counter == 1, "Client was not garbage collected"


def test_client_gc_after_autoscale() -> None:
session = mock.MagicMock(spec=requests.Session)
client = Client(
api_url="http://localhost:1984",
api_key="123",
session=session,
)
tracker = CallTracker()
weakref.finalize(client, tracker)
assert tracker.counter == 0

tracing_queue = client.tracing_queue
assert tracing_queue is not None

for _ in range(50_000):
id = uuid.uuid4()
client.create_run(
"my_run",
inputs={},
run_type="llm",
execution_order=1,
id=id,
trace_id=id,
dotted_order=id,
)

del client
tracing_queue.join()
time.sleep(2) # Give the background threads time to stop
gc.collect() # Force garbage collection
assert tracker.counter == 1, "Client was not garbage collected"

request_calls = [call for call in session.request.mock_calls if call.args]
assert len(request_calls) >= 500 and len(request_calls) <= 550
for call in request_calls:
assert call.args[0] == "post"
assert call.args[1] == "http://localhost:1984/runs/batch"


@pytest.mark.parametrize("auto_batch_tracing", [True, False])
def test_create_run_includes_langchain_env_var_metadata(
auto_batch_tracing: bool,
Expand Down

0 comments on commit 836648c

Please sign in to comment.