diff --git a/compiler/optimizer/demand.go b/compiler/optimizer/demand.go index 6e9057ed04..f6f68f6b50 100644 --- a/compiler/optimizer/demand.go +++ b/compiler/optimizer/demand.go @@ -5,132 +5,240 @@ import ( "github.com/brimdata/super/compiler/optimizer/demand" ) -func insertDemand(seq dag.Seq) dag.Seq { - demands := InferDemandSeqOut(seq) - return walk(seq, true, func(seq dag.Seq) dag.Seq { - for _, op := range seq { - switch op := op.(type) { - case *dag.FileScan: - op.Fields = demand.Fields(demands[op]) - case *dag.SeqScan: - op.Fields = demand.Fields(demands[op]) - } - } - return seq - }) -} - -// Returns a map from op to the demand on the output of that op. -func InferDemandSeqOut(seq dag.Seq) map[dag.Op]demand.Demand { - demands := make(map[dag.Op]demand.Demand) - inferDemandSeqOutWith(demands, demand.All(), seq) - for _, d := range demands { - if !demand.IsValid(d) { - panic("Invalid demand") - } +func DemandForSeq(seq dag.Seq, downstream demand.Demand) demand.Demand { + for i := len(seq) - 1; i >= 0; i-- { + downstream = demandForOp(seq[i], downstream) } - return demands + return downstream } -func inferDemandSeqOutWith(demands map[dag.Op]demand.Demand, demandSeqOut demand.Demand, seq dag.Seq) { - demandOpOut := demandSeqOut - for i := len(seq) - 1; i >= 0; i-- { - op := seq[i] - if _, ok := demands[op]; ok { - panic("Duplicate op value") +func demandForOp(op dag.Op, downstream demand.Demand) demand.Demand { + switch op := op.(type) { + case *dag.Combine: + return downstream + case *dag.Cut: + return demandForAssignments(op.Args, demand.None()) + case *dag.Drop: + return downstream + case *dag.Explode: + d := demand.None() + for _, a := range op.Args { + d = demand.Union(d, demandForExpr(a)) } - demands[op] = demandOpOut - - // Infer the demand that `op` places on it's input. - var demandOpIn demand.Demand - switch op := op.(type) { - case *dag.FileScan: - demandOpIn = demand.Union(demandOpOut, inferDemandExprIn(demand.All(), op.Filter)) - demands[op] = demandOpIn - case *dag.Filter: - demandOpIn = demand.Union( - // Everything that downstream operations need. - demandOpOut, - // Everything that affects the outcome of this filter. - inferDemandExprIn(demand.All(), op.Expr), - ) - case *dag.SeqScan: - demandOpIn = demand.Union(demandOpOut, inferDemandExprIn(demand.All(), op.Filter)) - demands[op] = demandOpIn - case *dag.Summarize: - demandOpIn = demand.None() - // TODO If LHS not in demandOut, we can ignore RHS - for _, assignment := range op.Keys { - demandOpIn = demand.Union(demandOpIn, inferDemandExprIn(demand.All(), assignment.RHS)) - } - for _, assignment := range op.Aggs { - demandOpIn = demand.Union(demandOpIn, inferDemandExprIn(demand.All(), assignment.RHS)) - } - case *dag.Yield: - demandOpIn = demand.None() - for _, expr := range op.Exprs { - demandOpIn = demand.Union(demandOpIn, inferDemandExprIn(demandOpOut, expr)) - } - default: - // Conservatively assume that `op` uses it's entire input, regardless of output demand. - demandOpIn = demand.All() + return d + case *dag.Filter: + return demand.Union(downstream, demandForExpr(op.Expr)) + case *dag.Fork: + d := demand.None() + for _, p := range op.Paths { + d = demand.Union(d, DemandForSeq(p, downstream)) } - demandOpOut = demandOpIn - } -} + return d + case *dag.Fuse: + return demand.All() + case *dag.Head: + return downstream + case *dag.Join: + d := downstream + d = demand.Union(d, demandForExpr(op.LeftKey)) + d = demand.Union(d, demandForExpr(op.RightKey)) + return demandForAssignments(op.Args, d) + case *dag.Load: + return demand.All() + case *dag.Merge: + return demand.Union(downstream, demandForExpr(op.Expr)) + case *dag.Mirror: + return demand.Union(DemandForSeq(op.Main, demand.All()), + DemandForSeq(op.Mirror, demand.All())) + case *dag.Output: + return demand.All() + case *dag.Over: + d := demand.None() + for _, def := range op.Defs { + d = demand.Union(d, demandForExpr(def.Expr)) + } + for _, e := range op.Exprs { + d = demand.Union(d, demandForExpr(e)) + } + return d + case *dag.Pass: + return downstream + case *dag.Put: + return demandForAssignments(op.Args, downstream) + case *dag.Rename: + return demandForAssignments(op.Args, downstream) + case *dag.Scatter: + d := demand.None() + for _, p := range op.Paths { + d = demand.Union(d, DemandForSeq(p, downstream)) + } + return d + case *dag.Scope: + return DemandForSeq(op.Body, downstream) + case *dag.Shape, *dag.Sort: + return downstream + case *dag.Summarize: + d := demand.None() + for _, assignment := range op.Keys { + d = demand.Union(d, demandForExpr(assignment.RHS)) + } + for _, assignment := range op.Aggs { + d = demand.Union(d, demandForExpr(assignment.RHS)) + } + return d + case *dag.Switch: + d := demandForExpr(op.Expr) + for _, c := range op.Cases { + d = demand.Union(d, demandForExpr(c.Expr)) + d = demand.Union(d, DemandForSeq(c.Path, downstream)) + } + return d + case *dag.Tail, *dag.Top, *dag.Uniq: + return downstream + case *dag.Vectorize: + return DemandForSeq(op.Body, downstream) + case *dag.Yield: + d := demand.None() + for _, e := range op.Exprs { + d = demand.Union(d, demandForExpr(e)) + } + return d -func inferDemandExprIn(demandOut demand.Demand, expr dag.Expr) demand.Demand { - if demand.IsNone(demandOut) { + case *dag.CommitMetaScan, *dag.DefaultScan, *dag.Deleter, *dag.DeleteScan, *dag.LakeMetaScan: return demand.None() - } - if expr == nil { + case *dag.FileScan: + d := demand.Union(downstream, demandForExpr(op.Filter)) + op.Fields = demand.Fields(d) + return demand.None() + case *dag.HTTPScan, *dag.Lister, *dag.NullScan, *dag.PoolMetaScan, *dag.PoolScan: + return demand.None() + case *dag.RobotScan: + return demandForExpr(op.Expr) + case *dag.SeqScan: + d := demand.Union(downstream, demandForExpr(op.Filter)) + d = demand.Union(d, demandForExpr(op.KeyPruner)) + op.Fields = demand.Fields(d) + return demand.None() + case *dag.Slicer: return demand.None() } - var demandIn demand.Demand + panic(op) +} + +func demandForExpr(expr dag.Expr) demand.Demand { switch expr := expr.(type) { + case nil: + return demand.None() case *dag.Agg: - // Since we don't know how the expr.Name will transform the inputs, we have to assume demand.All. - return demand.Union( - inferDemandExprIn(demand.All(), expr.Expr), - inferDemandExprIn(demand.All(), expr.Where), - ) + return demand.Union(demandForExpr(expr.Expr), demandForExpr(expr.Where)) + case *dag.ArrayExpr: + return demandForArrayOrSetExpr(expr.Elems) case *dag.BinaryExpr: - // Since we don't know how the expr.Op will transform the inputs, we have to assume demand.All. - demandIn = demand.Union( - inferDemandExprIn(demand.All(), expr.LHS), - inferDemandExprIn(demand.All(), expr.RHS), - ) + return demand.Union(demandForExpr(expr.LHS), demandForExpr(expr.RHS)) + case *dag.Call: + d := demand.None() + if expr.Name == "every" { + d = demand.Key("ts", demand.All()) + } + for _, a := range expr.Args { + d = demand.Union(d, demandForExpr(a)) + } + return d + case *dag.Conditional: + return demand.Union(demandForExpr(expr.Cond), + demand.Union(demandForExpr(expr.Then), demandForExpr(expr.Else))) case *dag.Dot: - demandIn = demand.Key(expr.RHS, inferDemandExprIn(demandOut, expr.LHS)) + return demandForExpr(expr.LHS) + case *dag.Func: + // return demand.All() + case *dag.IndexExpr: + return demand.Union(demandForExpr(expr.Expr), demandForExpr(expr.Index)) + case *dag.IsNullExpr: + return demandForExpr(expr.Expr) case *dag.Literal: - demandIn = demand.None() + return demand.None() + case *dag.MapCall: + return demandForExpr(expr.Expr) case *dag.MapExpr: - demandIn = demand.None() - for _, entry := range expr.Entries { - demandIn = demand.Union(demandIn, inferDemandExprIn(demand.All(), entry.Key)) - demandIn = demand.Union(demandIn, inferDemandExprIn(demand.All(), entry.Value)) + d := demand.None() + for _, e := range expr.Entries { + d = demand.Union(d, demandForExpr(e.Key)) + d = demand.Union(d, demandForExpr(e.Value)) + } + return d + case *dag.OverExpr: + d := demand.None() + for _, def := range expr.Defs { + d = demand.Union(d, demandForExpr(def.Expr)) + } + for _, e := range expr.Exprs { + d = demand.Union(d, demandForExpr(e)) } + return d case *dag.RecordExpr: - demandIn = demand.None() - for _, elem := range expr.Elems { - switch elem := elem.(type) { + d := demand.None() + for _, e := range expr.Elems { + switch e := e.(type) { case *dag.Field: - demandValueOut := demand.GetKey(demandOut, elem.Name) - if !demand.IsNone(demandValueOut) { - demandIn = demand.Union(demandIn, inferDemandExprIn(demandValueOut, elem.Value)) - } + d = demand.Union(d, demandForExpr(e.Value)) case *dag.Spread: - demandIn = demand.Union(demandIn, inferDemandExprIn(demand.All(), elem.Expr)) + d = demand.Union(d, demandForExpr(e.Expr)) + default: + panic(e) } } + return d + case *dag.RegexpMatch: + return demandForExpr(expr.Expr) + case *dag.RegexpSearch: + return demandForExpr(expr.Expr) + case *dag.Search: + return demandForExpr(expr.Expr) + case *dag.SetExpr: + return demandForArrayOrSetExpr(expr.Elems) + case *dag.SliceExpr: + return demand.Union(demandForExpr(expr.Expr), + demand.Union(demandForExpr(expr.From), demandForExpr(expr.To))) case *dag.This: - demandIn = demandOut + d := demand.All() for i := len(expr.Path) - 1; i >= 0; i-- { - demandIn = demand.Key(expr.Path[i], demandIn) + d = demand.Key(expr.Path[i], d) + } + return d + case *dag.UnaryExpr: + return demandForExpr(expr.Operand) + case *dag.Var: + return demand.None() + } + panic(expr) +} + +func demandForArrayOrSetExpr(elems []dag.VectorElem) demand.Demand { + d := demand.None() + for _, e := range elems { + switch e := e.(type) { + case *dag.Spread: + d = demand.Union(d, demandForExpr(e.Expr)) + case *dag.VectorValue: + d = demand.Union(d, demandForExpr(e.Expr)) + default: + panic(e) + } + } + return d +} + +func demandForAssignments(assignments []dag.Assignment, downstream demand.Demand) demand.Demand { + d := downstream + for _, a := range assignments { + if _, ok := a.LHS.(*dag.This); ok { + // Assignment clobbers a static field. + d = demand.Delete(d, demandForExpr(a.LHS)) + } else { + // Add anything needed by a dynamic field. + d = demand.Union(d, demandForExpr(a.LHS)) } - default: - // Conservatively assume that `expr` uses it's entire input, regardless of output demand. - demandIn = demand.All() + d = demand.Union(d, demandForExpr(a.RHS)) } - return demandIn + return d } diff --git a/compiler/optimizer/demand/demand.go b/compiler/optimizer/demand/demand.go index 26f2a123b5..73b417a340 100644 --- a/compiler/optimizer/demand/demand.go +++ b/compiler/optimizer/demand/demand.go @@ -1,6 +1,10 @@ package demand -import "github.com/brimdata/super/pkg/field" +import ( + "maps" + + "github.com/brimdata/super/pkg/field" +) type Demand interface { isDemand() @@ -61,6 +65,35 @@ func Key(key string, value Demand) Demand { return keys{key: value} } +// Delete deletes entries in b from a. +func Delete(a, b Demand) Demand { + aa, ok := a.(keys) + if !ok { + return a + } + bb, ok := b.(keys) + if !ok { + return a + } + copyOnWrite := true + for k, bv := range bb { + av, ok := aa[k] + if !ok { + continue + } + if copyOnWrite { + aa = maps.Clone(aa) + copyOnWrite = false + } + if IsAll(bv) { + delete(aa, k) + continue + } + aa[k] = Delete(av, bv) + } + return aa +} + func Union(a Demand, b Demand) Demand { if _, ok := a.(all); ok { return a diff --git a/compiler/optimizer/optimizer.go b/compiler/optimizer/optimizer.go index 053695bc16..5cbc0c7450 100644 --- a/compiler/optimizer/optimizer.go +++ b/compiler/optimizer/optimizer.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/brimdata/super/compiler/dag" + "github.com/brimdata/super/compiler/optimizer/demand" "github.com/brimdata/super/lake" "github.com/brimdata/super/order" "github.com/brimdata/super/pkg/field" @@ -152,8 +153,8 @@ func (o *Optimizer) Optimize(seq dag.Seq) (dag.Seq, error) { if err != nil { return nil, err } - seq = insertDemand(seq) seq = removePassOps(seq) + DemandForSeq(seq, demand.All()) return seq, nil } diff --git a/fuzz/fuzz.go b/fuzz/fuzz.go index a596977ddb..39799674f3 100644 --- a/fuzz/fuzz.go +++ b/fuzz/fuzz.go @@ -115,8 +115,7 @@ func RunQuery(t testing.TB, zctx *super.Context, readers []zio.Reader, querySour t.Skipf("%v", err) } if len(dag) > 0 { - demands := optimizer.InferDemandSeqOut(dag) - demand := demands[dag[0]] + demand := optimizer.DemandForSeq(dag, demand.All()) useDemand(demand) }