From 1f83125197940fbe07c320ef443bb466e857a202 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 19 Jan 2024 17:29:44 -0800 Subject: [PATCH] Add auto_batch_tracing modality for Client - if on, starts a background thread that batches inserts/updates - only applies to insert/updates w/ trace_id and dotted_order --- python/langsmith/client.py | 128 ++++++++++++++++++++----- python/tests/unit_tests/test_client.py | 36 ++++++- 2 files changed, 134 insertions(+), 30 deletions(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 3fb11fd4c..84ca3b63e 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -11,10 +11,13 @@ import logging import os import random +import signal import socket +import threading import time import uuid import weakref +from queue import Empty, Queue from typing import ( TYPE_CHECKING, Any, @@ -281,6 +284,8 @@ class Client: "_tenant_id", "tracing_sample_rate", "_sampled_post_uuids", + "tracing_queue", + "tracing_thread", ] def __init__( @@ -292,6 +297,7 @@ def __init__( timeout_ms: Optional[int] = None, web_url: Optional[str] = None, session: Optional[requests.Session] = None, + auto_batch_tracing: bool = True, ) -> None: """Initialize a Client instance. @@ -331,6 +337,22 @@ def __init__( # Create a session and register a finalizer to close it self.session = session if session else requests.Session() weakref.finalize(self, close_session, self.session) + # Initialize auto batching + if auto_batch_tracing: + self.tracing_queue: Queue = Queue() + exit_event = threading.Event() + + def signal_exit(signum: int, frame: Any) -> None: + exit_event.set() + + signal.signal(signal.SIGINT, signal_exit) + self.tracing_thread = threading.Thread( + target=self._tracing_thread_func, args=(exit_event,) + ) + self.tracing_thread.start() + else: + self.tracing_queue = None + self.tracing_thread = None # Mount the HTTPAdapter with the retry configuration adapter = requests_adapters.HTTPAdapter(max_retries=self.retry_config) @@ -340,6 +362,35 @@ def __init__( self._get_data_type ) + def _tracing_thread_func(self, exit_event: threading.Event) -> None: + def drain_queue(limit: Optional[int] = None) -> List[Tuple[str, dict]]: + next_batch: List[Tuple[str, dict]] = [] + try: + while item := self.tracing_queue.get(block=True, timeout=0.25): + next_batch.append(item) + if limit and len(next_batch) >= limit: + break + except Empty: + pass + return next_batch + + def handle_batch(batch: List[Tuple[str, dict]]) -> None: + create = [item[1] for item in batch if item[0] == "create"] + update = [item[1] for item in batch if item[0] == "update"] + try: + self.batch_ingest_runs(create=create, update=update, pre_sampled=True) + finally: + for _ in batch: + self.tracing_queue.task_done() + + # loop until we receive a signal to exit or the main thread dies + while not exit_event.is_set() and threading.main_thread().is_alive(): + if next_batch := drain_queue(100): + handle_batch(next_batch) + # drain the queue on exit + if next_batch := drain_queue(): + handle_batch(next_batch) + def _repr_html_(self) -> str: """Return an HTML representation of the instance with a link to the URL. @@ -810,6 +861,14 @@ def create_run( "run_type": run_type, "execution_order": execution_order if execution_order is not None else 1, } + if not self._filter_for_sampling([run_create]): + return + if ( + self.tracing_queue is not None + and run_create.get("trace_id") is not None + and run_create.get("dotted_order") is not None + ): + return self.tracing_queue.put(("create", run_create)) run_create = self._run_transform(run_create) self._insert_runtime_env([run_create]) @@ -818,16 +877,16 @@ def create_run( "Accept": "application/json", "Content-Type": "application/json", } - if self._filter_for_sampling([run_create]): - self.request_with_retries( - "post", - f"{self.api_url}/runs", - request_kwargs={ - "data": json.dumps(run_create, default=_serialize_json), - "headers": headers, - "timeout": self.timeout_ms / 1000, - }, - ) + + self.request_with_retries( + "post", + f"{self.api_url}/runs", + request_kwargs={ + "data": json.dumps(run_create, default=_serialize_json), + "headers": headers, + "timeout": self.timeout_ms / 1000, + }, + ) def batch_ingest_runs( self, @@ -837,6 +896,7 @@ def batch_ingest_runs( update: Optional[ Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]] ] = None, + pre_sampled: bool = False, ): """ Batch ingest/upsert multiple runs in the Langsmith system. @@ -878,10 +938,16 @@ def batch_ingest_runs( standalone_updates.append(run) update_dicts = standalone_updates # filter out runs that are not sampled - body = { - "post": self._filter_for_sampling(create_dicts), - "patch": self._filter_for_sampling(update_dicts, patch=True), - } + if pre_sampled: + body = { + "post": create_dicts, + "patch": update_dicts, + } + else: + body = { + "post": self._filter_for_sampling(create_dicts), + "patch": self._filter_for_sampling(update_dicts, patch=True), + } if not body["post"] and not body["patch"]: return @@ -936,7 +1002,13 @@ def update_run( "Accept": "application/json", "Content-Type": "application/json", } - data: Dict[str, Any] = {} + data: Dict[str, Any] = { + "id": _as_uuid(run_id, "run_id"), + "trace_id": kwargs.pop("trace_id", None), + "dotted_order": kwargs.pop("dotted_order", None), + } + if not self._filter_for_sampling([data], patch=True): + return if end_time is not None: data["end_time"] = end_time.isoformat() if error is not None: @@ -947,16 +1019,22 @@ def update_run( data["outputs"] = _hide_outputs(outputs) if events is not None: data["events"] = events - if self._filter_for_sampling([data]): - self.request_with_retries( - "patch", - f"{self.api_url}/runs/{_as_uuid(run_id, 'run_id')}", - request_kwargs={ - "data": json.dumps(data, default=_serialize_json), - "headers": headers, - "timeout": self.timeout_ms / 1000, - }, - ) + if ( + self.tracing_queue is not None + and data["trace_id"] is not None + and data["dotted_order"] is not None + ): + return self.tracing_queue.put(("update", data)) + + self.request_with_retries( + "patch", + f"{self.api_url}/runs/{data['id']}", + request_kwargs={ + "data": json.dumps(data, default=_serialize_json), + "headers": headers, + "timeout": self.timeout_ms / 1000, + }, + ) def _load_child_runs(self, run: ls_schemas.Run) -> ls_schemas.Run: """Load child runs for a given run. diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index 930d4211d..8adb5d2a8 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -195,8 +195,15 @@ def test_create_run_unicode() -> None: client.update_run(id_, status="completed") -def test_create_run_includes_langchain_env_var_metadata() -> None: - client = Client(api_url="http://localhost:1984", api_key="123") +@pytest.mark.parametrize("auto_batch_tracing", [True, False]) +def test_create_run_includes_langchain_env_var_metadata( + auto_batch_tracing: bool +) -> None: + client = Client( + api_url="http://localhost:1984", + api_key="123", + auto_batch_tracing=auto_batch_tracing, + ) inputs = { "foo": "これは私の友達です", "bar": "این یک کتاب است", @@ -212,13 +219,32 @@ def test_create_run_includes_langchain_env_var_metadata() -> None: ls_env.get_langchain_env_var_metadata.cache_clear() with patch.object(client, "session", session): id_ = uuid.uuid4() + start_time = datetime.now() client.create_run( - "my_run", inputs=inputs, run_type="llm", execution_order=1, id=id_ + "my_run", + inputs=inputs, + run_type="llm", + execution_order=1, + id=id_, + trace_id=id_, + dotted_order=f"{start_time.strftime('%Y%m%dT%H%M%S%fZ')}{id_}", + start_time=start_time, ) + if auto_batch_tracing: + client.tracing_queue.join() # Check the posted value in the request posted_value = json.loads(session.request.call_args[1]["data"]) - assert posted_value["extra"]["metadata"]["LANGCHAIN_REVISION"] == "abcd2234" - assert "LANGCHAIN_API_KEY" not in posted_value["extra"]["metadata"] + if not auto_batch_tracing: + assert ( + posted_value["extra"]["metadata"]["LANGCHAIN_REVISION"] + == "abcd2234" + ) + assert "LANGCHAIN_API_KEY" not in posted_value["extra"]["metadata"] + else: + assert ( + posted_value["post"][0]["extra"]["metadata"]["LANGCHAIN_REVISION"] + == "abcd2234" + ) @pytest.mark.parametrize("source_type", ["api", "model"])