Skip to content

Commit

Permalink
Add apply function
Browse files Browse the repository at this point in the history
This is a function that applies a function to every element in an array
or set value.
  • Loading branch information
mattnibs committed Oct 11, 2023
1 parent da25095 commit 1e84ef7
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 17 deletions.
6 changes: 6 additions & 0 deletions compiler/ast/dag/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ type (
Expr Expr `json:"expr"`
Where Expr `json:"where"`
}
ApplyExpr struct {
Kind string `json:"kind" unpack:""`
Expr Expr `json:"expr"`
Func string `json:"call"`
}
ArrayExpr struct {
Kind string `json:"kind" unpack:""`
Elems []VectorElem `json:"elems"`
Expand Down Expand Up @@ -113,6 +118,7 @@ type (
)

func (*Agg) ExprDAG() {}
func (*ApplyExpr) ExprDAG() {}
func (*ArrayExpr) ExprDAG() {}
func (*Assignment) ExprDAG() {}
func (*BinaryExpr) ExprDAG() {}
Expand Down
1 change: 1 addition & 0 deletions compiler/ast/dag/unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

var unpacker = unpack.New(
Agg{},
ApplyExpr{},
ArrayExpr{},
Assignment{},
BinaryExpr{},
Expand Down
35 changes: 28 additions & 7 deletions compiler/kernel/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ func (b *Builder) compileExpr(e dag.Expr) (expr.Evaluator, error) {
return nil, err
}
return expr.NewAggregatorExpr(agg), nil
case *dag.ApplyExpr:
return b.compileApplyExpr(e)
case *dag.OverExpr:
return b.compileOverExpr(e)
default:
Expand Down Expand Up @@ -335,13 +337,9 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) {
var path field.Path
// First check if call is to a user defined function, otherwise check for
// builtin function.
fn, ok := b.funcs[call.Name]
if !ok {
var err error
fn, path, err = function.New(b.zctx(), call.Name, len(call.Args))
if err != nil {
return nil, fmt.Errorf("%s(): %w", call.Name, err)
}
fn, err := b.lookupFunc(call.Name, len(call.Args))
if err != nil {
return nil, fmt.Errorf("internal error %s(): %w", call.Name, err)
}
args := call.Args
if path != nil {
Expand All @@ -355,6 +353,29 @@ func (b *Builder) compileCall(call dag.Call) (expr.Evaluator, error) {
return expr.NewCall(b.zctx(), fn, exprs), nil
}

func (b *Builder) lookupFunc(name string, nargs int) (expr.Function, error) {
fn, ok := b.funcs[name]
if !ok {
var err error
if fn, _, err = function.New(b.zctx(), name, nargs); err != nil {
return nil, err
}
}
return fn, nil
}

func (b *Builder) compileApplyExpr(a *dag.ApplyExpr) (expr.Evaluator, error) {
e, err := b.compileExpr(a.Expr)
if err != nil {
return nil, err
}
fn, err := b.lookupFunc(a.Func, 1)
if err != nil {
return nil, err
}
return expr.NewApplyFunc(b.zctx(), e, fn), nil
}

func (b *Builder) compileExprs(in []dag.Expr) ([]expr.Evaluator, error) {
var exprs []expr.Evaluator
for _, e := range in {
Expand Down
1 change: 1 addition & 0 deletions compiler/kernel/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type Builder struct {
progress *zbuf.Progress
deletes *sync.Map
funcs map[string]expr.Function
funcArgs map[string]int
}

func NewBuilder(octx *op.Context, source *data.Source) *Builder {
Expand Down
48 changes: 38 additions & 10 deletions compiler/semantic/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/brimdata/zed/pkg/reglob"
"github.com/brimdata/zed/runtime/expr"
"github.com/brimdata/zed/runtime/expr/agg"
"github.com/brimdata/zed/runtime/expr/function"
"github.com/brimdata/zed/zson"
)

Expand Down Expand Up @@ -473,26 +474,53 @@ func (a *analyzer) semCall(call *ast.Call) (dag.Expr, error) {
if err != nil {
return nil, fmt.Errorf("%s: bad argument: %w", call.Name, err)
}
err = a.validateFnCall(call.Name, len(exprs))
if errors.Is(err, function.ErrNoSuchFunction) && call.Name == "apply" {
if len(call.Args) != 2 {
return nil, fmt.Errorf("apply(): expects 2 argument(s)")
}
callid, ok := call.Args[1].(*ast.ID)
if !ok {
return nil, fmt.Errorf("apply(): second argument must be the identifier of a func")
}
if err := a.validateFnCall(callid.Name, 1); err != nil {
return nil, fmt.Errorf("%s(): %w", callid.Name, err)
}
return &dag.ApplyExpr{
Kind: "ApplyExpr",
Expr: exprs[0],
Func: callid.Name,
}, nil
}
if err != nil {
return nil, fmt.Errorf("%s(): %w", call.Name, err)
}
return &dag.Call{
Kind: "Call",
Name: call.Name,
Args: exprs,
}, nil
}

func (a *analyzer) validateFnCall(name string, nargs int) error {
// Call could be to a user defined func. Check if we have a matching func in
// scope.
e, err := a.scope.LookupExpr(call.Name)
e, err := a.scope.LookupExpr(name)
if err != nil {
return nil, err
return err
}
if e != nil {
f, ok := e.(*dag.Func)
if !ok {
return nil, fmt.Errorf("%s(): definition is not a function type: %T", call.Name, e)
return fmt.Errorf("definition is not a function type: %T", e)
}
if len(f.Params) != len(call.Args) {
return nil, fmt.Errorf("%s(): expects %d argument(s)", call.Name, len(f.Params))
if len(f.Params) != nargs {
return fmt.Errorf("expects %d argument(s)", len(f.Params))
}
return nil
}
return &dag.Call{
Kind: "Call",
Name: call.Name,
Args: exprs,
}, nil
_, _, err = function.New(a.zctx, name, nargs)
return err
}

func (a *analyzer) semExprs(in []ast.Expr) ([]dag.Expr, error) {
Expand Down
66 changes: 66 additions & 0 deletions runtime/expr/apply.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package expr

import (
"github.com/brimdata/zed"
"github.com/brimdata/zed/zcode"
)

type apply struct {
builder zcode.Builder
eval Evaluator
fn Function
zctx *zed.Context

// vals is used to reduce allocations
vals []zed.Value
// types is used to reduce allocations
types []zed.Type
}

func NewApplyFunc(zctx *zed.Context, e Evaluator, fn Function) Evaluator {
return &apply{eval: e, fn: fn, zctx: zctx}
}

func (a *apply) Eval(ectx Context, in *zed.Value) *zed.Value {
v := a.eval.Eval(ectx, in)
if v.IsError() {
return v
}
elems, err := v.Elements()
if err != nil {
return ectx.CopyValue(*a.zctx.WrapError(err.Error(), in))
}
if len(elems) == 0 {
return v
}
a.vals = a.vals[:0]
a.types = a.types[:0]
for _, elem := range elems {
out := a.fn.Call(ectx, []zed.Value{elem})
a.vals = append(a.vals, *out)
a.types = append(a.types, out.Type)
}
inner := a.innerType(a.types)
a.builder.Reset()
if union, ok := inner.(*zed.TypeUnion); ok {
for _, val := range a.vals {
zed.BuildUnion(&a.builder, union.TagOf(val.Type), val.Bytes())
}
} else {
for _, val := range a.vals {
a.builder.Append(val.Bytes())
}
}
if _, ok := zed.TypeUnder(in.Type).(*zed.TypeSet); ok {
return ectx.NewValue(a.zctx.LookupTypeSet(inner), zed.NormalizeSet(a.builder.Bytes()))
}
return ectx.NewValue(a.zctx.LookupTypeArray(inner), a.builder.Bytes())
}

func (a *apply) innerType(types []zed.Type) zed.Type {
types = zed.UniqueTypes(types)
if len(types) == 1 {
return types[0]
}
return a.zctx.LookupTypeUnion(types)
}
7 changes: 7 additions & 0 deletions runtime/expr/ztests/apply.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
script: |
echo '{a:["foo","bar","baz"]}' | zq -z 'a := apply(a,upper)' -
outputs:
- name: stdout
data: |
{a:["FOO","BAR","BAZ"]}

0 comments on commit 1e84ef7

Please sign in to comment.