Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autoscale background threads for tracer auto batching #382

Merged
merged 39 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c3f82b8
Update
hinthornw Jan 8, 2024
7e96bf2
update
hinthornw Jan 8, 2024
f9c17ad
Update test
hinthornw Jan 8, 2024
0ccc957
Update test
hinthornw Jan 9, 2024
fae1b65
Add optional tracing sampling rate to langsmith sdk
nfcampos Jan 19, 2024
82ba901
Lint
nfcampos Jan 19, 2024
19d608d
In batch tracing endpoint, combine patch and post payloads where poss…
nfcampos Jan 19, 2024
93489b3
Merge branch 'main' into nc/batch-trace-combine
nfcampos Jan 19, 2024
099162a
Update python/tests/integration_tests/test_runs.py
nfcampos Jan 20, 2024
511bcac
Update python/langsmith/client.py
nfcampos Jan 20, 2024
787a9a7
Add optional tracing sampling rate to langsmith sdk (#370)
nfcampos Jan 20, 2024
763b7d8
Nc/batch trace combine (#371)
nfcampos Jan 20, 2024
fc34146
Lint
nfcampos Jan 20, 2024
1f83125
Add auto_batch_tracing modality for Client
nfcampos Jan 20, 2024
2780421
Lint
nfcampos Jan 20, 2024
439e0e4
Fix GC problem
nfcampos Jan 20, 2024
1f98fbe
Remove unused signal, add comments
nfcampos Jan 20, 2024
3faa5eb
Add missing parent_run_id
nfcampos Jan 20, 2024
3bd9e6f
Adjust config
nfcampos Jan 21, 2024
64febeb
Use a priority queue to group runs from same trace in same batch wher…
nfcampos Jan 21, 2024
4f5c793
Also chunk at end
nfcampos Jan 21, 2024
8658098
Update retry
hinthornw Jan 22, 2024
a675719
Actually retry
hinthornw Jan 23, 2024
482c13c
Actually Retry by default (#374)
nfcampos Jan 23, 2024
5bed311
Catch 409's (#375)
hinthornw Jan 24, 2024
836648c
Autoscale background threads for tracer auto batching
nfcampos Jan 25, 2024
8d00547
Remove print
nfcampos Jan 25, 2024
ddadf9a
Compromise
nfcampos Jan 25, 2024
08f6735
Add limits
nfcampos Jan 25, 2024
4538ac9
Add one more constant
nfcampos Jan 25, 2024
1645fcd
Adjust
nfcampos Jan 25, 2024
e9de9f9
Add auto_batch_tracing modality for Client (#372)
nfcampos Jan 28, 2024
7916eec
Turn auto-batch tracing off
hinthornw Jan 28, 2024
70ce886
Merge
hinthornw Jan 28, 2024
8732a81
fixup int tests
hinthornw Jan 28, 2024
79f1aca
merge
hinthornw Jan 28, 2024
0a5adfe
fixup
hinthornw Jan 28, 2024
96c1afb
Merge branch 'wfh/batch_run_create' into nc/25jan/tracer-auto-batch-a…
hinthornw Jan 28, 2024
fe02f82
merge
hinthornw Jan 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def __init__(
self.tracing_queue: Optional[PriorityQueue] = PriorityQueue()

threading.Thread(
target=_tracing_thread_func,
target=_tracing_control_thread_func,
# arg must be a weakref to self to avoid the Thread object
# preventing garbage collection of the Client object
args=(weakref.ref(self),),
Expand Down Expand Up @@ -3147,11 +3147,18 @@ def _evaluate_strings(self, prediction, reference=None, input=None, **kwargs) ->


def _tracing_thread_drain_queue(
tracing_queue: Queue, limit: Optional[int] = None
tracing_queue: Queue, limit: int = 100, block: bool = True
) -> List[TracingQueueItem]:
next_batch: List[TracingQueueItem] = []
try:
while item := tracing_queue.get(block=True, timeout=0.25):
# wait 250ms for the first item, then
# - drain the queue with a 50ms block timeout
# - stop draining if we hit the limit
# shorter drain timeout is used instead of non-blocking calls to
# avoid creating too many small batches
if item := tracing_queue.get(block=block, timeout=0.25):
next_batch.append(item)
while item := tracing_queue.get(block=block, timeout=0.05):
next_batch.append(item)
if limit and len(next_batch) >= limit:
break
Expand All @@ -3172,24 +3179,70 @@ def _tracing_thread_handle_batch(
tracing_queue.task_done()


def _tracing_thread_func(client_ref: weakref.ref[Client]) -> None:
_AUTO_SCALE_UP_QSIZE_TRIGGER = 1000
_AUTO_SCALE_UP_NTHREADS_LIMIT = 16
_AUTO_SCALE_DOWN_NEMPTY_TRIGGER = 4


def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
client = client_ref()
if client is None:
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None

sub_threads: List[threading.Thread] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a max size here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


# loop until
while (
# the main thread dies
threading.main_thread().is_alive()
# or we're the only remaining reference to the client
and sys.getrefcount(client) > 3
and sys.getrefcount(client) > 3 + len(sub_threads)
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
):
if next_batch := _tracing_thread_drain_queue(tracing_queue, 100):
for thread in sub_threads:
if not thread.is_alive():
sub_threads.remove(thread)
if (
len(sub_threads) < _AUTO_SCALE_UP_NTHREADS_LIMIT
and tracing_queue.qsize() > _AUTO_SCALE_UP_QSIZE_TRIGGER
):
new_thread = threading.Thread(
target=_tracing_sub_thread_func, args=(weakref.ref(client),)
)
sub_threads.append(new_thread)
new_thread.start()
if next_batch := _tracing_thread_drain_queue(tracing_queue):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this get assigned to the right thread?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I understand the comment, this function runs all in the same thread, so any thing in this function is happening in this one thread

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah! So the main thread also drains the queue and handles requests alongside the subthreads, which it spawns to do the same thing?

It's weird we have a main thread and a subthread doing the same action, but in different places, but I guess that works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea because I think most uses of this don't need the other threads, can change if we think its better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so theoretically could you round robin dequeue records from all the different threads as they're accumulated? Nothing wrong with that, just thinking through how this works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a single queue shared by all consumer threads. increasing nr of threads just increases the rate at which we drawdown the single queue

_tracing_thread_handle_batch(client, tracing_queue, next_batch)

# drain the queue on exit
while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)


def _tracing_sub_thread_func(client_ref: weakref.ref[Client]) -> None:
client = client_ref()
if client is None:
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None

seen_successive_empty_queues = 0

# loop until
while (
# the main thread dies
threading.main_thread().is_alive()
# or we've seen the queue empty 4 times in a row
and seen_successive_empty_queues <= _AUTO_SCALE_DOWN_NEMPTY_TRIGGER
):
if next_batch := _tracing_thread_drain_queue(tracing_queue):
seen_successive_empty_queues = 0
_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
seen_successive_empty_queues += 1

# drain the queue on exit
while next_batch := _tracing_thread_drain_queue(tracing_queue, 100):
while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)
117 changes: 117 additions & 0 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,23 +208,140 @@ def __call__(self, *args: object, **kwargs: object) -> None:
self.counter += 1


@pytest.mark.parametrize("auto_batch_tracing", [True, False])
def test_client_gc_empty(auto_batch_tracing: bool) -> None:
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=auto_batch_tracing,
)
tracker = CallTracker()
weakref.finalize(client, tracker)
assert tracker.counter == 0

del client
time.sleep(1) # Give the background thread time to stop
gc.collect() # Force garbage collection
assert tracker.counter == 1, "Client was not garbage collected"


@pytest.mark.parametrize("auto_batch_tracing", [True, False])
def test_client_gc(auto_batch_tracing: bool) -> None:
session = mock.MagicMock(spec=requests.Session)
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=auto_batch_tracing,
session=session,
)
tracker = CallTracker()
weakref.finalize(client, tracker)
assert tracker.counter == 0

for _ in range(10):
id = uuid.uuid4()
client.create_run(
"my_run",
inputs={},
run_type="llm",
execution_order=1,
id=id,
trace_id=id,
dotted_order=id,
)

if auto_batch_tracing:
assert client.tracing_queue
client.tracing_queue.join()

request_calls = [call for call in session.request.mock_calls if call.args]
assert len(request_calls) == 1
for call in request_calls:
assert call.args[0] == "post"
assert call.args[1] == "http://localhost:1984/runs/batch"
else:
request_calls = [call for call in session.request.mock_calls if call.args]
assert len(request_calls) == 10
for call in request_calls:
assert call.args[0] == "post"
assert call.args[1] == "http://localhost:1984/runs"

del client
time.sleep(1) # Give the background thread time to stop
gc.collect() # Force garbage collection
assert tracker.counter == 1, "Client was not garbage collected"


@pytest.mark.parametrize("auto_batch_tracing", [True, False])
def test_client_gc_no_batched_runs(auto_batch_tracing: bool) -> None:
session = mock.MagicMock(spec=requests.Session)
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=auto_batch_tracing,
session=session,
)
tracker = CallTracker()
weakref.finalize(client, tracker)
assert tracker.counter == 0

# because no trace_id/dotted_order provided, auto batch is disabled
for _ in range(10):
client.create_run(
"my_run", inputs={}, run_type="llm", execution_order=1, id=uuid.uuid4()
)
request_calls = [call for call in session.request.mock_calls if call.args]
assert len(request_calls) == 10
for call in request_calls:
assert call.args[0] == "post"
assert call.args[1] == "http://localhost:1984/runs"

del client
time.sleep(1) # Give the background thread time to stop
gc.collect() # Force garbage collection
assert tracker.counter == 1, "Client was not garbage collected"


def test_client_gc_after_autoscale() -> None:
session = mock.MagicMock(spec=requests.Session)
client = Client(
api_url="http://localhost:1984",
api_key="123",
session=session,
auto_batch_tracing=True,
)
tracker = CallTracker()
weakref.finalize(client, tracker)
assert tracker.counter == 0

tracing_queue = client.tracing_queue
assert tracing_queue is not None

for _ in range(50_000):
id = uuid.uuid4()
client.create_run(
"my_run",
inputs={},
run_type="llm",
execution_order=1,
id=id,
trace_id=id,
dotted_order=id,
)

del client
tracing_queue.join()
time.sleep(2) # Give the background threads time to stop
gc.collect() # Force garbage collection
assert tracker.counter == 1, "Client was not garbage collected"

request_calls = [call for call in session.request.mock_calls if call.args]
assert len(request_calls) >= 500 and len(request_calls) <= 550
for call in request_calls:
assert call.args[0] == "post"
assert call.args[1] == "http://localhost:1984/runs/batch"


@pytest.mark.parametrize("auto_batch_tracing", [True, False])
def test_create_run_includes_langchain_env_var_metadata(
auto_batch_tracing: bool,
Expand Down
Loading