Skip to content

Commit

Permalink
Add auto_batch_tracing modality for Client
Browse files Browse the repository at this point in the history
- if on, starts a background thread that batches inserts/updates
- only applies to insert/updates w/ trace_id and dotted_order
  • Loading branch information
nfcampos committed Jan 20, 2024
1 parent fc34146 commit 1f83125
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 30 deletions.
128 changes: 103 additions & 25 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
import logging
import os
import random
import signal
import socket
import threading
import time
import uuid
import weakref
from queue import Empty, Queue
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -281,6 +284,8 @@ class Client:
"_tenant_id",
"tracing_sample_rate",
"_sampled_post_uuids",
"tracing_queue",
"tracing_thread",
]

def __init__(
Expand All @@ -292,6 +297,7 @@ def __init__(
timeout_ms: Optional[int] = None,
web_url: Optional[str] = None,
session: Optional[requests.Session] = None,
auto_batch_tracing: bool = True,
) -> None:
"""Initialize a Client instance.
Expand Down Expand Up @@ -331,6 +337,22 @@ def __init__(
# Create a session and register a finalizer to close it
self.session = session if session else requests.Session()
weakref.finalize(self, close_session, self.session)
# Initialize auto batching
if auto_batch_tracing:
self.tracing_queue: Queue = Queue()
exit_event = threading.Event()

def signal_exit(signum: int, frame: Any) -> None:
exit_event.set()

signal.signal(signal.SIGINT, signal_exit)
self.tracing_thread = threading.Thread(
target=self._tracing_thread_func, args=(exit_event,)
)
self.tracing_thread.start()
else:
self.tracing_queue = None
self.tracing_thread = None

# Mount the HTTPAdapter with the retry configuration
adapter = requests_adapters.HTTPAdapter(max_retries=self.retry_config)
Expand All @@ -340,6 +362,35 @@ def __init__(
self._get_data_type
)

def _tracing_thread_func(self, exit_event: threading.Event) -> None:
def drain_queue(limit: Optional[int] = None) -> List[Tuple[str, dict]]:
next_batch: List[Tuple[str, dict]] = []
try:
while item := self.tracing_queue.get(block=True, timeout=0.25):
next_batch.append(item)
if limit and len(next_batch) >= limit:
break
except Empty:
pass
return next_batch

def handle_batch(batch: List[Tuple[str, dict]]) -> None:
create = [item[1] for item in batch if item[0] == "create"]
update = [item[1] for item in batch if item[0] == "update"]
try:
self.batch_ingest_runs(create=create, update=update, pre_sampled=True)
finally:
for _ in batch:
self.tracing_queue.task_done()

# loop until we receive a signal to exit or the main thread dies
while not exit_event.is_set() and threading.main_thread().is_alive():
if next_batch := drain_queue(100):
handle_batch(next_batch)
# drain the queue on exit
if next_batch := drain_queue():
handle_batch(next_batch)

def _repr_html_(self) -> str:
"""Return an HTML representation of the instance with a link to the URL.
Expand Down Expand Up @@ -810,6 +861,14 @@ def create_run(
"run_type": run_type,
"execution_order": execution_order if execution_order is not None else 1,
}
if not self._filter_for_sampling([run_create]):
return
if (
self.tracing_queue is not None
and run_create.get("trace_id") is not None
and run_create.get("dotted_order") is not None
):
return self.tracing_queue.put(("create", run_create))

run_create = self._run_transform(run_create)
self._insert_runtime_env([run_create])
Expand All @@ -818,16 +877,16 @@ def create_run(
"Accept": "application/json",
"Content-Type": "application/json",
}
if self._filter_for_sampling([run_create]):
self.request_with_retries(
"post",
f"{self.api_url}/runs",
request_kwargs={
"data": json.dumps(run_create, default=_serialize_json),
"headers": headers,
"timeout": self.timeout_ms / 1000,
},
)

self.request_with_retries(
"post",
f"{self.api_url}/runs",
request_kwargs={
"data": json.dumps(run_create, default=_serialize_json),
"headers": headers,
"timeout": self.timeout_ms / 1000,
},
)

def batch_ingest_runs(
self,
Expand All @@ -837,6 +896,7 @@ def batch_ingest_runs(
update: Optional[
Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]]
] = None,
pre_sampled: bool = False,
):
"""
Batch ingest/upsert multiple runs in the Langsmith system.
Expand Down Expand Up @@ -878,10 +938,16 @@ def batch_ingest_runs(
standalone_updates.append(run)
update_dicts = standalone_updates
# filter out runs that are not sampled
body = {
"post": self._filter_for_sampling(create_dicts),
"patch": self._filter_for_sampling(update_dicts, patch=True),
}
if pre_sampled:
body = {
"post": create_dicts,
"patch": update_dicts,
}
else:
body = {
"post": self._filter_for_sampling(create_dicts),
"patch": self._filter_for_sampling(update_dicts, patch=True),
}
if not body["post"] and not body["patch"]:
return

Expand Down Expand Up @@ -936,7 +1002,13 @@ def update_run(
"Accept": "application/json",
"Content-Type": "application/json",
}
data: Dict[str, Any] = {}
data: Dict[str, Any] = {
"id": _as_uuid(run_id, "run_id"),
"trace_id": kwargs.pop("trace_id", None),
"dotted_order": kwargs.pop("dotted_order", None),
}
if not self._filter_for_sampling([data], patch=True):
return
if end_time is not None:
data["end_time"] = end_time.isoformat()
if error is not None:
Expand All @@ -947,16 +1019,22 @@ def update_run(
data["outputs"] = _hide_outputs(outputs)
if events is not None:
data["events"] = events
if self._filter_for_sampling([data]):
self.request_with_retries(
"patch",
f"{self.api_url}/runs/{_as_uuid(run_id, 'run_id')}",
request_kwargs={
"data": json.dumps(data, default=_serialize_json),
"headers": headers,
"timeout": self.timeout_ms / 1000,
},
)
if (
self.tracing_queue is not None
and data["trace_id"] is not None
and data["dotted_order"] is not None
):
return self.tracing_queue.put(("update", data))

self.request_with_retries(
"patch",
f"{self.api_url}/runs/{data['id']}",
request_kwargs={
"data": json.dumps(data, default=_serialize_json),
"headers": headers,
"timeout": self.timeout_ms / 1000,
},
)

def _load_child_runs(self, run: ls_schemas.Run) -> ls_schemas.Run:
"""Load child runs for a given run.
Expand Down
36 changes: 31 additions & 5 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,15 @@ def test_create_run_unicode() -> None:
client.update_run(id_, status="completed")


def test_create_run_includes_langchain_env_var_metadata() -> None:
client = Client(api_url="http://localhost:1984", api_key="123")
@pytest.mark.parametrize("auto_batch_tracing", [True, False])
def test_create_run_includes_langchain_env_var_metadata(
auto_batch_tracing: bool
) -> None:
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=auto_batch_tracing,
)
inputs = {
"foo": "これは私の友達です",
"bar": "این یک کتاب است",
Expand All @@ -212,13 +219,32 @@ def test_create_run_includes_langchain_env_var_metadata() -> None:
ls_env.get_langchain_env_var_metadata.cache_clear()
with patch.object(client, "session", session):
id_ = uuid.uuid4()
start_time = datetime.now()
client.create_run(
"my_run", inputs=inputs, run_type="llm", execution_order=1, id=id_
"my_run",
inputs=inputs,
run_type="llm",
execution_order=1,
id=id_,
trace_id=id_,
dotted_order=f"{start_time.strftime('%Y%m%dT%H%M%S%fZ')}{id_}",
start_time=start_time,
)
if auto_batch_tracing:
client.tracing_queue.join()
# Check the posted value in the request
posted_value = json.loads(session.request.call_args[1]["data"])
assert posted_value["extra"]["metadata"]["LANGCHAIN_REVISION"] == "abcd2234"
assert "LANGCHAIN_API_KEY" not in posted_value["extra"]["metadata"]
if not auto_batch_tracing:
assert (
posted_value["extra"]["metadata"]["LANGCHAIN_REVISION"]
== "abcd2234"
)
assert "LANGCHAIN_API_KEY" not in posted_value["extra"]["metadata"]
else:
assert (
posted_value["post"][0]["extra"]["metadata"]["LANGCHAIN_REVISION"]
== "abcd2234"
)


@pytest.mark.parametrize("source_type", ["api", "model"])
Expand Down

0 comments on commit 1f83125

Please sign in to comment.