Skip to content

Commit

Permalink
fix: unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
istae committed Sep 20, 2023
1 parent 84b6eaa commit 2b0e397
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 68 deletions.
8 changes: 4 additions & 4 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ type putterOptions struct {
type putterSessionWrapper struct {
storer.PutterSession
stamper postage.Stamper
save func(bool) error
save func() error
}

func (p *putterSessionWrapper) Put(ctx context.Context, chunk swarm.Chunk) error {
Expand All @@ -788,14 +788,14 @@ func (p *putterSessionWrapper) Done(ref swarm.Address) error {
if err != nil {
return err
}
return p.save(true)
return p.save()
}

func (p *putterSessionWrapper) Cleanup() error {
return errors.Join(p.PutterSession.Cleanup(), p.save(false))
return errors.Join(p.PutterSession.Cleanup(), p.save())
}

func (s *Service) getStamper(batchID []byte) (postage.Stamper, func(bool) error, error) {
func (s *Service) getStamper(batchID []byte) (postage.Stamper, func() error, error) {
exists, err := s.batchStore.Exists(batchID)
if err != nil {
return nil, nil, fmt.Errorf("batch exists: %w", err)
Expand Down
9 changes: 1 addition & 8 deletions pkg/api/postage.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func (s *Service) postageGetStampBucketsHandler(w http.ResponseWriter, r *http.R
}
hexBatchID := hex.EncodeToString(paths.BatchID)

issuer, save, err := s.post.GetStampIssuer(paths.BatchID)
issuer, _, err := s.post.GetStampIssuer(paths.BatchID)
if err != nil {
logger.Debug("get stamp issuer: get issuer failed", "batch_id", hexBatchID, "error", err)
logger.Error(nil, "get stamp issuer: get issuer failed")
Expand All @@ -303,13 +303,6 @@ func (s *Service) postageGetStampBucketsHandler(w http.ResponseWriter, r *http.R
resp.Buckets[i] = bucketData{BucketID: uint32(i), Collisions: v}
}

if err := save(false); err != nil {
logger.Debug("get stamp issuer: save issuer failed", "batch_id", hexBatchID, "error", err)
logger.Error(nil, "get stamp issuer: save issuer failed")
jsonhttp.InternalServerError(w, "save issuer failed")
return
}

jsonhttp.OK(w, resp)
}

Expand Down
4 changes: 1 addition & 3 deletions pkg/api/pss.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ func (s *Service) pssPostHandler(w http.ResponseWriter, r *http.Request) {

err = s.pss.Send(r.Context(), topic, payload, stamper, queries.Recipient, targets)
if err != nil {
err = errors.Join(err, save(false))
logger.Debug("send payload failed", "topic", paths.Topic, "error", err)
logger.Error(nil, "send payload failed")
switch {
Expand All @@ -110,8 +109,7 @@ func (s *Service) pssPostHandler(w http.ResponseWriter, r *http.Request) {
return
}

err = save(true)
if err != nil {
if err = save(); err != nil {
logger.Debug("save stamp failed", "error", err)
logger.Error(nil, "save stamp failed")
jsonhttp.InternalServerError(w, "pss send failed")
Expand Down
6 changes: 3 additions & 3 deletions pkg/api/stewardship.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,20 @@ func (s *Service) stewardshipPutHandler(w http.ResponseWriter, r *http.Request)

err = s.steward.Reupload(r.Context(), paths.Address, stamper)
if err != nil {
err = errors.Join(err, save(false))
err = errors.Join(err, save())
logger.Debug("re-upload failed", "chunk_address", paths.Address, "error", err)
logger.Error(nil, "re-upload failed")
jsonhttp.InternalServerError(w, "re-upload failed")
return
}

err = save(true)
if err != nil {
if err = save(); err != nil {
logger.Debug("unable to save stamper data", "batchID", batchID, "error", err)
logger.Error(nil, "unable to save stamper data")
jsonhttp.InternalServerError(w, "unable to save stamper data")
return
}

jsonhttp.OK(w, nil)
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/postage/mock/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ func (m *mockPostage) StampIssuers() []*postage.StampIssuer {
return issuers
}

func (m *mockPostage) GetStampIssuer(id []byte) (*postage.StampIssuer, func(bool) error, error) {
func (m *mockPostage) GetStampIssuer(id []byte) (*postage.StampIssuer, func() error, error) {
if m.acceptAll {
return postage.NewStampIssuer("test fallback", "test identity", id, big.NewInt(3), 24, 6, 1000, true), func(_ bool) error { return nil }, nil
return postage.NewStampIssuer("test fallback", "test identity", id, big.NewInt(3), 24, 6, 1000, false), func() error { return nil }, nil
}

m.issuerLock.Lock()
Expand All @@ -98,7 +98,7 @@ func (m *mockPostage) GetStampIssuer(id []byte) (*postage.StampIssuer, func(bool
return nil, nil, postage.ErrNotFound
}

return i, func(_ bool) error {
return i, func() error {
return nil
}, nil
}
Expand Down
11 changes: 4 additions & 7 deletions pkg/postage/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var (
type Service interface {
Add(*StampIssuer) error
StampIssuers() []*StampIssuer
GetStampIssuer([]byte) (*StampIssuer, func(bool) error, error)
GetStampIssuer([]byte) (*StampIssuer, func() error, error)
IssuerUsable(*StampIssuer) bool
BatchEventListener
BatchExpiryHandler
Expand Down Expand Up @@ -147,7 +147,7 @@ func (ps *service) IssuerUsable(st *StampIssuer) bool {
}

// GetStampIssuer finds a stamp issuer by batch ID.
func (ps *service) GetStampIssuer(batchID []byte) (*StampIssuer, func(bool) error, error) {
func (ps *service) GetStampIssuer(batchID []byte) (*StampIssuer, func() error, error) {
ps.lock.Lock()
defer ps.lock.Unlock()

Expand All @@ -156,11 +156,8 @@ func (ps *service) GetStampIssuer(batchID []byte) (*StampIssuer, func(bool) erro
if !ps.IssuerUsable(st) {
return nil, nil, ErrNotUsable
}
return st, func(update bool) error {
if update {
return ps.save(st)
}
return nil
return st, func() error {
return ps.save(st)
}, nil
}
}
Expand Down
44 changes: 8 additions & 36 deletions pkg/postage/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ package postage_test
import (
crand "crypto/rand"
"errors"
"io"
"math/big"
"testing"

"github.com/ethersphere/bee/pkg/postage"
pstoremock "github.com/ethersphere/bee/pkg/postage/batchstore/mock"
postagetesting "github.com/ethersphere/bee/pkg/postage/testing"
"github.com/ethersphere/bee/pkg/storage/inmemstore"
"github.com/google/go-cmp/cmp"
"io"
"math/big"
"testing"
)

// TestSaveLoad tests the idempotence of saving and loading the postage.Service
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestGetStampIssuer(t *testing.T) {
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
_ = save(true)
_ = save()
if st.Label() != string(id) {
t.Fatalf("wrong issuer returned")
}
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestGetStampIssuer(t *testing.T) {
if st.Label() != "recovered" {
t.Fatal("wrong issuer returned")
}
err = sv(true)
err = sv()
if err != nil {
t.Fatal(err)
}
Expand All @@ -177,7 +177,7 @@ func TestGetStampIssuer(t *testing.T) {
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
_ = save(true)
_ = save()
if stampIssuer.Amount().Cmp(big.NewInt(13)) != 0 {
t.Fatalf("expected amount %d got %d", 13, stampIssuer.Amount().Int64())
}
Expand All @@ -191,40 +191,12 @@ func TestGetStampIssuer(t *testing.T) {
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
_ = save(true)
_ = save()
if stampIssuer.Amount().Cmp(big.NewInt(3)) != 0 {
t.Fatalf("expected amount %d got %d", 3, stampIssuer.Amount().Int64())
}
if stampIssuer.Depth() != 17 {
t.Fatalf("expected depth %d got %d", 17, stampIssuer.Depth())
}
})
t.Run("save without update", func(t *testing.T) {
is, save, err := ps.GetStampIssuer(ids[1])
if err != nil {
t.Fatal(err)
}
data := is.Buckets()
modified := make([]uint32, len(data))
copy(modified, data)
for k, b := range modified {
b++
modified[k] = b
}

err = save(false)
if err != nil {
t.Fatal(err)
}

is, _, err = ps.GetStampIssuer(ids[1])
if err != nil {
t.Fatal(err)
}

if !cmp.Equal(is.Buckets(), data) {
t.Fatal("expected buckets to be unchanged")
}

})
}
7 changes: 5 additions & 2 deletions pkg/postage/stamper.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import (

var (
// ErrBucketFull is the error when a collision bucket is full.
ErrBucketFull = errors.New("bucket full")
ErrBucketFull = errors.New("bucket full")
ErrOverwriteImmutableIndex = errors.New("immutable batch old index overwrite due to previous faulty save")
)

// Stamper can issue stamps from the given address of chunk.
Expand Down Expand Up @@ -45,11 +46,13 @@ func (st *stamper) Stamp(addr swarm.Address) (*Stamp, error) {
}
switch err := st.store.Get(item); {
case err == nil:
// check if this index is in the past, it could happen that we encountered
// The index should be in the past. It could happen that we encountered
// some error after assigning this index and did not save the issuer data. In
// this case we should assign a new index and update it.
if st.issuer.assigned(item.BatchIndex) {
break
} else if st.issuer.ImmutableFlag() {
return nil, ErrOverwriteImmutableIndex
}
fallthrough
case errors.Is(err, storage.ErrNotFound):
Expand Down
19 changes: 18 additions & 1 deletion pkg/postage/stamper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func TestStamperStamping(t *testing.T) {
})

t.Run("incorrect old index", func(t *testing.T) {
st := newTestStampIssuer(t, 1000)
st := newTestStampIssuerMutability(t, 1000, false)
chunkAddr := swarm.RandAddress(t)
bIdx := postage.ToBucket(st.BucketDepth(), chunkAddr)
index := postage.IndexToBytes(bIdx, 100)
Expand All @@ -129,6 +129,23 @@ func TestStamperStamping(t *testing.T) {
}
})

t.Run("incorrect old index immutable", func(t *testing.T) {
st := newTestStampIssuerMutability(t, 1000, true)
chunkAddr := swarm.RandAddress(t)
bIdx := postage.ToBucket(st.BucketDepth(), chunkAddr)
index := postage.IndexToBytes(bIdx, 100)
testItem := postage.NewStampItem().
WithBatchID(st.ID()).
WithChunkAddress(chunkAddr).
WithBatchIndex(index)
testSt := &testStore{Store: inmemstore.New(), stampItem: testItem}
stamper := postage.NewStamper(testSt, st, signer)
_, err := stamper.Stamp(chunkAddr)
if !errors.Is(err, postage.ErrOverwriteImmutableIndex) {
t.Fatalf("got err %v, wanted %v", err, postage.ErrOverwriteImmutableIndex)
}
})

// tests return with ErrOwnerMismatch
t.Run("owner mismatch", func(t *testing.T) {
owner[0] ^= 0xff // bitflip the owner first byte, this case must come last!
Expand Down
7 changes: 6 additions & 1 deletion pkg/postage/stampissuer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ func TestStampIssuerMarshalling(t *testing.T) {
}

func newTestStampIssuer(t *testing.T, block uint64) *postage.StampIssuer {
t.Helper()
return newTestStampIssuerMutability(t, block, true)
}

func newTestStampIssuerMutability(t *testing.T, block uint64, immutable bool) *postage.StampIssuer {
t.Helper()
id := make([]byte, 32)
_, err := io.ReadFull(crand.Reader, id)
Expand All @@ -61,7 +66,7 @@ func newTestStampIssuer(t *testing.T, block uint64) *postage.StampIssuer {
16,
8,
block,
true,
immutable,
)
}

Expand Down

0 comments on commit 2b0e397

Please sign in to comment.