Skip to content

Commit

Permalink
python[patch]: orjson optional take 2 (#1230)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored Nov 19, 2024
1 parent a638687 commit 79f3008
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 46 deletions.
26 changes: 20 additions & 6 deletions python/langsmith/_internal/_background_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,27 @@ def tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None:
# 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached
num_known_refs = 3

def keep_thread_active() -> bool:
# if `client.cleanup()` was called, stop thread
if not client or (
hasattr(client, "_manual_cleanup") and client._manual_cleanup
):
return False
if not threading.main_thread().is_alive():
# main thread is dead. should not be active
return False

if hasattr(sys, "getrefcount"):
# check if client refs count indicates we're the only remaining
# reference to the client
return sys.getrefcount(client) > num_known_refs + len(sub_threads)
else:
# in PyPy, there is no sys.getrefcount attribute
# for now, keep thread alive
return True

# loop until
while (
# the main thread dies
threading.main_thread().is_alive()
# or we're the only remaining reference to the client
and sys.getrefcount(client) > num_known_refs + len(sub_threads)
):
while keep_thread_active():
for thread in sub_threads:
if not thread.is_alive():
sub_threads.remove(thread)
Expand Down
9 changes: 4 additions & 5 deletions python/langsmith/_internal/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import uuid
from typing import Literal, Optional, Union, cast

import orjson

from langsmith import schemas as ls_schemas
from langsmith._internal import _orjson
from langsmith._internal._multipart import MultipartPart, MultipartPartsAndContext
from langsmith._internal._serde import dumps_json as _dumps_json

Expand Down Expand Up @@ -169,12 +168,12 @@ def combine_serialized_queue_operations(
if op._none is not None and op._none != create_op._none:
# TODO optimize this more - this would currently be slowest
# for large payloads
create_op_dict = orjson.loads(create_op._none)
create_op_dict = _orjson.loads(create_op._none)
op_dict = {
k: v for k, v in orjson.loads(op._none).items() if v is not None
k: v for k, v in _orjson.loads(op._none).items() if v is not None
}
create_op_dict.update(op_dict)
create_op._none = orjson.dumps(create_op_dict)
create_op._none = _orjson.dumps(create_op_dict)

if op.inputs is not None:
create_op.inputs = op.inputs
Expand Down
84 changes: 84 additions & 0 deletions python/langsmith/_internal/_orjson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Stubs for orjson operations, compatible with PyPy via a json fallback."""

try:
from orjson import (
OPT_NON_STR_KEYS,
OPT_SERIALIZE_DATACLASS,
OPT_SERIALIZE_NUMPY,
OPT_SERIALIZE_UUID,
Fragment,
JSONDecodeError,
dumps,
loads,
)

except ImportError:
import dataclasses
import json
import uuid
from typing import Any, Callable, Optional

OPT_NON_STR_KEYS = 1
OPT_SERIALIZE_DATACLASS = 2
OPT_SERIALIZE_NUMPY = 4
OPT_SERIALIZE_UUID = 8

class Fragment: # type: ignore
def __init__(self, payloadb: bytes):
self.payloadb = payloadb

from json import JSONDecodeError # type: ignore

def dumps( # type: ignore
obj: Any,
/,
default: Optional[Callable[[Any], Any]] = None,
option: int = 0,
) -> bytes: # type: ignore
# for now, don't do anything for this case because `json.dumps`
# automatically encodes non-str keys as str by default, unlike orjson
# enable_non_str_keys = bool(option & OPT_NON_STR_KEYS)

enable_serialize_numpy = bool(option & OPT_SERIALIZE_NUMPY)
enable_serialize_dataclass = bool(option & OPT_SERIALIZE_DATACLASS)
enable_serialize_uuid = bool(option & OPT_SERIALIZE_UUID)

class CustomEncoder(json.JSONEncoder): # type: ignore
def encode(self, o: Any) -> str:
if isinstance(o, Fragment):
return o.payloadb.decode("utf-8") # type: ignore
return super().encode(o)

def default(self, o: Any) -> Any:
if enable_serialize_uuid and isinstance(o, uuid.UUID):
return str(o)
if enable_serialize_numpy and hasattr(o, "tolist"):
# even objects like np.uint16(15) have a .tolist() function
return o.tolist()
if (
enable_serialize_dataclass
and dataclasses.is_dataclass(o)
and not isinstance(o, type)
):
return dataclasses.asdict(o)
if default is not None:
return default(o)

return super().default(o)

return json.dumps(obj, cls=CustomEncoder).encode("utf-8")

def loads(payload: bytes, /) -> Any: # type: ignore
return json.loads(payload)


__all__ = [
"loads",
"dumps",
"Fragment",
"JSONDecodeError",
"OPT_SERIALIZE_NUMPY",
"OPT_SERIALIZE_DATACLASS",
"OPT_SERIALIZE_UUID",
"OPT_NON_STR_KEYS",
]
18 changes: 9 additions & 9 deletions python/langsmith/_internal/_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import uuid
from typing import Any

import orjson
from langsmith._internal import _orjson

try:
from zoneinfo import ZoneInfo # type: ignore[import-not-found]
Expand Down Expand Up @@ -133,13 +133,13 @@ def dumps_json(obj: Any) -> bytes:
The JSON formatted string.
"""
try:
return orjson.dumps(
return _orjson.dumps(
obj,
default=_serialize_json,
option=orjson.OPT_SERIALIZE_NUMPY
| orjson.OPT_SERIALIZE_DATACLASS
| orjson.OPT_SERIALIZE_UUID
| orjson.OPT_NON_STR_KEYS,
option=_orjson.OPT_SERIALIZE_NUMPY
| _orjson.OPT_SERIALIZE_DATACLASS
| _orjson.OPT_SERIALIZE_UUID
| _orjson.OPT_NON_STR_KEYS,
)
except TypeError as e:
# Usually caused by UTF surrogate characters
Expand All @@ -150,9 +150,9 @@ def dumps_json(obj: Any) -> bytes:
ensure_ascii=True,
).encode("utf-8")
try:
result = orjson.dumps(
orjson.loads(result.decode("utf-8", errors="surrogateescape"))
result = _orjson.dumps(
_orjson.loads(result.decode("utf-8", errors="surrogateescape"))
)
except orjson.JSONDecodeError:
except _orjson.JSONDecodeError:
result = _elide_surrogates(result)
return result
4 changes: 2 additions & 2 deletions python/langsmith/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, overload

import orjson
from typing_extensions import TypedDict

from langsmith import client as ls_client
Expand All @@ -21,6 +20,7 @@
from langsmith import run_trees as rt
from langsmith import schemas as ls_schemas
from langsmith import utils as ls_utils
from langsmith._internal import _orjson

try:
import pytest # type: ignore
Expand Down Expand Up @@ -374,7 +374,7 @@ def _serde_example_values(values: VT) -> VT:
if values is None:
return values
bts = ls_client._dumps_json(values)
return orjson.loads(bts)
return _orjson.loads(bts)


class _LangSmithTestSuite:
Expand Down
31 changes: 19 additions & 12 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
)
from urllib import parse as urllib_parse

import orjson
import requests
from requests import adapters as requests_adapters
from requests_toolbelt import ( # type: ignore[import-untyped]
Expand All @@ -69,6 +68,7 @@
from langsmith import env as ls_env
from langsmith import schemas as ls_schemas
from langsmith import utils as ls_utils
from langsmith._internal import _orjson
from langsmith._internal._background_thread import (
TracingQueueItem,
)
Expand Down Expand Up @@ -368,6 +368,7 @@ class Client:
"_info",
"_write_api_urls",
"_settings",
"_manual_cleanup",
]

def __init__(
Expand Down Expand Up @@ -516,6 +517,8 @@ def __init__(

self._settings: Union[ls_schemas.LangSmithSettings, None] = None

self._manual_cleanup = False

def _repr_html_(self) -> str:
"""Return an HTML representation of the instance with a link to the URL.
Expand Down Expand Up @@ -1252,7 +1255,7 @@ def _hide_run_inputs(self, inputs: dict):
if self._hide_inputs is True:
return {}
if self._anonymizer:
json_inputs = orjson.loads(_dumps_json(inputs))
json_inputs = _orjson.loads(_dumps_json(inputs))
return self._anonymizer(json_inputs)
if self._hide_inputs is False:
return inputs
Expand All @@ -1262,7 +1265,7 @@ def _hide_run_outputs(self, outputs: dict):
if self._hide_outputs is True:
return {}
if self._anonymizer:
json_outputs = orjson.loads(_dumps_json(outputs))
json_outputs = _orjson.loads(_dumps_json(outputs))
return self._anonymizer(json_outputs)
if self._hide_outputs is False:
return outputs
Expand All @@ -1282,20 +1285,20 @@ def _batch_ingest_run_ops(
# form the partial body and ids
for op in ops:
if isinstance(op, SerializedRunOperation):
curr_dict = orjson.loads(op._none)
curr_dict = _orjson.loads(op._none)
if op.inputs:
curr_dict["inputs"] = orjson.Fragment(op.inputs)
curr_dict["inputs"] = _orjson.Fragment(op.inputs)
if op.outputs:
curr_dict["outputs"] = orjson.Fragment(op.outputs)
curr_dict["outputs"] = _orjson.Fragment(op.outputs)
if op.events:
curr_dict["events"] = orjson.Fragment(op.events)
curr_dict["events"] = _orjson.Fragment(op.events)
if op.attachments:
logger.warning(
"Attachments are not supported when use_multipart_endpoint "
"is False"
)
ids_and_partial_body[op.operation].append(
(f"trace={op.trace_id},id={op.id}", orjson.dumps(curr_dict))
(f"trace={op.trace_id},id={op.id}", _orjson.dumps(curr_dict))
)
elif isinstance(op, SerializedFeedbackOperation):
logger.warning(
Expand All @@ -1321,20 +1324,20 @@ def _batch_ingest_run_ops(
and body_size + len(body_deque[0][1]) > size_limit_bytes
):
self._post_batch_ingest_runs(
orjson.dumps(body_chunks),
_orjson.dumps(body_chunks),
_context=f"\n{key}: {'; '.join(context_ids[key])}",
)
body_size = 0
body_chunks.clear()
context_ids.clear()
curr_id, curr_body = body_deque.popleft()
body_size += len(curr_body)
body_chunks[key].append(orjson.Fragment(curr_body))
body_chunks[key].append(_orjson.Fragment(curr_body))
context_ids[key].append(curr_id)
if body_size:
context = "; ".join(f"{k}: {'; '.join(v)}" for k, v in context_ids.items())
self._post_batch_ingest_runs(
orjson.dumps(body_chunks), _context="\n" + context
_orjson.dumps(body_chunks), _context="\n" + context
)

def batch_ingest_runs(
Expand Down Expand Up @@ -2759,7 +2762,7 @@ def create_dataset(
"POST",
"/datasets",
headers={**self._headers, "Content-Type": "application/json"},
data=orjson.dumps(dataset),
data=_orjson.dumps(dataset),
)
ls_utils.raise_for_status_with_text(response)

Expand Down Expand Up @@ -5675,6 +5678,10 @@ def push_prompt(
)
return url

def cleanup(self) -> None:
"""Manually trigger cleanup of the background thread."""
self._manual_cleanup = True


def convert_prompt_to_openai_format(
messages: Any,
Expand Down
4 changes: 2 additions & 2 deletions python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pydantic = [
{ version = "^2.7.4", python = ">=3.12.4" },
]
requests = "^2"
orjson = "^3.9.14"
orjson = { version = "^3.9.14", markers = "platform_python_implementation != 'PyPy'" }
httpx = ">=0.23.0,<1"
requests-toolbelt = "^1.0.0"

Expand Down
8 changes: 4 additions & 4 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from unittest.mock import MagicMock, patch

import dataclasses_json
import orjson
import pytest
import requests
from multipart import MultipartParser, MultipartPart, parse_options_header
Expand All @@ -33,6 +32,7 @@
import langsmith.utils as ls_utils
from langsmith import AsyncClient, EvaluationResult, run_trees
from langsmith import schemas as ls_schemas
from langsmith._internal import _orjson
from langsmith._internal._serde import _serialize_json
from langsmith.client import (
Client,
Expand Down Expand Up @@ -848,7 +848,7 @@ class MyNamedTuple(NamedTuple):
"set_with_class": set([MyClass(1)]),
"my_mock": MagicMock(text="Hello, world"),
}
res = orjson.loads(_dumps_json(to_serialize))
res = _orjson.loads(_dumps_json(to_serialize))
assert (
"model_dump" not in caplog.text
), f"Unexpected error logs were emitted: {caplog.text}"
Expand Down Expand Up @@ -898,7 +898,7 @@ def __repr__(self) -> str:
my_cyclic = CyclicClass(other=CyclicClass(other=None))
my_cyclic.other.other = my_cyclic # type: ignore

res = orjson.loads(_dumps_json({"cyclic": my_cyclic}))
res = _orjson.loads(_dumps_json({"cyclic": my_cyclic}))
assert res == {"cyclic": "my_cycles..."}
expected = {"foo": "foo", "bar": 1}

Expand Down Expand Up @@ -1142,7 +1142,7 @@ def test_batch_ingest_run_splits_large_batches(
op
for call in mock_session.request.call_args_list
for reqs in (
orjson.loads(call[1]["data"]).values() if call[0][0] == "POST" else []
_orjson.loads(call[1]["data"]).values() if call[0][0] == "POST" else []
)
for op in reqs
]
Expand Down
Loading

0 comments on commit 79f3008

Please sign in to comment.