diff --git a/namespace/namespace.go b/namespace/namespace.go index 1feb6f9..f2e6cb6 100644 --- a/namespace/namespace.go +++ b/namespace/namespace.go @@ -119,23 +119,23 @@ func (n Namespace) IsSecondaryReserved() bool { } func (n Namespace) IsParityShares() bool { - return bytes.Equal(n.Bytes(), ParitySharesNamespace.Bytes()) + return n.Equals(ParitySharesNamespace) } func (n Namespace) IsTailPadding() bool { - return bytes.Equal(n.Bytes(), TailPaddingNamespace.Bytes()) + return n.Equals(TailPaddingNamespace) } func (n Namespace) IsPrimaryReservedPadding() bool { - return bytes.Equal(n.Bytes(), PrimaryReservedPaddingNamespace.Bytes()) + return n.Equals(PrimaryReservedPaddingNamespace) } func (n Namespace) IsTx() bool { - return bytes.Equal(n.Bytes(), TxNamespace.Bytes()) + return n.Equals(TxNamespace) } func (n Namespace) IsPayForBlob() bool { - return bytes.Equal(n.Bytes(), PayForBlobNamespace.Bytes()) + return n.Equals(PayForBlobNamespace) } func (n Namespace) Repeat(times int) []Namespace { @@ -147,23 +147,34 @@ func (n Namespace) Repeat(times int) []Namespace { } func (n Namespace) Equals(n2 Namespace) bool { - return bytes.Equal(n.Bytes(), n2.Bytes()) + return n.Version == n2.Version && bytes.Equal(n.ID, n2.ID) } func (n Namespace) IsLessThan(n2 Namespace) bool { - return bytes.Compare(n.Bytes(), n2.Bytes()) == -1 + return n.Compare(n2) == -1 } func (n Namespace) IsLessOrEqualThan(n2 Namespace) bool { - return bytes.Compare(n.Bytes(), n2.Bytes()) < 1 + return n.Compare(n2) < 1 } func (n Namespace) IsGreaterThan(n2 Namespace) bool { - return bytes.Compare(n.Bytes(), n2.Bytes()) == 1 + return n.Compare(n2) == 1 } func (n Namespace) IsGreaterOrEqualThan(n2 Namespace) bool { - return bytes.Compare(n.Bytes(), n2.Bytes()) > -1 + return n.Compare(n2) > -1 +} + +func (n Namespace) Compare(n2 Namespace) int { + switch { + case n.Version == n2.Version: + return bytes.Compare(n.ID, n2.ID) + case n.Version < n2.Version: + return -1 + default: + return 1 + } } // leftPad returns a new byte slice with the provided byte slice left-padded to the provided size. diff --git a/namespace/namespace_test.go b/namespace/namespace_test.go index 57c125c..0bf5430 100644 --- a/namespace/namespace_test.go +++ b/namespace/namespace_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -297,3 +298,121 @@ func TestIsReserved(t *testing.T) { assert.Equal(t, tc.want, got) } } + +func Test_compareMethods(t *testing.T) { + minID := RandomBlobNamespaceID() + maxID := RandomBlobNamespaceID() + // repeat until maxID meets our expectations (maxID > minID). + for bytes.Compare(maxID, minID) != 1 { + maxID = RandomBlobNamespaceID() + } + + vers := []byte{NamespaceVersionZero, NamespaceVersionMax} + ids := [][]byte{minID, maxID} + + // collect all possible pairs: (ver1 ?? ver2) x (id1 ?? id2) + var testPairs [][2]Namespace + for _, ver1 := range vers { + for _, ver2 := range vers { + for _, id1 := range ids { + for _, id2 := range ids { + testPairs = append(testPairs, [2]Namespace{ + {Version: ver1, ID: id1}, + {Version: ver2, ID: id2}, + }) + } + } + } + } + require.Len(t, testPairs, 16) // len(vers) * len(vers) * len(ids) * len(ids) + + type testCase struct { + name string + fn func(n, n2 Namespace) bool + old func(n, n2 Namespace) bool + } + testCases := []testCase{ + { + name: "Equals", + fn: Namespace.Equals, + old: func(n, n2 Namespace) bool { + return bytes.Equal(n.Bytes(), n2.Bytes()) + }, + }, + { + name: "IsLessThan", + fn: Namespace.IsLessThan, + old: func(n, n2 Namespace) bool { + return bytes.Compare(n.Bytes(), n2.Bytes()) == -1 + }, + }, + { + name: "IsLessOrEqualThan", + fn: Namespace.IsLessOrEqualThan, + old: func(n, n2 Namespace) bool { + return bytes.Compare(n.Bytes(), n2.Bytes()) < 1 + }, + }, + { + name: "IsGreaterThan", + fn: Namespace.IsGreaterThan, + old: func(n, n2 Namespace) bool { + return bytes.Compare(n.Bytes(), n2.Bytes()) == 1 + }, + }, + { + name: "IsGreaterOrEqualThan", + fn: Namespace.IsGreaterOrEqualThan, + old: func(n, n2 Namespace) bool { + return bytes.Compare(n.Bytes(), n2.Bytes()) > -1 + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for i, p := range testPairs { + n, n2 := p[0], p[1] + got := tc.fn(n, n2) + want := tc.old(n, n2) + assert.Equal(t, want, got, "for pair %d", i) + } + }) + } +} + +func BenchmarkEqual(b *testing.B) { + n1 := RandomNamespace() + n2 := RandomNamespace() + // repeat until n2 meets our expectations (n1 != n2). + for n1.Equals(n2) { + n2 = RandomNamespace() + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + if n1.Equals(n2) { + b.Fatal() + } + } +} + +func BenchmarkCompare(b *testing.B) { + n1 := RandomNamespace() + n2 := RandomNamespace() + // repeat until n2 meets our expectations (n1 > n2). + for n1.Compare(n2) != 1 { + n2 = RandomNamespace() + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + if n1.Compare(n2) != 1 { + b.Fatal() + } + } +}