From bac0723538abc2f5e402bf1e6382061f3b0c5805 Mon Sep 17 00:00:00 2001 From: Matthew Nibecker Date: Thu, 21 Nov 2024 10:38:17 -0500 Subject: [PATCH] vam: Add support for conditionals (case exprs) --- compiler/kernel/vexpr.go | 20 ++- runtime/vam/expr/conditional.go | 127 ++++++++++++++++++ runtime/vam/op/filter.go | 15 +-- .../ztests => runtime/ztests/expr}/case.yaml | 2 + 4 files changed, 152 insertions(+), 12 deletions(-) create mode 100644 runtime/vam/expr/conditional.go rename {compiler/parser/ztests => runtime/ztests/expr}/case.yaml (94%) diff --git a/compiler/kernel/vexpr.go b/compiler/kernel/vexpr.go index ea61b44f69..7b39a55681 100644 --- a/compiler/kernel/vexpr.go +++ b/compiler/kernel/vexpr.go @@ -41,8 +41,8 @@ func (b *Builder) compileVamExpr(e dag.Expr) (vamexpr.Evaluator, error) { return b.compileVamUnary(*e) case *dag.BinaryExpr: return b.compileVamBinary(e) - //case *dag.Conditional: - // return b.compileVamConditional(*e) + case *dag.Conditional: + return b.compileVamConditional(*e) case *dag.Call: return b.compileVamCall(e) //case *dag.RegexpMatch: @@ -111,6 +111,22 @@ func (b *Builder) compileVamBinary(e *dag.BinaryExpr) (vamexpr.Evaluator, error) } } +func (b *Builder) compileVamConditional(node dag.Conditional) (vamexpr.Evaluator, error) { + predicate, err := b.compileVamExpr(node.Cond) + if err != nil { + return nil, err + } + thenExpr, err := b.compileVamExpr(node.Then) + if err != nil { + return nil, err + } + elseExpr, err := b.compileVamExpr(node.Else) + if err != nil { + return nil, err + } + return vamexpr.NewConditional(b.zctx(), predicate, thenExpr, elseExpr), nil +} + func (b *Builder) compileVamUnary(unary dag.UnaryExpr) (vamexpr.Evaluator, error) { e, err := b.compileVamExpr(unary.Operand) if err != nil { diff --git a/runtime/vam/expr/conditional.go b/runtime/vam/expr/conditional.go new file mode 100644 index 0000000000..80ce9b8ad8 --- /dev/null +++ b/runtime/vam/expr/conditional.go @@ -0,0 +1,127 @@ +package expr + +import ( + "github.com/RoaringBitmap/roaring" + "github.com/brimdata/super" + "github.com/brimdata/super/vector" +) + +type conditional struct { + zctx *super.Context + predicate Evaluator + thenExpr Evaluator + elseExpr Evaluator +} + +func NewConditional(zctx *super.Context, predicate, thenExpr, elseExpr Evaluator) Evaluator { + return &conditional{ + zctx: zctx, + predicate: predicate, + thenExpr: thenExpr, + elseExpr: elseExpr, + } +} + +func (c *conditional) Eval(this vector.Any) vector.Any { + predVec := c.predicate.Eval(this) + boolsMap, errsMap := BoolMask(predVec) + if errsMap.GetCardinality() == uint64(this.Len()) { + return c.predicateError(predVec) + } + if boolsMap.GetCardinality() == uint64(this.Len()) { + return c.thenExpr.Eval(this) + } + if boolsMap.IsEmpty() && errsMap.IsEmpty() { + return c.elseExpr.Eval(this) + } + thenVec := c.thenExpr.Eval(vector.NewView(this, boolsMap.ToArray())) + // elseMap is the difference between boolsMap or errsMap + elseMap := roaring.Or(boolsMap, errsMap) + elseMap.Flip(0, uint64(this.Len())) + elseIndex := elseMap.ToArray() + elseVec := c.elseExpr.Eval(vector.NewView(this, elseIndex)) + tags := make([]uint32, this.Len()) + for _, idx := range elseIndex { + tags[idx] = 1 + } + vecs := []vector.Any{thenVec, elseVec} + if !errsMap.IsEmpty() { + errsIndex := errsMap.ToArray() + for _, idx := range errsIndex { + tags[idx] = 2 + } + vecs = append(vecs, c.predicateError(vector.NewView(predVec, errsIndex))) + } + return vector.NewDynamic(tags, vecs) +} + +func (c *conditional) predicateError(vec vector.Any) vector.Any { + return vector.Apply(false, func(vecs ...vector.Any) vector.Any { + return vector.NewWrappedError(c.zctx, "?-operator: bool predicate required", vecs[0]) + }, vec) +} + +func BoolMask(mask vector.Any) (*roaring.Bitmap, *roaring.Bitmap) { + bools := roaring.New() + errs := roaring.New() + if dynamic, ok := mask.(*vector.Dynamic); ok { + for i, val := range dynamic.Values { + boolMaskRidx(dynamic.TagMap.Reverse[i], bools, errs, val) + } + } else { + boolMaskRidx(nil, bools, errs, mask) + } + return bools, errs +} + +func boolMaskRidx(ridx []uint32, bools, errs *roaring.Bitmap, vec vector.Any) { + switch vec := vec.(type) { + case *vector.Const: + if !vec.Value().Ptr().AsBool() { + return + } + if vec.Nulls != nil { + if ridx != nil { + for i, idx := range ridx { + if !vec.Nulls.Value(uint32(i)) { + bools.Add(idx) + } + } + } else { + for i := range vec.Len() { + if !vec.Nulls.Value(i) { + bools.Add(i) + } + } + } + } else { + if ridx != nil { + bools.AddMany(ridx) + } else { + bools.AddRange(0, uint64(vec.Len())) + } + } + case *vector.Bool: + if ridx != nil { + for i, idx := range ridx { + if vec.Value(uint32(i)) { + bools.Add(idx) + } + } + } else { + for i := range vec.Len() { + if vec.Value(i) { + bools.Add(i) + } + } + } + case *vector.Error: + if ridx != nil { + errs.AddMany(ridx) + } else { + errs.AddRange(0, uint64(vec.Len())) + } + default: + panic(vec) + } +} diff --git a/runtime/vam/op/filter.go b/runtime/vam/op/filter.go index bc9a8a3d30..c4b23460e7 100644 --- a/runtime/vam/op/filter.go +++ b/runtime/vam/op/filter.go @@ -31,18 +31,13 @@ func (f *Filter) Pull(done bool) (vector.Any, error) { // applyMask applies the mask vector mask to vec. Elements of mask that are not // Boolean are considered false. func applyMask(vec, mask vector.Any) (vector.Any, bool) { - n := mask.Len() - var index []uint32 - for k := uint32(0); k < n; k++ { - if vector.BoolValue(mask, k) { - index = append(index, k) - } - } - if len(index) == 0 { + // errors are ignored for filters + b, _ := expr.BoolMask(mask) + if b.IsEmpty() { return nil, false } - if len(index) == int(n) { + if b.GetCardinality() == uint64(mask.Len()) { return vec, true } - return vector.NewView(vec, index), true + return vector.NewView(vec, b.ToArray()), true } diff --git a/compiler/parser/ztests/case.yaml b/runtime/ztests/expr/case.yaml similarity index 94% rename from compiler/parser/ztests/case.yaml rename to runtime/ztests/expr/case.yaml index 9b9e1068bb..7f57084d6b 100644 --- a/compiler/parser/ztests/case.yaml +++ b/runtime/ztests/expr/case.yaml @@ -1,5 +1,7 @@ zed: 'yield case when x==1 then "foo" when x==2 then "bar" else {y:12} end' +vector: true + input: | {x:1} {x:2,y:3}