diff --git a/mql/mql_test.go b/mql/mql_test.go index ec31064860..93c7150f16 100644 --- a/mql/mql_test.go +++ b/mql/mql_test.go @@ -220,6 +220,24 @@ func TestNullResources(t *testing.T) { }) } +func TestNamedFunctions(t *testing.T) { + x := testutils.InitTester(testutils.LinuxMock()) + x.TestSimple(t, []testutils.SimpleTest{ + { + Code: "muser.groups.where(group: group != empty).length", + ResultIndex: 0, Expectation: int64(1), + }, + { + Code: "muser.groups.where(_: _ != empty).length", + ResultIndex: 0, Expectation: int64(1), + }, + { + Code: "muser.dict.listInt.where(x: x == 2)", + ResultIndex: 0, Expectation: []any{int64(2)}, + }, + }) +} + func TestNullString(t *testing.T) { x := testutils.InitTester(testutils.LinuxMock()) x.TestSimple(t, []testutils.SimpleTest{ diff --git a/mqlc/builtin_array.go b/mqlc/builtin_array.go index bb81f42626..4261d3f8fe 100644 --- a/mqlc/builtin_array.go +++ b/mqlc/builtin_array.go @@ -26,11 +26,12 @@ func compileArrayWhere(c *compiler, typ types.Type, ref uint64, id string, call } arg := call.Function[0] + bindingName := "_" if arg.Name != "" { - return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported") + bindingName = arg.Name } - refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref) + refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref, bindingName) if err != nil { return types.Nil, err } @@ -103,8 +104,12 @@ func compileArrayDuplicates(c *compiler, typ types.Type, ref uint64, id string, return types.Nil, errors.New("too many arguments when calling '" + id + "'") } else if call != nil && len(call.Function) == 1 { arg := call.Function[0] + bindingName := "_" + if arg.Name != "" { + bindingName = arg.Name + } - refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref) + refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref, bindingName) if err != nil { return types.Nil, err } @@ -457,11 +462,12 @@ func compileArrayMap(c *compiler, typ types.Type, ref uint64, id string, call *p } arg := call.Function[0] + bindingName := "_" if arg.Name != "" { - return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported") + bindingName = arg.Name } - refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref) + refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, typ, ref, bindingName) if err != nil { return types.Nil, err } diff --git a/mqlc/builtin_map.go b/mqlc/builtin_map.go index 1ac2d7283e..095a716cd7 100644 --- a/mqlc/builtin_map.go +++ b/mqlc/builtin_map.go @@ -26,8 +26,9 @@ func compileDictWhere(c *compiler, typ types.Type, ref uint64, id string, call * } arg := call.Function[0] + bindingName := "_" if arg.Name != "" { - return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported") + bindingName = arg.Name } keyType := types.Dict @@ -60,6 +61,12 @@ func compileDictWhere(c *compiler, typ types.Type, ref uint64, id string, call * // we want to make sure the `_` points to the value, which is useful when dealing // with arrays and the default in maps blockCompiler.Binding.ref = blockCompiler.tailRef() + if bindingName != "_" { + blockCompiler.vars.add(bindingName, variable{ + ref: blockCompiler.Binding.ref, + typ: valueType, + }) + } err := blockCompiler.compileExpressions([]*parser.Expression{arg.Value}) c.Result.Suggestions = append(c.Result.Suggestions, blockCompiler.Result.Suggestions...) @@ -467,8 +474,9 @@ func compileMapWhere(c *compiler, typ types.Type, ref uint64, id string, call *p } arg := call.Function[0] + bindingName := "_" if arg.Name != "" { - return types.Nil, errors.New("called '" + id + "' with a named parameter, which is not supported") + bindingName = arg.Name } keyType := typ.Key() @@ -501,6 +509,12 @@ func compileMapWhere(c *compiler, typ types.Type, ref uint64, id string, call *p // we want to make sure the `_` points to the value, which is useful when dealing // with arrays and the default in maps blockCompiler.Binding.ref = blockCompiler.tailRef() + if bindingName != "_" { + blockCompiler.vars.add(bindingName, variable{ + ref: blockCompiler.Binding.ref, + typ: valueType, + }) + } err := blockCompiler.compileExpressions([]*parser.Expression{arg.Value}) c.Result.Suggestions = append(c.Result.Suggestions, blockCompiler.Result.Suggestions...) diff --git a/mqlc/builtin_resource.go b/mqlc/builtin_resource.go index 81aff14c00..bd6a5abe81 100644 --- a/mqlc/builtin_resource.go +++ b/mqlc/builtin_resource.go @@ -134,22 +134,22 @@ func compileResourceWhere(c *compiler, typ types.Type, ref uint64, id string, ca if call == nil { return types.Nil, errors.New("missing filter argument for calling '" + id + "'") } - if len(call.Function) > 1 { - return types.Nil, errors.New("too many arguments when calling '" + id + "', only 1 is supported") - } - // if the where function is called without arguments, we don't have to do anything // so we just return the caller type as no additional step in the compiler is necessary if len(call.Function) == 0 { return typ, nil } + if len(call.Function) > 1 { + return types.Nil, errors.New("too many arguments when calling '" + id + "', only 1 is supported") + } arg := call.Function[0] + bindingName := "_" if arg.Name != "" { - return types.Nil, errors.New("called '" + id + "' function with a named parameter, which is not supported") + bindingName = arg.Name } - refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, types.Array(types.Type(resource.ListType)), ref) + refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, types.Array(types.Type(resource.ListType)), ref, bindingName) if err != nil { return types.Nil, err } @@ -209,11 +209,12 @@ func compileResourceMap(c *compiler, typ types.Type, ref uint64, id string, call } arg := call.Function[0] + bindingName := "_" if arg.Name != "" { - return types.Nil, errors.New("called '" + id + "' function with a named parameter, which is not supported") + bindingName = arg.Name } - refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, types.Array(types.Type(resource.ListType)), ref) + refs, err := c.blockExpressions([]*parser.Expression{arg.Value}, types.Array(types.Type(resource.ListType)), ref, bindingName) if err != nil { return types.Nil, err } diff --git a/mqlc/mqlc.go b/mqlc/mqlc.go index 268dfaa4d5..2a5c0e516b 100644 --- a/mqlc/mqlc.go +++ b/mqlc/mqlc.go @@ -335,7 +335,7 @@ func (c *compiler) compileBlock(expressions []*parser.Expression, typ types.Type } } - refs, err := c.blockExpressions(expressions, typ, bindingRef) + refs, err := c.blockExpressions(expressions, typ, bindingRef, "_") if err != nil { return types.Nil, err } @@ -601,7 +601,7 @@ type blockRefs struct { // evaluates the given expressions on a non-array resource (eg: no `[]int` nor `groups`) // and creates a function, whose reference is returned -func (c *compiler) blockOnResource(expressions []*parser.Expression, typ types.Type, binding uint64) (blockRefs, error) { +func (c *compiler) blockOnResource(expressions []*parser.Expression, typ types.Type, binding uint64, bindingName string) (blockRefs, error) { blockCompiler := c.newBlockCompiler(nil) blockCompiler.block.AddArgumentPlaceholder(blockCompiler.Result.CodeV2, blockCompiler.blockRef, typ, blockCompiler.Result.CodeV2.Checksums[binding]) @@ -612,7 +612,7 @@ func (c *compiler) blockOnResource(expressions []*parser.Expression, typ types.T blockCompiler.standalone = false }, } - blockCompiler.vars.add("_", v) + blockCompiler.vars.add(bindingName, v) blockCompiler.Binding = &v err := blockCompiler.compileExpressions(expressions) @@ -643,7 +643,7 @@ func (c *compiler) blockOnResource(expressions []*parser.Expression, typ types.T ref: blockCompiler.blockRef | 1, typ: nuTyp, } - blockCompiler.vars.add("_", v) + blockCompiler.vars.add(bindingName, v) blockCompiler.Binding = &v retryErr := blockCompiler.compileExpressions(expressions) if retryErr != nil { @@ -681,13 +681,13 @@ func (c *compiler) blockOnResource(expressions []*parser.Expression, typ types.T // blockExpressions evaluates the given expressions as if called by a block and // returns the compiled function reference -func (c *compiler) blockExpressions(expressions []*parser.Expression, typ types.Type, binding uint64) (blockRefs, error) { +func (c *compiler) blockExpressions(expressions []*parser.Expression, typ types.Type, binding uint64, bindingName string) (blockRefs, error) { if len(expressions) == 0 { return blockRefs{}, nil } if typ.IsArray() { - return c.blockOnResource(expressions, typ.Child(), binding) + return c.blockOnResource(expressions, typ.Child(), binding, bindingName) } // when calling a block {} on an array resource, we expand it to all its list @@ -708,7 +708,7 @@ func (c *compiler) blockExpressions(expressions []*parser.Expression, typ types. } } - return c.blockOnResource(expressions, typ, binding) + return c.blockOnResource(expressions, typ, binding, bindingName) } // Returns the singular return type of the given block. @@ -1279,6 +1279,10 @@ func (c *compiler) compileIdentifier(id string, callBinding *variable, calls []* variable, ok := c.vars.lookup(id) if ok { + if variable.name == "" { + c.standalone = false + } + if variable.callback != nil { variable.callback() } @@ -2012,7 +2016,7 @@ func (c *compiler) expandResourceFields(chunk *llx.Chunk, typ types.Type, ref ui return false } - refs, err := c.blockOnResource(ast.Expressions, types.Resource(info.Name), ref) + refs, err := c.blockOnResource(ast.Expressions, types.Resource(info.Name), ref, "_") if err != nil { log.Error().Err(err).Msg("failed to compile default for " + info.Name) } diff --git a/providers-sdk/v1/testutils/mockprovider/resources/all.go b/providers-sdk/v1/testutils/mockprovider/resources/all.go index fb8dc8a166..04026d554c 100644 --- a/providers-sdk/v1/testutils/mockprovider/resources/all.go +++ b/providers-sdk/v1/testutils/mockprovider/resources/all.go @@ -51,7 +51,7 @@ func (c *mqlMgroup) id() (string, error) { func (c *mqlMuser) dict() (any, error) { return map[string]any{ - "a1": []any{int64(1), int64(2), int64(3)}, - "s1": "hello world", + "listInt": []any{int64(1), int64(2), int64(3)}, + "string": "hello world", }, nil }