Skip to content

Commit

Permalink
bug fix in inference_endpoint wait function for proper waiting on upd…
Browse files Browse the repository at this point in the history
…ate (#2867)

* bug fix in inference_endpoint wait for proper waiting on update

* Update src/huggingface_hub/_inference_endpoints.py

improve code clarity and added logging based on review

Co-authored-by: Célina <[email protected]>

* changes in infernce_endpoint wait function for robust behaviour and addition of test case in test_inference_endpoint for testing changes in wait function

* changes in test case test_wait_update

---------

Co-authored-by: Célina <[email protected]>
  • Loading branch information
Ajinkya-25 and hanouticelina authored Feb 20, 2025
1 parent c8bbf54 commit 6456491
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
17 changes: 11 additions & 6 deletions src/huggingface_hub/_inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,21 @@ def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "Infere

start = time.time()
while True:
if self.url is not None:
# Means the URL is provisioned => check if the endpoint is reachable
response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token))
if response.status_code == 200:
logger.info("Inference Endpoint is ready to be used.")
return self
if self.status == InferenceEndpointStatus.FAILED:
raise InferenceEndpointError(
f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information."
)
if self.status == InferenceEndpointStatus.UPDATE_FAILED:
raise InferenceEndpointError(
f"Inference Endpoint {self.name} failed to update. Please check the logs for more information."
)
if self.status == InferenceEndpointStatus.RUNNING and self.url is not None:
# Verify the endpoint is actually reachable
response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token))
if response.status_code == 200:
logger.info("Inference Endpoint is ready to be used.")
return self

if timeout is not None:
if time.time() - start > timeout:
raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.")
Expand Down
59 changes: 57 additions & 2 deletions tests/test_inference_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime, timezone
from unittest.mock import Mock, patch
from itertools import chain, repeat
from unittest.mock import MagicMock, Mock, patch

import pytest

Expand Down Expand Up @@ -109,6 +110,39 @@
"targetReplica": 1,
},
}
# added for test_wait_update function
MOCK_UPDATE = {
"name": "my-endpoint-name",
"type": "protected",
"accountId": None,
"provider": {"vendor": "aws", "region": "us-east-1"},
"compute": {
"accelerator": "cpu",
"instanceType": "intel-icl",
"instanceSize": "x2",
"scaling": {"minReplica": 0, "maxReplica": 1},
},
"model": {
"repository": "gpt2",
"revision": "11c5a3d5811f50298f278a704980280950aedb10",
"task": "text-generation",
"framework": "pytorch",
"image": {"huggingface": {}},
"secret": {"token": "my-token"},
},
"status": {
"createdAt": "2023-10-26T12:41:53.263078506Z",
"createdBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"},
"updatedAt": "2023-10-26T12:41:53.263079138Z",
"updatedBy": {"id": "6273f303f6d63a28483fde12", "name": "Wauplin"},
"private": None,
"state": "updating",
"url": "https://vksrvs8pc1xnifhq.us-east-1.aws.endpoints.huggingface.cloud",
"message": "Endpoint waiting for the update",
"readyReplica": 0,
"targetReplica": 1,
},
}


def test_from_raw_initialization():
Expand Down Expand Up @@ -189,7 +223,7 @@ def test_fetch(mock_get: Mock):
@patch("huggingface_hub._inference_endpoints.get_session")
@patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint")
def test_wait_until_running(mock_get: Mock, mock_session: Mock):
"""Test waits waits until the endpoint is ready."""
"""Test waits until the endpoint is ready."""
endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo")

mock_get.side_effect = [
Expand Down Expand Up @@ -244,6 +278,27 @@ def test_wait_failed(mock_get: Mock):
endpoint.wait(refresh_every=0.001)


@patch("huggingface_hub.hf_api.HfApi.get_inference_endpoint")
@patch("huggingface_hub._inference_endpoints.get_session")
def test_wait_update(mock_get_session, mock_get_inference_endpoint):
"""Test that wait() returns when the endpoint transitions to running."""
endpoint = InferenceEndpoint.from_raw(MOCK_INITIALIZING, namespace="foo")
# Create an iterator that yields three MOCK_UPDATE responses,and then infinitely yields MOCK_RUNNING responses.
responses = chain(
[InferenceEndpoint.from_raw(MOCK_UPDATE, namespace="foo")] * 3,
repeat(InferenceEndpoint.from_raw(MOCK_RUNNING, namespace="foo")),
)
mock_get_inference_endpoint.side_effect = lambda *args, **kwargs: next(responses)

# Patch the get_session().get() call to always return a fake response with status_code 200.
fake_response = MagicMock()
fake_response.status_code = 200
mock_get_session.return_value.get.return_value = fake_response

endpoint.wait(refresh_every=0.05)
assert endpoint.status == "running"


@patch("huggingface_hub.hf_api.HfApi.pause_inference_endpoint")
def test_pause(mock: Mock):
"""Test `pause` calls the correct alias."""
Expand Down

0 comments on commit 6456491

Please sign in to comment.