diff --git a/python/langsmith/_internal/_background_thread.py b/python/langsmith/_internal/_background_thread.py index b6aee1f4e..c0f3d46ab 100644 --- a/python/langsmith/_internal/_background_thread.py +++ b/python/langsmith/_internal/_background_thread.py @@ -155,13 +155,27 @@ def tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None: # 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached num_known_refs = 3 + def keep_thread_active() -> bool: + # if `client.cleanup()` was called, stop thread + if not client or ( + hasattr(client, "_manual_cleanup") and client._manual_cleanup + ): + return False + if not threading.main_thread().is_alive(): + # main thread is dead. should not be active + return False + + if hasattr(sys, "getrefcount"): + # check if client refs count indicates we're the only remaining + # reference to the client + return sys.getrefcount(client) > num_known_refs + len(sub_threads) + else: + # in PyPy, there is no sys.getrefcount attribute + # for now, keep thread alive + return True + # loop until - while ( - # the main thread dies - threading.main_thread().is_alive() - # or we're the only remaining reference to the client - and sys.getrefcount(client) > num_known_refs + len(sub_threads) - ): + while keep_thread_active(): for thread in sub_threads: if not thread.is_alive(): sub_threads.remove(thread) diff --git a/python/langsmith/_internal/_operations.py b/python/langsmith/_internal/_operations.py index e1e99d6e2..66decff0f 100644 --- a/python/langsmith/_internal/_operations.py +++ b/python/langsmith/_internal/_operations.py @@ -5,9 +5,8 @@ import uuid from typing import Literal, Optional, Union, cast -import orjson - from langsmith import schemas as ls_schemas +from langsmith._internal import _orjson from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext from langsmith._internal._serde import dumps_json as _dumps_json @@ -169,12 +168,12 @@ def combine_serialized_queue_operations( if op._none is not None and op._none != create_op._none: # TODO optimize this more - this would currently be slowest # for large payloads - create_op_dict = orjson.loads(create_op._none) + create_op_dict = _orjson.loads(create_op._none) op_dict = { - k: v for k, v in orjson.loads(op._none).items() if v is not None + k: v for k, v in _orjson.loads(op._none).items() if v is not None } create_op_dict.update(op_dict) - create_op._none = orjson.dumps(create_op_dict) + create_op._none = _orjson.dumps(create_op_dict) if op.inputs is not None: create_op.inputs = op.inputs diff --git a/python/langsmith/_internal/_orjson.py b/python/langsmith/_internal/_orjson.py new file mode 100644 index 000000000..ecd9e20bc --- /dev/null +++ b/python/langsmith/_internal/_orjson.py @@ -0,0 +1,84 @@ +"""Stubs for orjson operations, compatible with PyPy via a json fallback.""" + +try: + from orjson import ( + OPT_NON_STR_KEYS, + OPT_SERIALIZE_DATACLASS, + OPT_SERIALIZE_NUMPY, + OPT_SERIALIZE_UUID, + Fragment, + JSONDecodeError, + dumps, + loads, + ) + +except ImportError: + import dataclasses + import json + import uuid + from typing import Any, Callable, Optional + + OPT_NON_STR_KEYS = 1 + OPT_SERIALIZE_DATACLASS = 2 + OPT_SERIALIZE_NUMPY = 4 + OPT_SERIALIZE_UUID = 8 + + class Fragment: # type: ignore + def __init__(self, payloadb: bytes): + self.payloadb = payloadb + + from json import JSONDecodeError # type: ignore + + def dumps( # type: ignore + obj: Any, + /, + default: Optional[Callable[[Any], Any]] = None, + option: int = 0, + ) -> bytes: # type: ignore + # for now, don't do anything for this case because `json.dumps` + # automatically encodes non-str keys as str by default, unlike orjson + # enable_non_str_keys = bool(option & OPT_NON_STR_KEYS) + + enable_serialize_numpy = bool(option & OPT_SERIALIZE_NUMPY) + enable_serialize_dataclass = bool(option & OPT_SERIALIZE_DATACLASS) + enable_serialize_uuid = bool(option & OPT_SERIALIZE_UUID) + + class CustomEncoder(json.JSONEncoder): # type: ignore + def encode(self, o: Any) -> str: + if isinstance(o, Fragment): + return o.payloadb.decode("utf-8") # type: ignore + return super().encode(o) + + def default(self, o: Any) -> Any: + if enable_serialize_uuid and isinstance(o, uuid.UUID): + return str(o) + if enable_serialize_numpy and hasattr(o, "tolist"): + # even objects like np.uint16(15) have a .tolist() function + return o.tolist() + if ( + enable_serialize_dataclass + and dataclasses.is_dataclass(o) + and not isinstance(o, type) + ): + return dataclasses.asdict(o) + if default is not None: + return default(o) + + return super().default(o) + + return json.dumps(obj, cls=CustomEncoder).encode("utf-8") + + def loads(payload: bytes, /) -> Any: # type: ignore + return json.loads(payload) + + +__all__ = [ + "loads", + "dumps", + "Fragment", + "JSONDecodeError", + "OPT_SERIALIZE_NUMPY", + "OPT_SERIALIZE_DATACLASS", + "OPT_SERIALIZE_UUID", + "OPT_NON_STR_KEYS", +] diff --git a/python/langsmith/_internal/_serde.py b/python/langsmith/_internal/_serde.py index e77f7319d..1bf8865c1 100644 --- a/python/langsmith/_internal/_serde.py +++ b/python/langsmith/_internal/_serde.py @@ -12,7 +12,7 @@ import uuid from typing import Any -import orjson +from langsmith._internal import _orjson try: from zoneinfo import ZoneInfo # type: ignore[import-not-found] @@ -133,13 +133,13 @@ def dumps_json(obj: Any) -> bytes: The JSON formatted string. """ try: - return orjson.dumps( + return _orjson.dumps( obj, default=_serialize_json, - option=orjson.OPT_SERIALIZE_NUMPY - | orjson.OPT_SERIALIZE_DATACLASS - | orjson.OPT_SERIALIZE_UUID - | orjson.OPT_NON_STR_KEYS, + option=_orjson.OPT_SERIALIZE_NUMPY + | _orjson.OPT_SERIALIZE_DATACLASS + | _orjson.OPT_SERIALIZE_UUID + | _orjson.OPT_NON_STR_KEYS, ) except TypeError as e: # Usually caused by UTF surrogate characters @@ -150,9 +150,9 @@ def dumps_json(obj: Any) -> bytes: ensure_ascii=True, ).encode("utf-8") try: - result = orjson.dumps( - orjson.loads(result.decode("utf-8", errors="surrogateescape")) + result = _orjson.dumps( + _orjson.loads(result.decode("utf-8", errors="surrogateescape")) ) - except orjson.JSONDecodeError: + except _orjson.JSONDecodeError: result = _elide_surrogates(result) return result diff --git a/python/langsmith/_testing.py b/python/langsmith/_testing.py index 8dd72fbcb..9eaa0877f 100644 --- a/python/langsmith/_testing.py +++ b/python/langsmith/_testing.py @@ -12,7 +12,6 @@ from pathlib import Path from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload -import orjson from typing_extensions import TypedDict from langsmith import client as ls_client @@ -21,6 +20,7 @@ from langsmith import run_trees as rt from langsmith import schemas as ls_schemas from langsmith import utils as ls_utils +from langsmith._internal import _orjson try: import pytest # type: ignore @@ -374,7 +374,7 @@ def _serde_example_values(values: VT) -> VT: if values is None: return values bts = ls_client._dumps_json(values) - return orjson.loads(bts) + return _orjson.loads(bts) class _LangSmithTestSuite: diff --git a/python/langsmith/client.py b/python/langsmith/client.py index eb397b4c4..8348b57d1 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -55,7 +55,6 @@ ) from urllib import parse as urllib_parse -import orjson import requests from requests import adapters as requests_adapters from requests_toolbelt import ( # type: ignore[import-untyped] @@ -69,6 +68,7 @@ from langsmith import env as ls_env from langsmith import schemas as ls_schemas from langsmith import utils as ls_utils +from langsmith._internal import _orjson from langsmith._internal._background_thread import ( TracingQueueItem, ) @@ -368,6 +368,7 @@ class Client: "_info", "_write_api_urls", "_settings", + "_manual_cleanup", ] def __init__( @@ -516,6 +517,8 @@ def __init__( self._settings: Union[ls_schemas.LangSmithSettings, None] = None + self._manual_cleanup = False + def _repr_html_(self) -> str: """Return an HTML representation of the instance with a link to the URL. @@ -1252,7 +1255,7 @@ def _hide_run_inputs(self, inputs: dict): if self._hide_inputs is True: return {} if self._anonymizer: - json_inputs = orjson.loads(_dumps_json(inputs)) + json_inputs = _orjson.loads(_dumps_json(inputs)) return self._anonymizer(json_inputs) if self._hide_inputs is False: return inputs @@ -1262,7 +1265,7 @@ def _hide_run_outputs(self, outputs: dict): if self._hide_outputs is True: return {} if self._anonymizer: - json_outputs = orjson.loads(_dumps_json(outputs)) + json_outputs = _orjson.loads(_dumps_json(outputs)) return self._anonymizer(json_outputs) if self._hide_outputs is False: return outputs @@ -1282,20 +1285,20 @@ def _batch_ingest_run_ops( # form the partial body and ids for op in ops: if isinstance(op, SerializedRunOperation): - curr_dict = orjson.loads(op._none) + curr_dict = _orjson.loads(op._none) if op.inputs: - curr_dict["inputs"] = orjson.Fragment(op.inputs) + curr_dict["inputs"] = _orjson.Fragment(op.inputs) if op.outputs: - curr_dict["outputs"] = orjson.Fragment(op.outputs) + curr_dict["outputs"] = _orjson.Fragment(op.outputs) if op.events: - curr_dict["events"] = orjson.Fragment(op.events) + curr_dict["events"] = _orjson.Fragment(op.events) if op.attachments: logger.warning( "Attachments are not supported when use_multipart_endpoint " "is False" ) ids_and_partial_body[op.operation].append( - (f"trace={op.trace_id},id={op.id}", orjson.dumps(curr_dict)) + (f"trace={op.trace_id},id={op.id}", _orjson.dumps(curr_dict)) ) elif isinstance(op, SerializedFeedbackOperation): logger.warning( @@ -1321,7 +1324,7 @@ def _batch_ingest_run_ops( and body_size + len(body_deque[0][1]) > size_limit_bytes ): self._post_batch_ingest_runs( - orjson.dumps(body_chunks), + _orjson.dumps(body_chunks), _context=f"\n{key}: {'; '.join(context_ids[key])}", ) body_size = 0 @@ -1329,12 +1332,12 @@ def _batch_ingest_run_ops( context_ids.clear() curr_id, curr_body = body_deque.popleft() body_size += len(curr_body) - body_chunks[key].append(orjson.Fragment(curr_body)) + body_chunks[key].append(_orjson.Fragment(curr_body)) context_ids[key].append(curr_id) if body_size: context = "; ".join(f"{k}: {'; '.join(v)}" for k, v in context_ids.items()) self._post_batch_ingest_runs( - orjson.dumps(body_chunks), _context="\n" + context + _orjson.dumps(body_chunks), _context="\n" + context ) def batch_ingest_runs( @@ -2759,7 +2762,7 @@ def create_dataset( "POST", "/datasets", headers={**self._headers, "Content-Type": "application/json"}, - data=orjson.dumps(dataset), + data=_orjson.dumps(dataset), ) ls_utils.raise_for_status_with_text(response) @@ -5675,6 +5678,10 @@ def push_prompt( ) return url + def cleanup(self) -> None: + """Manually trigger cleanup of the background thread.""" + self._manual_cleanup = True + def convert_prompt_to_openai_format( messages: Any, diff --git a/python/poetry.lock b/python/poetry.lock index a2e1c3667..2b362f986 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -2070,4 +2070,4 @@ vcr = [] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "ca8fa5c9a82d58bea646d5e7e1089175111ddec2c24cd0b19920d1afd4dd93da" +content-hash = "a5a6c61cba1b5ce9cf739700a780c2df63ff7aaa482c29de9910418263318586" diff --git a/python/pyproject.toml b/python/pyproject.toml index fc1d71da3..191d61b22 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -31,7 +31,7 @@ pydantic = [ { version = "^2.7.4", python = ">=3.12.4" }, ] requests = "^2" -orjson = "^3.9.14" +orjson = { version = "^3.9.14", markers = "platform_python_implementation != 'PyPy'" } httpx = ">=0.23.0,<1" requests-toolbelt = "^1.0.0" diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index 5dc1bbe1e..feec2c2f6 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -22,7 +22,6 @@ from unittest.mock import MagicMock, patch import dataclasses_json -import orjson import pytest import requests from multipart import MultipartParser, MultipartPart, parse_options_header @@ -33,6 +32,7 @@ import langsmith.utils as ls_utils from langsmith import AsyncClient, EvaluationResult, run_trees from langsmith import schemas as ls_schemas +from langsmith._internal import _orjson from langsmith._internal._serde import _serialize_json from langsmith.client import ( Client, @@ -848,7 +848,7 @@ class MyNamedTuple(NamedTuple): "set_with_class": set([MyClass(1)]), "my_mock": MagicMock(text="Hello, world"), } - res = orjson.loads(_dumps_json(to_serialize)) + res = _orjson.loads(_dumps_json(to_serialize)) assert ( "model_dump" not in caplog.text ), f"Unexpected error logs were emitted: {caplog.text}" @@ -898,7 +898,7 @@ def __repr__(self) -> str: my_cyclic = CyclicClass(other=CyclicClass(other=None)) my_cyclic.other.other = my_cyclic # type: ignore - res = orjson.loads(_dumps_json({"cyclic": my_cyclic})) + res = _orjson.loads(_dumps_json({"cyclic": my_cyclic})) assert res == {"cyclic": "my_cycles..."} expected = {"foo": "foo", "bar": 1} @@ -1142,7 +1142,7 @@ def test_batch_ingest_run_splits_large_batches( op for call in mock_session.request.call_args_list for reqs in ( - orjson.loads(call[1]["data"]).values() if call[0][0] == "POST" else [] + _orjson.loads(call[1]["data"]).values() if call[0][0] == "POST" else [] ) for op in reqs ] diff --git a/python/tests/unit_tests/test_operations.py b/python/tests/unit_tests/test_operations.py index a6b5cdeb3..43d06ebc5 100644 --- a/python/tests/unit_tests/test_operations.py +++ b/python/tests/unit_tests/test_operations.py @@ -1,5 +1,4 @@ -import orjson - +from langsmith._internal import _orjson from langsmith._internal._operations import ( SerializedFeedbackOperation, SerializedRunOperation, @@ -14,7 +13,7 @@ def test_combine_serialized_queue_operations(): operation="post", id="id1", trace_id="trace_id1", - _none=orjson.dumps({"a": 1}), + _none=_orjson.dumps({"a": 1}), inputs="inputs1", outputs="outputs1", events="events1", @@ -24,7 +23,7 @@ def test_combine_serialized_queue_operations(): operation="patch", id="id1", trace_id="trace_id1", - _none=orjson.dumps({"b": "2"}), + _none=_orjson.dumps({"b": "2"}), inputs="inputs1-patched", outputs="outputs1-patched", events="events1", @@ -87,7 +86,7 @@ def test_combine_serialized_queue_operations(): operation="post", id="id1", trace_id="trace_id1", - _none=orjson.dumps({"a": 1, "b": "2"}), + _none=_orjson.dumps({"a": 1, "b": "2"}), inputs="inputs1-patched", outputs="outputs1-patched", events="events1",