Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for snapstart runtime hooks #176

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions awslambdaric/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
_AWS_LAMBDA_LOG_LEVEL = _get_log_level_from_env_var(
os.environ.get("AWS_LAMBDA_LOG_LEVEL")
)
AWS_LAMBDA_INITIALIZATION_TYPE = "AWS_LAMBDA_INITIALIZATION_TYPE"
INIT_TYPE_SNAP_START = "snap-start"


def _get_handler(handler):
Expand Down Expand Up @@ -286,6 +288,29 @@ def extract_traceback(tb):
]


def on_init_complete(lambda_runtime_client, log_sink):
from . import lambda_runtime_hooks_runner

try:
lambda_runtime_hooks_runner.run_before_snapshot()
lambda_runtime_client.restore_next()
except:
error_result = build_fault_result(sys.exc_info(), None)
log_error(error_result, log_sink)
lambda_runtime_client.post_init_error(
error_result, FaultException.BEFORE_SNAPSHOT_ERROR
)
sys.exit(64)

try:
lambda_runtime_hooks_runner.run_after_restore()
except:
error_result = build_fault_result(sys.exc_info(), None)
log_error(error_result, log_sink)
lambda_runtime_client.report_restore_error(error_result)
sys.exit(65)


class LambdaLoggerHandler(logging.Handler):
def __init__(self, log_sink):
logging.Handler.__init__(self)
Expand Down Expand Up @@ -454,10 +479,10 @@ def run(app_root, handler, lambda_runtime_api_addr):
sys.stdout = Unbuffered(sys.stdout)
sys.stderr = Unbuffered(sys.stderr)

use_thread_for_polling_next = os.environ.get("AWS_EXECUTION_ENV") in [
use_thread_for_polling_next = os.environ.get("AWS_EXECUTION_ENV") in {
"AWS_Lambda_python3.12",
"AWS_Lambda_python3.13",
]
}

with create_log_sink() as log_sink:
lambda_runtime_client = LambdaRuntimeClient(
Expand Down Expand Up @@ -485,6 +510,9 @@ def run(app_root, handler, lambda_runtime_api_addr):

sys.exit(1)

if os.environ.get(AWS_LAMBDA_INITIALIZATION_TYPE) == INIT_TYPE_SNAP_START:
on_init_complete(lambda_runtime_client, log_sink)

while True:
event_request = lambda_runtime_client.wait_next_invocation()

Expand Down
50 changes: 41 additions & 9 deletions awslambdaric/lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,57 @@ def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False):
# Not defining symbol as global to avoid relying on TPE being imported unconditionally.
self.ThreadPoolExecutor = ThreadPoolExecutor

def post_init_error(self, error_response_data):
def call_rapid(
self, http_method, endpoint, expected_http_code, payload=None, headers=None
):
# These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`.
# Importing them lazily to speed up critical path of a common case.
import http
import http.client

runtime_connection = http.client.HTTPConnection(self.lambda_runtime_address)
runtime_connection.connect()
endpoint = "/2018-06-01/runtime/init/error"
headers = {ERROR_TYPE_HEADER: error_response_data["errorType"]}
runtime_connection.request(
"POST", endpoint, to_json(error_response_data), headers=headers
)
if http_method == "GET":
runtime_connection.request(http_method, endpoint)
else:
runtime_connection.request(
http_method, endpoint, to_json(payload), headers=headers
)

response = runtime_connection.getresponse()
response_body = response.read()

if response.code != http.HTTPStatus.ACCEPTED:
if response.code != expected_http_code:
raise LambdaRuntimeClientError(endpoint, response.code, response_body)

def post_init_error(self, error_response_data, error_type_override=None):
import http

endpoint = "/2018-06-01/runtime/init/error"
headers = {
ERROR_TYPE_HEADER: (
error_type_override
if error_type_override
else error_response_data["errorType"]
)
}
self.call_rapid(
"POST", endpoint, http.HTTPStatus.ACCEPTED, error_response_data, headers
)

def restore_next(self):
import http

endpoint = "/2018-06-01/runtime/restore/next"
self.call_rapid("GET", endpoint, http.HTTPStatus.OK)

def report_restore_error(self, restore_error_data):
import http

endpoint = "/2018-06-01/runtime/restore/error"
headers = {ERROR_TYPE_HEADER: FaultException.AFTER_RESTORE_ERROR}
self.call_rapid(
"POST", endpoint, http.HTTPStatus.ACCEPTED, restore_error_data, headers
)

def wait_next_invocation(self):
# Calling runtime_client.next() from a separate thread unblocks the main thread,
# which can then process signals.
Expand Down
2 changes: 2 additions & 0 deletions awslambdaric/lambda_runtime_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class FaultException(Exception):
IMPORT_MODULE_ERROR = "Runtime.ImportModuleError"
BUILT_IN_MODULE_CONFLICT = "Runtime.BuiltInModuleConflict"
MALFORMED_HANDLER_NAME = "Runtime.MalformedHandlerName"
BEFORE_SNAPSHOT_ERROR = "Runtime.BeforeSnapshotError"
AFTER_RESTORE_ERROR = "Runtime.AfterRestoreError"
LAMBDA_CONTEXT_UNMARSHAL_ERROR = "Runtime.LambdaContextUnmarshalError"
LAMBDA_RUNTIME_CLIENT_ERROR = "Runtime.LambdaRuntimeClientError"

Expand Down
18 changes: 18 additions & 0 deletions awslambdaric/lambda_runtime_hooks_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from snapshot_restore_py import get_before_snapshot, get_after_restore


def run_before_snapshot():
before_snapshot_callables = get_before_snapshot()
while before_snapshot_callables:
# Using pop as before checkpoint callables are executed in the reverse order of their registration
func, args, kwargs = before_snapshot_callables.pop()
func(*args, **kwargs)


def run_after_restore():
after_restore_callables = get_after_restore()
for func, args, kwargs in after_restore_callables:
func(*args, **kwargs)
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
simplejson>=3.18.4
snapshot-restore-py>=1.0.0
51 changes: 50 additions & 1 deletion tests/test_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import unittest
from io import StringIO
from tempfile import NamedTemporaryFile
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, Mock, patch, ANY

import awslambdaric.bootstrap as bootstrap
from awslambdaric.lambda_runtime_exception import FaultException
Expand All @@ -23,6 +23,7 @@
from awslambdaric.lambda_literals import (
lambda_unhandled_exception_warning_message,
)
import snapshot_restore_py


class TestUpdateXrayEnv(unittest.TestCase):
Expand Down Expand Up @@ -1457,5 +1458,53 @@ class TestException(Exception):
mock_sys.exit.assert_called_once_with(1)


class TestOnInitComplete(unittest.TestCase):
def tearDown(self):
# We are accessing private filed for cleaning up
snapshot_restore_py._before_snapshot_registry = []
snapshot_restore_py._after_restore_registry = []

# We are using ANY over here as the main thing we want to test is teh errorType propogation and stack trace generation
error_result = {
"errorMessage": "This is a Dummy type error",
"errorType": "TypeError",
"requestId": "",
"stackTrace": ANY,
}

def raise_type_error(self):
raise TypeError("This is a Dummy type error")

@patch("awslambdaric.bootstrap.LambdaRuntimeClient")
def test_before_snapshot_exception(self, mock_runtime_client):
snapshot_restore_py.register_before_snapshot(self.raise_type_error)

with self.assertRaises(SystemExit) as cm:
bootstrap.on_init_complete(
mock_runtime_client, log_sink=bootstrap.StandardLogSink()
)

self.assertEqual(cm.exception.code, 64)
mock_runtime_client.post_init_error.assert_called_once_with(
self.error_result,
FaultException.BEFORE_SNAPSHOT_ERROR,
)

@patch("awslambdaric.bootstrap.LambdaRuntimeClient")
def test_after_restore_exception(self, mock_runtime_client):
snapshot_restore_py.register_after_restore(self.raise_type_error)

with self.assertRaises(SystemExit) as cm:
bootstrap.on_init_complete(
mock_runtime_client, log_sink=bootstrap.StandardLogSink()
)

self.assertEqual(cm.exception.code, 65)
mock_runtime_client.restore_next.assert_called_once()
mock_runtime_client.report_restore_error.assert_called_once_with(
self.error_result
)


if __name__ == "__main__":
unittest.main()
73 changes: 73 additions & 0 deletions tests/test_lambda_runtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ def test_wait_next_invocation(self, mock_runtime_client):

headers = {"Lambda-Runtime-Function-Error-Type": error_result["errorType"]}

restore_error_result = {
"errorMessage": "Dummy Restore error",
"errorType": "Runtime.DummyRestoreError",
"requestId": "",
"stackTrace": [],
}

restore_error_header = {
"Lambda-Runtime-Function-Error-Type": "Runtime.AfterRestoreError"
}

before_snapshot_error_header = {
"Lambda-Runtime-Function-Error-Type": "Runtime.BeforeSnapshotError"
}

@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
def test_post_init_error(self, MockHTTPConnection):
mock_conn = MockHTTPConnection.return_value
Expand Down Expand Up @@ -225,6 +240,64 @@ def test_post_invocation_error_with_too_large_xray_cause(self, mock_runtime_clie
invoke_id, error_data, ""
)

@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
def test_restore_next(self, MockHTTPConnection):
mock_conn = MockHTTPConnection.return_value
mock_response = MagicMock(autospec=http.client.HTTPResponse)
mock_conn.getresponse.return_value = mock_response
mock_response.read.return_value = b""
mock_response.code = http.HTTPStatus.OK

runtime_client = LambdaRuntimeClient("localhost:1234")
runtime_client.restore_next()

MockHTTPConnection.assert_called_with("localhost:1234")
mock_conn.request.assert_called_once_with(
"GET",
"/2018-06-01/runtime/restore/next",
)
mock_response.read.assert_called_once()

@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
def test_restore_error(self, MockHTTPConnection):
mock_conn = MockHTTPConnection.return_value
mock_response = MagicMock(autospec=http.client.HTTPResponse)
mock_conn.getresponse.return_value = mock_response
mock_response.read.return_value = b""
mock_response.code = http.HTTPStatus.ACCEPTED

runtime_client = LambdaRuntimeClient("localhost:1234")
runtime_client.report_restore_error(self.restore_error_result)

MockHTTPConnection.assert_called_with("localhost:1234")
mock_conn.request.assert_called_once_with(
"POST",
"/2018-06-01/runtime/restore/error",
to_json(self.restore_error_result),
headers=self.restore_error_header,
)
mock_response.read.assert_called_once()

@patch("http.client.HTTPConnection", autospec=http.client.HTTPConnection)
def test_init_before_snapshot_error(self, MockHTTPConnection):
mock_conn = MockHTTPConnection.return_value
mock_response = MagicMock(autospec=http.client.HTTPResponse)
mock_conn.getresponse.return_value = mock_response
mock_response.read.return_value = b""
mock_response.code = http.HTTPStatus.ACCEPTED

runtime_client = LambdaRuntimeClient("localhost:1234")
runtime_client.post_init_error(self.error_result, "Runtime.BeforeSnapshotError")

MockHTTPConnection.assert_called_with("localhost:1234")
mock_conn.request.assert_called_once_with(
"POST",
"/2018-06-01/runtime/init/error",
to_json(self.error_result),
headers=self.before_snapshot_error_header,
)
mock_response.read.assert_called_once()

def test_connection_refused(self):
with self.assertRaises(ConnectionRefusedError):
runtime_client = LambdaRuntimeClient("127.0.0.1:1")
Expand Down
65 changes: 65 additions & 0 deletions tests/test_runtime_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import unittest
from unittest.mock import patch, call
from awslambdaric import lambda_runtime_hooks_runner
import snapshot_restore_py


def fun_test1():
print("In function ONE")


def fun_test2():
print("In function TWO")


def fun_with_args_kwargs(x, y, **kwargs):
print("Here are the args:", x, y)
print("Here are the keyword args:", kwargs)


class TestRuntimeHooks(unittest.TestCase):
def tearDown(self):
# We are accessing private filed for cleaning up
snapshot_restore_py._before_snapshot_registry = []
snapshot_restore_py._after_restore_registry = []

@patch("builtins.print")
def test_before_snapshot_execution_order(self, mock_print):
snapshot_restore_py.register_before_snapshot(
fun_with_args_kwargs, 5, 7, arg1="Lambda", arg2="SnapStart"
)
snapshot_restore_py.register_before_snapshot(fun_test2)
snapshot_restore_py.register_before_snapshot(fun_test1)

lambda_runtime_hooks_runner.run_before_snapshot()

calls = []
calls.append(call("In function ONE"))
calls.append(call("In function TWO"))
calls.append(call("Here are the args:", 5, 7))
calls.append(
call("Here are the keyword args:", {"arg1": "Lambda", "arg2": "SnapStart"})
)
self.assertEqual(calls, mock_print.mock_calls)

@patch("builtins.print")
def test_after_restore_execution_order(self, mock_print):
snapshot_restore_py.register_after_restore(
fun_with_args_kwargs, 11, 13, arg1="Lambda", arg2="SnapStart"
)
snapshot_restore_py.register_after_restore(fun_test2)
snapshot_restore_py.register_after_restore(fun_test1)

lambda_runtime_hooks_runner.run_after_restore()

calls = []
calls.append(call("Here are the args:", 11, 13))
calls.append(
call("Here are the keyword args:", {"arg1": "Lambda", "arg2": "SnapStart"})
)
calls.append(call("In function TWO"))
calls.append(call("In function ONE"))
self.assertEqual(calls, mock_print.mock_calls)