Skip to content

Commit

Permalink
Rework field selection for projection pushdown (aka demand) (#5553)
Browse files Browse the repository at this point in the history
Simplify compiler/optimizer/demand.go and extend it to handle all DAG
expressions and operators.
  • Loading branch information
nwt authored Jan 8, 2025
1 parent a070a07 commit 86dc144
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 107 deletions.
314 changes: 211 additions & 103 deletions compiler/optimizer/demand.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,132 +5,240 @@ import (
"github.com/brimdata/super/compiler/optimizer/demand"
)

func insertDemand(seq dag.Seq) dag.Seq {
demands := InferDemandSeqOut(seq)
return walk(seq, true, func(seq dag.Seq) dag.Seq {
for _, op := range seq {
switch op := op.(type) {
case *dag.FileScan:
op.Fields = demand.Fields(demands[op])
case *dag.SeqScan:
op.Fields = demand.Fields(demands[op])
}
}
return seq
})
}

// Returns a map from op to the demand on the output of that op.
func InferDemandSeqOut(seq dag.Seq) map[dag.Op]demand.Demand {
demands := make(map[dag.Op]demand.Demand)
inferDemandSeqOutWith(demands, demand.All(), seq)
for _, d := range demands {
if !demand.IsValid(d) {
panic("Invalid demand")
}
func DemandForSeq(seq dag.Seq, downstream demand.Demand) demand.Demand {
for i := len(seq) - 1; i >= 0; i-- {
downstream = demandForOp(seq[i], downstream)
}
return demands
return downstream
}

func inferDemandSeqOutWith(demands map[dag.Op]demand.Demand, demandSeqOut demand.Demand, seq dag.Seq) {
demandOpOut := demandSeqOut
for i := len(seq) - 1; i >= 0; i-- {
op := seq[i]
if _, ok := demands[op]; ok {
panic("Duplicate op value")
func demandForOp(op dag.Op, downstream demand.Demand) demand.Demand {
switch op := op.(type) {
case *dag.Combine:
return downstream
case *dag.Cut:
return demandForAssignments(op.Args, demand.None())
case *dag.Drop:
return downstream
case *dag.Explode:
d := demand.None()
for _, a := range op.Args {
d = demand.Union(d, demandForExpr(a))
}
demands[op] = demandOpOut

// Infer the demand that `op` places on it's input.
var demandOpIn demand.Demand
switch op := op.(type) {
case *dag.FileScan:
demandOpIn = demand.Union(demandOpOut, inferDemandExprIn(demand.All(), op.Filter))
demands[op] = demandOpIn
case *dag.Filter:
demandOpIn = demand.Union(
// Everything that downstream operations need.
demandOpOut,
// Everything that affects the outcome of this filter.
inferDemandExprIn(demand.All(), op.Expr),
)
case *dag.SeqScan:
demandOpIn = demand.Union(demandOpOut, inferDemandExprIn(demand.All(), op.Filter))
demands[op] = demandOpIn
case *dag.Summarize:
demandOpIn = demand.None()
// TODO If LHS not in demandOut, we can ignore RHS
for _, assignment := range op.Keys {
demandOpIn = demand.Union(demandOpIn, inferDemandExprIn(demand.All(), assignment.RHS))
}
for _, assignment := range op.Aggs {
demandOpIn = demand.Union(demandOpIn, inferDemandExprIn(demand.All(), assignment.RHS))
}
case *dag.Yield:
demandOpIn = demand.None()
for _, expr := range op.Exprs {
demandOpIn = demand.Union(demandOpIn, inferDemandExprIn(demandOpOut, expr))
}
default:
// Conservatively assume that `op` uses it's entire input, regardless of output demand.
demandOpIn = demand.All()
return d
case *dag.Filter:
return demand.Union(downstream, demandForExpr(op.Expr))
case *dag.Fork:
d := demand.None()
for _, p := range op.Paths {
d = demand.Union(d, DemandForSeq(p, downstream))
}
demandOpOut = demandOpIn
}
}
return d
case *dag.Fuse:
return demand.All()
case *dag.Head:
return downstream
case *dag.Join:
d := downstream
d = demand.Union(d, demandForExpr(op.LeftKey))
d = demand.Union(d, demandForExpr(op.RightKey))
return demandForAssignments(op.Args, d)
case *dag.Load:
return demand.All()
case *dag.Merge:
return demand.Union(downstream, demandForExpr(op.Expr))
case *dag.Mirror:
return demand.Union(DemandForSeq(op.Main, demand.All()),
DemandForSeq(op.Mirror, demand.All()))
case *dag.Output:
return demand.All()
case *dag.Over:
d := demand.None()
for _, def := range op.Defs {
d = demand.Union(d, demandForExpr(def.Expr))
}
for _, e := range op.Exprs {
d = demand.Union(d, demandForExpr(e))
}
return d
case *dag.Pass:
return downstream
case *dag.Put:
return demandForAssignments(op.Args, downstream)
case *dag.Rename:
return demandForAssignments(op.Args, downstream)
case *dag.Scatter:
d := demand.None()
for _, p := range op.Paths {
d = demand.Union(d, DemandForSeq(p, downstream))
}
return d
case *dag.Scope:
return DemandForSeq(op.Body, downstream)
case *dag.Shape, *dag.Sort:
return downstream
case *dag.Summarize:
d := demand.None()
for _, assignment := range op.Keys {
d = demand.Union(d, demandForExpr(assignment.RHS))
}
for _, assignment := range op.Aggs {
d = demand.Union(d, demandForExpr(assignment.RHS))
}
return d
case *dag.Switch:
d := demandForExpr(op.Expr)
for _, c := range op.Cases {
d = demand.Union(d, demandForExpr(c.Expr))
d = demand.Union(d, DemandForSeq(c.Path, downstream))
}
return d
case *dag.Tail, *dag.Top, *dag.Uniq:
return downstream
case *dag.Vectorize:
return DemandForSeq(op.Body, downstream)
case *dag.Yield:
d := demand.None()
for _, e := range op.Exprs {
d = demand.Union(d, demandForExpr(e))
}
return d

func inferDemandExprIn(demandOut demand.Demand, expr dag.Expr) demand.Demand {
if demand.IsNone(demandOut) {
case *dag.CommitMetaScan, *dag.DefaultScan, *dag.Deleter, *dag.DeleteScan, *dag.LakeMetaScan:
return demand.None()
}
if expr == nil {
case *dag.FileScan:
d := demand.Union(downstream, demandForExpr(op.Filter))
op.Fields = demand.Fields(d)
return demand.None()
case *dag.HTTPScan, *dag.Lister, *dag.NullScan, *dag.PoolMetaScan, *dag.PoolScan:
return demand.None()
case *dag.RobotScan:
return demandForExpr(op.Expr)
case *dag.SeqScan:
d := demand.Union(downstream, demandForExpr(op.Filter))
d = demand.Union(d, demandForExpr(op.KeyPruner))
op.Fields = demand.Fields(d)
return demand.None()
case *dag.Slicer:
return demand.None()
}
var demandIn demand.Demand
panic(op)
}

func demandForExpr(expr dag.Expr) demand.Demand {
switch expr := expr.(type) {
case nil:
return demand.None()
case *dag.Agg:
// Since we don't know how the expr.Name will transform the inputs, we have to assume demand.All.
return demand.Union(
inferDemandExprIn(demand.All(), expr.Expr),
inferDemandExprIn(demand.All(), expr.Where),
)
return demand.Union(demandForExpr(expr.Expr), demandForExpr(expr.Where))
case *dag.ArrayExpr:
return demandForArrayOrSetExpr(expr.Elems)
case *dag.BinaryExpr:
// Since we don't know how the expr.Op will transform the inputs, we have to assume demand.All.
demandIn = demand.Union(
inferDemandExprIn(demand.All(), expr.LHS),
inferDemandExprIn(demand.All(), expr.RHS),
)
return demand.Union(demandForExpr(expr.LHS), demandForExpr(expr.RHS))
case *dag.Call:
d := demand.None()
if expr.Name == "every" {
d = demand.Key("ts", demand.All())
}
for _, a := range expr.Args {
d = demand.Union(d, demandForExpr(a))
}
return d
case *dag.Conditional:
return demand.Union(demandForExpr(expr.Cond),
demand.Union(demandForExpr(expr.Then), demandForExpr(expr.Else)))
case *dag.Dot:
demandIn = demand.Key(expr.RHS, inferDemandExprIn(demandOut, expr.LHS))
return demandForExpr(expr.LHS)
case *dag.Func:
// return demand.All()
case *dag.IndexExpr:
return demand.Union(demandForExpr(expr.Expr), demandForExpr(expr.Index))
case *dag.IsNullExpr:
return demandForExpr(expr.Expr)
case *dag.Literal:
demandIn = demand.None()
return demand.None()
case *dag.MapCall:
return demandForExpr(expr.Expr)
case *dag.MapExpr:
demandIn = demand.None()
for _, entry := range expr.Entries {
demandIn = demand.Union(demandIn, inferDemandExprIn(demand.All(), entry.Key))
demandIn = demand.Union(demandIn, inferDemandExprIn(demand.All(), entry.Value))
d := demand.None()
for _, e := range expr.Entries {
d = demand.Union(d, demandForExpr(e.Key))
d = demand.Union(d, demandForExpr(e.Value))
}
return d
case *dag.OverExpr:
d := demand.None()
for _, def := range expr.Defs {
d = demand.Union(d, demandForExpr(def.Expr))
}
for _, e := range expr.Exprs {
d = demand.Union(d, demandForExpr(e))
}
return d
case *dag.RecordExpr:
demandIn = demand.None()
for _, elem := range expr.Elems {
switch elem := elem.(type) {
d := demand.None()
for _, e := range expr.Elems {
switch e := e.(type) {
case *dag.Field:
demandValueOut := demand.GetKey(demandOut, elem.Name)
if !demand.IsNone(demandValueOut) {
demandIn = demand.Union(demandIn, inferDemandExprIn(demandValueOut, elem.Value))
}
d = demand.Union(d, demandForExpr(e.Value))
case *dag.Spread:
demandIn = demand.Union(demandIn, inferDemandExprIn(demand.All(), elem.Expr))
d = demand.Union(d, demandForExpr(e.Expr))
default:
panic(e)
}
}
return d
case *dag.RegexpMatch:
return demandForExpr(expr.Expr)
case *dag.RegexpSearch:
return demandForExpr(expr.Expr)
case *dag.Search:
return demandForExpr(expr.Expr)
case *dag.SetExpr:
return demandForArrayOrSetExpr(expr.Elems)
case *dag.SliceExpr:
return demand.Union(demandForExpr(expr.Expr),
demand.Union(demandForExpr(expr.From), demandForExpr(expr.To)))
case *dag.This:
demandIn = demandOut
d := demand.All()
for i := len(expr.Path) - 1; i >= 0; i-- {
demandIn = demand.Key(expr.Path[i], demandIn)
d = demand.Key(expr.Path[i], d)
}
return d
case *dag.UnaryExpr:
return demandForExpr(expr.Operand)
case *dag.Var:
return demand.None()
}
panic(expr)
}

func demandForArrayOrSetExpr(elems []dag.VectorElem) demand.Demand {
d := demand.None()
for _, e := range elems {
switch e := e.(type) {
case *dag.Spread:
d = demand.Union(d, demandForExpr(e.Expr))
case *dag.VectorValue:
d = demand.Union(d, demandForExpr(e.Expr))
default:
panic(e)
}
}
return d
}

func demandForAssignments(assignments []dag.Assignment, downstream demand.Demand) demand.Demand {
d := downstream
for _, a := range assignments {
if _, ok := a.LHS.(*dag.This); ok {
// Assignment clobbers a static field.
d = demand.Delete(d, demandForExpr(a.LHS))
} else {
// Add anything needed by a dynamic field.
d = demand.Union(d, demandForExpr(a.LHS))
}
default:
// Conservatively assume that `expr` uses it's entire input, regardless of output demand.
demandIn = demand.All()
d = demand.Union(d, demandForExpr(a.RHS))
}
return demandIn
return d
}
35 changes: 34 additions & 1 deletion compiler/optimizer/demand/demand.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package demand

import "github.com/brimdata/super/pkg/field"
import (
"maps"

"github.com/brimdata/super/pkg/field"
)

type Demand interface {
isDemand()
Expand Down Expand Up @@ -61,6 +65,35 @@ func Key(key string, value Demand) Demand {
return keys{key: value}
}

// Delete deletes entries in b from a.
func Delete(a, b Demand) Demand {
aa, ok := a.(keys)
if !ok {
return a
}
bb, ok := b.(keys)
if !ok {
return a
}
copyOnWrite := true
for k, bv := range bb {
av, ok := aa[k]
if !ok {
continue
}
if copyOnWrite {
aa = maps.Clone(aa)
copyOnWrite = false
}
if IsAll(bv) {
delete(aa, k)
continue
}
aa[k] = Delete(av, bv)
}
return aa
}

func Union(a Demand, b Demand) Demand {
if _, ok := a.(all); ok {
return a
Expand Down
Loading

0 comments on commit 86dc144

Please sign in to comment.