Skip to content

Commit

Permalink
vam: Add support for conditionals (case exprs)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattnibs committed Nov 21, 2024
1 parent c06e81a commit bac0723
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 12 deletions.
20 changes: 18 additions & 2 deletions compiler/kernel/vexpr.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ func (b *Builder) compileVamExpr(e dag.Expr) (vamexpr.Evaluator, error) {
return b.compileVamUnary(*e)
case *dag.BinaryExpr:
return b.compileVamBinary(e)
//case *dag.Conditional:
// return b.compileVamConditional(*e)
case *dag.Conditional:
return b.compileVamConditional(*e)
case *dag.Call:
return b.compileVamCall(e)
//case *dag.RegexpMatch:
Expand Down Expand Up @@ -111,6 +111,22 @@ func (b *Builder) compileVamBinary(e *dag.BinaryExpr) (vamexpr.Evaluator, error)
}
}

func (b *Builder) compileVamConditional(node dag.Conditional) (vamexpr.Evaluator, error) {
predicate, err := b.compileVamExpr(node.Cond)
if err != nil {
return nil, err
}
thenExpr, err := b.compileVamExpr(node.Then)
if err != nil {
return nil, err
}
elseExpr, err := b.compileVamExpr(node.Else)
if err != nil {
return nil, err
}
return vamexpr.NewConditional(b.zctx(), predicate, thenExpr, elseExpr), nil
}

func (b *Builder) compileVamUnary(unary dag.UnaryExpr) (vamexpr.Evaluator, error) {
e, err := b.compileVamExpr(unary.Operand)
if err != nil {
Expand Down
127 changes: 127 additions & 0 deletions runtime/vam/expr/conditional.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package expr

import (
"github.com/RoaringBitmap/roaring"
"github.com/brimdata/super"
"github.com/brimdata/super/vector"
)

type conditional struct {
zctx *super.Context
predicate Evaluator
thenExpr Evaluator
elseExpr Evaluator
}

func NewConditional(zctx *super.Context, predicate, thenExpr, elseExpr Evaluator) Evaluator {
return &conditional{
zctx: zctx,
predicate: predicate,
thenExpr: thenExpr,
elseExpr: elseExpr,
}
}

func (c *conditional) Eval(this vector.Any) vector.Any {
predVec := c.predicate.Eval(this)
boolsMap, errsMap := BoolMask(predVec)
if errsMap.GetCardinality() == uint64(this.Len()) {
return c.predicateError(predVec)
}
if boolsMap.GetCardinality() == uint64(this.Len()) {
return c.thenExpr.Eval(this)
}
if boolsMap.IsEmpty() && errsMap.IsEmpty() {
return c.elseExpr.Eval(this)
}
thenVec := c.thenExpr.Eval(vector.NewView(this, boolsMap.ToArray()))
// elseMap is the difference between boolsMap or errsMap
elseMap := roaring.Or(boolsMap, errsMap)
elseMap.Flip(0, uint64(this.Len()))
elseIndex := elseMap.ToArray()
elseVec := c.elseExpr.Eval(vector.NewView(this, elseIndex))
tags := make([]uint32, this.Len())
for _, idx := range elseIndex {
tags[idx] = 1
}
vecs := []vector.Any{thenVec, elseVec}
if !errsMap.IsEmpty() {
errsIndex := errsMap.ToArray()
for _, idx := range errsIndex {
tags[idx] = 2
}
vecs = append(vecs, c.predicateError(vector.NewView(predVec, errsIndex)))
}
return vector.NewDynamic(tags, vecs)
}

func (c *conditional) predicateError(vec vector.Any) vector.Any {
return vector.Apply(false, func(vecs ...vector.Any) vector.Any {
return vector.NewWrappedError(c.zctx, "?-operator: bool predicate required", vecs[0])
}, vec)
}

func BoolMask(mask vector.Any) (*roaring.Bitmap, *roaring.Bitmap) {
bools := roaring.New()
errs := roaring.New()
if dynamic, ok := mask.(*vector.Dynamic); ok {
for i, val := range dynamic.Values {
boolMaskRidx(dynamic.TagMap.Reverse[i], bools, errs, val)
}
} else {
boolMaskRidx(nil, bools, errs, mask)
}
return bools, errs
}

func boolMaskRidx(ridx []uint32, bools, errs *roaring.Bitmap, vec vector.Any) {
switch vec := vec.(type) {
case *vector.Const:
if !vec.Value().Ptr().AsBool() {
return
}
if vec.Nulls != nil {
if ridx != nil {
for i, idx := range ridx {
if !vec.Nulls.Value(uint32(i)) {
bools.Add(idx)
}
}
} else {
for i := range vec.Len() {
if !vec.Nulls.Value(i) {
bools.Add(i)
}
}
}
} else {
if ridx != nil {
bools.AddMany(ridx)
} else {
bools.AddRange(0, uint64(vec.Len()))
}
}
case *vector.Bool:
if ridx != nil {
for i, idx := range ridx {
if vec.Value(uint32(i)) {
bools.Add(idx)
}
}
} else {
for i := range vec.Len() {
if vec.Value(i) {
bools.Add(i)
}
}
}
case *vector.Error:
if ridx != nil {
errs.AddMany(ridx)
} else {
errs.AddRange(0, uint64(vec.Len()))
}
default:
panic(vec)
}
}
15 changes: 5 additions & 10 deletions runtime/vam/op/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,13 @@ func (f *Filter) Pull(done bool) (vector.Any, error) {
// applyMask applies the mask vector mask to vec. Elements of mask that are not
// Boolean are considered false.
func applyMask(vec, mask vector.Any) (vector.Any, bool) {
n := mask.Len()
var index []uint32
for k := uint32(0); k < n; k++ {
if vector.BoolValue(mask, k) {
index = append(index, k)
}
}
if len(index) == 0 {
// errors are ignored for filters
b, _ := expr.BoolMask(mask)
if b.IsEmpty() {
return nil, false
}
if len(index) == int(n) {
if b.GetCardinality() == uint64(mask.Len()) {
return vec, true
}
return vector.NewView(vec, index), true
return vector.NewView(vec, b.ToArray()), true
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
zed: 'yield case when x==1 then "foo" when x==2 then "bar" else {y:12} end'

vector: true

input: |
{x:1}
{x:2,y:3}
Expand Down

0 comments on commit bac0723

Please sign in to comment.