diff --git a/plugins/internal/tengoutil/secure_script.go b/plugins/internal/tengoutil/secure_script.go index 388fd2cb..345e5af6 100644 --- a/plugins/internal/tengoutil/secure_script.go +++ b/plugins/internal/tengoutil/secure_script.go @@ -1,7 +1,12 @@ package tengoutil import ( + "context" + "errors" "fmt" + "io" + "net/http" + "time" "github.com/d5/tengo/v2" "github.com/d5/tengo/v2/stdlib" @@ -12,13 +17,23 @@ const ( maxConsts = 500 ) +const expectedArgsLength = 2 + +var defaultTimeout = 5 * time.Second + +var httpModule = map[string]tengo.Object{ + "get": httpGetFunction, +} + func NewSecureScript(input []byte, globals map[string]interface{}) (*tengo.Script, error) { s := tengo.NewScript(input) - s.SetImports(stdlib.GetModuleMap( + modules := stdlib.GetModuleMap( // `os` is excluded, should *not* be importable from script. "math", "text", "times", "rand", "fmt", "json", "base64", "hex", "enum", - )) + ) + modules.AddBuiltinModule("http", httpModule) + s.SetImports(modules) s.SetMaxAllocs(maxAllocs) s.SetMaxConstObjects(maxConsts) @@ -30,3 +45,81 @@ func NewSecureScript(input []byte, globals map[string]interface{}) (*tengo.Scrip return s, nil } + +var httpGetFunction = &tengo.UserFunction{ + Name: "get", + Value: func(args ...tengo.Object) (tengo.Object, error) { + url, err := extractURL(args) + if err != nil { + return nil, err + } + headers, err := extractHeaders(args) + if err != nil { + return nil, err + } + + return performGetRequest(url, headers, defaultTimeout) + }, +} + +func extractURL(args []tengo.Object) (string, error) { + if len(args) < 1 { + return "", errors.New("expected at least 1 argument (URL)") + } + url, ok := tengo.ToString(args[0]) + if !ok { + return "", errors.New("expected argument 1 (URL) to be a string") + } + + return url, nil +} + +func extractHeaders(args []tengo.Object) (map[string]string, error) { + headers := make(map[string]string) + if len(args) == expectedArgsLength { + headerMap, ok := args[1].(*tengo.Map) + if !ok { + return nil, fmt.Errorf("expected argument %d (headers) to be a map", expectedArgsLength) + } + for key, value := range headerMap.Value { + strValue, valueOk := tengo.ToString(value) + if !valueOk { + return nil, fmt.Errorf("header value for key '%s' must be a string, got %T", key, value) + } + headers[key] = strValue + } + } + + return headers, nil +} + +func performGetRequest(url string, headers map[string]string, timeout time.Duration) (tengo.Object, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + for key, value := range headers { + req.Header.Add(key, value) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return &tengo.Map{ + Value: map[string]tengo.Object{ + "body": &tengo.String{Value: string(body)}, + "code": &tengo.Int{Value: int64(resp.StatusCode)}, + }, + }, nil +} diff --git a/plugins/internal/tengoutil/secure_script_test.go b/plugins/internal/tengoutil/secure_script_test.go index cb244511..c2ea06ed 100644 --- a/plugins/internal/tengoutil/secure_script_test.go +++ b/plugins/internal/tengoutil/secure_script_test.go @@ -5,6 +5,7 @@ package tengoutil import ( "testing" + "time" "github.com/MakeNowJust/heredoc" "github.com/stretchr/testify/assert" @@ -55,4 +56,63 @@ func TestNewSecureScript(t *testing.T) { _, err = s.Compile() assert.NoError(t, err) }) + + t.Run("Allows import of custom http module", func(t *testing.T) { + s, err := NewSecureScript(([]byte)(heredoc.Doc(` + http := import("http") + response := http.get("http://example.com") + response.body + `)), nil) + assert.NoError(t, err) + _, err = s.Compile() + assert.NoError(t, err) + }) + + t.Run("HTTP GET with headers", func(t *testing.T) { + s, err := NewSecureScript(([]byte)(heredoc.Doc(` + http := import("http") + headers := { "User-Agent": "test-agent", "Accept": "application/json" } + response := http.get("http://example.com", headers) + response.body + `)), nil) + assert.NoError(t, err) + + _, err = s.Compile() + assert.NoError(t, err) + }) + + t.Run("HTTP GET with invalid URL argument type", func(t *testing.T) { + s, err := NewSecureScript(([]byte)(heredoc.Doc(` + http := import("http") + http.get(12345) + `)), nil) + assert.NoError(t, err) + + _, err = s.Compile() + assert.NoError(t, err) + + _, err = s.Run() + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported protocol scheme") + }) + + t.Run("HTTP GET with timeout", func(t *testing.T) { + s, err := NewSecureScript(([]byte)(heredoc.Doc(` + http := import("http") + response := http.get("http://example.com") + response.body + `)), nil) + assert.NoError(t, err) + + originalTimeout := defaultTimeout + defaultTimeout = 1 * time.Millisecond + defer func() { defaultTimeout = originalTimeout }() + + _, err = s.Compile() + assert.NoError(t, err) + + _, err = s.Run() + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") + }) }