Skip to content

Commit

Permalink
Add multipart feedback ingestion
Browse files Browse the repository at this point in the history
  • Loading branch information
akira committed Oct 25, 2024
1 parent 7c51b42 commit df84a79
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 71 deletions.
198 changes: 131 additions & 67 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,34 @@ def _run_transform(

return run_create

def _feedback_transform(
self,
feedback: Union[ls_schemas.Feedback, dict],
) -> dict:
"""Transform the given feedback object into a dictionary representation.
Args:
feedback (Union[ls_schemas.Feedback, dict]): The feedback object to transform.
update (bool, optional): Whether the payload is for an "update" event.
copy (bool, optional): Whether to deepcopy feedback inputs/outputs.
attachments_collector (Optional[dict[str, ls_schemas.Attachments]]):
A dictionary to collect attachments. If not passed, attachments
will be dropped.
Returns:
dict: The transformed feedback object as a dictionary.
"""
if hasattr(feedback, "dict") and callable(getattr(feedback, "dict")):
feedback_create: dict = feedback.dict() # type: ignore
else:
feedback_create = cast(dict, feedback)
if "id" not in feedback_create:
feedback_create["id"] = uuid.uuid4()
elif isinstance(feedback_create["id"], str):
feedback_create["id"] = uuid.UUID(feedback_create["id"])

return feedback_create

@staticmethod
def _insert_runtime_env(runs: Sequence[dict]) -> None:
runtime_env = ls_env.get_runtime_environment()
Expand Down Expand Up @@ -1556,14 +1584,15 @@ def _post_batch_ingest_runs(self, body: bytes, *, _context: str):
except Exception:
logger.warning(f"Failed to batch ingest runs: {repr(e)}")

def multipart_ingest_runs(
def multipart_ingest(
self,
create: Optional[
Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]]
] = None,
update: Optional[
Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]]
] = None,
feedback: Optional[Sequence[Union[ls_schemas.Feedback, Dict]]] = None,
*,
pre_sampled: bool = False,
) -> None:
Expand All @@ -1590,7 +1619,7 @@ def multipart_ingest_runs(
- The run objects MUST contain the dotted_order and trace_id fields
to be accepted by the API.
"""
if not create and not update:
if not create and not update and not feedback:
return
# transform and convert to dicts
all_attachments: Dict[str, ls_schemas.Attachments] = {}
Expand All @@ -1602,6 +1631,7 @@ def multipart_ingest_runs(
self._run_transform(run, update=True, attachments_collector=all_attachments)
for run in update or EMPTY_SEQ
]
feedback_dicts = [self._feedback_transform(f) for f in feedback or EMPTY_SEQ]
# require trace_id and dotted_order
if create_dicts:
for run in create_dicts:
Expand Down Expand Up @@ -1639,22 +1669,36 @@ def multipart_ingest_runs(
if not pre_sampled:
create_dicts = self._filter_for_sampling(create_dicts)
update_dicts = self._filter_for_sampling(update_dicts, patch=True)
if not create_dicts and not update_dicts:
if not create_dicts and not update_dicts and not feedback_dicts:
return
# insert runtime environment
self._insert_runtime_env(create_dicts)
self._insert_runtime_env(update_dicts)
# send the runs in multipart requests
acc_context: List[str] = []
acc_parts: MultipartParts = []
for event, payloads in (("post", create_dicts), ("patch", update_dicts)):
for event, payloads in (
("post", create_dicts),
("patch", update_dicts),
("feedback", feedback_dicts),
):
for payload in payloads:
# collect fields to be sent as separate parts
fields = [
("inputs", payload.pop("inputs", None)),
("outputs", payload.pop("outputs", None)),
("events", payload.pop("events", None)),
]
fields = []
if create_dicts or update_dicts:
fields.extend(
[
("inputs", payload.pop("inputs", None)),
("outputs", payload.pop("outputs", None)),
("events", payload.pop("events", None)),
]
)
if feedback:
fields.extend(
[
("feedback", payload.pop("feedback", None)),
]
)
# encode the main run payload
payloadb = _dumps_json(payload)
acc_parts.append(
Expand Down Expand Up @@ -4387,66 +4431,83 @@ def create_feedback(
f" endpoint: {sorted(kwargs)}",
DeprecationWarning,
)
if not isinstance(feedback_source_type, ls_schemas.FeedbackSourceType):
feedback_source_type = ls_schemas.FeedbackSourceType(feedback_source_type)
if feedback_source_type == ls_schemas.FeedbackSourceType.API:
feedback_source: ls_schemas.FeedbackSourceBase = (
ls_schemas.APIFeedbackSource(metadata=source_info)
try:
if not isinstance(feedback_source_type, ls_schemas.FeedbackSourceType):
feedback_source_type = ls_schemas.FeedbackSourceType(
feedback_source_type
)
if feedback_source_type == ls_schemas.FeedbackSourceType.API:
feedback_source: ls_schemas.FeedbackSourceBase = (
ls_schemas.APIFeedbackSource(metadata=source_info)
)
elif feedback_source_type == ls_schemas.FeedbackSourceType.MODEL:
feedback_source = ls_schemas.ModelFeedbackSource(metadata=source_info)
else:
raise ValueError(f"Unknown feedback source type {feedback_source_type}")
feedback_source.metadata = (
feedback_source.metadata if feedback_source.metadata is not None else {}
)
elif feedback_source_type == ls_schemas.FeedbackSourceType.MODEL:
feedback_source = ls_schemas.ModelFeedbackSource(metadata=source_info)
else:
raise ValueError(f"Unknown feedback source type {feedback_source_type}")
feedback_source.metadata = (
feedback_source.metadata if feedback_source.metadata is not None else {}
)
if source_run_id is not None and "__run" not in feedback_source.metadata:
feedback_source.metadata["__run"] = {"run_id": str(source_run_id)}
if feedback_source.metadata and "__run" in feedback_source.metadata:
# Validate that the linked run ID is a valid UUID
# Run info may be a base model or dict.
_run_meta: Union[dict, Any] = feedback_source.metadata["__run"]
if hasattr(_run_meta, "dict") and callable(_run_meta):
_run_meta = _run_meta.dict()
if "run_id" in _run_meta:
_run_meta["run_id"] = str(
_as_uuid(
feedback_source.metadata["__run"]["run_id"],
"feedback_source.metadata['__run']['run_id']",
if source_run_id is not None and "__run" not in feedback_source.metadata:
feedback_source.metadata["__run"] = {"run_id": str(source_run_id)}
if feedback_source.metadata and "__run" in feedback_source.metadata:
# Validate that the linked run ID is a valid UUID
# Run info may be a base model or dict.
_run_meta: Union[dict, Any] = feedback_source.metadata["__run"]
if hasattr(_run_meta, "dict") and callable(_run_meta):
_run_meta = _run_meta.dict()
if "run_id" in _run_meta:
_run_meta["run_id"] = str(
_as_uuid(
feedback_source.metadata["__run"]["run_id"],
"feedback_source.metadata['__run']['run_id']",
)
)
feedback_source.metadata["__run"] = _run_meta
feedback = ls_schemas.FeedbackCreate(
id=_ensure_uuid(feedback_id),
# If run_id is None, this is interpreted as session-level
# feedback.
run_id=_ensure_uuid(run_id, accept_null=True),
trace_id=_ensure_uuid(run_id, accept_null=True),
key=key,
score=score,
value=value,
correction=correction,
comment=comment,
feedback_source=feedback_source,
created_at=datetime.datetime.now(datetime.timezone.utc),
modified_at=datetime.datetime.now(datetime.timezone.utc),
feedback_config=feedback_config,
session_id=_ensure_uuid(project_id, accept_null=True),
comparative_experiment_id=_ensure_uuid(
comparative_experiment_id, accept_null=True
),
feedback_group_id=_ensure_uuid(feedback_group_id, accept_null=True),
)

feedback_block = _dumps_json(feedback.dict(exclude_none=True))
use_multipart = (self.info.batch_ingest_config or {}).get(
"use_multipart_endpoint", False
)

if use_multipart and self.tracing_queue is not None:
self.tracing_queue.put(
TracingQueueItem(str(feedback.id), "feedback", feedback)
)
feedback_source.metadata["__run"] = _run_meta
feedback = ls_schemas.FeedbackCreate(
id=_ensure_uuid(feedback_id),
# If run_id is None, this is interpreted as session-level
# feedback.
run_id=_ensure_uuid(run_id, accept_null=True),
key=key,
score=score,
value=value,
correction=correction,
comment=comment,
feedback_source=feedback_source,
created_at=datetime.datetime.now(datetime.timezone.utc),
modified_at=datetime.datetime.now(datetime.timezone.utc),
feedback_config=feedback_config,
session_id=_ensure_uuid(project_id, accept_null=True),
comparative_experiment_id=_ensure_uuid(
comparative_experiment_id, accept_null=True
),
feedback_group_id=_ensure_uuid(feedback_group_id, accept_null=True),
)
feedback_block = _dumps_json(feedback.dict(exclude_none=True))
self.request_with_retries(
"POST",
"/feedback",
request_kwargs={
"data": feedback_block,
},
stop_after_attempt=stop_after_attempt,
retry_on=(ls_utils.LangSmithNotFoundError,),
)
return ls_schemas.Feedback(**feedback.dict())
else:
self.request_with_retries(
"POST",
"/feedback",
request_kwargs={
"data": feedback_block,
},
stop_after_attempt=stop_after_attempt,
retry_on=(ls_utils.LangSmithNotFoundError,),
)
return ls_schemas.Feedback(**feedback.dict())
except Exception as e:
logger.error("Error creating feedback", exc_info=True)
raise e

def update_feedback(
self,
Expand Down Expand Up @@ -5797,9 +5858,12 @@ def _tracing_thread_handle_batch(
) -> None:
create = [it.item for it in batch if it.action == "create"]
update = [it.item for it in batch if it.action == "update"]
feedback = [it.item for it in batch if it.action == "feedback"]
try:
if use_multipart:
client.multipart_ingest_runs(create=create, update=update, pre_sampled=True)
client.multipart_ingest(
create=create, update=update, feedback=feedback, pre_sampled=True
)
else:
client.batch_ingest_runs(create=create, update=update, pre_sampled=True)
except Exception:
Expand Down
2 changes: 2 additions & 0 deletions python/langsmith/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ class FeedbackBase(BaseModel):
"""The time the feedback was last modified."""
run_id: Optional[UUID]
"""The associated run ID this feedback is logged for."""
trace_id: Optional[UUID]
"""The associated trace ID this feedback is logged for."""
key: str
"""The metric name, tag, or aspect to provide feedback on."""
score: SCORE_TYPE = None
Expand Down
4 changes: 1 addition & 3 deletions python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,9 +673,7 @@ def test_batch_ingest_runs(
},
]
if use_multipart_endpoint:
langchain_client.multipart_ingest_runs(
create=runs_to_create, update=runs_to_update
)
langchain_client.multipart_ingest(create=runs_to_create, update=runs_to_update)
else:
langchain_client.batch_ingest_runs(create=runs_to_create, update=runs_to_update)
runs = []
Expand Down
2 changes: 1 addition & 1 deletion python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def test_batch_ingest_run_splits_large_batches(
for run_id in patch_ids
]
if use_multipart_endpoint:
client.multipart_ingest_runs(create=posts, update=patches)
client.multipart_ingest(create=posts, update=patches)
# multipart endpoint should only send one request
expected_num_requests = 1
# count the number of POST requests
Expand Down

0 comments on commit df84a79

Please sign in to comment.