From 2b0e397c6f028815a29e879b0fdb2ce218543ab6 Mon Sep 17 00:00:00 2001 From: istae <14264581+istae@users.noreply.github.com> Date: Wed, 20 Sep 2023 14:59:07 +0300 Subject: [PATCH] fix: unit test --- pkg/api/api.go | 8 +++--- pkg/api/postage.go | 9 +------ pkg/api/pss.go | 4 +-- pkg/api/stewardship.go | 6 ++--- pkg/postage/mock/service.go | 6 ++--- pkg/postage/service.go | 11 +++------ pkg/postage/service_test.go | 44 ++++++--------------------------- pkg/postage/stamper.go | 7 ++++-- pkg/postage/stamper_test.go | 19 +++++++++++++- pkg/postage/stampissuer_test.go | 7 +++++- 10 files changed, 53 insertions(+), 68 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index d3266421145..3ecea7ad000 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -772,7 +772,7 @@ type putterOptions struct { type putterSessionWrapper struct { storer.PutterSession stamper postage.Stamper - save func(bool) error + save func() error } func (p *putterSessionWrapper) Put(ctx context.Context, chunk swarm.Chunk) error { @@ -788,14 +788,14 @@ func (p *putterSessionWrapper) Done(ref swarm.Address) error { if err != nil { return err } - return p.save(true) + return p.save() } func (p *putterSessionWrapper) Cleanup() error { - return errors.Join(p.PutterSession.Cleanup(), p.save(false)) + return errors.Join(p.PutterSession.Cleanup(), p.save()) } -func (s *Service) getStamper(batchID []byte) (postage.Stamper, func(bool) error, error) { +func (s *Service) getStamper(batchID []byte) (postage.Stamper, func() error, error) { exists, err := s.batchStore.Exists(batchID) if err != nil { return nil, nil, fmt.Errorf("batch exists: %w", err) diff --git a/pkg/api/postage.go b/pkg/api/postage.go index 8eb93c5e989..7df275c29a7 100644 --- a/pkg/api/postage.go +++ b/pkg/api/postage.go @@ -276,7 +276,7 @@ func (s *Service) postageGetStampBucketsHandler(w http.ResponseWriter, r *http.R } hexBatchID := hex.EncodeToString(paths.BatchID) - issuer, save, err := s.post.GetStampIssuer(paths.BatchID) + issuer, _, err := s.post.GetStampIssuer(paths.BatchID) if err != nil { logger.Debug("get stamp issuer: get issuer failed", "batch_id", hexBatchID, "error", err) logger.Error(nil, "get stamp issuer: get issuer failed") @@ -303,13 +303,6 @@ func (s *Service) postageGetStampBucketsHandler(w http.ResponseWriter, r *http.R resp.Buckets[i] = bucketData{BucketID: uint32(i), Collisions: v} } - if err := save(false); err != nil { - logger.Debug("get stamp issuer: save issuer failed", "batch_id", hexBatchID, "error", err) - logger.Error(nil, "get stamp issuer: save issuer failed") - jsonhttp.InternalServerError(w, "save issuer failed") - return - } - jsonhttp.OK(w, resp) } diff --git a/pkg/api/pss.go b/pkg/api/pss.go index f64affe56b1..aa87a6165c8 100644 --- a/pkg/api/pss.go +++ b/pkg/api/pss.go @@ -98,7 +98,6 @@ func (s *Service) pssPostHandler(w http.ResponseWriter, r *http.Request) { err = s.pss.Send(r.Context(), topic, payload, stamper, queries.Recipient, targets) if err != nil { - err = errors.Join(err, save(false)) logger.Debug("send payload failed", "topic", paths.Topic, "error", err) logger.Error(nil, "send payload failed") switch { @@ -110,8 +109,7 @@ func (s *Service) pssPostHandler(w http.ResponseWriter, r *http.Request) { return } - err = save(true) - if err != nil { + if err = save(); err != nil { logger.Debug("save stamp failed", "error", err) logger.Error(nil, "save stamp failed") jsonhttp.InternalServerError(w, "pss send failed") diff --git a/pkg/api/stewardship.go b/pkg/api/stewardship.go index 8f2f5a961c1..fe1beb34e48 100644 --- a/pkg/api/stewardship.go +++ b/pkg/api/stewardship.go @@ -70,20 +70,20 @@ func (s *Service) stewardshipPutHandler(w http.ResponseWriter, r *http.Request) err = s.steward.Reupload(r.Context(), paths.Address, stamper) if err != nil { - err = errors.Join(err, save(false)) + err = errors.Join(err, save()) logger.Debug("re-upload failed", "chunk_address", paths.Address, "error", err) logger.Error(nil, "re-upload failed") jsonhttp.InternalServerError(w, "re-upload failed") return } - err = save(true) - if err != nil { + if err = save(); err != nil { logger.Debug("unable to save stamper data", "batchID", batchID, "error", err) logger.Error(nil, "unable to save stamper data") jsonhttp.InternalServerError(w, "unable to save stamper data") return } + jsonhttp.OK(w, nil) } diff --git a/pkg/postage/mock/service.go b/pkg/postage/mock/service.go index 0f190ba9d46..e137aaaac79 100644 --- a/pkg/postage/mock/service.go +++ b/pkg/postage/mock/service.go @@ -85,9 +85,9 @@ func (m *mockPostage) StampIssuers() []*postage.StampIssuer { return issuers } -func (m *mockPostage) GetStampIssuer(id []byte) (*postage.StampIssuer, func(bool) error, error) { +func (m *mockPostage) GetStampIssuer(id []byte) (*postage.StampIssuer, func() error, error) { if m.acceptAll { - return postage.NewStampIssuer("test fallback", "test identity", id, big.NewInt(3), 24, 6, 1000, true), func(_ bool) error { return nil }, nil + return postage.NewStampIssuer("test fallback", "test identity", id, big.NewInt(3), 24, 6, 1000, false), func() error { return nil }, nil } m.issuerLock.Lock() @@ -98,7 +98,7 @@ func (m *mockPostage) GetStampIssuer(id []byte) (*postage.StampIssuer, func(bool return nil, nil, postage.ErrNotFound } - return i, func(_ bool) error { + return i, func() error { return nil }, nil } diff --git a/pkg/postage/service.go b/pkg/postage/service.go index 7c6ec254c37..5e8aa017903 100644 --- a/pkg/postage/service.go +++ b/pkg/postage/service.go @@ -32,7 +32,7 @@ var ( type Service interface { Add(*StampIssuer) error StampIssuers() []*StampIssuer - GetStampIssuer([]byte) (*StampIssuer, func(bool) error, error) + GetStampIssuer([]byte) (*StampIssuer, func() error, error) IssuerUsable(*StampIssuer) bool BatchEventListener BatchExpiryHandler @@ -147,7 +147,7 @@ func (ps *service) IssuerUsable(st *StampIssuer) bool { } // GetStampIssuer finds a stamp issuer by batch ID. -func (ps *service) GetStampIssuer(batchID []byte) (*StampIssuer, func(bool) error, error) { +func (ps *service) GetStampIssuer(batchID []byte) (*StampIssuer, func() error, error) { ps.lock.Lock() defer ps.lock.Unlock() @@ -156,11 +156,8 @@ func (ps *service) GetStampIssuer(batchID []byte) (*StampIssuer, func(bool) erro if !ps.IssuerUsable(st) { return nil, nil, ErrNotUsable } - return st, func(update bool) error { - if update { - return ps.save(st) - } - return nil + return st, func() error { + return ps.save(st) }, nil } } diff --git a/pkg/postage/service_test.go b/pkg/postage/service_test.go index b3238993e17..8cf798e264c 100644 --- a/pkg/postage/service_test.go +++ b/pkg/postage/service_test.go @@ -7,14 +7,14 @@ package postage_test import ( crand "crypto/rand" "errors" + "io" + "math/big" + "testing" + "github.com/ethersphere/bee/pkg/postage" pstoremock "github.com/ethersphere/bee/pkg/postage/batchstore/mock" postagetesting "github.com/ethersphere/bee/pkg/postage/testing" "github.com/ethersphere/bee/pkg/storage/inmemstore" - "github.com/google/go-cmp/cmp" - "io" - "math/big" - "testing" ) // TestSaveLoad tests the idempotence of saving and loading the postage.Service @@ -116,7 +116,7 @@ func TestGetStampIssuer(t *testing.T) { if err != nil { t.Fatalf("expected no error, got %v", err) } - _ = save(true) + _ = save() if st.Label() != string(id) { t.Fatalf("wrong issuer returned") } @@ -163,7 +163,7 @@ func TestGetStampIssuer(t *testing.T) { if st.Label() != "recovered" { t.Fatal("wrong issuer returned") } - err = sv(true) + err = sv() if err != nil { t.Fatal(err) } @@ -177,7 +177,7 @@ func TestGetStampIssuer(t *testing.T) { if err != nil { t.Fatalf("expected no error, got %v", err) } - _ = save(true) + _ = save() if stampIssuer.Amount().Cmp(big.NewInt(13)) != 0 { t.Fatalf("expected amount %d got %d", 13, stampIssuer.Amount().Int64()) } @@ -191,7 +191,7 @@ func TestGetStampIssuer(t *testing.T) { if err != nil { t.Fatalf("expected no error, got %v", err) } - _ = save(true) + _ = save() if stampIssuer.Amount().Cmp(big.NewInt(3)) != 0 { t.Fatalf("expected amount %d got %d", 3, stampIssuer.Amount().Int64()) } @@ -199,32 +199,4 @@ func TestGetStampIssuer(t *testing.T) { t.Fatalf("expected depth %d got %d", 17, stampIssuer.Depth()) } }) - t.Run("save without update", func(t *testing.T) { - is, save, err := ps.GetStampIssuer(ids[1]) - if err != nil { - t.Fatal(err) - } - data := is.Buckets() - modified := make([]uint32, len(data)) - copy(modified, data) - for k, b := range modified { - b++ - modified[k] = b - } - - err = save(false) - if err != nil { - t.Fatal(err) - } - - is, _, err = ps.GetStampIssuer(ids[1]) - if err != nil { - t.Fatal(err) - } - - if !cmp.Equal(is.Buckets(), data) { - t.Fatal("expected buckets to be unchanged") - } - - }) } diff --git a/pkg/postage/stamper.go b/pkg/postage/stamper.go index 9bd0f93465d..d12797f106a 100644 --- a/pkg/postage/stamper.go +++ b/pkg/postage/stamper.go @@ -15,7 +15,8 @@ import ( var ( // ErrBucketFull is the error when a collision bucket is full. - ErrBucketFull = errors.New("bucket full") + ErrBucketFull = errors.New("bucket full") + ErrOverwriteImmutableIndex = errors.New("immutable batch old index overwrite due to previous faulty save") ) // Stamper can issue stamps from the given address of chunk. @@ -45,11 +46,13 @@ func (st *stamper) Stamp(addr swarm.Address) (*Stamp, error) { } switch err := st.store.Get(item); { case err == nil: - // check if this index is in the past, it could happen that we encountered + // The index should be in the past. It could happen that we encountered // some error after assigning this index and did not save the issuer data. In // this case we should assign a new index and update it. if st.issuer.assigned(item.BatchIndex) { break + } else if st.issuer.ImmutableFlag() { + return nil, ErrOverwriteImmutableIndex } fallthrough case errors.Is(err, storage.ErrNotFound): diff --git a/pkg/postage/stamper_test.go b/pkg/postage/stamper_test.go index 220787fa925..1fc3b3d5c34 100644 --- a/pkg/postage/stamper_test.go +++ b/pkg/postage/stamper_test.go @@ -107,7 +107,7 @@ func TestStamperStamping(t *testing.T) { }) t.Run("incorrect old index", func(t *testing.T) { - st := newTestStampIssuer(t, 1000) + st := newTestStampIssuerMutability(t, 1000, false) chunkAddr := swarm.RandAddress(t) bIdx := postage.ToBucket(st.BucketDepth(), chunkAddr) index := postage.IndexToBytes(bIdx, 100) @@ -129,6 +129,23 @@ func TestStamperStamping(t *testing.T) { } }) + t.Run("incorrect old index immutable", func(t *testing.T) { + st := newTestStampIssuerMutability(t, 1000, true) + chunkAddr := swarm.RandAddress(t) + bIdx := postage.ToBucket(st.BucketDepth(), chunkAddr) + index := postage.IndexToBytes(bIdx, 100) + testItem := postage.NewStampItem(). + WithBatchID(st.ID()). + WithChunkAddress(chunkAddr). + WithBatchIndex(index) + testSt := &testStore{Store: inmemstore.New(), stampItem: testItem} + stamper := postage.NewStamper(testSt, st, signer) + _, err := stamper.Stamp(chunkAddr) + if !errors.Is(err, postage.ErrOverwriteImmutableIndex) { + t.Fatalf("got err %v, wanted %v", err, postage.ErrOverwriteImmutableIndex) + } + }) + // tests return with ErrOwnerMismatch t.Run("owner mismatch", func(t *testing.T) { owner[0] ^= 0xff // bitflip the owner first byte, this case must come last! diff --git a/pkg/postage/stampissuer_test.go b/pkg/postage/stampissuer_test.go index 907d8b98842..f104a23bf5b 100644 --- a/pkg/postage/stampissuer_test.go +++ b/pkg/postage/stampissuer_test.go @@ -47,6 +47,11 @@ func TestStampIssuerMarshalling(t *testing.T) { } func newTestStampIssuer(t *testing.T, block uint64) *postage.StampIssuer { + t.Helper() + return newTestStampIssuerMutability(t, block, true) +} + +func newTestStampIssuerMutability(t *testing.T, block uint64, immutable bool) *postage.StampIssuer { t.Helper() id := make([]byte, 32) _, err := io.ReadFull(crand.Reader, id) @@ -61,7 +66,7 @@ func newTestStampIssuer(t *testing.T, block uint64) *postage.StampIssuer { 16, 8, block, - true, + immutable, ) }