Skip to content

Commit

Permalink
Wfh/use system batch config (#404)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Feb 6, 2024
1 parent 50ac7dc commit 8ed33dd
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 134 deletions.
2 changes: 1 addition & 1 deletion js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@
},
"./package.json": "./package.json"
}
}
}
107 changes: 87 additions & 20 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,25 +458,37 @@ def _headers(self) -> Dict[str, str]:
return headers

@property
@ls_utils.ttl_cache(maxsize=1)
def info(self) -> Optional[ls_schemas.LangSmithInfo]:
"""Get the information about the LangSmith API.
Returns
-------
dict
The information about the LangSmith API.
Optional[ls_schemas.LangSmithInfo]
The information about the LangSmith API, or None if the API is
not available.
"""
info = Client._get_info(self.session, self.api_url, self.timeout_ms)
if info is None and self.tracing_queue is not None:
self.tracing_queue = None
return cast(Optional[ls_schemas.LangSmithInfo], info)

@staticmethod
@functools.lru_cache(maxsize=1)
def _get_info(
session: requests.Session, api_url: str, timeout_ms: int
) -> Optional[ls_schemas.LangSmithInfo]:
try:
response = self.session.get(
self.api_url + "/info",
headers=self._headers,
timeout=self.timeout_ms / 1000,
response = session.get(
api_url + "/info",
headers={"Accept": "application/json"},
timeout=timeout_ms / 1000,
)
ls_utils.raise_for_status_with_text(response)
return ls_schemas.LangSmithInfo(**response.json())
except ls_utils.LangSmithAPIError as e:
logger.debug("Failed to get info: %s", e)
except requests.HTTPError:
return None
except BaseException as e:
logger.warning(f"Failed to get info from {api_url}: {repr(e)}")
return None

def request_with_retries(
Expand Down Expand Up @@ -829,6 +841,15 @@ def upload_csv(
def _run_transform(
run: Union[ls_schemas.Run, dict, ls_schemas.RunLikeDict],
) -> dict:
"""
Transforms the given run object into a dictionary representation.
Args:
run (Union[ls_schemas.Run, dict]): The run object to transform.
Returns:
dict: The transformed run object as a dictionary.
"""
if hasattr(run, "dict") and callable(getattr(run, "dict")):
run_create = run.dict() # type: ignore
else:
Expand Down Expand Up @@ -935,6 +956,8 @@ def create_run(
# batch ingest requires trace_id and dotted_order to be set
and run_create.get("trace_id") is not None
and run_create.get("dotted_order") is not None
# Checked last since it makes a (cached) API call
and self.info is not None # Older versions don't support batch ingest
):
return self.tracing_queue.put(
TracingQueueItem(run_create["dotted_order"], "create", run_create)
Expand Down Expand Up @@ -1109,6 +1132,8 @@ def update_run(
# batch ingest requires trace_id and dotted_order to be set
and data["trace_id"] is not None
and data["dotted_order"] is not None
# Checked last since it makes an API call
and self.info is not None # Older versions don't support batch ingest
):
return self.tracing_queue.put(
TracingQueueItem(data["dotted_order"], "update", data)
Expand Down Expand Up @@ -3250,65 +3275,107 @@ def _tracing_thread_handle_batch(
_AUTO_SCALE_DOWN_NEMPTY_TRIGGER = 4


def _ensure_ingest_config(
info: Optional[ls_schemas.LangSmithInfo],
) -> ls_schemas.BatchIngestConfig:
default_config = ls_schemas.BatchIngestConfig(
size_limit=100,
scale_up_nthreads_limit=_AUTO_SCALE_UP_NTHREADS_LIMIT,
scale_up_qsize_trigger=_AUTO_SCALE_UP_QSIZE_TRIGGER,
scale_down_nempty_trigger=_AUTO_SCALE_DOWN_NEMPTY_TRIGGER,
)
if not info:
return default_config
try:
if not info.batch_ingest_config:
return default_config
return info.batch_ingest_config
except BaseException:
return default_config


def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
client = client_ref()
if client is None:
return
try:
if not client.info:
print(f"no info: {client.info}", file=sys.stderr, flush=True)
return
except BaseException as e:
logger.debug("Error in tracing control thread: %s", e)
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None
batch_ingest_config = _ensure_ingest_config(client.info)
size_limit: int = batch_ingest_config["size_limit"]
scale_up_nthreads_limit: int = batch_ingest_config["scale_up_nthreads_limit"]
scale_up_qsize_trigger: int = batch_ingest_config["scale_up_qsize_trigger"]

sub_threads: List[threading.Thread] = []
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
num_known_refs = 3

# loop until
while (
# the main thread dies
threading.main_thread().is_alive()
# or we're the only remaining reference to the client
and sys.getrefcount(client) > 3 + len(sub_threads)
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
and sys.getrefcount(client) > num_known_refs + len(sub_threads)
):
for thread in sub_threads:
if not thread.is_alive():
sub_threads.remove(thread)
if (
len(sub_threads) < _AUTO_SCALE_UP_NTHREADS_LIMIT
and tracing_queue.qsize() > _AUTO_SCALE_UP_QSIZE_TRIGGER
len(sub_threads) < scale_up_nthreads_limit
and tracing_queue.qsize() > scale_up_qsize_trigger
):
new_thread = threading.Thread(
target=_tracing_sub_thread_func, args=(weakref.ref(client),)
)
sub_threads.append(new_thread)
new_thread.start()
if next_batch := _tracing_thread_drain_queue(tracing_queue):
if next_batch := _tracing_thread_drain_queue(tracing_queue, limit=size_limit):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)

# drain the queue on exit
while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False):
while next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, block=False
):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)


def _tracing_sub_thread_func(client_ref: weakref.ref[Client]) -> None:
client = client_ref()
if client is None:
return
try:
if not client.info:
return
except BaseException as e:
logger.debug("Error in tracing control thread: %s", e)
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None

batch_ingest_config = _ensure_ingest_config(client.info)
size_limit = batch_ingest_config.get("size_limit", 100)
seen_successive_empty_queues = 0

# loop until
while (
# the main thread dies
threading.main_thread().is_alive()
# or we've seen the queue empty 4 times in a row
and seen_successive_empty_queues <= _AUTO_SCALE_DOWN_NEMPTY_TRIGGER
and seen_successive_empty_queues
<= batch_ingest_config["scale_down_nempty_trigger"]
):
if next_batch := _tracing_thread_drain_queue(tracing_queue):
if next_batch := _tracing_thread_drain_queue(tracing_queue, limit=size_limit):
seen_successive_empty_queues = 0
_tracing_thread_handle_batch(client, tracing_queue, next_batch)
else:
seen_successive_empty_queues += 1

# drain the queue on exit
while next_batch := _tracing_thread_drain_queue(tracing_queue, block=False):
while next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, block=False
):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)
37 changes: 0 additions & 37 deletions python/langsmith/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import logging
import os
import subprocess
import threading
import time
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import requests
Expand Down Expand Up @@ -287,38 +285,3 @@ def filter(self, record) -> bool:
return (
"Connection pool is full, discarding connection" not in record.getMessage()
)


def ttl_cache(
ttl_seconds: Optional[int] = None, maxsize: Optional[int] = None
) -> Callable:
"""LRU cache with an optional TTL."""

def decorator(func: Callable) -> Callable:
cache: Dict[Tuple, Tuple] = {}
cache_lock = threading.RLock()

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
key = (args, frozenset(kwargs.items()))
with cache_lock:
if key in cache:
result, timestamp = cache[key]
if ttl_seconds is None or time.time() - timestamp < ttl_seconds:
# Refresh the timestamp
cache[key] = (result, time.time())
return result
result = func(*args, **kwargs)
with cache_lock:
cache[key] = (result, time.time())

if maxsize is not None:
if len(cache) > maxsize:
oldest_key = min(cache, key=lambda k: cache[k][1])
del cache[oldest_key]

return result

return wrapper

return decorator
7 changes: 4 additions & 3 deletions python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""LangSmith langchain_client Integration Tests."""

import io
import os
import random
Expand Down Expand Up @@ -501,6 +502,6 @@ def test_get_info() -> None:
langchain_client = Client(api_key="not-a-real-key")
info = langchain_client.info
assert info
assert info.version is not None
assert info.batch_ingest_config is not None
assert info.batch_ingest_config["size_limit"] > 0
assert info.version is not None # type: ignore
assert info.batch_ingest_config is not None # type: ignore
assert info.batch_ingest_config["size_limit"] > 0 # type: ignore
Loading

0 comments on commit 8ed33dd

Please sign in to comment.