Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(py): Add multipart feedback ingestion #1129

Merged
merged 8 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/langsmith/_internal/_background_thread.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

Check notice on line 1 in python/langsmith/_internal/_background_thread.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... create_5_000_run_trees: Mean +- std dev: 572 ms +- 42 ms ......................................... create_10_000_run_trees: Mean +- std dev: 1.12 sec +- 0.05 sec ......................................... create_20_000_run_trees: Mean +- std dev: 1.12 sec +- 0.06 sec ......................................... dumps_class_nested_py_branch_and_leaf_200x400: Mean +- std dev: 774 us +- 16 us ......................................... dumps_class_nested_py_leaf_50x100: Mean +- std dev: 27.4 ms +- 0.4 ms ......................................... dumps_class_nested_py_leaf_100x200: Mean +- std dev: 114 ms +- 4 ms ......................................... dumps_dataclass_nested_50x100: Mean +- std dev: 27.9 ms +- 0.6 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (7.02 ms) is 12% of the mean (59.4 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. dumps_pydantic_nested_50x100: Mean +- std dev: 59.4 ms +- 7.0 ms ......................................... WARNING: the benchmark result may be unstable * the standard deviation (31.2 ms) is 15% of the mean (214 ms) Try to rerun the benchmark with more runs, values and/or loops. Run 'python -m pyperf system tune' command to reduce the system jitter. Use pyperf stats, pyperf dump and pyperf hist to analyze results. Use --quiet option to hide these warnings. dumps_pydanticv1_nested_50x100: Mean +- std dev: 214 ms +- 31 ms

Check notice on line 1 in python/langsmith/_internal/_background_thread.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +===============================================+=========+=======================+ | dumps_class_nested_py_branch_and_leaf_200x400 | 763 us | 774 us: 1.01x slower | +-----------------------------------------------+---------+-----------------------+ | dumps_class_nested_py_leaf_50x100 | 26.8 ms | 27.4 ms: 1.02x slower | +-----------------------------------------------+---------+-----------------------+ | dumps_dataclass_nested_50x100 | 27.1 ms | 27.9 ms: 1.03x slower | +-----------------------------------------------+---------+-----------------------+ | dumps_class_nested_py_leaf_100x200 | 111 ms | 114 ms: 1.03x slower | +-----------------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.02x slower | +-----------------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (5): create_10_000_run_trees, dumps_pydanticv1_nested_50x100, create_20_000_run_trees, create_5_000_run_trees, dumps_pydantic_nested_50x100

import logging
import sys
Expand Down Expand Up @@ -69,9 +69,12 @@
) -> None:
create = [it.item for it in batch if it.action == "create"]
update = [it.item for it in batch if it.action == "update"]
feedback = [it.item for it in batch if it.action == "feedback"]
try:
if use_multipart:
client.multipart_ingest_runs(create=create, update=update, pre_sampled=True)
client.multipart_ingest(
create=create, update=update, feedback=feedback, pre_sampled=True
)
else:
client.batch_ingest_runs(create=create, update=update, pre_sampled=True)
akira marked this conversation as resolved.
Show resolved Hide resolved
except Exception:
Expand Down
180 changes: 119 additions & 61 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,34 @@ def _run_transform(

return run_create

def _feedback_transform(
self,
feedback: Union[ls_schemas.Feedback, dict],
) -> dict:
"""Transform the given feedback object into a dictionary representation.

Args:
feedback (Union[ls_schemas.Feedback, dict]): The feedback object to transform.
update (bool, optional): Whether the payload is for an "update" event.
copy (bool, optional): Whether to deepcopy feedback inputs/outputs.
attachments_collector (Optional[dict[str, ls_schemas.Attachments]]):
A dictionary to collect attachments. If not passed, attachments
will be dropped.

Returns:
dict: The transformed feedback object as a dictionary.
"""
if hasattr(feedback, "dict") and callable(getattr(feedback, "dict")):
feedback_create: dict = feedback.dict() # type: ignore
else:
feedback_create = cast(dict, feedback)
if "id" not in feedback_create:
feedback_create["id"] = uuid.uuid4()
elif isinstance(feedback_create["id"], str):
feedback_create["id"] = uuid.UUID(feedback_create["id"])

return feedback_create

@staticmethod
def _insert_runtime_env(runs: Sequence[dict]) -> None:
runtime_env = ls_env.get_runtime_environment()
Expand Down Expand Up @@ -1408,14 +1436,15 @@ def _post_batch_ingest_runs(self, body: bytes, *, _context: str):
except Exception:
logger.warning(f"Failed to batch ingest runs: {repr(e)}")

def multipart_ingest_runs(
def multipart_ingest(
self,
create: Optional[
Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]]
] = None,
update: Optional[
Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]]
] = None,
feedback: Optional[Sequence[Union[ls_schemas.Feedback, Dict]]] = None,
akira marked this conversation as resolved.
Show resolved Hide resolved
*,
pre_sampled: bool = False,
) -> None:
Expand All @@ -1442,7 +1471,7 @@ def multipart_ingest_runs(
- The run objects MUST contain the dotted_order and trace_id fields
to be accepted by the API.
"""
if not create and not update:
if not create and not update and not feedback:
akira marked this conversation as resolved.
Show resolved Hide resolved
return
# transform and convert to dicts
all_attachments: Dict[str, ls_schemas.Attachments] = {}
Expand All @@ -1454,6 +1483,7 @@ def multipart_ingest_runs(
self._run_transform(run, update=True, attachments_collector=all_attachments)
for run in update or EMPTY_SEQ
]
feedback_dicts = [self._feedback_transform(f) for f in feedback or EMPTY_SEQ]
# require trace_id and dotted_order
if create_dicts:
for run in create_dicts:
Expand Down Expand Up @@ -1491,21 +1521,26 @@ def multipart_ingest_runs(
if not pre_sampled:
create_dicts = self._filter_for_sampling(create_dicts)
update_dicts = self._filter_for_sampling(update_dicts, patch=True)
if not create_dicts and not update_dicts:
if not create_dicts and not update_dicts and not feedback_dicts:
return
# insert runtime environment
self._insert_runtime_env(create_dicts)
self._insert_runtime_env(update_dicts)
# send the runs in multipart requests
acc_context: List[str] = []
acc_parts: MultipartParts = []
for event, payloads in (("post", create_dicts), ("patch", update_dicts)):
for event, payloads in (
("post", create_dicts),
("patch", update_dicts),
("feedback", feedback_dicts),
):
for payload in payloads:
# collect fields to be sent as separate parts
fields = [
("inputs", payload.pop("inputs", None)),
("outputs", payload.pop("outputs", None)),
("events", payload.pop("events", None)),
("feedback", payload.pop("feedback", None)),
]
# encode the main run payload
payloadb = _dumps_json(payload)
Expand Down Expand Up @@ -4115,6 +4150,7 @@ def _submit_feedback(**kwargs):
),
feedback_source_type=ls_schemas.FeedbackSourceType.MODEL,
project_id=project_id,
trace_id=run.trace_id if run else None,
)
return results

Expand Down Expand Up @@ -4185,6 +4221,7 @@ def create_feedback(
project_id: Optional[ID_TYPE] = None,
comparative_experiment_id: Optional[ID_TYPE] = None,
feedback_group_id: Optional[ID_TYPE] = None,
trace_id: Optional[ID_TYPE] = None,
agola11 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
) -> ls_schemas.Feedback:
"""Create a feedback in the LangSmith API.
Expand Down Expand Up @@ -4241,66 +4278,87 @@ def create_feedback(
f" endpoint: {sorted(kwargs)}",
DeprecationWarning,
)
if not isinstance(feedback_source_type, ls_schemas.FeedbackSourceType):
feedback_source_type = ls_schemas.FeedbackSourceType(feedback_source_type)
if feedback_source_type == ls_schemas.FeedbackSourceType.API:
feedback_source: ls_schemas.FeedbackSourceBase = (
ls_schemas.APIFeedbackSource(metadata=source_info)
try:
if not isinstance(feedback_source_type, ls_schemas.FeedbackSourceType):
feedback_source_type = ls_schemas.FeedbackSourceType(
feedback_source_type
)
if feedback_source_type == ls_schemas.FeedbackSourceType.API:
feedback_source: ls_schemas.FeedbackSourceBase = (
ls_schemas.APIFeedbackSource(metadata=source_info)
)
elif feedback_source_type == ls_schemas.FeedbackSourceType.MODEL:
feedback_source = ls_schemas.ModelFeedbackSource(metadata=source_info)
else:
raise ValueError(f"Unknown feedback source type {feedback_source_type}")
feedback_source.metadata = (
feedback_source.metadata if feedback_source.metadata is not None else {}
)
elif feedback_source_type == ls_schemas.FeedbackSourceType.MODEL:
feedback_source = ls_schemas.ModelFeedbackSource(metadata=source_info)
else:
raise ValueError(f"Unknown feedback source type {feedback_source_type}")
feedback_source.metadata = (
feedback_source.metadata if feedback_source.metadata is not None else {}
)
if source_run_id is not None and "__run" not in feedback_source.metadata:
feedback_source.metadata["__run"] = {"run_id": str(source_run_id)}
if feedback_source.metadata and "__run" in feedback_source.metadata:
# Validate that the linked run ID is a valid UUID
# Run info may be a base model or dict.
_run_meta: Union[dict, Any] = feedback_source.metadata["__run"]
if hasattr(_run_meta, "dict") and callable(_run_meta):
_run_meta = _run_meta.dict()
if "run_id" in _run_meta:
_run_meta["run_id"] = str(
_as_uuid(
feedback_source.metadata["__run"]["run_id"],
"feedback_source.metadata['__run']['run_id']",
if source_run_id is not None and "__run" not in feedback_source.metadata:
feedback_source.metadata["__run"] = {"run_id": str(source_run_id)}
if feedback_source.metadata and "__run" in feedback_source.metadata:
# Validate that the linked run ID is a valid UUID
# Run info may be a base model or dict.
_run_meta: Union[dict, Any] = feedback_source.metadata["__run"]
if hasattr(_run_meta, "dict") and callable(_run_meta):
_run_meta = _run_meta.dict()
if "run_id" in _run_meta:
_run_meta["run_id"] = str(
_as_uuid(
feedback_source.metadata["__run"]["run_id"],
"feedback_source.metadata['__run']['run_id']",
)
)
feedback_source.metadata["__run"] = _run_meta
feedback = ls_schemas.FeedbackCreate(
id=_ensure_uuid(feedback_id),
# If run_id is None, this is interpreted as session-level
# feedback.
run_id=_ensure_uuid(run_id, accept_null=True),
trace_id=_ensure_uuid(trace_id, accept_null=True),
key=key,
score=score,
value=value,
correction=correction,
comment=comment,
feedback_source=feedback_source,
created_at=datetime.datetime.now(datetime.timezone.utc),
modified_at=datetime.datetime.now(datetime.timezone.utc),
feedback_config=feedback_config,
session_id=_ensure_uuid(project_id, accept_null=True),
comparative_experiment_id=_ensure_uuid(
comparative_experiment_id, accept_null=True
),
feedback_group_id=_ensure_uuid(feedback_group_id, accept_null=True),
)

feedback_block = _dumps_json(feedback.dict(exclude_none=True))
use_multipart = (self.info.batch_ingest_config or {}).get(
"use_multipart_endpoint", False
)

if (
use_multipart
and self.tracing_queue is not None
and feedback.trace_id is not None
akira marked this conversation as resolved.
Show resolved Hide resolved
):
self.tracing_queue.put(
TracingQueueItem(str(feedback.id), "feedback", feedback)
)
feedback_source.metadata["__run"] = _run_meta
feedback = ls_schemas.FeedbackCreate(
id=_ensure_uuid(feedback_id),
# If run_id is None, this is interpreted as session-level
# feedback.
run_id=_ensure_uuid(run_id, accept_null=True),
key=key,
score=score,
value=value,
correction=correction,
comment=comment,
feedback_source=feedback_source,
created_at=datetime.datetime.now(datetime.timezone.utc),
modified_at=datetime.datetime.now(datetime.timezone.utc),
feedback_config=feedback_config,
session_id=_ensure_uuid(project_id, accept_null=True),
comparative_experiment_id=_ensure_uuid(
comparative_experiment_id, accept_null=True
),
feedback_group_id=_ensure_uuid(feedback_group_id, accept_null=True),
)
feedback_block = _dumps_json(feedback.dict(exclude_none=True))
self.request_with_retries(
"POST",
"/feedback",
request_kwargs={
"data": feedback_block,
},
stop_after_attempt=stop_after_attempt,
retry_on=(ls_utils.LangSmithNotFoundError,),
)
return ls_schemas.Feedback(**feedback.dict())
else:
self.request_with_retries(
"POST",
"/feedback",
request_kwargs={
"data": feedback_block,
},
stop_after_attempt=stop_after_attempt,
retry_on=(ls_utils.LangSmithNotFoundError,),
)
return ls_schemas.Feedback(**feedback.dict())
except Exception as e:
logger.error("Error creating feedback", exc_info=True)
akira marked this conversation as resolved.
Show resolved Hide resolved
raise e

def update_feedback(
self,
Expand Down
2 changes: 2 additions & 0 deletions python/langsmith/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ class FeedbackBase(BaseModel):
"""The time the feedback was last modified."""
run_id: Optional[UUID]
"""The associated run ID this feedback is logged for."""
trace_id: Optional[UUID]
akira marked this conversation as resolved.
Show resolved Hide resolved
"""The associated trace ID this feedback is logged for."""
key: str
"""The metric name, tag, or aspect to provide feedback on."""
score: SCORE_TYPE = None
Expand Down
32 changes: 14 additions & 18 deletions python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,7 @@ def test_batch_ingest_runs(
},
]
if use_multipart_endpoint:
langchain_client.multipart_ingest_runs(
create=runs_to_create, update=runs_to_update
)
langchain_client.multipart_ingest(create=runs_to_create, update=runs_to_update)
else:
langchain_client.batch_ingest_runs(create=runs_to_create, update=runs_to_update)
runs = []
Expand Down Expand Up @@ -744,25 +742,23 @@ def test_batch_ingest_runs(
"""


def test_multipart_ingest_runs_empty(
def test_multipart_ingest_empty(
langchain_client: Client, caplog: pytest.LogCaptureFixture
) -> None:
runs_to_create: list[dict] = []
runs_to_update: list[dict] = []

# make sure no warnings logged
with caplog.at_level(logging.WARNING, logger="langsmith.client"):
langchain_client.multipart_ingest_runs(
create=runs_to_create, update=runs_to_update
)
langchain_client.multipart_ingest(create=runs_to_create, update=runs_to_update)

assert not caplog.records


def test_multipart_ingest_runs_create_then_update(
def test_multipart_ingest_create_then_update(
langchain_client: Client, caplog: pytest.LogCaptureFixture
) -> None:
_session = "__test_multipart_ingest_runs_create_then_update"
_session = "__test_multipart_ingest_create_then_update"

trace_a_id = uuid4()
current_time = datetime.datetime.now(datetime.timezone.utc).strftime(
Expand All @@ -783,7 +779,7 @@ def test_multipart_ingest_runs_create_then_update(

# make sure no warnings logged
with caplog.at_level(logging.WARNING, logger="langsmith.client"):
langchain_client.multipart_ingest_runs(create=runs_to_create, update=[])
langchain_client.multipart_ingest(create=runs_to_create, update=[])

assert not caplog.records

Expand All @@ -796,15 +792,15 @@ def test_multipart_ingest_runs_create_then_update(
}
]
with caplog.at_level(logging.WARNING, logger="langsmith.client"):
langchain_client.multipart_ingest_runs(create=[], update=runs_to_update)
langchain_client.multipart_ingest(create=[], update=runs_to_update)

assert not caplog.records


def test_multipart_ingest_runs_update_then_create(
def test_multipart_ingest_update_then_create(
langchain_client: Client, caplog: pytest.LogCaptureFixture
) -> None:
_session = "__test_multipart_ingest_runs_update_then_create"
_session = "__test_multipart_ingest_update_then_create"

trace_a_id = uuid4()
current_time = datetime.datetime.now(datetime.timezone.utc).strftime(
Expand All @@ -822,7 +818,7 @@ def test_multipart_ingest_runs_update_then_create(

# make sure no warnings logged
with caplog.at_level(logging.WARNING, logger="langsmith.client"):
langchain_client.multipart_ingest_runs(create=[], update=runs_to_update)
langchain_client.multipart_ingest(create=[], update=runs_to_update)

assert not caplog.records

Expand All @@ -839,15 +835,15 @@ def test_multipart_ingest_runs_update_then_create(
]

with caplog.at_level(logging.WARNING, logger="langsmith.client"):
langchain_client.multipart_ingest_runs(create=runs_to_create, update=[])
langchain_client.multipart_ingest(create=runs_to_create, update=[])

assert not caplog.records


def test_multipart_ingest_runs_create_wrong_type(
def test_multipart_ingest_create_wrong_type(
langchain_client: Client, caplog: pytest.LogCaptureFixture
) -> None:
_session = "__test_multipart_ingest_runs_create_then_update"
_session = "__test_multipart_ingest_create_then_update"

trace_a_id = uuid4()
current_time = datetime.datetime.now(datetime.timezone.utc).strftime(
Expand All @@ -868,7 +864,7 @@ def test_multipart_ingest_runs_create_wrong_type(

# make sure no warnings logged
with caplog.at_level(logging.WARNING, logger="langsmith.client"):
langchain_client.multipart_ingest_runs(create=runs_to_create, update=[])
langchain_client.multipart_ingest(create=runs_to_create, update=[])

# this should 422
assert len(caplog.records) == 1, "Should get 1 warning for 422, not retried"
Expand Down
2 changes: 1 addition & 1 deletion python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def test_batch_ingest_run_splits_large_batches(
for run_id in patch_ids
]
if use_multipart_endpoint:
client.multipart_ingest_runs(create=posts, update=patches)
client.multipart_ingest(create=posts, update=patches)
# multipart endpoint should only send one request
expected_num_requests = 1
# count the number of POST requests
Expand Down
Loading