Skip to content

Commit

Permalink
vam: Support for unary minus (#5510)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattnibs authored Dec 4, 2024
1 parent 1f0c0ce commit 200f373
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 31 deletions.
5 changes: 2 additions & 3 deletions compiler/kernel/vexpr.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,8 @@ func (b *Builder) compileVamUnary(unary dag.UnaryExpr) (vamexpr.Evaluator, error
return nil, err
}
switch unary.Op {
//XXX TBD
//case "-":
// return vamexpr.NewUnaryMinus(b.zctx(), e), nil
case "-":
return vamexpr.NewUnaryMinus(b.zctx(), e), nil
case "!":
return vamexpr.NewLogicalNot(b.zctx(), e), nil
default:
Expand Down
44 changes: 20 additions & 24 deletions runtime/sam/expr/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,26 @@ func NewUnaryMinus(zctx *super.Context, e Evaluator) *UnaryMinus {
func (u *UnaryMinus) Eval(ectx Context, this super.Value) super.Value {
val := u.expr.Eval(ectx, this)
typ := val.Type()
if super.IsUnsigned(typ.ID()) {
switch typ.ID() {
case super.IDUint8:
typ = super.TypeInt8
case super.IDUint16:
typ = super.TypeInt16
case super.IDUint32:
typ = super.TypeInt32
default:
typ = super.TypeInt64
}
v, ok := coerce.ToInt(val, typ)
if !ok {
return u.zctx.WrapError("cannot cast to "+zson.FormatType(typ), val)
}
if val.IsNull() {
return super.NewValue(typ, nil)
}
val = super.NewInt(typ, v)
}
if val.IsNull() && super.IsNumber(typ.ID()) {
return val
}
Expand Down Expand Up @@ -575,30 +595,6 @@ func (u *UnaryMinus) Eval(ectx Context, this super.Value) super.Value {
return u.zctx.WrapError("unary '-' underflow", val)
}
return super.NewInt64(-v)
case super.IDUint8:
v := val.Uint()
if v > math.MaxInt8 {
return u.zctx.WrapError("unary '-' overflow", val)
}
return super.NewInt8(int8(-v))
case super.IDUint16:
v := val.Uint()
if v > math.MaxInt16 {
return u.zctx.WrapError("unary '-' overflow", val)
}
return super.NewInt16(int16(-v))
case super.IDUint32:
v := val.Uint()
if v > math.MaxInt32 {
return u.zctx.WrapError("unary '-' overflow", val)
}
return super.NewInt32(int32(-v))
case super.IDUint64:
v := val.Uint()
if v > math.MaxInt64 {
return u.zctx.WrapError("unary '-' overflow", val)
}
return super.NewInt64(int64(-v))
}
return u.zctx.WrapError("type incompatible with unary '-' operator", val)
}
Expand Down
150 changes: 150 additions & 0 deletions runtime/vam/expr/unaryminus.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package expr

import (
"math"

"github.com/brimdata/super"
"github.com/brimdata/super/runtime/vam/expr/cast"
"github.com/brimdata/super/vector"
)

type unaryMinus struct {
zctx *super.Context
expr Evaluator
}

func NewUnaryMinus(zctx *super.Context, eval Evaluator) Evaluator {
return &unaryMinus{zctx, eval}
}

func (u *unaryMinus) Eval(this vector.Any) vector.Any {
return vector.Apply(true, u.eval, u.expr.Eval(this))
}

func (u *unaryMinus) eval(vecs ...vector.Any) vector.Any {
vec := vector.Under(vecs[0])
if vec.Len() == 0 {
return vec
}
if _, ok := vec.(*vector.Error); ok {
return vec
}
id := vec.Type().ID()
if !super.IsNumber(vec.Type().ID()) {
return vector.NewWrappedError(u.zctx, "type incompatible with unary '-' operator", vecs[0])
}
if super.IsUnsigned(id) {
var typ super.Type
switch id {
case super.IDUint8:
typ = super.TypeInt8
case super.IDUint16:
typ = super.TypeInt16
case super.IDUint32:
typ = super.TypeInt32
default:
typ = super.TypeInt64
}
return u.eval(cast.To(u.zctx, vec, typ))
}
out, ok := u.convert(vec)
if !ok {
// Overflow for int detected, go slow path.
return u.slowPath(vec)
}
return out
}

func (u *unaryMinus) convert(vec vector.Any) (vector.Any, bool) {
switch vec := vec.(type) {
case *vector.Const:
var val super.Value
if super.IsFloat(vec.Type().ID()) {
val = super.NewFloat(vec.Type(), -vec.Value().Float())
} else {
v := vec.Value().Int()
if v == minInt(vec.Type()) {
return nil, false
}
val = super.NewInt(vec.Type(), -vec.Value().Int())
}
return vector.NewConst(val, vec.Len(), vec.Nulls), true
case *vector.Dict:
out, ok := u.convert(vec.Any)
if !ok {
return nil, false
}
return &vector.Dict{
Any: out,
Index: vec.Index,
Counts: vec.Counts,
Nulls: vec.Nulls,
}, true
case *vector.View:
out, ok := u.convert(vec.Any)
if !ok {
return nil, false
}
return &vector.View{Any: out, Index: vec.Index}, true
case *vector.Int:
min := minInt(vec.Type())
out := make([]int64, vec.Len())
for i := range vec.Len() {
if vec.Values[i] == min {
return nil, false
}
out[i] = -vec.Values[i]
}
return vector.NewInt(vec.Typ, out, vec.Nulls), true
case *vector.Float:
out := make([]float64, vec.Len())
for i := range vec.Len() {
out[i] = -vec.Values[i]
}
return vector.NewFloat(vec.Typ, out, vec.Nulls), true
default:
panic(vec)
}
}

func (u *unaryMinus) slowPath(vec vector.Any) vector.Any {
var nulls *vector.Bool
var ints []int64
var errs []uint32
minval := minInt(vec.Type())
for i := range vec.Len() {
v, isnull := vector.IntValue(vec, i)
if isnull {
if nulls == nil {
nulls = vector.NewBoolEmpty(vec.Len(), nil)
}
nulls.Set(uint32(len(ints)))
ints = append(ints, 0)
continue
}
if v == minval {
errs = append(errs, i)
} else {
ints = append(ints, -v)
}
}
if nulls != nil {
nulls.SetLen(uint32(len(ints)))
}
out := vector.NewInt(vec.Type(), ints, nulls)
err := vector.NewWrappedError(u.zctx, "unary '-' underflow", vector.NewView(vec, errs))
return vector.Combine(out, errs, err)
}

func minInt(typ super.Type) int64 {
switch typ.ID() {
case super.IDInt8:
return math.MinInt8
case super.IDInt16:
return math.MinInt16
case super.IDInt32:
return math.MinInt32
default:
return math.MinInt64
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
zed: yield -this

vector: true

input: |
1(int8)
-1(int8)
Expand Down Expand Up @@ -66,10 +68,10 @@ output: |
null(int16)
null(int32)
null(int64)
null(uint8)
null(uint16)
null(uint32)
null(uint64)
null(int8)
null(int16)
null(int32)
null(int64)
null(float16)
null(float32)
null(float64)
4 changes: 4 additions & 0 deletions vector/bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ func (b *Bool) Set(slot uint32) {
b.Bits[slot>>6] |= (1 << (slot & 0x3f))
}

func (b *Bool) SetLen(len uint32) {
b.len = len
}

func (b *Bool) Len() uint32 {
if b == nil {
return 0
Expand Down

0 comments on commit 200f373

Please sign in to comment.