Skip to content

Commit

Permalink
FastAPI model changes
Browse files Browse the repository at this point in the history
Signed-off-by: shubh chaurasia <[email protected]>
  • Loading branch information
shubh0508 committed Apr 26, 2022
1 parent 2dc2f70 commit 9ce8130
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 39 deletions.
93 changes: 54 additions & 39 deletions mlflow/pyfunc/scoring_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
import pandas as pd
import sys
import traceback
from pydantic import BaseModel
from fastapi import FastAPI, APIRouter, Request, HTTPException, Response, Header, status
from typing import List, Optional, Dict
import uvicorn
import asyncio
import json

# NB: We need to be careful what we import form mlflow here. Scoring server is used from within
# model's conda environment. The version of mlflow doing the serving (outside) and the version of
Expand Down Expand Up @@ -65,13 +71,24 @@

CONTENT_TYPE_FORMAT_RECORDS_ORIENTED = "pandas-records"
CONTENT_TYPE_FORMAT_SPLIT_ORIENTED = "pandas-split"
CONTENT_TYPE_RAW_JSON = "raw-json"

FORMATS = [CONTENT_TYPE_FORMAT_RECORDS_ORIENTED, CONTENT_TYPE_FORMAT_SPLIT_ORIENTED]
FORMATS = [CONTENT_TYPE_FORMAT_RECORDS_ORIENTED, CONTENT_TYPE_FORMAT_SPLIT_ORIENTED, CONTENT_TYPE_RAW_JSON]

PREDICTIONS_WRAPPER_ATTR_NAME_ENV_KEY = "PREDICTIONS_WRAPPER_ATTR_NAME"

_logger = logging.getLogger(__name__)

class RequestData(BaseModel):
columns: List[str] = []
data: list = []

def is_valid(self):
return True

def get_dataframe(self):
df = pd.DataFrame(data = self.data, columns = self.columns)
return df

def infer_and_parse_json_input(json_input, schema: Schema = None):
"""
Expand Down Expand Up @@ -205,38 +222,38 @@ def _handle_serving_error(error_message, error_code, include_traceback=True):
e = MlflowException(message=error_message, error_code=error_code)
reraise(MlflowException, e)


def init(model: PyFuncModel):

"""
Initialize the server. Loads pyfunc model from the path.
"""
app = flask.Flask(__name__)
fast_app = FastAPI(title= __name__, version= "v1")
fast_app.include_router(APIRouter())
input_schema = model.metadata.get_input_schema()

@app.route("/ping", methods=["GET"])
@fast_app.get("/ping")
def ping(): # pylint: disable=unused-variable
"""
Determine if the container is working and healthy.
We declare it healthy if we can load the model successfully.
"""
health = model is not None
status = 200 if health else 404
return flask.Response(response="\n", status=status, mimetype="application/json")
if model is None:
raise HTTPException(status_code=404, detail="Model not loaded properly")
return {"message": "OK"}

@app.route("/invocations", methods=["POST"])
@catch_mlflow_exception
def transformation(): # pylint: disable=unused-variable
@fast_app.post("/invocations")
def transformation(request_data: RequestData, content_type: Optional[str] = Header(None)): # pylint: disable=unused-variable
"""
Do an inference on a single batch of data. In this sample server,
we take data as CSV or json, convert it to a Pandas DataFrame or Numpy,
generate predictions and convert them back to json.
"""
# data = _dataframe_from_json(request_data.json())

# Content-Type can include other attributes like CHARSET
# Content-type RFC: https://datatracker.ietf.org/doc/html/rfc2045#section-5.1
# TODO: Suport ";" in quoted parameter values
type_parts = flask.request.content_type.split(";")
type_parts = content_type.split(";")
type_parts = list(map(str.strip, type_parts))
mime_type = type_parts[0]
parameter_value_pairs = type_parts[1:]
Expand All @@ -247,27 +264,31 @@ def transformation(): # pylint: disable=unused-variable

charset = parameter_values.get("charset", "utf-8").lower()
if charset != "utf-8":
return flask.Response(
response="The scoring server only supports UTF-8",
status=415,
mimetype="text/plain",
return Response(
content="The scoring server only supports UTF-8",
status_code=415,
media_type="text/plain"
)

content_format = parameter_values.get("format")

# Convert from CSV to pandas
if mime_type == CONTENT_TYPE_CSV and not content_format:
data = flask.request.data.decode("utf-8")
data = request_data.json()
csv_input = StringIO(data)
data = parse_csv_input(csv_input=csv_input)
elif mime_type == CONTENT_TYPE_JSON and content_format == CONTENT_TYPE_RAW_JSON:
if len(request_data.data) != 0:
data = dict(zip(request_data.columns, request_data.data[0]))
else:
data = {}
elif mime_type == CONTENT_TYPE_JSON and not content_format:
json_str = flask.request.data.decode("utf-8")
data = infer_and_parse_json_input(json_str, input_schema)
data = infer_and_parse_json_input(request_data.json(), input_schema)
elif (
mime_type == CONTENT_TYPE_JSON and content_format == CONTENT_TYPE_FORMAT_SPLIT_ORIENTED
):
data = parse_json_input(
json_input=StringIO(flask.request.data.decode("utf-8")),
json_input=StringIO(request_data.json()),
orient="split",
schema=input_schema,
)
Expand All @@ -276,29 +297,25 @@ def transformation(): # pylint: disable=unused-variable
and content_format == CONTENT_TYPE_FORMAT_RECORDS_ORIENTED
):
data = parse_json_input(
json_input=StringIO(flask.request.data.decode("utf-8")),
json_input=StringIO(request_data.json()),
orient="records",
schema=input_schema,
)
elif mime_type == CONTENT_TYPE_JSON_SPLIT_NUMPY and not content_format:
data = parse_split_oriented_json_input_to_numpy(flask.request.data.decode("utf-8"))
data = parse_split_oriented_json_input_to_numpy(request_data.json())
else:
return flask.Response(
response=(
"This predictor only supports the following content types and formats:"
return Response(
content="This predictor only supports the following content types and formats:"
" Types: {supported_content_types}; Formats: {formats}."
" Got '{received_content_type}'.".format(
supported_content_types=CONTENT_TYPES,
formats=FORMATS,
received_content_type=flask.request.content_type,
)
),
status=415,
mimetype="text/plain",
received_content_type=content_type,
),
status_code=415,
media_type="text/plain"
)

# Do the prediction

try:
raw_predictions = model.predict(data)
except MlflowException as e:
Expand All @@ -314,11 +331,10 @@ def transformation(): # pylint: disable=unused-variable
),
error_code=BAD_REQUEST,
)
result = StringIO()
predictions_to_json(raw_predictions, result)
return flask.Response(response=result.getvalue(), status=200, mimetype="application/json")
predictions = _get_jsonable_obj(raw_predictions, pandas_orient="records")
return str(predictions)

return app
return fast_app


def _predict(model_uri, input_path, output_path, content_type, json_format):
Expand All @@ -342,8 +358,8 @@ def _predict(model_uri, input_path, output_path, content_type, json_format):

def _serve(model_uri, port, host):
pyfunc_model = load_model(model_uri)
init(pyfunc_model).run(port=port, host=host)

fast_app = init(pyfunc_model)
uvicorn.run(fast_app, host=host, port=port, log_level="info")

def get_cmd(
model_uri: str, port: int = None, host: int = None, nworkers: int = None
Expand All @@ -362,8 +378,7 @@ def get_cmd(
args.append(f"-w {nworkers}")

command = (
f"gunicorn {' '.join(args)} ${{GUNICORN_CMD_ARGS}}"
" -- mlflow.pyfunc.scoring_server.wsgi:app"
"gunicorn mlflow.pyfunc.scoring_server.wsgi:app --worker-class uvicorn.workers.UvicornWorker"
)
else:
args = []
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def package_files(directory):
"alembic<=1.4.1",
# Required
"docker>=4.0.0",
"fastapi",
"uvicorn",
"Flask",
"gunicorn; platform_system != 'Windows'",
"numpy",
Expand Down

0 comments on commit 9ce8130

Please sign in to comment.