Skip to content

Commit

Permalink
vam: Support partials for aggregations (#5416)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattnibs authored Nov 1, 2024
1 parent 5bc4b98 commit f2bf9e6
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 50 deletions.
2 changes: 1 addition & 1 deletion compiler/kernel/vop.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func (b *Builder) compileVamSummarize(s *dag.Summarize, parent vector.Puller) (v
keyNames = append(keyNames, lhs.Path)
keyExprs = append(keyExprs, rhs)
}
return summarize.New(parent, b.zctx(), aggNames, aggs, keyNames, keyExprs)
return summarize.New(parent, b.zctx(), aggNames, aggs, keyNames, keyExprs, s.PartialsIn, s.PartialsOut)
}

func (b *Builder) compileVamAgg(agg *dag.Agg) (*vamexpr.Aggregator, error) {
Expand Down
22 changes: 4 additions & 18 deletions runtime/vam/expr/agg/agg.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (

type Func interface {
Consume(vector.Any)
Result() super.Value
ConsumeAsPartial(vector.Any)
Result(*super.Context) super.Value
ResultAsPartial(*super.Context) super.Value
}

type Pattern func() Func
Expand All @@ -21,7 +23,7 @@ func NewPattern(op string, hasarg bool) (Pattern, error) {
case "count":
needarg = false
pattern = func() Func {
return newAggCount()
return &count{}
}
// case "any":
// pattern = func() AggFunc {
Expand Down Expand Up @@ -71,19 +73,3 @@ func NewPattern(op string, hasarg bool) (Pattern, error) {
}
return pattern, nil
}

type aggCount struct {
count uint64
}

func newAggCount() *aggCount {
return &aggCount{}
}

func (a *aggCount) Consume(vec vector.Any) {
a.count += uint64(vec.Len())
}

func (a *aggCount) Result() super.Value {
return super.NewUint64(a.count)
}
38 changes: 37 additions & 1 deletion runtime/vam/expr/agg/avg.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package agg
import (
"github.com/brimdata/super"
"github.com/brimdata/super/vector"
"github.com/brimdata/super/zcode"
)

type avg struct {
Expand All @@ -20,9 +21,44 @@ func (a *avg) Consume(vec vector.Any) {
a.sum = sum(a.sum, vec)
}

func (a *avg) Result() super.Value {
func (a *avg) Result(*super.Context) super.Value {
if a.count > 0 {
return super.NewFloat64(a.sum / float64(a.count))
}
return super.NullFloat64
}

const (
sumName = "sum"
countName = "count"
)

func (a *avg) ConsumeAsPartial(partial vector.Any) {
rec, ok := partial.(*vector.Record)
if !ok || rec.Len() != 1 {
panic("avg: invalid partial")
}
si, ok1 := rec.Typ.IndexOfField(sumName)
ci, ok2 := rec.Typ.IndexOfField(countName)
if !ok1 || !ok2 {
panic("avg: invalid partial")
}
sumVal, ok1 := rec.Fields[si].(*vector.Const)
countVal, ok2 := rec.Fields[ci].(*vector.Const)
if !ok1 || !ok2 || sumVal.Type() != super.TypeFloat64 || countVal.Type() != super.TypeUint64 {
panic("avg: invalid partial")
}
a.sum += sumVal.Value().Float()
a.count += countVal.Value().Uint()
}

func (a *avg) ResultAsPartial(zctx *super.Context) super.Value {
var zv zcode.Bytes
zv = super.NewFloat64(a.sum).Encode(zv)
zv = super.NewUint64(a.count).Encode(zv)
typ := zctx.MustLookupTypeRecord([]super.Field{
super.NewField(sumName, super.TypeFloat64),
super.NewField(countName, super.TypeUint64),
})
return super.NewValue(typ, zv)
}
30 changes: 30 additions & 0 deletions runtime/vam/expr/agg/count.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package agg

import (
"github.com/brimdata/super"
"github.com/brimdata/super/vector"
)

type count struct {
count uint64
}

func (a *count) Consume(vec vector.Any) {
a.count += uint64(vec.Len())
}

func (a *count) Result(*super.Context) super.Value {
return super.NewUint64(a.count)
}

func (a *count) ConsumeAsPartial(partial vector.Any) {
c, ok := partial.(*vector.Const)
if !ok || c.Len() != 1 || partial.Type() != super.TypeUint64 {
panic("count: bad partial")
}
a.count += c.Value().Uint()
}

func (a *count) ResultAsPartial(*super.Context) super.Value {
return a.Result(nil)
}
10 changes: 9 additions & 1 deletion runtime/vam/expr/agg/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func newMathReducer(f *mathFunc) *mathReducer {

var _ Func = (*mathReducer)(nil)

func (m *mathReducer) Result() super.Value {
func (m *mathReducer) Result(*super.Context) super.Value {
if !m.hasval {
if m.math == nil {
return super.Null
Expand Down Expand Up @@ -78,6 +78,14 @@ func (m *mathReducer) Consume(vec vector.Any) {
m.math.consume(vec)
}

func (m *mathReducer) ConsumeAsPartial(vec vector.Any) {
m.Consume(vec)
}

func (m *mathReducer) ResultAsPartial(*super.Context) super.Value {
return m.Result(nil)
}

func isNull(vec vector.Any) bool {
if c, ok := vec.(*vector.Const); ok && c.Value().IsNull() {
return true
Expand Down
64 changes: 53 additions & 11 deletions runtime/vam/op/summarize/agg.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ type aggTable interface {
}

type superTable struct {
table map[string]aggRow
aggs []*expr.Aggregator
builder *vector.RecordBuilder
aggs []*expr.Aggregator
builder *vector.RecordBuilder
partialsIn bool
partialsOut bool
table map[string]aggRow
zctx *super.Context
}

var _ aggTable = (*superTable)(nil)
Expand Down Expand Up @@ -58,7 +61,11 @@ func (s *superTable) update(keys []vector.Any, args []vector.Any) {
if len(m) > 1 {
arg = vector.NewView(arg, index)
}
row.funcs[i].Consume(arg)
if s.partialsIn {
row.funcs[i].ConsumeAsPartial(arg)
} else {
row.funcs[i].Consume(arg)
}
}
}
}
Expand Down Expand Up @@ -95,23 +102,37 @@ func (s *superTable) materializeRow(row aggRow) vector.Any {
vecs = append(vecs, vector.NewConst(key, 1, nil))
}
for _, fn := range row.funcs {
val := fn.Result()
var val super.Value
if s.partialsOut {
val = fn.ResultAsPartial(s.zctx)
} else {
val = fn.Result(s.zctx)
}
vecs = append(vecs, vector.NewConst(val, 1, nil))
}
return s.builder.New(vecs)
}

type countByString struct {
nulls uint64
table map[string]uint64
builder *vector.RecordBuilder
nulls uint64
table map[string]uint64
builder *vector.RecordBuilder
partialsIn bool
}

func newCountByString(b *vector.RecordBuilder) aggTable {
return &countByString{builder: b, table: make(map[string]uint64)}
func newCountByString(b *vector.RecordBuilder, partialsIn bool) aggTable {
return &countByString{
builder: b,
table: make(map[string]uint64),
partialsIn: partialsIn,
}
}

func (c *countByString) update(keys []vector.Any, _ []vector.Any) {
func (c *countByString) update(keys, vals []vector.Any) {
if c.partialsIn {
c.updatePartial(keys[0], vals[0])
return
}
switch val := keys[0].(type) {
case *vector.String:
c.count(val)
Expand All @@ -124,6 +145,27 @@ func (c *countByString) update(keys []vector.Any, _ []vector.Any) {
}
}

func (c *countByString) updatePartial(keyvec, valvec vector.Any) {
key, ok1 := keyvec.(*vector.String)
val, ok2 := valvec.(*vector.Uint)
if !ok1 || !ok2 {
panic("count by string: invalid partials in")
}
if val.Nulls != nil {
for i := range key.Len() {
if val.Nulls.Value(i) {
c.nulls++
} else {
c.table[key.Value(i)] += val.Values[i]
}
}
} else {
for i := range key.Len() {
c.table[key.Value(i)] += val.Values[i]
}
}
}

func (c *countByString) count(vec *vector.String) {
offs := vec.Offsets
bytes := vec.Bytes
Expand Down
43 changes: 25 additions & 18 deletions runtime/vam/op/summarize/summarize.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,35 @@ type Summarize struct {
zctx *super.Context
// XX Abstract this runtime into a generic table computation.
// Then the generic interface can execute fast paths for simple scenarios.
aggs []*expr.Aggregator
aggNames field.List
keyExprs []expr.Evaluator
keyNames field.List
typeTable *super.TypeVectorTable
builder *vector.RecordBuilder
aggs []*expr.Aggregator
aggNames field.List
keyExprs []expr.Evaluator
keyNames field.List
typeTable *super.TypeVectorTable
builder *vector.RecordBuilder
partialsIn bool
partialsOut bool

types []super.Type
tables map[int]aggTable
results []aggTable
}

func New(parent vector.Puller, zctx *super.Context, aggPaths field.List, aggs []*expr.Aggregator, keyNames []field.Path, keyExprs []expr.Evaluator) (*Summarize, error) {
func New(parent vector.Puller, zctx *super.Context, aggPaths field.List, aggs []*expr.Aggregator, keyNames []field.Path, keyExprs []expr.Evaluator, partialsIn, partialsOut bool) (*Summarize, error) {
builder, err := vector.NewRecordBuilder(zctx, append(keyNames, aggPaths...))
if err != nil {
return nil, err
}
return &Summarize{
parent: parent,
aggs: aggs,
keyExprs: keyExprs,
tables: make(map[int]aggTable),
typeTable: super.NewTypeVectorTable(),
types: make([]super.Type, len(keyExprs)),
builder: builder,
parent: parent,
aggs: aggs,
keyExprs: keyExprs,
tables: make(map[int]aggTable),
typeTable: super.NewTypeVectorTable(),
types: make([]super.Type, len(keyExprs)),
builder: builder,
partialsIn: partialsIn,
partialsOut: partialsOut,
}, nil
}

Expand Down Expand Up @@ -94,12 +98,15 @@ func (s *Summarize) consume(keys []vector.Any, vals []vector.Any) {
func (s *Summarize) newAggTable(keyTypes []super.Type) aggTable {
// Check if we can us an optimized table, else go slow path.
if s.isCountByString(keyTypes) {
return newCountByString(s.builder)
return newCountByString(s.builder, s.partialsIn)
}
return &superTable{
table: make(map[string]aggRow),
aggs: s.aggs,
builder: s.builder,
aggs: s.aggs,
builder: s.builder,
partialsIn: s.partialsIn,
partialsOut: s.partialsOut,
table: make(map[string]aggRow),
zctx: s.zctx,
}
}

Expand Down

0 comments on commit f2bf9e6

Please sign in to comment.