Skip to content

Commit

Permalink
fix threading
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Mar 5, 2024
1 parent f10dfa5 commit 18b6ca2
Showing 1 changed file with 63 additions and 54 deletions.
117 changes: 63 additions & 54 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,17 +418,14 @@ def __init__(
self.session = session if session else requests.Session()
weakref.finalize(self, close_session, self.session)
# Initialize auto batching
if auto_batch_tracing:
self.tracing_queue: Optional[PriorityQueue] = PriorityQueue()

threading.Thread(
target=_tracing_control_thread_func,
# arg must be a weakref to self to avoid the Thread object
# preventing garbage collection of the Client object
args=(weakref.ref(self),),
).start()
else:
self.tracing_queue = None
self.tracing_queue: Optional[PriorityQueue] = PriorityQueue()

threading.Thread(
target=_tracing_control_thread_func,
# arg must be a weakref to self to avoid the Thread object
# preventing garbage collection of the Client object
args=(weakref.ref(self), auto_batch_tracing),
).start()

# Mount the HTTPAdapter with the retry configuration
adapter = requests_adapters.HTTPAdapter(max_retries=self.retry_config)
Expand Down Expand Up @@ -506,7 +503,7 @@ def _headers(self) -> Dict[str, str]:
return headers

@property
def info(self) -> Optional[ls_schemas.LangSmithInfo]:
def info(self) -> ls_schemas.LangSmithInfo:
"""Get the information about the LangSmith API.
Returns:
Expand All @@ -515,16 +512,13 @@ def info(self) -> Optional[ls_schemas.LangSmithInfo]:
The information about the LangSmith API, or None if the API is
not available.
"""
info = Client._get_info(self.session, self.api_url, self.timeout_ms)
if info is None and self.tracing_queue is not None:
self.tracing_queue = None
return cast(Optional[ls_schemas.LangSmithInfo], info)
return Client._get_info(self.session, self.api_url, self.timeout_ms)

@staticmethod
@functools.lru_cache(maxsize=1)
def _get_info(
session: requests.Session, api_url: str, timeout_ms: int
) -> Optional[ls_schemas.LangSmithInfo]:
) -> ls_schemas.LangSmithInfo:
try:
response = session.get(
api_url + "/info",
Expand All @@ -534,10 +528,10 @@ def _get_info(
ls_utils.raise_for_status_with_text(response)
return ls_schemas.LangSmithInfo(**response.json())
except requests.HTTPError:
return None
return ls_schemas.LangSmithInfo()
except BaseException as e:
logger.warning(f"Failed to get info from {api_url}: {repr(e)}")
return None
return ls_schemas.LangSmithInfo()

def request_with_retries(
self,
Expand Down Expand Up @@ -647,9 +641,14 @@ def request_with_retries(
f" {repr(e)}"
)
except requests.ConnectionError as e:
recommendation = (
"Please confirm your LANGCHAIN_ENDPOINT"
if self.api_url != "https://api.smith.langchain.com"
else "Please confirm your internet connection"
)
raise ls_utils.LangSmithConnectionError(
f"Connection error caused failure to {request_method} {url}"
" in LangSmith API. Please confirm your LANGCHAIN_ENDPOINT."
f" in LangSmith API. {recommendation}."
f" {repr(e)}"
) from e
except Exception as e:
Expand Down Expand Up @@ -1022,13 +1021,13 @@ def create_run(
# batch ingest requires trace_id and dotted_order to be set
and run_create.get("trace_id") is not None
and run_create.get("dotted_order") is not None
# Checked last since it makes a (cached) API call
and self.info is not None # Older versions don't support batch ingest
):
return self.tracing_queue.put(
TracingQueueItem(run_create["dotted_order"], "create", run_create)
)
self._create_run(run_create)

def _create_run(self, run_create: dict):
headers = {
**self._headers,
"Accept": "application/json",
Expand Down Expand Up @@ -1135,13 +1134,7 @@ def batch_ingest_runs(
return

self._insert_runtime_env(raw_body["post"])

if self.info is None:
raise ls_utils.LangSmithUserError(
"Batch ingest is not supported by your LangSmith server version. "
"Please upgrade to a newer version."
)
info = cast(ls_schemas.LangSmithInfo, self.info)
info = self.info

size_limit_bytes = (info.batch_ingest_config or {}).get(
"size_limit_bytes"
Expand Down Expand Up @@ -1240,11 +1233,6 @@ def update_run(
**kwargs : Any
Kwargs are ignored.
"""
headers = {
**self._headers,
"Accept": "application/json",
"Content-Type": "application/json",
}
data: Dict[str, Any] = {
"id": _as_uuid(run_id, "run_id"),
"trace_id": kwargs.pop("trace_id", None),
Expand Down Expand Up @@ -1272,18 +1260,23 @@ def update_run(
# batch ingest requires trace_id and dotted_order to be set
and data["trace_id"] is not None
and data["dotted_order"] is not None
# Checked last since it makes an API call
and self.info is not None # Older versions don't support batch ingest
):
return self.tracing_queue.put(
TracingQueueItem(data["dotted_order"], "update", data)
)

def _update_run(self, run_update: dict) -> None:
headers = {
**self._headers,
"Accept": "application/json",
"Content-Type": "application/json",
}

self.request_with_retries(
"patch",
f"{self.api_url}/runs/{data['id']}",
f"{self.api_url}/runs/{run_update['id']}",
request_kwargs={
"data": _dumps_json(data),
"data": _dumps_json(run_update),
"headers": headers,
"timeout": self.timeout_ms / 1000,
},
Expand Down Expand Up @@ -3883,12 +3876,21 @@ def _tracing_thread_drain_queue(


def _tracing_thread_handle_batch(
client: Client, tracing_queue: Queue, batch: List[TracingQueueItem]
client: Client,
tracing_queue: Queue,
batch: List[TracingQueueItem],
auto_batch_tracing: bool = True,
) -> None:
create = [it.item for it in batch if it.action == "create"]
update = [it.item for it in batch if it.action == "update"]
try:
client.batch_ingest_runs(create=create, update=update, pre_sampled=True)
if auto_batch_tracing:
client.batch_ingest_runs(create=create, update=update, pre_sampled=True)
else:
for run_create in create:
client._create_run(run_create)
for run_update in update:
client._update_run(run_update)
except Exception:
logger.error("Error in tracing queue", exc_info=True)
# exceptions are logged elsewhere, but we need to make sure the
Expand All @@ -3905,7 +3907,7 @@ def _tracing_thread_handle_batch(


def _ensure_ingest_config(
info: Optional[ls_schemas.LangSmithInfo],
info: ls_schemas.LangSmithInfo,
) -> ls_schemas.BatchIngestConfig:
default_config = ls_schemas.BatchIngestConfig(
size_limit_bytes=None, # Note this field is not used here
Expand All @@ -3924,16 +3926,12 @@ def _ensure_ingest_config(
return default_config


def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
def _tracing_control_thread_func(
client_ref: weakref.ref[Client], auto_batch_tracing: bool
) -> None:
client = client_ref()
if client is None:
return
try:
if not client.info:
return
except BaseException as e:
logger.debug("Error in tracing control thread: %s", e)
return
tracing_queue = client.tracing_queue
assert tracing_queue is not None
batch_ingest_config = _ensure_ingest_config(client.info)
Expand All @@ -3960,20 +3958,27 @@ def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
and tracing_queue.qsize() > scale_up_qsize_trigger
):
new_thread = threading.Thread(
target=_tracing_sub_thread_func, args=(weakref.ref(client),)
target=_tracing_sub_thread_func,
args=(weakref.ref(client), auto_batch_tracing),
)
sub_threads.append(new_thread)
new_thread.start()
if next_batch := _tracing_thread_drain_queue(tracing_queue, limit=size_limit):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, auto_batch_tracing=auto_batch_tracing
)
# drain the queue on exit
while next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, block=False
):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, auto_batch_tracing=auto_batch_tracing
)


def _tracing_sub_thread_func(client_ref: weakref.ref[Client]) -> None:
def _tracing_sub_thread_func(
client_ref: weakref.ref[Client], auto_batch_tracing: bool
) -> None:
client = client_ref()
if client is None:
return
Expand All @@ -3999,12 +4004,16 @@ def _tracing_sub_thread_func(client_ref: weakref.ref[Client]) -> None:
):
if next_batch := _tracing_thread_drain_queue(tracing_queue, limit=size_limit):
seen_successive_empty_queues = 0
_tracing_thread_handle_batch(client, tracing_queue, next_batch)
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, auto_batch_tracing=auto_batch_tracing
)
else:
seen_successive_empty_queues += 1

# drain the queue on exit
while next_batch := _tracing_thread_drain_queue(
tracing_queue, limit=size_limit, block=False
):
_tracing_thread_handle_batch(client, tracing_queue, next_batch)
_tracing_thread_handle_batch(
client, tracing_queue, next_batch, auto_batch_tracing=auto_batch_tracing
)

0 comments on commit 18b6ca2

Please sign in to comment.