Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
angus-langchain committed Jan 22, 2025
1 parent 9cfc403 commit 478d1a2
Showing 1 changed file with 77 additions and 77 deletions.
154 changes: 77 additions & 77 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,83 @@ def test_create_run_with_zstd_compression(mock_session_cls: mock.Mock) -> None:
)


@patch("langsmith.client.requests.Session")
def test_create_feedback_with_zstd_compression(mock_session_cls: mock.Mock) -> None:
"""Test that feedback is sent using zstd compression when compression is enabled."""
# Prepare a mocked session
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_session.request.return_value = mock_response
mock_session_cls.return_value = mock_session

with patch.dict("os.environ", {}, clear=True):
info = ls_schemas.LangSmithInfo(
version="0.8.11",
instance_flags={"zstd_compression_enabled": True},
batch_ingest_config=ls_schemas.BatchIngestConfig(
use_multipart_endpoint=True,
size_limit=1,
size_limit_bytes=128,
scale_up_nthreads_limit=4,
scale_up_qsize_trigger=3,
scale_down_nempty_trigger=1,
),
)
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=True,
session=mock_session,
info=info,
)

# Create a few runs with larger payloads so there's something to compress
run_id = uuid.uuid4()

feedback_data = {
"comment": "test comment",
"score": 0.95,
}
client.create_feedback(
run_id=run_id, key="test_key", trace_id=run_id, **feedback_data
)

# Let the background threads flush
if client.tracing_queue:
client.tracing_queue.join()
if client._futures is not None:
for fut in client._futures:
fut.result()

time.sleep(0.1)

# Inspect the calls
post_calls = [
call_obj
for call_obj in mock_session.request.mock_calls
if call_obj.args and call_obj.args[0] == "POST"
]
assert len(post_calls) == 1, "Expected exactly one POST request"

call_data = post_calls[0][2]["data"]
if hasattr(call_data, "read"):
call_data = call_data.read()

# Check for zstd magic bytes
zstd_magic = b"\x28\xb5\x2f\xfd"
assert call_data.startswith(zstd_magic), (
"Expected the request body to start with zstd magic bytes; "
"it appears feedback was not compressed."
)

# Verify Content-Encoding header
headers = post_calls[0][2]["headers"]
assert (
headers.get("Content-Encoding") == "zstd"
), "Expected Content-Encoding header to be 'zstd'"


@patch("langsmith.client.requests.Session")
def test_create_run_without_compression_support(mock_session_cls: mock.Mock) -> None:
"""Test that runs use regular multipart when server doesn't support compression."""
Expand Down Expand Up @@ -2348,80 +2425,3 @@ def test_create_run_with_disabled_compression(mock_session_cls: mock.Mock) -> No
run_parsed = json.loads(parts[0].value)
assert run_parsed["trace_id"] == str(run_id)
assert run_parsed["dotted_order"] == str(run_id)


@patch("langsmith.client.requests.Session")
def test_create_feedback_with_zstd_compression(mock_session_cls: mock.Mock) -> None:
"""Test that feedback is sent using zstd compression when compression is enabled."""
# Prepare a mocked session
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 200
mock_session.request.return_value = mock_response
mock_session_cls.return_value = mock_session

with patch.dict("os.environ", {}, clear=True):
info = ls_schemas.LangSmithInfo(
version="0.8.11",
instance_flags={"zstd_compression_enabled": True},
batch_ingest_config=ls_schemas.BatchIngestConfig(
use_multipart_endpoint=True,
size_limit=1,
size_limit_bytes=128,
scale_up_nthreads_limit=4,
scale_up_qsize_trigger=3,
scale_down_nempty_trigger=1,
),
)
client = Client(
api_url="http://localhost:1984",
api_key="123",
auto_batch_tracing=True,
session=mock_session,
info=info,
)

# Create a few runs with larger payloads so there's something to compress
run_id = uuid.uuid4()

feedback_data = {
"comment": "test comment",
"score": 0.95,
}
client.create_feedback(
run_id=run_id, key="test_key", trace_id=run_id, **feedback_data
)

# Let the background threads flush
if client.tracing_queue:
client.tracing_queue.join()
if client._futures is not None:
for fut in client._futures:
fut.result()

time.sleep(0.1)

# Inspect the calls
post_calls = [
call_obj
for call_obj in mock_session.request.mock_calls
if call_obj.args and call_obj.args[0] == "POST"
]
assert len(post_calls) == 1, "Expected exactly one POST request"

call_data = post_calls[0][2]["data"]
if hasattr(call_data, "read"):
call_data = call_data.read()

# Check for zstd magic bytes
zstd_magic = b"\x28\xb5\x2f\xfd"
assert call_data.startswith(zstd_magic), (
"Expected the request body to start with zstd magic bytes; "
"it appears feedback was not compressed."
)

# Verify Content-Encoding header
headers = post_calls[0][2]["headers"]
assert (
headers.get("Content-Encoding") == "zstd"
), "Expected Content-Encoding header to be 'zstd'"

0 comments on commit 478d1a2

Please sign in to comment.