diff --git a/compiler/ast/dag/expr.go b/compiler/ast/dag/expr.go index a735d4133a..7fe4779afa 100644 --- a/compiler/ast/dag/expr.go +++ b/compiler/ast/dag/expr.go @@ -21,6 +21,11 @@ type ( Expr Expr `json:"expr"` Where Expr `json:"where"` } + ApplyExpr struct { + Kind string `json:"kind" unpack:""` + Expr Expr `json:"expr"` + Func string `json:"call"` + } ArrayExpr struct { Kind string `json:"kind" unpack:""` Elems []VectorElem `json:"elems"` @@ -113,6 +118,7 @@ type ( ) func (*Agg) ExprDAG() {} +func (*ApplyExpr) ExprDAG() {} func (*ArrayExpr) ExprDAG() {} func (*Assignment) ExprDAG() {} func (*BinaryExpr) ExprDAG() {} diff --git a/compiler/ast/dag/unpack.go b/compiler/ast/dag/unpack.go index 0a03034870..6ebdcf51fe 100644 --- a/compiler/ast/dag/unpack.go +++ b/compiler/ast/dag/unpack.go @@ -8,6 +8,7 @@ import ( var unpacker = unpack.New( Agg{}, + ApplyExpr{}, ArrayExpr{}, Assignment{}, BinaryExpr{}, diff --git a/compiler/kernel/expr.go b/compiler/kernel/expr.go index 80d5dae6d4..c03ef11450 100644 --- a/compiler/kernel/expr.go +++ b/compiler/kernel/expr.go @@ -91,6 +91,8 @@ func (b *Builder) compileExpr(e dag.Expr) (expr.Evaluator, error) { return nil, err } return expr.NewAggregatorExpr(agg), nil + case *dag.ApplyExpr: + return b.compileApplyExpr(e) case *dag.OverExpr: return b.compileOverExpr(e) default: @@ -335,13 +337,9 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) { var path field.Path // First check if call is to a user defined function, otherwise check for // builtin function. - fn, ok := b.funcs[call.Name] - if !ok { - var err error - fn, path, err = function.New(b.zctx(), call.Name, len(call.Args)) - if err != nil { - return nil, fmt.Errorf("%s(): %w", call.Name, err) - } + fn, err := b.lookupFunc(call.Name, len(call.Args)) + if err != nil { + return nil, fmt.Errorf("internal error %s(): %w", call.Name, err) } args := call.Args if path != nil { @@ -355,6 +353,29 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) { return expr.NewCall(b.zctx(), fn, exprs), nil } +func (b *Builder) lookupFunc(name string, nargs int) (expr.Function, error) { + fn, ok := b.funcs[name] + if !ok { + var err error + if fn, _, err = function.New(b.zctx(), name, nargs); err != nil { + return nil, err + } + } + return fn, nil +} + +func (b *Builder) compileApplyExpr(a *dag.ApplyExpr) (expr.Evaluator, error) { + e, err := b.compileExpr(a.Expr) + if err != nil { + return nil, err + } + fn, err := b.lookupFunc(a.Func, 1) + if err != nil { + return nil, err + } + return expr.NewApplyFunc(b.zctx(), e, fn), nil +} + func (b *Builder) compileExprs(in []dag.Expr) ([]expr.Evaluator, error) { var exprs []expr.Evaluator for _, e := range in { diff --git a/compiler/kernel/op.go b/compiler/kernel/op.go index e51e960209..08353da86a 100644 --- a/compiler/kernel/op.go +++ b/compiler/kernel/op.go @@ -51,6 +51,7 @@ type Builder struct { progress *zbuf.Progress deletes *sync.Map funcs map[string]expr.Function + funcArgs map[string]int } func NewBuilder(octx *op.Context, source *data.Source) *Builder { diff --git a/compiler/semantic/expr.go b/compiler/semantic/expr.go index 68c21341eb..dd88aad122 100644 --- a/compiler/semantic/expr.go +++ b/compiler/semantic/expr.go @@ -12,6 +12,7 @@ import ( "github.com/brimdata/zed/pkg/reglob" "github.com/brimdata/zed/runtime/expr" "github.com/brimdata/zed/runtime/expr/agg" + "github.com/brimdata/zed/runtime/expr/function" "github.com/brimdata/zed/zson" ) @@ -473,26 +474,53 @@ func (a *analyzer) semCall(call *ast.Call) (dag.Expr, error) { if err != nil { return nil, fmt.Errorf("%s: bad argument: %w", call.Name, err) } + err = a.validateFnCall(call.Name, len(exprs)) + if errors.Is(err, function.ErrNoSuchFunction) && call.Name == "apply" { + if len(call.Args) != 2 { + return nil, fmt.Errorf("apply(): expects 2 argument(s)") + } + callid, ok := call.Args[1].(*ast.ID) + if !ok { + return nil, fmt.Errorf("apply(): second argument must be the identifier of a func") + } + if err := a.validateFnCall(callid.Name, 1); err != nil { + return nil, fmt.Errorf("%s(): %w", callid.Name, err) + } + return &dag.ApplyExpr{ + Kind: "ApplyExpr", + Expr: exprs[0], + Func: callid.Name, + }, nil + } + if err != nil { + return nil, fmt.Errorf("%s(): %w", call.Name, err) + } + return &dag.Call{ + Kind: "Call", + Name: call.Name, + Args: exprs, + }, nil +} + +func (a *analyzer) validateFnCall(name string, nargs int) error { // Call could be to a user defined func. Check if we have a matching func in // scope. - e, err := a.scope.LookupExpr(call.Name) + e, err := a.scope.LookupExpr(name) if err != nil { - return nil, err + return err } if e != nil { f, ok := e.(*dag.Func) if !ok { - return nil, fmt.Errorf("%s(): definition is not a function type: %T", call.Name, e) + return fmt.Errorf("definition is not a function type: %T", e) } - if len(f.Params) != len(call.Args) { - return nil, fmt.Errorf("%s(): expects %d argument(s)", call.Name, len(f.Params)) + if len(f.Params) != nargs { + return fmt.Errorf("expects %d argument(s)", len(f.Params)) } + return nil } - return &dag.Call{ - Kind: "Call", - Name: call.Name, - Args: exprs, - }, nil + _, _, err = function.New(a.zctx, name, nargs) + return err } func (a *analyzer) semExprs(in []ast.Expr) ([]dag.Expr, error) { diff --git a/runtime/expr/apply.go b/runtime/expr/apply.go new file mode 100644 index 0000000000..e7fa26d50c --- /dev/null +++ b/runtime/expr/apply.go @@ -0,0 +1,66 @@ +package expr + +import ( + "github.com/brimdata/zed" + "github.com/brimdata/zed/zcode" +) + +type apply struct { + builder zcode.Builder + eval Evaluator + fn Function + zctx *zed.Context + + // vals is used to reduce allocations + vals []zed.Value + // types is used to reduce allocations + types []zed.Type +} + +func NewApplyFunc(zctx *zed.Context, e Evaluator, fn Function) Evaluator { + return &apply{eval: e, fn: fn, zctx: zctx} +} + +func (a *apply) Eval(ectx Context, in *zed.Value) *zed.Value { + v := a.eval.Eval(ectx, in) + if v.IsError() { + return v + } + elems, err := v.Elements() + if err != nil { + return ectx.CopyValue(*a.zctx.WrapError(err.Error(), in)) + } + if len(elems) == 0 { + return v + } + a.vals = a.vals[:0] + a.types = a.types[:0] + for _, elem := range elems { + out := a.fn.Call(ectx, []zed.Value{elem}) + a.vals = append(a.vals, *out) + a.types = append(a.types, out.Type) + } + inner := a.innerType(a.types) + a.builder.Reset() + if union, ok := inner.(*zed.TypeUnion); ok { + for _, val := range a.vals { + zed.BuildUnion(&a.builder, union.TagOf(val.Type), val.Bytes()) + } + } else { + for _, val := range a.vals { + a.builder.Append(val.Bytes()) + } + } + if _, ok := zed.TypeUnder(in.Type).(*zed.TypeSet); ok { + return ectx.NewValue(a.zctx.LookupTypeSet(inner), zed.NormalizeSet(a.builder.Bytes())) + } + return ectx.NewValue(a.zctx.LookupTypeArray(inner), a.builder.Bytes()) +} + +func (a *apply) innerType(types []zed.Type) zed.Type { + types = zed.UniqueTypes(types) + if len(types) == 1 { + return types[0] + } + return a.zctx.LookupTypeUnion(types) +} diff --git a/runtime/expr/ztests/apply.yaml b/runtime/expr/ztests/apply.yaml new file mode 100644 index 0000000000..e564fbdc6c --- /dev/null +++ b/runtime/expr/ztests/apply.yaml @@ -0,0 +1,7 @@ +script: | + echo '{a:["foo","bar","baz"]}' | zq -z 'a := apply(a,upper)' - + +outputs: + - name: stdout + data: | + {a:["FOO","BAR","BAZ"]}