Skip to content

Commit

Permalink
Merge pull request #608 from tableau/jichikawa/handle-413-payload-res…
Browse files Browse the repository at this point in the history
…ponse

Handle 413 Request Entity Too Large
  • Loading branch information
jakeichikawasalesforce authored May 23, 2023
2 parents fe0a6e3 + 647f2a6 commit 5716a67
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 11 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## v2.8.0

### Improvements

- Returns 413 error code when request payload exceeds
TABPY_MAX_REQUEST_SIZE_MB config setting.

## v2.7.0

### Improvements
Expand Down
2 changes: 1 addition & 1 deletion tabpy/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.7.1
2.8.0
39 changes: 29 additions & 10 deletions tabpy/tabpy_server/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import shutil
import signal
import sys
import _thread

import tornado
from tornado.http1connection import HTTP1Connection

import tabpy
import tabpy.tabpy_server.app.arrow_server as pa
from tabpy.tabpy import __version__
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
from tabpy.tabpy_server.app.util import parse_pwd_file
Expand All @@ -26,9 +32,6 @@
StatusHandler,
UploadDestinationHandler,
)
import tornado
import tabpy.tabpy_server.app.arrow_server as pa
import _thread

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,6 +65,7 @@ class TabPyApp:
python_service = None
credentials = {}
arrow_server = None
max_request_size = None

def __init__(self, config_file):
if config_file is None:
Expand Down Expand Up @@ -116,11 +120,7 @@ def _get_arrow_server(self, config):

def run(self):
application = self._create_tornado_web_app()
max_request_size = (
int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
)
logger.info(f"Setting max request size to {max_request_size} bytes")


init_model_evaluator(self.settings, self.tabpy_state, self.python_service)

protocol = self.settings[SettingsParameters.TransferProtocol]
Expand All @@ -142,8 +142,8 @@ def run(self):
application.listen(
self.settings[SettingsParameters.Port],
ssl_options=ssl_options,
max_buffer_size=max_request_size,
max_body_size=max_request_size,
max_buffer_size=self.max_request_size,
max_body_size=self.max_request_size,
**settings,
)

Expand Down Expand Up @@ -354,6 +354,12 @@ def _parse_config(self, config_file):
].lower()

self._validate_transfer_protocol_settings()

# Set max request size in bytes
self.max_request_size = (
int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
)
logger.info(f"Setting max request size to {self.max_request_size} bytes")

# if state.ini does not exist try and create it - remove
# last dependence on batch/shell script
Expand Down Expand Up @@ -497,3 +503,16 @@ def _build_tabpy_state(self):
logger.info(f"Loading state from state file {state_file_path}")
tabpy_state = _get_state_from_file(state_file_dir)
return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)


# Override _read_body to allow content with size exceeding max_body_size
# This enables proper handling of 413 errors in base_handler
def _read_body_allow_max_size(self, code, headers, delegate):
if "Content-Length" in headers:
content_length = int(headers["Content-Length"])
if content_length > self._max_body_size:
return
return self.original_read_body(code, headers, delegate)

HTTP1Connection.original_read_body = HTTP1Connection._read_body
HTTP1Connection._read_body = _read_body_allow_max_size
23 changes: 23 additions & 0 deletions tabpy/tabpy_server/handlers/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def initialize(self, app):
self.username = None
self.password = None
self.eval_timeout = self.settings[SettingsParameters.EvaluateTimeout]
self.max_request_size = app.max_request_size

self.logger = ContextLoggerWrapper(self.request)
self.logger.enable_context_logging(
Expand Down Expand Up @@ -442,3 +443,25 @@ def fail_with_auth_error(self):
info="Not Acceptable",
log_message="Username or password provided when authentication not available.",
)

def request_body_size_within_limit(self):
"""
Determines if the request body size is within the specified limit.
Returns
-------
bool
True if the request body size is within the limit, False otherwise.
"""
if self.max_request_size is not None:
if "Content-Length" in self.request.headers:
content_length = int(self.request.headers["Content-Length"])
if content_length > self.max_request_size:
self.error_out(
413,
info="Request Entity Too Large",
log_message=f"Request with size {content_length} exceeded limit of {self.max_request_size} (bytes).",
)
return False

return True
7 changes: 7 additions & 0 deletions tabpy/tabpy_server/handlers/evaluation_plane_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def post(self):
if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
self.fail_with_auth_error()
return

if not self.request_body_size_within_limit():
return

self.error_out(404, "Ad-hoc scripts have been disabled on this analytics extension, please contact your "
"administrator.")

Expand Down Expand Up @@ -165,6 +169,9 @@ def post(self):
if self.should_fail_with_auth_error() != AuthErrorStates.NONE:
self.fail_with_auth_error()
return

if not self.request_body_size_within_limit():
return

self._add_CORS_header()
try:
Expand Down
6 changes: 6 additions & 0 deletions tabpy/tabpy_server/handlers/query_plane_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def get(self, endpoint_name):
self.fail_with_auth_error()
return

if not self.request_body_size_within_limit():
return

start = time.time()
endpoint_name = urllib.parse.unquote(endpoint_name)
self._process_query(endpoint_name, start)
Expand All @@ -229,6 +232,9 @@ def post(self, endpoint_name):
self.fail_with_auth_error()
return

if not self.request_body_size_within_limit():
return

start = time.time()
endpoint_name = urllib.parse.unquote(endpoint_name)
self._process_query(endpoint_name, start)
63 changes: 63 additions & 0 deletions tests/unit/server_tests/test_evaluation_plane_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
import json
import os
import tempfile
import string

from tornado.testing import AsyncHTTPTestCase

Expand Down Expand Up @@ -458,6 +460,67 @@ def test_evaluation_enabled(self):
)
self.assertEqual(200, response.code)

class TestEvaluationPlaneHandlerMaxRequestSize(AsyncHTTPTestCase):
@classmethod
def setUpClass(cls):
prefix = "__TestEvaluationPlaneHandlerMaxRequestSize_"

# create config file
cls.config_file = tempfile.NamedTemporaryFile(
mode="w+t", prefix=prefix, suffix=".conf", delete=False
)
cls.config_file.write(
"[TabPy]\n"
"TABPY_MAX_REQUEST_SIZE_MB = 1"
)
cls.config_file.close()

@classmethod
def tearDownClass(cls):
os.remove(cls.config_file.name)

def get_app(self):
self.app = TabPyApp(self.config_file.name)
return self.app._create_tornado_web_app()

def create_large_payload(self):
num_chars = 2 * 1024 * 1024 # 2MB Size
large_string = string.printable * (num_chars // len(string.printable))
large_string += string.printable[:num_chars % len(string.printable)]
payload = {
"data": { "_arg1": [1, large_string] },
"script": "return _arg1"
}
return json.dumps(payload).encode('utf-8')

def test_evaluation_payload_exceeds_max_request_size(self):
response = self.fetch(
"/evaluate",
method="POST",
body=self.create_large_payload()
)
self.assertEqual(413, response.code)

def test_evaluation_max_request_size_not_applied(self):
self.app.max_request_size = None
response = self.fetch(
"/evaluate",
method="POST",
body=self.create_large_payload()
)
self.assertEqual(200, response.code)
self.assertEqual(1, json.loads(response.body)[0])

def test_no_content_length_header_present(self):
response = self.fetch(
"/evaluate",
method="POST",
allow_nonstandard_methods=True
)
message = json.loads(response.body)["message"]
# Ensure it reaches script processing stage in EvaluationPlaneHandler.post
self.assertEqual("Error processing script", message)


class TestEvaluationPlaneHandlerDefault(AsyncHTTPTestCase):
@classmethod
Expand Down

0 comments on commit 5716a67

Please sign in to comment.