Skip to content

Commit

Permalink
Implement vector switch operator
Browse files Browse the repository at this point in the history
  • Loading branch information
nwt committed Dec 18, 2024
1 parent 5b29662 commit 6be6309
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 5 deletions.
62 changes: 58 additions & 4 deletions compiler/kernel/vop.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"

"github.com/brimdata/super"
"github.com/brimdata/super/compiler/dag"
"github.com/brimdata/super/compiler/optimizer"
"github.com/brimdata/super/pkg/field"
Expand Down Expand Up @@ -39,10 +40,10 @@ func (b *Builder) compileVam(o dag.Op, parents []vector.Puller) ([]vector.Puller
case *dag.Scope:
//return b.compileVecScope(o, parents)
case *dag.Switch:
//if o.Expr != nil {
// return b.compileVamExprSwitch(o, parents)
//}
//return b.compileVecSwitch(o, parents)
if o.Expr != nil {
return b.compileVamExprSwitch(o, parents)
}
return b.compileVamSwitch(o, parents)
default:
var parent vector.Puller
if len(parents) == 1 {
Expand Down Expand Up @@ -114,6 +115,59 @@ func (b *Builder) compileVamScatter(scatter *dag.Scatter, parents []vector.Pulle
return ops, nil
}

func (b *Builder) compileVamExprSwitch(swtch *dag.Switch, parents []vector.Puller) ([]vector.Puller, error) {
parent := parents[0]
if len(parents) > 1 {
parent = vamop.NewCombine(b.rctx, parents)
}
e, err := b.compileVamExpr(swtch.Expr)
if err != nil {
return nil, err
}
s := vamop.NewExprSwitch(b.rctx, parent, e)
var exits []vector.Puller
for _, c := range swtch.Cases {
var val *super.Value
if c.Expr != nil {
val2, err := b.evalAtCompileTime(c.Expr)
if err != nil {
return nil, err
}
if val2.IsError() {
return nil, errors.New("switch case is not a constant expression")
}
val = &val2
}
parents, err := b.compileVamSeq(c.Path, []vector.Puller{s.AddCase(val)})
if err != nil {
return nil, err
}
exits = append(exits, parents...)
}
return exits, nil
}

func (b *Builder) compileVamSwitch(swtch *dag.Switch, parents []vector.Puller) ([]vector.Puller, error) {
parent := parents[0]
if len(parents) > 1 {
parent = vamop.NewCombine(b.rctx, parents)
}
s := vamop.NewSwitch(b.rctx, parent)
var exits []vector.Puller
for _, c := range swtch.Cases {
e, err := b.compileVamExpr(c.Expr)
if err != nil {
return nil, fmt.Errorf("compiling switch case filter: %w", err)
}
exit, err := b.compileVamSeq(c.Path, []vector.Puller{s.AddCase(e)})
if err != nil {
return nil, err
}
exits = append(exits, exit...)
}
return exits, nil
}

func (b *Builder) compileVamLeaf(o dag.Op, parent vector.Puller) (vector.Puller, error) {
switch o := o.(type) {
case *dag.Cut:
Expand Down
58 changes: 58 additions & 0 deletions runtime/vam/op/exprswitch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package op

import (
"context"

"github.com/brimdata/super"
"github.com/brimdata/super/runtime/vam/expr"
"github.com/brimdata/super/vector"
"github.com/brimdata/super/zcode"
)

type ExprSwitch struct {
expr expr.Evaluator
router *router

builder zcode.Builder
cases map[string]*route
caseIndexes map[*route][]uint32
defaultRoute *route
}

func NewExprSwitch(ctx context.Context, parent vector.Puller, e expr.Evaluator) *ExprSwitch {
s := &ExprSwitch{expr: e, cases: map[string]*route{}, caseIndexes: map[*route][]uint32{}}
s.router = newRouter(ctx, s, parent)
return s
}

func (s *ExprSwitch) AddCase(val *super.Value) vector.Puller {
r := s.router.addRoute()
if val == nil {
s.defaultRoute = r
} else {
s.cases[string(val.Bytes())] = r
}
return r
}

func (s *ExprSwitch) forward(vec vector.Any) bool {
exprVec := s.expr.Eval(vec)
for i := range exprVec.Len() {
s.builder.Truncate()
exprVec.Serialize(&s.builder, i)
route, ok := s.cases[string(s.builder.Bytes().Body())]
if !ok {
route = s.defaultRoute
}
if route != nil {
s.caseIndexes[route] = append(s.caseIndexes[route], i)
}
}
for route, index := range s.caseIndexes {
view := vector.NewView(vec, index)
if !route.send(view, nil) {
return false
}
}
return true
}
2 changes: 1 addition & 1 deletion runtime/vam/op/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func newRouter(ctx context.Context, f forwarder, parent vector.Puller) *router {
return &router{ctx: ctx, forwarder: f, parent: parent}
}

func (r *router) addRoute() vector.Puller {
func (r *router) addRoute() *route {
route := &route{r, make(chan result), make(chan struct{}), false}
r.routes = append(r.routes, route)
return route
Expand Down
94 changes: 94 additions & 0 deletions runtime/vam/op/swtich.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package op

import (
"context"

"github.com/RoaringBitmap/roaring"
"github.com/brimdata/super"
"github.com/brimdata/super/runtime/vam/expr"
"github.com/brimdata/super/vector"
)

type Switch struct {
router *router
cases []expr.Evaluator
}

func NewSwitch(ctx context.Context, parent vector.Puller) *Switch {
s := &Switch{}
s.router = newRouter(ctx, s, parent)
return s
}

func (s *Switch) AddCase(e expr.Evaluator) vector.Puller {
s.cases = append(s.cases, e)
return s.router.addRoute()
}

func (s *Switch) forward(vec vector.Any) bool {
doneMap := roaring.New()
for i, c := range s.cases {
maskVec := c.Eval(vec)
boolMap, errMap := expr.BoolMask(maskVec)
boolMap.AndNot(doneMap)
errMap.AndNot(doneMap)
doneMap.Or(boolMap)
if !errMap.IsEmpty() {
// Clone because iteration results are undefined if the bitmap is modified.
for it := errMap.Clone().Iterator(); it.HasNext(); {
i := it.Next()
if isErrorMissing(maskVec, i) {
errMap.Remove(i)
}
}
}
var vec2 vector.Any
if errMap.IsEmpty() {
if boolMap.IsEmpty() {
continue
}
vec2 = vector.NewView(vec, boolMap.ToArray())
} else if boolMap.IsEmpty() {
vec2 = vector.NewView(maskVec, errMap.ToArray())
} else {
valIndex := boolMap.ToArray()
errIndex := errMap.ToArray()
tags := make([]uint32, 0, len(valIndex)+len(errIndex))
for len(valIndex) > 0 && len(errIndex) > 0 {
if valIndex[0] < errIndex[0] {
valIndex = valIndex[1:]
tags = append(tags, 0)
} else {
errIndex = errIndex[1:]
tags = append(tags, 1)
}
}
tags = append(tags, valIndex...)
tags = append(tags, errIndex...)
valVec := vector.NewView(vec, valIndex)
errVec := vector.NewView(maskVec, errIndex)
vec2 = vector.NewDynamic(tags, []vector.Any{valVec, errVec})
}
if !s.router.routes[i].send(vec2, nil) {
return false
}
}
return true
}

func isErrorMissing(vec vector.Any, i uint32) bool {
vec = vector.Under(vec)
if dynVec, ok := vec.(*vector.Dynamic); ok {
vec = dynVec.Values[dynVec.Tags[i]]
i = dynVec.TagMap.Forward[i]
}
errVec, ok := vec.(*vector.Error)
if !ok {
return false
}
if errVec.Vals.Type().ID() != super.IDString {
return false
}
s, null := vector.StringValue(errVec.Vals, i)
return !null && s == string(super.Missing)
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ zed: |
case this==3 => yield 4
)
vector: true

input: |
1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ zed: |
default => count:=count() |> put a:=-1
) |> sort a
vector: true

input: |
{a:1,s:"a"}
{a:2,s:"B"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ zed: |
switch (
case a == 1 => put v:='one'
case a / 0 => put v:='xxx'
case a % 0 => put v:='yyy'
) |> sort this
vector: true

input: |
{a:1,s:"a"}
{a:2,s:"b"}
output: |
{a:1,s:"a",v:"one"}
error("divide by zero")
error("divide by zero")
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ zed: |
default => count:=count() |> put a:=-1
) |> sort a
vector: true

input: |
{a:1,s:"a"}
{a:2,s:"B"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ zed: |
default => pass
) |> sort b
vector: true

input: |
{a:1,b:1}
{a:2,b:2}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ zed: |
default => over a |> yield {b:this}
) |> sort this
vector: true

input: |
{a:[1,2,3]}
{a:[6,7,8,9]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ zed: |
case 3 => ? null
) |> sort a
vector: true

input: |
{a:1(int32),s:"a"}
{a:2(int32),s:"B"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ zed: |
default => over a |> yield {b:this}
) |> sort this
vector: true

input: |
{a:[1,2,3]}
{a:[6,7,8,9]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ zed: |
case true => count:=count() |> put a:=-1
) |> sort a
vector: true

input: |
{a:1(int32),s:"a"}
{a:2(int32),s:"B"}
Expand Down

0 comments on commit 6be6309

Please sign in to comment.