Skip to content

Commit

Permalink
MERGE AFTER v9.14 -- ⭐ support named arguments in functions (#2964)
Browse files Browse the repository at this point in the history
* ⭐ support named arguments in functions

This is an exciting new v10 feature that allows users to set named arguments in functions.

This means that you can do things like this:

```coffee
users.all(user:
  groups.contains(group:
    user.uid == group.gid
  )
)
```

This handles all cases where you felt you had to assign `_` to a
variable. This makes it easier in all cases where you could only use one
expression (like with `all`, `one`, etc).

You can still rely on using `_` if no variable name is defined. This
remains the default.

Signed-off-by: Dominik Richter <[email protected]>

* 🐛 support named functions in dicts

Signed-off-by: Dominik Richter <[email protected]>

---------

Signed-off-by: Dominik Richter <[email protected]>
  • Loading branch information
arlimus authored Jan 7, 2024
1 parent 6866bf1 commit 1a6763d
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 25 deletions.
18 changes: 18 additions & 0 deletions mql/mql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,24 @@ func TestNullResources(t *testing.T) {
})
}

func TestNamedFunctions(t *testing.T) {
x := testutils.InitTester(testutils.LinuxMock())
x.TestSimple(t, []testutils.SimpleTest{
{
Code: "muser.groups.where(group: group != empty).length",
ResultIndex: 0, Expectation: int64(1),
},
{
Code: "muser.groups.where(_: _ != empty).length",
ResultIndex: 0, Expectation: int64(1),
},
{
Code: "muser.dict.listInt.where(x: x == 2)",
ResultIndex: 0, Expectation: []any{int64(2)},
},
})
}

func TestNullString(t *testing.T) {
x := testutils.InitTester(testutils.LinuxMock())
x.TestSimple(t, []testutils.SimpleTest{
Expand Down
16 changes: 11 additions & 5 deletions mqlc/builtin_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ func compileArrayWhere(c *compiler, typ types.Type, ref uint64, id string, call
}

arg := call.Function[0]
bindingName := "_"
if arg.Name != "" {
return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported")
bindingName = arg.Name
}

refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref)
refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref, bindingName)
if err != nil {
return types.Nil, err
}
Expand Down Expand Up @@ -103,8 +104,12 @@ func compileArrayDuplicates(c *compiler, typ types.Type, ref uint64, id string,
return types.Nil, errors.New("too many arguments when calling '" + id + "'")
} else if call != nil && len(call.Function) == 1 {
arg := call.Function[0]
bindingName := "_"
if arg.Name != "" {
bindingName = arg.Name
}

refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref)
refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref, bindingName)
if err != nil {
return types.Nil, err
}
Expand Down Expand Up @@ -457,11 +462,12 @@ func compileArrayMap(c *compiler, typ types.Type, ref uint64, id string, call *p
}

arg := call.Function[0]
bindingName := "_"
if arg.Name != "" {
return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported")
bindingName = arg.Name
}

refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref)
refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref, bindingName)
if err != nil {
return types.Nil, err
}
Expand Down
18 changes: 16 additions & 2 deletions mqlc/builtin_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ func compileDictWhere(c *compiler, typ types.Type, ref uint64, id string, call *
}

arg := call.Function[0]
bindingName := "_"
if arg.Name != "" {
return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported")
bindingName = arg.Name
}

keyType := types.Dict
Expand Down Expand Up @@ -60,6 +61,12 @@ func compileDictWhere(c *compiler, typ types.Type, ref uint64, id string, call *
// we want to make sure the `_` points to the value, which is useful when dealing
// with arrays and the default in maps
blockCompiler.Binding.ref = blockCompiler.tailRef()
if bindingName != "_" {
blockCompiler.vars.add(bindingName, variable{
ref: blockCompiler.Binding.ref,
typ: valueType,
})
}

err := blockCompiler.compileExpressions([]*parser.Expression{arg.Value})
c.Result.Suggestions = append(c.Result.Suggestions, blockCompiler.Result.Suggestions...)
Expand Down Expand Up @@ -467,8 +474,9 @@ func compileMapWhere(c *compiler, typ types.Type, ref uint64, id string, call *p
}

arg := call.Function[0]
bindingName := "_"
if arg.Name != "" {
return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported")
bindingName = arg.Name
}

keyType := typ.Key()
Expand Down Expand Up @@ -501,6 +509,12 @@ func compileMapWhere(c *compiler, typ types.Type, ref uint64, id string, call *p
// we want to make sure the `_` points to the value, which is useful when dealing
// with arrays and the default in maps
blockCompiler.Binding.ref = blockCompiler.tailRef()
if bindingName != "_" {
blockCompiler.vars.add(bindingName, variable{
ref: blockCompiler.Binding.ref,
typ: valueType,
})
}

err := blockCompiler.compileExpressions([]*parser.Expression{arg.Value})
c.Result.Suggestions = append(c.Result.Suggestions, blockCompiler.Result.Suggestions...)
Expand Down
17 changes: 9 additions & 8 deletions mqlc/builtin_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,22 +134,22 @@ func compileResourceWhere(c *compiler, typ types.Type, ref uint64, id string, ca
if call == nil {
return types.Nil, errors.New("missing filter argument for calling '" + id + "'")
}
if len(call.Function) > 1 {
return types.Nil, errors.New("too many arguments when calling '" + id + "', only 1 is supported")
}

// if the where function is called without arguments, we don't have to do anything
// so we just return the caller type as no additional step in the compiler is necessary
if len(call.Function) == 0 {
return typ, nil
}
if len(call.Function) > 1 {
return types.Nil, errors.New("too many arguments when calling '" + id + "', only 1 is supported")
}

arg := call.Function[0]
bindingName := "_"
if arg.Name != "" {
return types.Nil, errors.New("called '" + id + "' function with a named parameter, which is not supported")
bindingName = arg.Name
}

refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, types.Array(types.Type(resource.ListType)), ref)
refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, types.Array(types.Type(resource.ListType)), ref, bindingName)
if err != nil {
return types.Nil, err
}
Expand Down Expand Up @@ -209,11 +209,12 @@ func compileResourceMap(c *compiler, typ types.Type, ref uint64, id string, call
}

arg := call.Function[0]
bindingName := "_"
if arg.Name != "" {
return types.Nil, errors.New("called '" + id + "' function with a named parameter, which is not supported")
bindingName = arg.Name
}

refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, types.Array(types.Type(resource.ListType)), ref)
refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, types.Array(types.Type(resource.ListType)), ref, bindingName)
if err != nil {
return types.Nil, err
}
Expand Down
20 changes: 12 additions & 8 deletions mqlc/mqlc.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func (c *compiler) compileBlock(expressions []*parser.Expression, typ types.Type
}
}

refs, err := c.blockExpressions(expressions, typ, bindingRef)
refs, err := c.blockExpressions(expressions, typ, bindingRef, "_")
if err != nil {
return types.Nil, err
}
Expand Down Expand Up @@ -601,7 +601,7 @@ type blockRefs struct {

// evaluates the given expressions on a non-array resource (eg: no `[]int` nor `groups`)
// and creates a function, whose reference is returned
func (c *compiler) blockOnResource(expressions []*parser.Expression, typ types.Type, binding uint64) (blockRefs, error) {
func (c *compiler) blockOnResource(expressions []*parser.Expression, typ types.Type, binding uint64, bindingName string) (blockRefs, error) {
blockCompiler := c.newBlockCompiler(nil)
blockCompiler.block.AddArgumentPlaceholder(blockCompiler.Result.CodeV2,
blockCompiler.blockRef, typ, blockCompiler.Result.CodeV2.Checksums[binding])
Expand All @@ -612,7 +612,7 @@ func (c *compiler) blockOnResource(expressions []*parser.Expression, typ types.T
blockCompiler.standalone = false
},
}
blockCompiler.vars.add("_", v)
blockCompiler.vars.add(bindingName, v)
blockCompiler.Binding = &v

err := blockCompiler.compileExpressions(expressions)
Expand Down Expand Up @@ -643,7 +643,7 @@ func (c *compiler) blockOnResource(expressions []*parser.Expression, typ types.T
ref: blockCompiler.blockRef | 1,
typ: nuTyp,
}
blockCompiler.vars.add("_", v)
blockCompiler.vars.add(bindingName, v)
blockCompiler.Binding = &v
retryErr := blockCompiler.compileExpressions(expressions)
if retryErr != nil {
Expand Down Expand Up @@ -681,13 +681,13 @@ func (c *compiler) blockOnResource(expressions []*parser.Expression, typ types.T

// blockExpressions evaluates the given expressions as if called by a block and
// returns the compiled function reference
func (c *compiler) blockExpressions(expressions []*parser.Expression, typ types.Type, binding uint64) (blockRefs, error) {
func (c *compiler) blockExpressions(expressions []*parser.Expression, typ types.Type, binding uint64, bindingName string) (blockRefs, error) {
if len(expressions) == 0 {
return blockRefs{}, nil
}

if typ.IsArray() {
return c.blockOnResource(expressions, typ.Child(), binding)
return c.blockOnResource(expressions, typ.Child(), binding, bindingName)
}

// when calling a block {} on an array resource, we expand it to all its list
Expand All @@ -708,7 +708,7 @@ func (c *compiler) blockExpressions(expressions []*parser.Expression, typ types.
}
}

return c.blockOnResource(expressions, typ, binding)
return c.blockOnResource(expressions, typ, binding, bindingName)
}

// Returns the singular return type of the given block.
Expand Down Expand Up @@ -1279,6 +1279,10 @@ func (c *compiler) compileIdentifier(id string, callBinding *variable, calls []*

variable, ok := c.vars.lookup(id)
if ok {
if variable.name == "" {
c.standalone = false
}

if variable.callback != nil {
variable.callback()
}
Expand Down Expand Up @@ -2012,7 +2016,7 @@ func (c *compiler) expandResourceFields(chunk *llx.Chunk, typ types.Type, ref ui
return false
}

refs, err := c.blockOnResource(ast.Expressions, types.Resource(info.Name), ref)
refs, err := c.blockOnResource(ast.Expressions, types.Resource(info.Name), ref, "_")
if err != nil {
log.Error().Err(err).Msg("failed to compile default for " + info.Name)
}
Expand Down
4 changes: 2 additions & 2 deletions providers-sdk/v1/testutils/mockprovider/resources/all.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (c *mqlMgroup) id() (string, error) {

func (c *mqlMuser) dict() (any, error) {
return map[string]any{
"a1": []any{int64(1), int64(2), int64(3)},
"s1": "hello world",
"listInt": []any{int64(1), int64(2), int64(3)},
"string": "hello world",
}, nil
}

0 comments on commit 1a6763d

Please sign in to comment.