diff --git a/python/langsmith/client.py b/python/langsmith/client.py index beea86fdb..4835234d0 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -583,7 +583,10 @@ def _headers(self) -> Dict[str, str]: Dict[str, str] The headers for the API request. """ - headers = {"User-Agent": f"langsmith-py/{langsmith.__version__}"} + headers = { + "User-Agent": f"langsmith-py/{langsmith.__version__}", + "Accept": "application/json", + } if self.api_key: headers[X_API_KEY] = self.api_key return headers @@ -614,13 +617,16 @@ def info(self) -> ls_schemas.LangSmithInfo: def request_with_retries( self, - request_method: str, - url: str, - request_kwargs: Mapping, + /, + method: Literal["GET", "POST", "PUT", "PATCH", "DELETE"], + pathname: str, + *, + request_kwargs: Optional[Mapping] = None, stop_after_attempt: int = 1, retry_on: Optional[Sequence[Type[BaseException]]] = None, to_ignore: Optional[Sequence[Type[BaseException]]] = None, handle_response: Optional[Callable[[requests.Response, int], Any]] = None, + **kwargs: Any, ) -> requests.Response: """Send a request with retries. @@ -628,8 +634,8 @@ def request_with_retries( ---------- request_method : str The HTTP request method. - url : str - The URL to send the request to. + pathname : str + The pathname of the request URL. Will be appended to the API URL. request_kwargs : Mapping Additional request parameters. stop_after_attempt : int, default=1 @@ -642,6 +648,8 @@ def request_with_retries( handle_response : Callable[[requests.Response, int], Any] or None, default=None A function to handle the response and return whether to continue retrying. + **kwargs : Any + Additional keyword arguments to pass to the request. Returns: ------- @@ -659,6 +667,23 @@ def request_with_retries( LangSmithError If the request fails. """ + request_kwargs = request_kwargs or {} + request_kwargs = { + "headers": { + **self._headers, + **request_kwargs.get("headers", {}), + **kwargs.get("headers", {}), + }, + "timeout": self.timeout_ms / 1000, + **request_kwargs, + **kwargs, + } + if ( + method != "GET" + and "data" in request_kwargs + and not request_kwargs["headers"].get("Content-Type") + ): + request_kwargs["headers"]["Content-Type"] = "application/json" logging_filters = [ ls_utils.FilterLangSmithRetry(), ls_utils.FilterPoolFullWarning(host=str(self._host)), @@ -675,7 +700,14 @@ def request_with_retries( try: with ls_utils.filter_logs(_urllib3_logger, logging_filters): response = self.session.request( - request_method, url, stream=False, **request_kwargs + method, + ( + self.api_url + pathname + if not pathname.startswith("http") + else pathname + ), + stream=False, + **request_kwargs, ) ls_utils.raise_for_status_with_text(response) return response @@ -688,35 +720,35 @@ def request_with_retries( continue if response.status_code == 500: raise ls_utils.LangSmithAPIError( - f"Server error caused failure to {request_method}" - f" {url} in" + f"Server error caused failure to {method}" + f" {pathname} in" f" LangSmith API. {repr(e)}" ) elif response.status_code == 429: raise ls_utils.LangSmithRateLimitError( - f"Rate limit exceeded for {url}. {repr(e)}" + f"Rate limit exceeded for {pathname}. {repr(e)}" ) elif response.status_code == 401: raise ls_utils.LangSmithAuthError( - f"Authentication failed for {url}. {repr(e)}" + f"Authentication failed for {pathname}. {repr(e)}" ) elif response.status_code == 404: raise ls_utils.LangSmithNotFoundError( - f"Resource not found for {url}. {repr(e)}" + f"Resource not found for {pathname}. {repr(e)}" ) elif response.status_code == 409: raise ls_utils.LangSmithConflictError( - f"Conflict for {url}. {repr(e)}" + f"Conflict for {pathname}. {repr(e)}" ) else: raise ls_utils.LangSmithError( - f"Failed to {request_method} {url} in LangSmith" + f"Failed to {method} {pathname} in LangSmith" f" API. {repr(e)}" ) else: raise ls_utils.LangSmithUserError( - f"Failed to {request_method} {url} in LangSmith API." + f"Failed to {method} {pathname} in LangSmith API." f" {repr(e)}" ) except requests.ConnectionError as e: @@ -726,7 +758,7 @@ def request_with_retries( else "Please confirm your internet connection." ) raise ls_utils.LangSmithConnectionError( - f"Connection error caused failure to {request_method} {url}" + f"Connection error caused failure to {method} {pathname}" f" in LangSmith API. {recommendation}." f" {repr(e)}" ) from e @@ -738,7 +770,7 @@ def request_with_retries( [str(args[0])] + [msg] + [str(arg) for arg in args[2:]] ) raise ls_utils.LangSmithError( - f"Failed to {request_method} {url} in LangSmith API. {emsg}" + f"Failed to {method} {pathname} in LangSmith API. {emsg}" ) from e except to_ignore_ as e: if response is not None: @@ -753,20 +785,7 @@ def request_with_retries( continue raise ls_utils.LangSmithError( - f"Failed to {request_method} {url} in LangSmith API." - ) - - def _get_with_retries( - self, path: str, params: Optional[Dict[str, Any]] = None - ) -> requests.Response: - return self.request_with_retries( - "get", - f"{self.api_url}{path}", - request_kwargs={ - "params": params, - "headers": self._headers, - "timeout": self.timeout_ms / 1000, - }, + f"Failed to {method} {pathname} in LangSmith API." ) def _get_paginated_list( @@ -791,7 +810,7 @@ def _get_paginated_list( params_["limit"] = params_.get("limit", 100) while True: params_["offset"] = offset - response = self._get_with_retries(path, params=params_) + response = self.request_with_retries("GET", path, params=params_) items = response.json() if not items: @@ -808,7 +827,7 @@ def _get_cursor_paginated_list( path: str, *, body: Optional[dict] = None, - request_method: str = "post", + request_method: Literal["GET", "POST"] = "POST", data_key: str = "runs", ) -> Iterator[dict]: """Get a cursor paginated list of items. @@ -832,11 +851,9 @@ def _get_cursor_paginated_list( while True: response = self.request_with_retries( request_method, - f"{self.api_url}{path}", + path, request_kwargs={ "data": _dumps_json(params_), - "headers": self._headers, - "timeout": self.timeout_ms / 1000, }, ) response_body = response.json() @@ -1110,19 +1127,13 @@ def create_run( def _create_run(self, run_create: dict): for api_url, api_key in self._write_api_urls.items(): - headers = { - **self._headers, - "Accept": "application/json", - "Content-Type": "application/json", - X_API_KEY: api_key, - } + headers = {**self._headers, X_API_KEY: api_key} self.request_with_retries( - "post", + "POST", f"{api_url}/runs", request_kwargs={ "data": _dumps_json(run_create), "headers": headers, - "timeout": self.timeout_ms / 1000, }, to_ignore=(ls_utils.LangSmithConflictError,), ) @@ -1263,15 +1274,12 @@ def handle_429(response: requests.Response, attempt: int) -> bool: try: for api_url, api_key in self._write_api_urls.items(): self.request_with_retries( - "post", + "POST", f"{api_url}/runs/batch", request_kwargs={ "data": body, - "timeout": self.timeout_ms / 1000, "headers": { **self._headers, - "Accept": "application/json", - "Content-Type": "application/json", X_API_KEY: api_key, }, }, @@ -1355,18 +1363,15 @@ def _update_run(self, run_update: dict) -> None: for api_url, api_key in self._write_api_urls.items(): headers = { **self._headers, - "Accept": "application/json", - "Content-Type": "application/json", X_API_KEY: api_key, } self.request_with_retries( - "patch", + "PATCH", f"{api_url}/runs/{run_update['id']}", request_kwargs={ "data": _dumps_json(run_update), "headers": headers, - "timeout": self.timeout_ms / 1000, }, ) @@ -1423,7 +1428,9 @@ def read_run( Run The run. """ - response = self._get_with_retries(f"/runs/{_as_uuid(run_id, 'run_id')}") + response = self.request_with_retries( + "GET", f"/runs/{_as_uuid(run_id, 'run_id')}" + ) run = ls_schemas.Run(**response.json(), _host_url=self._host_url) if load_child_runs and run.child_run_ids: run = self._load_child_runs(run) @@ -1618,9 +1625,7 @@ def list_runs( } body_query = {k: v for k, v in body_query.items() if v is not None} for i, run in enumerate( - self._get_cursor_paginated_list( - "/runs/query", body=body_query, request_method="post" - ) + self._get_cursor_paginated_list("/runs/query", body=body_query) ): yield ls_schemas.Run(**run, _host_url=self._host_url) if limit is not None and i + 1 >= limit: @@ -1979,7 +1984,9 @@ def _get_optional_tenant_id(self) -> Optional[uuid.UUID]: if self._tenant_id is not None: return self._tenant_id try: - response = self._get_with_retries("/sessions", params={"limit": 1}) + response = self.request_with_retries( + "GET", "/sessions", params={"limit": 1} + ) result = response.json() if isinstance(result, list) and len(result) > 0: tracer_session = ls_schemas.TracerSessionResult( @@ -2033,7 +2040,7 @@ def read_project( else: raise ValueError("Must provide project_name or project_id") params["include_stats"] = include_stats - response = self._get_with_retries(path, params=params) + response = self.request_with_retries("GET", path, params=params) result = response.json() if isinstance(result, list): if len(result) == 0: @@ -2316,7 +2323,8 @@ def read_dataset( params["name"] = dataset_name else: raise ValueError("Must provide dataset_name or dataset_id") - response = self._get_with_retries( + response = self.request_with_retries( + "GET", path, params=params, ) @@ -2436,7 +2444,8 @@ def read_dataset_openai_finetuning( dataset_id = self.read_dataset(dataset_name=dataset_name).id else: raise ValueError("Must provide dataset_name or dataset_id") - response = self._get_with_retries( + response = self.request_with_retries( + "GET", f"{path}/{_as_uuid(dataset_id, 'dataset_id')}/openai_ft", ) dataset = [json.loads(line) for line in response.text.strip().split("\n")] @@ -2641,7 +2650,8 @@ def read_dataset_version( dataset_id = self.read_dataset(dataset_name=dataset_name).id if (as_of and tag) or (as_of is None and tag is None): raise ValueError("Exactly one of as_of and tag must be specified.") - response = self._get_with_retries( + response = self.request_with_retries( + "GET", f"/datasets/{_as_uuid(dataset_id, 'dataset_id')}/version", params={"as_of": as_of, "tag": tag}, ) @@ -2997,7 +3007,8 @@ def read_example( Returns: Example: The example. """ - response = self._get_with_retries( + response = self.request_with_retries( + "GET", f"/examples/{_as_uuid(example_id, 'example_id')}", params={ "as_of": as_of.isoformat() if as_of else None, @@ -3451,15 +3462,9 @@ def create_feedback( ) self.request_with_retries( "POST", - self.api_url + "/feedback", + "/feedback", request_kwargs={ "data": _dumps_json(feedback.dict(exclude_none=True)), - "headers": { - **self._headers, - "Content-Type": "application/json", - "Accept": "application/json", - }, - "timeout": self.timeout_ms / 1000, }, stop_after_attempt=stop_after_attempt, retry_on=(ls_utils.LangSmithNotFoundError,), @@ -3519,7 +3524,8 @@ def read_feedback(self, feedback_id: ID_TYPE) -> ls_schemas.Feedback: Feedback The feedback. """ - response = self._get_with_retries( + response = self.request_with_retries( + "GET", f"/feedback/{_as_uuid(feedback_id, 'feedback_id')}", ) return ls_schemas.Feedback(**response.json()) @@ -3688,12 +3694,9 @@ def create_presigned_feedback_token( raise ValueError(f"Unknown expiration type: {type(expiration)}") response = self.request_with_retries( - "post", - f"{self.api_url}/feedback/tokens", - { - "data": _dumps_json(body), - "headers": self._headers, - }, + "POST", + "/feedback/tokens", + data=_dumps_json(body), ) ls_utils.raise_for_status_with_text(response) return ls_schemas.FeedbackIngestToken(**response.json()) @@ -3798,12 +3801,9 @@ def create_annotation_queue( "id": queue_id, } response = self.request_with_retries( - "post", - f"{self.api_url}/annotation-queues", - { - "json": {k: v for k, v in body.items() if v is not None}, - "headers": self._headers, - }, + "POST", + "/annotation-queues", + json={k: v for k, v in body.items() if v is not None}, ) ls_utils.raise_for_status_with_text(response) return ls_schemas.AnnotationQueue( @@ -3836,14 +3836,11 @@ def update_annotation_queue( annotation queue. Defaults to None. """ response = self.request_with_retries( - "patch", - f"{self.api_url}/annotation-queues/{_as_uuid(queue_id, 'queue_id')}", - { - "json": { - "name": name, - "description": description, - }, - "headers": self._headers, + "PATCH", + f"/annotation-queues/{_as_uuid(queue_id, 'queue_id')}", + json={ + "name": name, + "description": description, }, ) ls_utils.raise_for_status_with_text(response) @@ -3871,14 +3868,9 @@ def add_runs_to_annotation_queue( queue. """ response = self.request_with_retries( - "post", - f"{self.api_url}/annotation-queues/{_as_uuid(queue_id, 'queue_id')}/runs", - { - "json": [ - str(_as_uuid(id_, f"run_ids[{i}]")) for i, id_ in enumerate(run_ids) - ], - "headers": self._headers, - }, + "POST", + f"/annotation-queues/{_as_uuid(queue_id, 'queue_id')}/runs", + json=[str(_as_uuid(id_, f"run_ids[{i}]")) for i, id_ in enumerate(run_ids)], ) ls_utils.raise_for_status_with_text(response) diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index 476750f12..b64e5b982 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -358,7 +358,7 @@ def mock_get(*args, **kwargs): assert len(request_calls) >= 1 for call in request_calls: - assert call.args[0] == "post" + assert call.args[0] == "POST" assert call.args[1] == "http://localhost:1984/runs/batch" get_calls = [call for call in session.get.mock_calls if call.args] # assert len(get_calls) == 1 @@ -369,7 +369,7 @@ def mock_get(*args, **kwargs): assert len(request_calls) == 10 for call in request_calls: - assert call.args[0] == "post" + assert call.args[0] == "POST" assert call.args[1] == "http://localhost:1984/runs" if auto_batch_tracing: get_calls = [call for call in session.get.mock_calls if call.args] @@ -482,7 +482,7 @@ def test_client_gc_after_autoscale() -> None: request_calls = [call for call in session.request.mock_calls if call.args] assert len(request_calls) >= 500 and len(request_calls) <= 550 for call in request_calls: - assert call.args[0] == "post" + assert call.args[0] == "POST" assert call.args[1] == "http://localhost:1984/runs/batch" @@ -796,7 +796,7 @@ def test_retry_on_connection_error(mock_sleep: MagicMock): mock_session.request.side_effect = requests.ConnectionError() with pytest.raises(ls_utils.LangSmithConnectionError): - client.request_with_retries("GET", "https://test.url", {}, stop_after_attempt=2) + client.request_with_retries("GET", "https://test.url", stop_after_attempt=2) assert mock_session.request.call_count == 2 @@ -810,7 +810,7 @@ def test_http_status_500_handling(mock_sleep): mock_session.request.return_value = mock_response with pytest.raises(ls_utils.LangSmithAPIError): - client.request_with_retries("GET", "https://test.url", {}, stop_after_attempt=2) + client.request_with_retries("GET", "https://test.url", stop_after_attempt=2) assert mock_session.request.call_count == 2 @@ -826,7 +826,6 @@ def test_pass_on_409_handling(mock_sleep): response = client.request_with_retries( "GET", "https://test.url", - {}, stop_after_attempt=5, to_ignore=[ls_utils.LangSmithConflictError], ) @@ -843,7 +842,7 @@ def test_http_status_429_handling(mock_raise_for_status): mock_session.request.return_value = mock_response mock_raise_for_status.side_effect = HTTPError() with pytest.raises(ls_utils.LangSmithRateLimitError): - client.request_with_retries("GET", "https://test.url", {}) + client.request_with_retries("GET", "https://test.url") @patch("langsmith.client.ls_utils.raise_for_status_with_text") @@ -855,7 +854,7 @@ def test_http_status_401_handling(mock_raise_for_status): mock_session.request.return_value = mock_response mock_raise_for_status.side_effect = HTTPError() with pytest.raises(ls_utils.LangSmithAuthError): - client.request_with_retries("GET", "https://test.url", {}) + client.request_with_retries("GET", "https://test.url") @patch("langsmith.client.ls_utils.raise_for_status_with_text") @@ -867,7 +866,7 @@ def test_http_status_404_handling(mock_raise_for_status): mock_session.request.return_value = mock_response mock_raise_for_status.side_effect = HTTPError() with pytest.raises(ls_utils.LangSmithNotFoundError): - client.request_with_retries("GET", "https://test.url", {}) + client.request_with_retries("GET", "https://test.url") @patch("langsmith.client.ls_utils.raise_for_status_with_text") @@ -894,7 +893,7 @@ def test_batch_ingest_run_retry_on_429(mock_raise_for_status): assert mock_session.request.call_count >= 3 # count the number of POST requests assert ( - sum([1 for call in mock_session.request.call_args_list if call[0][0] == "post"]) + sum([1 for call in mock_session.request.call_args_list if call[0][0] == "POST"]) == 3 ) @@ -941,7 +940,7 @@ def test_batch_ingest_run_splits_large_batches(payload_size: int): expected_num_requests = min(6, math.ceil((len(run_ids) * 2) / max_in_batch)) # count the number of POST requests assert ( - sum([1 for call in mock_session.request.call_args_list if call[0][0] == "post"]) + sum([1 for call in mock_session.request.call_args_list if call[0][0] == "POST"]) == expected_num_requests ) request_bodies = [ diff --git a/python/tests/unit_tests/test_run_helpers.py b/python/tests/unit_tests/test_run_helpers.py index 599d14a20..0abea6647 100644 --- a/python/tests/unit_tests/test_run_helpers.py +++ b/python/tests/unit_tests/test_run_helpers.py @@ -197,7 +197,7 @@ def my_iterator_fn(a, b, d): assert 1 <= len(mock_calls) <= 2 call = mock_calls[0] - assert call.args[0] == "post" + assert call.args[0] == "POST" assert call.args[1].startswith("https://api.smith.langchain.com") body = json.loads(mock_calls[0].kwargs["data"]) assert body["post"] @@ -232,7 +232,7 @@ async def my_iterator_fn(a, b, d): assert 1 <= len(mock_calls) <= 2 call = mock_calls[0] - assert call.args[0] == "post" + assert call.args[0] == "POST" assert call.args[1].startswith("https://api.smith.langchain.com") body = json.loads(call.kwargs["data"]) assert body["post"] @@ -333,7 +333,7 @@ def my_function(a: int, b: int, d: int) -> int: mock_calls = mock_client_.session.request.mock_calls # type: ignore assert 1 <= len(mock_calls) <= 2 call = mock_calls[0] - assert call.args[0] == "post" + assert call.args[0] == "POST" assert call.args[1].startswith("https://api.smith.langchain.com") body = json.loads(call.kwargs["data"]) assert body["post"] @@ -354,7 +354,7 @@ def my_other_function(run_tree) -> int: mock_calls = mock_client_.session.request.mock_calls # type: ignore assert 1 <= len(mock_calls) <= 2 call = mock_calls[0] - assert call.args[0] == "post" + assert call.args[0] == "POST" assert call.args[1].startswith("https://api.smith.langchain.com") body = json.loads(call.kwargs["data"]) assert body["post"]