diff --git a/python/langsmith/_internal/_background_thread.py b/python/langsmith/_internal/_background_thread.py index 525a3513c..3a468643f 100644 --- a/python/langsmith/_internal/_background_thread.py +++ b/python/langsmith/_internal/_background_thread.py @@ -69,9 +69,12 @@ def _tracing_thread_handle_batch( ) -> None: create = [it.item for it in batch if it.action == "create"] update = [it.item for it in batch if it.action == "update"] + feedback = [it.item for it in batch if it.action == "feedback"] try: if use_multipart: - client.multipart_ingest_runs(create=create, update=update, pre_sampled=True) + client.multipart_ingest( + create=create, update=update, feedback=feedback, pre_sampled=True + ) else: client.batch_ingest_runs(create=create, update=update, pre_sampled=True) except Exception: diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 4a1601c44..b09833d4f 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -1122,6 +1122,34 @@ def _run_transform( return run_create + def _feedback_transform( + self, + feedback: Union[ls_schemas.Feedback, dict], + ) -> dict: + """Transform the given feedback object into a dictionary representation. + + Args: + feedback (Union[ls_schemas.Feedback, dict]): The feedback object to transform. + update (bool, optional): Whether the payload is for an "update" event. + copy (bool, optional): Whether to deepcopy feedback inputs/outputs. + attachments_collector (Optional[dict[str, ls_schemas.Attachments]]): + A dictionary to collect attachments. If not passed, attachments + will be dropped. + + Returns: + dict: The transformed feedback object as a dictionary. + """ + if hasattr(feedback, "dict") and callable(getattr(feedback, "dict")): + feedback_create: dict = feedback.dict() # type: ignore + else: + feedback_create = cast(dict, feedback) + if "id" not in feedback_create: + feedback_create["id"] = uuid.uuid4() + elif isinstance(feedback_create["id"], str): + feedback_create["id"] = uuid.UUID(feedback_create["id"]) + + return feedback_create + @staticmethod def _insert_runtime_env(runs: Sequence[dict]) -> None: runtime_env = ls_env.get_runtime_environment() @@ -1408,7 +1436,7 @@ def _post_batch_ingest_runs(self, body: bytes, *, _context: str): except Exception: logger.warning(f"Failed to batch ingest runs: {repr(e)}") - def multipart_ingest_runs( + def multipart_ingest( self, create: Optional[ Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]] @@ -1416,6 +1444,7 @@ def multipart_ingest_runs( update: Optional[ Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]] ] = None, + feedback: Optional[Sequence[Union[ls_schemas.Feedback, Dict]]] = None, *, pre_sampled: bool = False, ) -> None: @@ -1442,7 +1471,7 @@ def multipart_ingest_runs( - The run objects MUST contain the dotted_order and trace_id fields to be accepted by the API. """ - if not create and not update: + if not (create or update or feedback): return # transform and convert to dicts all_attachments: Dict[str, ls_schemas.Attachments] = {} @@ -1454,6 +1483,7 @@ def multipart_ingest_runs( self._run_transform(run, update=True, attachments_collector=all_attachments) for run in update or EMPTY_SEQ ] + feedback_dicts = [self._feedback_transform(f) for f in feedback or EMPTY_SEQ] # require trace_id and dotted_order if create_dicts: for run in create_dicts: @@ -1491,7 +1521,7 @@ def multipart_ingest_runs( if not pre_sampled: create_dicts = self._filter_for_sampling(create_dicts) update_dicts = self._filter_for_sampling(update_dicts, patch=True) - if not create_dicts and not update_dicts: + if not create_dicts and not update_dicts and not feedback_dicts: return # insert runtime environment self._insert_runtime_env(create_dicts) @@ -1499,13 +1529,18 @@ def multipart_ingest_runs( # send the runs in multipart requests acc_context: List[str] = [] acc_parts: MultipartParts = [] - for event, payloads in (("post", create_dicts), ("patch", update_dicts)): + for event, payloads in ( + ("post", create_dicts), + ("patch", update_dicts), + ("feedback", feedback_dicts), + ): for payload in payloads: # collect fields to be sent as separate parts fields = [ ("inputs", payload.pop("inputs", None)), ("outputs", payload.pop("outputs", None)), ("events", payload.pop("events", None)), + ("feedback", payload.pop("feedback", None)), ] # encode the main run payload payloadb = _dumps_json(payload) @@ -4115,6 +4150,7 @@ def _submit_feedback(**kwargs): ), feedback_source_type=ls_schemas.FeedbackSourceType.MODEL, project_id=project_id, + trace_id=run.trace_id if run else None, ) return results @@ -4185,6 +4221,7 @@ def create_feedback( project_id: Optional[ID_TYPE] = None, comparative_experiment_id: Optional[ID_TYPE] = None, feedback_group_id: Optional[ID_TYPE] = None, + trace_id: Optional[ID_TYPE] = None, **kwargs: Any, ) -> ls_schemas.Feedback: """Create a feedback in the LangSmith API. @@ -4194,6 +4231,8 @@ def create_feedback( run_id : str or UUID The ID of the run to provide feedback for. Either the run_id OR the project_id must be provided. + trace_id : str or UUID + The trace ID of the run to provide feedback for. This is optional. key : str The name of the metric or 'aspect' this feedback is about. score : float or int or bool or None, default=None @@ -4241,66 +4280,87 @@ def create_feedback( f" endpoint: {sorted(kwargs)}", DeprecationWarning, ) - if not isinstance(feedback_source_type, ls_schemas.FeedbackSourceType): - feedback_source_type = ls_schemas.FeedbackSourceType(feedback_source_type) - if feedback_source_type == ls_schemas.FeedbackSourceType.API: - feedback_source: ls_schemas.FeedbackSourceBase = ( - ls_schemas.APIFeedbackSource(metadata=source_info) + try: + if not isinstance(feedback_source_type, ls_schemas.FeedbackSourceType): + feedback_source_type = ls_schemas.FeedbackSourceType( + feedback_source_type + ) + if feedback_source_type == ls_schemas.FeedbackSourceType.API: + feedback_source: ls_schemas.FeedbackSourceBase = ( + ls_schemas.APIFeedbackSource(metadata=source_info) + ) + elif feedback_source_type == ls_schemas.FeedbackSourceType.MODEL: + feedback_source = ls_schemas.ModelFeedbackSource(metadata=source_info) + else: + raise ValueError(f"Unknown feedback source type {feedback_source_type}") + feedback_source.metadata = ( + feedback_source.metadata if feedback_source.metadata is not None else {} ) - elif feedback_source_type == ls_schemas.FeedbackSourceType.MODEL: - feedback_source = ls_schemas.ModelFeedbackSource(metadata=source_info) - else: - raise ValueError(f"Unknown feedback source type {feedback_source_type}") - feedback_source.metadata = ( - feedback_source.metadata if feedback_source.metadata is not None else {} - ) - if source_run_id is not None and "__run" not in feedback_source.metadata: - feedback_source.metadata["__run"] = {"run_id": str(source_run_id)} - if feedback_source.metadata and "__run" in feedback_source.metadata: - # Validate that the linked run ID is a valid UUID - # Run info may be a base model or dict. - _run_meta: Union[dict, Any] = feedback_source.metadata["__run"] - if hasattr(_run_meta, "dict") and callable(_run_meta): - _run_meta = _run_meta.dict() - if "run_id" in _run_meta: - _run_meta["run_id"] = str( - _as_uuid( - feedback_source.metadata["__run"]["run_id"], - "feedback_source.metadata['__run']['run_id']", + if source_run_id is not None and "__run" not in feedback_source.metadata: + feedback_source.metadata["__run"] = {"run_id": str(source_run_id)} + if feedback_source.metadata and "__run" in feedback_source.metadata: + # Validate that the linked run ID is a valid UUID + # Run info may be a base model or dict. + _run_meta: Union[dict, Any] = feedback_source.metadata["__run"] + if hasattr(_run_meta, "dict") and callable(_run_meta): + _run_meta = _run_meta.dict() + if "run_id" in _run_meta: + _run_meta["run_id"] = str( + _as_uuid( + feedback_source.metadata["__run"]["run_id"], + "feedback_source.metadata['__run']['run_id']", + ) ) + feedback_source.metadata["__run"] = _run_meta + feedback = ls_schemas.FeedbackCreate( + id=_ensure_uuid(feedback_id), + # If run_id is None, this is interpreted as session-level + # feedback. + run_id=_ensure_uuid(run_id, accept_null=True), + trace_id=_ensure_uuid(trace_id, accept_null=True), + key=key, + score=score, + value=value, + correction=correction, + comment=comment, + feedback_source=feedback_source, + created_at=datetime.datetime.now(datetime.timezone.utc), + modified_at=datetime.datetime.now(datetime.timezone.utc), + feedback_config=feedback_config, + session_id=_ensure_uuid(project_id, accept_null=True), + comparative_experiment_id=_ensure_uuid( + comparative_experiment_id, accept_null=True + ), + feedback_group_id=_ensure_uuid(feedback_group_id, accept_null=True), + ) + + feedback_block = _dumps_json(feedback.dict(exclude_none=True)) + use_multipart = (self.info.batch_ingest_config or {}).get( + "use_multipart_endpoint", False + ) + + if ( + use_multipart + and self.tracing_queue is not None + and feedback.trace_id is not None + ): + self.tracing_queue.put( + TracingQueueItem(str(feedback.id), "feedback", feedback) ) - feedback_source.metadata["__run"] = _run_meta - feedback = ls_schemas.FeedbackCreate( - id=_ensure_uuid(feedback_id), - # If run_id is None, this is interpreted as session-level - # feedback. - run_id=_ensure_uuid(run_id, accept_null=True), - key=key, - score=score, - value=value, - correction=correction, - comment=comment, - feedback_source=feedback_source, - created_at=datetime.datetime.now(datetime.timezone.utc), - modified_at=datetime.datetime.now(datetime.timezone.utc), - feedback_config=feedback_config, - session_id=_ensure_uuid(project_id, accept_null=True), - comparative_experiment_id=_ensure_uuid( - comparative_experiment_id, accept_null=True - ), - feedback_group_id=_ensure_uuid(feedback_group_id, accept_null=True), - ) - feedback_block = _dumps_json(feedback.dict(exclude_none=True)) - self.request_with_retries( - "POST", - "/feedback", - request_kwargs={ - "data": feedback_block, - }, - stop_after_attempt=stop_after_attempt, - retry_on=(ls_utils.LangSmithNotFoundError,), - ) - return ls_schemas.Feedback(**feedback.dict()) + else: + self.request_with_retries( + "POST", + "/feedback", + request_kwargs={ + "data": feedback_block, + }, + stop_after_attempt=stop_after_attempt, + retry_on=(ls_utils.LangSmithNotFoundError,), + ) + return ls_schemas.Feedback(**feedback.dict()) + except Exception as e: + logger.error("Error creating feedback", exc_info=True) + raise e def update_feedback( self, diff --git a/python/langsmith/schemas.py b/python/langsmith/schemas.py index 8ad12b3d0..3f93b4363 100644 --- a/python/langsmith/schemas.py +++ b/python/langsmith/schemas.py @@ -440,6 +440,8 @@ class FeedbackBase(BaseModel): """The time the feedback was last modified.""" run_id: Optional[UUID] """The associated run ID this feedback is logged for.""" + trace_id: Optional[UUID] + """The associated trace ID this feedback is logged for.""" key: str """The metric name, tag, or aspect to provide feedback on.""" score: SCORE_TYPE = None diff --git a/python/pyproject.toml b/python/pyproject.toml index fc5e0f0af..620a3662a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langsmith" -version = "0.1.137" +version = "0.1.138rc1" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." authors = ["LangChain "] license = "MIT" diff --git a/python/tests/integration_tests/test_client.py b/python/tests/integration_tests/test_client.py index ad4b2cd93..059111e25 100644 --- a/python/tests/integration_tests/test_client.py +++ b/python/tests/integration_tests/test_client.py @@ -684,8 +684,19 @@ def test_batch_ingest_runs( }, ] if use_multipart_endpoint: - langchain_client.multipart_ingest_runs( - create=runs_to_create, update=runs_to_update + feedback = [ + { + "run_id": run["id"], + "trace_id": run["trace_id"], + "key": "test_key", + "score": 0.9, + "value": "test_value", + "comment": "test_comment", + } + for run in runs_to_create + ] + langchain_client.multipart_ingest( + create=runs_to_create, update=runs_to_update, feedback=feedback ) else: langchain_client.batch_ingest_runs(create=runs_to_create, update=runs_to_update) @@ -726,6 +737,17 @@ def test_batch_ingest_runs( assert run3.inputs == {"input1": 1, "input2": 2} assert run3.error == "error" + if use_multipart_endpoint: + feedbacks = list( + langchain_client.list_feedback(run_ids=[run.id for run in runs]) + ) + assert len(feedbacks) == 3 + for feedback in feedbacks: + assert feedback.key == "test_key" + assert feedback.score == 0.9 + assert feedback.value == "test_value" + assert feedback.comment == "test_comment" + """ Multipart partitions: @@ -744,7 +766,7 @@ def test_batch_ingest_runs( """ -def test_multipart_ingest_runs_empty( +def test_multipart_ingest_empty( langchain_client: Client, caplog: pytest.LogCaptureFixture ) -> None: runs_to_create: list[dict] = [] @@ -752,17 +774,15 @@ def test_multipart_ingest_runs_empty( # make sure no warnings logged with caplog.at_level(logging.WARNING, logger="langsmith.client"): - langchain_client.multipart_ingest_runs( - create=runs_to_create, update=runs_to_update - ) + langchain_client.multipart_ingest(create=runs_to_create, update=runs_to_update) assert not caplog.records -def test_multipart_ingest_runs_create_then_update( +def test_multipart_ingest_create_then_update( langchain_client: Client, caplog: pytest.LogCaptureFixture ) -> None: - _session = "__test_multipart_ingest_runs_create_then_update" + _session = "__test_multipart_ingest_create_then_update" trace_a_id = uuid4() current_time = datetime.datetime.now(datetime.timezone.utc).strftime( @@ -783,7 +803,7 @@ def test_multipart_ingest_runs_create_then_update( # make sure no warnings logged with caplog.at_level(logging.WARNING, logger="langsmith.client"): - langchain_client.multipart_ingest_runs(create=runs_to_create, update=[]) + langchain_client.multipart_ingest(create=runs_to_create, update=[]) assert not caplog.records @@ -796,15 +816,15 @@ def test_multipart_ingest_runs_create_then_update( } ] with caplog.at_level(logging.WARNING, logger="langsmith.client"): - langchain_client.multipart_ingest_runs(create=[], update=runs_to_update) + langchain_client.multipart_ingest(create=[], update=runs_to_update) assert not caplog.records -def test_multipart_ingest_runs_update_then_create( +def test_multipart_ingest_update_then_create( langchain_client: Client, caplog: pytest.LogCaptureFixture ) -> None: - _session = "__test_multipart_ingest_runs_update_then_create" + _session = "__test_multipart_ingest_update_then_create" trace_a_id = uuid4() current_time = datetime.datetime.now(datetime.timezone.utc).strftime( @@ -822,7 +842,7 @@ def test_multipart_ingest_runs_update_then_create( # make sure no warnings logged with caplog.at_level(logging.WARNING, logger="langsmith.client"): - langchain_client.multipart_ingest_runs(create=[], update=runs_to_update) + langchain_client.multipart_ingest(create=[], update=runs_to_update) assert not caplog.records @@ -839,15 +859,15 @@ def test_multipart_ingest_runs_update_then_create( ] with caplog.at_level(logging.WARNING, logger="langsmith.client"): - langchain_client.multipart_ingest_runs(create=runs_to_create, update=[]) + langchain_client.multipart_ingest(create=runs_to_create, update=[]) assert not caplog.records -def test_multipart_ingest_runs_create_wrong_type( +def test_multipart_ingest_create_wrong_type( langchain_client: Client, caplog: pytest.LogCaptureFixture ) -> None: - _session = "__test_multipart_ingest_runs_create_then_update" + _session = "__test_multipart_ingest_create_then_update" trace_a_id = uuid4() current_time = datetime.datetime.now(datetime.timezone.utc).strftime( @@ -868,7 +888,7 @@ def test_multipart_ingest_runs_create_wrong_type( # make sure no warnings logged with caplog.at_level(logging.WARNING, logger="langsmith.client"): - langchain_client.multipart_ingest_runs(create=runs_to_create, update=[]) + langchain_client.multipart_ingest(create=runs_to_create, update=[]) # this should 422 assert len(caplog.records) == 1, "Should get 1 warning for 422, not retried" diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index a3f15671b..d5e4b5cde 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -1059,8 +1059,20 @@ def test_batch_ingest_run_splits_large_batches( } for run_id in patch_ids ] + if use_multipart_endpoint: - client.multipart_ingest_runs(create=posts, update=patches) + feedback = [ + { + "run_id": run_id, + "trace_id": run_id, + "key": "test_key", + "score": 0.9, + "value": "test_value", + "comment": "test_comment", + } + for run_id in run_ids + ] + client.multipart_ingest(create=posts, update=patches, feedback=feedback) # multipart endpoint should only send one request expected_num_requests = 1 # count the number of POST requests