Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: save stamps for all upload types regardless of any errors #4327

Merged
merged 1 commit into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ type putterOptions struct {
type putterSessionWrapper struct {
storer.PutterSession
stamper postage.Stamper
save func(bool) error
save func() error
}

func (p *putterSessionWrapper) Put(ctx context.Context, chunk swarm.Chunk) error {
Expand All @@ -788,14 +788,14 @@ func (p *putterSessionWrapper) Done(ref swarm.Address) error {
if err != nil {
return err
}
return p.save(true)
return p.save()
}

func (p *putterSessionWrapper) Cleanup() error {
return errors.Join(p.PutterSession.Cleanup(), p.save(false))
return errors.Join(p.PutterSession.Cleanup(), p.save())
}

func (s *Service) getStamper(batchID []byte) (postage.Stamper, func(bool) error, error) {
func (s *Service) getStamper(batchID []byte) (postage.Stamper, func() error, error) {
exists, err := s.batchStore.Exists(batchID)
if err != nil {
return nil, nil, fmt.Errorf("batch exists: %w", err)
Expand Down
9 changes: 1 addition & 8 deletions pkg/api/postage.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func (s *Service) postageGetStampBucketsHandler(w http.ResponseWriter, r *http.R
}
hexBatchID := hex.EncodeToString(paths.BatchID)

issuer, save, err := s.post.GetStampIssuer(paths.BatchID)
issuer, _, err := s.post.GetStampIssuer(paths.BatchID)
if err != nil {
logger.Debug("get stamp issuer: get issuer failed", "batch_id", hexBatchID, "error", err)
logger.Error(nil, "get stamp issuer: get issuer failed")
Expand All @@ -303,13 +303,6 @@ func (s *Service) postageGetStampBucketsHandler(w http.ResponseWriter, r *http.R
resp.Buckets[i] = bucketData{BucketID: uint32(i), Collisions: v}
}

if err := save(false); err != nil {
logger.Debug("get stamp issuer: save issuer failed", "batch_id", hexBatchID, "error", err)
logger.Error(nil, "get stamp issuer: save issuer failed")
jsonhttp.InternalServerError(w, "save issuer failed")
return
}

jsonhttp.OK(w, resp)
}

Expand Down
4 changes: 1 addition & 3 deletions pkg/api/pss.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ func (s *Service) pssPostHandler(w http.ResponseWriter, r *http.Request) {

err = s.pss.Send(r.Context(), topic, payload, stamper, queries.Recipient, targets)
if err != nil {
err = errors.Join(err, save(false))
logger.Debug("send payload failed", "topic", paths.Topic, "error", err)
logger.Error(nil, "send payload failed")
switch {
Expand All @@ -110,8 +109,7 @@ func (s *Service) pssPostHandler(w http.ResponseWriter, r *http.Request) {
return
}

err = save(true)
if err != nil {
if err = save(); err != nil {
logger.Debug("save stamp failed", "error", err)
logger.Error(nil, "save stamp failed")
jsonhttp.InternalServerError(w, "pss send failed")
Expand Down
6 changes: 3 additions & 3 deletions pkg/api/stewardship.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,20 @@ func (s *Service) stewardshipPutHandler(w http.ResponseWriter, r *http.Request)

err = s.steward.Reupload(r.Context(), paths.Address, stamper)
if err != nil {
err = errors.Join(err, save(false))
err = errors.Join(err, save())
logger.Debug("re-upload failed", "chunk_address", paths.Address, "error", err)
logger.Error(nil, "re-upload failed")
jsonhttp.InternalServerError(w, "re-upload failed")
return
}

err = save(true)
if err != nil {
if err = save(); err != nil {
logger.Debug("unable to save stamper data", "batchID", batchID, "error", err)
logger.Error(nil, "unable to save stamper data")
jsonhttp.InternalServerError(w, "unable to save stamper data")
return
}

jsonhttp.OK(w, nil)
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/postage/mock/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ func (m *mockPostage) StampIssuers() []*postage.StampIssuer {
return issuers
}

func (m *mockPostage) GetStampIssuer(id []byte) (*postage.StampIssuer, func(bool) error, error) {
func (m *mockPostage) GetStampIssuer(id []byte) (*postage.StampIssuer, func() error, error) {
if m.acceptAll {
return postage.NewStampIssuer("test fallback", "test identity", id, big.NewInt(3), 24, 6, 1000, true), func(_ bool) error { return nil }, nil
return postage.NewStampIssuer("test fallback", "test identity", id, big.NewInt(3), 24, 6, 1000, false), func() error { return nil }, nil
}

m.issuerLock.Lock()
Expand All @@ -98,7 +98,7 @@ func (m *mockPostage) GetStampIssuer(id []byte) (*postage.StampIssuer, func(bool
return nil, nil, postage.ErrNotFound
}

return i, func(_ bool) error {
return i, func() error {
return nil
}, nil
}
Expand Down
11 changes: 4 additions & 7 deletions pkg/postage/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var (
type Service interface {
Add(*StampIssuer) error
StampIssuers() []*StampIssuer
GetStampIssuer([]byte) (*StampIssuer, func(bool) error, error)
GetStampIssuer([]byte) (*StampIssuer, func() error, error)
IssuerUsable(*StampIssuer) bool
BatchEventListener
BatchExpiryHandler
Expand Down Expand Up @@ -147,7 +147,7 @@ func (ps *service) IssuerUsable(st *StampIssuer) bool {
}

// GetStampIssuer finds a stamp issuer by batch ID.
func (ps *service) GetStampIssuer(batchID []byte) (*StampIssuer, func(bool) error, error) {
func (ps *service) GetStampIssuer(batchID []byte) (*StampIssuer, func() error, error) {
ps.lock.Lock()
defer ps.lock.Unlock()

Expand All @@ -156,11 +156,8 @@ func (ps *service) GetStampIssuer(batchID []byte) (*StampIssuer, func(bool) erro
if !ps.IssuerUsable(st) {
return nil, nil, ErrNotUsable
}
return st, func(update bool) error {
if update {
return ps.save(st)
}
return nil
return st, func() error {
return ps.save(st)
}, nil
}
}
Expand Down
44 changes: 8 additions & 36 deletions pkg/postage/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ package postage_test
import (
crand "crypto/rand"
"errors"
"io"
"math/big"
"testing"

"github.com/ethersphere/bee/pkg/postage"
pstoremock "github.com/ethersphere/bee/pkg/postage/batchstore/mock"
postagetesting "github.com/ethersphere/bee/pkg/postage/testing"
"github.com/ethersphere/bee/pkg/storage/inmemstore"
"github.com/google/go-cmp/cmp"
"io"
"math/big"
"testing"
)

// TestSaveLoad tests the idempotence of saving and loading the postage.Service
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestGetStampIssuer(t *testing.T) {
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
_ = save(true)
_ = save()
if st.Label() != string(id) {
t.Fatalf("wrong issuer returned")
}
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestGetStampIssuer(t *testing.T) {
if st.Label() != "recovered" {
t.Fatal("wrong issuer returned")
}
err = sv(true)
err = sv()
if err != nil {
t.Fatal(err)
}
Expand All @@ -177,7 +177,7 @@ func TestGetStampIssuer(t *testing.T) {
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
_ = save(true)
_ = save()
if stampIssuer.Amount().Cmp(big.NewInt(13)) != 0 {
t.Fatalf("expected amount %d got %d", 13, stampIssuer.Amount().Int64())
}
Expand All @@ -191,40 +191,12 @@ func TestGetStampIssuer(t *testing.T) {
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
_ = save(true)
_ = save()
if stampIssuer.Amount().Cmp(big.NewInt(3)) != 0 {
t.Fatalf("expected amount %d got %d", 3, stampIssuer.Amount().Int64())
}
if stampIssuer.Depth() != 17 {
t.Fatalf("expected depth %d got %d", 17, stampIssuer.Depth())
}
})
t.Run("save without update", func(t *testing.T) {
is, save, err := ps.GetStampIssuer(ids[1])
if err != nil {
t.Fatal(err)
}
data := is.Buckets()
modified := make([]uint32, len(data))
copy(modified, data)
for k, b := range modified {
b++
modified[k] = b
}

err = save(false)
if err != nil {
t.Fatal(err)
}

is, _, err = ps.GetStampIssuer(ids[1])
if err != nil {
t.Fatal(err)
}

if !cmp.Equal(is.Buckets(), data) {
t.Fatal("expected buckets to be unchanged")
}

})
}
7 changes: 5 additions & 2 deletions pkg/postage/stamper.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import (

var (
// ErrBucketFull is the error when a collision bucket is full.
ErrBucketFull = errors.New("bucket full")
ErrBucketFull = errors.New("bucket full")
ErrOverwriteImmutableIndex = errors.New("immutable batch old index overwrite due to previous faulty save")
)

// Stamper can issue stamps from the given address of chunk.
Expand Down Expand Up @@ -45,11 +46,13 @@ func (st *stamper) Stamp(addr swarm.Address) (*Stamp, error) {
}
switch err := st.store.Get(item); {
case err == nil:
// check if this index is in the past, it could happen that we encountered
// The index should be in the past. It could happen that we encountered
// some error after assigning this index and did not save the issuer data. In
// this case we should assign a new index and update it.
if st.issuer.assigned(item.BatchIndex) {
break
} else if st.issuer.ImmutableFlag() {
return nil, ErrOverwriteImmutableIndex
}
fallthrough
case errors.Is(err, storage.ErrNotFound):
Expand Down
19 changes: 18 additions & 1 deletion pkg/postage/stamper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func TestStamperStamping(t *testing.T) {
})

t.Run("incorrect old index", func(t *testing.T) {
st := newTestStampIssuer(t, 1000)
st := newTestStampIssuerMutability(t, 1000, false)
chunkAddr := swarm.RandAddress(t)
bIdx := postage.ToBucket(st.BucketDepth(), chunkAddr)
index := postage.IndexToBytes(bIdx, 100)
Expand All @@ -129,6 +129,23 @@ func TestStamperStamping(t *testing.T) {
}
})

t.Run("incorrect old index immutable", func(t *testing.T) {
st := newTestStampIssuerMutability(t, 1000, true)
chunkAddr := swarm.RandAddress(t)
bIdx := postage.ToBucket(st.BucketDepth(), chunkAddr)
index := postage.IndexToBytes(bIdx, 100)
testItem := postage.NewStampItem().
WithBatchID(st.ID()).
WithChunkAddress(chunkAddr).
WithBatchIndex(index)
testSt := &testStore{Store: inmemstore.New(), stampItem: testItem}
stamper := postage.NewStamper(testSt, st, signer)
_, err := stamper.Stamp(chunkAddr)
if !errors.Is(err, postage.ErrOverwriteImmutableIndex) {
t.Fatalf("got err %v, wanted %v", err, postage.ErrOverwriteImmutableIndex)
}
})

// tests return with ErrOwnerMismatch
t.Run("owner mismatch", func(t *testing.T) {
owner[0] ^= 0xff // bitflip the owner first byte, this case must come last!
Expand Down
7 changes: 6 additions & 1 deletion pkg/postage/stampissuer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ func TestStampIssuerMarshalling(t *testing.T) {
}

func newTestStampIssuer(t *testing.T, block uint64) *postage.StampIssuer {
t.Helper()
return newTestStampIssuerMutability(t, block, true)
}

func newTestStampIssuerMutability(t *testing.T, block uint64, immutable bool) *postage.StampIssuer {
t.Helper()
id := make([]byte, 32)
_, err := io.ReadFull(crand.Reader, id)
Expand All @@ -61,7 +66,7 @@ func newTestStampIssuer(t *testing.T, block uint64) *postage.StampIssuer {
16,
8,
block,
true,
immutable,
)
}

Expand Down