Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework field selection for projection pushdown (aka demand) #5553

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
314 changes: 211 additions & 103 deletions compiler/optimizer/demand.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,132 +5,240 @@ import (
"github.com/brimdata/super/compiler/optimizer/demand"
)

func insertDemand(seq dag.Seq) dag.Seq {
demands := InferDemandSeqOut(seq)
return walk(seq, true, func(seq dag.Seq) dag.Seq {
for _, op := range seq {
switch op := op.(type) {
case *dag.FileScan:
op.Fields = demand.Fields(demands[op])
case *dag.SeqScan:
op.Fields = demand.Fields(demands[op])
}
}
return seq
})
}

// Returns a map from op to the demand on the output of that op.
func InferDemandSeqOut(seq dag.Seq) map[dag.Op]demand.Demand {
demands := make(map[dag.Op]demand.Demand)
inferDemandSeqOutWith(demands, demand.All(), seq)
for _, d := range demands {
if !demand.IsValid(d) {
panic("Invalid demand")
}
func DemandForSeq(seq dag.Seq, downstream demand.Demand) demand.Demand {
for i := len(seq) - 1; i >= 0; i-- {
downstream = demandForOp(seq[i], downstream)
}
return demands
return downstream
}

func inferDemandSeqOutWith(demands map[dag.Op]demand.Demand, demandSeqOut demand.Demand, seq dag.Seq) {
demandOpOut := demandSeqOut
for i := len(seq) - 1; i >= 0; i-- {
op := seq[i]
if _, ok := demands[op]; ok {
panic("Duplicate op value")
func demandForOp(op dag.Op, downstream demand.Demand) demand.Demand {
switch op := op.(type) {
case *dag.Combine:
return downstream
case *dag.Cut:
return demandForAssignments(op.Args, demand.None())
case *dag.Drop:
return downstream
case *dag.Explode:
d := demand.None()
for _, a := range op.Args {
d = demand.Union(d, demandForExpr(a))
}
demands[op] = demandOpOut

// Infer the demand that `op` places on it's input.
var demandOpIn demand.Demand
switch op := op.(type) {
case *dag.FileScan:
demandOpIn = demand.Union(demandOpOut, inferDemandExprIn(demand.All(), op.Filter))
demands[op] = demandOpIn
case *dag.Filter:
demandOpIn = demand.Union(
// Everything that downstream operations need.
demandOpOut,
// Everything that affects the outcome of this filter.
inferDemandExprIn(demand.All(), op.Expr),
)
case *dag.SeqScan:
demandOpIn = demand.Union(demandOpOut, inferDemandExprIn(demand.All(), op.Filter))
demands[op] = demandOpIn
case *dag.Summarize:
demandOpIn = demand.None()
// TODO If LHS not in demandOut, we can ignore RHS
for _, assignment := range op.Keys {
demandOpIn = demand.Union(demandOpIn, inferDemandExprIn(demand.All(), assignment.RHS))
}
for _, assignment := range op.Aggs {
demandOpIn = demand.Union(demandOpIn, inferDemandExprIn(demand.All(), assignment.RHS))
}
case *dag.Yield:
demandOpIn = demand.None()
for _, expr := range op.Exprs {
demandOpIn = demand.Union(demandOpIn, inferDemandExprIn(demandOpOut, expr))
}
default:
// Conservatively assume that `op` uses it's entire input, regardless of output demand.
demandOpIn = demand.All()
return d
case *dag.Filter:
return demand.Union(downstream, demandForExpr(op.Expr))
case *dag.Fork:
d := demand.None()
for _, p := range op.Paths {
d = demand.Union(d, DemandForSeq(p, downstream))
}
demandOpOut = demandOpIn
}
}
return d
case *dag.Fuse:
return demand.All()
case *dag.Head:
return downstream
case *dag.Join:
d := downstream
d = demand.Union(d, demandForExpr(op.LeftKey))
d = demand.Union(d, demandForExpr(op.RightKey))
return demandForAssignments(op.Args, d)
case *dag.Load:
return demand.All()
case *dag.Merge:
return demand.Union(downstream, demandForExpr(op.Expr))
case *dag.Mirror:
return demand.Union(DemandForSeq(op.Main, demand.All()),
DemandForSeq(op.Mirror, demand.All()))
case *dag.Output:
return demand.All()
case *dag.Over:
d := demand.None()
for _, def := range op.Defs {
d = demand.Union(d, demandForExpr(def.Expr))
}
for _, e := range op.Exprs {
d = demand.Union(d, demandForExpr(e))
}
return d
case *dag.Pass:
return downstream
case *dag.Put:
return demandForAssignments(op.Args, downstream)
case *dag.Rename:
return demandForAssignments(op.Args, downstream)
case *dag.Scatter:
d := demand.None()
for _, p := range op.Paths {
d = demand.Union(d, DemandForSeq(p, downstream))
}
return d
case *dag.Scope:
return DemandForSeq(op.Body, downstream)
case *dag.Shape, *dag.Sort:
return downstream
case *dag.Summarize:
d := demand.None()
for _, assignment := range op.Keys {
d = demand.Union(d, demandForExpr(assignment.RHS))
}
for _, assignment := range op.Aggs {
d = demand.Union(d, demandForExpr(assignment.RHS))
}
return d
case *dag.Switch:
d := demandForExpr(op.Expr)
for _, c := range op.Cases {
d = demand.Union(d, demandForExpr(c.Expr))
d = demand.Union(d, DemandForSeq(c.Path, downstream))
}
return d
case *dag.Tail, *dag.Top, *dag.Uniq:
return downstream
case *dag.Vectorize:
return DemandForSeq(op.Body, downstream)
case *dag.Yield:
d := demand.None()
for _, e := range op.Exprs {
d = demand.Union(d, demandForExpr(e))
}
return d

func inferDemandExprIn(demandOut demand.Demand, expr dag.Expr) demand.Demand {
if demand.IsNone(demandOut) {
case *dag.CommitMetaScan, *dag.DefaultScan, *dag.Deleter, *dag.DeleteScan, *dag.LakeMetaScan:
return demand.None()
}
if expr == nil {
case *dag.FileScan:
d := demand.Union(downstream, demandForExpr(op.Filter))
op.Fields = demand.Fields(d)
return demand.None()
case *dag.HTTPScan, *dag.Lister, *dag.NullScan, *dag.PoolMetaScan, *dag.PoolScan:
return demand.None()
case *dag.RobotScan:
return demandForExpr(op.Expr)
case *dag.SeqScan:
d := demand.Union(downstream, demandForExpr(op.Filter))
d = demand.Union(d, demandForExpr(op.KeyPruner))
op.Fields = demand.Fields(d)
return demand.None()
case *dag.Slicer:
return demand.None()
}
var demandIn demand.Demand
panic(op)
}

func demandForExpr(expr dag.Expr) demand.Demand {
switch expr := expr.(type) {
case nil:
return demand.None()
case *dag.Agg:
// Since we don't know how the expr.Name will transform the inputs, we have to assume demand.All.
return demand.Union(
inferDemandExprIn(demand.All(), expr.Expr),
inferDemandExprIn(demand.All(), expr.Where),
)
return demand.Union(demandForExpr(expr.Expr), demandForExpr(expr.Where))
case *dag.ArrayExpr:
return demandForArrayOrSetExpr(expr.Elems)
case *dag.BinaryExpr:
// Since we don't know how the expr.Op will transform the inputs, we have to assume demand.All.
demandIn = demand.Union(
inferDemandExprIn(demand.All(), expr.LHS),
inferDemandExprIn(demand.All(), expr.RHS),
)
return demand.Union(demandForExpr(expr.LHS), demandForExpr(expr.RHS))
case *dag.Call:
d := demand.None()
if expr.Name == "every" {
d = demand.Key("ts", demand.All())
}
for _, a := range expr.Args {
d = demand.Union(d, demandForExpr(a))
}
return d
case *dag.Conditional:
return demand.Union(demandForExpr(expr.Cond),
demand.Union(demandForExpr(expr.Then), demandForExpr(expr.Else)))
case *dag.Dot:
demandIn = demand.Key(expr.RHS, inferDemandExprIn(demandOut, expr.LHS))
return demandForExpr(expr.LHS)
case *dag.Func:
// return demand.All()
case *dag.IndexExpr:
return demand.Union(demandForExpr(expr.Expr), demandForExpr(expr.Index))
case *dag.IsNullExpr:
return demandForExpr(expr.Expr)
case *dag.Literal:
demandIn = demand.None()
return demand.None()
case *dag.MapCall:
return demandForExpr(expr.Expr)
case *dag.MapExpr:
demandIn = demand.None()
for _, entry := range expr.Entries {
demandIn = demand.Union(demandIn, inferDemandExprIn(demand.All(), entry.Key))
demandIn = demand.Union(demandIn, inferDemandExprIn(demand.All(), entry.Value))
d := demand.None()
for _, e := range expr.Entries {
d = demand.Union(d, demandForExpr(e.Key))
d = demand.Union(d, demandForExpr(e.Value))
}
return d
case *dag.OverExpr:
d := demand.None()
for _, def := range expr.Defs {
d = demand.Union(d, demandForExpr(def.Expr))
}
for _, e := range expr.Exprs {
d = demand.Union(d, demandForExpr(e))
}
return d
case *dag.RecordExpr:
demandIn = demand.None()
for _, elem := range expr.Elems {
switch elem := elem.(type) {
d := demand.None()
for _, e := range expr.Elems {
switch e := e.(type) {
case *dag.Field:
demandValueOut := demand.GetKey(demandOut, elem.Name)
if !demand.IsNone(demandValueOut) {
demandIn = demand.Union(demandIn, inferDemandExprIn(demandValueOut, elem.Value))
}
d = demand.Union(d, demandForExpr(e.Value))
case *dag.Spread:
demandIn = demand.Union(demandIn, inferDemandExprIn(demand.All(), elem.Expr))
d = demand.Union(d, demandForExpr(e.Expr))
default:
panic(e)
}
}
return d
case *dag.RegexpMatch:
return demandForExpr(expr.Expr)
case *dag.RegexpSearch:
return demandForExpr(expr.Expr)
case *dag.Search:
return demandForExpr(expr.Expr)
case *dag.SetExpr:
return demandForArrayOrSetExpr(expr.Elems)
case *dag.SliceExpr:
return demand.Union(demandForExpr(expr.Expr),
demand.Union(demandForExpr(expr.From), demandForExpr(expr.To)))
case *dag.This:
demandIn = demandOut
d := demand.All()
for i := len(expr.Path) - 1; i >= 0; i-- {
demandIn = demand.Key(expr.Path[i], demandIn)
d = demand.Key(expr.Path[i], d)
}
return d
case *dag.UnaryExpr:
return demandForExpr(expr.Operand)
case *dag.Var:
return demand.None()
}
panic(expr)
}

func demandForArrayOrSetExpr(elems []dag.VectorElem) demand.Demand {
d := demand.None()
for _, e := range elems {
switch e := e.(type) {
case *dag.Spread:
d = demand.Union(d, demandForExpr(e.Expr))
case *dag.VectorValue:
d = demand.Union(d, demandForExpr(e.Expr))
default:
panic(e)
}
}
return d
}

func demandForAssignments(assignments []dag.Assignment, downstream demand.Demand) demand.Demand {
d := downstream
for _, a := range assignments {
if _, ok := a.LHS.(*dag.This); ok {
// Assignment clobbers a static field.
d = demand.Delete(d, demandForExpr(a.LHS))
} else {
// Add anything needed by a dynamic field.
d = demand.Union(d, demandForExpr(a.LHS))
}
default:
// Conservatively assume that `expr` uses it's entire input, regardless of output demand.
demandIn = demand.All()
d = demand.Union(d, demandForExpr(a.RHS))
}
return demandIn
return d
}
35 changes: 34 additions & 1 deletion compiler/optimizer/demand/demand.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package demand

import "github.com/brimdata/super/pkg/field"
import (
"maps"

"github.com/brimdata/super/pkg/field"
)

type Demand interface {
isDemand()
Expand Down Expand Up @@ -61,6 +65,35 @@ func Key(key string, value Demand) Demand {
return keys{key: value}
}

// Delete deletes entries in b from a.
func Delete(a, b Demand) Demand {
aa, ok := a.(keys)
if !ok {
return a
}
bb, ok := b.(keys)
if !ok {
return a
}
copyOnWrite := true
for k, bv := range bb {
av, ok := aa[k]
if !ok {
continue
}
if copyOnWrite {
aa = maps.Clone(aa)
copyOnWrite = false
}
if IsAll(bv) {
delete(aa, k)
continue
}
aa[k] = Delete(av, bv)
}
return aa
}

func Union(a Demand, b Demand) Demand {
if _, ok := a.(all); ok {
return a
Expand Down
Loading