diff --git a/compiler/kernel/vexpr.go b/compiler/kernel/vexpr.go index 7b39a55681..833caf9ef4 100644 --- a/compiler/kernel/vexpr.go +++ b/compiler/kernel/vexpr.go @@ -133,9 +133,8 @@ func (b *Builder) compileVamUnary(unary dag.UnaryExpr) (vamexpr.Evaluator, error return nil, err } switch unary.Op { - //XXX TBD - //case "-": - // return vamexpr.NewUnaryMinus(b.zctx(), e), nil + case "-": + return vamexpr.NewUnaryMinus(b.zctx(), e), nil case "!": return vamexpr.NewLogicalNot(b.zctx(), e), nil default: diff --git a/runtime/sam/expr/eval.go b/runtime/sam/expr/eval.go index 71e3f27559..244332f687 100644 --- a/runtime/sam/expr/eval.go +++ b/runtime/sam/expr/eval.go @@ -545,6 +545,26 @@ func NewUnaryMinus(zctx *super.Context, e Evaluator) *UnaryMinus { func (u *UnaryMinus) Eval(ectx Context, this super.Value) super.Value { val := u.expr.Eval(ectx, this) typ := val.Type() + if super.IsUnsigned(typ.ID()) { + switch typ.ID() { + case super.IDUint8: + typ = super.TypeInt8 + case super.IDUint16: + typ = super.TypeInt16 + case super.IDUint32: + typ = super.TypeInt32 + default: + typ = super.TypeInt64 + } + v, ok := coerce.ToInt(val, typ) + if !ok { + return u.zctx.WrapError("cannot cast to "+zson.FormatType(typ), val) + } + if val.IsNull() { + return super.NewValue(typ, nil) + } + val = super.NewInt(typ, v) + } if val.IsNull() && super.IsNumber(typ.ID()) { return val } @@ -575,30 +595,6 @@ func (u *UnaryMinus) Eval(ectx Context, this super.Value) super.Value { return u.zctx.WrapError("unary '-' underflow", val) } return super.NewInt64(-v) - case super.IDUint8: - v := val.Uint() - if v > math.MaxInt8 { - return u.zctx.WrapError("unary '-' overflow", val) - } - return super.NewInt8(int8(-v)) - case super.IDUint16: - v := val.Uint() - if v > math.MaxInt16 { - return u.zctx.WrapError("unary '-' overflow", val) - } - return super.NewInt16(int16(-v)) - case super.IDUint32: - v := val.Uint() - if v > math.MaxInt32 { - return u.zctx.WrapError("unary '-' overflow", val) - } - return super.NewInt32(int32(-v)) - case super.IDUint64: - v := val.Uint() - if v > math.MaxInt64 { - return u.zctx.WrapError("unary '-' overflow", val) - } - return super.NewInt64(int64(-v)) } return u.zctx.WrapError("type incompatible with unary '-' operator", val) } diff --git a/runtime/vam/expr/unaryminus.go b/runtime/vam/expr/unaryminus.go new file mode 100644 index 0000000000..79071d32a3 --- /dev/null +++ b/runtime/vam/expr/unaryminus.go @@ -0,0 +1,150 @@ +package expr + +import ( + "math" + + "github.com/brimdata/super" + "github.com/brimdata/super/runtime/vam/expr/cast" + "github.com/brimdata/super/vector" +) + +type unaryMinus struct { + zctx *super.Context + expr Evaluator +} + +func NewUnaryMinus(zctx *super.Context, eval Evaluator) Evaluator { + return &unaryMinus{zctx, eval} +} + +func (u *unaryMinus) Eval(this vector.Any) vector.Any { + return vector.Apply(true, u.eval, u.expr.Eval(this)) +} + +func (u *unaryMinus) eval(vecs ...vector.Any) vector.Any { + vec := vector.Under(vecs[0]) + if vec.Len() == 0 { + return vec + } + if _, ok := vec.(*vector.Error); ok { + return vec + } + id := vec.Type().ID() + if !super.IsNumber(vec.Type().ID()) { + return vector.NewWrappedError(u.zctx, "type incompatible with unary '-' operator", vecs[0]) + } + if super.IsUnsigned(id) { + var typ super.Type + switch id { + case super.IDUint8: + typ = super.TypeInt8 + case super.IDUint16: + typ = super.TypeInt16 + case super.IDUint32: + typ = super.TypeInt32 + default: + typ = super.TypeInt64 + } + return u.eval(cast.To(u.zctx, vec, typ)) + } + out, ok := u.convert(vec) + if !ok { + // Overflow for int detected, go slow path. + return u.slowPath(vec) + } + return out +} + +func (u *unaryMinus) convert(vec vector.Any) (vector.Any, bool) { + switch vec := vec.(type) { + case *vector.Const: + var val super.Value + if super.IsFloat(vec.Type().ID()) { + val = super.NewFloat(vec.Type(), -vec.Value().Float()) + } else { + v := vec.Value().Int() + if v == minInt(vec.Type()) { + return nil, false + } + val = super.NewInt(vec.Type(), -vec.Value().Int()) + } + return vector.NewConst(val, vec.Len(), vec.Nulls), true + case *vector.Dict: + out, ok := u.convert(vec.Any) + if !ok { + return nil, false + } + return &vector.Dict{ + Any: out, + Index: vec.Index, + Counts: vec.Counts, + Nulls: vec.Nulls, + }, true + case *vector.View: + out, ok := u.convert(vec.Any) + if !ok { + return nil, false + } + return &vector.View{Any: out, Index: vec.Index}, true + case *vector.Int: + min := minInt(vec.Type()) + out := make([]int64, vec.Len()) + for i := range vec.Len() { + if vec.Values[i] == min { + return nil, false + } + out[i] = -vec.Values[i] + } + return vector.NewInt(vec.Typ, out, vec.Nulls), true + case *vector.Float: + out := make([]float64, vec.Len()) + for i := range vec.Len() { + out[i] = -vec.Values[i] + } + return vector.NewFloat(vec.Typ, out, vec.Nulls), true + default: + panic(vec) + } +} + +func (u *unaryMinus) slowPath(vec vector.Any) vector.Any { + var nulls *vector.Bool + var ints []int64 + var errs []uint32 + minval := minInt(vec.Type()) + for i := range vec.Len() { + v, isnull := vector.IntValue(vec, i) + if isnull { + if nulls == nil { + nulls = vector.NewBoolEmpty(vec.Len(), nil) + } + nulls.Set(uint32(len(ints))) + ints = append(ints, 0) + continue + } + if v == minval { + errs = append(errs, i) + } else { + ints = append(ints, -v) + } + } + if nulls != nil { + nulls.SetLen(uint32(len(ints))) + } + out := vector.NewInt(vec.Type(), ints, nulls) + err := vector.NewWrappedError(u.zctx, "unary '-' underflow", vector.NewView(vec, errs)) + return vector.Combine(out, errs, err) +} + +func minInt(typ super.Type) int64 { + switch typ.ID() { + case super.IDInt8: + return math.MinInt8 + case super.IDInt16: + return math.MinInt16 + case super.IDInt32: + return math.MinInt32 + default: + return math.MinInt64 + } +} diff --git a/runtime/sam/expr/ztests/unary-minus.yaml b/runtime/ztests/expr/unary-minus.yaml similarity index 95% rename from runtime/sam/expr/ztests/unary-minus.yaml rename to runtime/ztests/expr/unary-minus.yaml index d8fee7e48e..9245786350 100644 --- a/runtime/sam/expr/ztests/unary-minus.yaml +++ b/runtime/ztests/expr/unary-minus.yaml @@ -1,5 +1,7 @@ zed: yield -this +vector: true + input: | 1(int8) -1(int8) @@ -66,10 +68,10 @@ output: | null(int16) null(int32) null(int64) - null(uint8) - null(uint16) - null(uint32) - null(uint64) + null(int8) + null(int16) + null(int32) + null(int64) null(float16) null(float32) null(float64) diff --git a/vector/bool.go b/vector/bool.go index b0a23d9edb..289d24720c 100644 --- a/vector/bool.go +++ b/vector/bool.go @@ -39,6 +39,10 @@ func (b *Bool) Set(slot uint32) { b.Bits[slot>>6] |= (1 << (slot & 0x3f)) } +func (b *Bool) SetLen(len uint32) { + b.len = len +} + func (b *Bool) Len() uint32 { if b == nil { return 0