diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index cf99cf03fac..5cb35eaf5a6 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -9,7 +9,6 @@ import ( "encoding/hex" "errors" "fmt" - "github.com/ethersphere/bee/pkg/topology" "net/http" "path" "path/filepath" @@ -17,6 +16,8 @@ import ( "strings" "time" + "github.com/ethersphere/bee/pkg/topology" + "github.com/ethereum/go-ethereum/common" "github.com/ethersphere/bee/pkg/feeds" "github.com/ethersphere/bee/pkg/file/joiner" @@ -175,7 +176,7 @@ func (s *Service) fileUploadHandler( } factory := requestPipelineFactory(ctx, putter, encrypt, rLevel) - l := loadsave.New(s.storer.ChunkStore(), factory) + l := loadsave.New(s.storer.ChunkStore(), s.storer.Cache(), factory) m, err := manifest.NewDefaultManifest(l, encrypt) if err != nil { @@ -443,7 +444,7 @@ func (s *Service) serveManifestEntry( // downloadHandler contains common logic for dowloading Swarm file from API func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *http.Request, reference swarm.Address, additionalHeaders http.Header, etag bool) { - reader, l, err := joiner.New(r.Context(), s.storer.Download(true), reference) + reader, l, err := joiner.New(r.Context(), s.storer.Download(true), s.storer.Cache(), reference) if err != nil { if errors.Is(err, storage.ErrNotFound) || errors.Is(err, topology.ErrNotFound) { logger.Debug("api download: not found ", "address", reference, "error", err) diff --git a/pkg/api/bzz_test.go b/pkg/api/bzz_test.go index 731cfcb567f..e6d52d8acd3 100644 --- a/pkg/api/bzz_test.go +++ b/pkg/api/bzz_test.go @@ -554,7 +554,7 @@ func TestFeedIndirection(t *testing.T) { t.Fatal(err) } m, err := manifest.NewDefaultManifest( - loadsave.New(storer.ChunkStore(), pipelineFactory(storer.Cache(), false, 0)), + loadsave.New(storer.ChunkStore(), storer.Cache(), pipelineFactory(storer.Cache(), false, 0)), false, ) if err != nil { diff --git a/pkg/api/dirs.go b/pkg/api/dirs.go index 771787a24a9..00b530b9137 100644 --- a/pkg/api/dirs.go +++ b/pkg/api/dirs.go @@ -132,7 +132,7 @@ func storeDir( loggerV1 := logger.V(1).Build() p := requestPipelineFn(putter, encrypt, rLevel) - ls := loadsave.New(getter, requestPipelineFactory(ctx, putter, encrypt, rLevel)) + ls := loadsave.New(getter, putter, requestPipelineFactory(ctx, putter, encrypt, rLevel)) dirManifest, err := manifest.NewDefaultManifest(ls, encrypt) if err != nil { diff --git a/pkg/api/feed.go b/pkg/api/feed.go index f3a3413a426..23961e6d5fc 100644 --- a/pkg/api/feed.go +++ b/pkg/api/feed.go @@ -196,7 +196,7 @@ func (s *Service) feedPostHandler(w http.ResponseWriter, r *http.Request) { logger: logger, } - l := loadsave.New(s.storer.ChunkStore(), requestPipelineFactory(r.Context(), putter, false, 0)) + l := loadsave.New(s.storer.ChunkStore(), s.storer.Cache(), requestPipelineFactory(r.Context(), putter, false, 0)) feedManifest, err := manifest.NewDefaultManifest(l, false) if err != nil { logger.Debug("create manifest failed", "error", err) diff --git a/pkg/api/pin.go b/pkg/api/pin.go index a98af277583..7c80e5f196b 100644 --- a/pkg/api/pin.go +++ b/pkg/api/pin.go @@ -50,7 +50,7 @@ func (s *Service) pinRootHash(w http.ResponseWriter, r *http.Request) { } getter := s.storer.Download(true) - traverser := traversal.New(getter) + traverser := traversal.New(getter, s.storer.Cache()) sem := semaphore.NewWeighted(100) var errTraverse error diff --git a/pkg/file/addresses/addresses_getter_test.go b/pkg/file/addresses/addresses_getter_test.go index 542c2d8a442..a5fe9d00000 100644 --- a/pkg/file/addresses/addresses_getter_test.go +++ b/pkg/file/addresses/addresses_getter_test.go @@ -63,7 +63,7 @@ func TestAddressesGetterIterateChunkAddresses(t *testing.T) { addressesGetter := addresses.NewGetter(store, addressIterFunc) - j, _, err := joiner.New(ctx, addressesGetter, rootChunk.Address()) + j, _, err := joiner.New(ctx, addressesGetter, store, rootChunk.Address()) if err != nil { t.Fatal(err) } diff --git a/pkg/file/file_test.go b/pkg/file/file_test.go index fd85a23fb9d..cc677336b64 100644 --- a/pkg/file/file_test.go +++ b/pkg/file/file_test.go @@ -62,7 +62,7 @@ func testSplitThenJoin(t *testing.T) { } // then join - r, l, err := joiner.New(ctx, store, resultAddress) + r, l, err := joiner.New(ctx, store, store, resultAddress) if err != nil { t.Fatal(err) } diff --git a/pkg/file/joiner/joiner.go b/pkg/file/joiner/joiner.go index b046c6bd0fd..5758584f006 100644 --- a/pkg/file/joiner/joiner.go +++ b/pkg/file/joiner/joiner.go @@ -7,33 +7,40 @@ package joiner import ( "context" - "encoding/binary" "errors" "io" "sync" "sync/atomic" + "github.com/ethersphere/bee/pkg/bmt" "github.com/ethersphere/bee/pkg/encryption" "github.com/ethersphere/bee/pkg/encryption/store" "github.com/ethersphere/bee/pkg/file" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/file/redundancy/getter" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "golang.org/x/sync/errgroup" ) type joiner struct { - addr swarm.Address - rootData []byte - span int64 - off int64 - refLength int + addr swarm.Address + rootData []byte + span int64 + off int64 + refLength int + rootParity int + maxBranching int // maximum branching in an intermediate chunk ctx context.Context getter storage.Getter + putter storage.Putter // required to save recovered data + + chunkToSpan func(data []byte) (redundancy.Level, int64) // returns parity and span value from chunkData } // New creates a new Joiner. A Joiner provides Read, Seek and Size functionalities. -func New(ctx context.Context, getter storage.Getter, address swarm.Address) (file.Joiner, int64, error) { +func New(ctx context.Context, getter storage.Getter, putter storage.Putter, address swarm.Address) (file.Joiner, int64, error) { getter = store.New(getter) // retrieve the root chunk to read the total data length the be retrieved rootChunk, err := getter.Get(ctx, address) @@ -41,17 +48,42 @@ func New(ctx context.Context, getter storage.Getter, address swarm.Address) (fil return nil, 0, err } - var chunkData = rootChunk.Data() - - span := int64(binary.LittleEndian.Uint64(chunkData[:swarm.SpanSize])) + chunkData := rootChunk.Data() + rootData := chunkData[swarm.SpanSize:] + refLength := len(address.Bytes()) + encryption := false + if refLength != swarm.HashSize { + encryption = true + } + rLevel, span := chunkToSpan(chunkData) + rootParity := 0 + maxBranching := swarm.ChunkSize / refLength + spanFn := func(data []byte) (redundancy.Level, int64) { + return 0, int64(bmt.LengthFromSpan(data[:swarm.SpanSize])) + } + // override stuff if root chunk has redundancy + if rLevel != redundancy.NONE { + _, parities := file.ReferenceCount(uint64(span), rLevel, encryption) + rootParity = parities + spanFn = chunkToSpan + if encryption { + maxBranching = rLevel.GetMaxEncShards() + } else { + maxBranching = rLevel.GetMaxShards() + } + } j := &joiner{ - addr: rootChunk.Address(), - refLength: len(address.Bytes()), - ctx: ctx, - getter: getter, - span: span, - rootData: chunkData[swarm.SpanSize:], + addr: rootChunk.Address(), + refLength: refLength, + ctx: ctx, + getter: getter, + putter: putter, + span: span, + rootData: rootData, + rootParity: rootParity, + maxBranching: maxBranching, + chunkToSpan: spanFn, } return j, span, nil @@ -81,7 +113,7 @@ func (j *joiner) ReadAt(buffer []byte, off int64) (read int, err error) { } var bytesRead int64 var eg errgroup.Group - j.readAtOffset(buffer, j.rootData, 0, j.span, off, 0, readLen, &bytesRead, &eg) + j.readAtOffset(buffer, j.rootData, 0, j.span, off, 0, readLen, &bytesRead, j.rootParity, &eg) err = eg.Wait() if err != nil { @@ -93,7 +125,13 @@ func (j *joiner) ReadAt(buffer []byte, off int64) (read int, err error) { var ErrMalformedTrie = errors.New("malformed tree") -func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffset, bytesToRead int64, bytesRead *int64, eg *errgroup.Group) { +func (j *joiner) readAtOffset( + b, data []byte, + cur, subTrieSize, off, bufferOffset, bytesToRead int64, + bytesRead *int64, + parity int, + eg *errgroup.Group, +) { // we are at a leaf data chunk if subTrieSize <= int64(len(data)) { dataOffsetStart := off - cur @@ -109,14 +147,23 @@ func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffse return } + pSize, err := file.ChunkPayloadSize(data) + if err != nil { + eg.Go(func() error { + return err + }) + return + } + sAddresses, pAddresses := file.ChunkAddresses(data[:pSize], parity, j.refLength) + getter := getter.New(sAddresses, pAddresses, j.getter, j.putter) for cursor := 0; cursor < len(data); cursor += j.refLength { if bytesToRead == 0 { break } // fast forward the cursor - sec := subtrieSection(data, cursor, j.refLength, subTrieSize) - if cur+sec < off { + sec := j.subtrieSection(data, cursor, pSize, parity, subTrieSize) + if cur+sec <= off { cur += sec continue } @@ -139,19 +186,20 @@ func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffse func(address swarm.Address, b []byte, cur, subTrieSize, off, bufferOffset, bytesToRead, subtrieSpanLimit int64) { eg.Go(func() error { - ch, err := j.getter.Get(j.ctx, address) + ch, err := getter.Get(j.ctx, address) if err != nil { return err } chunkData := ch.Data()[8:] - subtrieSpan := int64(chunkToSpan(ch.Data())) + subtrieLevel, subtrieSpan := j.chunkToSpan(ch.Data()) + _, subtrieParity := file.ReferenceCount(uint64(subtrieSpan), subtrieLevel, j.refLength != swarm.HashSize) if subtrieSpan > subtrieSpanLimit { return ErrMalformedTrie } - j.readAtOffset(b, chunkData, cur, subtrieSpan, off, bufferOffset, currentReadSize, bytesRead, eg) + j.readAtOffset(b, chunkData, cur, subtrieSpan, off, bufferOffset, currentReadSize, bytesRead, subtrieParity, eg) return nil }) }(address, b, cur, subtrieSpan, off, bufferOffset, currentReadSize, subtrieSpanLimit) @@ -163,8 +211,13 @@ func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffse } } +// getShards returns the effective reference number respective to the intermediate chunk payload length and its parities +func (j *joiner) getShards(payloadSize, parities int) int { + return (payloadSize - parities*swarm.HashSize) / j.refLength +} + // brute-forces the subtrie size for each of the sections in this intermediate chunk -func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64 { +func (j *joiner) subtrieSection(data []byte, startIdx, payloadSize, parities int, subtrieSize int64) int64 { // assume we have a trie of size `y` then we can assume that all of // the forks except for the last one on the right are of equal size // this is due to how the splitter wraps levels. @@ -173,9 +226,9 @@ func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64 // where y is the size of the subtrie, refs are the number of references // x is constant (the brute forced value) and l is the size of the last subtrie var ( - refs = int64(len(data) / refLen) // how many references in the intermediate chunk - branching = int64(4096 / refLen) // branching factor is chunkSize divided by reference length - branchSize = int64(4096) + refs = int64(j.getShards(payloadSize, parities)) // how many effective references in the intermediate chunk + branching = int64(j.maxBranching) // branching factor is chunkSize divided by reference length + branchSize = int64(swarm.ChunkSize) ) for { whatsLeft := subtrieSize - (branchSize * (refs - 1)) @@ -186,7 +239,7 @@ func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64 } // handle last branch edge case - if startIdx == int(refs-1)*refLen { + if startIdx == int(refs-1)*j.refLength { return subtrieSize - (refs-1)*branchSize } return branchSize @@ -229,10 +282,10 @@ func (j *joiner) IterateChunkAddresses(fn swarm.AddressIterFunc) error { return err } - return j.processChunkAddresses(j.ctx, fn, j.rootData, j.span) + return j.processChunkAddresses(j.ctx, fn, j.rootData, j.span, j.rootParity) } -func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIterFunc, data []byte, subTrieSize int64) error { +func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIterFunc, data []byte, subTrieSize int64, parity int) error { // we are at a leaf data chunk if subTrieSize <= int64(len(data)) { return nil @@ -248,6 +301,12 @@ func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIter var wg sync.WaitGroup + eSize, err := file.ChunkPayloadSize(data) + if err != nil { + return err + } + sAddresses, pAddresses := file.ChunkAddresses(data[:eSize], parity, j.refLength) + getter := getter.New(sAddresses, pAddresses, j.getter, j.putter) for cursor := 0; cursor < len(data); cursor += j.refLength { ref := data[cursor : cursor+j.refLength] var reportAddr swarm.Address @@ -262,7 +321,7 @@ func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIter return err } - sec := subtrieSection(data, cursor, j.refLength, subTrieSize) + sec := j.subtrieSection(data, cursor, eSize, parity, subTrieSize) if sec <= swarm.ChunkSize { continue } @@ -273,15 +332,16 @@ func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIter eg.Go(func() error { defer wg.Done() - ch, err := j.getter.Get(ectx, address) + ch, err := getter.Get(ectx, address) if err != nil { return err } chunkData := ch.Data()[8:] - subtrieSpan := int64(chunkToSpan(ch.Data())) + subtrieLevel, subtrieSpan := j.chunkToSpan(ch.Data()) + _, parities := file.ReferenceCount(uint64(subtrieSpan), subtrieLevel, j.refLength != swarm.HashSize) - return j.processChunkAddresses(ectx, fn, chunkData, subtrieSpan) + return j.processChunkAddresses(ectx, fn, chunkData, subtrieSpan, parities) }) }(address, eg) @@ -295,6 +355,9 @@ func (j *joiner) Size() int64 { return j.span } -func chunkToSpan(data []byte) uint64 { - return binary.LittleEndian.Uint64(data[:8]) +// UTILITIES + +func chunkToSpan(data []byte) (redundancy.Level, int64) { + level, spanBytes := redundancy.DecodeSpan(data[:swarm.SpanSize]) + return level, int64(bmt.LengthFromSpan(spanBytes)) } diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 772c0e12ae2..67ac1c3f9c1 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -7,6 +7,7 @@ package joiner_test import ( "bytes" "context" + "crypto/rand" "encoding/binary" "errors" "fmt" @@ -20,6 +21,7 @@ import ( "github.com/ethersphere/bee/pkg/encryption/store" "github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/pipeline/builder" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/file/splitter" filetest "github.com/ethersphere/bee/pkg/file/testing" storage "github.com/ethersphere/bee/pkg/storage" @@ -34,7 +36,7 @@ func TestJoiner_ErrReferenceLength(t *testing.T) { t.Parallel() store := inmemchunkstore.New() - _, _, err := joiner.New(context.Background(), store, swarm.ZeroAddress) + _, _, err := joiner.New(context.Background(), store, store, swarm.ZeroAddress) if !errors.Is(err, storage.ErrReferenceLength) { t.Fatalf("expected ErrReferenceLength %x but got %v", swarm.ZeroAddress, err) @@ -64,7 +66,7 @@ func TestJoinerSingleChunk(t *testing.T) { } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, mockAddr) + joinReader, l, err := joiner.New(ctx, store, store, mockAddr) if err != nil { t.Fatal(err) } @@ -104,7 +106,7 @@ func TestJoinerDecryptingStore_NormalChunk(t *testing.T) { } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, mockAddr) + joinReader, l, err := joiner.New(ctx, store, st, mockAddr) if err != nil { t.Fatal(err) } @@ -152,7 +154,7 @@ func TestJoinerWithReference(t *testing.T) { } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, rootChunk.Address()) + joinReader, l, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -208,7 +210,7 @@ func TestJoinerMalformed(t *testing.T) { t.Fatal(err) } - joinReader, _, err := joiner.New(ctx, store, rootChunk.Address()) + joinReader, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -254,7 +256,7 @@ func TestEncryptDecrypt(t *testing.T) { if err != nil { t.Fatal(err) } - reader, l, err := joiner.New(context.Background(), store, resultAddress) + reader, l, err := joiner.New(context.Background(), store, store, resultAddress) if err != nil { t.Fatal(err) } @@ -341,7 +343,7 @@ func TestSeek(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, addr) + j, _, err := joiner.New(ctx, store, store, addr) if err != nil { t.Fatal(err) } @@ -618,7 +620,7 @@ func TestPrefetch(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, addr) + j, _, err := joiner.New(ctx, store, store, addr) if err != nil { t.Fatal(err) } @@ -667,7 +669,7 @@ func TestJoinerReadAt(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -714,7 +716,7 @@ func TestJoinerOneLevel(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -808,7 +810,7 @@ func TestJoinerTwoLevelsAcrossChunk(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -864,7 +866,7 @@ func TestJoinerIterateChunkAddresses(t *testing.T) { createdAddresses := []swarm.Address{rootChunk.Address(), firstAddress, secondAddress} - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -917,7 +919,7 @@ func TestJoinerIterateChunkAddresses_Encrypted(t *testing.T) { if err != nil { t.Fatal(err) } - j, l, err := joiner.New(context.Background(), store, resultAddress) + j, l, err := joiner.New(context.Background(), store, store, resultAddress) if err != nil { t.Fatal(err) } @@ -951,3 +953,121 @@ func TestJoinerIterateChunkAddresses_Encrypted(t *testing.T) { } } } + +func TestJoinerRedundancy(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + rLevel redundancy.Level + encryptChunk bool + }{ + { + redundancy.MEDIUM, + true, + }, + { + redundancy.STRONG, + false, + }, + { + redundancy.INSANE, + true, + }, + { + redundancy.PARANOID, + false, + }, + } { + tc := tc + t.Run(fmt.Sprintf("redundancy %d encryption %t", tc.rLevel, tc.encryptChunk), func(t *testing.T) { + ctx := context.Background() + store := inmemchunkstore.New() + pipe := builder.NewPipelineBuilder(ctx, store, tc.encryptChunk, tc.rLevel) + + // generate and store chunks + dataChunkCount := tc.rLevel.GetMaxShards() + 1 // generate a carrier chunk + if tc.encryptChunk { + dataChunkCount = tc.rLevel.GetMaxEncShards() + 1 + } + dataChunks := make([]swarm.Chunk, dataChunkCount) + chunkSize := swarm.ChunkSize + for i := 0; i < dataChunkCount; i++ { + chunkData := make([]byte, chunkSize) + _, err := io.ReadFull(rand.Reader, chunkData) + if err != nil { + t.Fatal(err) + } + dataChunks[i], err = cac.New(chunkData) + if err != nil { + t.Fatal(err) + } + err = store.Put(ctx, dataChunks[i]) + if err != nil { + t.Fatal(err) + } + _, err = pipe.Write(chunkData) + if err != nil { + t.Fatal(err) + } + } + + // reader init + sum, err := pipe.Sum() + if err != nil { + t.Fatal(err) + } + swarmAddr := swarm.NewAddress(sum) + joinReader, rootSpan, err := joiner.New(ctx, store, store, swarmAddr) + if err != nil { + t.Fatal(err) + } + // sanity checks + expectedRootSpan := chunkSize * dataChunkCount + if int64(expectedRootSpan) != rootSpan { + t.Fatalf("Expected root span %d. Got: %d", expectedRootSpan, rootSpan) + } + // all data can be read back + readCheck := func() { + offset := int64(0) + for i := 0; i < dataChunkCount; i++ { + chunkData := make([]byte, chunkSize) + _, err = joinReader.ReadAt(chunkData, offset) + if err != nil { + t.Fatalf("read error check at chunkdata comparisation on %d index: %s", i, err.Error()) + } + expectedChunkData := dataChunks[i].Data()[swarm.SpanSize:] + if !bytes.Equal(expectedChunkData, chunkData) { + t.Fatalf("read error check at chunkdata comparisation on %d index. Data are not the same", i) + } + offset += int64(chunkSize) + } + } + readCheck() + + // remove data chunks in order to trigger recovery + maxShards := tc.rLevel.GetMaxShards() + maxParities := tc.rLevel.GetParities(maxShards) + if tc.encryptChunk { + maxShards = tc.rLevel.GetMaxEncShards() + maxParities = tc.rLevel.GetEncParities(maxShards) + } + removeCount := maxParities + if maxParities > maxShards { + removeCount = maxShards + } + for i := 0; i < removeCount; i++ { + err := store.Delete(ctx, dataChunks[i].Address()) + if err != nil { + t.Fatal(err) + } + } + // remove parity chunk + err = store.Delete(ctx, dataChunks[len(dataChunks)-1].Address()) + if err != nil { + t.Fatal(err) + } + + // check whether the data still be readable + readCheck() + }) + } +} diff --git a/pkg/file/loadsave/loadsave.go b/pkg/file/loadsave/loadsave.go index 1f5267a4e8e..1d65e45ab6b 100644 --- a/pkg/file/loadsave/loadsave.go +++ b/pkg/file/loadsave/loadsave.go @@ -27,13 +27,15 @@ var errReadonlyLoadSave = errors.New("readonly manifest loadsaver") // load all of the subtrie of a given hash in memory. type loadSave struct { getter storage.Getter + putter storage.Putter pipelineFn func() pipeline.Interface } // New returns a new read-write load-saver. -func New(getter storage.Getter, pipelineFn func() pipeline.Interface) file.LoadSaver { +func New(getter storage.Getter, putter storage.Putter, pipelineFn func() pipeline.Interface) file.LoadSaver { return &loadSave{ getter: getter, + putter: putter, pipelineFn: pipelineFn, } } @@ -47,7 +49,7 @@ func NewReadonly(getter storage.Getter) file.LoadSaver { } func (ls *loadSave) Load(ctx context.Context, ref []byte) ([]byte, error) { - j, _, err := joiner.New(ctx, ls.getter, swarm.NewAddress(ref)) + j, _, err := joiner.New(ctx, ls.getter, ls.putter, swarm.NewAddress(ref)) if err != nil { return nil, err } diff --git a/pkg/file/loadsave/loadsave_test.go b/pkg/file/loadsave/loadsave_test.go index cfb55a953a6..64154859757 100644 --- a/pkg/file/loadsave/loadsave_test.go +++ b/pkg/file/loadsave/loadsave_test.go @@ -28,7 +28,7 @@ func TestLoadSave(t *testing.T) { t.Parallel() store := inmemchunkstore.New() - ls := loadsave.New(store, pipelineFn(store)) + ls := loadsave.New(store, store, pipelineFn(store)) ref, err := ls.Save(context.Background(), data) if err != nil { diff --git a/pkg/file/pipeline/hashtrie/hashtrie_test.go b/pkg/file/pipeline/hashtrie/hashtrie_test.go index beebd77cf74..78b4c20b2d7 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie_test.go +++ b/pkg/file/pipeline/hashtrie/hashtrie_test.go @@ -23,6 +23,7 @@ import ( "github.com/ethersphere/bee/pkg/file/pipeline/mock" "github.com/ethersphere/bee/pkg/file/pipeline/store" "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" "github.com/ethersphere/bee/pkg/swarm" ) @@ -43,6 +44,41 @@ func init() { binary.LittleEndian.PutUint64(span, 1) } +// NewErasureHashTrieWriter returns back an redundancy param and a HastTrieWriter pipeline +// which are using simple BMT and StoreWriter pipelines for chunk writes +func newErasureHashTrieWriter( + ctx context.Context, + s storage.Putter, + rLevel redundancy.Level, + encryptChunks bool, + intermediateChunkPipeline, parityChunkPipeline pipeline.ChainWriter, +) (redundancy.IParams, pipeline.ChainWriter) { + pf := func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline) + return bmt.NewBmtWriter(lsw) + } + if encryptChunks { + pf = func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline) + b := bmt.NewBmtWriter(lsw) + return enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b) + } + } + ppf := func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, parityChunkPipeline) + return bmt.NewBmtWriter(lsw) + } + + hashSize := swarm.HashSize + if encryptChunks { + hashSize *= 2 + } + + r := redundancy.New(rLevel, encryptChunks, ppf) + ht := hashtrie.NewHashTrieWriter(hashSize, r, pf) + return r, ht +} + func TestLevels(t *testing.T) { t.Parallel() @@ -271,32 +307,14 @@ func TestRedundancy(t *testing.T) { s := inmemchunkstore.New() intermediateChunkCounter := mock.NewChainWriter() parityChunkCounter := mock.NewChainWriter() - pf := func() pipeline.ChainWriter { - lsw := store.NewStoreWriter(ctx, s, intermediateChunkCounter) - return bmt.NewBmtWriter(lsw) - } - if tc.encryption { - pf = func() pipeline.ChainWriter { - lsw := store.NewStoreWriter(ctx, s, intermediateChunkCounter) - b := bmt.NewBmtWriter(lsw) - return enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b) - } - } - ppf := func() pipeline.ChainWriter { - lsw := store.NewStoreWriter(ctx, s, parityChunkCounter) - return bmt.NewBmtWriter(lsw) - } + r, ht := newErasureHashTrieWriter(ctx, s, tc.level, tc.encryption, intermediateChunkCounter, parityChunkCounter) + + // write data to the hashTrie var key []byte - hashSize := swarm.HashSize if tc.encryption { - hashSize *= 2 - key = addr.Bytes() + key = make([]byte, swarm.HashSize) } - - r := redundancy.New(tc.level, tc.encryption, ppf) - ht := hashtrie.NewHashTrieWriter(hashSize, r, pf) - for i := 0; i < tc.writes; i++ { a := &pipeline.PipeWriteArgs{Data: chData, Span: chSpan, Ref: chAddr, Key: key} err := ht.ChainWrite(a) diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go new file mode 100644 index 00000000000..982c5080109 --- /dev/null +++ b/pkg/file/redundancy/getter/getter.go @@ -0,0 +1,333 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package getter + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/ethersphere/bee/pkg/encryption/store" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/klauspost/reedsolomon" +) + +/// ERRORS + +type cannotRecoverError struct { + missingChunks int +} + +func (e cannotRecoverError) Error() string { + return fmt.Sprintf("redundancy getter: there are %d missing chunks in order to do recovery", e.missingChunks) +} + +func IsCannotRecoverError(err error, missingChunks int) bool { + return errors.Is(err, cannotRecoverError{missingChunks}) +} + +type isNotRecoveredError struct { + chAddress string +} + +func (e isNotRecoveredError) Error() string { + return fmt.Sprintf("redundancy getter: chunk with address %s is not recovered", e.chAddress) +} + +func IsNotRecoveredError(err error, chAddress string) bool { + return errors.Is(err, isNotRecoveredError{chAddress}) +} + +type noDataAddressIncludedError struct { + chAddress string +} + +func (e noDataAddressIncludedError) Error() string { + return fmt.Sprintf("redundancy getter: no data shard address given with chunk address %s", e.chAddress) +} + +func IsNoDataAddressIncludedError(err error, chAddress string) bool { + return errors.Is(err, noDataAddressIncludedError{chAddress}) +} + +type noRedundancyError struct { + chAddress string +} + +func (e noRedundancyError) Error() string { + return fmt.Sprintf("redundancy getter: cannot get chunk %s because no redundancy added", e.chAddress) +} + +func IsNoRedundancyError(err error, chAddress string) bool { + return errors.Is(err, noRedundancyError{chAddress}) +} + +/// TYPES + +// inflightChunk is initialized if recovery happened already +type inflightChunk struct { + pos int // chunk index in the erasureData/intermediate chunk + wait chan struct{} // chunk is under retrieval +} + +// getter retrieves children of an intermediate chunk +// it caches sibling chunks if erasure decoding was called on the level already +type getter struct { + storage.Getter + storage.Putter + mu sync.Mutex + sAddresses []swarm.Address // shard addresses + pAddresses []swarm.Address // parity addresses + cache map[string]inflightChunk // map from chunk address to cached shard chunk data + erasureData [][]byte // data + parity shards for erasure decoding; TODO mutex + encrypted bool // swarm datashards are encrypted +} + +// New returns a getter object which is used to retrieve children of an intermediate chunk +func New(sAddresses, pAddresses []swarm.Address, g storage.Getter, p storage.Putter) storage.Getter { + encrypted := len(sAddresses[0].Bytes()) == swarm.HashSize*2 + shards := len(sAddresses) + parities := len(pAddresses) + n := shards + parities + erasureData := make([][]byte, n) + cache := make(map[string]inflightChunk, n) + // init cache + for i, addr := range sAddresses { + cache[addr.String()] = inflightChunk{ + pos: i, + // wait channel initialization is needed when recovery starts + } + } + for i, addr := range pAddresses { + cache[addr.String()] = inflightChunk{ + pos: len(sAddresses) + i, + // no wait channel initialization is needed + } + } + + return &getter{ + Getter: g, + Putter: p, + sAddresses: sAddresses, + pAddresses: pAddresses, + cache: cache, + encrypted: encrypted, + erasureData: erasureData, + } +} + +// Get will call parities and other sibling chunks if the chunk address cannot be retrieved +// assumes it is called for data shards only +func (g *getter) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { + g.mu.Lock() + cValue, ok := g.cache[addr.String()] + g.mu.Unlock() + if !ok || cValue.pos >= len(g.sAddresses) { + return nil, noDataAddressIncludedError{addr.String()} + } + + if cValue.wait != nil { // equals to g.processing but does not need lock again + return g.getAfterProcessed(ctx, addr) + } + + ch, err := g.Getter.Get(ctx, addr) + if err == nil { + return ch, nil + } + if errors.Is(storage.ErrNotFound, err) && len(g.pAddresses) == 0 { + return nil, noRedundancyError{addr.String()} + } + + // during the get, the recovery may have started by other process + if g.processing(addr) { + return g.getAfterProcessed(ctx, addr) + } + + return g.executeStrategies(ctx, addr) +} + +// Inc increments the counter for the given key. +func (g *getter) setErasureData(index int, data []byte) { + g.mu.Lock() + g.erasureData[index] = data + g.mu.Unlock() +} + +// processing returns whether the recovery workflow has been started +func (g *getter) processing(addr swarm.Address) bool { + g.mu.Lock() + defer g.mu.Unlock() + iCh := g.cache[addr.String()] + return iCh.wait != nil +} + +// getAfterProcessed returns chunk from the cache +func (g *getter) getAfterProcessed(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { + g.mu.Lock() + c, ok := g.cache[addr.String()] + // sanity check + if !ok { + return nil, fmt.Errorf("redundancy getter: chunk %s should have been in the cache", addr.String()) + } + + cacheData := g.erasureData[c.pos] + g.mu.Unlock() + if cacheData != nil { + return g.cacheDataToChunk(addr, cacheData) + } + + select { + case <-c.wait: + return g.cacheDataToChunk(addr, cacheData) + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// executeStrategies executes recovery strategies from redundancy for the given swarm address +func (g *getter) executeStrategies(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { + g.initWaitChannels() + err := g.cautiousStrategy(ctx) + if err != nil { + g.closeWaitChannels() + return nil, err + } + + return g.getAfterProcessed(ctx, addr) +} + +// initWaitChannels initializes the wait channels in the cache mapping which indicates the start of the recovery process as well +func (g *getter) initWaitChannels() { + g.mu.Lock() + for _, addr := range g.sAddresses { + iCh := g.cache[addr.String()] + iCh.wait = make(chan struct{}) + g.cache[addr.String()] = iCh + } + g.mu.Unlock() +} + +// closeChannls closes all pending channels +func (g *getter) closeWaitChannels() { + for _, addr := range g.sAddresses { + c := g.cache[addr.String()] + if !channelIsClosed(c.wait) { + close(c.wait) + } + } +} + +// cautiousStrategy requests all chunks (data and parity) on the level +// and if it has enough data for erasure decoding then it cancel other requests +func (g *getter) cautiousStrategy(ctx context.Context) error { + requiredChunks := len(g.sAddresses) + subContext, cancelContext := context.WithCancel(ctx) + retrievedCh := make(chan struct{}, requiredChunks+len(g.pAddresses)) + var wg sync.WaitGroup + + addresses := append(g.sAddresses, g.pAddresses...) + for _, a := range addresses { + wg.Add(1) + c := g.cache[a.String()] + go func(a swarm.Address, c inflightChunk) { + defer wg.Done() + // enrypted chunk data should remain encrypted + address := swarm.NewAddress(a.Bytes()[:swarm.HashSize]) + ch, err := g.Getter.Get(subContext, address) + if err != nil { + return + } + g.setErasureData(c.pos, ch.Data()) + if c.pos < len(g.sAddresses) && !channelIsClosed(c.wait) { + close(c.wait) + } + retrievedCh <- struct{}{} + }(a, c) + } + + // Goroutine to wait for WaitGroup completion + go func() { + wg.Wait() + close(retrievedCh) + }() + retrieved := 0 + for retrieved < requiredChunks { + _, ok := <-retrievedCh + if !ok { + break + } + retrieved++ + } + cancelContext() + + if retrieved < requiredChunks { + return cannotRecoverError{requiredChunks - retrieved} + } + + return g.erasureDecode(ctx) +} + +// erasureDecode perform Reed-Solomon recovery on data +// assumes it is called after filling up cache with the required amount of shards and parities +func (g *getter) erasureDecode(ctx context.Context) error { + enc, err := reedsolomon.New(len(g.sAddresses), len(g.pAddresses)) + if err != nil { + return err + } + + // missing chunks + var missingIndices []int + for i := range g.sAddresses { + if g.erasureData[i] == nil { + missingIndices = append(missingIndices, i) + } + } + + g.mu.Lock() + err = enc.ReconstructData(g.erasureData) + g.mu.Unlock() + if err != nil { + return err + } + + g.closeWaitChannels() + // save missing chunks + for _, index := range missingIndices { + data := g.erasureData[index] + addr := g.sAddresses[index] + err := g.Putter.Put(ctx, swarm.NewChunk(addr, data)) + if err != nil { + return err + } + } + return nil +} + +// cacheDataToChunk transforms passed chunk data to legit swarm chunk +func (g *getter) cacheDataToChunk(addr swarm.Address, chData []byte) (swarm.Chunk, error) { + if chData == nil { + return nil, isNotRecoveredError{addr.String()} + } + if g.encrypted { + data, err := store.DecryptChunkData(chData, addr.Bytes()[swarm.HashSize:]) + if err != nil { + return nil, err + } + chData = data + } + + return swarm.NewChunk(addr, chData), nil +} + +func channelIsClosed(wait <-chan struct{}) bool { + select { + case _, ok := <-wait: + return !ok + default: + return false + } +} diff --git a/pkg/file/redundancy/getter/getter_test.go b/pkg/file/redundancy/getter/getter_test.go new file mode 100644 index 00000000000..addee4b6028 --- /dev/null +++ b/pkg/file/redundancy/getter/getter_test.go @@ -0,0 +1,217 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package getter_test + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "testing" + + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/file/redundancy/getter" + "github.com/ethersphere/bee/pkg/storage" + inmem "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/klauspost/reedsolomon" +) + +func initData(ctx context.Context, erasureBuffer [][]byte, dataShardCount int, s storage.ChunkStore) ([]swarm.Address, []swarm.Address, error) { + spanBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(spanBytes, swarm.ChunkSize) + for i := 0; i < dataShardCount; i++ { + erasureBuffer[i] = make([]byte, swarm.HashSize) + _, err := io.ReadFull(rand.Reader, erasureBuffer[i]) + if err != nil { + return nil, nil, err + } + copy(erasureBuffer[i], spanBytes) + } + // make parity shards + for i := dataShardCount; i < len(erasureBuffer); i++ { + erasureBuffer[i] = make([]byte, swarm.HashSize) + } + // generate parity chunks + rs, err := reedsolomon.New(dataShardCount, len(erasureBuffer)-dataShardCount) + if err != nil { + return nil, nil, err + } + err = rs.Encode(erasureBuffer) + if err != nil { + return nil, nil, err + } + // calculate chunk addresses and upload to the store + var sAddresses, pAddresses []swarm.Address + for i := 0; i < dataShardCount; i++ { + chunk, err := cac.NewWithDataSpan(erasureBuffer[i]) + if err != nil { + return nil, nil, err + } + err = s.Put(ctx, chunk) + if err != nil { + return nil, nil, err + } + sAddresses = append(sAddresses, chunk.Address()) + } + for i := dataShardCount; i < len(erasureBuffer); i++ { + chunk, err := cac.NewWithDataSpan(erasureBuffer[i]) + if err != nil { + return nil, nil, err + } + err = s.Put(ctx, chunk) + if err != nil { + return nil, nil, err + } + pAddresses = append(pAddresses, chunk.Address()) + } + + return sAddresses, pAddresses, err +} + +func dataShardsAvailable(ctx context.Context, sAddresses []swarm.Address, s storage.ChunkStore) error { + for i := 0; i < len(sAddresses); i++ { + has, err := s.Has(ctx, sAddresses[i]) + if err != nil { + return err + } + if !has { + return fmt.Errorf("datashard %d is not available", i) + } + } + return nil +} + +func TestDecoding(t *testing.T) { + s := inmem.New() + + erasureBuffer := make([][]byte, 128) + dataShardCount := 100 + ctx := context.TODO() + + sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) + if err != nil { + t.Fatal(err) + } + + getter := getter.New(sAddresses, pAddresses, s, s) + // sanity check - all data shards are retrievable + for i := 0; i < dataShardCount; i++ { + ch, err := getter.Get(ctx, sAddresses[i]) + if err != nil { + t.Fatalf("address %s at index %d is not retrievable by redundancy getter", sAddresses[i], i) + } + if !bytes.Equal(ch.Data(), erasureBuffer[i]) { + t.Fatalf("retrieved chunk data differ from the original at index %d", i) + } + } + + // remove maximum possible chunks from storage + removeChunkCount := len(erasureBuffer) - dataShardCount + for i := 0; i < removeChunkCount; i++ { + err := s.Delete(ctx, sAddresses[i]) + if err != nil { + t.Fatal(err) + } + } + + err = dataShardsAvailable(ctx, sAddresses, s) // sanity check + if err == nil { + t.Fatalf("some data shards should be missing") + } + ch, err := getter.Get(ctx, sAddresses[0]) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(ch.Data(), erasureBuffer[0]) { + t.Fatalf("retrieved chunk data differ from the original at index %d", 0) + } + err = dataShardsAvailable(ctx, sAddresses, s) + if err != nil { + t.Fatal(err) + } +} + +func TestRecoveryLimits(t *testing.T) { + s := inmem.New() + + erasureBuffer := make([][]byte, 8) + dataShardCount := 5 + ctx := context.TODO() + + sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) + if err != nil { + t.Fatal(err) + } + + _getter := getter.New(sAddresses, pAddresses, s, s) + + // remove more chunks that can be corrected by erasure code + removeChunkCount := len(erasureBuffer) - dataShardCount + 1 + for i := 0; i < removeChunkCount; i++ { + err := s.Delete(ctx, sAddresses[i]) + if err != nil { + t.Fatal(err) + } + } + _, err = _getter.Get(ctx, sAddresses[0]) + if !getter.IsCannotRecoverError(err, 1) { + t.Fatal(err) + } + + // call once more + _, err = _getter.Get(ctx, sAddresses[0]) + if !getter.IsNotRecoveredError(err, sAddresses[0].String()) { + t.Fatal(err) + } +} + +func TestNoRedundancyOnRecovery(t *testing.T) { + s := inmem.New() + + dataShardCount := 5 + erasureBuffer := make([][]byte, dataShardCount) + ctx := context.TODO() + + sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) + if err != nil { + t.Fatal(err) + } + + _getter := getter.New(sAddresses, pAddresses, s, s) + + // remove one chunk that trying to request later + err = s.Delete(ctx, sAddresses[0]) + if err != nil { + t.Fatal(err) + } + _, err = _getter.Get(ctx, sAddresses[0]) + if !getter.IsNoRedundancyError(err, sAddresses[0].String()) { + t.Fatal(err) + } +} + +func TestNoDataAddressIncluded(t *testing.T) { + s := inmem.New() + + erasureBuffer := make([][]byte, 8) + dataShardCount := 5 + ctx := context.TODO() + + sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) + if err != nil { + t.Fatal(err) + } + + _getter := getter.New(sAddresses, pAddresses, s, s) + + // trying to retrieve a parity address + _, err = _getter.Get(ctx, pAddresses[0]) + if !getter.IsNoDataAddressIncludedError(err, pAddresses[0].String()) { + t.Fatal(err) + } +} diff --git a/pkg/file/redundancy/level.go b/pkg/file/redundancy/level.go index bdd7226af72..406e9b6c437 100644 --- a/pkg/file/redundancy/level.go +++ b/pkg/file/redundancy/level.go @@ -4,7 +4,11 @@ package redundancy -import "github.com/ethersphere/bee/pkg/swarm" +import ( + "errors" + + "github.com/ethersphere/bee/pkg/swarm" +) type Level uint8 @@ -20,18 +24,11 @@ const maxLevel = 8 // GetParities returns number of parities based on appendix F table 5 func (l Level) GetParities(shards int) int { - switch l { - case MEDIUM: - return mediumEt.getParities(shards) - case STRONG: - return strongEt.getParities(shards) - case INSANE: - return insaneEt.getParities(shards) - case PARANOID: - return paranoidEt.getParities(shards) - default: + et, err := l.getErasureTable() + if err != nil { return 0 } + return et.getParities(shards) } // GetMaxShards returns back the maximum number of effective data chunks @@ -42,17 +39,40 @@ func (l Level) GetMaxShards() int { // GetEncParities returns number of parities for encrypted chunks based on appendix F table 6 func (l Level) GetEncParities(shards int) int { + et, err := l.getEncErasureTable() + if err != nil { + return 0 + } + return et.getParities(shards) +} + +func (l Level) getErasureTable() (erasureTable, error) { switch l { case MEDIUM: - return encMediumEt.getParities(shards) + return *mediumEt, nil case STRONG: - return encStrongEt.getParities(shards) + return *strongEt, nil case INSANE: - return encInsaneEt.getParities(shards) + return *insaneEt, nil case PARANOID: - return encParanoidEt.getParities(shards) + return *paranoidEt, nil default: - return 0 + return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") + } +} + +func (l Level) getEncErasureTable() (erasureTable, error) { + switch l { + case MEDIUM: + return *encMediumEt, nil + case STRONG: + return *encStrongEt, nil + case INSANE: + return *encInsaneEt, nil + case PARANOID: + return *encParanoidEt, nil + default: + return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") } } diff --git a/pkg/file/redundancy/table.go b/pkg/file/redundancy/table.go index a88505da2b1..1b62e1ec34d 100644 --- a/pkg/file/redundancy/table.go +++ b/pkg/file/redundancy/table.go @@ -4,6 +4,8 @@ package redundancy +import "fmt" + type erasureTable struct { shards []int parities []int @@ -50,3 +52,13 @@ func (et *erasureTable) getParities(maxShards int) int { } return 0 } + +// getMinShards returns back the minimum shard number respect to the given parity number +func (et *erasureTable) GetMinShards(parities int) (int, error) { + for k, p := range et.parities { + if p == parities { + return et.shards[k], nil + } + } + return 0, fmt.Errorf("parity table: there is no minimum shard number for given parity %d", parities) +} diff --git a/pkg/file/utils.go b/pkg/file/utils.go index 298a6578447..59afd2e9cef 100644 --- a/pkg/file/utils.go +++ b/pkg/file/utils.go @@ -5,10 +5,47 @@ package file import ( + "bytes" + "errors" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/swarm" ) +// ChunkPayloadSize returns the effective byte length of an intermediate chunk +// assumes data is always chunk size (without span) +func ChunkPayloadSize(data []byte) (int, error) { + l := len(data) + for l >= swarm.HashSize { + if !bytes.Equal(data[l-swarm.HashSize:l], swarm.ZeroAddress.Bytes()) { + return l, nil + } + + l -= swarm.HashSize + } + + return 0, errors.New("redundancy getter: intermediate chunk does not have at least a child") +} + +// ChunkAddresses returns data shards and parities of the intermediate chunk +// assumes data is truncated by ChunkPayloadSize +func ChunkAddresses(data []byte, parities, reflen int) (sAddresses, pAddresses []swarm.Address) { + shards := (len(data) - parities*swarm.HashSize) / reflen + sAddresses = make([]swarm.Address, shards) + pAddresses = make([]swarm.Address, parities) + offset := 0 + for i := 0; i < shards; i++ { + sAddresses[i] = swarm.NewAddress(data[offset : offset+reflen]) + offset += reflen + } + for i := 0; i < parities; i++ { + pAddresses[i] = swarm.NewAddress(data[offset : offset+swarm.HashSize]) + offset += swarm.HashSize + } + + return sAddresses, pAddresses +} + // ReferenceCount brute-forces the data shard count from which identify the parity count as well in a substree // assumes span > swarm.chunkSize // returns data and parity shard number diff --git a/pkg/node/bootstrap.go b/pkg/node/bootstrap.go index bf2bc4030cb..a1e48228da7 100644 --- a/pkg/node/bootstrap.go +++ b/pkg/node/bootstrap.go @@ -237,7 +237,7 @@ func bootstrapNode( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - reader, l, err = joiner.New(ctx, localStore.Download(true), snapshotReference) + reader, l, err = joiner.New(ctx, localStore.Download(true), localStore.Cache(), snapshotReference) if err != nil { logger.Warning("bootstrap: file joiner failed", "error", err) continue diff --git a/pkg/node/node.go b/pkg/node/node.go index ad96a94ac88..bdf03b76421 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -1083,7 +1083,7 @@ func NewBee( b.resolverCloser = multiResolver feedFactory := factory.New(localStore.Download(true)) - steward := steward.New(localStore, retrieval) + steward := steward.New(localStore, retrieval, localStore.Cache()) extraOpts := api.ExtraOptions{ Pingpong: pingPong, diff --git a/pkg/steward/steward.go b/pkg/steward/steward.go index 9f7bf0641a2..73b3d17b87f 100644 --- a/pkg/steward/steward.go +++ b/pkg/steward/steward.go @@ -36,11 +36,11 @@ type steward struct { netTraverser traversal.Traverser } -func New(ns storer.NetStore, r retrieval.Interface) Interface { +func New(ns storer.NetStore, r retrieval.Interface, joinerPutter storage.Putter) Interface { return &steward{ netStore: ns, - traverser: traversal.New(ns.Download(true)), - netTraverser: traversal.New(&netGetter{r}), + traverser: traversal.New(ns.Download(true), joinerPutter), + netTraverser: traversal.New(&netGetter{r}, joinerPutter), } } diff --git a/pkg/steward/steward_test.go b/pkg/steward/steward_test.go index 3836f75dcfd..bdddbdd3d08 100644 --- a/pkg/steward/steward_test.go +++ b/pkg/steward/steward_test.go @@ -24,15 +24,16 @@ import ( func TestSteward(t *testing.T) { t.Parallel() + inmem := inmemchunkstore.New() var ( ctx = context.Background() chunks = 1000 data = make([]byte, chunks*4096) //1k chunks - chunkStore = inmemchunkstore.New() + chunkStore = inmem store = mockstorer.NewWithChunkStore(chunkStore) localRetrieval = &localRetriever{ChunkStore: chunkStore} - s = steward.New(store, localRetrieval) + s = steward.New(store, localRetrieval, inmem) stamper = postagetesting.NewStamper() ) diff --git a/pkg/storer/epoch_migration.go b/pkg/storer/epoch_migration.go index f93ecafa486..e02ab5b6313 100644 --- a/pkg/storer/epoch_migration.go +++ b/pkg/storer/epoch_migration.go @@ -392,6 +392,7 @@ func (e *epochMigrator) migratePinning(ctx context.Context) error { return swarm.NewChunk(addr, chData), nil }), + pStorage.ChunkStore(), ) e.logger.Debug("migrating pinning collections, if all the chunks in the collection" + diff --git a/pkg/traversal/traversal.go b/pkg/traversal/traversal.go index e9609a90c8f..dbdb5603e82 100644 --- a/pkg/traversal/traversal.go +++ b/pkg/traversal/traversal.go @@ -29,19 +29,20 @@ type Traverser interface { } // New constructs for a new Traverser. -func New(store storage.Getter) Traverser { - return &service{store: store} +func New(getter storage.Getter, putter storage.Putter) Traverser { + return &service{getter: getter, putter: putter} } // service is implementation of Traverser using storage.Storer as its storage. type service struct { - store storage.Getter + getter storage.Getter + putter storage.Putter } // Traverse implements Traverser.Traverse method. func (s *service) Traverse(ctx context.Context, addr swarm.Address, iterFn swarm.AddressIterFunc) error { processBytes := func(ref swarm.Address) error { - j, _, err := joiner.New(ctx, s.store, ref) + j, _, err := joiner.New(ctx, s.getter, s.putter, ref) if err != nil { return fmt.Errorf("traversal: joiner error on %q: %w", ref, err) } @@ -54,7 +55,7 @@ func (s *service) Traverse(ctx context.Context, addr swarm.Address, iterFn swarm // skip SOC check for encrypted references if addr.IsValidLength() { - ch, err := s.store.Get(ctx, addr) + ch, err := s.getter.Get(ctx, addr) if err != nil { return fmt.Errorf("traversal: failed to get root chunk %s: %w", addr.String(), err) } @@ -64,7 +65,7 @@ func (s *service) Traverse(ctx context.Context, addr swarm.Address, iterFn swarm } } - ls := loadsave.NewReadonly(s.store) + ls := loadsave.NewReadonly(s.getter) switch mf, err := manifest.NewDefaultManifestReference(addr, ls); { case errors.Is(err, manifest.ErrInvalidManifestType): break diff --git a/pkg/traversal/traversal_test.go b/pkg/traversal/traversal_test.go index 75b5f925283..7d9d473e1e4 100644 --- a/pkg/traversal/traversal_test.go +++ b/pkg/traversal/traversal_test.go @@ -167,7 +167,7 @@ func TestTraversalBytes(t *testing.T) { t.Fatal(err) } - err = traversal.New(storerMock).Traverse(ctx, address, iter.Next) + err = traversal.New(storerMock, storerMock).Traverse(ctx, address, iter.Next) if err != nil { t.Fatal(err) } @@ -262,7 +262,7 @@ func TestTraversalFiles(t *testing.T) { t.Fatal(err) } - ls := loadsave.New(storerMock, pipelineFactory(storerMock, false)) + ls := loadsave.New(storerMock, storerMock, pipelineFactory(storerMock, false)) fManifest, err := manifest.NewDefaultManifest(ls, false) if err != nil { t.Fatal(err) @@ -294,7 +294,7 @@ func TestTraversalFiles(t *testing.T) { t.Fatal(err) } - err = traversal.New(storerMock).Traverse(ctx, address, iter.Next) + err = traversal.New(storerMock, storerMock).Traverse(ctx, address, iter.Next) if err != nil { t.Fatal(err) } @@ -421,7 +421,7 @@ func TestTraversalManifest(t *testing.T) { } wantHashes = append(wantHashes, tc.manifestHashes...) - ls := loadsave.New(storerMock, pipelineFactory(storerMock, false)) + ls := loadsave.New(storerMock, storerMock, pipelineFactory(storerMock, false)) dirManifest, err := manifest.NewMantarayManifest(ls, false) if err != nil { t.Fatal(err) @@ -452,7 +452,7 @@ func TestTraversalManifest(t *testing.T) { t.Fatal(err) } - err = traversal.New(storerMock).Traverse(ctx, address, iter.Next) + err = traversal.New(storerMock, storerMock).Traverse(ctx, address, iter.Next) if err != nil { t.Fatal(err) } @@ -490,7 +490,7 @@ func TestTraversalSOC(t *testing.T) { t.Fatal(err) } - err = traversal.New(store).Traverse(ctx, sch.Address(), iter.Next) + err = traversal.New(store, store).Traverse(ctx, sch.Address(), iter.Next) if err != nil { t.Fatal(err) }