Skip to content

Commit

Permalink
Add model_dir arg to testing functions (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
faph authored Sep 19, 2024
2 parents 627a9cf + 3aa8fa3 commit 69ef77b
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/deployment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ like this:
COPY entrypoint.sh /usr/local/bin/
RUN python -m pip install \
gunicorn \
inference-server \
shipping-forecast # Our package implementing the hooks
EXPOSE 8080
Expand Down
15 changes: 14 additions & 1 deletion docs/testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ Here we can use any serializer compatible with :mod:`sagemaker.serializers` and

If no serializer or deserializer is configured, bytes data are passed through as is for both input and output.

:func:`inference_server.testing.predict` accepts a ``model_dir`` argument which can used to set the directory containing
the model artifacts to be loaded. At runtime, this directory is always :file:`/opt/ml/model`. In our tests, we may want
to create model artifacts on the fly, for example in a temporary directory using a Pytest fixture, like this::

import pathlib

@pytest.fixture
def model_artifacts_dir(tmp_path) -> pathlib.Path:
dir_ = tmp_path / "model"
dir_.mkdir()
# instantiate a model object and serialize as 1 or more files to the directory
...
return dir_


Testing model predictions (low-level API)
-----------------------------------------
Expand All @@ -63,7 +77,6 @@ Instead of using the high-level testing API, we can also use invoke requests sim
assert response.json() == expected_prediction



Verifying plugin registration
-----------------------------

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ dependencies = [
[project.optional-dependencies]

docs = [
"pytest", # Because we import this in inference_server.testing
"sphinx",
"sphinx-rtd-theme",
]
Expand All @@ -81,6 +82,7 @@ linting = [
"isort",
"mypy",
"pre-commit",
"pytest", # Because we import this in inference_server.testing
]


Expand Down
3 changes: 2 additions & 1 deletion src/inference_server/_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def model_fn(model_dir: str) -> ModelType:
This function will be called when the server starts up. Here, ``ModelType`` can be any Python class corresponding to
the model, for example :class:`sklearn.tree.DecisionTreeClassifier`.
:param model_dir: Local filesystem directory containing the model files
:param model_dir: Local filesystem directory containing the model files. This is always :file:`/opt/ml/model` when
invoked by **inference-server**.
"""
raise NotImplementedError

Expand Down
24 changes: 19 additions & 5 deletions src/inference_server/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
"""

import io
import pathlib
from types import ModuleType
from typing import Any, Callable, Optional, Protocol, Tuple, Type, Union

import botocore.response # type: ignore[import-untyped]
import pluggy
import pytest
import werkzeug.test

import inference_server
Expand Down Expand Up @@ -79,12 +81,18 @@ def deserialize(self, stream: "botocore.response.StreamingBody", content_type: s


def predict(
data: Any, serializer: Optional[ImplementsSerialize] = None, deserializer: Optional[ImplementsDeserialize] = None
data: Any,
*,
model_dir: Optional[pathlib.Path] = None,
serializer: Optional[ImplementsSerialize] = None,
deserializer: Optional[ImplementsDeserialize] = None,
) -> Any:
"""
Invoke the model and return a prediction
:param data: Model input data
:param model_dir: Optional pass a custom model directory to load the model from. Default is
:file:`/opt/ml/model/`.
:param serializer: Optional. A serializer for sending the data as bytes to the model server. Should be compatible
with :class:`sagemaker.serializers.BaseSerializer`. Default: bytes pass-through.
:param deserializer: Optional. A deserializer for processing the prediction as sent by the model server. Should be
Expand All @@ -98,7 +106,7 @@ def predict(
"Content-Type": serializer.CONTENT_TYPE, # The serializer declares the content-type of the input data
"Accept": ", ".join(deserializer.ACCEPT), # The deserializer dictates the content-type of the prediction
}
prediction_response = post_invocations(data=serialized_data, headers=http_headers)
prediction_response = post_invocations(model_dir=model_dir, data=serialized_data, headers=http_headers)
prediction_stream = botocore.response.StreamingBody(
raw_stream=io.BytesIO(prediction_response.data),
content_length=prediction_response.content_length,
Expand All @@ -117,15 +125,21 @@ def client() -> werkzeug.test.Client:
return werkzeug.test.Client(inference_server.create_app())


def post_invocations(**kwargs) -> werkzeug.test.TestResponse:
def post_invocations(*, model_dir: Optional[pathlib.Path] = None, **kwargs) -> werkzeug.test.TestResponse:
"""
Send an HTTP POST request to ``/invocations`` using a test HTTP client and return the response
This function should be used to verify an inference request using the full **inference-server** logic.
:param kwargs: Keyword arguments passed to :meth:`werkzeug.test.Client.post`
:param model_dir: Optional pass a custom model directory to load the model from. Default is :file:`/opt/ml/model/`.
:param kwargs: Keyword arguments passed to :meth:`werkzeug.test.Client.post`
"""
response = client().post("/invocations", **kwargs)
# pytest should be available when we are using inference_server.testing
with pytest.MonkeyPatch.context() as monkeypatch:
if model_dir:
monkeypatch.setattr(inference_server, "_MODEL_DIR", str(model_dir))
response = client().post("/invocations", **kwargs)

assert response.status_code == 200
return response

Expand Down
49 changes: 48 additions & 1 deletion tests/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import pathlib
from typing import Tuple

import botocore.response
Expand All @@ -22,6 +22,14 @@ def test_package_has_version():
assert inference_server.__version__ is not None


@pytest.fixture(autouse=True)
def reset_caches():
try:
yield
finally:
inference_server._model.cache_clear()


@pytest.fixture
def client():
return inference_server.testing.client()
Expand All @@ -46,6 +54,26 @@ def ping_fn(model):
pm.unregister(PingPlugin)


@pytest.fixture
def model_using_dir():
class ModelPlugin:
"""Plugin which just defines a model_fn"""

@staticmethod
@inference_server.plugin_hook()
def model_fn(model_dir: str):
"""Model function for testing we are passing a custom directory"""
assert model_dir != "/opt/ml/model"
return lambda data: data

pm = inference_server.testing.plugin_manager()
pm.register(ModelPlugin)
try:
yield
finally:
pm.unregister(ModelPlugin)


def test_version():
"""Test that the package has a version"""
assert inference_server.__version__ is not None
Expand Down Expand Up @@ -80,6 +108,17 @@ def test_invocations():
assert response.headers["Content-Type"] == "application/octet-stream"


def test_invocations_custom_model_dir(model_using_dir):
"""Test the default plugin (which passes through any input bytes) using low-level testing.post_invocations"""
data = b"What's the shipping forecast for tomorrow"
model_dir = pathlib.Path(__file__).parent

response = inference_server.testing.post_invocations(
data=data, model_dir=model_dir, headers={"Accept": "application/octet-stream"}
)
assert response.data == data


def test_prediction_custom_serializer():
"""Test the default plugin again, now using high-level testing.predict"""

Expand Down Expand Up @@ -115,6 +154,14 @@ def test_prediction_no_serializer():
assert prediction == input_data


def test_prediction_model_dir(model_using_dir):
input_data = b"What's the shipping forecast for tomorrow"
model_dir = pathlib.Path(__file__).parent

prediction = inference_server.testing.predict(input_data, model_dir=model_dir)
assert prediction == input_data


def test_execution_parameters(client):
response = client.get("/execution-parameters")
assert response.data == b'{"BatchStrategy":"MultiRecord","MaxConcurrentTransforms":1,"MaxPayloadInMB":6}'
Expand Down

0 comments on commit 69ef77b

Please sign in to comment.