diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py index f10c7bf..bd18da8 100644 --- a/awslambdaric/bootstrap.py +++ b/awslambdaric/bootstrap.py @@ -102,7 +102,6 @@ def replace_line_indentation(line, indent_char, new_indent_char): if _AWS_LAMBDA_LOG_FORMAT == LogFormat.JSON: _ERROR_FRAME_TYPE = _JSON_FRAME_TYPES[logging.ERROR] - _WARNING_FRAME_TYPE = _JSON_FRAME_TYPES[logging.WARNING] def log_error(error_result, log_sink): error_result = { @@ -118,7 +117,6 @@ def log_error(error_result, log_sink): else: _ERROR_FRAME_TYPE = _TEXT_FRAME_TYPES[logging.ERROR] - _WARNING_FRAME_TYPE = _TEXT_FRAME_TYPES[logging.WARNING] def log_error(error_result, log_sink): error_description = "[ERROR]" @@ -203,7 +201,7 @@ def handle_event_request( if error_result is not None: from .lambda_literals import lambda_unhandled_exception_warning_message - log_sink.log(lambda_unhandled_exception_warning_message, _WARNING_FRAME_TYPE) + logging.warning(lambda_unhandled_exception_warning_message) log_error(error_result, log_sink) lambda_runtime_client.post_invocation_error( invoke_id, to_json(error_result), to_json(xray_fault) diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 79dcae6..7397d62 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -18,7 +18,11 @@ import awslambdaric.bootstrap as bootstrap from awslambdaric.lambda_runtime_exception import FaultException -from awslambdaric.lambda_runtime_log_utils import LogFormat, _get_log_level_from_env_var +from awslambdaric.lambda_runtime_log_utils import ( + LogFormat, + _get_log_level_from_env_var, + JsonFormatter, +) from awslambdaric.lambda_runtime_marshaller import LambdaMarshaller from awslambdaric.lambda_literals import ( lambda_unhandled_exception_warning_message, @@ -61,6 +65,14 @@ def setUp(self): self.event_body = '"event_body"' self.working_directory = os.getcwd() + logging.getLogger().handlers.clear() + + def tearDown(self) -> None: + logging.getLogger().handlers.clear() + logging.getLogger().level = logging.NOTSET + + return super().tearDown() + @staticmethod def dummy_handler(json_input, lambda_context): return {"input": json_input, "aws_request_id": lambda_context.aws_request_id} @@ -451,6 +463,8 @@ def raise_exception_handler(json_input, lambda_context): ), ) + logging.getLogger().addHandler(logging.StreamHandler(mock_stdout)) + bootstrap.handle_event_request( self.lambda_runtime, raise_exception_handler, @@ -467,6 +481,7 @@ def raise_exception_handler(json_input, lambda_context): # NOTE: Indentation characters are NO-BREAK SPACE (U+00A0) not SPACE (U+0020) error_logs = ( lambda_unhandled_exception_warning_message + + "\n" + "[ERROR] FaultExceptionType: Fault exception msg\r" ) error_logs += "Traceback (most recent call last):\r" @@ -487,6 +502,8 @@ def raise_exception_handler(json_input, lambda_context): "FaultExceptionType", "Fault exception msg", None ) + logging.getLogger().addHandler(logging.StreamHandler(mock_stdout)) + bootstrap.handle_event_request( self.lambda_runtime, raise_exception_handler, @@ -501,6 +518,7 @@ def raise_exception_handler(json_input, lambda_context): ) error_logs = ( lambda_unhandled_exception_warning_message + + "\n" + "[ERROR] FaultExceptionType: Fault exception msg\rTraceback (most recent call last):\n" ) @@ -516,6 +534,8 @@ def raise_exception_handler(json_input, lambda_context): except ImportError: raise bootstrap.FaultException("FaultExceptionType", None, None) + logging.getLogger().addHandler(logging.StreamHandler(mock_stdout)) + bootstrap.handle_event_request( self.lambda_runtime, raise_exception_handler, @@ -530,6 +550,7 @@ def raise_exception_handler(json_input, lambda_context): ) error_logs = ( lambda_unhandled_exception_warning_message + + "\n" + "[ERROR] FaultExceptionType\rTraceback (most recent call last):\n" ) @@ -545,6 +566,8 @@ def raise_exception_handler(json_input, lambda_context): except ImportError: raise bootstrap.FaultException(None, "Fault exception msg", None) + logging.getLogger().addHandler(logging.StreamHandler(mock_stdout)) + bootstrap.handle_event_request( self.lambda_runtime, raise_exception_handler, @@ -559,6 +582,7 @@ def raise_exception_handler(json_input, lambda_context): ) error_logs = ( lambda_unhandled_exception_warning_message + + "\n" + "[ERROR] Fault exception msg\rTraceback (most recent call last):\n" ) @@ -583,6 +607,8 @@ def raise_exception_handler(json_input, lambda_context): ), ) + logging.getLogger().addHandler(logging.StreamHandler(mock_stdout)) + bootstrap.handle_event_request( self.lambda_runtime, raise_exception_handler, @@ -595,7 +621,7 @@ def raise_exception_handler(json_input, lambda_context): 0, bootstrap.StandardLogSink(), ) - error_logs = lambda_unhandled_exception_warning_message + "[ERROR]\r" + error_logs = lambda_unhandled_exception_warning_message + "\n[ERROR]\r" error_logs += "Traceback (most recent call last):\r" error_logs += '  File "spam.py", line 3, in \r' error_logs += "    spam.eggs()\r" @@ -604,6 +630,48 @@ def raise_exception_handler(json_input, lambda_context): self.assertEqual(mock_stdout.getvalue(), error_logs) + @patch("sys.stdout", new_callable=StringIO) + def test_handle_event_request_fault_exception_logging_in_json(self, mock_stdout): + def raise_exception_handler(json_input, lambda_context): + try: + import invalid_module # noqa: F401 + except ImportError: + raise bootstrap.FaultException("FaultExceptionType", None, None) + + logging_handler = logging.StreamHandler(mock_stdout) + logging_handler.setFormatter(JsonFormatter()) + logging.getLogger().addHandler(logging_handler) + + bootstrap.handle_event_request( + self.lambda_runtime, + raise_exception_handler, + "invoke_id", + self.event_body, + "application/json", + {}, + {}, + "invoked_function_arn", + 0, + bootstrap.StandardLogSink(), + ) + + stdout_value = mock_stdout.getvalue() + received_warning = stdout_value.split("\n")[0] + received_rest = stdout_value[len(received_warning) + 1 :] + + warning = json.loads(received_warning) + self.assertEqual(warning["level"], "WARNING") + self.assertEqual(warning["message"], lambda_unhandled_exception_warning_message) + self.assertEqual(warning["logger"], "root") + self.assertIn("timestamp", warning) + + # this line is not in json because of the way the test runtime is bootstrapped + error_logs = ( + "\n[ERROR] FaultExceptionType\rTraceback (most recent call last):\n" + ) + + self.assertEqual(received_rest, error_logs) + class TestXrayFault(unittest.TestCase): def test_make_xray(self):