diff --git a/cmdutil/when.go b/cmdutil/when.go index e15077903..2446e5772 100644 --- a/cmdutil/when.go +++ b/cmdutil/when.go @@ -1,49 +1,61 @@ package cmdutil import ( - "fmt" "os" "strings" "github.com/expr-lang/expr" + "github.com/expr-lang/expr/ast" "github.com/pkg/errors" ) +// AST walker which replaces `$IDENTIFIER` with `Env.IDENTIFIER` member lookup expressions. +type EnvPatcher struct{} + +func (ep *EnvPatcher) Visit(node *ast.Node) { + if n, ok := (*node).(*ast.IdentifierNode); ok && n.Value[0] == '$' && n.Value != "$env" { + ast.Patch( + node, + &ast.MemberNode{ + Node: &ast.IdentifierNode{Value: "Env"}, + Property: &ast.StringNode{Value: n.Value[1:]}, + }, + ) + } +} + +// The predefined variables of a when expression +type WhenEnv struct { + Env map[string]string +} + +var NewWhenEnv = func() *WhenEnv { + return &WhenEnv{Env: envMap()} +} + func IsAllowedToExecute(when string) (bool, error) { if when == "" { return true, nil } - ropts := []string{} - em := envMap() - for k := range em { - ropts = append(ropts, fmt.Sprintf("$%s", k), fmt.Sprintf("Env.%s", k)) - } - r := strings.NewReplacer(ropts...) - when = r.Replace(when) - got, err := expr.Eval(fmt.Sprintf("(%s) == true", when), struct { - Env map[string]string - }{ - Env: em, - }) + + // when expressions must produce a boolean result + program, err := expr.Compile(when, expr.Patch(&EnvPatcher{}), expr.AsBool()) if err != nil { return false, errors.WithStack(err) } - return got.(bool), nil + if got, err := expr.Run(program, NewWhenEnv()); err != nil { + return false, errors.WithStack(err) + } else { + return got.(bool), nil + } } func envMap() map[string]string { m := map[string]string{} for _, kv := range os.Environ() { - if !strings.Contains(kv, "=") { - continue - } - parts := strings.SplitN(kv, "=", 2) - k := parts[0] - if len(parts) < 2 { - m[k] = "" - continue + if k, v, ok := strings.Cut(kv, "="); ok { + m[k] = v } - m[k] = parts[1] } return m } diff --git a/cmdutil/when_test.go b/cmdutil/when_test.go index 8cdb38593..086ae65fd 100644 --- a/cmdutil/when_test.go +++ b/cmdutil/when_test.go @@ -1,48 +1,128 @@ package cmdutil import ( - "os" + "strings" "testing" ) +func TestEnvMap(t *testing.T) { + t.Setenv("TEST_ENV_EMPTY", "") + t.Setenv("TEST_ENV_SET", "value") + result := envMap() + if value, ok := result["TEST_ENV_EMPTY"]; !ok { + t.Error("Expected TEST_ENV_EMPTY to be set") + } else if value != "" { + t.Errorf("Expected TEST_ENV_EMPTY to be an empty string, got %v", value) + } + if value, ok := result["TEST_ENV_SET"]; !ok { + t.Error("Expected TEST_ENV_SET to be set") + } else if value != "value" { + t.Errorf("Expected TEST_ENV_SET to be 'value', got %v", value) + } +} + func TestIsAllowedToExecute(t *testing.T) { tests := []struct { - envset map[string]string - when string - want bool + name string + envset map[string]string + when string + want bool + errorContains any }{ { + name: "Empty expression", + envset: map[string]string{}, + when: "", + want: true, + errorContains: nil, + }, + { + name: "Equality test, true", + envset: map[string]string{ + "TEST_ENV1": "a", + }, + when: "$TEST_ENV1 == 'a'", + want: true, + errorContains: nil, + }, + { + name: "Equality test, false", envset: map[string]string{ "TEST_ENV1": "a", }, - when: "$TEST_ENV1 == 'a'", - want: true, + when: "$TEST_ENV1 == 'b'", + want: false, + errorContains: nil, }, { + name: "Containment in $env", envset: map[string]string{ + "env": "should not replace $env", "TEST_ENV1": "a", }, - when: "$TEST_ENV1 == 'b'", - want: false, + when: `'TEST_ENV1' not in $env`, + want: true, + errorContains: nil, }, { + name: "Containment in Env", envset: map[string]string{ "TEST_ENV1": "a", }, - when: `$TEST_ENV1 == "a"`, - want: true, + when: "'TEST_ENV1' in Env", + want: true, + errorContains: nil, + }, + { + name: "Env var name is used in string literal", + envset: map[string]string{ + "TEST_ENV1": "foo", + "TEST_ENV2": "$TEST_ENV1", + }, + when: `$TEST_ENV2 == '$TEST_ENV1'`, + want: true, + errorContains: nil, + }, + { + name: "Env var not set", + envset: map[string]string{}, + when: `$TEST_ENV_NONESUCH == ""`, + want: true, + errorContains: nil, + }, + { + name: "Invalid expression", + envset: map[string]string{}, + when: `($TEST_ENV1 == "Missing parentheses"`, + want: false, + errorContains: "unexpected token EOF", + }, + { + name: "Expression produces a non-boolean result", + envset: map[string]string{}, + when: `"String literal expression"`, + want: false, + errorContains: "expected bool, but got string", }, } for _, tt := range tests { - for k, v := range tt.envset { - os.Setenv(k, v) - } - got, err := IsAllowedToExecute(tt.when) - if err != nil { - t.Fatal(err) - } - if got != tt.want { - t.Errorf("got %v\nwant %v", got, tt.want) - } + t.Run(tt.name, func(t *testing.T) { + NewWhenEnv = func() *WhenEnv { return &WhenEnv{Env: tt.envset} } + got, err := IsAllowedToExecute(tt.when) + if err != nil { + if tt.errorContains != nil { + if !strings.Contains(err.Error(), tt.errorContains.(string)) { + t.Errorf("Error %v does not contain %s", err, tt.errorContains) + } + } else { + t.Error(err) + } + } else if tt.errorContains != nil { + t.Errorf("Expected an error containing %v", tt.errorContains) + } + if got != tt.want { + t.Errorf("got %v\nwant %v", got, tt.want) + } + }) } }