diff --git a/pkg/bmt/bmt.go b/pkg/bmt/bmt.go index 38c0e8bff5f..e13aa3ec1cb 100644 --- a/pkg/bmt/bmt.go +++ b/pkg/bmt/bmt.go @@ -40,19 +40,6 @@ type Hasher struct { span []byte // The span of the data subsumed under the chunk } -// facade -func NewHasher(hasherFact func() hash.Hash) *Hasher { - conf := NewConf(hasherFact, swarm.BmtBranches, 32) - - return &Hasher{ - Conf: conf, - result: make(chan []byte), - errc: make(chan error, 1), - span: make([]byte, SpanSize), - bmt: newTree(conf.segmentSize, conf.maxSize, conf.depth, conf.hasher), - } -} - // Capacity returns the maximum amount of bytes that will be processed by this hasher implementation. // since BMT assumes a balanced binary tree, capacity it is always a power of 2 func (h *Hasher) Capacity() int { diff --git a/pkg/bmt/proof.go b/pkg/bmt/proof.go index fa39174c3c5..b9a958db9ab 100644 --- a/pkg/bmt/proof.go +++ b/pkg/bmt/proof.go @@ -17,17 +17,6 @@ type Proof struct { Index int } -// Override base hash function of Hasher to fill buffer with zeros until chunk length -func (p Prover) Hash(b []byte) ([]byte, error) { - for i := p.size; i < p.maxSize; i += len(zerosection) { - _, err := p.Write(zerosection) - if err != nil { - return []byte{}, err - } - } - return p.Hasher.Hash(b) -} - // Proof returns the inclusion proof of the i-th data segment func (p Prover) Proof(i int) Proof { index := i @@ -47,42 +36,26 @@ func (p Prover) Proof(i int) Proof { secsize := 2 * p.segmentSize offset := i * secsize section := p.bmt.buffer[offset : offset+secsize] - left := section[:p.segmentSize] - right := section[p.segmentSize:] - var segment, firstSegmentSister []byte - if index%2 == 0 { - segment, firstSegmentSister = left, right - } else { - segment, firstSegmentSister = right, left - } - sisters = append([][]byte{firstSegmentSister}, sisters...) - return Proof{segment, sisters, p.span, index} + return Proof{section, sisters, p.span, index} } // Verify returns the bmt hash obtained from the proof which can then be checked against // the BMT hash of the chunk func (p Prover) Verify(i int, proof Proof) (root []byte, err error) { - var section []byte - if i%2 == 0 { - section = append(append(section, proof.ProveSegment...), proof.ProofSegments[0]...) - } else { - section = append(append(section, proof.ProofSegments[0]...), proof.ProveSegment...) - } i = i / 2 n := p.bmt.leaves[i] - hasher := p.hasher() isLeft := n.isLeft - root, err = doHash(hasher, section) + root, err = doHash(n.hasher, proof.ProveSegment) if err != nil { return nil, err } n = n.parent - for _, sister := range proof.ProofSegments[1:] { + for _, sister := range proof.ProofSegments { if isLeft { - root, err = doHash(hasher, root, sister) + root, err = doHash(n.hasher, root, sister) } else { - root, err = doHash(hasher, sister, root) + root, err = doHash(n.hasher, sister, root) } if err != nil { return nil, err @@ -90,7 +63,7 @@ func (p Prover) Verify(i int, proof Proof) (root []byte, err error) { isLeft = n.isLeft n = n.parent } - return doHash(hasher, proof.Span, root) + return sha3hash(proof.Span, root) } func (n *node) getSister(isLeft bool) []byte { diff --git a/pkg/bmt/proof_test.go b/pkg/bmt/proof_test.go index 1b7f6d3b3dd..337b1bf3420 100644 --- a/pkg/bmt/proof_test.go +++ b/pkg/bmt/proof_test.go @@ -20,8 +20,7 @@ func TestProofCorrectness(t *testing.T) { t.Parallel() testData := []byte("hello world") - testDataPadded := make([]byte, swarm.ChunkSize) - copy(testDataPadded, testData) + testData = append(testData, make([]byte, 4096-len(testData))...) verifySegments := func(t *testing.T, exp []string, found [][]byte) { t.Helper() @@ -58,8 +57,8 @@ func TestProofCorrectness(t *testing.T) { if err != nil { t.Fatal(err) } - pr := bmt.Prover{hh} - rh, err := pr.Hash(nil) + + rh, err := hh.Hash(nil) if err != nil { t.Fatal(err) } @@ -67,10 +66,9 @@ func TestProofCorrectness(t *testing.T) { t.Run("proof for left most", func(t *testing.T) { t.Parallel() - proof := pr.Proof(0) + proof := bmt.Prover{hh}.Proof(0) expSegmentStrings := []string{ - "0000000000000000000000000000000000000000000000000000000000000000", "ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5", "b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30", "21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85", @@ -81,7 +79,7 @@ func TestProofCorrectness(t *testing.T) { verifySegments(t, expSegmentStrings, proof.ProofSegments) - if !bytes.Equal(proof.ProveSegment, testDataPadded[:hh.Size()]) { + if !bytes.Equal(proof.ProveSegment, testData[:2*hh.Size()]) { t.Fatal("section incorrect") } @@ -93,10 +91,9 @@ func TestProofCorrectness(t *testing.T) { t.Run("proof for right most", func(t *testing.T) { t.Parallel() - proof := pr.Proof(127) + proof := bmt.Prover{hh}.Proof(127) expSegmentStrings := []string{ - "0000000000000000000000000000000000000000000000000000000000000000", "ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5", "b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30", "21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85", @@ -107,7 +104,7 @@ func TestProofCorrectness(t *testing.T) { verifySegments(t, expSegmentStrings, proof.ProofSegments) - if !bytes.Equal(proof.ProveSegment, testDataPadded[127*hh.Size():]) { + if !bytes.Equal(proof.ProveSegment, testData[126*hh.Size():]) { t.Fatal("section incorrect") } @@ -119,10 +116,9 @@ func TestProofCorrectness(t *testing.T) { t.Run("proof for middle", func(t *testing.T) { t.Parallel() - proof := pr.Proof(64) + proof := bmt.Prover{hh}.Proof(64) expSegmentStrings := []string{ - "0000000000000000000000000000000000000000000000000000000000000000", "ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5", "b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30", "21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85", @@ -133,7 +129,7 @@ func TestProofCorrectness(t *testing.T) { verifySegments(t, expSegmentStrings, proof.ProofSegments) - if !bytes.Equal(proof.ProveSegment, testDataPadded[64*hh.Size():65*hh.Size()]) { + if !bytes.Equal(proof.ProveSegment, testData[64*hh.Size():66*hh.Size()]) { t.Fatal("section incorrect") } @@ -146,7 +142,6 @@ func TestProofCorrectness(t *testing.T) { t.Parallel() segmentStrings := []string{ - "0000000000000000000000000000000000000000000000000000000000000000", "ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5", "b4c11951957c6f8f642c4af61cd6b24640fec6dc7fc607ee8206a99e92410d30", "21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85", @@ -164,9 +159,9 @@ func TestProofCorrectness(t *testing.T) { segments = append(segments, decoded) } - segment := testDataPadded[64*hh.Size() : 65*hh.Size()] + segment := testData[64*hh.Size() : 66*hh.Size()] - rootHash, err := pr.Verify(64, bmt.Proof{ + rootHash, err := bmt.Prover{hh}.Verify(64, bmt.Proof{ ProveSegment: segment, ProofSegments: segments, Span: bmt.LengthToSpan(4096), @@ -205,7 +200,6 @@ func TestProof(t *testing.T) { } rh, err := hh.Hash(nil) - pr := bmt.Prover{hh} if err != nil { t.Fatal(err) } @@ -215,7 +209,7 @@ func TestProof(t *testing.T) { t.Run(fmt.Sprintf("segmentIndex %d", i), func(t *testing.T) { t.Parallel() - proof := pr.Proof(i) + proof := bmt.Prover{hh}.Proof(i) h := pool.Get() defer pool.Put(h) diff --git a/pkg/bmt/trhasher.go b/pkg/bmt/trhasher.go new file mode 100644 index 00000000000..00df6664b85 --- /dev/null +++ b/pkg/bmt/trhasher.go @@ -0,0 +1,25 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bmt + +import ( + "hash" + + "github.com/ethersphere/bee/pkg/swarm" +) + +func NewTrHasher(prefix []byte) *Hasher { + capacity := 32 + hasherFact := func() hash.Hash { return swarm.NewTrHasher(prefix) } + conf := NewConf(hasherFact, swarm.BmtBranches, capacity) + + return &Hasher{ + Conf: conf, + result: make(chan []byte), + errc: make(chan error, 1), + span: make([]byte, SpanSize), + bmt: newTree(conf.segmentSize, conf.maxSize, conf.depth, conf.hasher), + } +} diff --git a/pkg/storer/sample.go b/pkg/storer/sample.go index 07c92885ac0..690e06affc3 100644 --- a/pkg/storer/sample.go +++ b/pkg/storer/sample.go @@ -10,7 +10,6 @@ import ( "crypto/hmac" "encoding/binary" "fmt" - "hash" "math/big" "sort" "sync" @@ -44,13 +43,7 @@ type Sample struct { func RandSample(t *testing.T, anchor []byte) Sample { t.Helper() - prefixHasherFactory := func() hash.Hash { - return swarm.NewPrefixHasher(anchor) - } - pool := bmt.NewPool(bmt.NewConf(prefixHasherFactory, swarm.BmtBranches, 8)) - - hasher := pool.Get() - defer pool.Put(hasher) + hasher := bmt.NewTrHasher(anchor) items := make([]SampleItem, SampleSize) for i := 0; i < SampleSize; i++ { diff --git a/pkg/swarm/hasher.go b/pkg/swarm/hasher.go index b9823bb50a1..485b61ab398 100644 --- a/pkg/swarm/hasher.go +++ b/pkg/swarm/hasher.go @@ -15,15 +15,15 @@ func NewHasher() hash.Hash { return sha3.NewLegacyKeccak256() } -type PrefixHasher struct { +type trHasher struct { hash.Hash prefix []byte } -// NewPrefixHasher returns new hasher which is Keccak-256 hasher +// NewTrHasher returns new hasher which is Keccak-256 hasher // with prefix value added as initial data. -func NewPrefixHasher(prefix []byte) hash.Hash { - h := &PrefixHasher{ +func NewTrHasher(prefix []byte) hash.Hash { + h := &trHasher{ Hash: NewHasher(), prefix: prefix, } @@ -32,7 +32,7 @@ func NewPrefixHasher(prefix []byte) hash.Hash { return h } -func (h *PrefixHasher) Reset() { +func (h *trHasher) Reset() { h.Hash.Reset() _, _ = h.Write(h.prefix) } diff --git a/pkg/swarm/hasher_test.go b/pkg/swarm/hasher_test.go index 3811e09605a..bfee0a78e98 100644 --- a/pkg/swarm/hasher_test.go +++ b/pkg/swarm/hasher_test.go @@ -66,7 +66,7 @@ func TestNewTrHasher(t *testing.T) { // Run tests cases against TrHasher for _, tc := range tests { - h := swarm.NewPrefixHasher(tc.prefix) + h := swarm.NewTrHasher(tc.prefix) _, err := h.Write(tc.plaintext) if err != nil {