diff --git a/runtime/expr/sort.go b/runtime/expr/sort.go index d14ca74e70..13d37e1dc2 100644 --- a/runtime/expr/sort.go +++ b/runtime/expr/sort.go @@ -2,6 +2,7 @@ package expr import ( "bytes" + "cmp" "fmt" "math" "slices" @@ -326,24 +327,12 @@ func LookupCompare(typ zed.Type) comparefn { case zed.IDInt16, zed.IDInt32, zed.IDInt64: return func(a, b zcode.Bytes) int { - va, vb := zed.DecodeInt(a), zed.DecodeInt(b) - if va < vb { - return -1 - } else if va > vb { - return 1 - } - return 0 + return cmp.Compare(zed.DecodeInt(a), zed.DecodeInt(b)) } case zed.IDUint16, zed.IDUint32, zed.IDUint64: return func(a, b zcode.Bytes) int { - va, vb := zed.DecodeUint(a), zed.DecodeUint(b) - if va < vb { - return -1 - } else if va > vb { - return 1 - } - return 0 + return cmp.Compare(zed.DecodeUint(a), zed.DecodeUint(b)) } case zed.IDFloat16, zed.IDFloat32, zed.IDFloat64: @@ -352,41 +341,19 @@ func LookupCompare(typ zed.Type) comparefn { aNaN, bNaN := math.IsNaN(va), math.IsNaN(vb) if aNaN && bNaN { // Order different NaNs so ZNG sets have a canonical form. - aBits, bBits := math.Float64bits(va), math.Float64bits(vb) - if aBits < bBits { - return -1 - } else if aBits > bBits { - return 1 - } - return 0 - } else if aNaN || va < vb { - return -1 - } else if bNaN || va > vb { - return 1 + cmp.Compare(math.Float64bits(va), math.Float64bits(vb)) } - return 0 + return cmp.Compare(va, vb) } case zed.IDTime: return func(a, b zcode.Bytes) int { - va, vb := zed.DecodeTime(a), zed.DecodeTime(b) - if va < vb { - return -1 - } else if va > vb { - return 1 - } - return 0 + return cmp.Compare(zed.DecodeTime(a), zed.DecodeTime(b)) } case zed.IDDuration: return func(a, b zcode.Bytes) int { - va, vb := zed.DecodeDuration(a), zed.DecodeDuration(b) - if va < vb { - return -1 - } else if va > vb { - return 1 - } - return 0 + return cmp.Compare(zed.DecodeDuration(a), zed.DecodeDuration(b)) } case zed.IDIP: diff --git a/type.go b/type.go index e6186ae1ad..0fa00a69c4 100644 --- a/type.go +++ b/type.go @@ -10,6 +10,7 @@ package zed import ( + "cmp" "encoding/binary" "errors" "fmt" @@ -411,17 +412,17 @@ func CompareTypes(a, b Type) int { // a == b return 0 } - if cmp := compareInts(int(a.Kind()), int(b.Kind())); cmp != 0 { + if cmp := cmp.Compare(a.Kind(), b.Kind()); cmp != 0 { return cmp } switch a.Kind() { case PrimitiveKind: - return compareInts(aID, bID) + return cmp.Compare(aID, bID) case RecordKind: ra, rb := TypeRecordOf(a), TypeRecordOf(b) // First compare number of fields. - if len(ra.Fields) != len(rb.Fields) { - return compareInts(len(ra.Fields), len(rb.Fields)) + if cmp := cmp.Compare(len(ra.Fields), len(rb.Fields)); cmp != 0 { + return cmp } // Second compare field names. for i := 0; i < len(ra.Fields); i++ { @@ -447,7 +448,7 @@ func CompareTypes(a, b Type) int { return CompareTypes(ma.ValType, mb.ValType) case UnionKind: ua, ub := a.(*TypeUnion), b.(*TypeUnion) - if cmp := compareInts(len(ua.Types), len(ub.Types)); cmp != 0 { + if cmp := cmp.Compare(len(ua.Types), len(ub.Types)); cmp != 0 { return cmp } for i := 0; i < len(ua.Types); i++ { @@ -458,7 +459,7 @@ func CompareTypes(a, b Type) int { return 0 case EnumKind: ea, eb := a.(*TypeEnum), b.(*TypeEnum) - if cmp := compareInts(len(ea.Symbols), len(eb.Symbols)); cmp != 0 { + if cmp := cmp.Compare(len(ea.Symbols), len(eb.Symbols)); cmp != 0 { return cmp } for i := 0; i < len(ea.Symbols); i++ { @@ -474,15 +475,6 @@ func CompareTypes(a, b Type) int { return 0 } -func compareInts(a, b int) int { - if a < b { - return -1 - } else if a > b { - return 1 - } - return 0 -} - type TypeOfType struct{} func (t *TypeOfType) ID() int {