-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Use expr.Patch and a visitor to only replace identifiers that start with `$` with a `Env.` member lookup. - use `expr.AsBool()` to assert the expression produces a boolean - Clean up envMap parsing to use strings.Cut() - Expand tests, using testing.T.Run() with test names.
- Loading branch information
Showing
2 changed files
with
135 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,61 @@ | ||
package cmdutil | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"strings" | ||
|
||
"github.com/expr-lang/expr" | ||
"github.com/expr-lang/expr/ast" | ||
"github.com/pkg/errors" | ||
) | ||
|
||
// AST walker which replaces `$IDENTIFIER` with `Env.IDENTIFIER` member lookup expressions. | ||
type EnvPatcher struct{} | ||
|
||
func (ep *EnvPatcher) Visit(node *ast.Node) { | ||
if n, ok := (*node).(*ast.IdentifierNode); ok && n.Value[0] == '$' && n.Value != "$env" { | ||
ast.Patch( | ||
node, | ||
&ast.MemberNode{ | ||
Node: &ast.IdentifierNode{Value: "Env"}, | ||
Property: &ast.StringNode{Value: n.Value[1:]}, | ||
}, | ||
) | ||
} | ||
} | ||
|
||
// The predefined variables of a when expression | ||
type WhenEnv struct { | ||
Env map[string]string | ||
} | ||
|
||
var NewWhenEnv = func() *WhenEnv { | ||
return &WhenEnv{Env: envMap()} | ||
} | ||
|
||
func IsAllowedToExecute(when string) (bool, error) { | ||
if when == "" { | ||
return true, nil | ||
} | ||
ropts := []string{} | ||
em := envMap() | ||
for k := range em { | ||
ropts = append(ropts, fmt.Sprintf("$%s", k), fmt.Sprintf("Env.%s", k)) | ||
} | ||
r := strings.NewReplacer(ropts...) | ||
when = r.Replace(when) | ||
got, err := expr.Eval(fmt.Sprintf("(%s) == true", when), struct { | ||
Env map[string]string | ||
}{ | ||
Env: em, | ||
}) | ||
|
||
// when expressions must produce a boolean result | ||
program, err := expr.Compile(when, expr.Patch(&EnvPatcher{}), expr.AsBool()) | ||
if err != nil { | ||
return false, errors.WithStack(err) | ||
} | ||
return got.(bool), nil | ||
if got, err := expr.Run(program, NewWhenEnv()); err != nil { | ||
return false, errors.WithStack(err) | ||
} else { | ||
return got.(bool), nil | ||
} | ||
} | ||
|
||
func envMap() map[string]string { | ||
m := map[string]string{} | ||
for _, kv := range os.Environ() { | ||
if !strings.Contains(kv, "=") { | ||
continue | ||
} | ||
parts := strings.SplitN(kv, "=", 2) | ||
k := parts[0] | ||
if len(parts) < 2 { | ||
m[k] = "" | ||
continue | ||
if k, v, ok := strings.Cut(kv, "="); ok { | ||
m[k] = v | ||
} | ||
m[k] = parts[1] | ||
} | ||
return m | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,128 @@ | ||
package cmdutil | ||
|
||
import ( | ||
"os" | ||
"strings" | ||
"testing" | ||
) | ||
|
||
func TestEnvMap(t *testing.T) { | ||
t.Setenv("TEST_ENV_EMPTY", "") | ||
t.Setenv("TEST_ENV_SET", "value") | ||
result := envMap() | ||
if value, ok := result["TEST_ENV_EMPTY"]; !ok { | ||
t.Error("Expected TEST_ENV_EMPTY to be set") | ||
} else if value != "" { | ||
t.Errorf("Expected TEST_ENV_EMPTY to be an empty string, got %v", value) | ||
} | ||
if value, ok := result["TEST_ENV_SET"]; !ok { | ||
t.Error("Expected TEST_ENV_SET to be set") | ||
} else if value != "value" { | ||
t.Errorf("Expected TEST_ENV_SET to be 'value', got %v", value) | ||
} | ||
} | ||
|
||
func TestIsAllowedToExecute(t *testing.T) { | ||
tests := []struct { | ||
envset map[string]string | ||
when string | ||
want bool | ||
name string | ||
envset map[string]string | ||
when string | ||
want bool | ||
errorContains any | ||
}{ | ||
{ | ||
name: "Empty expression", | ||
envset: map[string]string{}, | ||
when: "", | ||
want: true, | ||
errorContains: nil, | ||
}, | ||
{ | ||
name: "Equality test, true", | ||
envset: map[string]string{ | ||
"TEST_ENV1": "a", | ||
}, | ||
when: "$TEST_ENV1 == 'a'", | ||
want: true, | ||
errorContains: nil, | ||
}, | ||
{ | ||
name: "Equality test, false", | ||
envset: map[string]string{ | ||
"TEST_ENV1": "a", | ||
}, | ||
when: "$TEST_ENV1 == 'a'", | ||
want: true, | ||
when: "$TEST_ENV1 == 'b'", | ||
want: false, | ||
errorContains: nil, | ||
}, | ||
{ | ||
name: "Containment in $env", | ||
envset: map[string]string{ | ||
"env": "should not replace $env", | ||
"TEST_ENV1": "a", | ||
}, | ||
when: "$TEST_ENV1 == 'b'", | ||
want: false, | ||
when: `'TEST_ENV1' not in $env`, | ||
want: true, | ||
errorContains: nil, | ||
}, | ||
{ | ||
name: "Containment in Env", | ||
envset: map[string]string{ | ||
"TEST_ENV1": "a", | ||
}, | ||
when: `$TEST_ENV1 == "a"`, | ||
want: true, | ||
when: "'TEST_ENV1' in Env", | ||
want: true, | ||
errorContains: nil, | ||
}, | ||
{ | ||
name: "Env var name is used in string literal", | ||
envset: map[string]string{ | ||
"TEST_ENV1": "foo", | ||
"TEST_ENV2": "$TEST_ENV1", | ||
}, | ||
when: `$TEST_ENV2 == '$TEST_ENV1'`, | ||
want: true, | ||
errorContains: nil, | ||
}, | ||
{ | ||
name: "Env var not set", | ||
envset: map[string]string{}, | ||
when: `$TEST_ENV_NONESUCH == ""`, | ||
want: true, | ||
errorContains: nil, | ||
}, | ||
{ | ||
name: "Invalid expression", | ||
envset: map[string]string{}, | ||
when: `($TEST_ENV1 == "Missing parentheses"`, | ||
want: false, | ||
errorContains: "unexpected token EOF", | ||
}, | ||
{ | ||
name: "Expression produces a non-boolean result", | ||
envset: map[string]string{}, | ||
when: `"String literal expression"`, | ||
want: false, | ||
errorContains: "expected bool, but got string", | ||
}, | ||
} | ||
for _, tt := range tests { | ||
for k, v := range tt.envset { | ||
os.Setenv(k, v) | ||
} | ||
got, err := IsAllowedToExecute(tt.when) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if got != tt.want { | ||
t.Errorf("got %v\nwant %v", got, tt.want) | ||
} | ||
t.Run(tt.name, func(t *testing.T) { | ||
NewWhenEnv = func() *WhenEnv { return &WhenEnv{Env: tt.envset} } | ||
got, err := IsAllowedToExecute(tt.when) | ||
if err != nil { | ||
if tt.errorContains != nil { | ||
if !strings.Contains(err.Error(), tt.errorContains.(string)) { | ||
t.Errorf("Error %v does not contain %s", err, tt.errorContains) | ||
} | ||
} else { | ||
t.Error(err) | ||
} | ||
} else if tt.errorContains != nil { | ||
t.Errorf("Expected an error containing %v", tt.errorContains) | ||
} | ||
if got != tt.want { | ||
t.Errorf("got %v\nwant %v", got, tt.want) | ||
} | ||
}) | ||
} | ||
} |