diff --git a/lambda/handler.go b/lambda/handler.go index e4cfaf7a..a4960b0a 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -10,8 +10,10 @@ import ( "fmt" "io" "io/ioutil" // nolint:staticcheck + "os" "reflect" "strings" + "syscall" "github.com/aws/aws-lambda-go/lambda/handlertrace" ) @@ -28,6 +30,8 @@ type handlerOptions struct { jsonResponseIndentValue string enableSIGTERM bool sigtermCallbacks []func() + enableExecWrapper bool + execWrapperCallbacks []func() } type Option func(*handlerOptions) @@ -102,6 +106,28 @@ func WithEnableSIGTERM(callbacks ...func()) Option { }) } +// WithEnableExecWrapper enables applying the value of the AWS_LAMBDA_EXEC_WRAPPER environment +// variable. If this fariable is set, the current process will be re-started, wrapped under the +// specified wrapper script. Optionally, an array of callback functions to run before restarting +// the process may be provided. +// +// Usage: +// +// lambda.StartWithOptions( +// func (event any) (any, error) { +// return event, nil +// }, +// lambda.WithEnableExecWrapper(func(){ +// log.Print("[AWS_LAMBDA_EXEC_WRAPPER] process is about to be re-started...") +// }), +// ) +func WithEnableExecWrapper(callbacks ...func()) Option { + return Option(func(h *handlerOptions) { + h.execWrapperCallbacks = append(h.execWrapperCallbacks, callbacks...) + h.enableExecWrapper = true + }) +} + // handlerTakesContext returns whether the handler takes a context.Context as its first argument. func handlerTakesContext(handler reflect.Type) (bool, error) { switch handler.NumIn() { @@ -184,6 +210,9 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions { for _, option := range options { option(h) } + if h.enableExecWrapper { + execAWSLambdaExecWrapper(os.Getenv, syscall.Exec, h.execWrapperCallbacks) + } if h.enableSIGTERM { enableSIGTERM(h.sigtermCallbacks) } diff --git a/lambda/wrapper.go b/lambda/wrapper.go new file mode 100644 index 00000000..854a239c --- /dev/null +++ b/lambda/wrapper.go @@ -0,0 +1,40 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +package lambda + +import ( + "log" + "os" +) + +const awsLambdaExecWrapper = "AWS_LAMBDA_EXEC_WRAPPER" + +// execAWSLambdaExecWrapper applies the AWS_LAMBDA_EXEC_WRAPPER environment variable. +// If AWS_LAMBDA_EXEC_WRAPPER is defined, replace the current process by spawning +// it with the current process' arguments (including the program name). If the call +// to syscall.Exec fails, this aborts the process with a fatal error. +func execAWSLambdaExecWrapper( + getenv func(key string) string, + sysExec func(argv0 string, argv []string, envv []string) error, + callbacks []func(), +) { + wrapper := getenv(awsLambdaExecWrapper) + if wrapper == "" { + return + } + + // Execute the provided callbacks before re-starting the process... + for _, callback := range callbacks { + callback() + } + + // The AWS_LAMBDA_EXEC_WRAPPER variable is blanked before replacing the process + // in order to avoid endlessly restarting the process. + env := append(os.Environ(), awsLambdaExecWrapper+"=") + if err := sysExec(wrapper, append([]string{wrapper}, os.Args...), env); err != nil { + log.Fatalf("failed to sysExec() %s=%s: %v", awsLambdaExecWrapper, wrapper, err) + } +} diff --git a/lambda/wrapper_test.go b/lambda/wrapper_test.go new file mode 100644 index 00000000..05ed1d85 --- /dev/null +++ b/lambda/wrapper_test.go @@ -0,0 +1,65 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +package lambda + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExecAwsLambdaExecWrapperNotSet(t *testing.T) { + var called bool + callback := func() { called = true } + + exec, execCalled := mockExec(t, "") + execAWSLambdaExecWrapper( + mockedGetenv(t, ""), + exec, + []func(){callback}, + ) + require.False(t, *execCalled) + require.False(t, called) +} + +func TestExecAwsLambdaExecWrapperSet(t *testing.T) { + var called bool + callback := func() { called = true } + + wrapper := "/path/to/wrapper/entry/point" + exec, execCalled := mockExec(t, wrapper) + execAWSLambdaExecWrapper( + mockedGetenv(t, wrapper), + exec, + []func(){callback}, + ) + require.True(t, *execCalled) + require.True(t, called) +} + +func mockExec(t *testing.T, value string) (mock func(string, []string, []string) error, called *bool) { + mock = func(argv0 string, argv []string, envv []string) error { + *called = true + require.Equal(t, value, argv0) + require.Equal(t, append([]string{value}, os.Args...), argv) + require.Equal(t, awsLambdaExecWrapper+"=", envv[len(envv)-1]) + return nil + } + called = ptrTo(false) + return +} + +func mockedGetenv(t *testing.T, value string) func(string) string { + return func(key string) string { + require.Equal(t, awsLambdaExecWrapper, key) + return value + } +} + +func ptrTo(val bool) *bool { + return &val +}