diff --git a/compiler/kernel/vexpr.go b/compiler/kernel/vexpr.go index 833caf9ef4..eeb9137aee 100644 --- a/compiler/kernel/vexpr.go +++ b/compiler/kernel/vexpr.go @@ -100,8 +100,8 @@ func (b *Builder) compileVamBinary(e *dag.BinaryExpr) (vamexpr.Evaluator, error) return vamexpr.NewLogicalAnd(b.zctx(), lhs, rhs), nil case "or": return vamexpr.NewLogicalOr(b.zctx(), lhs, rhs), nil - //case "in": XXX TBD - // return vamexpr.NewIn(b.zctx(), lhs, rhs), nil + case "in": + return vamexpr.NewIn(b.zctx(), lhs, rhs), nil case "==", "!=", "<", "<=", ">", ">=": return vamexpr.NewCompare(b.zctx(), lhs, rhs, op), nil case "+", "-", "*", "/", "%": diff --git a/runtime/vam/expr/logic.go b/runtime/vam/expr/logic.go index cf4da97b29..c229db0e12 100644 --- a/runtime/vam/expr/logic.go +++ b/runtime/vam/expr/logic.go @@ -1,6 +1,8 @@ package expr import ( + "slices" + "github.com/brimdata/super" "github.com/brimdata/super/vector" ) @@ -177,6 +179,18 @@ func toBool(vec vector.Any) *vector.Bool { } else { return vector.NewBoolEmpty(vec.Len(), vec.Nulls) } + case *vector.Dynamic: + nulls := vector.NewBoolEmpty(vec.Len(), nil) + out := vector.NewBoolEmpty(vec.Len(), nulls) + for i := range vec.Len() { + v, null := vector.BoolValue(vec, i) + if null { + nulls.Set(i) + } else if v { + out.Set(i) + } + } + return out case *vector.Bool: return vec default: @@ -191,3 +205,90 @@ func trueBool(n uint32) *vector.Bool { } return vec } + +type In struct { + zctx *super.Context + lhs Evaluator + rhs Evaluator + eq *Compare +} + +func NewIn(zctx *super.Context, lhs, rhs Evaluator) *In { + return &In{zctx, lhs, rhs, NewCompare(zctx, nil, nil, "==")} +} + +func (i *In) Eval(this vector.Any) vector.Any { + return vector.Apply(true, i.eval, i.lhs.Eval(this), i.rhs.Eval(this)) +} + +func (i *In) eval(vecs ...vector.Any) vector.Any { + lhs, rhs := vecs[0], vecs[1] + if lhs.Type().Kind() == super.ErrorKind { + return lhs + } + if rhs.Type().Kind() == super.ErrorKind { + return rhs + } + return i.evalResursive(lhs, rhs) +} + +func (i *In) evalResursive(vecs ...vector.Any) vector.Any { + lhs, rhs := vecs[0], vecs[1] + rhs = vector.Under(rhs) + var index []uint32 + if view, ok := rhs.(*vector.View); ok { + rhs = view.Any + index = view.Index + } + switch rhs := rhs.(type) { + case *vector.Record: + out := vector.NewBoolEmpty(lhs.Len(), nil) + for _, f := range rhs.Fields { + if index != nil { + f = vector.NewView(f, index) + } + out = vector.Or(out, toBool(i.evalResursive(lhs, f))) + } + return out + case *vector.Array: + return i.evalForList(lhs, rhs.Values, rhs.Offsets, index) + case *vector.Set: + return i.evalForList(lhs, rhs.Values, rhs.Offsets, index) + case *vector.Map: + return vector.Or(i.evalForList(lhs, rhs.Keys, rhs.Offsets, index), + i.evalForList(lhs, rhs.Values, rhs.Offsets, index)) + case *vector.Union: + return vector.Apply(true, i.evalResursive, lhs, rhs) + case *vector.Error: + return i.evalResursive(lhs, rhs.Vals) + default: + return i.eq.eval(lhs, rhs) + } +} + +func (i *In) evalForList(lhs, rhs vector.Any, offsets, index []uint32) *vector.Bool { + out := vector.NewBoolEmpty(lhs.Len(), nil) + var lhsIndex, rhsIndex []uint32 + for j := range lhs.Len() { + if index != nil { + j = index[j] + } + start, end := offsets[j], offsets[j+1] + if start == end { + continue + } + n := end - start + lhsIndex = slices.Grow(lhsIndex[:0], int(n))[:n] + rhsIndex = slices.Grow(rhsIndex[:0], int(n))[:n] + for k := range n { + lhsIndex[k] = k + rhsIndex[k] = k + start + } + lhsView := vector.NewView(lhs, lhsIndex) + rhsView := vector.NewView(rhs, rhsIndex) + if toBool(i.evalResursive(lhsView, rhsView)).TrueCount() > 0 { + out.Set(j) + } + } + return out +} diff --git a/runtime/sam/expr/ztests/in-field.yaml b/runtime/ztests/expr/in-field.yaml similarity index 92% rename from runtime/sam/expr/ztests/in-field.yaml rename to runtime/ztests/expr/in-field.yaml index 52b1619c94..fcc1b68ef7 100644 --- a/runtime/sam/expr/ztests/in-field.yaml +++ b/runtime/ztests/expr/in-field.yaml @@ -1,5 +1,7 @@ zed: 2 in a +vector: true + input: | {a:[1(uint32),2(uint32)]} {a:[1(uint32)]} diff --git a/runtime/ztests/expr/in.yaml b/runtime/ztests/expr/in.yaml new file mode 100644 index 0000000000..a6be236ae6 --- /dev/null +++ b/runtime/ztests/expr/in.yaml @@ -0,0 +1,28 @@ +zed: | + yield [1 in this, 9 in this] + +vector: true + +input: | + 1 + {a:1} + [0,1,2] + |[0,1,2]| + |{1:null}| + |{null:1}| + 1((int64,string)) + {a:[0,1]([(int64,string)])} + [error(0),error(1)] + error(0) + +output: | + [true,false] + [true,false] + [true,false] + [true,false] + [true,false] + [true,false] + [true,false] + [true,false] + [true,false] + [error(0),error(0)] diff --git a/vector/bool.go b/vector/bool.go index 289d24720c..ed3cac6671 100644 --- a/vector/bool.go +++ b/vector/bool.go @@ -146,13 +146,16 @@ func And(a, b *Bool) *Bool { // BoolValue returns the value of slot in vec if the value is a Boolean. It // returns false otherwise. -func BoolValue(vec Any, slot uint32) bool { +func BoolValue(vec Any, slot uint32) (bool, bool) { switch vec := Under(vec).(type) { case *Bool: - return vec.Value(slot) + return vec.Value(slot), vec.Nulls.Value(slot) case *Const: - return vec.Value().Ptr().AsBool() + return vec.Value().Ptr().AsBool(), vec.Nulls.Value(slot) case *Dict: + if vec.Nulls.Value(slot) { + return false, true + } return BoolValue(vec.Any, uint32(vec.Index[slot])) case *Dynamic: tag := vec.Tags[slot] @@ -160,7 +163,7 @@ func BoolValue(vec Any, slot uint32) bool { case *View: return BoolValue(vec.Any, vec.Index[slot]) } - return false + panic(vec) } func NullsOf(v Any) *Bool {