From 1e65ba02985079c329e677228d9671ed1571f046 Mon Sep 17 00:00:00 2001 From: nugaon <50576770+nugaon@users.noreply.github.com> Date: Mon, 4 Dec 2023 15:40:42 +0100 Subject: [PATCH 01/23] feat: erasure encoder (#4429) --- go.mod | 3 +- go.sum | 6 +- openapi/Swarm.yaml | 6 + openapi/SwarmCommon.yaml | 11 ++ pkg/api/api.go | 12 +- pkg/api/api_test.go | 5 +- pkg/api/bytes.go | 14 +- pkg/api/bzz.go | 23 ++- pkg/api/bzz_test.go | 2 +- pkg/api/dirs.go | 13 +- pkg/api/feed.go | 2 +- pkg/encryption/encryption.go | 7 + pkg/encryption/store/decrypt_store.go | 18 +- pkg/file/file_test.go | 2 +- pkg/file/joiner/joiner_test.go | 6 +- pkg/file/loadsave/loadsave_test.go | 2 +- pkg/file/pipeline/builder/builder.go | 24 ++- pkg/file/pipeline/builder/builder_test.go | 10 +- pkg/file/pipeline/feeder/feeder.go | 1 + pkg/file/pipeline/hashtrie/hashtrie.go | 149 +++++++++----- pkg/file/pipeline/hashtrie/hashtrie_test.go | 172 +++++++++++++++-- pkg/file/redundancy/export_test.go | 8 + pkg/file/redundancy/level.go | 117 +++++++++++ pkg/file/redundancy/redundancy.go | 204 ++++++++++++++++++++ pkg/file/redundancy/redundancy_test.go | 158 +++++++++++++++ pkg/file/redundancy/span.go | 34 ++++ pkg/file/redundancy/table.go | 52 +++++ pkg/file/utils.go | 60 ++++++ pkg/steward/steward_test.go | 2 +- pkg/traversal/traversal_test.go | 8 +- 30 files changed, 1005 insertions(+), 126 deletions(-) create mode 100644 pkg/file/redundancy/export_test.go create mode 100644 pkg/file/redundancy/level.go create mode 100644 pkg/file/redundancy/redundancy.go create mode 100644 pkg/file/redundancy/redundancy_test.go create mode 100644 pkg/file/redundancy/span.go create mode 100644 pkg/file/redundancy/table.go create mode 100644 pkg/file/utils.go diff --git a/go.mod b/go.mod index 3f9b836e08a..5f38b82e566 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.5 github.com/ipfs/go-cid v0.4.1 github.com/kardianos/service v1.2.0 + github.com/klauspost/reedsolomon v1.11.8 github.com/libp2p/go-libp2p v0.30.0 github.com/multiformats/go-multiaddr v0.11.0 github.com/multiformats/go-multiaddr-dns v0.3.1 @@ -105,7 +106,7 @@ require ( github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect github.com/klauspost/compress v1.16.7 // indirect - github.com/klauspost/cpuid/v2 v2.2.5 // indirect + github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/koron/go-ssdp v0.0.4 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/libp2p/go-buffer-pool v0.1.0 // indirect diff --git a/go.sum b/go.sum index 5c83af6a368..0afd81c3918 100644 --- a/go.sum +++ b/go.sum @@ -522,10 +522,12 @@ github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQs github.com/klauspost/cpuid v0.0.0-20170728055534-ae7887de9fa5/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid/v2 v2.0.4/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.6/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= -github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= +github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg= github.com/klauspost/pgzip v1.0.2-0.20170402124221-0bf5dcad4ada/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= +github.com/klauspost/reedsolomon v1.11.8 h1:s8RpUW5TK4hjr+djiOpbZJB4ksx+TdYbRH7vHQpwPOY= +github.com/klauspost/reedsolomon v1.11.8/go.mod h1:4bXRN+cVzMdml6ti7qLouuYi32KHJ5MGv0Qd8a47h6A= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0= diff --git a/openapi/Swarm.yaml b/openapi/Swarm.yaml index accfe44ae2e..a8d6f44c8fe 100644 --- a/openapi/Swarm.yaml +++ b/openapi/Swarm.yaml @@ -120,6 +120,11 @@ paths: $ref: "SwarmCommon.yaml#/components/parameters/SwarmEncryptParameter" name: swarm-encrypt required: false + - in: header + schema: + $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyLevel" + name: swarm-redundancy-level + required: false requestBody: content: @@ -254,6 +259,7 @@ paths: - $ref: "SwarmCommon.yaml#/components/parameters/SwarmErrorDocumentParameter" - $ref: "SwarmCommon.yaml#/components/parameters/SwarmPostageBatchId" - $ref: "SwarmCommon.yaml#/components/parameters/SwarmDeferredUpload" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyParameter" requestBody: content: multipart/form-data: diff --git a/openapi/SwarmCommon.yaml b/openapi/SwarmCommon.yaml index a3e46342eea..c5ac5674fe9 100644 --- a/openapi/SwarmCommon.yaml +++ b/openapi/SwarmCommon.yaml @@ -934,6 +934,17 @@ components: description: > Represents the encrypting state of the file + SwarmRedundancyParameter: + in: header + name: swarm-redundancy-level + schema: + type: integer + enum: [0, 1, 2, 3, 4] + required: false + description: > + Add redundancy to the data being uploaded so that downloaders can download it with better UX. + 0 value is default and does not add any redundancy to the file. + ContentTypePreserved: in: header name: Content-Type diff --git a/pkg/api/api.go b/pkg/api/api.go index ef680fd3ea5..2dd583b7df4 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -33,6 +33,7 @@ import ( "github.com/ethersphere/bee/pkg/feeds" "github.com/ethersphere/bee/pkg/file/pipeline" "github.com/ethersphere/bee/pkg/file/pipeline/builder" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/p2p" @@ -79,6 +80,7 @@ const ( SwarmCollectionHeader = "Swarm-Collection" SwarmPostageBatchIdHeader = "Swarm-Postage-Batch-Id" SwarmDeferredUploadHeader = "Swarm-Deferred-Upload" + SwarmRLevel = "Swarm-Redundancy-Level" ImmutableHeader = "Immutable" GasPriceHeader = "Gas-Price" @@ -622,7 +624,7 @@ func (s *Service) corsHandler(h http.Handler) http.Handler { allowedHeaders := []string{ "User-Agent", "Accept", "X-Requested-With", "Access-Control-Request-Headers", "Access-Control-Request-Method", "Accept-Ranges", "Content-Encoding", AuthorizationHeader, AcceptEncodingHeader, ContentTypeHeader, ContentDispositionHeader, RangeHeader, OriginHeader, - SwarmTagHeader, SwarmPinHeader, SwarmEncryptHeader, SwarmIndexDocumentHeader, SwarmErrorDocumentHeader, SwarmCollectionHeader, SwarmPostageBatchIdHeader, SwarmDeferredUploadHeader, + SwarmTagHeader, SwarmPinHeader, SwarmEncryptHeader, SwarmIndexDocumentHeader, SwarmErrorDocumentHeader, SwarmCollectionHeader, SwarmPostageBatchIdHeader, SwarmDeferredUploadHeader, SwarmRLevel, GasPriceHeader, GasLimitHeader, ImmutableHeader, } allowedHeadersStr := strings.Join(allowedHeaders, ", ") @@ -848,16 +850,16 @@ func (s *Service) newStamperPutter(ctx context.Context, opts putterOptions) (sto type pipelineFunc func(context.Context, io.Reader) (swarm.Address, error) -func requestPipelineFn(s storage.Putter, encrypt bool) pipelineFunc { +func requestPipelineFn(s storage.Putter, encrypt bool, rLevel redundancy.Level) pipelineFunc { return func(ctx context.Context, r io.Reader) (swarm.Address, error) { - pipe := builder.NewPipelineBuilder(ctx, s, encrypt) + pipe := builder.NewPipelineBuilder(ctx, s, encrypt, rLevel) return builder.FeedPipeline(ctx, pipe, r) } } -func requestPipelineFactory(ctx context.Context, s storage.Putter, encrypt bool) func() pipeline.Interface { +func requestPipelineFactory(ctx context.Context, s storage.Putter, encrypt bool, rLevel redundancy.Level) func() pipeline.Interface { return func() pipeline.Interface { - return builder.NewPipelineBuilder(ctx, s, encrypt) + return builder.NewPipelineBuilder(ctx, s, encrypt, rLevel) } } diff --git a/pkg/api/api_test.go b/pkg/api/api_test.go index a31bb87ff5d..b06e5808479 100644 --- a/pkg/api/api_test.go +++ b/pkg/api/api_test.go @@ -31,6 +31,7 @@ import ( "github.com/ethersphere/bee/pkg/feeds" "github.com/ethersphere/bee/pkg/file/pipeline" "github.com/ethersphere/bee/pkg/file/pipeline/builder" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/log" p2pmock "github.com/ethersphere/bee/pkg/p2p/mock" @@ -321,9 +322,9 @@ func request(t *testing.T, client *http.Client, method, resource string, body io return resp } -func pipelineFactory(s storage.Putter, encrypt bool) func() pipeline.Interface { +func pipelineFactory(s storage.Putter, encrypt bool, rLevel redundancy.Level) func() pipeline.Interface { return func() pipeline.Interface { - return builder.NewPipelineBuilder(context.Background(), s, encrypt) + return builder.NewPipelineBuilder(context.Background(), s, encrypt, rLevel) } } diff --git a/pkg/api/bytes.go b/pkg/api/bytes.go index dd2861ff285..894da9a4b06 100644 --- a/pkg/api/bytes.go +++ b/pkg/api/bytes.go @@ -12,6 +12,7 @@ import ( "strconv" "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/postage" storage "github.com/ethersphere/bee/pkg/storage" @@ -29,11 +30,12 @@ func (s *Service) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { logger := tracing.NewLoggerWithTraceID(r.Context(), s.logger.WithName("post_bytes").Build()) headers := struct { - BatchID []byte `map:"Swarm-Postage-Batch-Id" validate:"required"` - SwarmTag uint64 `map:"Swarm-Tag"` - Pin bool `map:"Swarm-Pin"` - Deferred *bool `map:"Swarm-Deferred-Upload"` - Encrypt bool `map:"Swarm-Encrypt"` + BatchID []byte `map:"Swarm-Postage-Batch-Id" validate:"required"` + SwarmTag uint64 `map:"Swarm-Tag"` + Pin bool `map:"Swarm-Pin"` + Deferred *bool `map:"Swarm-Deferred-Upload"` + Encrypt bool `map:"Swarm-Encrypt"` + RLevel redundancy.Level `map:"Swarm-Redundancy-Level"` }{} if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) @@ -91,7 +93,7 @@ func (s *Service) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { logger: logger, } - p := requestPipelineFn(putter, headers.Encrypt) + p := requestPipelineFn(putter, headers.Encrypt, headers.RLevel) address, err := p(r.Context(), r.Body) if err != nil { logger.Debug("split write all failed", "error", err) diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index 1dba2add517..b1605bae90d 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -21,6 +21,7 @@ import ( "github.com/ethersphere/bee/pkg/feeds" "github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/loadsave" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/manifest" @@ -37,13 +38,14 @@ func (s *Service) bzzUploadHandler(w http.ResponseWriter, r *http.Request) { logger := tracing.NewLoggerWithTraceID(r.Context(), s.logger.WithName("post_bzz").Build()) headers := struct { - ContentType string `map:"Content-Type,mimeMediaType" validate:"required"` - BatchID []byte `map:"Swarm-Postage-Batch-Id" validate:"required"` - SwarmTag uint64 `map:"Swarm-Tag"` - Pin bool `map:"Swarm-Pin"` - Deferred *bool `map:"Swarm-Deferred-Upload"` - Encrypt bool `map:"Swarm-Encrypt"` - IsDir bool `map:"Swarm-Collection"` + ContentType string `map:"Content-Type,mimeMediaType" validate:"required"` + BatchID []byte `map:"Swarm-Postage-Batch-Id" validate:"required"` + SwarmTag uint64 `map:"Swarm-Tag"` + Pin bool `map:"Swarm-Pin"` + Deferred *bool `map:"Swarm-Deferred-Upload"` + Encrypt bool `map:"Swarm-Encrypt"` + IsDir bool `map:"Swarm-Collection"` + RLevel redundancy.Level `map:"Swarm-Redundancy-Level"` }{} if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) @@ -105,7 +107,7 @@ func (s *Service) bzzUploadHandler(w http.ResponseWriter, r *http.Request) { s.dirUploadHandler(logger, ow, r, putter, r.Header.Get(ContentTypeHeader), headers.Encrypt, tag) return } - s.fileUploadHandler(logger, ow, r, putter, headers.Encrypt, tag) + s.fileUploadHandler(logger, ow, r, putter, headers.Encrypt, tag, headers.RLevel) } // fileUploadResponse is returned when an HTTP request to upload a file is successful @@ -122,6 +124,7 @@ func (s *Service) fileUploadHandler( putter storer.PutterSession, encrypt bool, tagID uint64, + rLevel redundancy.Level, ) { queries := struct { FileName string `map:"name" validate:"startsnotwith=/"` @@ -131,7 +134,7 @@ func (s *Service) fileUploadHandler( return } - p := requestPipelineFn(putter, encrypt) + p := requestPipelineFn(putter, encrypt, rLevel) ctx := r.Context() // first store the file and get its reference @@ -171,7 +174,7 @@ func (s *Service) fileUploadHandler( } } - factory := requestPipelineFactory(ctx, putter, encrypt) + factory := requestPipelineFactory(ctx, putter, encrypt, rLevel) l := loadsave.New(s.storer.ChunkStore(), factory) m, err := manifest.NewDefaultManifest(l, encrypt) diff --git a/pkg/api/bzz_test.go b/pkg/api/bzz_test.go index 4e54a967d3a..731cfcb567f 100644 --- a/pkg/api/bzz_test.go +++ b/pkg/api/bzz_test.go @@ -554,7 +554,7 @@ func TestFeedIndirection(t *testing.T) { t.Fatal(err) } m, err := manifest.NewDefaultManifest( - loadsave.New(storer.ChunkStore(), pipelineFactory(storer.Cache(), false)), + loadsave.New(storer.ChunkStore(), pipelineFactory(storer.Cache(), false, 0)), false, ) if err != nil { diff --git a/pkg/api/dirs.go b/pkg/api/dirs.go index edd4f741120..771787a24a9 100644 --- a/pkg/api/dirs.go +++ b/pkg/api/dirs.go @@ -19,6 +19,7 @@ import ( "strings" "github.com/ethersphere/bee/pkg/file/loadsave" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/manifest" @@ -63,6 +64,12 @@ func (s *Service) dirUploadHandler( } defer r.Body.Close() + rsParity, err := strconv.ParseUint(r.Header.Get(SwarmRLevel), 10, 1) + if err != nil { + logger.Debug("store dir failed", "rsParity parsing error") + logger.Error(nil, "store dir failed") + } + reference, err := storeDir( r.Context(), encrypt, @@ -72,6 +79,7 @@ func (s *Service) dirUploadHandler( s.storer.ChunkStore(), r.Header.Get(SwarmIndexDocumentHeader), r.Header.Get(SwarmErrorDocumentHeader), + redundancy.Level(rsParity), ) if err != nil { logger.Debug("store dir failed", "error", err) @@ -117,13 +125,14 @@ func storeDir( getter storage.Getter, indexFilename, errorFilename string, + rLevel redundancy.Level, ) (swarm.Address, error) { logger := tracing.NewLoggerWithTraceID(ctx, log) loggerV1 := logger.V(1).Build() - p := requestPipelineFn(putter, encrypt) - ls := loadsave.New(getter, requestPipelineFactory(ctx, putter, encrypt)) + p := requestPipelineFn(putter, encrypt, rLevel) + ls := loadsave.New(getter, requestPipelineFactory(ctx, putter, encrypt, rLevel)) dirManifest, err := manifest.NewDefaultManifest(ls, encrypt) if err != nil { diff --git a/pkg/api/feed.go b/pkg/api/feed.go index d85f89fe9b3..f3a3413a426 100644 --- a/pkg/api/feed.go +++ b/pkg/api/feed.go @@ -196,7 +196,7 @@ func (s *Service) feedPostHandler(w http.ResponseWriter, r *http.Request) { logger: logger, } - l := loadsave.New(s.storer.ChunkStore(), requestPipelineFactory(r.Context(), putter, false)) + l := loadsave.New(s.storer.ChunkStore(), requestPipelineFactory(r.Context(), putter, false, 0)) feedManifest, err := manifest.NewDefaultManifest(l, false) if err != nil { logger.Debug("create manifest failed", "error", err) diff --git a/pkg/encryption/encryption.go b/pkg/encryption/encryption.go index c0019dd76bd..1d3fc3d08d8 100644 --- a/pkg/encryption/encryption.go +++ b/pkg/encryption/encryption.go @@ -185,3 +185,10 @@ func GenerateRandomKey(l int) Key { } return key } + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/pkg/encryption/store/decrypt_store.go b/pkg/encryption/store/decrypt_store.go index cf16b8ac3cd..7a9f0395e91 100644 --- a/pkg/encryption/store/decrypt_store.go +++ b/pkg/encryption/store/decrypt_store.go @@ -9,6 +9,8 @@ import ( "encoding/binary" "github.com/ethersphere/bee/pkg/encryption" + "github.com/ethersphere/bee/pkg/file" + "github.com/ethersphere/bee/pkg/file/redundancy" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "golang.org/x/crypto/sha3" @@ -37,7 +39,7 @@ func (s *decryptingStore) Get(ctx context.Context, addr swarm.Address) (ch swarm return nil, err } - d, err := decryptChunkData(ch.Data(), ref[swarm.HashSize:]) + d, err := DecryptChunkData(ch.Data(), ref[swarm.HashSize:]) if err != nil { return nil, err } @@ -48,19 +50,19 @@ func (s *decryptingStore) Get(ctx context.Context, addr swarm.Address) (ch swarm } } -func decryptChunkData(chunkData []byte, encryptionKey encryption.Key) ([]byte, error) { +func DecryptChunkData(chunkData []byte, encryptionKey encryption.Key) ([]byte, error) { decryptedSpan, decryptedData, err := decrypt(chunkData, encryptionKey) if err != nil { return nil, err } // removing extra bytes which were just added for padding - length := binary.LittleEndian.Uint64(decryptedSpan) - refSize := int64(swarm.HashSize + encryption.KeyLength) - for length > swarm.ChunkSize { - length = length + (swarm.ChunkSize - 1) - length = length / swarm.ChunkSize - length *= uint64(refSize) + level, span := redundancy.DecodeSpan(decryptedSpan) + length := binary.LittleEndian.Uint64(span) + if length > swarm.ChunkSize { + dataRefSize := uint64(swarm.HashSize + encryption.KeyLength) + dataShards, parities := file.ReferenceCount(length, level, true) + length = dataRefSize*uint64(dataShards) + uint64(parities*swarm.HashSize) } c := make([]byte, length+8) diff --git a/pkg/file/file_test.go b/pkg/file/file_test.go index 951b679d591..fd85a23fb9d 100644 --- a/pkg/file/file_test.go +++ b/pkg/file/file_test.go @@ -48,7 +48,7 @@ func testSplitThenJoin(t *testing.T) { paramstring = strings.Split(t.Name(), "/") dataIdx, _ = strconv.ParseInt(paramstring[1], 10, 0) store = inmemchunkstore.New() - p = builder.NewPipelineBuilder(context.Background(), store, false) + p = builder.NewPipelineBuilder(context.Background(), store, false, 0) data, _ = test.GetVector(t, int(dataIdx)) ) diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 6063df6d846..772c0e12ae2 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -182,7 +182,7 @@ func TestJoinerMalformed(t *testing.T) { defer cancel() subTrie := []byte{8085: 1} - pb := builder.NewPipelineBuilder(ctx, store, false) + pb := builder.NewPipelineBuilder(ctx, store, false, 0) c1addr, _ := builder.FeedPipeline(ctx, pb, bytes.NewReader(subTrie)) chunk2 := testingc.GenerateTestRandomChunk() @@ -248,7 +248,7 @@ func TestEncryptDecrypt(t *testing.T) { t.Fatal(err) } ctx := context.Background() - pipe := builder.NewPipelineBuilder(ctx, store, true) + pipe := builder.NewPipelineBuilder(ctx, store, true, 0) testDataReader := bytes.NewReader(testData) resultAddress, err := builder.FeedPipeline(ctx, pipe, testDataReader) if err != nil { @@ -911,7 +911,7 @@ func TestJoinerIterateChunkAddresses_Encrypted(t *testing.T) { t.Fatal(err) } ctx := context.Background() - pipe := builder.NewPipelineBuilder(ctx, store, true) + pipe := builder.NewPipelineBuilder(ctx, store, true, 0) testDataReader := bytes.NewReader(testData) resultAddress, err := builder.FeedPipeline(ctx, pipe, testDataReader) if err != nil { diff --git a/pkg/file/loadsave/loadsave_test.go b/pkg/file/loadsave/loadsave_test.go index 21f4847bef9..cfb55a953a6 100644 --- a/pkg/file/loadsave/loadsave_test.go +++ b/pkg/file/loadsave/loadsave_test.go @@ -73,6 +73,6 @@ func TestReadonlyLoadSave(t *testing.T) { func pipelineFn(s storage.Putter) func() pipeline.Interface { return func() pipeline.Interface { - return builder.NewPipelineBuilder(context.Background(), s, false) + return builder.NewPipelineBuilder(context.Background(), s, false, 0) } } diff --git a/pkg/file/pipeline/builder/builder.go b/pkg/file/pipeline/builder/builder.go index 0833f50317c..e5cf03b2da4 100644 --- a/pkg/file/pipeline/builder/builder.go +++ b/pkg/file/pipeline/builder/builder.go @@ -17,23 +17,29 @@ import ( "github.com/ethersphere/bee/pkg/file/pipeline/feeder" "github.com/ethersphere/bee/pkg/file/pipeline/hashtrie" "github.com/ethersphere/bee/pkg/file/pipeline/store" + "github.com/ethersphere/bee/pkg/file/redundancy" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" ) // NewPipelineBuilder returns the appropriate pipeline according to the specified parameters -func NewPipelineBuilder(ctx context.Context, s storage.Putter, encrypt bool) pipeline.Interface { +func NewPipelineBuilder(ctx context.Context, s storage.Putter, encrypt bool, rLevel redundancy.Level) pipeline.Interface { if encrypt { - return newEncryptionPipeline(ctx, s) + return newEncryptionPipeline(ctx, s, rLevel) } - return newPipeline(ctx, s) + return newPipeline(ctx, s, rLevel) } // newPipeline creates a standard pipeline that only hashes content with BMT to create // a merkle-tree of hashes that represent the given arbitrary size byte stream. Partial // writes are supported. The pipeline flow is: Data -> Feeder -> BMT -> Storage -> HashTrie. -func newPipeline(ctx context.Context, s storage.Putter) pipeline.Interface { - tw := hashtrie.NewHashTrieWriter(swarm.ChunkSize, swarm.Branches, swarm.HashSize, newShortPipelineFunc(ctx, s)) +func newPipeline(ctx context.Context, s storage.Putter, rLevel redundancy.Level) pipeline.Interface { + pipeline := newShortPipelineFunc(ctx, s) + tw := hashtrie.NewHashTrieWriter( + swarm.HashSize, + redundancy.New(rLevel, false, pipeline), + pipeline, + ) lsw := store.NewStoreWriter(ctx, s, tw) b := bmt.NewBmtWriter(lsw) return feeder.NewChunkFeederWriter(swarm.ChunkSize, b) @@ -53,8 +59,12 @@ func newShortPipelineFunc(ctx context.Context, s storage.Putter) func() pipeline // writes are supported. The pipeline flow is: Data -> Feeder -> Encryption -> BMT -> Storage -> HashTrie. // Note that the encryption writer will mutate the data to contain the encrypted span, but the span field // with the unencrypted span is preserved. -func newEncryptionPipeline(ctx context.Context, s storage.Putter) pipeline.Interface { - tw := hashtrie.NewHashTrieWriter(swarm.ChunkSize, 64, swarm.HashSize+encryption.KeyLength, newShortEncryptionPipelineFunc(ctx, s)) +func newEncryptionPipeline(ctx context.Context, s storage.Putter, rLevel redundancy.Level) pipeline.Interface { + tw := hashtrie.NewHashTrieWriter( + swarm.HashSize+encryption.KeyLength, + redundancy.New(rLevel, true, newShortPipelineFunc(ctx, s)), + newShortEncryptionPipelineFunc(ctx, s), + ) lsw := store.NewStoreWriter(ctx, s, tw) b := bmt.NewBmtWriter(lsw) enc := enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b) diff --git a/pkg/file/pipeline/builder/builder_test.go b/pkg/file/pipeline/builder/builder_test.go index c720f25f356..266474d9376 100644 --- a/pkg/file/pipeline/builder/builder_test.go +++ b/pkg/file/pipeline/builder/builder_test.go @@ -23,7 +23,7 @@ func TestPartialWrites(t *testing.T) { t.Parallel() m := inmemchunkstore.New() - p := builder.NewPipelineBuilder(context.Background(), m, false) + p := builder.NewPipelineBuilder(context.Background(), m, false, 0) _, _ = p.Write([]byte("hello ")) _, _ = p.Write([]byte("world")) @@ -41,7 +41,7 @@ func TestHelloWorld(t *testing.T) { t.Parallel() m := inmemchunkstore.New() - p := builder.NewPipelineBuilder(context.Background(), m, false) + p := builder.NewPipelineBuilder(context.Background(), m, false, 0) data := []byte("hello world") _, err := p.Write(data) @@ -64,7 +64,7 @@ func TestEmpty(t *testing.T) { t.Parallel() m := inmemchunkstore.New() - p := builder.NewPipelineBuilder(context.Background(), m, false) + p := builder.NewPipelineBuilder(context.Background(), m, false, 0) data := []byte{} _, err := p.Write(data) @@ -92,7 +92,7 @@ func TestAllVectors(t *testing.T) { t.Parallel() m := inmemchunkstore.New() - p := builder.NewPipelineBuilder(context.Background(), m, false) + p := builder.NewPipelineBuilder(context.Background(), m, false, 0) _, err := p.Write(data) if err != nil { @@ -155,7 +155,7 @@ func benchmarkPipeline(b *testing.B, count int) { data := testutil.RandBytes(b, count) m := inmemchunkstore.New() - p := builder.NewPipelineBuilder(context.Background(), m, false) + p := builder.NewPipelineBuilder(context.Background(), m, false, 0) b.StartTimer() diff --git a/pkg/file/pipeline/feeder/feeder.go b/pkg/file/pipeline/feeder/feeder.go index 03d77cfb49c..9b4b60cae61 100644 --- a/pkg/file/pipeline/feeder/feeder.go +++ b/pkg/file/pipeline/feeder/feeder.go @@ -75,6 +75,7 @@ func (f *chunkFeeder) Write(b []byte) (int, error) { sp += n binary.LittleEndian.PutUint64(d[:span], uint64(sp)) + args := &pipeline.PipeWriteArgs{Data: d[:span+sp], Span: d[:span]} err := f.next.ChainWrite(args) if err != nil { diff --git a/pkg/file/pipeline/hashtrie/hashtrie.go b/pkg/file/pipeline/hashtrie/hashtrie.go index c559a646c55..fc1f040aca5 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie.go +++ b/pkg/file/pipeline/hashtrie/hashtrie.go @@ -9,6 +9,7 @@ import ( "errors" "github.com/ethersphere/bee/pkg/file/pipeline" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/swarm" ) @@ -20,26 +21,34 @@ var ( const maxLevel = 8 type hashTrieWriter struct { - branching int - chunkSize int - refSize int - fullChunk int // full chunk size in terms of the data represented in the buffer (span+refsize) - cursors []int // level cursors, key is level. level 0 is data level and is not represented in this package. writes always start at level 1. higher levels will always have LOWER cursor values. - buffer []byte // keeps all level data - full bool // indicates whether the trie is full. currently we support (128^7)*4096 = 2305843009213693952 bytes - pipelineFn pipeline.PipelineFunc + refSize int + cursors []int // level cursors, key is level. level 0 is data level holds how many chunks were processed. Intermediate higher levels will always have LOWER cursor values. + buffer []byte // keeps intermediate level data + full bool // indicates whether the trie is full. currently we support (128^7)*4096 = 2305843009213693952 bytes + pipelineFn pipeline.PipelineFunc + rParams redundancy.IParams + parityChunkFn redundancy.ParityChunkCallback + chunkCounters []uint8 // counts the chunk references in intermediate chunks. key is the chunk level. + effectiveChunkCounters []uint8 // counts the effective chunk references in intermediate chunks. key is the chunk level. + maxChildrenChunks uint8 // maximum number of chunk references in intermediate chunks. } -func NewHashTrieWriter(chunkSize, branching, refLen int, pipelineFn pipeline.PipelineFunc) pipeline.ChainWriter { - return &hashTrieWriter{ - cursors: make([]int, 9), - buffer: make([]byte, swarm.ChunkWithSpanSize*9*2), // double size as temp workaround for weak calculation of needed buffer space - branching: branching, - chunkSize: chunkSize, - refSize: refLen, - fullChunk: (refLen + swarm.SpanSize) * branching, - pipelineFn: pipelineFn, +func NewHashTrieWriter(refLen int, rParams redundancy.IParams, pipelineFn pipeline.PipelineFunc) pipeline.ChainWriter { + h := &hashTrieWriter{ + refSize: refLen, + cursors: make([]int, 9), + buffer: make([]byte, swarm.ChunkWithSpanSize*9*2), // double size as temp workaround for weak calculation of needed buffer space + rParams: rParams, + pipelineFn: pipelineFn, + chunkCounters: make([]uint8, 9), + effectiveChunkCounters: make([]uint8, 9), + maxChildrenChunks: uint8(rParams.MaxShards() + rParams.Parities(rParams.MaxShards())), } + h.parityChunkFn = func(level int, span, address []byte) error { + return h.writeToIntermediateLevel(level, true, span, address, []byte{}) + } + + return h } // accepts writes of hashes from the previous writer in the chain, by definition these writes @@ -47,30 +56,50 @@ func NewHashTrieWriter(chunkSize, branching, refLen int, pipelineFn pipeline.Pip func (h *hashTrieWriter) ChainWrite(p *pipeline.PipeWriteArgs) error { oneRef := h.refSize + swarm.SpanSize l := len(p.Span) + len(p.Ref) + len(p.Key) - if l%oneRef != 0 { + if l%oneRef != 0 || l == 0 { return errInconsistentRefs } if h.full { return errTrieFull } - return h.writeToLevel(1, p.Span, p.Ref, p.Key) + if h.rParams.Level() == redundancy.NONE { + return h.writeToIntermediateLevel(1, false, p.Span, p.Ref, p.Key) + } else { + return h.writeToDataLevel(p.Span, p.Ref, p.Key, p.Data) + } } -func (h *hashTrieWriter) writeToLevel(level int, span, ref, key []byte) error { +func (h *hashTrieWriter) writeToIntermediateLevel(level int, parityChunk bool, span, ref, key []byte) error { copy(h.buffer[h.cursors[level]:h.cursors[level]+len(span)], span) h.cursors[level] += len(span) copy(h.buffer[h.cursors[level]:h.cursors[level]+len(ref)], ref) h.cursors[level] += len(ref) copy(h.buffer[h.cursors[level]:h.cursors[level]+len(key)], key) h.cursors[level] += len(key) - howLong := (h.refSize + swarm.SpanSize) * h.branching - if h.levelSize(level) == howLong { - return h.wrapFullLevel(level) + // update counters + if !parityChunk { + h.effectiveChunkCounters[level]++ + } + h.chunkCounters[level]++ + if h.chunkCounters[level] == h.maxChildrenChunks { + // at this point the erasure coded chunks have been written + err := h.wrapFullLevel(level) + return err } return nil } +func (h *hashTrieWriter) writeToDataLevel(span, ref, key, data []byte) error { + // write dataChunks to the level above + err := h.writeToIntermediateLevel(1, false, span, ref, key) + if err != nil { + return err + } + + return h.rParams.ChunkWrite(0, data, h.parityChunkFn) +} + // wrapLevel wraps an existing level and writes the resulting hash to the following level // then truncates the current level data by shifting the cursors. // Steps are performed in the following order: @@ -81,20 +110,36 @@ func (h *hashTrieWriter) writeToLevel(level int, span, ref, key []byte) error { // - get the hash that was created, append it one level above, and if necessary, wrap that level too // - remove already hashed data from buffer // -// assumes that the function has been called when refsize+span*branching has been reached +// assumes that h.chunkCounters[level] has reached h.maxChildrenChunks at fullchunk +// or redundancy.Encode was called in case of rightmost chunks func (h *hashTrieWriter) wrapFullLevel(level int) error { data := h.buffer[h.cursors[level+1]:h.cursors[level]] sp := uint64(0) var hashes []byte - for i := 0; i < len(data); i += h.refSize + 8 { + offset := 0 + for i := uint8(0); i < h.effectiveChunkCounters[level]; i++ { // sum up the spans of the level, then we need to bmt them and store it as a chunk // then write the chunk address to the next level up - sp += binary.LittleEndian.Uint64(data[i : i+8]) - hash := data[i+8 : i+h.refSize+8] + sp += binary.LittleEndian.Uint64(data[offset : offset+swarm.SpanSize]) + offset += +swarm.SpanSize + hash := data[offset : offset+h.refSize] + offset += h.refSize + hashes = append(hashes, hash...) + } + parities := 0 + for offset < len(data) { + // we do not add span of parity chunks to the common because that is gibberish + offset += +swarm.SpanSize + hash := data[offset : offset+swarm.HashSize] // parity reference has always hash length + offset += swarm.HashSize hashes = append(hashes, hash...) + parities++ } spb := make([]byte, 8) binary.LittleEndian.PutUint64(spb, sp) + if parities > 0 { + redundancy.EncodeLevel(spb, h.rParams.Level()) + } hashes = append(spb, hashes...) writer := h.pipelineFn() args := pipeline.PipeWriteArgs{ @@ -105,7 +150,13 @@ func (h *hashTrieWriter) wrapFullLevel(level int) error { if err != nil { return err } - err = h.writeToLevel(level+1, args.Span, args.Ref, args.Key) + + err = h.writeToIntermediateLevel(level+1, false, args.Span, args.Ref, args.Key) + if err != nil { + return err + } + + err = h.rParams.ChunkWrite(level, args.Data, h.parityChunkFn) if err != nil { return err } @@ -113,19 +164,13 @@ func (h *hashTrieWriter) wrapFullLevel(level int) error { // this "truncates" the current level that was wrapped // by setting the cursors to the cursors of one level above h.cursors[level] = h.cursors[level+1] + h.chunkCounters[level], h.effectiveChunkCounters[level] = 0, 0 if level+1 == 8 { h.full = true } return nil } -func (h *hashTrieWriter) levelSize(level int) int { - if level == 8 { - return h.cursors[level] - } - return h.cursors[level] - h.cursors[level+1] -} - // Sum returns the Swarm merkle-root content-addressed hash // of an arbitrary-length binary data. // The algorithm it uses is as follows: @@ -142,25 +187,22 @@ func (h *hashTrieWriter) levelSize(level int) int { // - more than one hash, in which case we _do_ perform a hashing operation, appending the hash to // the next level func (h *hashTrieWriter) Sum() ([]byte, error) { - oneRef := h.refSize + swarm.SpanSize for i := 1; i < maxLevel; i++ { - l := h.levelSize(i) - if l%oneRef != 0 { - return nil, errInconsistentRefs - } + l := h.chunkCounters[i] switch { case l == 0: // level empty, continue to the next. continue - case l == h.fullChunk: + case l == h.maxChildrenChunks: // this case is possible and necessary due to the carry over // in the next switch case statement. normal writes done // through writeToLevel will automatically wrap a full level. + // erasure encoding call is not necessary since ElevateCarrierChunk solves that err := h.wrapFullLevel(i) if err != nil { return nil, err } - case l == oneRef: + case l == 1: // this cursor assignment basically means: // take the hash|span|key from this level, and append it to // the data of the next level. you may wonder how this works: @@ -175,20 +217,33 @@ func (h *hashTrieWriter) Sum() ([]byte, error) { // that might or might not have data. the eventual result is that the last // hash generated will always be carried over to the last level (8), then returned. h.cursors[i+1] = h.cursors[i] + // replace cached chunk to the level as well + err := h.rParams.ElevateCarrierChunk(i-1, h.parityChunkFn) + if err != nil { + return nil, err + } + // update counters, subtracting from current level is not necessary + h.effectiveChunkCounters[i+1]++ + h.chunkCounters[i+1]++ default: + // call erasure encoding before writing the last chunk on the level + err := h.rParams.Encode(i-1, h.parityChunkFn) + if err != nil { + return nil, err + } // more than 0 but smaller than chunk size - wrap the level to the one above it - err := h.wrapFullLevel(i) + err = h.wrapFullLevel(i) if err != nil { return nil, err } } } - levelLen := h.levelSize(8) - if levelLen != oneRef { + levelLen := h.chunkCounters[maxLevel] + if levelLen != 1 { return nil, errInconsistentRefs } // return the hash in the highest level, that's all we need - data := h.buffer[0:h.cursors[8]] - return data[8:], nil + data := h.buffer[0:h.cursors[maxLevel]] + return data[swarm.SpanSize:], nil } diff --git a/pkg/file/pipeline/hashtrie/hashtrie_test.go b/pkg/file/pipeline/hashtrie/hashtrie_test.go index 2078bbf18d1..beebd77cf74 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie_test.go +++ b/pkg/file/pipeline/hashtrie/hashtrie_test.go @@ -5,15 +5,24 @@ package hashtrie_test import ( + "bytes" "context" "encoding/binary" "errors" "testing" + bmtUtils "github.com/ethersphere/bee/pkg/bmt" + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/encryption" + dec "github.com/ethersphere/bee/pkg/encryption/store" + "github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file/pipeline" "github.com/ethersphere/bee/pkg/file/pipeline/bmt" + enc "github.com/ethersphere/bee/pkg/file/pipeline/encryption" "github.com/ethersphere/bee/pkg/file/pipeline/hashtrie" + "github.com/ethersphere/bee/pkg/file/pipeline/mock" "github.com/ethersphere/bee/pkg/file/pipeline/store" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" "github.com/ethersphere/bee/pkg/swarm" ) @@ -38,9 +47,7 @@ func TestLevels(t *testing.T) { t.Parallel() var ( - branching = 4 - chunkSize = 128 - hashSize = 32 + hashSize = 32 ) // to create a level wrap we need to do branching^(level-1) writes @@ -100,7 +107,7 @@ func TestLevels(t *testing.T) { return bmt.NewBmtWriter(lsw) } - ht := hashtrie.NewHashTrieWriter(chunkSize, branching, hashSize, pf) + ht := hashtrie.NewHashTrieWriter(hashSize, redundancy.New(0, false, pf), pf) for i := 0; i < tc.writes; i++ { a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span} @@ -129,21 +136,31 @@ func TestLevels(t *testing.T) { } } +type redundancyMock struct { + redundancy.Params +} + +func (r redundancyMock) MaxShards() int { + return 4 +} + func TestLevels_TrieFull(t *testing.T) { t.Parallel() var ( - branching = 4 - chunkSize = 128 - hashSize = 32 - writes = 16384 // this is to get a balanced trie - s = inmemchunkstore.New() - pf = func() pipeline.ChainWriter { + hashSize = 32 + writes = 16384 // this is to get a balanced trie + s = inmemchunkstore.New() + pf = func() pipeline.ChainWriter { lsw := store.NewStoreWriter(ctx, s, nil) return bmt.NewBmtWriter(lsw) } + r = redundancy.New(0, false, pf) + rMock = &redundancyMock{ + Params: *r, + } - ht = hashtrie.NewHashTrieWriter(chunkSize, branching, hashSize, pf) + ht = hashtrie.NewHashTrieWriter(hashSize, rMock, pf) ) // to create a level wrap we need to do branching^(level-1) writes @@ -176,17 +193,15 @@ func TestRegression(t *testing.T) { t.Parallel() var ( - branching = 128 - chunkSize = 4096 - hashSize = 32 - writes = 67100000 / 4096 - span = make([]byte, 8) - s = inmemchunkstore.New() - pf = func() pipeline.ChainWriter { + hashSize = 32 + writes = 67100000 / 4096 + span = make([]byte, 8) + s = inmemchunkstore.New() + pf = func() pipeline.ChainWriter { lsw := store.NewStoreWriter(ctx, s, nil) return bmt.NewBmtWriter(lsw) } - ht = hashtrie.NewHashTrieWriter(chunkSize, branching, hashSize, pf) + ht = hashtrie.NewHashTrieWriter(hashSize, redundancy.New(0, false, pf), pf) ) binary.LittleEndian.PutUint64(span, 4096) @@ -213,3 +228,122 @@ func TestRegression(t *testing.T) { t.Fatalf("want span %d got %d", writes*4096, sp) } } + +// TestRedundancy using erasure coding library and checks carrierChunk function and modified span in intermediate chunk +func TestRedundancy(t *testing.T) { + t.Parallel() + // chunks need to have data so that it will not throw error on redundancy caching + ch, err := cac.New(make([]byte, swarm.ChunkSize)) + if err != nil { + t.Fatal(err) + } + chData := ch.Data() + chSpan := chData[:swarm.SpanSize] + chAddr := ch.Address().Bytes() + + // test logic assumes a simple 2 level chunk tree with carrier chunk + for _, tc := range []struct { + desc string + level redundancy.Level + encryption bool + writes int + parities int + }{ + { + desc: "redundancy write for not encrypted data", + level: redundancy.INSANE, + encryption: false, + writes: 98, // 97 chunk references fit into one chunk + 1 carrier + parities: 38, // 31 (full ch) + 7 (2 ref) + }, + { + desc: "redundancy write for encrypted data", + level: redundancy.PARANOID, + encryption: true, + writes: 21, // 21 encrypted chunk references fit into one chunk + 1 carrier + parities: 118, // // 88 (full ch) + 30 (2 ref) + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + s := inmemchunkstore.New() + intermediateChunkCounter := mock.NewChainWriter() + parityChunkCounter := mock.NewChainWriter() + pf := func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, intermediateChunkCounter) + return bmt.NewBmtWriter(lsw) + } + if tc.encryption { + pf = func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, intermediateChunkCounter) + b := bmt.NewBmtWriter(lsw) + return enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b) + } + } + ppf := func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, parityChunkCounter) + return bmt.NewBmtWriter(lsw) + } + + var key []byte + hashSize := swarm.HashSize + if tc.encryption { + hashSize *= 2 + key = addr.Bytes() + } + + r := redundancy.New(tc.level, tc.encryption, ppf) + ht := hashtrie.NewHashTrieWriter(hashSize, r, pf) + + for i := 0; i < tc.writes; i++ { + a := &pipeline.PipeWriteArgs{Data: chData, Span: chSpan, Ref: chAddr, Key: key} + err := ht.ChainWrite(a) + if err != nil { + t.Fatal(err) + } + } + + ref, err := ht.Sum() + if err != nil { + t.Fatal(err) + } + + // sanity check for the test samples + if tc.parities != parityChunkCounter.ChainWriteCalls() { + t.Errorf("generated parities should be %d. Got: %d", tc.parities, parityChunkCounter.ChainWriteCalls()) + } + if intermediateChunkCounter.ChainWriteCalls() != 2 { // root chunk and the chunk which was written before carrierChunk movement + t.Errorf("effective chunks should be %d. Got: %d", tc.writes, intermediateChunkCounter.ChainWriteCalls()) + } + + rootch, err := s.Get(ctx, swarm.NewAddress(ref[:swarm.HashSize])) + if err != nil { + t.Fatal(err) + } + chData := rootch.Data() + if tc.encryption { + chData, err = dec.DecryptChunkData(chData, ref[swarm.HashSize:]) + if err != nil { + t.Fatal(err) + } + } + + // span check + level, sp := redundancy.DecodeSpan(chData[:swarm.SpanSize]) + expectedSpan := bmtUtils.LengthToSpan(int64(tc.writes * swarm.ChunkSize)) + if !bytes.Equal(expectedSpan, sp) { + t.Fatalf("want span %d got %d", expectedSpan, span) + } + if level != tc.level { + t.Fatalf("encoded level differs from the uploaded one %d. Got: %d", tc.level, level) + } + expectedParities := tc.parities - r.Parities(r.MaxShards()) + _, parity := file.ReferenceCount(bmtUtils.LengthFromSpan(sp), level, tc.encryption) + if expectedParities != parity { + t.Fatalf("want parity %d got %d", expectedParities, parity) + } + }) + } +} diff --git a/pkg/file/redundancy/export_test.go b/pkg/file/redundancy/export_test.go new file mode 100644 index 00000000000..492c956ec88 --- /dev/null +++ b/pkg/file/redundancy/export_test.go @@ -0,0 +1,8 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy + +var SetErasureEncoder = setErasureEncoder +var GetErasureEncoder = getErasureEncoder diff --git a/pkg/file/redundancy/level.go b/pkg/file/redundancy/level.go new file mode 100644 index 00000000000..bdd7226af72 --- /dev/null +++ b/pkg/file/redundancy/level.go @@ -0,0 +1,117 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy + +import "github.com/ethersphere/bee/pkg/swarm" + +type Level uint8 + +const ( + NONE Level = iota + MEDIUM + STRONG + INSANE + PARANOID +) + +const maxLevel = 8 + +// GetParities returns number of parities based on appendix F table 5 +func (l Level) GetParities(shards int) int { + switch l { + case MEDIUM: + return mediumEt.getParities(shards) + case STRONG: + return strongEt.getParities(shards) + case INSANE: + return insaneEt.getParities(shards) + case PARANOID: + return paranoidEt.getParities(shards) + default: + return 0 + } +} + +// GetMaxShards returns back the maximum number of effective data chunks +func (l Level) GetMaxShards() int { + p := l.GetParities(swarm.Branches) + return swarm.Branches - p +} + +// GetEncParities returns number of parities for encrypted chunks based on appendix F table 6 +func (l Level) GetEncParities(shards int) int { + switch l { + case MEDIUM: + return encMediumEt.getParities(shards) + case STRONG: + return encStrongEt.getParities(shards) + case INSANE: + return encInsaneEt.getParities(shards) + case PARANOID: + return encParanoidEt.getParities(shards) + default: + return 0 + } +} + +// GetMaxEncShards returns back the maximum number of effective encrypted data chunks +func (l Level) GetMaxEncShards() int { + p := l.GetEncParities(swarm.EncryptedBranches) + return (swarm.Branches - p) / 2 +} + +// TABLE INITS + +var mediumEt = newErasureTable( + []int{94, 68, 46, 28, 14, 5, 1}, + []int{9, 8, 7, 6, 5, 4, 3}, +) +var encMediumEt = newErasureTable( + []int{47, 34, 23, 14, 7, 2}, + []int{9, 8, 7, 6, 5, 4}, +) + +var strongEt = newErasureTable( + []int{104, 95, 86, 77, 69, 61, 53, 46, 39, 32, 26, 20, 15, 10, 6, 3, 1}, + []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5}, +) +var encStrongEt = newErasureTable( + []int{52, 47, 43, 38, 34, 30, 26, 23, 19, 16, 13, 10, 7, 5, 3, 1}, + []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, +) + +var insaneEt = newErasureTable( + []int{92, 87, 82, 77, 73, 68, 63, 59, 54, 50, 45, 41, 37, 33, 29, 26, 22, 19, 16, 13, 10, 8, 5, 3, 2, 1}, + []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, +) +var encInsaneEt = newErasureTable( + []int{46, 43, 41, 38, 36, 34, 31, 29, 27, 25, 22, 20, 18, 16, 14, 13, 11, 9, 8, 6, 5, 4, 2, 1}, + []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 7}, +) + +var paranoidEt = newErasureTable( + []int{ + 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, + 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, + 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, + }, + []int{ + 90, 88, 87, 85, 84, 82, 81, 79, 77, 76, + 74, 72, 71, 69, 67, 66, 64, 62, 60, 59, + 57, 55, 53, 51, 49, 48, 46, 44, 41, 39, + 37, 35, 32, 30, 27, 24, 20, + }, +) +var encParanoidEt = newErasureTable( + []int{ + 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, + 8, 7, 6, 5, 4, 3, 2, 1, + }, + []int{ + 88, 85, 82, 79, 76, 72, 69, 66, 62, 59, + 55, 51, 48, 44, 39, 35, 30, 24, + }, +) diff --git a/pkg/file/redundancy/redundancy.go b/pkg/file/redundancy/redundancy.go new file mode 100644 index 00000000000..3fe1bca2c43 --- /dev/null +++ b/pkg/file/redundancy/redundancy.go @@ -0,0 +1,204 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy + +import ( + "fmt" + + "github.com/ethersphere/bee/pkg/file/pipeline" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/klauspost/reedsolomon" +) + +// ParityChunkCallback is called when a new parity chunk has been created +type ParityChunkCallback func(level int, span, address []byte) error + +type IParams interface { + MaxShards() int + Level() Level + Parities(int) int + ChunkWrite(int, []byte, ParityChunkCallback) error + ElevateCarrierChunk(int, ParityChunkCallback) error + Encode(int, ParityChunkCallback) error +} + +type ErasureEncoder interface { + Encode([][]byte) error +} + +var erasureEncoderFunc = func(shards, parities int) (ErasureEncoder, error) { + return reedsolomon.New(shards, parities) +} + +// setErasureEncoder changes erasureEncoderFunc to a new erasureEncoder facade +// +// used for testing +func setErasureEncoder(f func(shards, parities int) (ErasureEncoder, error)) { + erasureEncoderFunc = f +} + +// getErasureEncoder returns erasureEncoderFunc +// +// used for testing +func getErasureEncoder() func(shards, parities int) (ErasureEncoder, error) { + return erasureEncoderFunc +} + +type Params struct { + level Level + pipeLine pipeline.PipelineFunc + buffer [][][]byte // keeps bytes of chunks on each level for producing erasure coded data; [levelIndex][branchIndex][byteIndex] + cursor []int // index of the current buffered chunk in Buffer. this is basically the latest used branchIndex. + maxShards int // number of chunks after which the parity encode function should be called + maxParity int // number of parity chunks if maxShards has been reached for erasure coding + encryption bool +} + +func New(level Level, encryption bool, pipeLine pipeline.PipelineFunc) *Params { + maxShards := 0 + maxParity := 0 + if encryption { + maxShards = level.GetMaxEncShards() + maxParity = level.GetParities(swarm.EncryptedBranches) + } else { + maxShards = level.GetMaxShards() + maxParity = level.GetParities(swarm.BmtBranches) + } + // init dataBuffer for erasure coding + rsChunkLevels := 0 + if level != NONE { + rsChunkLevels = maxLevel + } + Buffer := make([][][]byte, rsChunkLevels) + for i := 0; i < rsChunkLevels; i++ { + Buffer[i] = make([][]byte, swarm.BmtBranches) // 128 long always because buffer varies at encrypted chunks + } + + return &Params{ + level: level, + pipeLine: pipeLine, + buffer: Buffer, + cursor: make([]int, 9), + maxShards: maxShards, + maxParity: maxParity, + encryption: encryption, + } +} + +// ACCESSORS + +func (p *Params) MaxShards() int { + return p.maxShards +} + +func (p *Params) Level() Level { + return p.level +} + +// METHODS + +func (p *Params) Parities(shards int) int { + if p.encryption { + return p.level.GetEncParities(shards) + } + return p.level.GetParities(shards) +} + +// ChunkWrite caches the chunk data on the given chunk level and if it is full then it calls Encode +func (p *Params) ChunkWrite(chunkLevel int, data []byte, callback ParityChunkCallback) error { + if p.level == NONE { + return nil + } + if len(data) != swarm.ChunkWithSpanSize { + zeros := make([]byte, swarm.ChunkWithSpanSize-len(data)) + data = append(data, zeros...) + } + + return p.chunkWrite(chunkLevel, data, callback) +} + +// ChunkWrite caches the chunk data on the given chunk level and if it is full then it calls Encode +func (p *Params) chunkWrite(chunkLevel int, data []byte, callback ParityChunkCallback) error { + // append chunk to the buffer + p.buffer[chunkLevel][p.cursor[chunkLevel]] = data + p.cursor[chunkLevel]++ + + // add parity chunk if it is necessary + if p.cursor[chunkLevel] == p.maxShards { + // append erasure coded data + return p.encode(chunkLevel, callback) + } + return nil +} + +// Encode produces and stores parity chunks that will be also passed back to the caller +func (p *Params) Encode(chunkLevel int, callback ParityChunkCallback) error { + if p.level == NONE || p.cursor[chunkLevel] == 0 { + return nil + } + + return p.encode(chunkLevel, callback) +} + +func (p *Params) encode(chunkLevel int, callback ParityChunkCallback) error { + shards := p.cursor[chunkLevel] + parities := p.Parities(shards) + + n := shards + parities + // realloc for parity chunks if it does not override the prev one + // caculate parity chunks + enc, err := erasureEncoderFunc(shards, parities) + if err != nil { + return err + } + // make parity data + pz := len(p.buffer[chunkLevel][0]) + for i := shards; i < n; i++ { + p.buffer[chunkLevel][i] = make([]byte, pz) + } + err = enc.Encode(p.buffer[chunkLevel][:n]) + if err != nil { + return err + } + // store and pass newly created parity chunks + for i := shards; i < n; i++ { + chunkData := p.buffer[chunkLevel][i] + span := chunkData[:swarm.SpanSize] + + // store data chunk + writer := p.pipeLine() + args := pipeline.PipeWriteArgs{ + Data: chunkData, + Span: span, + } + err = writer.ChainWrite(&args) + if err != nil { + return err + } + + // write parity chunk to the level above + err = callback(chunkLevel+1, span, args.Ref) + if err != nil { + return err + } + } + // reset cursor of dataBuffer in case it was a full chunk + p.cursor[chunkLevel] = 0 + + return nil +} + +// ElevateCarrierChunk moves the last poor orphan chunk to the level above where it can fit and there are other chunks as well. +func (p *Params) ElevateCarrierChunk(chunkLevel int, callback ParityChunkCallback) error { + if p.level == NONE { + return nil + } + if p.cursor[chunkLevel] != 1 { + return fmt.Errorf("redundancy: cannot elevate carrier chunk because it is not the only chunk on the level. It has %d chunks", p.cursor[chunkLevel]) + } + + // not necessary to update current level since we will not work with it anymore + return p.chunkWrite(chunkLevel+1, p.buffer[chunkLevel][p.cursor[chunkLevel]-1], callback) +} diff --git a/pkg/file/redundancy/redundancy_test.go b/pkg/file/redundancy/redundancy_test.go new file mode 100644 index 00000000000..a968cb43386 --- /dev/null +++ b/pkg/file/redundancy/redundancy_test.go @@ -0,0 +1,158 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy_test + +import ( + "crypto/rand" + "fmt" + "io" + "sync" + "testing" + + "github.com/ethersphere/bee/pkg/file/pipeline" + "github.com/ethersphere/bee/pkg/file/pipeline/bmt" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/swarm" +) + +// MOCK ENCODER + +type mockEncoder struct { + shards, parities int +} + +func newMockEncoder(shards, parities int) (redundancy.ErasureEncoder, error) { + return &mockEncoder{ + shards: shards, + parities: parities, + }, nil +} + +// Encode makes MSB of span equal to data +func (m *mockEncoder) Encode(buffer [][]byte) error { + // writes parity data + indicatedValue := 0 + for i := m.shards; i < m.shards+m.parities; i++ { + data := make([]byte, 32) + data[swarm.SpanSize-1], data[swarm.SpanSize] = uint8(indicatedValue), uint8(indicatedValue) + buffer[i] = data + indicatedValue++ + } + return nil +} + +// PARITY CHAIN WRITER + +type ParityChainWriter struct { + sync.Mutex + chainWriteCalls int + sumCalls int + validCalls []bool +} + +func NewParityChainWriter() *ParityChainWriter { + return &ParityChainWriter{} +} + +// ACCESSORS + +func (c *ParityChainWriter) ChainWriteCalls() int { + c.Lock() + defer c.Unlock() + return c.chainWriteCalls +} +func (c *ParityChainWriter) SumCalls() int { c.Lock(); defer c.Unlock(); return c.sumCalls } + +// METHODS + +func (c *ParityChainWriter) ChainWrite(args *pipeline.PipeWriteArgs) error { + c.Lock() + defer c.Unlock() + valid := args.Span[len(args.Span)-1] == args.Data[len(args.Span)] && args.Data[len(args.Span)] == byte(c.chainWriteCalls) + c.chainWriteCalls++ + c.validCalls = append(c.validCalls, valid) + return nil +} +func (c *ParityChainWriter) Sum() ([]byte, error) { + c.Lock() + defer c.Unlock() + c.sumCalls++ + return nil, nil +} + +func TestEncode(t *testing.T) { + t.Parallel() + // initializes mockEncoder -> creates shard chunks -> redundancy.chunkWrites -> call encode + erasureEncoder := redundancy.GetErasureEncoder() + defer func() { + redundancy.SetErasureEncoder(erasureEncoder) + }() + redundancy.SetErasureEncoder(newMockEncoder) + + // test on the data level + for _, level := range []redundancy.Level{redundancy.MEDIUM, redundancy.STRONG, redundancy.INSANE, redundancy.PARANOID} { + for _, encrypted := range []bool{false, true} { + maxShards := level.GetMaxShards() + if encrypted { + maxShards = level.GetMaxEncShards() + } + for shardCount := 1; shardCount <= maxShards; shardCount++ { + t.Run(fmt.Sprintf("redundancy level %d is checked with %d shards", level, shardCount), func(t *testing.T) { + parityChainWriter := NewParityChainWriter() + ppf := func() pipeline.ChainWriter { + return bmt.NewBmtWriter(parityChainWriter) + } + params := redundancy.New(level, encrypted, ppf) + // checks parity pipelinecalls are valid + + parityCount := 0 + parityCallback := func(level int, span, address []byte) error { + parityCount++ + return nil + } + + for i := 0; i < shardCount; i++ { + buffer := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, buffer) + if err != nil { + t.Error(err) + } + err = params.ChunkWrite(0, buffer, parityCallback) + if err != nil { + t.Error(err) + } + } + if shardCount != maxShards { + // encode should be called automatically when reaching maxshards + err := params.Encode(0, parityCallback) + if err != nil { + t.Error(err) + } + } + + // CHECKS + + if parityCount != parityChainWriter.chainWriteCalls { + t.Fatalf("parity callback was called %d times meanwhile chainwrite was called %d times", parityCount, parityChainWriter.chainWriteCalls) + } + + expectedParityCount := params.Level().GetParities(shardCount) + if encrypted { + expectedParityCount = params.Level().GetEncParities(shardCount) + } + if parityCount != expectedParityCount { + t.Fatalf("parity callback was called %d times meanwhile expected parity number should be %d", parityCount, expectedParityCount) + } + + for i, validCall := range parityChainWriter.validCalls { + if !validCall { + t.Fatalf("parity chunk data is wrong at parity index %d", i) + } + } + }) + } + } + } +} diff --git a/pkg/file/redundancy/span.go b/pkg/file/redundancy/span.go new file mode 100644 index 00000000000..eb71286fa43 --- /dev/null +++ b/pkg/file/redundancy/span.go @@ -0,0 +1,34 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy + +import ( + "github.com/ethersphere/bee/pkg/swarm" +) + +// EncodeLevel encodes used redundancy level for uploading into span keeping the real byte count for the chunk. +// assumes span is LittleEndian +func EncodeLevel(span []byte, level Level) { + // set parity in the most signifact byte + span[swarm.SpanSize-1] = uint8(level) | 1<<7 // p + 128 +} + +// DecodeSpan decodes the used redundancy level from span keeping the real byte count for the chunk. +// assumes span is LittleEndian +func DecodeSpan(span []byte) (Level, []byte) { + spanCopy := make([]byte, swarm.SpanSize) + copy(spanCopy, span) + if !IsLevelEncoded(spanCopy) { + return 0, spanCopy + } + pByte := spanCopy[swarm.SpanSize-1] + return Level(pByte & ((1 << 7) - 1)), append(spanCopy[:swarm.SpanSize-1], 0) +} + +// IsLevelEncoded checks whether the redundancy level is encoded in the span +// assumes span is LittleEndian +func IsLevelEncoded(span []byte) bool { + return span[swarm.SpanSize-1] > 128 +} diff --git a/pkg/file/redundancy/table.go b/pkg/file/redundancy/table.go new file mode 100644 index 00000000000..a88505da2b1 --- /dev/null +++ b/pkg/file/redundancy/table.go @@ -0,0 +1,52 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy + +type erasureTable struct { + shards []int + parities []int +} + +// newErasureTable initializes a shards<->parities table +// +// the value order must be strictly descending in both arrays +// example usage: +// shards := []int{94, 68, 46, 28, 14, 5, 1} +// parities := []int{9, 8, 7, 6, 5, 4, 3} +// var et = newErasureTable(shards, parities) +func newErasureTable(shards, parities []int) *erasureTable { + if len(shards) != len(parities) { + panic("redundancy table: shards and parities arrays must be of equal size") + } + + maxShards := shards[0] + maxParities := parities[0] + for k := 1; k < len(shards); k++ { + s := shards[k] + if maxShards <= s { + panic("redundancy table: shards should be in strictly descending order") + } + p := parities[k] + if maxParities <= p { + panic("redundancy table: parities should be in strictly descending order") + } + maxShards, maxParities = s, p + } + + return &erasureTable{ + shards: shards, + parities: parities, + } +} + +// getParities gives back the optimal parity number for a given shard +func (et *erasureTable) getParities(maxShards int) int { + for k, s := range et.shards { + if maxShards >= s { + return et.parities[k] + } + } + return 0 +} diff --git a/pkg/file/utils.go b/pkg/file/utils.go new file mode 100644 index 00000000000..298a6578447 --- /dev/null +++ b/pkg/file/utils.go @@ -0,0 +1,60 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package file + +import ( + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/swarm" +) + +// ReferenceCount brute-forces the data shard count from which identify the parity count as well in a substree +// assumes span > swarm.chunkSize +// returns data and parity shard number +func ReferenceCount(span uint64, level redundancy.Level, encrytedChunk bool) (int, int) { + // assume we have a trie of size `span` then we can assume that all of + // the forks except for the last one on the right are of equal size + // this is due to how the splitter wraps levels. + // first the algorithm will search for a BMT level where span can be included + // then identify how large data one reference can hold on that level + // then count how many references can satisfy span + // and finally how many parity shards should be on that level + maxShards := level.GetMaxShards() + if encrytedChunk { + maxShards = level.GetMaxEncShards() + } + var ( + branching = uint64(maxShards) // branching factor is how many data shard references can fit into one intermediate chunk + branchSize = uint64(swarm.ChunkSize) + ) + // search for branch level big enough to include span + branchLevel := 1 + for { + if branchSize >= span { + break + } + branchSize *= branching + branchLevel++ + } + // span in one full reference + referenceSize := uint64(swarm.ChunkSize) + // referenceSize = branching ** (branchLevel - 1) + for i := 1; i < branchLevel-1; i++ { + referenceSize *= branching + } + + dataShardAddresses := 1 + spanOffset := referenceSize + for spanOffset < span { + spanOffset += referenceSize + dataShardAddresses++ + } + + parityAddresses := level.GetParities(dataShardAddresses) + if encrytedChunk { + parityAddresses = level.GetEncParities(dataShardAddresses) + } + + return dataShardAddresses, parityAddresses +} diff --git a/pkg/steward/steward_test.go b/pkg/steward/steward_test.go index 6644c86d69b..3836f75dcfd 100644 --- a/pkg/steward/steward_test.go +++ b/pkg/steward/steward_test.go @@ -44,7 +44,7 @@ func TestSteward(t *testing.T) { t.Fatal(err) } - pipe := builder.NewPipelineBuilder(ctx, chunkStore, false) + pipe := builder.NewPipelineBuilder(ctx, chunkStore, false, 0) addr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data)) if err != nil { t.Fatal(err) diff --git a/pkg/traversal/traversal_test.go b/pkg/traversal/traversal_test.go index dc8a78fcb53..75b5f925283 100644 --- a/pkg/traversal/traversal_test.go +++ b/pkg/traversal/traversal_test.go @@ -161,7 +161,7 @@ func TestTraversalBytes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - pipe := builder.NewPipelineBuilder(ctx, storerMock, false) + pipe := builder.NewPipelineBuilder(ctx, storerMock, false, 0) address, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data)) if err != nil { t.Fatal(err) @@ -256,7 +256,7 @@ func TestTraversalFiles(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - pipe := builder.NewPipelineBuilder(ctx, storerMock, false) + pipe := builder.NewPipelineBuilder(ctx, storerMock, false, 0) fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data)) if err != nil { t.Fatal(err) @@ -430,7 +430,7 @@ func TestTraversalManifest(t *testing.T) { for _, f := range tc.files { data := generateSample(f.size) - pipe := builder.NewPipelineBuilder(ctx, storerMock, false) + pipe := builder.NewPipelineBuilder(ctx, storerMock, false, 0) fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data)) if err != nil { t.Fatal(err) @@ -506,6 +506,6 @@ func TestTraversalSOC(t *testing.T) { func pipelineFactory(s storage.Putter, encrypt bool) func() pipeline.Interface { return func() pipeline.Interface { - return builder.NewPipelineBuilder(context.Background(), s, encrypt) + return builder.NewPipelineBuilder(context.Background(), s, encrypt, 0) } } From efc54cd40a28e7f12b32efa00ac1777b86b89369 Mon Sep 17 00:00:00 2001 From: nugaon <50576770+nugaon@users.noreply.github.com> Date: Mon, 4 Dec 2023 16:23:20 +0100 Subject: [PATCH 02/23] feat: erasure decoder (#4448) --- pkg/api/bzz.go | 7 +- pkg/api/bzz_test.go | 2 +- pkg/api/dirs.go | 2 +- pkg/api/feed.go | 2 +- pkg/api/pin.go | 2 +- pkg/file/addresses/addresses_getter_test.go | 2 +- pkg/file/file_test.go | 2 +- pkg/file/joiner/joiner.go | 135 +++++--- pkg/file/joiner/joiner_test.go | 146 ++++++++- pkg/file/loadsave/loadsave.go | 6 +- pkg/file/loadsave/loadsave_test.go | 2 +- pkg/file/pipeline/hashtrie/hashtrie_test.go | 62 ++-- pkg/file/redundancy/getter/getter.go | 333 ++++++++++++++++++++ pkg/file/redundancy/getter/getter_test.go | 217 +++++++++++++ pkg/file/redundancy/level.go | 52 ++- pkg/file/redundancy/table.go | 12 + pkg/file/utils.go | 37 +++ pkg/node/bootstrap.go | 2 +- pkg/node/node.go | 2 +- pkg/steward/steward.go | 6 +- pkg/steward/steward_test.go | 5 +- pkg/storer/epoch_migration.go | 1 + pkg/traversal/traversal.go | 13 +- pkg/traversal/traversal_test.go | 12 +- 24 files changed, 944 insertions(+), 118 deletions(-) create mode 100644 pkg/file/redundancy/getter/getter.go create mode 100644 pkg/file/redundancy/getter/getter_test.go diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index b1605bae90d..a479c6829f3 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -9,7 +9,6 @@ import ( "encoding/hex" "errors" "fmt" - "github.com/ethersphere/bee/pkg/topology" "net/http" "path" "path/filepath" @@ -17,6 +16,8 @@ import ( "strings" "time" + "github.com/ethersphere/bee/pkg/topology" + "github.com/ethereum/go-ethereum/common" "github.com/ethersphere/bee/pkg/feeds" "github.com/ethersphere/bee/pkg/file/joiner" @@ -175,7 +176,7 @@ func (s *Service) fileUploadHandler( } factory := requestPipelineFactory(ctx, putter, encrypt, rLevel) - l := loadsave.New(s.storer.ChunkStore(), factory) + l := loadsave.New(s.storer.ChunkStore(), s.storer.Cache(), factory) m, err := manifest.NewDefaultManifest(l, encrypt) if err != nil { @@ -465,7 +466,7 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h if headers.Cache != nil { cache = *headers.Cache } - reader, l, err := joiner.New(r.Context(), s.storer.Download(cache), reference) + reader, l, err := joiner.New(r.Context(), s.storer.Download(cache), s.storer.Cache(), reference) if err != nil { if errors.Is(err, storage.ErrNotFound) || errors.Is(err, topology.ErrNotFound) { logger.Debug("api download: not found ", "address", reference, "error", err) diff --git a/pkg/api/bzz_test.go b/pkg/api/bzz_test.go index 731cfcb567f..e6d52d8acd3 100644 --- a/pkg/api/bzz_test.go +++ b/pkg/api/bzz_test.go @@ -554,7 +554,7 @@ func TestFeedIndirection(t *testing.T) { t.Fatal(err) } m, err := manifest.NewDefaultManifest( - loadsave.New(storer.ChunkStore(), pipelineFactory(storer.Cache(), false, 0)), + loadsave.New(storer.ChunkStore(), storer.Cache(), pipelineFactory(storer.Cache(), false, 0)), false, ) if err != nil { diff --git a/pkg/api/dirs.go b/pkg/api/dirs.go index 771787a24a9..00b530b9137 100644 --- a/pkg/api/dirs.go +++ b/pkg/api/dirs.go @@ -132,7 +132,7 @@ func storeDir( loggerV1 := logger.V(1).Build() p := requestPipelineFn(putter, encrypt, rLevel) - ls := loadsave.New(getter, requestPipelineFactory(ctx, putter, encrypt, rLevel)) + ls := loadsave.New(getter, putter, requestPipelineFactory(ctx, putter, encrypt, rLevel)) dirManifest, err := manifest.NewDefaultManifest(ls, encrypt) if err != nil { diff --git a/pkg/api/feed.go b/pkg/api/feed.go index f3a3413a426..23961e6d5fc 100644 --- a/pkg/api/feed.go +++ b/pkg/api/feed.go @@ -196,7 +196,7 @@ func (s *Service) feedPostHandler(w http.ResponseWriter, r *http.Request) { logger: logger, } - l := loadsave.New(s.storer.ChunkStore(), requestPipelineFactory(r.Context(), putter, false, 0)) + l := loadsave.New(s.storer.ChunkStore(), s.storer.Cache(), requestPipelineFactory(r.Context(), putter, false, 0)) feedManifest, err := manifest.NewDefaultManifest(l, false) if err != nil { logger.Debug("create manifest failed", "error", err) diff --git a/pkg/api/pin.go b/pkg/api/pin.go index a98af277583..7c80e5f196b 100644 --- a/pkg/api/pin.go +++ b/pkg/api/pin.go @@ -50,7 +50,7 @@ func (s *Service) pinRootHash(w http.ResponseWriter, r *http.Request) { } getter := s.storer.Download(true) - traverser := traversal.New(getter) + traverser := traversal.New(getter, s.storer.Cache()) sem := semaphore.NewWeighted(100) var errTraverse error diff --git a/pkg/file/addresses/addresses_getter_test.go b/pkg/file/addresses/addresses_getter_test.go index 542c2d8a442..a5fe9d00000 100644 --- a/pkg/file/addresses/addresses_getter_test.go +++ b/pkg/file/addresses/addresses_getter_test.go @@ -63,7 +63,7 @@ func TestAddressesGetterIterateChunkAddresses(t *testing.T) { addressesGetter := addresses.NewGetter(store, addressIterFunc) - j, _, err := joiner.New(ctx, addressesGetter, rootChunk.Address()) + j, _, err := joiner.New(ctx, addressesGetter, store, rootChunk.Address()) if err != nil { t.Fatal(err) } diff --git a/pkg/file/file_test.go b/pkg/file/file_test.go index fd85a23fb9d..cc677336b64 100644 --- a/pkg/file/file_test.go +++ b/pkg/file/file_test.go @@ -62,7 +62,7 @@ func testSplitThenJoin(t *testing.T) { } // then join - r, l, err := joiner.New(ctx, store, resultAddress) + r, l, err := joiner.New(ctx, store, store, resultAddress) if err != nil { t.Fatal(err) } diff --git a/pkg/file/joiner/joiner.go b/pkg/file/joiner/joiner.go index b046c6bd0fd..5758584f006 100644 --- a/pkg/file/joiner/joiner.go +++ b/pkg/file/joiner/joiner.go @@ -7,33 +7,40 @@ package joiner import ( "context" - "encoding/binary" "errors" "io" "sync" "sync/atomic" + "github.com/ethersphere/bee/pkg/bmt" "github.com/ethersphere/bee/pkg/encryption" "github.com/ethersphere/bee/pkg/encryption/store" "github.com/ethersphere/bee/pkg/file" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/file/redundancy/getter" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "golang.org/x/sync/errgroup" ) type joiner struct { - addr swarm.Address - rootData []byte - span int64 - off int64 - refLength int + addr swarm.Address + rootData []byte + span int64 + off int64 + refLength int + rootParity int + maxBranching int // maximum branching in an intermediate chunk ctx context.Context getter storage.Getter + putter storage.Putter // required to save recovered data + + chunkToSpan func(data []byte) (redundancy.Level, int64) // returns parity and span value from chunkData } // New creates a new Joiner. A Joiner provides Read, Seek and Size functionalities. -func New(ctx context.Context, getter storage.Getter, address swarm.Address) (file.Joiner, int64, error) { +func New(ctx context.Context, getter storage.Getter, putter storage.Putter, address swarm.Address) (file.Joiner, int64, error) { getter = store.New(getter) // retrieve the root chunk to read the total data length the be retrieved rootChunk, err := getter.Get(ctx, address) @@ -41,17 +48,42 @@ func New(ctx context.Context, getter storage.Getter, address swarm.Address) (fil return nil, 0, err } - var chunkData = rootChunk.Data() - - span := int64(binary.LittleEndian.Uint64(chunkData[:swarm.SpanSize])) + chunkData := rootChunk.Data() + rootData := chunkData[swarm.SpanSize:] + refLength := len(address.Bytes()) + encryption := false + if refLength != swarm.HashSize { + encryption = true + } + rLevel, span := chunkToSpan(chunkData) + rootParity := 0 + maxBranching := swarm.ChunkSize / refLength + spanFn := func(data []byte) (redundancy.Level, int64) { + return 0, int64(bmt.LengthFromSpan(data[:swarm.SpanSize])) + } + // override stuff if root chunk has redundancy + if rLevel != redundancy.NONE { + _, parities := file.ReferenceCount(uint64(span), rLevel, encryption) + rootParity = parities + spanFn = chunkToSpan + if encryption { + maxBranching = rLevel.GetMaxEncShards() + } else { + maxBranching = rLevel.GetMaxShards() + } + } j := &joiner{ - addr: rootChunk.Address(), - refLength: len(address.Bytes()), - ctx: ctx, - getter: getter, - span: span, - rootData: chunkData[swarm.SpanSize:], + addr: rootChunk.Address(), + refLength: refLength, + ctx: ctx, + getter: getter, + putter: putter, + span: span, + rootData: rootData, + rootParity: rootParity, + maxBranching: maxBranching, + chunkToSpan: spanFn, } return j, span, nil @@ -81,7 +113,7 @@ func (j *joiner) ReadAt(buffer []byte, off int64) (read int, err error) { } var bytesRead int64 var eg errgroup.Group - j.readAtOffset(buffer, j.rootData, 0, j.span, off, 0, readLen, &bytesRead, &eg) + j.readAtOffset(buffer, j.rootData, 0, j.span, off, 0, readLen, &bytesRead, j.rootParity, &eg) err = eg.Wait() if err != nil { @@ -93,7 +125,13 @@ func (j *joiner) ReadAt(buffer []byte, off int64) (read int, err error) { var ErrMalformedTrie = errors.New("malformed tree") -func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffset, bytesToRead int64, bytesRead *int64, eg *errgroup.Group) { +func (j *joiner) readAtOffset( + b, data []byte, + cur, subTrieSize, off, bufferOffset, bytesToRead int64, + bytesRead *int64, + parity int, + eg *errgroup.Group, +) { // we are at a leaf data chunk if subTrieSize <= int64(len(data)) { dataOffsetStart := off - cur @@ -109,14 +147,23 @@ func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffse return } + pSize, err := file.ChunkPayloadSize(data) + if err != nil { + eg.Go(func() error { + return err + }) + return + } + sAddresses, pAddresses := file.ChunkAddresses(data[:pSize], parity, j.refLength) + getter := getter.New(sAddresses, pAddresses, j.getter, j.putter) for cursor := 0; cursor < len(data); cursor += j.refLength { if bytesToRead == 0 { break } // fast forward the cursor - sec := subtrieSection(data, cursor, j.refLength, subTrieSize) - if cur+sec < off { + sec := j.subtrieSection(data, cursor, pSize, parity, subTrieSize) + if cur+sec <= off { cur += sec continue } @@ -139,19 +186,20 @@ func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffse func(address swarm.Address, b []byte, cur, subTrieSize, off, bufferOffset, bytesToRead, subtrieSpanLimit int64) { eg.Go(func() error { - ch, err := j.getter.Get(j.ctx, address) + ch, err := getter.Get(j.ctx, address) if err != nil { return err } chunkData := ch.Data()[8:] - subtrieSpan := int64(chunkToSpan(ch.Data())) + subtrieLevel, subtrieSpan := j.chunkToSpan(ch.Data()) + _, subtrieParity := file.ReferenceCount(uint64(subtrieSpan), subtrieLevel, j.refLength != swarm.HashSize) if subtrieSpan > subtrieSpanLimit { return ErrMalformedTrie } - j.readAtOffset(b, chunkData, cur, subtrieSpan, off, bufferOffset, currentReadSize, bytesRead, eg) + j.readAtOffset(b, chunkData, cur, subtrieSpan, off, bufferOffset, currentReadSize, bytesRead, subtrieParity, eg) return nil }) }(address, b, cur, subtrieSpan, off, bufferOffset, currentReadSize, subtrieSpanLimit) @@ -163,8 +211,13 @@ func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffse } } +// getShards returns the effective reference number respective to the intermediate chunk payload length and its parities +func (j *joiner) getShards(payloadSize, parities int) int { + return (payloadSize - parities*swarm.HashSize) / j.refLength +} + // brute-forces the subtrie size for each of the sections in this intermediate chunk -func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64 { +func (j *joiner) subtrieSection(data []byte, startIdx, payloadSize, parities int, subtrieSize int64) int64 { // assume we have a trie of size `y` then we can assume that all of // the forks except for the last one on the right are of equal size // this is due to how the splitter wraps levels. @@ -173,9 +226,9 @@ func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64 // where y is the size of the subtrie, refs are the number of references // x is constant (the brute forced value) and l is the size of the last subtrie var ( - refs = int64(len(data) / refLen) // how many references in the intermediate chunk - branching = int64(4096 / refLen) // branching factor is chunkSize divided by reference length - branchSize = int64(4096) + refs = int64(j.getShards(payloadSize, parities)) // how many effective references in the intermediate chunk + branching = int64(j.maxBranching) // branching factor is chunkSize divided by reference length + branchSize = int64(swarm.ChunkSize) ) for { whatsLeft := subtrieSize - (branchSize * (refs - 1)) @@ -186,7 +239,7 @@ func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64 } // handle last branch edge case - if startIdx == int(refs-1)*refLen { + if startIdx == int(refs-1)*j.refLength { return subtrieSize - (refs-1)*branchSize } return branchSize @@ -229,10 +282,10 @@ func (j *joiner) IterateChunkAddresses(fn swarm.AddressIterFunc) error { return err } - return j.processChunkAddresses(j.ctx, fn, j.rootData, j.span) + return j.processChunkAddresses(j.ctx, fn, j.rootData, j.span, j.rootParity) } -func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIterFunc, data []byte, subTrieSize int64) error { +func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIterFunc, data []byte, subTrieSize int64, parity int) error { // we are at a leaf data chunk if subTrieSize <= int64(len(data)) { return nil @@ -248,6 +301,12 @@ func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIter var wg sync.WaitGroup + eSize, err := file.ChunkPayloadSize(data) + if err != nil { + return err + } + sAddresses, pAddresses := file.ChunkAddresses(data[:eSize], parity, j.refLength) + getter := getter.New(sAddresses, pAddresses, j.getter, j.putter) for cursor := 0; cursor < len(data); cursor += j.refLength { ref := data[cursor : cursor+j.refLength] var reportAddr swarm.Address @@ -262,7 +321,7 @@ func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIter return err } - sec := subtrieSection(data, cursor, j.refLength, subTrieSize) + sec := j.subtrieSection(data, cursor, eSize, parity, subTrieSize) if sec <= swarm.ChunkSize { continue } @@ -273,15 +332,16 @@ func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIter eg.Go(func() error { defer wg.Done() - ch, err := j.getter.Get(ectx, address) + ch, err := getter.Get(ectx, address) if err != nil { return err } chunkData := ch.Data()[8:] - subtrieSpan := int64(chunkToSpan(ch.Data())) + subtrieLevel, subtrieSpan := j.chunkToSpan(ch.Data()) + _, parities := file.ReferenceCount(uint64(subtrieSpan), subtrieLevel, j.refLength != swarm.HashSize) - return j.processChunkAddresses(ectx, fn, chunkData, subtrieSpan) + return j.processChunkAddresses(ectx, fn, chunkData, subtrieSpan, parities) }) }(address, eg) @@ -295,6 +355,9 @@ func (j *joiner) Size() int64 { return j.span } -func chunkToSpan(data []byte) uint64 { - return binary.LittleEndian.Uint64(data[:8]) +// UTILITIES + +func chunkToSpan(data []byte) (redundancy.Level, int64) { + level, spanBytes := redundancy.DecodeSpan(data[:swarm.SpanSize]) + return level, int64(bmt.LengthFromSpan(spanBytes)) } diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 772c0e12ae2..67ac1c3f9c1 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -7,6 +7,7 @@ package joiner_test import ( "bytes" "context" + "crypto/rand" "encoding/binary" "errors" "fmt" @@ -20,6 +21,7 @@ import ( "github.com/ethersphere/bee/pkg/encryption/store" "github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/pipeline/builder" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/file/splitter" filetest "github.com/ethersphere/bee/pkg/file/testing" storage "github.com/ethersphere/bee/pkg/storage" @@ -34,7 +36,7 @@ func TestJoiner_ErrReferenceLength(t *testing.T) { t.Parallel() store := inmemchunkstore.New() - _, _, err := joiner.New(context.Background(), store, swarm.ZeroAddress) + _, _, err := joiner.New(context.Background(), store, store, swarm.ZeroAddress) if !errors.Is(err, storage.ErrReferenceLength) { t.Fatalf("expected ErrReferenceLength %x but got %v", swarm.ZeroAddress, err) @@ -64,7 +66,7 @@ func TestJoinerSingleChunk(t *testing.T) { } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, mockAddr) + joinReader, l, err := joiner.New(ctx, store, store, mockAddr) if err != nil { t.Fatal(err) } @@ -104,7 +106,7 @@ func TestJoinerDecryptingStore_NormalChunk(t *testing.T) { } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, mockAddr) + joinReader, l, err := joiner.New(ctx, store, st, mockAddr) if err != nil { t.Fatal(err) } @@ -152,7 +154,7 @@ func TestJoinerWithReference(t *testing.T) { } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, rootChunk.Address()) + joinReader, l, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -208,7 +210,7 @@ func TestJoinerMalformed(t *testing.T) { t.Fatal(err) } - joinReader, _, err := joiner.New(ctx, store, rootChunk.Address()) + joinReader, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -254,7 +256,7 @@ func TestEncryptDecrypt(t *testing.T) { if err != nil { t.Fatal(err) } - reader, l, err := joiner.New(context.Background(), store, resultAddress) + reader, l, err := joiner.New(context.Background(), store, store, resultAddress) if err != nil { t.Fatal(err) } @@ -341,7 +343,7 @@ func TestSeek(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, addr) + j, _, err := joiner.New(ctx, store, store, addr) if err != nil { t.Fatal(err) } @@ -618,7 +620,7 @@ func TestPrefetch(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, addr) + j, _, err := joiner.New(ctx, store, store, addr) if err != nil { t.Fatal(err) } @@ -667,7 +669,7 @@ func TestJoinerReadAt(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -714,7 +716,7 @@ func TestJoinerOneLevel(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -808,7 +810,7 @@ func TestJoinerTwoLevelsAcrossChunk(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -864,7 +866,7 @@ func TestJoinerIterateChunkAddresses(t *testing.T) { createdAddresses := []swarm.Address{rootChunk.Address(), firstAddress, secondAddress} - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -917,7 +919,7 @@ func TestJoinerIterateChunkAddresses_Encrypted(t *testing.T) { if err != nil { t.Fatal(err) } - j, l, err := joiner.New(context.Background(), store, resultAddress) + j, l, err := joiner.New(context.Background(), store, store, resultAddress) if err != nil { t.Fatal(err) } @@ -951,3 +953,121 @@ func TestJoinerIterateChunkAddresses_Encrypted(t *testing.T) { } } } + +func TestJoinerRedundancy(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + rLevel redundancy.Level + encryptChunk bool + }{ + { + redundancy.MEDIUM, + true, + }, + { + redundancy.STRONG, + false, + }, + { + redundancy.INSANE, + true, + }, + { + redundancy.PARANOID, + false, + }, + } { + tc := tc + t.Run(fmt.Sprintf("redundancy %d encryption %t", tc.rLevel, tc.encryptChunk), func(t *testing.T) { + ctx := context.Background() + store := inmemchunkstore.New() + pipe := builder.NewPipelineBuilder(ctx, store, tc.encryptChunk, tc.rLevel) + + // generate and store chunks + dataChunkCount := tc.rLevel.GetMaxShards() + 1 // generate a carrier chunk + if tc.encryptChunk { + dataChunkCount = tc.rLevel.GetMaxEncShards() + 1 + } + dataChunks := make([]swarm.Chunk, dataChunkCount) + chunkSize := swarm.ChunkSize + for i := 0; i < dataChunkCount; i++ { + chunkData := make([]byte, chunkSize) + _, err := io.ReadFull(rand.Reader, chunkData) + if err != nil { + t.Fatal(err) + } + dataChunks[i], err = cac.New(chunkData) + if err != nil { + t.Fatal(err) + } + err = store.Put(ctx, dataChunks[i]) + if err != nil { + t.Fatal(err) + } + _, err = pipe.Write(chunkData) + if err != nil { + t.Fatal(err) + } + } + + // reader init + sum, err := pipe.Sum() + if err != nil { + t.Fatal(err) + } + swarmAddr := swarm.NewAddress(sum) + joinReader, rootSpan, err := joiner.New(ctx, store, store, swarmAddr) + if err != nil { + t.Fatal(err) + } + // sanity checks + expectedRootSpan := chunkSize * dataChunkCount + if int64(expectedRootSpan) != rootSpan { + t.Fatalf("Expected root span %d. Got: %d", expectedRootSpan, rootSpan) + } + // all data can be read back + readCheck := func() { + offset := int64(0) + for i := 0; i < dataChunkCount; i++ { + chunkData := make([]byte, chunkSize) + _, err = joinReader.ReadAt(chunkData, offset) + if err != nil { + t.Fatalf("read error check at chunkdata comparisation on %d index: %s", i, err.Error()) + } + expectedChunkData := dataChunks[i].Data()[swarm.SpanSize:] + if !bytes.Equal(expectedChunkData, chunkData) { + t.Fatalf("read error check at chunkdata comparisation on %d index. Data are not the same", i) + } + offset += int64(chunkSize) + } + } + readCheck() + + // remove data chunks in order to trigger recovery + maxShards := tc.rLevel.GetMaxShards() + maxParities := tc.rLevel.GetParities(maxShards) + if tc.encryptChunk { + maxShards = tc.rLevel.GetMaxEncShards() + maxParities = tc.rLevel.GetEncParities(maxShards) + } + removeCount := maxParities + if maxParities > maxShards { + removeCount = maxShards + } + for i := 0; i < removeCount; i++ { + err := store.Delete(ctx, dataChunks[i].Address()) + if err != nil { + t.Fatal(err) + } + } + // remove parity chunk + err = store.Delete(ctx, dataChunks[len(dataChunks)-1].Address()) + if err != nil { + t.Fatal(err) + } + + // check whether the data still be readable + readCheck() + }) + } +} diff --git a/pkg/file/loadsave/loadsave.go b/pkg/file/loadsave/loadsave.go index 1f5267a4e8e..1d65e45ab6b 100644 --- a/pkg/file/loadsave/loadsave.go +++ b/pkg/file/loadsave/loadsave.go @@ -27,13 +27,15 @@ var errReadonlyLoadSave = errors.New("readonly manifest loadsaver") // load all of the subtrie of a given hash in memory. type loadSave struct { getter storage.Getter + putter storage.Putter pipelineFn func() pipeline.Interface } // New returns a new read-write load-saver. -func New(getter storage.Getter, pipelineFn func() pipeline.Interface) file.LoadSaver { +func New(getter storage.Getter, putter storage.Putter, pipelineFn func() pipeline.Interface) file.LoadSaver { return &loadSave{ getter: getter, + putter: putter, pipelineFn: pipelineFn, } } @@ -47,7 +49,7 @@ func NewReadonly(getter storage.Getter) file.LoadSaver { } func (ls *loadSave) Load(ctx context.Context, ref []byte) ([]byte, error) { - j, _, err := joiner.New(ctx, ls.getter, swarm.NewAddress(ref)) + j, _, err := joiner.New(ctx, ls.getter, ls.putter, swarm.NewAddress(ref)) if err != nil { return nil, err } diff --git a/pkg/file/loadsave/loadsave_test.go b/pkg/file/loadsave/loadsave_test.go index cfb55a953a6..64154859757 100644 --- a/pkg/file/loadsave/loadsave_test.go +++ b/pkg/file/loadsave/loadsave_test.go @@ -28,7 +28,7 @@ func TestLoadSave(t *testing.T) { t.Parallel() store := inmemchunkstore.New() - ls := loadsave.New(store, pipelineFn(store)) + ls := loadsave.New(store, store, pipelineFn(store)) ref, err := ls.Save(context.Background(), data) if err != nil { diff --git a/pkg/file/pipeline/hashtrie/hashtrie_test.go b/pkg/file/pipeline/hashtrie/hashtrie_test.go index beebd77cf74..78b4c20b2d7 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie_test.go +++ b/pkg/file/pipeline/hashtrie/hashtrie_test.go @@ -23,6 +23,7 @@ import ( "github.com/ethersphere/bee/pkg/file/pipeline/mock" "github.com/ethersphere/bee/pkg/file/pipeline/store" "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" "github.com/ethersphere/bee/pkg/swarm" ) @@ -43,6 +44,41 @@ func init() { binary.LittleEndian.PutUint64(span, 1) } +// NewErasureHashTrieWriter returns back an redundancy param and a HastTrieWriter pipeline +// which are using simple BMT and StoreWriter pipelines for chunk writes +func newErasureHashTrieWriter( + ctx context.Context, + s storage.Putter, + rLevel redundancy.Level, + encryptChunks bool, + intermediateChunkPipeline, parityChunkPipeline pipeline.ChainWriter, +) (redundancy.IParams, pipeline.ChainWriter) { + pf := func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline) + return bmt.NewBmtWriter(lsw) + } + if encryptChunks { + pf = func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline) + b := bmt.NewBmtWriter(lsw) + return enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b) + } + } + ppf := func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, parityChunkPipeline) + return bmt.NewBmtWriter(lsw) + } + + hashSize := swarm.HashSize + if encryptChunks { + hashSize *= 2 + } + + r := redundancy.New(rLevel, encryptChunks, ppf) + ht := hashtrie.NewHashTrieWriter(hashSize, r, pf) + return r, ht +} + func TestLevels(t *testing.T) { t.Parallel() @@ -271,32 +307,14 @@ func TestRedundancy(t *testing.T) { s := inmemchunkstore.New() intermediateChunkCounter := mock.NewChainWriter() parityChunkCounter := mock.NewChainWriter() - pf := func() pipeline.ChainWriter { - lsw := store.NewStoreWriter(ctx, s, intermediateChunkCounter) - return bmt.NewBmtWriter(lsw) - } - if tc.encryption { - pf = func() pipeline.ChainWriter { - lsw := store.NewStoreWriter(ctx, s, intermediateChunkCounter) - b := bmt.NewBmtWriter(lsw) - return enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b) - } - } - ppf := func() pipeline.ChainWriter { - lsw := store.NewStoreWriter(ctx, s, parityChunkCounter) - return bmt.NewBmtWriter(lsw) - } + r, ht := newErasureHashTrieWriter(ctx, s, tc.level, tc.encryption, intermediateChunkCounter, parityChunkCounter) + + // write data to the hashTrie var key []byte - hashSize := swarm.HashSize if tc.encryption { - hashSize *= 2 - key = addr.Bytes() + key = make([]byte, swarm.HashSize) } - - r := redundancy.New(tc.level, tc.encryption, ppf) - ht := hashtrie.NewHashTrieWriter(hashSize, r, pf) - for i := 0; i < tc.writes; i++ { a := &pipeline.PipeWriteArgs{Data: chData, Span: chSpan, Ref: chAddr, Key: key} err := ht.ChainWrite(a) diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go new file mode 100644 index 00000000000..982c5080109 --- /dev/null +++ b/pkg/file/redundancy/getter/getter.go @@ -0,0 +1,333 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package getter + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/ethersphere/bee/pkg/encryption/store" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/klauspost/reedsolomon" +) + +/// ERRORS + +type cannotRecoverError struct { + missingChunks int +} + +func (e cannotRecoverError) Error() string { + return fmt.Sprintf("redundancy getter: there are %d missing chunks in order to do recovery", e.missingChunks) +} + +func IsCannotRecoverError(err error, missingChunks int) bool { + return errors.Is(err, cannotRecoverError{missingChunks}) +} + +type isNotRecoveredError struct { + chAddress string +} + +func (e isNotRecoveredError) Error() string { + return fmt.Sprintf("redundancy getter: chunk with address %s is not recovered", e.chAddress) +} + +func IsNotRecoveredError(err error, chAddress string) bool { + return errors.Is(err, isNotRecoveredError{chAddress}) +} + +type noDataAddressIncludedError struct { + chAddress string +} + +func (e noDataAddressIncludedError) Error() string { + return fmt.Sprintf("redundancy getter: no data shard address given with chunk address %s", e.chAddress) +} + +func IsNoDataAddressIncludedError(err error, chAddress string) bool { + return errors.Is(err, noDataAddressIncludedError{chAddress}) +} + +type noRedundancyError struct { + chAddress string +} + +func (e noRedundancyError) Error() string { + return fmt.Sprintf("redundancy getter: cannot get chunk %s because no redundancy added", e.chAddress) +} + +func IsNoRedundancyError(err error, chAddress string) bool { + return errors.Is(err, noRedundancyError{chAddress}) +} + +/// TYPES + +// inflightChunk is initialized if recovery happened already +type inflightChunk struct { + pos int // chunk index in the erasureData/intermediate chunk + wait chan struct{} // chunk is under retrieval +} + +// getter retrieves children of an intermediate chunk +// it caches sibling chunks if erasure decoding was called on the level already +type getter struct { + storage.Getter + storage.Putter + mu sync.Mutex + sAddresses []swarm.Address // shard addresses + pAddresses []swarm.Address // parity addresses + cache map[string]inflightChunk // map from chunk address to cached shard chunk data + erasureData [][]byte // data + parity shards for erasure decoding; TODO mutex + encrypted bool // swarm datashards are encrypted +} + +// New returns a getter object which is used to retrieve children of an intermediate chunk +func New(sAddresses, pAddresses []swarm.Address, g storage.Getter, p storage.Putter) storage.Getter { + encrypted := len(sAddresses[0].Bytes()) == swarm.HashSize*2 + shards := len(sAddresses) + parities := len(pAddresses) + n := shards + parities + erasureData := make([][]byte, n) + cache := make(map[string]inflightChunk, n) + // init cache + for i, addr := range sAddresses { + cache[addr.String()] = inflightChunk{ + pos: i, + // wait channel initialization is needed when recovery starts + } + } + for i, addr := range pAddresses { + cache[addr.String()] = inflightChunk{ + pos: len(sAddresses) + i, + // no wait channel initialization is needed + } + } + + return &getter{ + Getter: g, + Putter: p, + sAddresses: sAddresses, + pAddresses: pAddresses, + cache: cache, + encrypted: encrypted, + erasureData: erasureData, + } +} + +// Get will call parities and other sibling chunks if the chunk address cannot be retrieved +// assumes it is called for data shards only +func (g *getter) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { + g.mu.Lock() + cValue, ok := g.cache[addr.String()] + g.mu.Unlock() + if !ok || cValue.pos >= len(g.sAddresses) { + return nil, noDataAddressIncludedError{addr.String()} + } + + if cValue.wait != nil { // equals to g.processing but does not need lock again + return g.getAfterProcessed(ctx, addr) + } + + ch, err := g.Getter.Get(ctx, addr) + if err == nil { + return ch, nil + } + if errors.Is(storage.ErrNotFound, err) && len(g.pAddresses) == 0 { + return nil, noRedundancyError{addr.String()} + } + + // during the get, the recovery may have started by other process + if g.processing(addr) { + return g.getAfterProcessed(ctx, addr) + } + + return g.executeStrategies(ctx, addr) +} + +// Inc increments the counter for the given key. +func (g *getter) setErasureData(index int, data []byte) { + g.mu.Lock() + g.erasureData[index] = data + g.mu.Unlock() +} + +// processing returns whether the recovery workflow has been started +func (g *getter) processing(addr swarm.Address) bool { + g.mu.Lock() + defer g.mu.Unlock() + iCh := g.cache[addr.String()] + return iCh.wait != nil +} + +// getAfterProcessed returns chunk from the cache +func (g *getter) getAfterProcessed(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { + g.mu.Lock() + c, ok := g.cache[addr.String()] + // sanity check + if !ok { + return nil, fmt.Errorf("redundancy getter: chunk %s should have been in the cache", addr.String()) + } + + cacheData := g.erasureData[c.pos] + g.mu.Unlock() + if cacheData != nil { + return g.cacheDataToChunk(addr, cacheData) + } + + select { + case <-c.wait: + return g.cacheDataToChunk(addr, cacheData) + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// executeStrategies executes recovery strategies from redundancy for the given swarm address +func (g *getter) executeStrategies(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { + g.initWaitChannels() + err := g.cautiousStrategy(ctx) + if err != nil { + g.closeWaitChannels() + return nil, err + } + + return g.getAfterProcessed(ctx, addr) +} + +// initWaitChannels initializes the wait channels in the cache mapping which indicates the start of the recovery process as well +func (g *getter) initWaitChannels() { + g.mu.Lock() + for _, addr := range g.sAddresses { + iCh := g.cache[addr.String()] + iCh.wait = make(chan struct{}) + g.cache[addr.String()] = iCh + } + g.mu.Unlock() +} + +// closeChannls closes all pending channels +func (g *getter) closeWaitChannels() { + for _, addr := range g.sAddresses { + c := g.cache[addr.String()] + if !channelIsClosed(c.wait) { + close(c.wait) + } + } +} + +// cautiousStrategy requests all chunks (data and parity) on the level +// and if it has enough data for erasure decoding then it cancel other requests +func (g *getter) cautiousStrategy(ctx context.Context) error { + requiredChunks := len(g.sAddresses) + subContext, cancelContext := context.WithCancel(ctx) + retrievedCh := make(chan struct{}, requiredChunks+len(g.pAddresses)) + var wg sync.WaitGroup + + addresses := append(g.sAddresses, g.pAddresses...) + for _, a := range addresses { + wg.Add(1) + c := g.cache[a.String()] + go func(a swarm.Address, c inflightChunk) { + defer wg.Done() + // enrypted chunk data should remain encrypted + address := swarm.NewAddress(a.Bytes()[:swarm.HashSize]) + ch, err := g.Getter.Get(subContext, address) + if err != nil { + return + } + g.setErasureData(c.pos, ch.Data()) + if c.pos < len(g.sAddresses) && !channelIsClosed(c.wait) { + close(c.wait) + } + retrievedCh <- struct{}{} + }(a, c) + } + + // Goroutine to wait for WaitGroup completion + go func() { + wg.Wait() + close(retrievedCh) + }() + retrieved := 0 + for retrieved < requiredChunks { + _, ok := <-retrievedCh + if !ok { + break + } + retrieved++ + } + cancelContext() + + if retrieved < requiredChunks { + return cannotRecoverError{requiredChunks - retrieved} + } + + return g.erasureDecode(ctx) +} + +// erasureDecode perform Reed-Solomon recovery on data +// assumes it is called after filling up cache with the required amount of shards and parities +func (g *getter) erasureDecode(ctx context.Context) error { + enc, err := reedsolomon.New(len(g.sAddresses), len(g.pAddresses)) + if err != nil { + return err + } + + // missing chunks + var missingIndices []int + for i := range g.sAddresses { + if g.erasureData[i] == nil { + missingIndices = append(missingIndices, i) + } + } + + g.mu.Lock() + err = enc.ReconstructData(g.erasureData) + g.mu.Unlock() + if err != nil { + return err + } + + g.closeWaitChannels() + // save missing chunks + for _, index := range missingIndices { + data := g.erasureData[index] + addr := g.sAddresses[index] + err := g.Putter.Put(ctx, swarm.NewChunk(addr, data)) + if err != nil { + return err + } + } + return nil +} + +// cacheDataToChunk transforms passed chunk data to legit swarm chunk +func (g *getter) cacheDataToChunk(addr swarm.Address, chData []byte) (swarm.Chunk, error) { + if chData == nil { + return nil, isNotRecoveredError{addr.String()} + } + if g.encrypted { + data, err := store.DecryptChunkData(chData, addr.Bytes()[swarm.HashSize:]) + if err != nil { + return nil, err + } + chData = data + } + + return swarm.NewChunk(addr, chData), nil +} + +func channelIsClosed(wait <-chan struct{}) bool { + select { + case _, ok := <-wait: + return !ok + default: + return false + } +} diff --git a/pkg/file/redundancy/getter/getter_test.go b/pkg/file/redundancy/getter/getter_test.go new file mode 100644 index 00000000000..addee4b6028 --- /dev/null +++ b/pkg/file/redundancy/getter/getter_test.go @@ -0,0 +1,217 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package getter_test + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "testing" + + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/file/redundancy/getter" + "github.com/ethersphere/bee/pkg/storage" + inmem "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/klauspost/reedsolomon" +) + +func initData(ctx context.Context, erasureBuffer [][]byte, dataShardCount int, s storage.ChunkStore) ([]swarm.Address, []swarm.Address, error) { + spanBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(spanBytes, swarm.ChunkSize) + for i := 0; i < dataShardCount; i++ { + erasureBuffer[i] = make([]byte, swarm.HashSize) + _, err := io.ReadFull(rand.Reader, erasureBuffer[i]) + if err != nil { + return nil, nil, err + } + copy(erasureBuffer[i], spanBytes) + } + // make parity shards + for i := dataShardCount; i < len(erasureBuffer); i++ { + erasureBuffer[i] = make([]byte, swarm.HashSize) + } + // generate parity chunks + rs, err := reedsolomon.New(dataShardCount, len(erasureBuffer)-dataShardCount) + if err != nil { + return nil, nil, err + } + err = rs.Encode(erasureBuffer) + if err != nil { + return nil, nil, err + } + // calculate chunk addresses and upload to the store + var sAddresses, pAddresses []swarm.Address + for i := 0; i < dataShardCount; i++ { + chunk, err := cac.NewWithDataSpan(erasureBuffer[i]) + if err != nil { + return nil, nil, err + } + err = s.Put(ctx, chunk) + if err != nil { + return nil, nil, err + } + sAddresses = append(sAddresses, chunk.Address()) + } + for i := dataShardCount; i < len(erasureBuffer); i++ { + chunk, err := cac.NewWithDataSpan(erasureBuffer[i]) + if err != nil { + return nil, nil, err + } + err = s.Put(ctx, chunk) + if err != nil { + return nil, nil, err + } + pAddresses = append(pAddresses, chunk.Address()) + } + + return sAddresses, pAddresses, err +} + +func dataShardsAvailable(ctx context.Context, sAddresses []swarm.Address, s storage.ChunkStore) error { + for i := 0; i < len(sAddresses); i++ { + has, err := s.Has(ctx, sAddresses[i]) + if err != nil { + return err + } + if !has { + return fmt.Errorf("datashard %d is not available", i) + } + } + return nil +} + +func TestDecoding(t *testing.T) { + s := inmem.New() + + erasureBuffer := make([][]byte, 128) + dataShardCount := 100 + ctx := context.TODO() + + sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) + if err != nil { + t.Fatal(err) + } + + getter := getter.New(sAddresses, pAddresses, s, s) + // sanity check - all data shards are retrievable + for i := 0; i < dataShardCount; i++ { + ch, err := getter.Get(ctx, sAddresses[i]) + if err != nil { + t.Fatalf("address %s at index %d is not retrievable by redundancy getter", sAddresses[i], i) + } + if !bytes.Equal(ch.Data(), erasureBuffer[i]) { + t.Fatalf("retrieved chunk data differ from the original at index %d", i) + } + } + + // remove maximum possible chunks from storage + removeChunkCount := len(erasureBuffer) - dataShardCount + for i := 0; i < removeChunkCount; i++ { + err := s.Delete(ctx, sAddresses[i]) + if err != nil { + t.Fatal(err) + } + } + + err = dataShardsAvailable(ctx, sAddresses, s) // sanity check + if err == nil { + t.Fatalf("some data shards should be missing") + } + ch, err := getter.Get(ctx, sAddresses[0]) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(ch.Data(), erasureBuffer[0]) { + t.Fatalf("retrieved chunk data differ from the original at index %d", 0) + } + err = dataShardsAvailable(ctx, sAddresses, s) + if err != nil { + t.Fatal(err) + } +} + +func TestRecoveryLimits(t *testing.T) { + s := inmem.New() + + erasureBuffer := make([][]byte, 8) + dataShardCount := 5 + ctx := context.TODO() + + sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) + if err != nil { + t.Fatal(err) + } + + _getter := getter.New(sAddresses, pAddresses, s, s) + + // remove more chunks that can be corrected by erasure code + removeChunkCount := len(erasureBuffer) - dataShardCount + 1 + for i := 0; i < removeChunkCount; i++ { + err := s.Delete(ctx, sAddresses[i]) + if err != nil { + t.Fatal(err) + } + } + _, err = _getter.Get(ctx, sAddresses[0]) + if !getter.IsCannotRecoverError(err, 1) { + t.Fatal(err) + } + + // call once more + _, err = _getter.Get(ctx, sAddresses[0]) + if !getter.IsNotRecoveredError(err, sAddresses[0].String()) { + t.Fatal(err) + } +} + +func TestNoRedundancyOnRecovery(t *testing.T) { + s := inmem.New() + + dataShardCount := 5 + erasureBuffer := make([][]byte, dataShardCount) + ctx := context.TODO() + + sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) + if err != nil { + t.Fatal(err) + } + + _getter := getter.New(sAddresses, pAddresses, s, s) + + // remove one chunk that trying to request later + err = s.Delete(ctx, sAddresses[0]) + if err != nil { + t.Fatal(err) + } + _, err = _getter.Get(ctx, sAddresses[0]) + if !getter.IsNoRedundancyError(err, sAddresses[0].String()) { + t.Fatal(err) + } +} + +func TestNoDataAddressIncluded(t *testing.T) { + s := inmem.New() + + erasureBuffer := make([][]byte, 8) + dataShardCount := 5 + ctx := context.TODO() + + sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) + if err != nil { + t.Fatal(err) + } + + _getter := getter.New(sAddresses, pAddresses, s, s) + + // trying to retrieve a parity address + _, err = _getter.Get(ctx, pAddresses[0]) + if !getter.IsNoDataAddressIncludedError(err, pAddresses[0].String()) { + t.Fatal(err) + } +} diff --git a/pkg/file/redundancy/level.go b/pkg/file/redundancy/level.go index bdd7226af72..406e9b6c437 100644 --- a/pkg/file/redundancy/level.go +++ b/pkg/file/redundancy/level.go @@ -4,7 +4,11 @@ package redundancy -import "github.com/ethersphere/bee/pkg/swarm" +import ( + "errors" + + "github.com/ethersphere/bee/pkg/swarm" +) type Level uint8 @@ -20,18 +24,11 @@ const maxLevel = 8 // GetParities returns number of parities based on appendix F table 5 func (l Level) GetParities(shards int) int { - switch l { - case MEDIUM: - return mediumEt.getParities(shards) - case STRONG: - return strongEt.getParities(shards) - case INSANE: - return insaneEt.getParities(shards) - case PARANOID: - return paranoidEt.getParities(shards) - default: + et, err := l.getErasureTable() + if err != nil { return 0 } + return et.getParities(shards) } // GetMaxShards returns back the maximum number of effective data chunks @@ -42,17 +39,40 @@ func (l Level) GetMaxShards() int { // GetEncParities returns number of parities for encrypted chunks based on appendix F table 6 func (l Level) GetEncParities(shards int) int { + et, err := l.getEncErasureTable() + if err != nil { + return 0 + } + return et.getParities(shards) +} + +func (l Level) getErasureTable() (erasureTable, error) { switch l { case MEDIUM: - return encMediumEt.getParities(shards) + return *mediumEt, nil case STRONG: - return encStrongEt.getParities(shards) + return *strongEt, nil case INSANE: - return encInsaneEt.getParities(shards) + return *insaneEt, nil case PARANOID: - return encParanoidEt.getParities(shards) + return *paranoidEt, nil default: - return 0 + return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") + } +} + +func (l Level) getEncErasureTable() (erasureTable, error) { + switch l { + case MEDIUM: + return *encMediumEt, nil + case STRONG: + return *encStrongEt, nil + case INSANE: + return *encInsaneEt, nil + case PARANOID: + return *encParanoidEt, nil + default: + return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") } } diff --git a/pkg/file/redundancy/table.go b/pkg/file/redundancy/table.go index a88505da2b1..1b62e1ec34d 100644 --- a/pkg/file/redundancy/table.go +++ b/pkg/file/redundancy/table.go @@ -4,6 +4,8 @@ package redundancy +import "fmt" + type erasureTable struct { shards []int parities []int @@ -50,3 +52,13 @@ func (et *erasureTable) getParities(maxShards int) int { } return 0 } + +// getMinShards returns back the minimum shard number respect to the given parity number +func (et *erasureTable) GetMinShards(parities int) (int, error) { + for k, p := range et.parities { + if p == parities { + return et.shards[k], nil + } + } + return 0, fmt.Errorf("parity table: there is no minimum shard number for given parity %d", parities) +} diff --git a/pkg/file/utils.go b/pkg/file/utils.go index 298a6578447..59afd2e9cef 100644 --- a/pkg/file/utils.go +++ b/pkg/file/utils.go @@ -5,10 +5,47 @@ package file import ( + "bytes" + "errors" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/swarm" ) +// ChunkPayloadSize returns the effective byte length of an intermediate chunk +// assumes data is always chunk size (without span) +func ChunkPayloadSize(data []byte) (int, error) { + l := len(data) + for l >= swarm.HashSize { + if !bytes.Equal(data[l-swarm.HashSize:l], swarm.ZeroAddress.Bytes()) { + return l, nil + } + + l -= swarm.HashSize + } + + return 0, errors.New("redundancy getter: intermediate chunk does not have at least a child") +} + +// ChunkAddresses returns data shards and parities of the intermediate chunk +// assumes data is truncated by ChunkPayloadSize +func ChunkAddresses(data []byte, parities, reflen int) (sAddresses, pAddresses []swarm.Address) { + shards := (len(data) - parities*swarm.HashSize) / reflen + sAddresses = make([]swarm.Address, shards) + pAddresses = make([]swarm.Address, parities) + offset := 0 + for i := 0; i < shards; i++ { + sAddresses[i] = swarm.NewAddress(data[offset : offset+reflen]) + offset += reflen + } + for i := 0; i < parities; i++ { + pAddresses[i] = swarm.NewAddress(data[offset : offset+swarm.HashSize]) + offset += swarm.HashSize + } + + return sAddresses, pAddresses +} + // ReferenceCount brute-forces the data shard count from which identify the parity count as well in a substree // assumes span > swarm.chunkSize // returns data and parity shard number diff --git a/pkg/node/bootstrap.go b/pkg/node/bootstrap.go index bf2bc4030cb..a1e48228da7 100644 --- a/pkg/node/bootstrap.go +++ b/pkg/node/bootstrap.go @@ -237,7 +237,7 @@ func bootstrapNode( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - reader, l, err = joiner.New(ctx, localStore.Download(true), snapshotReference) + reader, l, err = joiner.New(ctx, localStore.Download(true), localStore.Cache(), snapshotReference) if err != nil { logger.Warning("bootstrap: file joiner failed", "error", err) continue diff --git a/pkg/node/node.go b/pkg/node/node.go index 197c048a263..960a907a836 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -1069,7 +1069,7 @@ func NewBee( b.resolverCloser = multiResolver feedFactory := factory.New(localStore.Download(true)) - steward := steward.New(localStore, retrieval) + steward := steward.New(localStore, retrieval, localStore.Cache()) extraOpts := api.ExtraOptions{ Pingpong: pingPong, diff --git a/pkg/steward/steward.go b/pkg/steward/steward.go index 9f7bf0641a2..73b3d17b87f 100644 --- a/pkg/steward/steward.go +++ b/pkg/steward/steward.go @@ -36,11 +36,11 @@ type steward struct { netTraverser traversal.Traverser } -func New(ns storer.NetStore, r retrieval.Interface) Interface { +func New(ns storer.NetStore, r retrieval.Interface, joinerPutter storage.Putter) Interface { return &steward{ netStore: ns, - traverser: traversal.New(ns.Download(true)), - netTraverser: traversal.New(&netGetter{r}), + traverser: traversal.New(ns.Download(true), joinerPutter), + netTraverser: traversal.New(&netGetter{r}, joinerPutter), } } diff --git a/pkg/steward/steward_test.go b/pkg/steward/steward_test.go index 3836f75dcfd..bdddbdd3d08 100644 --- a/pkg/steward/steward_test.go +++ b/pkg/steward/steward_test.go @@ -24,15 +24,16 @@ import ( func TestSteward(t *testing.T) { t.Parallel() + inmem := inmemchunkstore.New() var ( ctx = context.Background() chunks = 1000 data = make([]byte, chunks*4096) //1k chunks - chunkStore = inmemchunkstore.New() + chunkStore = inmem store = mockstorer.NewWithChunkStore(chunkStore) localRetrieval = &localRetriever{ChunkStore: chunkStore} - s = steward.New(store, localRetrieval) + s = steward.New(store, localRetrieval, inmem) stamper = postagetesting.NewStamper() ) diff --git a/pkg/storer/epoch_migration.go b/pkg/storer/epoch_migration.go index f93ecafa486..e02ab5b6313 100644 --- a/pkg/storer/epoch_migration.go +++ b/pkg/storer/epoch_migration.go @@ -392,6 +392,7 @@ func (e *epochMigrator) migratePinning(ctx context.Context) error { return swarm.NewChunk(addr, chData), nil }), + pStorage.ChunkStore(), ) e.logger.Debug("migrating pinning collections, if all the chunks in the collection" + diff --git a/pkg/traversal/traversal.go b/pkg/traversal/traversal.go index e9609a90c8f..dbdb5603e82 100644 --- a/pkg/traversal/traversal.go +++ b/pkg/traversal/traversal.go @@ -29,19 +29,20 @@ type Traverser interface { } // New constructs for a new Traverser. -func New(store storage.Getter) Traverser { - return &service{store: store} +func New(getter storage.Getter, putter storage.Putter) Traverser { + return &service{getter: getter, putter: putter} } // service is implementation of Traverser using storage.Storer as its storage. type service struct { - store storage.Getter + getter storage.Getter + putter storage.Putter } // Traverse implements Traverser.Traverse method. func (s *service) Traverse(ctx context.Context, addr swarm.Address, iterFn swarm.AddressIterFunc) error { processBytes := func(ref swarm.Address) error { - j, _, err := joiner.New(ctx, s.store, ref) + j, _, err := joiner.New(ctx, s.getter, s.putter, ref) if err != nil { return fmt.Errorf("traversal: joiner error on %q: %w", ref, err) } @@ -54,7 +55,7 @@ func (s *service) Traverse(ctx context.Context, addr swarm.Address, iterFn swarm // skip SOC check for encrypted references if addr.IsValidLength() { - ch, err := s.store.Get(ctx, addr) + ch, err := s.getter.Get(ctx, addr) if err != nil { return fmt.Errorf("traversal: failed to get root chunk %s: %w", addr.String(), err) } @@ -64,7 +65,7 @@ func (s *service) Traverse(ctx context.Context, addr swarm.Address, iterFn swarm } } - ls := loadsave.NewReadonly(s.store) + ls := loadsave.NewReadonly(s.getter) switch mf, err := manifest.NewDefaultManifestReference(addr, ls); { case errors.Is(err, manifest.ErrInvalidManifestType): break diff --git a/pkg/traversal/traversal_test.go b/pkg/traversal/traversal_test.go index 75b5f925283..7d9d473e1e4 100644 --- a/pkg/traversal/traversal_test.go +++ b/pkg/traversal/traversal_test.go @@ -167,7 +167,7 @@ func TestTraversalBytes(t *testing.T) { t.Fatal(err) } - err = traversal.New(storerMock).Traverse(ctx, address, iter.Next) + err = traversal.New(storerMock, storerMock).Traverse(ctx, address, iter.Next) if err != nil { t.Fatal(err) } @@ -262,7 +262,7 @@ func TestTraversalFiles(t *testing.T) { t.Fatal(err) } - ls := loadsave.New(storerMock, pipelineFactory(storerMock, false)) + ls := loadsave.New(storerMock, storerMock, pipelineFactory(storerMock, false)) fManifest, err := manifest.NewDefaultManifest(ls, false) if err != nil { t.Fatal(err) @@ -294,7 +294,7 @@ func TestTraversalFiles(t *testing.T) { t.Fatal(err) } - err = traversal.New(storerMock).Traverse(ctx, address, iter.Next) + err = traversal.New(storerMock, storerMock).Traverse(ctx, address, iter.Next) if err != nil { t.Fatal(err) } @@ -421,7 +421,7 @@ func TestTraversalManifest(t *testing.T) { } wantHashes = append(wantHashes, tc.manifestHashes...) - ls := loadsave.New(storerMock, pipelineFactory(storerMock, false)) + ls := loadsave.New(storerMock, storerMock, pipelineFactory(storerMock, false)) dirManifest, err := manifest.NewMantarayManifest(ls, false) if err != nil { t.Fatal(err) @@ -452,7 +452,7 @@ func TestTraversalManifest(t *testing.T) { t.Fatal(err) } - err = traversal.New(storerMock).Traverse(ctx, address, iter.Next) + err = traversal.New(storerMock, storerMock).Traverse(ctx, address, iter.Next) if err != nil { t.Fatal(err) } @@ -490,7 +490,7 @@ func TestTraversalSOC(t *testing.T) { t.Fatal(err) } - err = traversal.New(store).Traverse(ctx, sch.Address(), iter.Next) + err = traversal.New(store, store).Traverse(ctx, sch.Address(), iter.Next) if err != nil { t.Fatal(err) } From a5987fa85afbeee90dcac59e123acac6cd67cea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Mon, 4 Dec 2023 17:06:26 +0100 Subject: [PATCH 03/23] fix: data race --- pkg/file/redundancy/getter/getter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go index 982c5080109..f7ef497743c 100644 --- a/pkg/file/redundancy/getter/getter.go +++ b/pkg/file/redundancy/getter/getter.go @@ -288,8 +288,8 @@ func (g *getter) erasureDecode(ctx context.Context) error { } g.mu.Lock() + defer g.mu.Unlock() err = enc.ReconstructData(g.erasureData) - g.mu.Unlock() if err != nil { return err } From e0fbcf73d651705dfc0f02bcb6b50d1b4761b641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tr=C3=B3n?= Date: Mon, 4 Dec 2023 19:10:29 +0100 Subject: [PATCH 04/23] feat(replicas): new replicas pkg for redundancy by dispersed replicas (#4453) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Viktor Levente Tóth --- pkg/replicas/export_test.go | 16 +++ pkg/replicas/getter.go | 163 +++++++++++++++++++++ pkg/replicas/getter_test.go | 271 +++++++++++++++++++++++++++++++++++ pkg/replicas/putter.go | 60 ++++++++ pkg/replicas/putter_test.go | 195 +++++++++++++++++++++++++ pkg/replicas/replica_test.go | 59 ++++++++ pkg/replicas/replicas.go | 145 +++++++++++++++++++ 7 files changed, 909 insertions(+) create mode 100644 pkg/replicas/export_test.go create mode 100644 pkg/replicas/getter.go create mode 100644 pkg/replicas/getter_test.go create mode 100644 pkg/replicas/putter.go create mode 100644 pkg/replicas/putter_test.go create mode 100644 pkg/replicas/replica_test.go create mode 100644 pkg/replicas/replicas.go diff --git a/pkg/replicas/export_test.go b/pkg/replicas/export_test.go new file mode 100644 index 00000000000..562dc9512c3 --- /dev/null +++ b/pkg/replicas/export_test.go @@ -0,0 +1,16 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package replicas + +import "github.com/ethersphere/bee/pkg/storage" + +var ( + Counts = counts + Signer = signer +) + +func Wait(g storage.Getter) { + g.(*getter).wg.Wait() +} diff --git a/pkg/replicas/getter.go b/pkg/replicas/getter.go new file mode 100644 index 00000000000..8b335130fa7 --- /dev/null +++ b/pkg/replicas/getter.go @@ -0,0 +1,163 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// the code below implements the integration of dispersed replicas in chunk fetching. +// using storage.Getter interface. +package replicas + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/soc" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +// ErrSwarmageddon is returned in case of a vis mayor called Swarmageddon. +// Swarmageddon is the situation when none of the replicas can be retrieved. +// If 2^{depth} replicas were uploaded and they all have valid postage stamps +// then the probability of Swarmageddon is less than 0.000001 +// assuming the error rate of chunk retrievals stays below the level expressed +// as depth by the publisher. +type ErrSwarmageddon struct { + error +} + +func (err *ErrSwarmageddon) Unwrap() []error { + if err == nil || err.error == nil { + return nil + } + var uwe interface{ Unwrap() []error } + if !errors.As(err.error, &uwe) { + return nil + } + return uwe.Unwrap() +} + +// getter is the private implementation of storage.Getter, an interface for +// retrieving chunks. This getter embeds the original simple chunk getter and extends it +// to a multiplexed variant that fetches chunks with replicas. +// +// the strategy to retrieve a chunk that has replicas can be configured with a few parameters: +// - RetryInterval: the delay before a new batch of replicas is fetched. +// - depth: 2^{depth} is the total number of additional replicas that have been uploaded +// (by default, it is assumed to be 4, ie. total of 16) +// - (not implemented) pivot: replicas with address in the proximity of pivot will be tried first +type getter struct { + wg sync.WaitGroup + storage.Getter + level redundancy.Level +} + +// NewGetter is the getter constructor +func NewGetter(g storage.Getter, level redundancy.Level) storage.Getter { + return &getter{Getter: g, level: level} +} + +// Get makes the getter satisfy the storage.Getter interface +func (g *getter) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // channel that the results (retrieved chunks) are gathered to from concurrent + // workers each fetching a replica + resultC := make(chan swarm.Chunk) + // errc collects the errors + errc := make(chan error, 17) + var errs []error + errcnt := 0 + + // concurrently call to retrieve chunk using original CAC address + g.wg.Add(1) + go func() { + defer g.wg.Done() + ch, err := g.Getter.Get(ctx, addr) + if err != nil { + errc <- err + return + } + + select { + case resultC <- ch: + case <-ctx.Done(): + } + }() + // counters + n := 0 // counts the replica addresses tried + target := 2 // the number of replicas attempted to download in this batch + total := counts[g.level] // total number of replicas allowed (and makes sense) to retrieve + + // + rr := newReplicator(addr, uint8(g.level)) + next := rr.c + var wait <-chan time.Time // nil channel to disable case + // addresses used are doubling each period of search expansion + // (at intervals of RetryInterval) + ticker := time.NewTicker(RetryInterval) + defer ticker.Stop() + for level := uint8(0); level <= uint8(g.level); { + select { + // at least one chunk is retrieved, cancel the rest and return early + case chunk := <-resultC: + cancel() + return chunk, nil + + case err = <-errc: + errs = append(errs, err) + errcnt++ + if errcnt > total { + return nil, &ErrSwarmageddon{errors.Join(errs...)} + } + + // ticker switches on the address channel + case <-wait: + wait = nil + next = rr.c + level++ + target = 1 << level + n = 0 + continue + + // getting the addresses in order + case so := <-next: + if so == nil { + next = nil + continue + } + + g.wg.Add(1) + go func() { + defer g.wg.Done() + ch, err := g.Getter.Get(ctx, swarm.NewAddress(so.addr)) + if err != nil { + errc <- err + return + } + + soc, err := soc.FromChunk(ch) + if err != nil { + errc <- err + return + } + + select { + case resultC <- soc.WrappedChunk(): + case <-ctx.Done(): + } + }() + n++ + if n < target { + continue + } + next = nil + wait = ticker.C + } + } + + return nil, nil +} diff --git a/pkg/replicas/getter_test.go b/pkg/replicas/getter_test.go new file mode 100644 index 00000000000..2664e9a6a14 --- /dev/null +++ b/pkg/replicas/getter_test.go @@ -0,0 +1,271 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package replicas_test + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/replicas" + "github.com/ethersphere/bee/pkg/soc" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +type testGetter struct { + ch swarm.Chunk + now time.Time + origCalled chan struct{} + origIndex int + errf func(int) chan struct{} + firstFound int32 + attempts atomic.Int32 + cancelled chan struct{} + addresses [17]swarm.Address + latencies [17]time.Duration +} + +func (tg *testGetter) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + i := tg.attempts.Add(1) - 1 + tg.addresses[i] = addr + tg.latencies[i] = time.Since(tg.now) + + if addr.Equal(tg.ch.Address()) { + tg.origIndex = int(i) + close(tg.origCalled) + ch = tg.ch + } + + if i != tg.firstFound { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-tg.errf(int(i)): + return nil, storage.ErrNotFound + } + } + defer func() { + go func() { + select { + case <-time.After(100 * time.Millisecond): + case <-ctx.Done(): + close(tg.cancelled) + } + }() + }() + + if ch != nil { + return ch, nil + } + return soc.New(addr.Bytes(), tg.ch).Sign(replicas.Signer) +} + +func newTestGetter(ch swarm.Chunk, firstFound int, errf func(int) chan struct{}) *testGetter { + return &testGetter{ + ch: ch, + errf: errf, + firstFound: int32(firstFound), + cancelled: make(chan struct{}), + origCalled: make(chan struct{}), + } +} + +// Close implements the storage.Getter interface +func (tg *testGetter) Close() error { + return nil +} + +func TestGetter(t *testing.T) { + t.Parallel() + // failure is a struct that defines a failure scenario to test + type failure struct { + name string + err error + errf func(int, int) func(int) chan struct{} + } + // failures is a list of failure scenarios to test + failures := []failure{ + { + "timeout", + context.Canceled, + func(_, _ int) func(i int) chan struct{} { + return func(i int) chan struct{} { + return nil + } + }, + }, + { + "not found", + storage.ErrNotFound, + func(_, _ int) func(i int) chan struct{} { + c := make(chan struct{}) + close(c) + return func(i int) chan struct{} { + return c + } + }, + }, + } + type test struct { + name string + failure failure + level int + count int + found int + } + + var tests []test + for _, f := range failures { + for level, c := range replicas.Counts { + for j := 0; j <= c*2+1; j++ { + tests = append(tests, test{ + name: fmt.Sprintf("%s level %d count %d found %d", f.name, level, c, j), + failure: f, + level: level, + count: c, + found: j, + }) + } + } + } + + // initiailise the base chunk + chunkLen := 420 + buf := make([]byte, chunkLen) + if _, err := io.ReadFull(rand.Reader, buf); err != nil { + t.Fatal(err) + } + ch, err := cac.New(buf) + if err != nil { + t.Fatal(err) + } + // reset retry interval to speed up tests + retryInterval := replicas.RetryInterval + defer func() { replicas.RetryInterval = retryInterval }() + replicas.RetryInterval = 100 * time.Millisecond + + // run the tests + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // initiate a chunk retrieval session using replicas.Getter + // embedding a testGetter that simulates the behaviour of a chunk store + store := newTestGetter(ch, tc.found, tc.failure.errf(tc.found, tc.count)) + g := replicas.NewGetter(store, redundancy.Level(tc.level)) + store.now = time.Now() + ctx, cancel := context.WithCancel(context.Background()) + if tc.found > tc.count { + wait := replicas.RetryInterval / 2 * time.Duration(1+2*tc.level) + go func() { + time.Sleep(wait) + cancel() + }() + } + _, err := g.Get(ctx, ch.Address()) + replicas.Wait(g) + cancel() + + // test the returned error + if tc.found <= tc.count { + if err != nil { + t.Fatalf("expected no error. got %v", err) + } + // if j <= c, the original chunk should be retrieved and the context should be cancelled + t.Run("retrievals cancelled", func(t *testing.T) { + + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for context to be cancelled") + case <-store.cancelled: + } + }) + + } else { + if err == nil { + t.Fatalf("expected error. got ") + } + + t.Run("returns correct error", func(t *testing.T) { + var esg *replicas.ErrSwarmageddon + if !errors.As(err, &esg) { + t.Fatalf("incorrect error. want Swarmageddon. got %v", err) + } + errs := esg.Unwrap() + for _, err := range errs { + if !errors.Is(err, tc.failure.err) { + t.Fatalf("incorrect error. want it to wrap %v. got %v", tc.failure.err, err) + } + } + if len(errs) != tc.count+1 { + t.Fatalf("incorrect error. want %d. got %d", tc.count+1, len(errs)) + } + }) + } + + attempts := int(store.attempts.Load()) + // the original chunk should be among those attempted for retrieval + addresses := store.addresses[:attempts] + latencies := store.latencies[:attempts] + t.Run("original address called", func(t *testing.T) { + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting form original address to be attempted for retrieval") + case <-store.origCalled: + i := store.origIndex + if i > 2 { + t.Fatalf("original address called too late. want at most 2 (preceding attempts). got %v (latency: %v)", i, latencies[i]) + } + addresses = append(addresses[:i], addresses[i+1:]...) + latencies = append(latencies[:i], latencies[i+1:]...) + attempts-- + } + }) + + t.Run("retrieved count", func(t *testing.T) { + if attempts > tc.count { + t.Fatalf("too many attempts to retrieve a replica: want at most %v. got %v.", tc.count, attempts) + } + if tc.found > tc.count { + if attempts < tc.count { + t.Fatalf("too few attempts to retrieve a replica: want at least %v. got %v.", tc.count, attempts) + } + return + } + max := 2 + for i := 1; i < tc.level && max < tc.found; i++ { + max = max * 2 + } + if attempts > max { + t.Fatalf("too many attempts to retrieve a replica: want at most %v. got %v. latencies %v", max, attempts, latencies) + } + }) + + t.Run("dispersion", func(t *testing.T) { + + if err := dispersed(redundancy.Level(tc.level), ch, addresses); err != nil { + t.Fatalf("addresses are not dispersed: %v", err) + } + }) + + t.Run("latency", func(t *testing.T) { + + for i, latency := range latencies { + multiplier := latency / replicas.RetryInterval + if multiplier > 0 && i < replicas.Counts[multiplier-1] { + t.Fatalf("incorrect latency for retrieving replica %d: %v", i, err) + } + } + }) + + }) + } +} diff --git a/pkg/replicas/putter.go b/pkg/replicas/putter.go new file mode 100644 index 00000000000..2bcceb2e2da --- /dev/null +++ b/pkg/replicas/putter.go @@ -0,0 +1,60 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// the code below implements the integration of dispersed replicas in chunk upload. +// using storage.Putter interface. +package replicas + +import ( + "context" + "errors" + "sync" + + "github.com/ethersphere/bee/pkg/soc" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +// putter is the private implementation of the public storage.Putter interface +// putter extends the original putter to a concurrent multiputter +type putter struct { + putter storage.Putter +} + +// NewPutter is the putter constructor +func NewPutter(p storage.Putter) storage.Putter { + return &putter{p} +} + +// Put makes the getter satisfy the storage.Getter interface +func (p *putter) Put(ctx context.Context, ch swarm.Chunk) (err error) { + rlevel := getLevelFromContext(ctx) + errs := []error{p.putter.Put(ctx, ch)} + if rlevel == 0 { + return errs[0] + } + + rr := newReplicator(ch.Address(), uint8(rlevel)) + errc := make(chan error, counts[rlevel]) + wg := sync.WaitGroup{} + for r := range rr.c { + r := r + wg.Add(1) + go func() { + defer wg.Done() + sch, err := soc.New(r.id, ch).Sign(signer) + if err == nil { + err = p.putter.Put(ctx, sch) + } + errc <- err + }() + } + + wg.Wait() + close(errc) + for err := range errc { + errs = append(errs, err) + } + return errors.Join(errs...) +} diff --git a/pkg/replicas/putter_test.go b/pkg/replicas/putter_test.go new file mode 100644 index 00000000000..e53bfc6f1bf --- /dev/null +++ b/pkg/replicas/putter_test.go @@ -0,0 +1,195 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package replicas_test + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/replicas" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" + "github.com/ethersphere/bee/pkg/swarm" +) + +var ( + errTestA = errors.New("A") + errTestB = errors.New("B") +) + +type testBasePutter struct { + getErrors func(context.Context, swarm.Address) error + putErrors func(context.Context, swarm.Address) error + store storage.ChunkStore +} + +func (tbp *testBasePutter) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { + + g := tbp.getErrors + if g != nil { + return nil, g(ctx, addr) + } + return tbp.store.Get(ctx, addr) +} + +func (tbp *testBasePutter) Put(ctx context.Context, ch swarm.Chunk) error { + + g := tbp.putErrors + if g != nil { + return g(ctx, ch.Address()) + } + return tbp.store.Put(ctx, ch) +} + +func TestPutter(t *testing.T) { + t.Parallel() + tcs := []struct { + level redundancy.Level + length int + }{ + {0, 1}, + {1, 1}, + {2, 1}, + {3, 1}, + {4, 1}, + {0, 4096}, + {1, 4096}, + {2, 4096}, + {3, 4096}, + {4, 4096}, + } + for _, tc := range tcs { + t.Run(fmt.Sprintf("redundancy:%d, size:%d", tc.level, tc.length), func(t *testing.T) { + buf := make([]byte, tc.length) + if _, err := io.ReadFull(rand.Reader, buf); err != nil { + t.Fatal(err) + } + ctx := context.Background() + ctx = replicas.SetLevel(ctx, tc.level) + + ch, err := cac.New(buf) + if err != nil { + t.Fatal(err) + } + store := inmemchunkstore.New() + defer store.Close() + p := replicas.NewPutter(store) + + if err := p.Put(ctx, ch); err != nil { + t.Fatalf("expected no error. got %v", err) + } + var addrs []swarm.Address + orig := false + _ = store.Iterate(ctx, func(chunk swarm.Chunk) (stop bool, err error) { + if ch.Address().Equal(chunk.Address()) { + orig = true + return false, nil + } + addrs = append(addrs, chunk.Address()) + return false, nil + }) + if !orig { + t.Fatal("origial chunk missing") + } + t.Run("dispersion", func(t *testing.T) { + if err := dispersed(tc.level, ch, addrs); err != nil { + t.Fatalf("addresses are not dispersed: %v", err) + } + }) + t.Run("attempts", func(t *testing.T) { + count := replicas.Counts[tc.level] + if len(addrs) != count { + t.Fatalf("incorrect number of attempts. want %v, got %v", count, len(addrs)) + } + }) + + t.Run("replication", func(t *testing.T) { + if err := replicated(store, ch, addrs); err != nil { + t.Fatalf("chunks are not replicas: %v", err) + } + }) + }) + } + t.Run("error handling", func(t *testing.T) { + tcs := []struct { + name string + level redundancy.Level + length int + f func(*testBasePutter) *testBasePutter + err []error + }{ + {"put errors", 4, 4096, func(tbp *testBasePutter) *testBasePutter { + var j int32 + i := &j + atomic.StoreInt32(i, 0) + tbp.putErrors = func(ctx context.Context, _ swarm.Address) error { + j := atomic.AddInt32(i, 1) + if j == 6 { + return errTestA + } + if j == 12 { + return errTestB + } + return nil + } + return tbp + }, []error{errTestA, errTestB}}, + {"put latencies", 4, 4096, func(tbp *testBasePutter) *testBasePutter { + var j int32 + i := &j + atomic.StoreInt32(i, 0) + tbp.putErrors = func(ctx context.Context, _ swarm.Address) error { + j := atomic.AddInt32(i, 1) + if j == 6 { + select { + case <-time.After(100 * time.Millisecond): + case <-ctx.Done(): + return ctx.Err() + } + } + if j == 12 { + return errTestA + } + return nil + } + return tbp + }, []error{errTestA, context.DeadlineExceeded}}, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + buf := make([]byte, tc.length) + if _, err := io.ReadFull(rand.Reader, buf); err != nil { + t.Fatal(err) + } + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + ctx = replicas.SetLevel(ctx, tc.level) + ch, err := cac.New(buf) + if err != nil { + t.Fatal(err) + } + store := inmemchunkstore.New() + defer store.Close() + p := replicas.NewPutter(tc.f(&testBasePutter{store: store})) + errs := p.Put(ctx, ch) + for _, err := range tc.err { + if !errors.Is(errs, err) { + t.Fatalf("incorrect error. want it to contain %v. got %v.", tc.err, errs) + } + } + }) + } + }) + +} diff --git a/pkg/replicas/replica_test.go b/pkg/replicas/replica_test.go new file mode 100644 index 00000000000..ce56401d987 --- /dev/null +++ b/pkg/replicas/replica_test.go @@ -0,0 +1,59 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// replicas_test just contains helper functions to verify dispersion and replication +package replicas_test + +import ( + "context" + "errors" + "fmt" + + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/soc" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +// dispersed verifies that a set of addresses are maximally dispersed without repetition +func dispersed(level redundancy.Level, ch swarm.Chunk, addrs []swarm.Address) error { + nhoods := make(map[byte]bool) + + for _, addr := range addrs { + if len(addr.Bytes()) != swarm.HashSize { + return errors.New("corrupt data: invalid address length") + } + nh := addr.Bytes()[0] >> (8 - int(level)) + if nhoods[nh] { + return errors.New("not dispersed enough: duplicate neighbourhood") + } + nhoods[nh] = true + } + if len(nhoods) != len(addrs) { + return fmt.Errorf("not dispersed enough: unexpected number of neighbourhood covered: want %v. got %v", len(addrs), len(nhoods)) + } + + return nil +} + +// replicated verifies that the replica chunks are indeed replicas +// of the original chunk wrapped in soc +func replicated(store storage.ChunkStore, ch swarm.Chunk, addrs []swarm.Address) error { + ctx := context.Background() + for _, addr := range addrs { + chunk, err := store.Get(ctx, addr) + if err != nil { + return err + } + + sch, err := soc.FromChunk(chunk) + if err != nil { + return err + } + if !sch.WrappedChunk().Equal(ch) { + return errors.New("invalid replica: does not wrap original content addressed chunk") + } + } + return nil +} diff --git a/pkg/replicas/replicas.go b/pkg/replicas/replicas.go new file mode 100644 index 00000000000..65806abba7c --- /dev/null +++ b/pkg/replicas/replicas.go @@ -0,0 +1,145 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package replicas implements a scheme to replicate chunks +// in such a way that +// - the replicas are optimally dispersed to aid cross-neighbourhood redundancy +// - the replicas addresses can be deduced by retrievers only knowing the address +// of the original content addressed chunk +// - no new chunk validation rules are introduced +package replicas + +import ( + "context" + "encoding/hex" + "time" + + "github.com/ethersphere/bee/pkg/crypto" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/swarm" +) + +type redundancyLevelType struct{} + +var ( + // redundancyLevel is the context key for the redundancy level + redundancyLevel redundancyLevelType + // RetryInterval is the duration between successive additional requests + RetryInterval = 300 * time.Millisecond + // + // counts of replicas used for levels of increasing security + // the actual number of replicas needed to keep the error rate below 1/10^6 + // for the five levels of redundancy are 0, 2, 4, 5, 19 + // we use an approximation as the successive powers of 2 + counts = [5]int{0, 2, 4, 8, 16} + sums = [5]int{0, 2, 6, 14, 30} + privKey, _ = crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) + signer = crypto.NewDefaultSigner(privKey) + owner, _ = hex.DecodeString("dc5b20847f43d67928f49cd4f85d696b5a7617b5") +) + +// SetLevel sets the redundancy level in the context +func SetLevel(ctx context.Context, level redundancy.Level) context.Context { + return context.WithValue(ctx, redundancyLevel, level) +} + +// getLevelFromContext is a helper function to extract the redundancy level from the context +func getLevelFromContext(ctx context.Context) redundancy.Level { + rlevel := redundancy.PARANOID + if val := ctx.Value(redundancyLevel); val != nil { + rlevel = val.(redundancy.Level) + } + return rlevel +} + +// replicator running the find for replicas +type replicator struct { + addr []byte // chunk address + queue [16]*replica // to sort addresses according to di + exist [30]bool // maps the 16 distinct nibbles on all levels + sizes [5]int // number of distinct neighnourhoods redcorded for each depth + c chan *replica + depth uint8 +} + +// newReplicator replicator constructor +func newReplicator(addr swarm.Address, depth uint8) *replicator { + rr := &replicator{ + addr: addr.Bytes(), + sizes: counts, + c: make(chan *replica, 16), + depth: depth, + } + go rr.replicas() + return rr +} + +// replica of the mined SOC chunk (address) that serve as replicas +type replica struct { + addr, id []byte // byte slice of SOC address and SOC ID +} + +// replicate returns a replica params strucure seeded with a byte of entropy as argument +func (rr *replicator) replicate(i uint8) (sp *replica) { + // change the last byte of the address to create SOC ID + id := make([]byte, 32) + copy(id, rr.addr) + id[0] = i + // calculate SOC address for potential replica + h := swarm.NewHasher() + _, _ = h.Write(id) + _, _ = h.Write(owner) + return &replica{h.Sum(nil), id} +} + +// nh returns the lookup key for the neighbourhood of depth d +// to be used as index to the replicators exist array +func (r *replica) nh(d uint8) (nh int) { + return sums[d-1] + int(r.addr[0]>>(8-d)) +} + +// replicas enumerates replica parameters (SOC ID) pushing it in a channel given as argument +// the order of replicas is so that addresses are always maximally dispersed +// in successive sets of addresses. +// I.e., the binary tree representing the new addresses prefix bits up to depth is balanced +func (rr *replicator) replicas() { + defer close(rr.c) + n := 0 + for i := uint8(0); n < counts[rr.depth] && i < 255; i++ { + // create soc replica (ID and address using constant owner) + // the soc is added to neighbourhoods of depths in the closed interval [from...to] + r := rr.replicate(i) + d, m := rr.add(r, rr.depth) + if d == 0 { + continue + } + for m, r = range rr.queue[n:] { + if r == nil { + break + } + rr.c <- r + } + n += m + } +} + +// add inserts the soc replica into a replicator so that addresses are balanced +func (rr *replicator) add(r *replica, d uint8) (depth int, rank int) { + if d == 0 { + return 0, 0 + } + nh := r.nh(d) + if rr.exist[nh] { + return 0, 0 + } + rr.exist[nh] = true + l, o := rr.add(r, d-1) + if l == 0 { + o = rr.sizes[d-1] + rr.sizes[d-1]++ + rr.queue[o] = r + l = int(d) + } + return l, o +} From 5d1b73c48d7793d6a8e3eadf57603417479772e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Sun, 10 Dec 2023 21:55:10 +0100 Subject: [PATCH 05/23] refactor: reviews --- openapi/Swarm.yaml | 2 +- openapi/SwarmCommon.yaml | 2 +- pkg/api/api.go | 4 +- pkg/api/bzz.go | 3 +- pkg/api/dirs.go | 12 +++-- pkg/encryption/encryption.go | 7 --- pkg/file/joiner/joiner.go | 9 ++-- pkg/file/joiner/joiner_test.go | 9 ++-- pkg/file/pipeline/feeder/feeder.go | 1 - pkg/file/pipeline/hashtrie/hashtrie.go | 5 +- pkg/file/pipeline/hashtrie/hashtrie_test.go | 2 +- pkg/file/redundancy/export_test.go | 11 +++- pkg/file/redundancy/getter/export_test.go | 23 +++++++++ pkg/file/redundancy/getter/getter.go | 57 ++++++--------------- pkg/file/redundancy/level.go | 34 ++++++++---- pkg/file/redundancy/redundancy.go | 41 ++++----------- pkg/file/redundancy/redundancy_test.go | 14 ++--- pkg/file/redundancy/table.go | 16 +----- 18 files changed, 111 insertions(+), 141 deletions(-) create mode 100644 pkg/file/redundancy/getter/export_test.go diff --git a/openapi/Swarm.yaml b/openapi/Swarm.yaml index a8d6f44c8fe..7445f98e88c 100644 --- a/openapi/Swarm.yaml +++ b/openapi/Swarm.yaml @@ -122,7 +122,7 @@ paths: required: false - in: header schema: - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyLevel" + $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyLevelParameter" name: swarm-redundancy-level required: false diff --git a/openapi/SwarmCommon.yaml b/openapi/SwarmCommon.yaml index c5ac5674fe9..1c2e1bbe77d 100644 --- a/openapi/SwarmCommon.yaml +++ b/openapi/SwarmCommon.yaml @@ -934,7 +934,7 @@ components: description: > Represents the encrypting state of the file - SwarmRedundancyParameter: + SwarmRedundancyLevelParameter: in: header name: swarm-redundancy-level schema: diff --git a/pkg/api/api.go b/pkg/api/api.go index 2dd583b7df4..ff7bbaa0d4c 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -80,7 +80,7 @@ const ( SwarmCollectionHeader = "Swarm-Collection" SwarmPostageBatchIdHeader = "Swarm-Postage-Batch-Id" SwarmDeferredUploadHeader = "Swarm-Deferred-Upload" - SwarmRLevel = "Swarm-Redundancy-Level" + SwarmRedunancyLevel = "Swarm-Redundancy-Level" ImmutableHeader = "Immutable" GasPriceHeader = "Gas-Price" @@ -624,7 +624,7 @@ func (s *Service) corsHandler(h http.Handler) http.Handler { allowedHeaders := []string{ "User-Agent", "Accept", "X-Requested-With", "Access-Control-Request-Headers", "Access-Control-Request-Method", "Accept-Ranges", "Content-Encoding", AuthorizationHeader, AcceptEncodingHeader, ContentTypeHeader, ContentDispositionHeader, RangeHeader, OriginHeader, - SwarmTagHeader, SwarmPinHeader, SwarmEncryptHeader, SwarmIndexDocumentHeader, SwarmErrorDocumentHeader, SwarmCollectionHeader, SwarmPostageBatchIdHeader, SwarmDeferredUploadHeader, SwarmRLevel, + SwarmTagHeader, SwarmPinHeader, SwarmEncryptHeader, SwarmIndexDocumentHeader, SwarmErrorDocumentHeader, SwarmCollectionHeader, SwarmPostageBatchIdHeader, SwarmDeferredUploadHeader, SwarmRedunancyLevel, GasPriceHeader, GasLimitHeader, ImmutableHeader, } allowedHeadersStr := strings.Join(allowedHeaders, ", ") diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index a479c6829f3..a4b06f738cb 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -16,8 +16,6 @@ import ( "strings" "time" - "github.com/ethersphere/bee/pkg/topology" - "github.com/ethereum/go-ethereum/common" "github.com/ethersphere/bee/pkg/feeds" "github.com/ethersphere/bee/pkg/file/joiner" @@ -30,6 +28,7 @@ import ( storage "github.com/ethersphere/bee/pkg/storage" storer "github.com/ethersphere/bee/pkg/storer" "github.com/ethersphere/bee/pkg/swarm" + "github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/tracing" "github.com/ethersphere/langos" "github.com/gorilla/mux" diff --git a/pkg/api/dirs.go b/pkg/api/dirs.go index 00b530b9137..44cece35832 100644 --- a/pkg/api/dirs.go +++ b/pkg/api/dirs.go @@ -64,10 +64,14 @@ func (s *Service) dirUploadHandler( } defer r.Body.Close() - rsParity, err := strconv.ParseUint(r.Header.Get(SwarmRLevel), 10, 1) + rLevelNum, err := strconv.ParseUint(r.Header.Get(SwarmRedunancyLevel), 10, 1) if err != nil { - logger.Debug("store dir failed", "rsParity parsing error") - logger.Error(nil, "store dir failed") + logger.Debug("store directory failed failed", "redundancy level parsing error") + logger.Error(nil, "store directory failed") + } + rLevel, err := redundancy.NewLevel(uint8(rLevelNum)) + if err != nil { + jsonhttp.BadRequest(w, err.Error()) } reference, err := storeDir( @@ -79,7 +83,7 @@ func (s *Service) dirUploadHandler( s.storer.ChunkStore(), r.Header.Get(SwarmIndexDocumentHeader), r.Header.Get(SwarmErrorDocumentHeader), - redundancy.Level(rsParity), + rLevel, ) if err != nil { logger.Debug("store dir failed", "error", err) diff --git a/pkg/encryption/encryption.go b/pkg/encryption/encryption.go index 1d3fc3d08d8..c0019dd76bd 100644 --- a/pkg/encryption/encryption.go +++ b/pkg/encryption/encryption.go @@ -185,10 +185,3 @@ func GenerateRandomKey(l int) Key { } return key } - -func min(a, b int) int { - if a < b { - return a - } - return b -} diff --git a/pkg/file/joiner/joiner.go b/pkg/file/joiner/joiner.go index 5758584f006..349e7f1a759 100644 --- a/pkg/file/joiner/joiner.go +++ b/pkg/file/joiner/joiner.go @@ -51,10 +51,7 @@ func New(ctx context.Context, getter storage.Getter, putter storage.Putter, addr chunkData := rootChunk.Data() rootData := chunkData[swarm.SpanSize:] refLength := len(address.Bytes()) - encryption := false - if refLength != swarm.HashSize { - encryption = true - } + encryption := refLength != swarm.HashSize rLevel, span := chunkToSpan(chunkData) rootParity := 0 maxBranching := swarm.ChunkSize / refLength @@ -355,8 +352,8 @@ func (j *joiner) Size() int64 { return j.span } -// UTILITIES - +// chunkToSpan returns redundancy level and span value +// in the types that the package uses func chunkToSpan(data []byte) (redundancy.Level, int64) { level, spanBytes := redundancy.DecodeSpan(data[:swarm.SpanSize]) return level, int64(bmt.LengthFromSpan(spanBytes)) diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 67ac1c3f9c1..e15fb8277d6 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -1032,11 +1032,11 @@ func TestJoinerRedundancy(t *testing.T) { chunkData := make([]byte, chunkSize) _, err = joinReader.ReadAt(chunkData, offset) if err != nil { - t.Fatalf("read error check at chunkdata comparisation on %d index: %s", i, err.Error()) + t.Fatalf("joiner %d: read data at offset %v: %v", i, offset, err.Error()) } expectedChunkData := dataChunks[i].Data()[swarm.SpanSize:] if !bytes.Equal(expectedChunkData, chunkData) { - t.Fatalf("read error check at chunkdata comparisation on %d index. Data are not the same", i) + t.Fatalf("joiner %d: read data at offset %v: %v", i, offset, err.Error()) } offset += int64(chunkSize) } @@ -1050,10 +1050,7 @@ func TestJoinerRedundancy(t *testing.T) { maxShards = tc.rLevel.GetMaxEncShards() maxParities = tc.rLevel.GetEncParities(maxShards) } - removeCount := maxParities - if maxParities > maxShards { - removeCount = maxShards - } + removeCount := min(maxParities, maxParities) for i := 0; i < removeCount; i++ { err := store.Delete(ctx, dataChunks[i].Address()) if err != nil { diff --git a/pkg/file/pipeline/feeder/feeder.go b/pkg/file/pipeline/feeder/feeder.go index 9b4b60cae61..03d77cfb49c 100644 --- a/pkg/file/pipeline/feeder/feeder.go +++ b/pkg/file/pipeline/feeder/feeder.go @@ -75,7 +75,6 @@ func (f *chunkFeeder) Write(b []byte) (int, error) { sp += n binary.LittleEndian.PutUint64(d[:span], uint64(sp)) - args := &pipeline.PipeWriteArgs{Data: d[:span+sp], Span: d[:span]} err := f.next.ChainWrite(args) if err != nil { diff --git a/pkg/file/pipeline/hashtrie/hashtrie.go b/pkg/file/pipeline/hashtrie/hashtrie.go index fc1f040aca5..3091b21bb8f 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie.go +++ b/pkg/file/pipeline/hashtrie/hashtrie.go @@ -26,14 +26,14 @@ type hashTrieWriter struct { buffer []byte // keeps intermediate level data full bool // indicates whether the trie is full. currently we support (128^7)*4096 = 2305843009213693952 bytes pipelineFn pipeline.PipelineFunc - rParams redundancy.IParams + rParams redundancy.RedundancyParams parityChunkFn redundancy.ParityChunkCallback chunkCounters []uint8 // counts the chunk references in intermediate chunks. key is the chunk level. effectiveChunkCounters []uint8 // counts the effective chunk references in intermediate chunks. key is the chunk level. maxChildrenChunks uint8 // maximum number of chunk references in intermediate chunks. } -func NewHashTrieWriter(refLen int, rParams redundancy.IParams, pipelineFn pipeline.PipelineFunc) pipeline.ChainWriter { +func NewHashTrieWriter(refLen int, rParams redundancy.RedundancyParams, pipelineFn pipeline.PipelineFunc) pipeline.ChainWriter { h := &hashTrieWriter{ refSize: refLen, cursors: make([]int, 9), @@ -90,6 +90,7 @@ func (h *hashTrieWriter) writeToIntermediateLevel(level int, parityChunk bool, s return nil } +// writeToDataLevel caches data chunks and call writeToIntermediateLevel func (h *hashTrieWriter) writeToDataLevel(span, ref, key, data []byte) error { // write dataChunks to the level above err := h.writeToIntermediateLevel(1, false, span, ref, key) diff --git a/pkg/file/pipeline/hashtrie/hashtrie_test.go b/pkg/file/pipeline/hashtrie/hashtrie_test.go index 78b4c20b2d7..8b501b6b536 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie_test.go +++ b/pkg/file/pipeline/hashtrie/hashtrie_test.go @@ -52,7 +52,7 @@ func newErasureHashTrieWriter( rLevel redundancy.Level, encryptChunks bool, intermediateChunkPipeline, parityChunkPipeline pipeline.ChainWriter, -) (redundancy.IParams, pipeline.ChainWriter) { +) (redundancy.RedundancyParams, pipeline.ChainWriter) { pf := func() pipeline.ChainWriter { lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline) return bmt.NewBmtWriter(lsw) diff --git a/pkg/file/redundancy/export_test.go b/pkg/file/redundancy/export_test.go index 492c956ec88..7ea470a7e9c 100644 --- a/pkg/file/redundancy/export_test.go +++ b/pkg/file/redundancy/export_test.go @@ -4,5 +4,12 @@ package redundancy -var SetErasureEncoder = setErasureEncoder -var GetErasureEncoder = getErasureEncoder +// SetErasureEncoder changes erasureEncoderFunc to a new erasureEncoder facade +func SetErasureEncoder(f func(shards, parities int) (ErasureEncoder, error)) { + erasureEncoderFunc = f +} + +// GetErasureEncoder returns erasureEncoderFunc +func GetErasureEncoder() func(shards, parities int) (ErasureEncoder, error) { + return erasureEncoderFunc +} diff --git a/pkg/file/redundancy/getter/export_test.go b/pkg/file/redundancy/getter/export_test.go new file mode 100644 index 00000000000..e97c05bc499 --- /dev/null +++ b/pkg/file/redundancy/getter/export_test.go @@ -0,0 +1,23 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package getter + +import "errors" + +func IsCannotRecoverError(err error, missingChunks int) bool { + return errors.Is(err, cannotRecoverError(missingChunks)) +} + +func IsNotRecoveredError(err error, chAddress string) bool { + return errors.Is(err, notRecoveredError(chAddress)) +} + +func IsNoDataAddressIncludedError(err error, chAddress string) bool { + return errors.Is(err, noDataAddressIncludedError(chAddress)) +} + +func IsNoRedundancyError(err error, chAddress string) bool { + return errors.Is(err, noRedundancyError(chAddress)) +} diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go index f7ef497743c..a23cd652ee6 100644 --- a/pkg/file/redundancy/getter/getter.go +++ b/pkg/file/redundancy/getter/getter.go @@ -18,52 +18,28 @@ import ( /// ERRORS -type cannotRecoverError struct { - missingChunks int -} +type cannotRecoverError int func (e cannotRecoverError) Error() string { - return fmt.Sprintf("redundancy getter: there are %d missing chunks in order to do recovery", e.missingChunks) -} - -func IsCannotRecoverError(err error, missingChunks int) bool { - return errors.Is(err, cannotRecoverError{missingChunks}) + return fmt.Sprintf("redundancy getter: there are %d missing chunks in order to do recovery", e) } -type isNotRecoveredError struct { - chAddress string -} - -func (e isNotRecoveredError) Error() string { - return fmt.Sprintf("redundancy getter: chunk with address %s is not recovered", e.chAddress) -} +type notRecoveredError string -func IsNotRecoveredError(err error, chAddress string) bool { - return errors.Is(err, isNotRecoveredError{chAddress}) +func (e notRecoveredError) Error() string { + return fmt.Sprintf("redundancy getter: chunk with address %s is not recovered", string(e)) } -type noDataAddressIncludedError struct { - chAddress string -} +type noDataAddressIncludedError string func (e noDataAddressIncludedError) Error() string { - return fmt.Sprintf("redundancy getter: no data shard address given with chunk address %s", e.chAddress) -} - -func IsNoDataAddressIncludedError(err error, chAddress string) bool { - return errors.Is(err, noDataAddressIncludedError{chAddress}) + return fmt.Sprintf("redundancy getter: no data shard address given with chunk address %s", string(e)) } -type noRedundancyError struct { - chAddress string -} +type noRedundancyError string func (e noRedundancyError) Error() string { - return fmt.Sprintf("redundancy getter: cannot get chunk %s because no redundancy added", e.chAddress) -} - -func IsNoRedundancyError(err error, chAddress string) bool { - return errors.Is(err, noRedundancyError{chAddress}) + return fmt.Sprintf("redundancy getter: cannot get chunk %s because no redundancy added", string(e)) } /// TYPES @@ -79,11 +55,11 @@ type inflightChunk struct { type getter struct { storage.Getter storage.Putter - mu sync.Mutex + mu sync.Mutex // guards erasureData and cache sAddresses []swarm.Address // shard addresses pAddresses []swarm.Address // parity addresses cache map[string]inflightChunk // map from chunk address to cached shard chunk data - erasureData [][]byte // data + parity shards for erasure decoding; TODO mutex + erasureData [][]byte // data + parity shards for erasure decoding; encrypted bool // swarm datashards are encrypted } @@ -95,7 +71,6 @@ func New(sAddresses, pAddresses []swarm.Address, g storage.Getter, p storage.Put n := shards + parities erasureData := make([][]byte, n) cache := make(map[string]inflightChunk, n) - // init cache for i, addr := range sAddresses { cache[addr.String()] = inflightChunk{ pos: i, @@ -105,7 +80,7 @@ func New(sAddresses, pAddresses []swarm.Address, g storage.Getter, p storage.Put for i, addr := range pAddresses { cache[addr.String()] = inflightChunk{ pos: len(sAddresses) + i, - // no wait channel initialization is needed + // no wait channel initialization is needed since parity chunk addresses shouldn't be requested directly } } @@ -127,7 +102,7 @@ func (g *getter) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, erro cValue, ok := g.cache[addr.String()] g.mu.Unlock() if !ok || cValue.pos >= len(g.sAddresses) { - return nil, noDataAddressIncludedError{addr.String()} + return nil, noDataAddressIncludedError(addr.String()) } if cValue.wait != nil { // equals to g.processing but does not need lock again @@ -139,7 +114,7 @@ func (g *getter) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, erro return ch, nil } if errors.Is(storage.ErrNotFound, err) && len(g.pAddresses) == 0 { - return nil, noRedundancyError{addr.String()} + return nil, noRedundancyError(addr.String()) } // during the get, the recovery may have started by other process @@ -265,7 +240,7 @@ func (g *getter) cautiousStrategy(ctx context.Context) error { cancelContext() if retrieved < requiredChunks { - return cannotRecoverError{requiredChunks - retrieved} + return cannotRecoverError(requiredChunks - retrieved) } return g.erasureDecode(ctx) @@ -310,7 +285,7 @@ func (g *getter) erasureDecode(ctx context.Context) error { // cacheDataToChunk transforms passed chunk data to legit swarm chunk func (g *getter) cacheDataToChunk(addr swarm.Address, chData []byte) (swarm.Chunk, error) { if chData == nil { - return nil, isNotRecoveredError{addr.String()} + return nil, notRecoveredError(addr.String()) } if g.encrypted { data, err := store.DecryptChunkData(chData, addr.Bytes()[swarm.HashSize:]) diff --git a/pkg/file/redundancy/level.go b/pkg/file/redundancy/level.go index 406e9b6c437..58470e8cf0e 100644 --- a/pkg/file/redundancy/level.go +++ b/pkg/file/redundancy/level.go @@ -6,21 +6,37 @@ package redundancy import ( "errors" + "fmt" "github.com/ethersphere/bee/pkg/swarm" ) +// Level is the redundancy level +// which carries information about how much redundancy should be added to data to remain retrievable with a 1-10^(-6) certainty +// in different groups of expected chunk retrival error rates (level values) type Level uint8 const ( + // no redundancy will be added NONE Level = iota + // expected 1% chunk retrieval error rate MEDIUM + // expected 5% chunk retrieval error rate STRONG + // expected 10% chunk retrieval error rate INSANE + // expected 50% chunk retrieval error rate PARANOID ) -const maxLevel = 8 +// NewLevel returns a Level coresponding to the passed number parameter +// throws an error if there is no level for the passed number +func NewLevel(n uint8) (Level, error) { + if n > uint8(PARANOID) { + return 0, fmt.Errorf("redundancy: number %d does not have corresponding level", n) + } + return Level(n), nil +} // GetParities returns number of parities based on appendix F table 5 func (l Level) GetParities(shards int) int { @@ -49,13 +65,13 @@ func (l Level) GetEncParities(shards int) int { func (l Level) getErasureTable() (erasureTable, error) { switch l { case MEDIUM: - return *mediumEt, nil + return mediumEt, nil case STRONG: - return *strongEt, nil + return strongEt, nil case INSANE: - return *insaneEt, nil + return insaneEt, nil case PARANOID: - return *paranoidEt, nil + return paranoidEt, nil default: return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") } @@ -64,13 +80,13 @@ func (l Level) getErasureTable() (erasureTable, error) { func (l Level) getEncErasureTable() (erasureTable, error) { switch l { case MEDIUM: - return *encMediumEt, nil + return encMediumEt, nil case STRONG: - return *encStrongEt, nil + return encStrongEt, nil case INSANE: - return *encInsaneEt, nil + return encInsaneEt, nil case PARANOID: - return *encParanoidEt, nil + return encParanoidEt, nil default: return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") } diff --git a/pkg/file/redundancy/redundancy.go b/pkg/file/redundancy/redundancy.go index 3fe1bca2c43..7782b6cb4a4 100644 --- a/pkg/file/redundancy/redundancy.go +++ b/pkg/file/redundancy/redundancy.go @@ -15,13 +15,13 @@ import ( // ParityChunkCallback is called when a new parity chunk has been created type ParityChunkCallback func(level int, span, address []byte) error -type IParams interface { - MaxShards() int +type RedundancyParams interface { + MaxShards() int // returns the maximum data shard number being used in an intermediate chunk Level() Level - Parities(int) int - ChunkWrite(int, []byte, ParityChunkCallback) error - ElevateCarrierChunk(int, ParityChunkCallback) error - Encode(int, ParityChunkCallback) error + Parities(int) int // returns the optimal parity number for a given + ChunkWrite(int, []byte, ParityChunkCallback) error // caches the chunk data on the given chunk level and call encode + ElevateCarrierChunk(int, ParityChunkCallback) error // moves the carrier chunk to the level above + Encode(int, ParityChunkCallback) error // add parities on the given level } type ErasureEncoder interface { @@ -32,20 +32,6 @@ var erasureEncoderFunc = func(shards, parities int) (ErasureEncoder, error) { return reedsolomon.New(shards, parities) } -// setErasureEncoder changes erasureEncoderFunc to a new erasureEncoder facade -// -// used for testing -func setErasureEncoder(f func(shards, parities int) (ErasureEncoder, error)) { - erasureEncoderFunc = f -} - -// getErasureEncoder returns erasureEncoderFunc -// -// used for testing -func getErasureEncoder() func(shards, parities int) (ErasureEncoder, error) { - return erasureEncoderFunc -} - type Params struct { level Level pipeLine pipeline.PipelineFunc @@ -69,7 +55,7 @@ func New(level Level, encryption bool, pipeLine pipeline.PipelineFunc) *Params { // init dataBuffer for erasure coding rsChunkLevels := 0 if level != NONE { - rsChunkLevels = maxLevel + rsChunkLevels = 8 } Buffer := make([][][]byte, rsChunkLevels) for i := 0; i < rsChunkLevels; i++ { @@ -87,8 +73,6 @@ func New(level Level, encryption bool, pipeLine pipeline.PipelineFunc) *Params { } } -// ACCESSORS - func (p *Params) MaxShards() int { return p.maxShards } @@ -97,8 +81,6 @@ func (p *Params) Level() Level { return p.level } -// METHODS - func (p *Params) Parities(shards int) int { if p.encryption { return p.level.GetEncParities(shards) @@ -148,12 +130,12 @@ func (p *Params) encode(chunkLevel int, callback ParityChunkCallback) error { n := shards + parities // realloc for parity chunks if it does not override the prev one - // caculate parity chunks + // calculate parity chunks enc, err := erasureEncoderFunc(shards, parities) if err != nil { return err } - // make parity data + pz := len(p.buffer[chunkLevel][0]) for i := shards; i < n; i++ { p.buffer[chunkLevel][i] = make([]byte, pz) @@ -162,12 +144,11 @@ func (p *Params) encode(chunkLevel int, callback ParityChunkCallback) error { if err != nil { return err } - // store and pass newly created parity chunks + for i := shards; i < n; i++ { chunkData := p.buffer[chunkLevel][i] span := chunkData[:swarm.SpanSize] - // store data chunk writer := p.pipeLine() args := pipeline.PipeWriteArgs{ Data: chunkData, @@ -178,13 +159,11 @@ func (p *Params) encode(chunkLevel int, callback ParityChunkCallback) error { return err } - // write parity chunk to the level above err = callback(chunkLevel+1, span, args.Ref) if err != nil { return err } } - // reset cursor of dataBuffer in case it was a full chunk p.cursor[chunkLevel] = 0 return nil diff --git a/pkg/file/redundancy/redundancy_test.go b/pkg/file/redundancy/redundancy_test.go index a968cb43386..bb7aa35ba75 100644 --- a/pkg/file/redundancy/redundancy_test.go +++ b/pkg/file/redundancy/redundancy_test.go @@ -17,8 +17,6 @@ import ( "github.com/ethersphere/bee/pkg/swarm" ) -// MOCK ENCODER - type mockEncoder struct { shards, parities int } @@ -43,8 +41,6 @@ func (m *mockEncoder) Encode(buffer [][]byte) error { return nil } -// PARITY CHAIN WRITER - type ParityChainWriter struct { sync.Mutex chainWriteCalls int @@ -56,8 +52,6 @@ func NewParityChainWriter() *ParityChainWriter { return &ParityChainWriter{} } -// ACCESSORS - func (c *ParityChainWriter) ChainWriteCalls() int { c.Lock() defer c.Unlock() @@ -65,8 +59,6 @@ func (c *ParityChainWriter) ChainWriteCalls() int { } func (c *ParityChainWriter) SumCalls() int { c.Lock(); defer c.Unlock(); return c.sumCalls } -// METHODS - func (c *ParityChainWriter) ChainWrite(args *pipeline.PipeWriteArgs) error { c.Lock() defer c.Unlock() @@ -117,18 +109,18 @@ func TestEncode(t *testing.T) { buffer := make([]byte, 32) _, err := io.ReadFull(rand.Reader, buffer) if err != nil { - t.Error(err) + t.Fatal(err) } err = params.ChunkWrite(0, buffer, parityCallback) if err != nil { - t.Error(err) + t.Fatal(err) } } if shardCount != maxShards { // encode should be called automatically when reaching maxshards err := params.Encode(0, parityCallback) if err != nil { - t.Error(err) + t.Fatal(err) } } diff --git a/pkg/file/redundancy/table.go b/pkg/file/redundancy/table.go index 1b62e1ec34d..c993db310c6 100644 --- a/pkg/file/redundancy/table.go +++ b/pkg/file/redundancy/table.go @@ -4,8 +4,6 @@ package redundancy -import "fmt" - type erasureTable struct { shards []int parities []int @@ -18,7 +16,7 @@ type erasureTable struct { // shards := []int{94, 68, 46, 28, 14, 5, 1} // parities := []int{9, 8, 7, 6, 5, 4, 3} // var et = newErasureTable(shards, parities) -func newErasureTable(shards, parities []int) *erasureTable { +func newErasureTable(shards, parities []int) erasureTable { if len(shards) != len(parities) { panic("redundancy table: shards and parities arrays must be of equal size") } @@ -37,7 +35,7 @@ func newErasureTable(shards, parities []int) *erasureTable { maxShards, maxParities = s, p } - return &erasureTable{ + return erasureTable{ shards: shards, parities: parities, } @@ -52,13 +50,3 @@ func (et *erasureTable) getParities(maxShards int) int { } return 0 } - -// getMinShards returns back the minimum shard number respect to the given parity number -func (et *erasureTable) GetMinShards(parities int) (int, error) { - for k, p := range et.parities { - if p == parities { - return et.shards[k], nil - } - } - return 0, fmt.Errorf("parity table: there is no minimum shard number for given parity %d", parities) -} From 485b5ac26140c17477bc38a6535c3e4e2fa77be0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Tue, 12 Dec 2023 15:30:43 +0100 Subject: [PATCH 06/23] fix: typeo --- pkg/file/joiner/joiner_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index e15fb8277d6..ca568bddd2b 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -1050,7 +1050,7 @@ func TestJoinerRedundancy(t *testing.T) { maxShards = tc.rLevel.GetMaxEncShards() maxParities = tc.rLevel.GetEncParities(maxShards) } - removeCount := min(maxParities, maxParities) + removeCount := min(maxParities, maxShards) for i := 0; i < removeCount; i++ { err := store.Delete(ctx, dataChunks[i].Address()) if err != nil { From 4243730284ec656adc7cdfac393ef81b3f91b979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Tue, 12 Dec 2023 15:47:20 +0100 Subject: [PATCH 07/23] fix: @ldeffenb review --- pkg/api/api.go | 4 ++-- pkg/api/dirs.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index ff7bbaa0d4c..58dba18c377 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -80,7 +80,7 @@ const ( SwarmCollectionHeader = "Swarm-Collection" SwarmPostageBatchIdHeader = "Swarm-Postage-Batch-Id" SwarmDeferredUploadHeader = "Swarm-Deferred-Upload" - SwarmRedunancyLevel = "Swarm-Redundancy-Level" + SwarmRedundancyLevel = "Swarm-Redundancy-Level" ImmutableHeader = "Immutable" GasPriceHeader = "Gas-Price" @@ -624,7 +624,7 @@ func (s *Service) corsHandler(h http.Handler) http.Handler { allowedHeaders := []string{ "User-Agent", "Accept", "X-Requested-With", "Access-Control-Request-Headers", "Access-Control-Request-Method", "Accept-Ranges", "Content-Encoding", AuthorizationHeader, AcceptEncodingHeader, ContentTypeHeader, ContentDispositionHeader, RangeHeader, OriginHeader, - SwarmTagHeader, SwarmPinHeader, SwarmEncryptHeader, SwarmIndexDocumentHeader, SwarmErrorDocumentHeader, SwarmCollectionHeader, SwarmPostageBatchIdHeader, SwarmDeferredUploadHeader, SwarmRedunancyLevel, + SwarmTagHeader, SwarmPinHeader, SwarmEncryptHeader, SwarmIndexDocumentHeader, SwarmErrorDocumentHeader, SwarmCollectionHeader, SwarmPostageBatchIdHeader, SwarmDeferredUploadHeader, SwarmRedundancyLevel, GasPriceHeader, GasLimitHeader, ImmutableHeader, } allowedHeadersStr := strings.Join(allowedHeaders, ", ") diff --git a/pkg/api/dirs.go b/pkg/api/dirs.go index 44cece35832..867e1266629 100644 --- a/pkg/api/dirs.go +++ b/pkg/api/dirs.go @@ -64,7 +64,7 @@ func (s *Service) dirUploadHandler( } defer r.Body.Close() - rLevelNum, err := strconv.ParseUint(r.Header.Get(SwarmRedunancyLevel), 10, 1) + rLevelNum, err := strconv.ParseUint(r.Header.Get(SwarmRedundancyLevel), 10, 8) if err != nil { logger.Debug("store directory failed failed", "redundancy level parsing error") logger.Error(nil, "store directory failed") From 7574a9631b067dde3fbd44ff0f4576f9d7744ddc Mon Sep 17 00:00:00 2001 From: nugaon <50576770+nugaon@users.noreply.github.com> Date: Thu, 14 Dec 2023 23:08:57 +0100 Subject: [PATCH 08/23] feat: replicas integration (#4492) --- pkg/file/joiner/joiner.go | 11 ++- pkg/file/joiner/joiner_test.go | 17 +++-- pkg/file/pipeline/builder/builder.go | 4 ++ pkg/file/pipeline/hashtrie/hashtrie.go | 37 ++++++++-- pkg/file/pipeline/hashtrie/hashtrie_test.go | 29 ++++++-- pkg/file/redundancy/getter/getter.go | 20 +++--- pkg/file/redundancy/level.go | 23 +++++++ pkg/file/redundancy/redundancy.go | 23 +++++-- pkg/file/utils.go | 6 +- pkg/replicas/export_test.go | 1 - pkg/replicas/getter.go | 8 +-- pkg/replicas/getter_test.go | 6 +- pkg/replicas/putter.go | 6 +- pkg/replicas/putter_test.go | 2 +- pkg/replicas/replicas.go | 75 ++++++++++----------- 15 files changed, 182 insertions(+), 86 deletions(-) diff --git a/pkg/file/joiner/joiner.go b/pkg/file/joiner/joiner.go index 349e7f1a759..828eecab9f9 100644 --- a/pkg/file/joiner/joiner.go +++ b/pkg/file/joiner/joiner.go @@ -18,6 +18,7 @@ import ( "github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/file/redundancy/getter" + "github.com/ethersphere/bee/pkg/replicas" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "golang.org/x/sync/errgroup" @@ -41,9 +42,13 @@ type joiner struct { // New creates a new Joiner. A Joiner provides Read, Seek and Size functionalities. func New(ctx context.Context, getter storage.Getter, putter storage.Putter, address swarm.Address) (file.Joiner, int64, error) { - getter = store.New(getter) // retrieve the root chunk to read the total data length the be retrieved - rootChunk, err := getter.Get(ctx, address) + rLevel := replicas.GetLevelFromContext(ctx) + rootChunkGetter := store.New(getter) + if rLevel != redundancy.NONE { + rootChunkGetter = store.New(replicas.NewGetter(getter, rLevel)) + } + rootChunk, err := rootChunkGetter.Get(ctx, address) if err != nil { return nil, 0, err } @@ -152,7 +157,7 @@ func (j *joiner) readAtOffset( return } sAddresses, pAddresses := file.ChunkAddresses(data[:pSize], parity, j.refLength) - getter := getter.New(sAddresses, pAddresses, j.getter, j.putter) + getter := store.New(getter.New(sAddresses, pAddresses, j.getter, j.putter)) for cursor := 0; cursor < len(data); cursor += j.refLength { if bytesToRead == 0 { break diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index ca568bddd2b..c406654023a 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -24,6 +24,7 @@ import ( "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/file/splitter" filetest "github.com/ethersphere/bee/pkg/file/testing" + "github.com/ethersphere/bee/pkg/replicas" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" testingc "github.com/ethersphere/bee/pkg/storage/testing" @@ -980,6 +981,7 @@ func TestJoinerRedundancy(t *testing.T) { tc := tc t.Run(fmt.Sprintf("redundancy %d encryption %t", tc.rLevel, tc.encryptChunk), func(t *testing.T) { ctx := context.Background() + ctx = replicas.SetLevel(ctx, tc.rLevel) store := inmemchunkstore.New() pipe := builder.NewPipelineBuilder(ctx, store, tc.encryptChunk, tc.rLevel) @@ -1000,10 +1002,6 @@ func TestJoinerRedundancy(t *testing.T) { if err != nil { t.Fatal(err) } - err = store.Put(ctx, dataChunks[i]) - if err != nil { - t.Fatal(err) - } _, err = pipe.Write(chunkData) if err != nil { t.Fatal(err) @@ -1065,6 +1063,17 @@ func TestJoinerRedundancy(t *testing.T) { // check whether the data still be readable readCheck() + + // remove root chunk and try to get it by disperse replica + err = store.Delete(ctx, swarmAddr) + if err != nil { + t.Fatal(err) + } + joinReader, _, err = joiner.New(ctx, store, store, swarmAddr) + if err != nil { + t.Fatal(err) + } + readCheck() }) } } diff --git a/pkg/file/pipeline/builder/builder.go b/pkg/file/pipeline/builder/builder.go index e5cf03b2da4..d022ed5724a 100644 --- a/pkg/file/pipeline/builder/builder.go +++ b/pkg/file/pipeline/builder/builder.go @@ -36,9 +36,11 @@ func NewPipelineBuilder(ctx context.Context, s storage.Putter, encrypt bool, rLe func newPipeline(ctx context.Context, s storage.Putter, rLevel redundancy.Level) pipeline.Interface { pipeline := newShortPipelineFunc(ctx, s) tw := hashtrie.NewHashTrieWriter( + ctx, swarm.HashSize, redundancy.New(rLevel, false, pipeline), pipeline, + s, ) lsw := store.NewStoreWriter(ctx, s, tw) b := bmt.NewBmtWriter(lsw) @@ -61,9 +63,11 @@ func newShortPipelineFunc(ctx context.Context, s storage.Putter) func() pipeline // with the unencrypted span is preserved. func newEncryptionPipeline(ctx context.Context, s storage.Putter, rLevel redundancy.Level) pipeline.Interface { tw := hashtrie.NewHashTrieWriter( + ctx, swarm.HashSize+encryption.KeyLength, redundancy.New(rLevel, true, newShortPipelineFunc(ctx, s)), newShortEncryptionPipelineFunc(ctx, s), + s, ) lsw := store.NewStoreWriter(ctx, s, tw) b := bmt.NewBmtWriter(lsw) diff --git a/pkg/file/pipeline/hashtrie/hashtrie.go b/pkg/file/pipeline/hashtrie/hashtrie.go index 3091b21bb8f..cb3c5f468a2 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie.go +++ b/pkg/file/pipeline/hashtrie/hashtrie.go @@ -5,11 +5,15 @@ package hashtrie import ( + "context" "encoding/binary" "errors" + "fmt" "github.com/ethersphere/bee/pkg/file/pipeline" "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/replicas" + "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" ) @@ -21,6 +25,7 @@ var ( const maxLevel = 8 type hashTrieWriter struct { + ctx context.Context // context for put function of dispersed replica chunks refSize int cursors []int // level cursors, key is level. level 0 is data level holds how many chunks were processed. Intermediate higher levels will always have LOWER cursor values. buffer []byte // keeps intermediate level data @@ -28,13 +33,21 @@ type hashTrieWriter struct { pipelineFn pipeline.PipelineFunc rParams redundancy.RedundancyParams parityChunkFn redundancy.ParityChunkCallback - chunkCounters []uint8 // counts the chunk references in intermediate chunks. key is the chunk level. - effectiveChunkCounters []uint8 // counts the effective chunk references in intermediate chunks. key is the chunk level. - maxChildrenChunks uint8 // maximum number of chunk references in intermediate chunks. + chunkCounters []uint8 // counts the chunk references in intermediate chunks. key is the chunk level. + effectiveChunkCounters []uint8 // counts the effective chunk references in intermediate chunks. key is the chunk level. + maxChildrenChunks uint8 // maximum number of chunk references in intermediate chunks. + replicaPutter storage.Putter // putter to save dispersed replicas of the root chunk } -func NewHashTrieWriter(refLen int, rParams redundancy.RedundancyParams, pipelineFn pipeline.PipelineFunc) pipeline.ChainWriter { +func NewHashTrieWriter( + ctx context.Context, + refLen int, + rParams redundancy.RedundancyParams, + pipelineFn pipeline.PipelineFunc, + replicaPutter storage.Putter, +) pipeline.ChainWriter { h := &hashTrieWriter{ + ctx: ctx, refSize: refLen, cursors: make([]int, 9), buffer: make([]byte, swarm.ChunkWithSpanSize*9*2), // double size as temp workaround for weak calculation of needed buffer space @@ -43,6 +56,7 @@ func NewHashTrieWriter(refLen int, rParams redundancy.RedundancyParams, pipeline chunkCounters: make([]uint8, 9), effectiveChunkCounters: make([]uint8, 9), maxChildrenChunks: uint8(rParams.MaxShards() + rParams.Parities(rParams.MaxShards())), + replicaPutter: replicas.NewPutter(replicaPutter), } h.parityChunkFn = func(level int, span, address []byte) error { return h.writeToIntermediateLevel(level, true, span, address, []byte{}) @@ -246,5 +260,18 @@ func (h *hashTrieWriter) Sum() ([]byte, error) { // return the hash in the highest level, that's all we need data := h.buffer[0:h.cursors[maxLevel]] - return data[swarm.SpanSize:], nil + rootHash := data[swarm.SpanSize:] + + // save disperse replicas of the root chunk + if h.rParams.Level() != redundancy.NONE { + rootData, err := h.rParams.GetRootData() + if err != nil { + return nil, err + } + err = h.replicaPutter.Put(h.ctx, swarm.NewChunk(swarm.NewAddress(rootHash), rootData)) + if err != nil { + return nil, fmt.Errorf("hashtrie: cannot put dispersed replica %s", err.Error()) + } + } + return rootHash, nil } diff --git a/pkg/file/pipeline/hashtrie/hashtrie_test.go b/pkg/file/pipeline/hashtrie/hashtrie_test.go index 8b501b6b536..42752df77f8 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie_test.go +++ b/pkg/file/pipeline/hashtrie/hashtrie_test.go @@ -23,6 +23,7 @@ import ( "github.com/ethersphere/bee/pkg/file/pipeline/mock" "github.com/ethersphere/bee/pkg/file/pipeline/store" "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/replicas" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" "github.com/ethersphere/bee/pkg/swarm" @@ -52,6 +53,7 @@ func newErasureHashTrieWriter( rLevel redundancy.Level, encryptChunks bool, intermediateChunkPipeline, parityChunkPipeline pipeline.ChainWriter, + replicaPutter storage.Putter, ) (redundancy.RedundancyParams, pipeline.ChainWriter) { pf := func() pipeline.ChainWriter { lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline) @@ -75,7 +77,7 @@ func newErasureHashTrieWriter( } r := redundancy.New(rLevel, encryptChunks, ppf) - ht := hashtrie.NewHashTrieWriter(hashSize, r, pf) + ht := hashtrie.NewHashTrieWriter(ctx, hashSize, r, pf, replicaPutter) return r, ht } @@ -143,7 +145,7 @@ func TestLevels(t *testing.T) { return bmt.NewBmtWriter(lsw) } - ht := hashtrie.NewHashTrieWriter(hashSize, redundancy.New(0, false, pf), pf) + ht := hashtrie.NewHashTrieWriter(ctx, hashSize, redundancy.New(0, false, pf), pf, s) for i := 0; i < tc.writes; i++ { a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span} @@ -196,7 +198,7 @@ func TestLevels_TrieFull(t *testing.T) { Params: *r, } - ht = hashtrie.NewHashTrieWriter(hashSize, rMock, pf) + ht = hashtrie.NewHashTrieWriter(ctx, hashSize, rMock, pf, s) ) // to create a level wrap we need to do branching^(level-1) writes @@ -237,7 +239,7 @@ func TestRegression(t *testing.T) { lsw := store.NewStoreWriter(ctx, s, nil) return bmt.NewBmtWriter(lsw) } - ht = hashtrie.NewHashTrieWriter(hashSize, redundancy.New(0, false, pf), pf) + ht = hashtrie.NewHashTrieWriter(ctx, hashSize, redundancy.New(0, false, pf), pf, s) ) binary.LittleEndian.PutUint64(span, 4096) @@ -265,6 +267,16 @@ func TestRegression(t *testing.T) { } } +type replicaPutter struct { + storage.Putter + replicaCount uint8 +} + +func (r *replicaPutter) Put(ctx context.Context, chunk swarm.Chunk) error { + r.replicaCount++ + return r.Putter.Put(ctx, chunk) +} + // TestRedundancy using erasure coding library and checks carrierChunk function and modified span in intermediate chunk func TestRedundancy(t *testing.T) { t.Parallel() @@ -303,12 +315,14 @@ func TestRedundancy(t *testing.T) { tc := tc t.Run(tc.desc, func(t *testing.T) { t.Parallel() + subCtx := replicas.SetLevel(ctx, tc.level) s := inmemchunkstore.New() intermediateChunkCounter := mock.NewChainWriter() parityChunkCounter := mock.NewChainWriter() + replicaChunkCounter := &replicaPutter{Putter: s} - r, ht := newErasureHashTrieWriter(ctx, s, tc.level, tc.encryption, intermediateChunkCounter, parityChunkCounter) + r, ht := newErasureHashTrieWriter(subCtx, s, tc.level, tc.encryption, intermediateChunkCounter, parityChunkCounter, replicaChunkCounter) // write data to the hashTrie var key []byte @@ -336,7 +350,7 @@ func TestRedundancy(t *testing.T) { t.Errorf("effective chunks should be %d. Got: %d", tc.writes, intermediateChunkCounter.ChainWriteCalls()) } - rootch, err := s.Get(ctx, swarm.NewAddress(ref[:swarm.HashSize])) + rootch, err := s.Get(subCtx, swarm.NewAddress(ref[:swarm.HashSize])) if err != nil { t.Fatal(err) } @@ -362,6 +376,9 @@ func TestRedundancy(t *testing.T) { if expectedParities != parity { t.Fatalf("want parity %d got %d", expectedParities, parity) } + if tc.level.GetReplicaCount()+1 != int(replicaChunkCounter.replicaCount) { // +1 is the original chunk + t.Fatalf("unexpected number of replicas: want %d. Got: %d", tc.level.GetReplicaCount(), int(replicaChunkCounter.replicaCount)) + } }) } } diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go index a23cd652ee6..30eaeaded66 100644 --- a/pkg/file/redundancy/getter/getter.go +++ b/pkg/file/redundancy/getter/getter.go @@ -10,7 +10,6 @@ import ( "fmt" "sync" - "github.com/ethersphere/bee/pkg/encryption/store" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "github.com/klauspost/reedsolomon" @@ -59,13 +58,11 @@ type getter struct { sAddresses []swarm.Address // shard addresses pAddresses []swarm.Address // parity addresses cache map[string]inflightChunk // map from chunk address to cached shard chunk data - erasureData [][]byte // data + parity shards for erasure decoding; - encrypted bool // swarm datashards are encrypted + erasureData [][]byte // data + parity shards for erasure decoding; TODO mutex } // New returns a getter object which is used to retrieve children of an intermediate chunk func New(sAddresses, pAddresses []swarm.Address, g storage.Getter, p storage.Putter) storage.Getter { - encrypted := len(sAddresses[0].Bytes()) == swarm.HashSize*2 shards := len(sAddresses) parities := len(pAddresses) n := shards + parities @@ -90,7 +87,6 @@ func New(sAddresses, pAddresses []swarm.Address, g storage.Getter, p storage.Put sAddresses: sAddresses, pAddresses: pAddresses, cache: cache, - encrypted: encrypted, erasureData: erasureData, } } @@ -287,13 +283,13 @@ func (g *getter) cacheDataToChunk(addr swarm.Address, chData []byte) (swarm.Chun if chData == nil { return nil, notRecoveredError(addr.String()) } - if g.encrypted { - data, err := store.DecryptChunkData(chData, addr.Bytes()[swarm.HashSize:]) - if err != nil { - return nil, err - } - chData = data - } + // if g.encrypted { + // data, err := store.DecryptChunkData(chData, addr.Bytes()[swarm.HashSize:]) + // if err != nil { + // return nil, err + // } + // chData = data + // } return swarm.NewChunk(addr, chData), nil } diff --git a/pkg/file/redundancy/level.go b/pkg/file/redundancy/level.go index 58470e8cf0e..597c506a67c 100644 --- a/pkg/file/redundancy/level.go +++ b/pkg/file/redundancy/level.go @@ -98,6 +98,16 @@ func (l Level) GetMaxEncShards() int { return (swarm.Branches - p) / 2 } +// GetReplicaCount returns back the dispersed replica number +func (l Level) GetReplicaCount() int { + return replicaCounts[int(l)] +} + +// Decrement returns a weaker redundancy level compare to the current one +func (l Level) Decrement() Level { + return Level(uint8(l) - 1) +} + // TABLE INITS var mediumEt = newErasureTable( @@ -151,3 +161,16 @@ var encParanoidEt = newErasureTable( 55, 51, 48, 44, 39, 35, 30, 24, }, ) + +// DISPERSED REPLICAS INIT + +// GetReplicaCounts returns back the ascending dispersed replica counts for all redundancy levels +func GetReplicaCounts() [5]int { + c := replicaCounts + return c +} + +// the actual number of replicas needed to keep the error rate below 1/10^6 +// for the five levels of redundancy are 0, 2, 4, 5, 19 +// we use an approximation as the successive powers of 2 +var replicaCounts = [5]int{0, 2, 4, 8, 16} diff --git a/pkg/file/redundancy/redundancy.go b/pkg/file/redundancy/redundancy.go index 7782b6cb4a4..443dc0637b3 100644 --- a/pkg/file/redundancy/redundancy.go +++ b/pkg/file/redundancy/redundancy.go @@ -18,10 +18,11 @@ type ParityChunkCallback func(level int, span, address []byte) error type RedundancyParams interface { MaxShards() int // returns the maximum data shard number being used in an intermediate chunk Level() Level - Parities(int) int // returns the optimal parity number for a given - ChunkWrite(int, []byte, ParityChunkCallback) error // caches the chunk data on the given chunk level and call encode - ElevateCarrierChunk(int, ParityChunkCallback) error // moves the carrier chunk to the level above - Encode(int, ParityChunkCallback) error // add parities on the given level + Parities(int) int + ChunkWrite(int, []byte, ParityChunkCallback) error + ElevateCarrierChunk(int, ParityChunkCallback) error + Encode(int, ParityChunkCallback) error + GetRootData() ([]byte, error) } type ErasureEncoder interface { @@ -181,3 +182,17 @@ func (p *Params) ElevateCarrierChunk(chunkLevel int, callback ParityChunkCallbac // not necessary to update current level since we will not work with it anymore return p.chunkWrite(chunkLevel+1, p.buffer[chunkLevel][p.cursor[chunkLevel]-1], callback) } + +// GetRootData returns the topmost chunk in the tree. +// throws and error if the encoding has not been finished in the BMT +// OR redundancy is not used in the BMT +func (p *Params) GetRootData() ([]byte, error) { + if p.level == NONE { + return nil, fmt.Errorf("redundancy: no redundancy level is used for the file in order to cache root data") + } + lastBuffer := p.buffer[len(p.buffer)-1] + if len(lastBuffer[0]) != swarm.ChunkWithSpanSize { + return nil, fmt.Errorf("redundancy: hashtrie sum has not finished in order to cache root data") + } + return lastBuffer[0], nil +} diff --git a/pkg/file/utils.go b/pkg/file/utils.go index 59afd2e9cef..021109637cd 100644 --- a/pkg/file/utils.go +++ b/pkg/file/utils.go @@ -12,12 +12,14 @@ import ( "github.com/ethersphere/bee/pkg/swarm" ) +var ZeroAddress = [32]byte{} + // ChunkPayloadSize returns the effective byte length of an intermediate chunk // assumes data is always chunk size (without span) func ChunkPayloadSize(data []byte) (int, error) { l := len(data) for l >= swarm.HashSize { - if !bytes.Equal(data[l-swarm.HashSize:l], swarm.ZeroAddress.Bytes()) { + if !bytes.Equal(data[l-swarm.HashSize:l], ZeroAddress[:]) { return l, nil } @@ -35,7 +37,7 @@ func ChunkAddresses(data []byte, parities, reflen int) (sAddresses, pAddresses [ pAddresses = make([]swarm.Address, parities) offset := 0 for i := 0; i < shards; i++ { - sAddresses[i] = swarm.NewAddress(data[offset : offset+reflen]) + sAddresses[i] = swarm.NewAddress(data[offset : offset+swarm.HashSize]) offset += reflen } for i := 0; i < parities; i++ { diff --git a/pkg/replicas/export_test.go b/pkg/replicas/export_test.go index 562dc9512c3..6e029f302c6 100644 --- a/pkg/replicas/export_test.go +++ b/pkg/replicas/export_test.go @@ -7,7 +7,6 @@ package replicas import "github.com/ethersphere/bee/pkg/storage" var ( - Counts = counts Signer = signer ) diff --git a/pkg/replicas/getter.go b/pkg/replicas/getter.go index 8b335130fa7..496bc5d658a 100644 --- a/pkg/replicas/getter.go +++ b/pkg/replicas/getter.go @@ -88,12 +88,12 @@ func (g *getter) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, e } }() // counters - n := 0 // counts the replica addresses tried - target := 2 // the number of replicas attempted to download in this batch - total := counts[g.level] // total number of replicas allowed (and makes sense) to retrieve + n := 0 // counts the replica addresses tried + target := 2 // the number of replicas attempted to download in this batch + total := g.level.GetReplicaCount() // - rr := newReplicator(addr, uint8(g.level)) + rr := newReplicator(addr, g.level) next := rr.c var wait <-chan time.Time // nil channel to disable case // addresses used are doubling each period of search expansion diff --git a/pkg/replicas/getter_test.go b/pkg/replicas/getter_test.go index 2664e9a6a14..3b11ad26d94 100644 --- a/pkg/replicas/getter_test.go +++ b/pkg/replicas/getter_test.go @@ -126,7 +126,7 @@ func TestGetter(t *testing.T) { var tests []test for _, f := range failures { - for level, c := range replicas.Counts { + for level, c := range redundancy.GetReplicaCounts() { for j := 0; j <= c*2+1; j++ { tests = append(tests, test{ name: fmt.Sprintf("%s level %d count %d found %d", f.name, level, c, j), @@ -257,10 +257,10 @@ func TestGetter(t *testing.T) { }) t.Run("latency", func(t *testing.T) { - + counts := redundancy.GetReplicaCounts() for i, latency := range latencies { multiplier := latency / replicas.RetryInterval - if multiplier > 0 && i < replicas.Counts[multiplier-1] { + if multiplier > 0 && i < counts[multiplier-1] { t.Fatalf("incorrect latency for retrieving replica %d: %v", i, err) } } diff --git a/pkg/replicas/putter.go b/pkg/replicas/putter.go index 2bcceb2e2da..f2334a994b8 100644 --- a/pkg/replicas/putter.go +++ b/pkg/replicas/putter.go @@ -29,14 +29,14 @@ func NewPutter(p storage.Putter) storage.Putter { // Put makes the getter satisfy the storage.Getter interface func (p *putter) Put(ctx context.Context, ch swarm.Chunk) (err error) { - rlevel := getLevelFromContext(ctx) + rlevel := GetLevelFromContext(ctx) errs := []error{p.putter.Put(ctx, ch)} if rlevel == 0 { return errs[0] } - rr := newReplicator(ch.Address(), uint8(rlevel)) - errc := make(chan error, counts[rlevel]) + rr := newReplicator(ch.Address(), rlevel) + errc := make(chan error, rlevel.GetReplicaCount()) wg := sync.WaitGroup{} for r := range rr.c { r := r diff --git a/pkg/replicas/putter_test.go b/pkg/replicas/putter_test.go index e53bfc6f1bf..7a90b5ec308 100644 --- a/pkg/replicas/putter_test.go +++ b/pkg/replicas/putter_test.go @@ -107,7 +107,7 @@ func TestPutter(t *testing.T) { } }) t.Run("attempts", func(t *testing.T) { - count := replicas.Counts[tc.level] + count := tc.level.GetReplicaCount() if len(addrs) != count { t.Fatalf("incorrect number of attempts. want %v, got %v", count, len(addrs)) } diff --git a/pkg/replicas/replicas.go b/pkg/replicas/replicas.go index 65806abba7c..8cf6f5f88bc 100644 --- a/pkg/replicas/replicas.go +++ b/pkg/replicas/replicas.go @@ -27,16 +27,9 @@ var ( redundancyLevel redundancyLevelType // RetryInterval is the duration between successive additional requests RetryInterval = 300 * time.Millisecond - // - // counts of replicas used for levels of increasing security - // the actual number of replicas needed to keep the error rate below 1/10^6 - // for the five levels of redundancy are 0, 2, 4, 5, 19 - // we use an approximation as the successive powers of 2 - counts = [5]int{0, 2, 4, 8, 16} - sums = [5]int{0, 2, 6, 14, 30} - privKey, _ = crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) - signer = crypto.NewDefaultSigner(privKey) - owner, _ = hex.DecodeString("dc5b20847f43d67928f49cd4f85d696b5a7617b5") + privKey, _ = crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) + signer = crypto.NewDefaultSigner(privKey) + owner, _ = hex.DecodeString("dc5b20847f43d67928f49cd4f85d696b5a7617b5") ) // SetLevel sets the redundancy level in the context @@ -44,8 +37,8 @@ func SetLevel(ctx context.Context, level redundancy.Level) context.Context { return context.WithValue(ctx, redundancyLevel, level) } -// getLevelFromContext is a helper function to extract the redundancy level from the context -func getLevelFromContext(ctx context.Context) redundancy.Level { +// GetLevelFromContext is a helper function to extract the redundancy level from the context +func GetLevelFromContext(ctx context.Context) redundancy.Level { rlevel := redundancy.PARANOID if val := ctx.Value(redundancyLevel); val != nil { rlevel = val.(redundancy.Level) @@ -55,21 +48,21 @@ func getLevelFromContext(ctx context.Context) redundancy.Level { // replicator running the find for replicas type replicator struct { - addr []byte // chunk address - queue [16]*replica // to sort addresses according to di - exist [30]bool // maps the 16 distinct nibbles on all levels - sizes [5]int // number of distinct neighnourhoods redcorded for each depth - c chan *replica - depth uint8 + addr []byte // chunk address + queue [16]*replica // to sort addresses according to di + exist [30]bool // maps the 16 distinct nibbles on all levels + sizes [5]int // number of distinct neighnourhoods redcorded for each depth + c chan *replica + rLevel redundancy.Level } // newReplicator replicator constructor -func newReplicator(addr swarm.Address, depth uint8) *replicator { +func newReplicator(addr swarm.Address, rLevel redundancy.Level) *replicator { rr := &replicator{ - addr: addr.Bytes(), - sizes: counts, - c: make(chan *replica, 16), - depth: depth, + addr: addr.Bytes(), + sizes: redundancy.GetReplicaCounts(), + c: make(chan *replica, 16), + rLevel: rLevel, } go rr.replicas() return rr @@ -93,12 +86,6 @@ func (rr *replicator) replicate(i uint8) (sp *replica) { return &replica{h.Sum(nil), id} } -// nh returns the lookup key for the neighbourhood of depth d -// to be used as index to the replicators exist array -func (r *replica) nh(d uint8) (nh int) { - return sums[d-1] + int(r.addr[0]>>(8-d)) -} - // replicas enumerates replica parameters (SOC ID) pushing it in a channel given as argument // the order of replicas is so that addresses are always maximally dispersed // in successive sets of addresses. @@ -106,11 +93,11 @@ func (r *replica) nh(d uint8) (nh int) { func (rr *replicator) replicas() { defer close(rr.c) n := 0 - for i := uint8(0); n < counts[rr.depth] && i < 255; i++ { + for i := uint8(0); n < rr.rLevel.GetReplicaCount() && i < 255; i++ { // create soc replica (ID and address using constant owner) // the soc is added to neighbourhoods of depths in the closed interval [from...to] r := rr.replicate(i) - d, m := rr.add(r, rr.depth) + d, m := rr.add(r, rr.rLevel) if d == 0 { continue } @@ -125,21 +112,33 @@ func (rr *replicator) replicas() { } // add inserts the soc replica into a replicator so that addresses are balanced -func (rr *replicator) add(r *replica, d uint8) (depth int, rank int) { - if d == 0 { +func (rr *replicator) add(r *replica, rLevel redundancy.Level) (depth int, rank int) { + if rLevel == redundancy.NONE { return 0, 0 } - nh := r.nh(d) + nh := nh(rLevel, r.addr) if rr.exist[nh] { return 0, 0 } rr.exist[nh] = true - l, o := rr.add(r, d-1) + l, o := rr.add(r, rLevel.Decrement()) if l == 0 { - o = rr.sizes[d-1] - rr.sizes[d-1]++ + o = rr.sizes[uint8(rLevel.Decrement())] + rr.sizes[uint8(rLevel.Decrement())]++ rr.queue[o] = r - l = int(d) + l = rLevel.GetReplicaCount() } return l, o } + +// UTILS + +// index bases needed to keep track how many addresses were mined for a level. +var replicaIndexBases = [5]int{0, 2, 6, 14} + +// nh returns the lookup key based on the redundancy level +// to be used as index to the replicators exist array +func nh(rLevel redundancy.Level, addr []byte) int { + d := uint8(rLevel) + return replicaIndexBases[d-1] + int(addr[0]>>(8-d)) +} From e94f81c6e69d19985ce38919eb0861a49f44db34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Fri, 15 Dec 2023 09:36:16 +0100 Subject: [PATCH 09/23] test(fix): hashtrie replicacount race condition --- pkg/file/pipeline/hashtrie/hashtrie_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/file/pipeline/hashtrie/hashtrie_test.go b/pkg/file/pipeline/hashtrie/hashtrie_test.go index 42752df77f8..c3729f78901 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie_test.go +++ b/pkg/file/pipeline/hashtrie/hashtrie_test.go @@ -9,6 +9,7 @@ import ( "context" "encoding/binary" "errors" + "sync/atomic" "testing" bmtUtils "github.com/ethersphere/bee/pkg/bmt" @@ -269,11 +270,11 @@ func TestRegression(t *testing.T) { type replicaPutter struct { storage.Putter - replicaCount uint8 + replicaCount atomic.Uint32 } func (r *replicaPutter) Put(ctx context.Context, chunk swarm.Chunk) error { - r.replicaCount++ + r.replicaCount.Add(1) return r.Putter.Put(ctx, chunk) } @@ -376,8 +377,8 @@ func TestRedundancy(t *testing.T) { if expectedParities != parity { t.Fatalf("want parity %d got %d", expectedParities, parity) } - if tc.level.GetReplicaCount()+1 != int(replicaChunkCounter.replicaCount) { // +1 is the original chunk - t.Fatalf("unexpected number of replicas: want %d. Got: %d", tc.level.GetReplicaCount(), int(replicaChunkCounter.replicaCount)) + if tc.level.GetReplicaCount()+1 != int(replicaChunkCounter.replicaCount.Load()) { // +1 is the original chunk + t.Fatalf("unexpected number of replicas: want %d. Got: %d", tc.level.GetReplicaCount(), int(replicaChunkCounter.replicaCount.Load())) } }) } From 6ba286b5f20ff88fed26f0876e3835ec4dde660e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Fri, 15 Dec 2023 13:49:01 +0100 Subject: [PATCH 10/23] refactor: replicas add --- pkg/replicas/replicas.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/replicas/replicas.go b/pkg/replicas/replicas.go index 8cf6f5f88bc..0fbcee6857b 100644 --- a/pkg/replicas/replicas.go +++ b/pkg/replicas/replicas.go @@ -122,9 +122,10 @@ func (rr *replicator) add(r *replica, rLevel redundancy.Level) (depth int, rank } rr.exist[nh] = true l, o := rr.add(r, rLevel.Decrement()) + d := uint8(rLevel) - 1 if l == 0 { - o = rr.sizes[uint8(rLevel.Decrement())] - rr.sizes[uint8(rLevel.Decrement())]++ + o = rr.sizes[d] + rr.sizes[d]++ rr.queue[o] = r l = rLevel.GetReplicaCount() } From 2c6911f3ba1efc19f9e39eade0ab0ef583f95139 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tr=C3=B3n?= Date: Sun, 17 Dec 2023 23:49:19 +0100 Subject: [PATCH 11/23] feat(redundancy/getter): rs decoder rewrite (#4507) --- pkg/file/joiner/joiner.go | 144 ++++--- pkg/file/joiner/joiner_test.go | 235 ++++++++--- pkg/file/redundancy/getter/export_test.go | 23 -- pkg/file/redundancy/getter/getter.go | 406 ++++++++----------- pkg/file/redundancy/getter/getter_test.go | 470 +++++++++++++++------- pkg/file/redundancy/getter/strategies.go | 133 ++++++ pkg/file/utils.go | 29 +- pkg/steward/steward_test.go | 24 +- 8 files changed, 923 insertions(+), 541 deletions(-) delete mode 100644 pkg/file/redundancy/getter/export_test.go create mode 100644 pkg/file/redundancy/getter/strategies.go diff --git a/pkg/file/joiner/joiner.go b/pkg/file/joiner/joiner.go index 828eecab9f9..c1ea6f5b4ea 100644 --- a/pkg/file/joiner/joiner.go +++ b/pkg/file/joiner/joiner.go @@ -11,6 +11,7 @@ import ( "io" "sync" "sync/atomic" + "time" "github.com/ethersphere/bee/pkg/bmt" "github.com/ethersphere/bee/pkg/encryption" @@ -33,20 +34,74 @@ type joiner struct { rootParity int maxBranching int // maximum branching in an intermediate chunk - ctx context.Context - getter storage.Getter - putter storage.Putter // required to save recovered data - + ctx context.Context + decoders *decoderCache chunkToSpan func(data []byte) (redundancy.Level, int64) // returns parity and span value from chunkData } +// decoderCache is cache of decoders for intermediate chunks +type decoderCache struct { + fetcher storage.Getter // network retrieval interface to fetch chunks + putter storage.Putter // interface to local storage to save reconstructed chunks + mu sync.Mutex // mutex to protect cache + cache map[string]storage.Getter // map from chunk address to RS getter + strategy getter.Strategy // strategy to use for retrieval + strict bool // strict mode + fetcherTimeout time.Duration // timeout for each fetch +} + +// NewDecoderCache creates a new decoder cache +func NewDecoderCache(g storage.Getter, p storage.Putter, strategy getter.Strategy, strict bool, fetcherTimeout time.Duration) *decoderCache { + return &decoderCache{ + fetcher: g, + putter: p, + cache: make(map[string]storage.Getter), + strategy: strategy, + strict: strict, + fetcherTimeout: fetcherTimeout, + } +} + +func fingerprint(addrs []swarm.Address) string { + h := swarm.NewHasher() + for _, addr := range addrs { + _, _ = h.Write(addr.Bytes()) + } + return string(h.Sum(nil)) +} + +// GetOrCreate returns a decoder for the given chunk address +func (g *decoderCache) GetOrCreate(addrs []swarm.Address, shardCnt int) storage.Getter { + if len(addrs) == shardCnt { + return g.fetcher + } + key := fingerprint(addrs) + g.mu.Lock() + defer g.mu.Unlock() + d, ok := g.cache[key] + if ok { + if d == nil { + return g.fetcher + } + return d + } + remove := func() { + g.mu.Lock() + defer g.mu.Unlock() + g.cache[key] = nil + } + d = getter.New(addrs, shardCnt, g.fetcher, g.putter, g.strategy, g.strict, g.fetcherTimeout, remove) + g.cache[key] = d + return d +} + // New creates a new Joiner. A Joiner provides Read, Seek and Size functionalities. -func New(ctx context.Context, getter storage.Getter, putter storage.Putter, address swarm.Address) (file.Joiner, int64, error) { +func New(ctx context.Context, g storage.Getter, putter storage.Putter, address swarm.Address) (file.Joiner, int64, error) { // retrieve the root chunk to read the total data length the be retrieved rLevel := replicas.GetLevelFromContext(ctx) - rootChunkGetter := store.New(getter) + rootChunkGetter := store.New(g) if rLevel != redundancy.NONE { - rootChunkGetter = store.New(replicas.NewGetter(getter, rLevel)) + rootChunkGetter = store.New(replicas.NewGetter(g, rLevel)) } rootChunk, err := rootChunkGetter.Get(ctx, address) if err != nil { @@ -56,17 +111,21 @@ func New(ctx context.Context, getter storage.Getter, putter storage.Putter, addr chunkData := rootChunk.Data() rootData := chunkData[swarm.SpanSize:] refLength := len(address.Bytes()) - encryption := refLength != swarm.HashSize + encryption := refLength == encryption.ReferenceSize rLevel, span := chunkToSpan(chunkData) rootParity := 0 maxBranching := swarm.ChunkSize / refLength spanFn := func(data []byte) (redundancy.Level, int64) { return 0, int64(bmt.LengthFromSpan(data[:swarm.SpanSize])) } + var strategy getter.Strategy + var strict bool + var fetcherTimeout time.Duration // override stuff if root chunk has redundancy if rLevel != redundancy.NONE { _, parities := file.ReferenceCount(uint64(span), rLevel, encryption) rootParity = parities + strategy, strict, fetcherTimeout = getter.GetParamsFromContext(ctx) spanFn = chunkToSpan if encryption { maxBranching = rLevel.GetMaxEncShards() @@ -79,8 +138,7 @@ func New(ctx context.Context, getter storage.Getter, putter storage.Putter, addr addr: rootChunk.Address(), refLength: refLength, ctx: ctx, - getter: getter, - putter: putter, + decoders: NewDecoderCache(g, putter, strategy, strict, fetcherTimeout), span: span, rootData: rootData, rootParity: rootParity, @@ -148,7 +206,6 @@ func (j *joiner) readAtOffset( atomic.AddInt64(bytesRead, int64(n)) return } - pSize, err := file.ChunkPayloadSize(data) if err != nil { eg.Go(func() error { @@ -156,8 +213,9 @@ func (j *joiner) readAtOffset( }) return } - sAddresses, pAddresses := file.ChunkAddresses(data[:pSize], parity, j.refLength) - getter := store.New(getter.New(sAddresses, pAddresses, j.getter, j.putter)) + + addrs, shardCnt := file.ChunkAddresses(data[:pSize], parity, j.refLength) + g := store.New(j.decoders.GetOrCreate(addrs, shardCnt)) for cursor := 0; cursor < len(data); cursor += j.refLength { if bytesToRead == 0 { break @@ -171,7 +229,7 @@ func (j *joiner) readAtOffset( } // if we are here it means that we are within the bounds of the data we need to read - address := swarm.NewAddress(data[cursor : cursor+j.refLength]) + addr := swarm.NewAddress(data[cursor : cursor+j.refLength]) subtrieSpan := sec subtrieSpanLimit := sec @@ -188,14 +246,14 @@ func (j *joiner) readAtOffset( func(address swarm.Address, b []byte, cur, subTrieSize, off, bufferOffset, bytesToRead, subtrieSpanLimit int64) { eg.Go(func() error { - ch, err := getter.Get(j.ctx, address) + ch, err := g.Get(j.ctx, addr) if err != nil { return err } chunkData := ch.Data()[8:] subtrieLevel, subtrieSpan := j.chunkToSpan(ch.Data()) - _, subtrieParity := file.ReferenceCount(uint64(subtrieSpan), subtrieLevel, j.refLength != swarm.HashSize) + _, subtrieParity := file.ReferenceCount(uint64(subtrieSpan), subtrieLevel, j.refLength == encryption.ReferenceSize) if subtrieSpan > subtrieSpanLimit { return ErrMalformedTrie @@ -204,7 +262,7 @@ func (j *joiner) readAtOffset( j.readAtOffset(b, chunkData, cur, subtrieSpan, off, bufferOffset, currentReadSize, bytesRead, subtrieParity, eg) return nil }) - }(address, b, cur, subtrieSpan, off, bufferOffset, currentReadSize, subtrieSpanLimit) + }(addr, b, cur, subtrieSpan, off, bufferOffset, currentReadSize, subtrieSpanLimit) bufferOffset += currentReadSize bytesToRead -= currentReadSize @@ -307,45 +365,39 @@ func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIter if err != nil { return err } - sAddresses, pAddresses := file.ChunkAddresses(data[:eSize], parity, j.refLength) - getter := getter.New(sAddresses, pAddresses, j.getter, j.putter) - for cursor := 0; cursor < len(data); cursor += j.refLength { - ref := data[cursor : cursor+j.refLength] - var reportAddr swarm.Address - address := swarm.NewAddress(ref) - if len(ref) == encryption.ReferenceSize { - reportAddr = swarm.NewAddress(ref[:swarm.HashSize]) - } else { - reportAddr = swarm.NewAddress(ref) - } - - if err := fn(reportAddr); err != nil { + addrs, shardCnt := file.ChunkAddresses(data[:eSize], parity, j.refLength) + g := store.New(j.decoders.GetOrCreate(addrs, shardCnt)) + for i, addr := range addrs { + if err := fn(addr); err != nil { return err } - + cursor := i * swarm.HashSize + if j.refLength == encryption.ReferenceSize { + cursor += swarm.HashSize * min(i, shardCnt) + } sec := j.subtrieSection(data, cursor, eSize, parity, subTrieSize) if sec <= swarm.ChunkSize { continue } - func(address swarm.Address, eg *errgroup.Group) { - wg.Add(1) - - eg.Go(func() error { - defer wg.Done() + wg.Add(1) + eg.Go(func() error { + defer wg.Done() - ch, err := getter.Get(ectx, address) - if err != nil { - return err - } + if j.refLength == encryption.ReferenceSize && i < shardCnt { + addr = swarm.NewAddress(data[cursor : cursor+swarm.HashSize*2]) + } + ch, err := g.Get(ectx, addr) + if err != nil { + return err + } - chunkData := ch.Data()[8:] - subtrieLevel, subtrieSpan := j.chunkToSpan(ch.Data()) - _, parities := file.ReferenceCount(uint64(subtrieSpan), subtrieLevel, j.refLength != swarm.HashSize) + chunkData := ch.Data()[8:] + subtrieLevel, subtrieSpan := j.chunkToSpan(ch.Data()) + _, parities := file.ReferenceCount(uint64(subtrieSpan), subtrieLevel, j.refLength != swarm.HashSize) - return j.processChunkAddresses(ectx, fn, chunkData, subtrieSpan, parities) - }) - }(address, eg) + return j.processChunkAddresses(ectx, fn, chunkData, subtrieSpan, parities) + }) wg.Wait() } diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index c406654023a..af86a9fd061 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -18,19 +18,19 @@ import ( "time" "github.com/ethersphere/bee/pkg/cac" - "github.com/ethersphere/bee/pkg/encryption/store" "github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/pipeline/builder" "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/file/redundancy/getter" "github.com/ethersphere/bee/pkg/file/splitter" filetest "github.com/ethersphere/bee/pkg/file/testing" - "github.com/ethersphere/bee/pkg/replicas" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" testingc "github.com/ethersphere/bee/pkg/storage/testing" "github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/util/testutil" "gitlab.com/nolash/go-mockbytes" + "golang.org/x/sync/errgroup" ) func TestJoiner_ErrReferenceLength(t *testing.T) { @@ -89,7 +89,6 @@ func TestJoinerDecryptingStore_NormalChunk(t *testing.T) { t.Parallel() st := inmemchunkstore.New() - store := store.New(st) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -107,7 +106,7 @@ func TestJoinerDecryptingStore_NormalChunk(t *testing.T) { } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, st, mockAddr) + joinReader, l, err := joiner.New(ctx, st, st, mockAddr) if err != nil { t.Fatal(err) } @@ -128,34 +127,34 @@ func TestJoinerDecryptingStore_NormalChunk(t *testing.T) { func TestJoinerWithReference(t *testing.T) { t.Parallel() - store := inmemchunkstore.New() + st := inmemchunkstore.New() ctx, cancel := context.WithCancel(context.Background()) defer cancel() // create root chunk and two data chunks referenced in the root chunk rootChunk := filetest.GenerateTestRandomFileChunk(swarm.ZeroAddress, swarm.ChunkSize*2, swarm.SectionSize*2) - err := store.Put(ctx, rootChunk) + err := st.Put(ctx, rootChunk) if err != nil { t.Fatal(err) } firstAddress := swarm.NewAddress(rootChunk.Data()[8 : swarm.SectionSize+8]) firstChunk := filetest.GenerateTestRandomFileChunk(firstAddress, swarm.ChunkSize, swarm.ChunkSize) - err = store.Put(ctx, firstChunk) + err = st.Put(ctx, firstChunk) if err != nil { t.Fatal(err) } secondAddress := swarm.NewAddress(rootChunk.Data()[swarm.SectionSize+8:]) secondChunk := filetest.GenerateTestRandomFileChunk(secondAddress, swarm.ChunkSize, swarm.ChunkSize) - err = store.Put(ctx, secondChunk) + err = st.Put(ctx, secondChunk) if err != nil { t.Fatal(err) } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, store, rootChunk.Address()) + joinReader, l, err := joiner.New(ctx, st, st, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -955,12 +954,84 @@ func TestJoinerIterateChunkAddresses_Encrypted(t *testing.T) { } } +type mockPutter struct { + storage.ChunkStore + shards, parities chan swarm.Chunk + done chan struct{} +} + +func newMockPutter(store storage.ChunkStore, shardCnt, parityCnt int) *mockPutter { + return &mockPutter{ + ChunkStore: store, + done: make(chan struct{}, 1), + shards: make(chan swarm.Chunk, shardCnt), + parities: make(chan swarm.Chunk, parityCnt), + } +} + +func (m *mockPutter) Put(ctx context.Context, ch swarm.Chunk) error { + if len(m.shards) < cap(m.shards) { + m.shards <- ch + return nil + } + if len(m.parities) < cap(m.parities) { + m.parities <- ch + return nil + } + err := m.ChunkStore.Put(context.Background(), ch) + select { + case m.done <- struct{}{}: + default: + } + return err +} + +func (m *mockPutter) wait(ctx context.Context) { + select { + case <-m.done: + case <-ctx.Done(): + } + close(m.parities) + close(m.shards) +} + +func (m *mockPutter) store(cnt int) error { + n := 0 + for ch := range m.parities { + if err := m.ChunkStore.Put(context.Background(), ch); err != nil { + return err + } + n++ + if n == cnt { + return nil + } + } + for ch := range m.shards { + if err := m.ChunkStore.Put(context.Background(), ch); err != nil { + return err + } + n++ + if n == cnt { + break + } + } + return nil +} + func TestJoinerRedundancy(t *testing.T) { - t.Parallel() + + strategyTimeout := getter.StrategyTimeout + defer func() { getter.StrategyTimeout = strategyTimeout }() + getter.StrategyTimeout = 100 * time.Millisecond + for _, tc := range []struct { rLevel redundancy.Level encryptChunk bool }{ + { + redundancy.MEDIUM, + false, + }, { redundancy.MEDIUM, true, @@ -969,6 +1040,14 @@ func TestJoinerRedundancy(t *testing.T) { redundancy.STRONG, false, }, + { + redundancy.STRONG, + true, + }, + { + redundancy.INSANE, + false, + }, { redundancy.INSANE, true, @@ -977,22 +1056,29 @@ func TestJoinerRedundancy(t *testing.T) { redundancy.PARANOID, false, }, + { + redundancy.PARANOID, + true, + }, } { tc := tc t.Run(fmt.Sprintf("redundancy %d encryption %t", tc.rLevel, tc.encryptChunk), func(t *testing.T) { - ctx := context.Background() - ctx = replicas.SetLevel(ctx, tc.rLevel) - store := inmemchunkstore.New() - pipe := builder.NewPipelineBuilder(ctx, store, tc.encryptChunk, tc.rLevel) - // generate and store chunks - dataChunkCount := tc.rLevel.GetMaxShards() + 1 // generate a carrier chunk + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + shardCnt := tc.rLevel.GetMaxShards() + parityCnt := tc.rLevel.GetParities(shardCnt) if tc.encryptChunk { - dataChunkCount = tc.rLevel.GetMaxEncShards() + 1 + shardCnt = tc.rLevel.GetMaxEncShards() + parityCnt = tc.rLevel.GetEncParities(shardCnt) } - dataChunks := make([]swarm.Chunk, dataChunkCount) + store := inmemchunkstore.New() + putter := newMockPutter(store, shardCnt, parityCnt) + pipe := builder.NewPipelineBuilder(ctx, putter, tc.encryptChunk, tc.rLevel) + dataChunks := make([]swarm.Chunk, shardCnt) chunkSize := swarm.ChunkSize - for i := 0; i < dataChunkCount; i++ { + for i := 0; i < shardCnt; i++ { chunkData := make([]byte, chunkSize) _, err := io.ReadFull(rand.Reader, chunkData) if err != nil { @@ -1014,66 +1100,91 @@ func TestJoinerRedundancy(t *testing.T) { t.Fatal(err) } swarmAddr := swarm.NewAddress(sum) - joinReader, rootSpan, err := joiner.New(ctx, store, store, swarmAddr) + putter.wait(ctx) + _, err = store.Get(ctx, swarm.NewAddress(sum[:swarm.HashSize])) if err != nil { t.Fatal(err) } - // sanity checks - expectedRootSpan := chunkSize * dataChunkCount - if int64(expectedRootSpan) != rootSpan { - t.Fatalf("Expected root span %d. Got: %d", expectedRootSpan, rootSpan) - } // all data can be read back - readCheck := func() { - offset := int64(0) - for i := 0; i < dataChunkCount; i++ { - chunkData := make([]byte, chunkSize) - _, err = joinReader.ReadAt(chunkData, offset) - if err != nil { - t.Fatalf("joiner %d: read data at offset %v: %v", i, offset, err.Error()) - } - expectedChunkData := dataChunks[i].Data()[swarm.SpanSize:] - if !bytes.Equal(expectedChunkData, chunkData) { - t.Fatalf("joiner %d: read data at offset %v: %v", i, offset, err.Error()) - } - offset += int64(chunkSize) - } - } - readCheck() + readCheck := func(t *testing.T, expErr error) { + t.Helper() - // remove data chunks in order to trigger recovery - maxShards := tc.rLevel.GetMaxShards() - maxParities := tc.rLevel.GetParities(maxShards) - if tc.encryptChunk { - maxShards = tc.rLevel.GetMaxEncShards() - maxParities = tc.rLevel.GetEncParities(maxShards) - } - removeCount := min(maxParities, maxShards) - for i := 0; i < removeCount; i++ { - err := store.Delete(ctx, dataChunks[i].Address()) + ctx, cancel := context.WithTimeout(context.Background(), 8*getter.StrategyTimeout) + defer cancel() + ctx = getter.SetFetchTimeout(ctx, getter.StrategyTimeout) + joinReader, rootSpan, err := joiner.New(ctx, store, store, swarmAddr) if err != nil { t.Fatal(err) } + // sanity checks + expectedRootSpan := chunkSize * shardCnt + if int64(expectedRootSpan) != rootSpan { + t.Fatalf("Expected root span %d. Got: %d", expectedRootSpan, rootSpan) + } + i := 0 + eg, ectx := errgroup.WithContext(ctx) + for ; i < shardCnt; i++ { + select { + case <-ectx.Done(): + break + default: + } + i := i + eg.Go(func() error { + chunkData := make([]byte, chunkSize) + _, err := joinReader.ReadAt(chunkData, int64(i*chunkSize)) + if err != nil { + return err + } + select { + case <-ectx.Done(): + return ectx.Err() + default: + } + expectedChunkData := dataChunks[i].Data()[swarm.SpanSize:] + if !bytes.Equal(expectedChunkData, chunkData) { + return fmt.Errorf("data mismatch on chunk position %d", i) + } + return nil + }) + } + err = eg.Wait() + if !errors.Is(err, expErr) { + t.Fatalf("unexpected error reading chunkdata at chunk position %d: expected %v. got %v", i, expErr, err) + } } - // remove parity chunk - err = store.Delete(ctx, dataChunks[len(dataChunks)-1].Address()) - if err != nil { + t.Run("no recovery possible with no chunk stored", func(t *testing.T) { + readCheck(t, context.DeadlineExceeded) + }) + + if err := putter.store(shardCnt - 1); err != nil { t.Fatal(err) } + t.Run("no recovery possible with 1 short of shardCnt chunks stored", func(t *testing.T) { + readCheck(t, context.DeadlineExceeded) + }) - // check whether the data still be readable - readCheck() + if err := putter.store(1); err != nil { + t.Fatal(err) + } + t.Run("recovery given shardCnt chunks stored", func(t *testing.T) { + readCheck(t, nil) + }) - // remove root chunk and try to get it by disperse replica - err = store.Delete(ctx, swarmAddr) - if err != nil { + if err := putter.store(shardCnt + parityCnt); err != nil { t.Fatal(err) } - joinReader, _, err = joiner.New(ctx, store, store, swarmAddr) - if err != nil { + t.Run("success given shardCnt data chunks stored, no need for recovery", func(t *testing.T) { + readCheck(t, nil) + }) + // success after rootChunk deleted using replicas given shardCnt data chunks stored, no need for recovery + if err := store.Delete(ctx, swarm.NewAddress(swarmAddr.Bytes()[:swarm.HashSize])); err != nil { t.Fatal(err) } - readCheck() + t.Run("recover from replica if root deleted", func(t *testing.T) { + readCheck(t, nil) + }) + }) } } diff --git a/pkg/file/redundancy/getter/export_test.go b/pkg/file/redundancy/getter/export_test.go deleted file mode 100644 index e97c05bc499..00000000000 --- a/pkg/file/redundancy/getter/export_test.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2023 The Swarm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package getter - -import "errors" - -func IsCannotRecoverError(err error, missingChunks int) bool { - return errors.Is(err, cannotRecoverError(missingChunks)) -} - -func IsNotRecoveredError(err error, chAddress string) bool { - return errors.Is(err, notRecoveredError(chAddress)) -} - -func IsNoDataAddressIncludedError(err error, chAddress string) bool { - return errors.Is(err, noDataAddressIncludedError(chAddress)) -} - -func IsNoRedundancyError(err error, chAddress string) bool { - return errors.Is(err, noRedundancyError(chAddress)) -} diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go index 30eaeaded66..1f1b847d856 100644 --- a/pkg/file/redundancy/getter/getter.go +++ b/pkg/file/redundancy/getter/getter.go @@ -6,299 +6,219 @@ package getter import ( "context" - "errors" - "fmt" + "io" "sync" + "sync/atomic" + "time" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "github.com/klauspost/reedsolomon" ) -/// ERRORS - -type cannotRecoverError int - -func (e cannotRecoverError) Error() string { - return fmt.Sprintf("redundancy getter: there are %d missing chunks in order to do recovery", e) -} - -type notRecoveredError string - -func (e notRecoveredError) Error() string { - return fmt.Sprintf("redundancy getter: chunk with address %s is not recovered", string(e)) -} - -type noDataAddressIncludedError string - -func (e noDataAddressIncludedError) Error() string { - return fmt.Sprintf("redundancy getter: no data shard address given with chunk address %s", string(e)) -} - -type noRedundancyError string - -func (e noRedundancyError) Error() string { - return fmt.Sprintf("redundancy getter: cannot get chunk %s because no redundancy added", string(e)) -} - -/// TYPES - -// inflightChunk is initialized if recovery happened already -type inflightChunk struct { - pos int // chunk index in the erasureData/intermediate chunk - wait chan struct{} // chunk is under retrieval -} - -// getter retrieves children of an intermediate chunk -// it caches sibling chunks if erasure decoding was called on the level already -type getter struct { +// decoder is a private implementation of storage.Getter +// if retrieves children of an intermediate chunk potentially using erasure decoding +// it caches sibling chunks if erasure decoding started already +type decoder struct { + fetcher storage.Getter // network retrieval interface to fetch chunks + putter storage.Putter // interface to local storage to save reconstructed chunks + addrs []swarm.Address // all addresses of the intermediate chunk + inflight []atomic.Bool // locks to protect wait channels and RS buffer + cache map[string]int // map from chunk address shard position index + waits []chan struct{} // wait channels for each chunk + rsbuf [][]byte // RS buffer of data + parity shards for erasure decoding + ready chan struct{} // signal channel for successful retrieval of shardCnt chunks + shardCnt int // number of data shards + parityCnt int // number of parity shards + wg sync.WaitGroup // wait group to wait for all goroutines to finish + mu sync.Mutex // mutex to protect buffer + fetchTimeout time.Duration // timeout for each fetch + fetchedCnt atomic.Int32 // count successful retrievals + cancel func() // cancel function for RS decoding + remove func() // callback to remove decoder from decoders cache +} + +type Getter interface { storage.Getter - storage.Putter - mu sync.Mutex // guards erasureData and cache - sAddresses []swarm.Address // shard addresses - pAddresses []swarm.Address // parity addresses - cache map[string]inflightChunk // map from chunk address to cached shard chunk data - erasureData [][]byte // data + parity shards for erasure decoding; TODO mutex -} - -// New returns a getter object which is used to retrieve children of an intermediate chunk -func New(sAddresses, pAddresses []swarm.Address, g storage.Getter, p storage.Putter) storage.Getter { - shards := len(sAddresses) - parities := len(pAddresses) - n := shards + parities - erasureData := make([][]byte, n) - cache := make(map[string]inflightChunk, n) - for i, addr := range sAddresses { - cache[addr.String()] = inflightChunk{ - pos: i, - // wait channel initialization is needed when recovery starts - } - } - for i, addr := range pAddresses { - cache[addr.String()] = inflightChunk{ - pos: len(sAddresses) + i, - // no wait channel initialization is needed since parity chunk addresses shouldn't be requested directly - } - } - - return &getter{ - Getter: g, - Putter: p, - sAddresses: sAddresses, - pAddresses: pAddresses, - cache: cache, - erasureData: erasureData, - } + io.Closer +} + +// New returns a decoder object used tos retrieve children of an intermediate chunk +func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter, strategy Strategy, strict bool, timeout time.Duration, remove func()) Getter { + ctx, cancel := context.WithCancel(context.Background()) + size := len(addrs) + + rsg := &decoder{ + fetcher: g, + putter: p, + addrs: addrs, + inflight: make([]atomic.Bool, size), + cache: make(map[string]int, size), + waits: make([]chan struct{}, shardCnt), + rsbuf: make([][]byte, size), + ready: make(chan struct{}, 1), + cancel: cancel, + remove: remove, + shardCnt: shardCnt, + parityCnt: size - shardCnt, + fetchTimeout: timeout, + } + + // after init, cache and wait channels are immutable, need no locking + for i := 0; i < shardCnt; i++ { + rsg.cache[addrs[i].ByteString()] = i + rsg.waits[i] = make(chan struct{}) + } + + // prefetch chunks according to strategy + rsg.wg.Add(1) + go func() { + rsg.prefetch(ctx, strategy, strict) + rsg.wg.Done() + }() + return rsg } // Get will call parities and other sibling chunks if the chunk address cannot be retrieved // assumes it is called for data shards only -func (g *getter) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { - g.mu.Lock() - cValue, ok := g.cache[addr.String()] - g.mu.Unlock() - if !ok || cValue.pos >= len(g.sAddresses) { - return nil, noDataAddressIncludedError(addr.String()) - } - - if cValue.wait != nil { // equals to g.processing but does not need lock again - return g.getAfterProcessed(ctx, addr) - } - - ch, err := g.Getter.Get(ctx, addr) - if err == nil { - return ch, nil - } - if errors.Is(storage.ErrNotFound, err) && len(g.pAddresses) == 0 { - return nil, noRedundancyError(addr.String()) - } - - // during the get, the recovery may have started by other process - if g.processing(addr) { - return g.getAfterProcessed(ctx, addr) - } - - return g.executeStrategies(ctx, addr) -} - -// Inc increments the counter for the given key. -func (g *getter) setErasureData(index int, data []byte) { - g.mu.Lock() - g.erasureData[index] = data - g.mu.Unlock() -} - -// processing returns whether the recovery workflow has been started -func (g *getter) processing(addr swarm.Address) bool { - g.mu.Lock() - defer g.mu.Unlock() - iCh := g.cache[addr.String()] - return iCh.wait != nil -} - -// getAfterProcessed returns chunk from the cache -func (g *getter) getAfterProcessed(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { - g.mu.Lock() - c, ok := g.cache[addr.String()] - // sanity check +func (g *decoder) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { + i, ok := g.cache[addr.ByteString()] if !ok { - return nil, fmt.Errorf("redundancy getter: chunk %s should have been in the cache", addr.String()) + return nil, storage.ErrNotFound } - - cacheData := g.erasureData[c.pos] - g.mu.Unlock() - if cacheData != nil { - return g.cacheDataToChunk(addr, cacheData) + if g.fly(i, true) { + g.wg.Add(1) + go func() { + g.fetch(ctx, i) + g.wg.Done() + }() } - select { - case <-c.wait: - return g.cacheDataToChunk(addr, cacheData) + case <-g.waits[i]: + return swarm.NewChunk(addr, g.rsbuf[i]), nil case <-ctx.Done(): return nil, ctx.Err() } } -// executeStrategies executes recovery strategies from redundancy for the given swarm address -func (g *getter) executeStrategies(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { - g.initWaitChannels() - err := g.cautiousStrategy(ctx) - if err != nil { - g.closeWaitChannels() - return nil, err - } - - return g.getAfterProcessed(ctx, addr) +// fly commits to retrieve the chunk (fly and land) +// it marks a chunk as inflight and returns true unless it is already inflight +// the atomic bool implements a singleflight pattern +func (g *decoder) fly(i int, up bool) (success bool) { + return g.inflight[i].CompareAndSwap(!up, up) } -// initWaitChannels initializes the wait channels in the cache mapping which indicates the start of the recovery process as well -func (g *getter) initWaitChannels() { - g.mu.Lock() - for _, addr := range g.sAddresses { - iCh := g.cache[addr.String()] - iCh.wait = make(chan struct{}) - g.cache[addr.String()] = iCh +// fetch retrieves a chunk from the underlying storage +// it must be called asynchonously and only once for each chunk (singleflight pattern) +// it races with erasure recovery which takes precedence even if it started later +// due to the fact that erasure recovery could only implement global locking on all shards +func (g *decoder) fetch(ctx context.Context, i int) { + ch, err := g.fetcher.Get(ctx, g.addrs[i]) + if err != nil { + _ = g.fly(i, false) // unset inflight + return } - g.mu.Unlock() -} -// closeChannls closes all pending channels -func (g *getter) closeWaitChannels() { - for _, addr := range g.sAddresses { - c := g.cache[addr.String()] - if !channelIsClosed(c.wait) { - close(c.wait) + g.mu.Lock() + defer g.mu.Unlock() + if i < len(g.waits) { + select { + case <-g.waits[i]: // if chunk is retrieved, ignore + return + default: } } -} -// cautiousStrategy requests all chunks (data and parity) on the level -// and if it has enough data for erasure decoding then it cancel other requests -func (g *getter) cautiousStrategy(ctx context.Context) error { - requiredChunks := len(g.sAddresses) - subContext, cancelContext := context.WithCancel(ctx) - retrievedCh := make(chan struct{}, requiredChunks+len(g.pAddresses)) - var wg sync.WaitGroup - - addresses := append(g.sAddresses, g.pAddresses...) - for _, a := range addresses { - wg.Add(1) - c := g.cache[a.String()] - go func(a swarm.Address, c inflightChunk) { - defer wg.Done() - // enrypted chunk data should remain encrypted - address := swarm.NewAddress(a.Bytes()[:swarm.HashSize]) - ch, err := g.Getter.Get(subContext, address) - if err != nil { - return - } - g.setErasureData(c.pos, ch.Data()) - if c.pos < len(g.sAddresses) && !channelIsClosed(c.wait) { - close(c.wait) - } - retrievedCh <- struct{}{} - }(a, c) + select { + case <-ctx.Done(): // if context is cancelled, ignore + _ = g.fly(i, false) // unset inflight + return + default: } - // Goroutine to wait for WaitGroup completion - go func() { - wg.Wait() - close(retrievedCh) - }() - retrieved := 0 - for retrieved < requiredChunks { - _, ok := <-retrievedCh - if !ok { - break - } - retrieved++ + // write chunk to rsbuf and signal waiters + g.rsbuf[i] = ch.Data() // save the chunk in the RS buffer + if i < len(g.waits) { + close(g.waits[i]) // signal that the chunk is retrieved } - cancelContext() - if retrieved < requiredChunks { - return cannotRecoverError(requiredChunks - retrieved) + // if all chunks are retrieved, signal ready + n := g.fetchedCnt.Add(1) + if n == int32(g.shardCnt) { + close(g.ready) // signal that just enough chunks are retrieved for decoding } +} - return g.erasureDecode(ctx) +// missing gathers missing data shards not yet retrieved +// it sets the chunk as inflight and returns the index of the missing data shards +func (g *decoder) missing() (m []int) { + for i := 0; i < g.shardCnt; i++ { + select { + case <-g.waits[i]: // if chunk is retrieved, ignore + continue + default: + } + _ = g.fly(i, true) // commit (RS) or will commit to retrieve the chunk + m = append(m, i) // remember the missing chunk + } + return m } -// erasureDecode perform Reed-Solomon recovery on data -// assumes it is called after filling up cache with the required amount of shards and parities -func (g *getter) erasureDecode(ctx context.Context) error { - enc, err := reedsolomon.New(len(g.sAddresses), len(g.pAddresses)) +// decode uses Reed-Solomon erasure coding decoder to recover data shards +// it must be called after shqrdcnt shards are retrieved +// it must be called under g.mu mutex protection +func (g *decoder) decode(ctx context.Context) error { + enc, err := reedsolomon.New(g.shardCnt, g.parityCnt) if err != nil { return err } - // missing chunks - var missingIndices []int - for i := range g.sAddresses { - if g.erasureData[i] == nil { - missingIndices = append(missingIndices, i) - } - } + // decode data + return enc.ReconstructData(g.rsbuf) +} +// recover wraps the stages of data shard recovery: +// 1. gather missing data shards +// 2. decode using Reed-Solomon decoder +// 3. save reconstructed chunks +func (g *decoder) recover(ctx context.Context) error { + // buffer lock acquired g.mu.Lock() defer g.mu.Unlock() - err = enc.ReconstructData(g.erasureData) - if err != nil { - return err + + // gather missing shards + m := g.missing() + if len(m) == 0 { + return nil } - g.closeWaitChannels() - // save missing chunks - for _, index := range missingIndices { - data := g.erasureData[index] - addr := g.sAddresses[index] - err := g.Putter.Put(ctx, swarm.NewChunk(addr, data)) - if err != nil { - return err - } + // decode using Reed-Solomon decoder + if err := g.decode(ctx); err != nil { + return err } - return nil -} -// cacheDataToChunk transforms passed chunk data to legit swarm chunk -func (g *getter) cacheDataToChunk(addr swarm.Address, chData []byte) (swarm.Chunk, error) { - if chData == nil { - return nil, notRecoveredError(addr.String()) + // close wait channels for missing chunks + for _, i := range m { + close(g.waits[i]) } - // if g.encrypted { - // data, err := store.DecryptChunkData(chData, addr.Bytes()[swarm.HashSize:]) - // if err != nil { - // return nil, err - // } - // chData = data - // } - return swarm.NewChunk(addr, chData), nil + // save chunks + return g.save(ctx, m) } -func channelIsClosed(wait <-chan struct{}) bool { - select { - case _, ok := <-wait: - return !ok - default: - return false +// save iterate over reconstructed shards and puts the corresponding chunks to local storage +func (g *decoder) save(ctx context.Context, missing []int) error { + for _, i := range missing { + if err := g.putter.Put(ctx, swarm.NewChunk(g.addrs[i], g.rsbuf[i])); err != nil { + return err + } } + return nil +} + +func (g *decoder) Close() error { + g.cancel() + g.wg.Wait() + g.remove() + return nil } diff --git a/pkg/file/redundancy/getter/getter_test.go b/pkg/file/redundancy/getter/getter_test.go index addee4b6028..d19b5535057 100644 --- a/pkg/file/redundancy/getter/getter_test.go +++ b/pkg/file/redundancy/getter/getter_test.go @@ -9,9 +9,13 @@ import ( "context" "crypto/rand" "encoding/binary" + "errors" "fmt" "io" + mrand "math/rand" + "sync" "testing" + "time" "github.com/ethersphere/bee/pkg/cac" "github.com/ethersphere/bee/pkg/file/redundancy/getter" @@ -19,199 +23,383 @@ import ( inmem "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" "github.com/ethersphere/bee/pkg/swarm" "github.com/klauspost/reedsolomon" + "golang.org/x/sync/errgroup" ) -func initData(ctx context.Context, erasureBuffer [][]byte, dataShardCount int, s storage.ChunkStore) ([]swarm.Address, []swarm.Address, error) { - spanBytes := make([]byte, 8) - binary.LittleEndian.PutUint64(spanBytes, swarm.ChunkSize) - for i := 0; i < dataShardCount; i++ { - erasureBuffer[i] = make([]byte, swarm.HashSize) - _, err := io.ReadFull(rand.Reader, erasureBuffer[i]) - if err != nil { - return nil, nil, err +type delayed struct { + storage.ChunkStore + cache map[string]time.Duration + mu sync.Mutex +} + +func (d *delayed) delay(addr swarm.Address, delay time.Duration) { + d.mu.Lock() + defer d.mu.Unlock() + d.cache[addr.String()] = delay +} + +func (d *delayed) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + d.mu.Lock() + defer d.mu.Unlock() + if delay, ok := d.cache[addr.String()]; ok && delay > 0 { + select { + case <-time.After(delay): + delete(d.cache, addr.String()) + case <-ctx.Done(): + return nil, ctx.Err() } - copy(erasureBuffer[i], spanBytes) - } - // make parity shards - for i := dataShardCount; i < len(erasureBuffer); i++ { - erasureBuffer[i] = make([]byte, swarm.HashSize) - } - // generate parity chunks - rs, err := reedsolomon.New(dataShardCount, len(erasureBuffer)-dataShardCount) - if err != nil { - return nil, nil, err } - err = rs.Encode(erasureBuffer) - if err != nil { - return nil, nil, err + return d.ChunkStore.Get(ctx, addr) +} + +// TestGetter tests the retrieval of chunks with missing data shards +// using the RACE strategy for a number of erasure code parameters +func TestGetterRACE(t *testing.T) { + type getterTest struct { + bufSize int + shardCnt int + erasureCnt int } - // calculate chunk addresses and upload to the store - var sAddresses, pAddresses []swarm.Address - for i := 0; i < dataShardCount; i++ { - chunk, err := cac.NewWithDataSpan(erasureBuffer[i]) - if err != nil { - return nil, nil, err - } - err = s.Put(ctx, chunk) - if err != nil { - return nil, nil, err + + var tcs []getterTest + for bufSize := 3; bufSize <= 128; bufSize += 21 { + for shardCnt := bufSize/2 + 1; shardCnt <= bufSize; shardCnt += 21 { + parityCnt := bufSize - shardCnt + erasures := mrand.Perm(parityCnt - 1) + if len(erasures) > 3 { + erasures = erasures[:3] + } + for _, erasureCnt := range erasures { + tcs = append(tcs, getterTest{bufSize, shardCnt, erasureCnt}) + } + tcs = append(tcs, getterTest{bufSize, shardCnt, parityCnt}, getterTest{bufSize, shardCnt, parityCnt + 1}) + erasures = mrand.Perm(shardCnt - 1) + if len(erasures) > 3 { + erasures = erasures[:3] + } + for _, erasureCnt := range erasures { + tcs = append(tcs, getterTest{bufSize, shardCnt, erasureCnt + parityCnt + 1}) + } } - sAddresses = append(sAddresses, chunk.Address()) } - for i := dataShardCount; i < len(erasureBuffer); i++ { - chunk, err := cac.NewWithDataSpan(erasureBuffer[i]) - if err != nil { - return nil, nil, err - } - err = s.Put(ctx, chunk) - if err != nil { - return nil, nil, err + t.Run("GET with RACE", func(t *testing.T) { + t.Parallel() + + for _, tc := range tcs { + t.Run(fmt.Sprintf("data/total/missing=%d/%d/%d", tc.shardCnt, tc.bufSize, tc.erasureCnt), func(t *testing.T) { + testDecodingRACE(t, tc.bufSize, tc.shardCnt, tc.erasureCnt) + }) } - pAddresses = append(pAddresses, chunk.Address()) - } + }) +} - return sAddresses, pAddresses, err +// TestGetterFallback tests the retrieval of chunks with missing data shards +// using the strict or fallback mode starting with NONE and DATA strategies +func TestGetterFallback(t *testing.T) { + t.Run("GET", func(t *testing.T) { + t.Run("NONE", func(t *testing.T) { + t.Run("strict", func(t *testing.T) { + testDecodingFallback(t, getter.NONE, true) + }) + t.Run("fallback", func(t *testing.T) { + testDecodingFallback(t, getter.NONE, false) + }) + }) + t.Run("DATA", func(t *testing.T) { + t.Run("strict", func(t *testing.T) { + testDecodingFallback(t, getter.DATA, true) + }) + t.Run("fallback", func(t *testing.T) { + testDecodingFallback(t, getter.DATA, false) + }) + }) + }) } -func dataShardsAvailable(ctx context.Context, sAddresses []swarm.Address, s storage.ChunkStore) error { - for i := 0; i < len(sAddresses); i++ { - has, err := s.Has(ctx, sAddresses[i]) - if err != nil { - return err - } - if !has { - return fmt.Errorf("datashard %d is not available", i) +func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { + t.Helper() + + strategyTimeout := getter.StrategyTimeout + defer func() { getter.StrategyTimeout = strategyTimeout }() + getter.StrategyTimeout = 100 * time.Millisecond + + store := inmem.New() + buf := make([][]byte, bufSize) + addrs := initData(t, buf, shardCnt, store) + + var addr swarm.Address + erasures := forget(t, store, addrs, erasureCnt) + for _, i := range erasures { + if i < shardCnt { + addr = addrs[i] + break } } - return nil + if len(addr.Bytes()) == 0 { + t.Skip("no data shard erased") + } + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + g := getter.New(addrs, shardCnt, store, store, getter.RACE, false, 2*getter.StrategyTimeout, func() {}) + defer g.Close() + parityCnt := len(buf) - shardCnt + q := make(chan error, 1) + go func() { + _, err := g.Get(ctx, addr) + q <- err + }() + err := context.DeadlineExceeded + select { + case err = <-q: + case <-time.After(getter.StrategyTimeout * 4): + } + switch { + case erasureCnt > parityCnt: + t.Run("unable to recover", func(t *testing.T) { + if !errors.Is(err, storage.ErrNotFound) && + !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected not found error or deadline exceeded, got %v", err) + } + }) + case erasureCnt <= parityCnt: + t.Run("will recover", func(t *testing.T) { + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + checkShardsAvailable(t, store, addrs[:shardCnt], buf[:shardCnt]) + }) + } } -func TestDecoding(t *testing.T) { - s := inmem.New() +// testDecodingFallback tests the retrieval of chunks with missing data shards +func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { + t.Helper() - erasureBuffer := make([][]byte, 128) - dataShardCount := 100 - ctx := context.TODO() + strategyTimeout := getter.StrategyTimeout + defer func() { getter.StrategyTimeout = strategyTimeout }() + getter.StrategyTimeout = 100 * time.Millisecond - sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) + bufSize := 12 + shardCnt := 6 + store := &delayed{ChunkStore: inmem.New(), cache: make(map[string]time.Duration)} + buf := make([][]byte, bufSize) + addrs := initData(t, buf, shardCnt, store) + + // erase two data shards + delayed, erased := 1, 0 + ctx := context.TODO() + err := store.Delete(ctx, addrs[erased]) if err != nil { t.Fatal(err) } - getter := getter.New(sAddresses, pAddresses, s, s) - // sanity check - all data shards are retrievable - for i := 0; i < dataShardCount; i++ { - ch, err := getter.Get(ctx, sAddresses[i]) + // context for enforced retrievals with long timeout + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + // signal channels for delayed and erased chunk retrieval + waitDelayed, waitErased := make(chan error, 1), make(chan error, 1) + + // complete retrieval of delayed chunk by putting it into the store after a while + delay := +getter.StrategyTimeout / 4 + if s == getter.NONE { + delay += getter.StrategyTimeout + } + store.delay(addrs[delayed], delay) + // create getter + start := time.Now() + g := getter.New(addrs, shardCnt, store, store, s, strict, getter.StrategyTimeout/2, func() {}) + defer g.Close() + + // launch delayed and erased chunk retrieval + wg := sync.WaitGroup{} + // defer wg.Wait() + wg.Add(2) + // signal using the waitDelayed and waitErased channels when + // delayed and erased chunk retrieval completes + go func() { + defer wg.Done() + _, err := g.Get(ctx, addrs[delayed]) + waitDelayed <- err + }() + go func() { + defer wg.Done() + _, err := g.Get(ctx, addrs[erased]) + waitErased <- err + }() + + // set timeouts for the cases + var timeout time.Duration + switch { + case strict: + timeout = 2*getter.StrategyTimeout - 10*time.Millisecond + case s == getter.NONE: + timeout = 4*getter.StrategyTimeout - 10*time.Millisecond + case s == getter.DATA: + timeout = 3*getter.StrategyTimeout - 10*time.Millisecond + } + + // wait for delayed chunk retrieval to complete + select { + case err := <-waitDelayed: if err != nil { - t.Fatalf("address %s at index %d is not retrievable by redundancy getter", sAddresses[i], i) + t.Fatal("unexpected error", err) + } + round := time.Since(start) / getter.StrategyTimeout + switch { + case strict && s == getter.NONE: + if round < 1 { + t.Fatalf("unexpected completion of delayed chunk retrieval. got round %d", round) + } + case s == getter.NONE: + if round < 1 { + t.Fatalf("unexpected completion of delayed chunk retrieval. got round %d", round) + } + if round > 2 { + t.Fatalf("unexpected late completion of delayed chunk retrieval. got round %d", round) + } + case s == getter.DATA: + if round > 0 { + t.Fatalf("unexpected late completion of delayed chunk retrieval. got round %d", round) + } + } + + checkShardsAvailable(t, store, addrs[delayed:], buf[delayed:]) + // wait for erased chunk retrieval to complete + select { + case err := <-waitErased: + if err != nil { + t.Fatal("unexpected error", err) + } + round = time.Since(start) / getter.StrategyTimeout + switch { + case strict: + t.Fatalf("unexpected completion of erased chunk retrieval. got round %d", round) + case s == getter.NONE: + if round < 2 { + t.Fatalf("unexpected early completion of erased chunk retrieval. got round %d", round) + } + if round > 2 { + t.Fatalf("unexpected late completion of erased chunk retrieval. got round %d", round) + } + case s == getter.DATA: + if round < 1 { + t.Fatalf("unexpected early completion of erased chunk retrieval. got round %d", round) + } + if round > 1 { + t.Fatalf("unexpected late completion of delayed chunk retrieval. got round %d", round) + } + } + checkShardsAvailable(t, store, addrs[:erased], buf[:erased]) + + case <-time.After(getter.StrategyTimeout * 2): + if !strict { + t.Fatal("unexpected timeout using strategy", s, "with strict", strict) + } } - if !bytes.Equal(ch.Data(), erasureBuffer[i]) { - t.Fatalf("retrieved chunk data differ from the original at index %d", i) + case <-time.After(timeout): + if !strict || s != getter.NONE { + t.Fatal("unexpected timeout using strategy", s, "with strict", strict) } } +} + +func initData(t *testing.T, buf [][]byte, shardCnt int, s storage.ChunkStore) []swarm.Address { + t.Helper() + spanBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(spanBytes, swarm.ChunkSize) - // remove maximum possible chunks from storage - removeChunkCount := len(erasureBuffer) - dataShardCount - for i := 0; i < removeChunkCount; i++ { - err := s.Delete(ctx, sAddresses[i]) + for i := 0; i < len(buf); i++ { + buf[i] = make([]byte, swarm.ChunkWithSpanSize) + if i >= shardCnt { + continue + } + _, err := io.ReadFull(rand.Reader, buf[i]) if err != nil { t.Fatal(err) } + copy(buf[i], spanBytes) } - err = dataShardsAvailable(ctx, sAddresses, s) // sanity check - if err == nil { - t.Fatalf("some data shards should be missing") - } - ch, err := getter.Get(ctx, sAddresses[0]) + // fill in parity chunks + rs, err := reedsolomon.New(shardCnt, len(buf)-shardCnt) if err != nil { t.Fatal(err) } - if !bytes.Equal(ch.Data(), erasureBuffer[0]) { - t.Fatalf("retrieved chunk data differ from the original at index %d", 0) - } - err = dataShardsAvailable(ctx, sAddresses, s) + err = rs.Encode(buf) if err != nil { t.Fatal(err) } -} - -func TestRecoveryLimits(t *testing.T) { - s := inmem.New() - erasureBuffer := make([][]byte, 8) - dataShardCount := 5 + // calculate chunk addresses and upload to the store + addrs := make([]swarm.Address, len(buf)) ctx := context.TODO() - - sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) - if err != nil { - t.Fatal(err) - } - - _getter := getter.New(sAddresses, pAddresses, s, s) - - // remove more chunks that can be corrected by erasure code - removeChunkCount := len(erasureBuffer) - dataShardCount + 1 - for i := 0; i < removeChunkCount; i++ { - err := s.Delete(ctx, sAddresses[i]) + for i := 0; i < len(buf); i++ { + chunk, err := cac.NewWithDataSpan(buf[i]) if err != nil { t.Fatal(err) } - } - _, err = _getter.Get(ctx, sAddresses[0]) - if !getter.IsCannotRecoverError(err, 1) { - t.Fatal(err) + err = s.Put(ctx, chunk) + if err != nil { + t.Fatal(err) + } + addrs[i] = chunk.Address() } - // call once more - _, err = _getter.Get(ctx, sAddresses[0]) - if !getter.IsNotRecoveredError(err, sAddresses[0].String()) { - t.Fatal(err) - } + return addrs } -func TestNoRedundancyOnRecovery(t *testing.T) { - s := inmem.New() - - dataShardCount := 5 - erasureBuffer := make([][]byte, dataShardCount) - ctx := context.TODO() - - sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) - if err != nil { - t.Fatal(err) - } - - _getter := getter.New(sAddresses, pAddresses, s, s) - - // remove one chunk that trying to request later - err = s.Delete(ctx, sAddresses[0]) - if err != nil { - t.Fatal(err) +func checkShardsAvailable(t *testing.T, s storage.ChunkStore, addrs []swarm.Address, data [][]byte) { + t.Helper() + eg, ctx := errgroup.WithContext(context.Background()) + for i, addr := range addrs { + i := i + addr := addr + eg.Go(func() (err error) { + var delay time.Duration + var ch swarm.Chunk + for i := 0; i < 30; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + default: + <-time.After(delay) + delay = 50 * time.Millisecond + } + ch, err = s.Get(ctx, addr) + if err == nil { + break + } + err = fmt.Errorf("datashard %d with address %v is not available: %w", i, addr, err) + select { + case <-ctx.Done(): + return ctx.Err() + default: + <-time.After(delay) + delay = 50 * time.Millisecond + } + } + if err == nil && !bytes.Equal(ch.Data(), data[i]) { + return fmt.Errorf("datashard %d has incorrect data", i) + } + return err + }) } - _, err = _getter.Get(ctx, sAddresses[0]) - if !getter.IsNoRedundancyError(err, sAddresses[0].String()) { + if err := eg.Wait(); err != nil { t.Fatal(err) } } -func TestNoDataAddressIncluded(t *testing.T) { - s := inmem.New() +func forget(t *testing.T, store storage.ChunkStore, addrs []swarm.Address, erasureCnt int) (erasures []int) { + t.Helper() - erasureBuffer := make([][]byte, 8) - dataShardCount := 5 ctx := context.TODO() - - sAddresses, pAddresses, err := initData(ctx, erasureBuffer, dataShardCount, s) - if err != nil { - t.Fatal(err) - } - - _getter := getter.New(sAddresses, pAddresses, s, s) - - // trying to retrieve a parity address - _, err = _getter.Get(ctx, pAddresses[0]) - if !getter.IsNoDataAddressIncludedError(err, pAddresses[0].String()) { - t.Fatal(err) + erasures = mrand.Perm(len(addrs))[:erasureCnt] + for _, i := range erasures { + err := store.Delete(ctx, addrs[i]) + if err != nil { + t.Fatal(err) + } } + return erasures } diff --git a/pkg/file/redundancy/getter/strategies.go b/pkg/file/redundancy/getter/strategies.go new file mode 100644 index 00000000000..3d2a39a3816 --- /dev/null +++ b/pkg/file/redundancy/getter/strategies.go @@ -0,0 +1,133 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package getter + +import ( + "context" + "fmt" + "time" +) + +var ( + StrategyTimeout = 500 * time.Millisecond // timeout for each strategy +) + +type ( + strategyKey struct{} + modeKey struct{} + fetcherTimeoutKey struct{} + Strategy = int +) + +const ( + NONE Strategy = iota // no prefetching and no decoding + DATA // just retrieve data shards no decoding + PROX // proximity driven selective fetching + RACE // aggressive fetching racing all chunks + strategyCnt +) + +// GetParamsFromContext extracts the strategy and strict mode from the context +func GetParamsFromContext(ctx context.Context) (s Strategy, strict bool, fetcherTimeout time.Duration) { + var ok bool + s, ok = ctx.Value(strategyKey{}).(Strategy) + if !ok { + s = RACE + } + strict, _ = ctx.Value(modeKey{}).(bool) + fetcherTimeout, _ = ctx.Value(fetcherTimeoutKey{}).(time.Duration) + return s, strict, fetcherTimeout +} + +// SetFetchTimeout sets the timeout for each fetch +func SetFetchTimeout(ctx context.Context, timeout time.Duration) context.Context { + return context.WithValue(ctx, fetcherTimeoutKey{}, timeout) +} + +// SetStrategy sets the strategy for the retrieval +func SetStrategy(ctx context.Context, s Strategy) context.Context { + return context.WithValue(ctx, strategyKey{}, s) +} + +// SetStrict sets the strict mode for the retrieval +func SetStrict(ctx context.Context, strict bool) context.Context { + return context.WithValue(ctx, modeKey{}, strict) +} + +func (g *decoder) prefetch(ctx context.Context, strategy int, strict bool) { + if strict && strategy == NONE { + return + } + + defer g.remove() + var cancels []func() + cancelAll := func() { + for _, cancel := range cancels { + cancel() + } + } + defer cancelAll() + run := func(s Strategy) error { + if s == PROX { // NOT IMPLEMENTED + return fmt.Errorf("strategy %d not implemented", s) + } + + var stop <-chan time.Time + if s < RACE { + timer := time.NewTimer(StrategyTimeout) + defer timer.Stop() + stop = timer.C + } + lctx, cancel := context.WithTimeout(ctx, g.fetchTimeout) + cancels = append(cancels, cancel) + prefetch(lctx, g, s) + + select { + // successfully retrieved shardCnt number of chunks + case <-g.ready: + cancelAll() + case <-stop: + return fmt.Errorf("prefetching with strategy %d timed out", s) + case <-ctx.Done(): + return nil + } + // call the erasure decoder + // if decoding is successful terminate the prefetch loop + return g.recover(ctx) // context to cancel when shardCnt chunks are retrieved + } + var err error + for s := strategy; s == strategy || (err != nil && !strict && s < strategyCnt); s++ { + err = run(s) + } +} + +// prefetch launches the retrieval of chunks based on the strategy +func prefetch(ctx context.Context, g *decoder, s Strategy) { + var m []int + switch s { + case NONE: + return + case DATA: + // only retrieve data shards + m = g.missing() + case PROX: + // proximity driven selective fetching + // NOT IMPLEMENTED + case RACE: + // retrieve all chunks at once enabling race among chunks + m = g.missing() + for i := g.shardCnt; i < len(g.addrs); i++ { + m = append(m, i) + } + } + for _, i := range m { + i := i + g.wg.Add(1) + go func() { + g.fetch(ctx, i) + g.wg.Done() + }() + } +} diff --git a/pkg/file/utils.go b/pkg/file/utils.go index 021109637cd..35f21414b82 100644 --- a/pkg/file/utils.go +++ b/pkg/file/utils.go @@ -12,14 +12,16 @@ import ( "github.com/ethersphere/bee/pkg/swarm" ) -var ZeroAddress = [32]byte{} +var ( + zeroAddress = [32]byte{} +) // ChunkPayloadSize returns the effective byte length of an intermediate chunk // assumes data is always chunk size (without span) func ChunkPayloadSize(data []byte) (int, error) { l := len(data) for l >= swarm.HashSize { - if !bytes.Equal(data[l-swarm.HashSize:l], ZeroAddress[:]) { + if !bytes.Equal(data[l-swarm.HashSize:l], zeroAddress[:]) { return l, nil } @@ -31,21 +33,16 @@ func ChunkPayloadSize(data []byte) (int, error) { // ChunkAddresses returns data shards and parities of the intermediate chunk // assumes data is truncated by ChunkPayloadSize -func ChunkAddresses(data []byte, parities, reflen int) (sAddresses, pAddresses []swarm.Address) { - shards := (len(data) - parities*swarm.HashSize) / reflen - sAddresses = make([]swarm.Address, shards) - pAddresses = make([]swarm.Address, parities) - offset := 0 - for i := 0; i < shards; i++ { - sAddresses[i] = swarm.NewAddress(data[offset : offset+swarm.HashSize]) - offset += reflen - } - for i := 0; i < parities; i++ { - pAddresses[i] = swarm.NewAddress(data[offset : offset+swarm.HashSize]) - offset += swarm.HashSize +func ChunkAddresses(data []byte, parities, reflen int) (addrs []swarm.Address, shardCnt int) { + shardCnt = (len(data) - parities*swarm.HashSize) / reflen + for offset := 0; offset < len(data); offset += reflen { + addrs = append(addrs, swarm.NewAddress(data[offset:offset+swarm.HashSize])) + if len(addrs) == shardCnt && reflen != swarm.HashSize { + reflen = swarm.HashSize + offset += reflen + } } - - return sAddresses, pAddresses + return addrs, shardCnt } // ReferenceCount brute-forces the data shard count from which identify the parity count as well in a substree diff --git a/pkg/steward/steward_test.go b/pkg/steward/steward_test.go index bdddbdd3d08..935db798c01 100644 --- a/pkg/steward/steward_test.go +++ b/pkg/steward/steward_test.go @@ -10,6 +10,7 @@ import ( "crypto/rand" "errors" "sync" + "sync/atomic" "testing" "time" @@ -22,9 +23,20 @@ import ( "github.com/ethersphere/bee/pkg/swarm" ) +type counter struct { + storage.ChunkStore + count atomic.Int32 +} + +func (c *counter) Put(ctx context.Context, ch swarm.Chunk) (err error) { + c.count.Add(1) + return c.ChunkStore.Put(ctx, ch) +} + func TestSteward(t *testing.T) { + t.Skip("skipping test until we indentify the cause of the flakiness") t.Parallel() - inmem := inmemchunkstore.New() + inmem := &counter{ChunkStore: inmemchunkstore.New()} var ( ctx = context.Background() @@ -51,15 +63,7 @@ func TestSteward(t *testing.T) { t.Fatal(err) } - chunkCount := 0 - err = chunkStore.Iterate(context.Background(), func(ch swarm.Chunk) (bool, error) { - chunkCount++ - return false, nil - }) - if err != nil { - t.Fatalf("failed iterating: %v", err) - } - + chunkCount := int(inmem.count.Load()) done := make(chan struct{}) errc := make(chan error, 1) go func() { From 0820d737d67b85b864cba814a51d48d1cfcaa7a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Mon, 18 Dec 2023 14:09:56 +0100 Subject: [PATCH 12/23] test: set level in context --- pkg/steward/steward_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/steward/steward_test.go b/pkg/steward/steward_test.go index 935db798c01..8e2083abc80 100644 --- a/pkg/steward/steward_test.go +++ b/pkg/steward/steward_test.go @@ -15,7 +15,9 @@ import ( "time" "github.com/ethersphere/bee/pkg/file/pipeline/builder" + "github.com/ethersphere/bee/pkg/file/redundancy" postagetesting "github.com/ethersphere/bee/pkg/postage/mock" + "github.com/ethersphere/bee/pkg/replicas" "github.com/ethersphere/bee/pkg/steward" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" @@ -34,7 +36,6 @@ func (c *counter) Put(ctx context.Context, ch swarm.Chunk) (err error) { } func TestSteward(t *testing.T) { - t.Skip("skipping test until we indentify the cause of the flakiness") t.Parallel() inmem := &counter{ChunkStore: inmemchunkstore.New()} @@ -48,6 +49,7 @@ func TestSteward(t *testing.T) { s = steward.New(store, localRetrieval, inmem) stamper = postagetesting.NewStamper() ) + ctx = replicas.SetLevel(ctx, redundancy.NONE) n, err := rand.Read(data) if n != cap(data) { From 845b68bd5cce4414196acc12fd3412eceb095bef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Mon, 18 Dec 2023 14:37:52 +0100 Subject: [PATCH 13/23] test: not full chunk in joiner --- pkg/file/joiner/joiner_test.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index af86a9fd061..08ce4b3b311 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -1078,8 +1078,12 @@ func TestJoinerRedundancy(t *testing.T) { pipe := builder.NewPipelineBuilder(ctx, putter, tc.encryptChunk, tc.rLevel) dataChunks := make([]swarm.Chunk, shardCnt) chunkSize := swarm.ChunkSize + size := chunkSize for i := 0; i < shardCnt; i++ { - chunkData := make([]byte, chunkSize) + if i == shardCnt-1 { + size = 5 + } + chunkData := make([]byte, size) _, err := io.ReadFull(rand.Reader, chunkData) if err != nil { t.Fatal(err) @@ -1117,7 +1121,7 @@ func TestJoinerRedundancy(t *testing.T) { t.Fatal(err) } // sanity checks - expectedRootSpan := chunkSize * shardCnt + expectedRootSpan := chunkSize*(shardCnt-1) + 5 if int64(expectedRootSpan) != rootSpan { t.Fatalf("Expected root span %d. Got: %d", expectedRootSpan, rootSpan) } @@ -1132,7 +1136,7 @@ func TestJoinerRedundancy(t *testing.T) { i := i eg.Go(func() error { chunkData := make([]byte, chunkSize) - _, err := joinReader.ReadAt(chunkData, int64(i*chunkSize)) + n, err := joinReader.ReadAt(chunkData, int64(i*chunkSize)) if err != nil { return err } @@ -1142,7 +1146,7 @@ func TestJoinerRedundancy(t *testing.T) { default: } expectedChunkData := dataChunks[i].Data()[swarm.SpanSize:] - if !bytes.Equal(expectedChunkData, chunkData) { + if !bytes.Equal(expectedChunkData, chunkData[:n]) { return fmt.Errorf("data mismatch on chunk position %d", i) } return nil From 4307af359c4b3fca099a8986d19fdf40bcd03bc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Mon, 18 Dec 2023 14:40:14 +0100 Subject: [PATCH 14/23] test: raise strategy timeout --- pkg/file/joiner/joiner_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 08ce4b3b311..0ef5fef9ab4 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -1113,7 +1113,7 @@ func TestJoinerRedundancy(t *testing.T) { readCheck := func(t *testing.T, expErr error) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 8*getter.StrategyTimeout) + ctx, cancel := context.WithTimeout(context.Background(), 10*getter.StrategyTimeout) defer cancel() ctx = getter.SetFetchTimeout(ctx, getter.StrategyTimeout) joinReader, rootSpan, err := joiner.New(ctx, store, store, swarmAddr) From 3da49a226a2f3eb23204e1d8fa693a05603b810f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tr=C3=B3n?= Date: Mon, 18 Dec 2023 20:36:41 +0100 Subject: [PATCH 15/23] feat(redundancy): api headers integration (#4515) --- openapi/Swarm.yaml | 25 ++++----- openapi/SwarmCommon.yaml | 35 +++++++++++++ pkg/api/api.go | 28 ++++++----- pkg/api/bzz.go | 13 ++++- pkg/api/dirs.go | 4 +- pkg/file/joiner/joiner_test.go | 5 +- pkg/file/redundancy/getter/getter.go | 64 ++++++++++++------------ pkg/file/redundancy/getter/strategies.go | 13 ++--- 8 files changed, 115 insertions(+), 72 deletions(-) diff --git a/openapi/Swarm.yaml b/openapi/Swarm.yaml index 7445f98e88c..6fb9561e44f 100644 --- a/openapi/Swarm.yaml +++ b/openapi/Swarm.yaml @@ -1,7 +1,7 @@ openapi: 3.0.3 info: - version: 5.1.1 + version: 5.2.0 title: Bee API description: "A list of the currently provided Interfaces to interact with the swarm, implementing file operations and sending messages" @@ -122,7 +122,7 @@ paths: required: false - in: header schema: - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyLevelParameter" + $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyLevel" name: swarm-redundancy-level required: false @@ -163,11 +163,10 @@ paths: $ref: "SwarmCommon.yaml#/components/schemas/SwarmReference" required: true description: Swarm address reference to content - - in: header - schema: - $ref: "SwarmCommon.yaml#/components/parameters/SwarmCache" - name: swarm-cache - required: false + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmCache" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyStrategyParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyFallbackModeParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmChunkRetrievalTimeoutParameter" responses: "200": description: Retrieved content specified by reference @@ -311,11 +310,10 @@ paths: $ref: "SwarmCommon.yaml#/components/schemas/SwarmReference" required: true description: Swarm address of content - - in: header - schema: - $ref: "SwarmCommon.yaml#/components/parameters/SwarmCache" - name: swarm-cache - required: false + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmCache" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyStrategyParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyFallbackModeParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmChunkRetrievalTimeoutParameter" responses: "200": description: Ok @@ -353,6 +351,9 @@ paths: type: string required: true description: Path to the file in the collection. + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyStrategyParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyModeParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmChunkRetrievalTimeoutParameter" responses: "200": description: Ok diff --git a/openapi/SwarmCommon.yaml b/openapi/SwarmCommon.yaml index 1c2e1bbe77d..4a7faadc154 100644 --- a/openapi/SwarmCommon.yaml +++ b/openapi/SwarmCommon.yaml @@ -945,6 +945,41 @@ components: Add redundancy to the data being uploaded so that downloaders can download it with better UX. 0 value is default and does not add any redundancy to the file. + SwarmRedundancyStrategyParameter: + in: header + name: swarm-redundancy-strategy + schema: + type: integer + enum: [0, 1, 2, 3] + required: false + description: > + Specify the retrieve strategy on redundant data. + The mumbers stand for NONE, DATA, PROX and RACE, respectively. + Strategy NONE means no prefetching takes place. + Strategy DATA means only data chunks are prefetched. + Strategy PROX means only chunks that are close to the node are prefetched. + Strategy RACE means all chunks are prefetched: n data chunks and k parity chunks. The first n chunks to arrive are used to reconstruct the file. + Multiple strategies can be used in a fallback cascade if the swarm redundancy fallback mode is set to true. + The default strategy is NONE, DATA, falling back to PROX, falling back to RACE + + SwarmRedundancyFallbackModeParameter: + in: header + name: swarm-redundancy-fallback-mode + schema: + type: boolean + required: false + description: > + Specify if the retrieve strategies (chunk prefetching on redundant data) are used in a fallback cascade. The default is true. + + SwarmChunkRetrievalTimeoutParameter: + in: header + name: swarm-chunk-retrieval-timeout + schema: + $ref: "#/components/schemas/Duration" + required: false + description: > + Specify the timeout for chunk retrieval. The default is 30 seconds. + ContentTypePreserved: in: header name: Content-Type diff --git a/pkg/api/api.go b/pkg/api/api.go index 58dba18c377..26d11213e5a 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -70,17 +70,20 @@ import ( const loggerName = "api" const ( - SwarmPinHeader = "Swarm-Pin" - SwarmTagHeader = "Swarm-Tag" - SwarmEncryptHeader = "Swarm-Encrypt" - SwarmIndexDocumentHeader = "Swarm-Index-Document" - SwarmErrorDocumentHeader = "Swarm-Error-Document" - SwarmFeedIndexHeader = "Swarm-Feed-Index" - SwarmFeedIndexNextHeader = "Swarm-Feed-Index-Next" - SwarmCollectionHeader = "Swarm-Collection" - SwarmPostageBatchIdHeader = "Swarm-Postage-Batch-Id" - SwarmDeferredUploadHeader = "Swarm-Deferred-Upload" - SwarmRedundancyLevel = "Swarm-Redundancy-Level" + SwarmPinHeader = "Swarm-Pin" + SwarmTagHeader = "Swarm-Tag" + SwarmEncryptHeader = "Swarm-Encrypt" + SwarmIndexDocumentHeader = "Swarm-Index-Document" + SwarmErrorDocumentHeader = "Swarm-Error-Document" + SwarmFeedIndexHeader = "Swarm-Feed-Index" + SwarmFeedIndexNextHeader = "Swarm-Feed-Index-Next" + SwarmCollectionHeader = "Swarm-Collection" + SwarmPostageBatchIdHeader = "Swarm-Postage-Batch-Id" + SwarmDeferredUploadHeader = "Swarm-Deferred-Upload" + SwarmRedundancyLevelHeader = "Swarm-Redundancy-Level" + SwarmRedundancyStrategyHeader = "Swarm-Redundancy-Strategy" + SwarmRedundancyFallbackModeHeader = "Swarm-Redundancy-Fallback-Mode" + SwarmChunkRetrievalTimeoutHeader = "Swarm-Chunk-Retrieval-Timeout-Level" ImmutableHeader = "Immutable" GasPriceHeader = "Gas-Price" @@ -624,8 +627,7 @@ func (s *Service) corsHandler(h http.Handler) http.Handler { allowedHeaders := []string{ "User-Agent", "Accept", "X-Requested-With", "Access-Control-Request-Headers", "Access-Control-Request-Method", "Accept-Ranges", "Content-Encoding", AuthorizationHeader, AcceptEncodingHeader, ContentTypeHeader, ContentDispositionHeader, RangeHeader, OriginHeader, - SwarmTagHeader, SwarmPinHeader, SwarmEncryptHeader, SwarmIndexDocumentHeader, SwarmErrorDocumentHeader, SwarmCollectionHeader, SwarmPostageBatchIdHeader, SwarmDeferredUploadHeader, SwarmRedundancyLevel, - GasPriceHeader, GasLimitHeader, ImmutableHeader, + SwarmTagHeader, SwarmPinHeader, SwarmEncryptHeader, SwarmIndexDocumentHeader, SwarmErrorDocumentHeader, SwarmCollectionHeader, SwarmPostageBatchIdHeader, SwarmDeferredUploadHeader, SwarmRedundancyLevelHeader, SwarmRedundancyStrategyHeader, SwarmRedundancyFallbackModeHeader, SwarmChunkRetrievalTimeoutHeader, SwarmFeedIndexHeader, SwarmFeedIndexNextHeader, GasPriceHeader, GasLimitHeader, ImmutableHeader, } allowedHeadersStr := strings.Join(allowedHeaders, ", ") diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index a4b06f738cb..be514a92ba5 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -21,6 +21,7 @@ import ( "github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/loadsave" "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/file/redundancy/getter" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/manifest" @@ -455,7 +456,10 @@ func (s *Service) serveManifestEntry( // downloadHandler contains common logic for dowloading Swarm file from API func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *http.Request, reference swarm.Address, additionalHeaders http.Header, etag bool) { headers := struct { - Cache *bool `map:"Swarm-Cache"` + Cache *bool `map:"Swarm-Cache"` + Strategy getter.Strategy `map:"Swarm-Redundancy-Strategy"` + FallbackMode bool `map:"Swarm-Redundancy-Fallback-Mode"` + ChunkRetrievalTimeout time.Duration `map:"Swarm-Chunk-Retrieval-Timeout"` }{} if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) @@ -465,7 +469,12 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h if headers.Cache != nil { cache = *headers.Cache } - reader, l, err := joiner.New(r.Context(), s.storer.Download(cache), s.storer.Cache(), reference) + + ctx := r.Context() + ctx = getter.SetStrategy(ctx, headers.Strategy) + ctx = getter.SetStrict(ctx, headers.FallbackMode) + ctx = getter.SetFetchTimeout(ctx, headers.ChunkRetrievalTimeout) + reader, l, err := joiner.New(ctx, s.storer.Download(cache), s.storer.Cache(), reference) if err != nil { if errors.Is(err, storage.ErrNotFound) || errors.Is(err, topology.ErrNotFound) { logger.Debug("api download: not found ", "address", reference, "error", err) diff --git a/pkg/api/dirs.go b/pkg/api/dirs.go index 867e1266629..018bc4d8e5c 100644 --- a/pkg/api/dirs.go +++ b/pkg/api/dirs.go @@ -64,9 +64,9 @@ func (s *Service) dirUploadHandler( } defer r.Body.Close() - rLevelNum, err := strconv.ParseUint(r.Header.Get(SwarmRedundancyLevel), 10, 8) + rLevelNum, err := strconv.ParseUint(r.Header.Get(SwarmRedundancyLevelHeader), 10, 8) if err != nil { - logger.Debug("store directory failed failed", "redundancy level parsing error") + logger.Debug("store directory failed", "redundancy level parsing error") logger.Error(nil, "store directory failed") } rLevel, err := redundancy.NewLevel(uint8(rLevelNum)) diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 0ef5fef9ab4..7a97080171b 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -1062,7 +1062,7 @@ func TestJoinerRedundancy(t *testing.T) { }, } { tc := tc - t.Run(fmt.Sprintf("redundancy %d encryption %t", tc.rLevel, tc.encryptChunk), func(t *testing.T) { + t.Run(fmt.Sprintf("redundancy=%d encryption=%t", tc.rLevel, tc.encryptChunk), func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -1113,7 +1113,7 @@ func TestJoinerRedundancy(t *testing.T) { readCheck := func(t *testing.T, expErr error) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 10*getter.StrategyTimeout) + ctx, cancel := context.WithTimeout(context.Background(), 15*getter.StrategyTimeout) defer cancel() ctx = getter.SetFetchTimeout(ctx, getter.StrategyTimeout) joinReader, rootSpan, err := joiner.New(ctx, store, store, swarmAddr) @@ -1188,7 +1188,6 @@ func TestJoinerRedundancy(t *testing.T) { t.Run("recover from replica if root deleted", func(t *testing.T) { readCheck(t, nil) }) - }) } } diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go index 1f1b847d856..e23f2a3e6af 100644 --- a/pkg/file/redundancy/getter/getter.go +++ b/pkg/file/redundancy/getter/getter.go @@ -20,22 +20,21 @@ import ( // if retrieves children of an intermediate chunk potentially using erasure decoding // it caches sibling chunks if erasure decoding started already type decoder struct { - fetcher storage.Getter // network retrieval interface to fetch chunks - putter storage.Putter // interface to local storage to save reconstructed chunks - addrs []swarm.Address // all addresses of the intermediate chunk - inflight []atomic.Bool // locks to protect wait channels and RS buffer - cache map[string]int // map from chunk address shard position index - waits []chan struct{} // wait channels for each chunk - rsbuf [][]byte // RS buffer of data + parity shards for erasure decoding - ready chan struct{} // signal channel for successful retrieval of shardCnt chunks - shardCnt int // number of data shards - parityCnt int // number of parity shards - wg sync.WaitGroup // wait group to wait for all goroutines to finish - mu sync.Mutex // mutex to protect buffer - fetchTimeout time.Duration // timeout for each fetch - fetchedCnt atomic.Int32 // count successful retrievals - cancel func() // cancel function for RS decoding - remove func() // callback to remove decoder from decoders cache + fetcher storage.Getter // network retrieval interface to fetch chunks + putter storage.Putter // interface to local storage to save reconstructed chunks + addrs []swarm.Address // all addresses of the intermediate chunk + inflight []atomic.Bool // locks to protect wait channels and RS buffer + cache map[string]int // map from chunk address shard position index + waits []chan struct{} // wait channels for each chunk + rsbuf [][]byte // RS buffer of data + parity shards for erasure decoding + ready chan struct{} // signal channel for successful retrieval of shardCnt chunks + shardCnt int // number of data shards + parityCnt int // number of parity shards + wg sync.WaitGroup // wait group to wait for all goroutines to finish + mu sync.Mutex // mutex to protect buffer + fetchedCnt atomic.Int32 // count successful retrievals + cancel func() // cancel function for RS decoding + remove func() // callback to remove decoder from decoders cache } type Getter interface { @@ -44,24 +43,27 @@ type Getter interface { } // New returns a decoder object used tos retrieve children of an intermediate chunk -func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter, strategy Strategy, strict bool, timeout time.Duration, remove func()) Getter { +func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter, strategy Strategy, strict bool, fetchTimeout time.Duration, remove func()) Getter { ctx, cancel := context.WithCancel(context.Background()) size := len(addrs) + if fetchTimeout == 0 { + fetchTimeout = 30 * time.Second + } + strategyTimeout := StrategyTimeout rsg := &decoder{ - fetcher: g, - putter: p, - addrs: addrs, - inflight: make([]atomic.Bool, size), - cache: make(map[string]int, size), - waits: make([]chan struct{}, shardCnt), - rsbuf: make([][]byte, size), - ready: make(chan struct{}, 1), - cancel: cancel, - remove: remove, - shardCnt: shardCnt, - parityCnt: size - shardCnt, - fetchTimeout: timeout, + fetcher: g, + putter: p, + addrs: addrs, + inflight: make([]atomic.Bool, size), + cache: make(map[string]int, size), + waits: make([]chan struct{}, shardCnt), + rsbuf: make([][]byte, size), + ready: make(chan struct{}, 1), + cancel: cancel, + remove: remove, + shardCnt: shardCnt, + parityCnt: size - shardCnt, } // after init, cache and wait channels are immutable, need no locking @@ -73,7 +75,7 @@ func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter // prefetch chunks according to strategy rsg.wg.Add(1) go func() { - rsg.prefetch(ctx, strategy, strict) + rsg.prefetch(ctx, strategy, strict, strategyTimeout, fetchTimeout) rsg.wg.Done() }() return rsg diff --git a/pkg/file/redundancy/getter/strategies.go b/pkg/file/redundancy/getter/strategies.go index 3d2a39a3816..8bf944e8ae7 100644 --- a/pkg/file/redundancy/getter/strategies.go +++ b/pkg/file/redundancy/getter/strategies.go @@ -31,11 +31,7 @@ const ( // GetParamsFromContext extracts the strategy and strict mode from the context func GetParamsFromContext(ctx context.Context) (s Strategy, strict bool, fetcherTimeout time.Duration) { - var ok bool - s, ok = ctx.Value(strategyKey{}).(Strategy) - if !ok { - s = RACE - } + s, _ = ctx.Value(strategyKey{}).(Strategy) strict, _ = ctx.Value(modeKey{}).(bool) fetcherTimeout, _ = ctx.Value(fetcherTimeoutKey{}).(time.Duration) return s, strict, fetcherTimeout @@ -56,11 +52,10 @@ func SetStrict(ctx context.Context, strict bool) context.Context { return context.WithValue(ctx, modeKey{}, strict) } -func (g *decoder) prefetch(ctx context.Context, strategy int, strict bool) { +func (g *decoder) prefetch(ctx context.Context, strategy int, strict bool, strategyTimeout, fetchTimeout time.Duration) { if strict && strategy == NONE { return } - defer g.remove() var cancels []func() cancelAll := func() { @@ -76,11 +71,11 @@ func (g *decoder) prefetch(ctx context.Context, strategy int, strict bool) { var stop <-chan time.Time if s < RACE { - timer := time.NewTimer(StrategyTimeout) + timer := time.NewTimer(strategyTimeout) defer timer.Stop() stop = timer.C } - lctx, cancel := context.WithTimeout(ctx, g.fetchTimeout) + lctx, cancel := context.WithTimeout(ctx, fetchTimeout) cancels = append(cancels, cancel) prefetch(lctx, g, s) From a9ba136d4ed65eb5cc8fba4a752b28f45fd9f035 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Mon, 18 Dec 2023 23:37:05 +0100 Subject: [PATCH 16/23] test: bzz api --- pkg/api/bzz_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pkg/api/bzz_test.go b/pkg/api/bzz_test.go index e6d52d8acd3..d2ad44e507c 100644 --- a/pkg/api/bzz_test.go +++ b/pkg/api/bzz_test.go @@ -187,6 +187,28 @@ func TestBzzFiles(t *testing.T) { ) }) + t.Run("redundancy", func(t *testing.T) { + fileName := "my-pictures.jpeg" + + var resp api.BzzUploadResponse + jsonhttptest.Request(t, client, http.MethodPost, fileUploadResource+"?name="+fileName, http.StatusCreated, + jsonhttptest.WithRequestHeader(api.SwarmDeferredUploadHeader, "true"), + jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, batchOkStr), + jsonhttptest.WithRequestBody(bytes.NewReader(simpleData)), + jsonhttptest.WithRequestHeader(api.SwarmEncryptHeader, "True"), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyLevelHeader, "4"), + jsonhttptest.WithRequestHeader(api.ContentTypeHeader, "image/jpeg; charset=utf-8"), + jsonhttptest.WithUnmarshalJSONResponse(&resp), + ) + + jsonhttptest.Request(t, client, http.MethodGet, fileDownloadResource(resp.Reference.String()), http.StatusOK, + jsonhttptest.WithExpectedContentLength(len(simpleData)), + jsonhttptest.WithExpectedResponseHeader(api.ContentTypeHeader, "image/jpeg; charset=utf-8"), + jsonhttptest.WithExpectedResponseHeader(api.ContentDispositionHeader, fmt.Sprintf(`inline; filename="%s"`, fileName)), + jsonhttptest.WithExpectedResponse(simpleData), + ) + }) + t.Run("filter out filename path", func(t *testing.T) { fileName := "my-pictures.jpeg" fileNameWithPath := "../../" + fileName From 139783a40ac679c3db7c2e254b9407513caddad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Mon, 18 Dec 2023 23:38:00 +0100 Subject: [PATCH 17/23] fix: sorry --- pkg/file/pipeline/hashtrie/hashtrie.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/file/pipeline/hashtrie/hashtrie.go b/pkg/file/pipeline/hashtrie/hashtrie.go index cb3c5f468a2..b3c898c76e1 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie.go +++ b/pkg/file/pipeline/hashtrie/hashtrie.go @@ -268,7 +268,7 @@ func (h *hashTrieWriter) Sum() ([]byte, error) { if err != nil { return nil, err } - err = h.replicaPutter.Put(h.ctx, swarm.NewChunk(swarm.NewAddress(rootHash), rootData)) + err = h.replicaPutter.Put(h.ctx, swarm.NewChunk(swarm.NewAddress(rootHash[:swarm.HashSize]), rootData)) if err != nil { return nil, fmt.Errorf("hashtrie: cannot put dispersed replica %s", err.Error()) } From eea6099a5294112b678890ccd94e8784ab076a29 Mon Sep 17 00:00:00 2001 From: nugaon <50576770+nugaon@users.noreply.github.com> Date: Tue, 19 Dec 2023 18:05:49 +0100 Subject: [PATCH 18/23] feat: dispersed replica validation (#4522) --- pkg/file/redundancy/getter/getter_test.go | 15 +----- pkg/replicas/replicas.go | 4 +- pkg/soc/validator.go | 7 +++ pkg/soc/validator_test.go | 58 +++++++++++++++++++++++ pkg/swarm/swarm.go | 6 +++ 5 files changed, 74 insertions(+), 16 deletions(-) diff --git a/pkg/file/redundancy/getter/getter_test.go b/pkg/file/redundancy/getter/getter_test.go index d19b5535057..bb00f7ddd02 100644 --- a/pkg/file/redundancy/getter/getter_test.go +++ b/pkg/file/redundancy/getter/getter_test.go @@ -151,7 +151,7 @@ func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { err := context.DeadlineExceeded select { case err = <-q: - case <-time.After(getter.StrategyTimeout * 4): + case <-time.After(getter.StrategyTimeout * 10): } switch { case erasureCnt > parityCnt: @@ -228,17 +228,6 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { waitErased <- err }() - // set timeouts for the cases - var timeout time.Duration - switch { - case strict: - timeout = 2*getter.StrategyTimeout - 10*time.Millisecond - case s == getter.NONE: - timeout = 4*getter.StrategyTimeout - 10*time.Millisecond - case s == getter.DATA: - timeout = 3*getter.StrategyTimeout - 10*time.Millisecond - } - // wait for delayed chunk retrieval to complete select { case err := <-waitDelayed: @@ -297,7 +286,7 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { t.Fatal("unexpected timeout using strategy", s, "with strict", strict) } } - case <-time.After(timeout): + case <-time.After(getter.StrategyTimeout * 3): if !strict || s != getter.NONE { t.Fatal("unexpected timeout using strategy", s, "with strict", strict) } diff --git a/pkg/replicas/replicas.go b/pkg/replicas/replicas.go index 0fbcee6857b..18b21d4b8b8 100644 --- a/pkg/replicas/replicas.go +++ b/pkg/replicas/replicas.go @@ -12,7 +12,6 @@ package replicas import ( "context" - "encoding/hex" "time" "github.com/ethersphere/bee/pkg/crypto" @@ -29,7 +28,6 @@ var ( RetryInterval = 300 * time.Millisecond privKey, _ = crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) signer = crypto.NewDefaultSigner(privKey) - owner, _ = hex.DecodeString("dc5b20847f43d67928f49cd4f85d696b5a7617b5") ) // SetLevel sets the redundancy level in the context @@ -82,7 +80,7 @@ func (rr *replicator) replicate(i uint8) (sp *replica) { // calculate SOC address for potential replica h := swarm.NewHasher() _, _ = h.Write(id) - _, _ = h.Write(owner) + _, _ = h.Write(swarm.ReplicasOwner) return &replica{h.Sum(nil), id} } diff --git a/pkg/soc/validator.go b/pkg/soc/validator.go index a707eaded36..db0c388808e 100644 --- a/pkg/soc/validator.go +++ b/pkg/soc/validator.go @@ -5,6 +5,8 @@ package soc import ( + "bytes" + "github.com/ethersphere/bee/pkg/swarm" ) @@ -15,6 +17,11 @@ func Valid(ch swarm.Chunk) bool { return false } + // disperse replica validation + if bytes.Equal(s.owner, swarm.ReplicasOwner) && !bytes.Equal(s.WrappedChunk().Address().Bytes()[1:32], s.id[1:32]) { + return false + } + address, err := s.Address() if err != nil { return false diff --git a/pkg/soc/validator_test.go b/pkg/soc/validator_test.go index 18ed00001a4..695d43cc46f 100644 --- a/pkg/soc/validator_test.go +++ b/pkg/soc/validator_test.go @@ -5,9 +5,13 @@ package soc_test import ( + "crypto/rand" + "io" "strings" "testing" + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/crypto" "github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/swarm" ) @@ -31,6 +35,60 @@ func TestValid(t *testing.T) { } } +// TestValidDispersedReplica verifies that the validator can detect +// valid dispersed replicas chunks. +func TestValidDispersedReplica(t *testing.T) { + t.Parallel() + + t.Run("valid", func(t *testing.T) { + privKey, _ := crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) + signer := crypto.NewDefaultSigner(privKey) + + chData := make([]byte, swarm.ChunkSize) + _, _ = io.ReadFull(rand.Reader, chData) + ch, err := cac.New(chData) + if err != nil { + t.Fatal(err) + } + id := append([]byte{1}, ch.Address().Bytes()[1:]...) + + socCh, err := soc.New(id, ch).Sign(signer) + if err != nil { + t.Fatal(err) + } + + // check valid chunk + if !soc.Valid(socCh) { + t.Fatal("dispersed replica chunk is invalid") + } + }) + + t.Run("invalid", func(t *testing.T) { + privKey, _ := crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) + signer := crypto.NewDefaultSigner(privKey) + + chData := make([]byte, swarm.ChunkSize) + _, _ = io.ReadFull(rand.Reader, chData) + ch, err := cac.New(chData) + if err != nil { + t.Fatal(err) + } + id := append([]byte{1}, ch.Address().Bytes()[1:]...) + // change to invalid ID + id[2] += 1 + + socCh, err := soc.New(id, ch).Sign(signer) + if err != nil { + t.Fatal(err) + } + + // check valid chunk + if soc.Valid(socCh) { + t.Fatal("dispersed replica should be invalid") + } + }) +} + // TestInvalid verifies that the validator can detect chunks // with invalid data and invalid address. func TestInvalid(t *testing.T) { diff --git a/pkg/swarm/swarm.go b/pkg/swarm/swarm.go index 7e8d514b095..b69b4c034bf 100644 --- a/pkg/swarm/swarm.go +++ b/pkg/swarm/swarm.go @@ -37,6 +37,12 @@ var ( ErrInvalidChunk = errors.New("invalid chunk") ) +var ( + // Ethereum Address for SOC owner of Dispersed Replicas + // generated from private key 0x0100000000000000000000000000000000000000000000000000000000000000 + ReplicasOwner, _ = hex.DecodeString("dc5b20847f43d67928f49cd4f85d696b5a7617b5") +) + var ( // EmptyAddress is the address that is all zeroes. EmptyAddress = NewAddress(make([]byte, HashSize)) From 7ed51dbcf2f82ec629c65b8d520f85d673cf0a5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Wed, 20 Dec 2023 17:05:10 +0100 Subject: [PATCH 19/23] refactor: reviews --- pkg/api/bzz.go | 2 +- pkg/api/dirs.go | 11 +---------- pkg/file/redundancy/level.go | 19 ++++++------------- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index be514a92ba5..3f8cfbac1f6 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -105,7 +105,7 @@ func (s *Service) bzzUploadHandler(w http.ResponseWriter, r *http.Request) { } if headers.IsDir || headers.ContentType == multiPartFormData { - s.dirUploadHandler(logger, ow, r, putter, r.Header.Get(ContentTypeHeader), headers.Encrypt, tag) + s.dirUploadHandler(logger, ow, r, putter, r.Header.Get(ContentTypeHeader), headers.Encrypt, tag, headers.RLevel) return } s.fileUploadHandler(logger, ow, r, putter, headers.Encrypt, tag, headers.RLevel) diff --git a/pkg/api/dirs.go b/pkg/api/dirs.go index 018bc4d8e5c..c368c8e9072 100644 --- a/pkg/api/dirs.go +++ b/pkg/api/dirs.go @@ -41,6 +41,7 @@ func (s *Service) dirUploadHandler( contentTypeString string, encrypt bool, tag uint64, + rLevel redundancy.Level, ) { if r.Body == http.NoBody { logger.Error(nil, "request has no body") @@ -64,16 +65,6 @@ func (s *Service) dirUploadHandler( } defer r.Body.Close() - rLevelNum, err := strconv.ParseUint(r.Header.Get(SwarmRedundancyLevelHeader), 10, 8) - if err != nil { - logger.Debug("store directory failed", "redundancy level parsing error") - logger.Error(nil, "store directory failed") - } - rLevel, err := redundancy.NewLevel(uint8(rLevelNum)) - if err != nil { - jsonhttp.BadRequest(w, err.Error()) - } - reference, err := storeDir( r.Context(), encrypt, diff --git a/pkg/file/redundancy/level.go b/pkg/file/redundancy/level.go index 597c506a67c..3dfba1cd084 100644 --- a/pkg/file/redundancy/level.go +++ b/pkg/file/redundancy/level.go @@ -29,15 +29,6 @@ const ( PARANOID ) -// NewLevel returns a Level coresponding to the passed number parameter -// throws an error if there is no level for the passed number -func NewLevel(n uint8) (Level, error) { - if n > uint8(PARANOID) { - return 0, fmt.Errorf("redundancy: number %d does not have corresponding level", n) - } - return Level(n), nil -} - // GetParities returns number of parities based on appendix F table 5 func (l Level) GetParities(shards int) int { et, err := l.getErasureTable() @@ -64,6 +55,8 @@ func (l Level) GetEncParities(shards int) int { func (l Level) getErasureTable() (erasureTable, error) { switch l { + case NONE: + return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") case MEDIUM: return mediumEt, nil case STRONG: @@ -73,12 +66,14 @@ func (l Level) getErasureTable() (erasureTable, error) { case PARANOID: return paranoidEt, nil default: - return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") + return erasureTable{}, fmt.Errorf("redundancy: level value %d is not a legit redundancy level", l) } } func (l Level) getEncErasureTable() (erasureTable, error) { switch l { + case NONE: + return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") case MEDIUM: return encMediumEt, nil case STRONG: @@ -88,7 +83,7 @@ func (l Level) getEncErasureTable() (erasureTable, error) { case PARANOID: return encParanoidEt, nil default: - return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") + return erasureTable{}, fmt.Errorf("redundancy: level value %d is not a legit redundancy level", l) } } @@ -162,8 +157,6 @@ var encParanoidEt = newErasureTable( }, ) -// DISPERSED REPLICAS INIT - // GetReplicaCounts returns back the ascending dispersed replica counts for all redundancy levels func GetReplicaCounts() [5]int { c := replicaCounts From 2306fb05d8cb2de4007080aeea23160ad16d1b88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Wed, 20 Dec 2023 17:51:50 +0100 Subject: [PATCH 20/23] fix: split cmd pipelineBuilder call --- cmd/bee/cmd/split.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/bee/cmd/split.go b/cmd/bee/cmd/split.go index 6eba46c69c2..e173ae186fa 100644 --- a/cmd/bee/cmd/split.go +++ b/cmd/bee/cmd/split.go @@ -33,7 +33,7 @@ type pipelineFunc func(context.Context, io.Reader) (swarm.Address, error) func requestPipelineFn(s storage.Putter, encrypt bool) pipelineFunc { return func(ctx context.Context, r io.Reader) (swarm.Address, error) { - pipe := builder.NewPipelineBuilder(ctx, s, encrypt) + pipe := builder.NewPipelineBuilder(ctx, s, encrypt, 0) return builder.FeedPipeline(ctx, pipe, r) } } From 107f94770d34f75a4ff33e95deb703abc0695d18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Wed, 20 Dec 2023 18:08:26 +0100 Subject: [PATCH 21/23] docs: fix openapi --- openapi/Swarm.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openapi/Swarm.yaml b/openapi/Swarm.yaml index 6fb9561e44f..3ac543a5970 100644 --- a/openapi/Swarm.yaml +++ b/openapi/Swarm.yaml @@ -122,7 +122,7 @@ paths: required: false - in: header schema: - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyLevel" + $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyLevelParameter" name: swarm-redundancy-level required: false @@ -258,7 +258,7 @@ paths: - $ref: "SwarmCommon.yaml#/components/parameters/SwarmErrorDocumentParameter" - $ref: "SwarmCommon.yaml#/components/parameters/SwarmPostageBatchId" - $ref: "SwarmCommon.yaml#/components/parameters/SwarmDeferredUpload" - - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyLevelParameter" requestBody: content: multipart/form-data: From 5cc9eefdddebfe2759680da47fbc783363d6234d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Levente=20T=C3=B3th?= Date: Wed, 20 Dec 2023 18:16:15 +0100 Subject: [PATCH 22/23] docs: openapi fallbackmodeparam --- openapi/Swarm.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openapi/Swarm.yaml b/openapi/Swarm.yaml index 3ac543a5970..61b446d1095 100644 --- a/openapi/Swarm.yaml +++ b/openapi/Swarm.yaml @@ -352,7 +352,7 @@ paths: required: true description: Path to the file in the collection. - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyStrategyParameter" - - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyModeParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyFallbackModeParameter" - $ref: "SwarmCommon.yaml#/components/parameters/SwarmChunkRetrievalTimeoutParameter" responses: "200": From e3d4d5afa3f1b3f4217d8358f73e06a4e2242ff7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tr=C3=B3n?= Date: Wed, 7 Feb 2024 02:43:21 +0100 Subject: [PATCH 23/23] feat: redundancy retrieve api (#4529) Co-authored-by: nugaon <50576770+nugaon@users.noreply.github.com> Co-authored-by: Anatol <87016465+notanatol@users.noreply.github.com> Co-authored-by: dysordys Co-authored-by: Gyorgy Barabas --- openapi/SwarmCommon.yaml | 2 +- pkg/api/api.go | 22 +- pkg/api/bzz.go | 51 +++- pkg/api/bzz_test.go | 275 ++++++++++++++++++-- pkg/api/debugstorage_test.go | 2 +- pkg/file/joiner/joiner.go | 44 ++-- pkg/file/joiner/joiner_test.go | 193 +++++++++++++- pkg/file/pipeline/hashtrie/hashtrie_test.go | 9 +- pkg/file/redundancy/getter/getter.go | 58 +++-- pkg/file/redundancy/getter/getter_test.go | 86 +++--- pkg/file/redundancy/getter/strategies.go | 121 +++++++-- pkg/file/redundancy/level.go | 57 ++-- pkg/replicas/getter.go | 21 +- pkg/replicas/getter_test.go | 14 +- pkg/replicas/putter.go | 3 +- pkg/replicas/putter_test.go | 4 +- pkg/replicas/replicas.go | 21 +- pkg/steward/steward_test.go | 3 +- pkg/storageincentives/proof_test.go | 2 - pkg/storer/mock/forgetting.go | 128 +++++++++ pkg/util/testutil/pseudorand/reader.go | 182 +++++++++++++ pkg/util/testutil/pseudorand/reader_test.go | 121 +++++++++ pkg/util/testutil/racedetection/off.go | 10 + pkg/util/testutil/racedetection/on.go | 10 + pkg/util/testutil/racedetection/race.go | 9 + 25 files changed, 1193 insertions(+), 255 deletions(-) create mode 100644 pkg/storer/mock/forgetting.go create mode 100644 pkg/util/testutil/pseudorand/reader.go create mode 100644 pkg/util/testutil/pseudorand/reader_test.go create mode 100644 pkg/util/testutil/racedetection/off.go create mode 100644 pkg/util/testutil/racedetection/on.go create mode 100644 pkg/util/testutil/racedetection/race.go diff --git a/openapi/SwarmCommon.yaml b/openapi/SwarmCommon.yaml index 4a7faadc154..07119cc6894 100644 --- a/openapi/SwarmCommon.yaml +++ b/openapi/SwarmCommon.yaml @@ -954,7 +954,7 @@ components: required: false description: > Specify the retrieve strategy on redundant data. - The mumbers stand for NONE, DATA, PROX and RACE, respectively. + The numbers stand for NONE, DATA, PROX and RACE, respectively. Strategy NONE means no prefetching takes place. Strategy DATA means only data chunks are prefetched. Strategy PROX means only chunks that are close to the node are prefetched. diff --git a/pkg/api/api.go b/pkg/api/api.go index 26d11213e5a..470174c68d4 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -83,7 +83,8 @@ const ( SwarmRedundancyLevelHeader = "Swarm-Redundancy-Level" SwarmRedundancyStrategyHeader = "Swarm-Redundancy-Strategy" SwarmRedundancyFallbackModeHeader = "Swarm-Redundancy-Fallback-Mode" - SwarmChunkRetrievalTimeoutHeader = "Swarm-Chunk-Retrieval-Timeout-Level" + SwarmChunkRetrievalTimeoutHeader = "Swarm-Chunk-Retrieval-Timeout" + SwarmLookAheadBufferSizeHeader = "Swarm-Lookahead-Buffer-Size" ImmutableHeader = "Immutable" GasPriceHeader = "Gas-Price" @@ -99,18 +100,6 @@ const ( OriginHeader = "Origin" ) -// The size of buffer used for prefetching content with Langos. -// Warning: This value influences the number of chunk requests and chunker join goroutines -// per file request. -// Recommended value is 8 or 16 times the io.Copy default buffer value which is 32kB, depending -// on the file size. Use lookaheadBufferSize() to get the correct buffer size for the request. -const ( - smallFileBufferSize = 8 * 32 * 1024 - largeFileBufferSize = 16 * 32 * 1024 - - largeBufferFilesizeThreshold = 10 * 1000000 // ten megs -) - const ( multiPartFormData = "multipart/form-data" contentTypeTar = "application/x-tar" @@ -615,13 +604,6 @@ func (s *Service) gasConfigMiddleware(handlerName string) func(h http.Handler) h } } -func lookaheadBufferSize(size int64) int { - if size <= largeBufferFilesizeThreshold { - return smallFileBufferSize - } - return largeFileBufferSize -} - // corsHandler sets CORS headers to HTTP response if allowed origins are configured. func (s *Service) corsHandler(h http.Handler) http.Handler { allowedHeaders := []string{ diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index 3f8cfbac1f6..67a80a58918 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -35,6 +35,25 @@ import ( "github.com/gorilla/mux" ) +// The size of buffer used for prefetching content with Langos when not using erasure coding +// Warning: This value influences the number of chunk requests and chunker join goroutines +// per file request. +// Recommended value is 8 or 16 times the io.Copy default buffer value which is 32kB, depending +// on the file size. Use lookaheadBufferSize() to get the correct buffer size for the request. +const ( + smallFileBufferSize = 8 * 32 * 1024 + largeFileBufferSize = 16 * 32 * 1024 + + largeBufferFilesizeThreshold = 10 * 1000000 // ten megs +) + +func lookaheadBufferSize(size int64) int { + if size <= largeBufferFilesizeThreshold { + return smallFileBufferSize + } + return largeFileBufferSize +} + func (s *Service) bzzUploadHandler(w http.ResponseWriter, r *http.Request) { logger := tracing.NewLoggerWithTraceID(r.Context(), s.logger.WithName("post_bzz").Build()) @@ -272,8 +291,13 @@ func (s *Service) serveReference(logger log.Logger, address swarm.Address, pathV loggerV1 := logger.V(1).Build() headers := struct { - Cache *bool `map:"Swarm-Cache"` + Cache *bool `map:"Swarm-Cache"` + Strategy getter.Strategy `map:"Swarm-Redundancy-Strategy"` + FallbackMode bool `map:"Swarm-Redundancy-Fallback-Mode"` + ChunkRetrievalTimeout string `map:"Swarm-Chunk-Retrieval-Timeout"` + LookaheadBufferSize *string `map:"Swarm-Lookahead-Buffer-Size"` }{} + if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) return @@ -282,10 +306,12 @@ func (s *Service) serveReference(logger log.Logger, address swarm.Address, pathV if headers.Cache != nil { cache = *headers.Cache } + ls := loadsave.NewReadonly(s.storer.Download(cache)) feedDereferenced := false ctx := r.Context() + ctx = getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, getter.DefaultStrategyTimeout.String()) FETCH: // read manifest entry @@ -366,7 +392,6 @@ FETCH: jsonhttp.NotFound(w, "address not found or incorrect") return } - me, err := m.Lookup(ctx, pathVar) if err != nil { loggerV1.Debug("bzz download: invalid path", "address", address, "path", pathVar, "error", err) @@ -459,8 +484,10 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h Cache *bool `map:"Swarm-Cache"` Strategy getter.Strategy `map:"Swarm-Redundancy-Strategy"` FallbackMode bool `map:"Swarm-Redundancy-Fallback-Mode"` - ChunkRetrievalTimeout time.Duration `map:"Swarm-Chunk-Retrieval-Timeout"` + ChunkRetrievalTimeout string `map:"Swarm-Chunk-Retrieval-Timeout"` + LookaheadBufferSize *string `map:"Swarm-Lookahead-Buffer-Size"` }{} + if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) return @@ -471,9 +498,7 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h } ctx := r.Context() - ctx = getter.SetStrategy(ctx, headers.Strategy) - ctx = getter.SetStrict(ctx, headers.FallbackMode) - ctx = getter.SetFetchTimeout(ctx, headers.ChunkRetrievalTimeout) + ctx = getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, getter.DefaultStrategyTimeout.String()) reader, l, err := joiner.New(ctx, s.storer.Download(cache), s.storer.Cache(), reference) if err != nil { if errors.Is(err, storage.ErrNotFound) || errors.Is(err, topology.ErrNotFound) { @@ -497,7 +522,19 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h } w.Header().Set(ContentLengthHeader, strconv.FormatInt(l, 10)) w.Header().Set("Access-Control-Expose-Headers", ContentDispositionHeader) - http.ServeContent(w, r, "", time.Now(), langos.NewBufferedLangos(reader, lookaheadBufferSize(l))) + bufSize := int64(lookaheadBufferSize(l)) + if headers.LookaheadBufferSize != nil { + bufSize, err = strconv.ParseInt(*headers.LookaheadBufferSize, 10, 64) + if err != nil { + logger.Debug("parsing lookahead buffer size", "error", err) + bufSize = 0 + } + } + if bufSize > 0 { + http.ServeContent(w, r, "", time.Now(), langos.NewBufferedLangos(reader, int(bufSize))) + return + } + http.ServeContent(w, r, "", time.Now(), reader) } // manifestMetadataLoad returns the value for a key stored in the metadata of diff --git a/pkg/api/bzz_test.go b/pkg/api/bzz_test.go index d2ad44e507c..634d4eb0166 100644 --- a/pkg/api/bzz_test.go +++ b/pkg/api/bzz_test.go @@ -16,20 +16,223 @@ import ( "strconv" "strings" "testing" + "time" "github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/file/loadsave" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/manifest" mockbatchstore "github.com/ethersphere/bee/pkg/postage/batchstore/mock" mockpost "github.com/ethersphere/bee/pkg/postage/mock" + "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" mockstorer "github.com/ethersphere/bee/pkg/storer/mock" "github.com/ethersphere/bee/pkg/swarm" + "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" ) -// nolint:paralleltest,tparallel +// nolint:paralleltest,tparallel,thelper + +// TestBzzUploadDownloadWithRedundancy tests the API for upload and download files +// with all combinations of redundancy level, encryption and size (levels, i.e., the +// +// height of the swarm hash tree). +// +// This is a variation on the same play as TestJoinerRedundancy +// but here the tested scenario is simplified since we are not testing the intricacies of +// download strategies, but only correct parameter passing and correct recovery functionality +// +// The test cases have the following structure: +// +// 1. upload a file with a given redundancy level and encryption +// +// 2. [positive test] download the file by the reference returned by the upload API response +// This uses range queries to target specific (number of) chunks of the file structure. +// During path traversal in the swarm hash tree, the underlying mocksore (forgetting) +// is in 'recording' mode, flagging all the retrieved chunks as chunks to forget. +// This is to simulate the scenario where some of the chunks are not available/lost +// NOTE: For this to work one needs to switch off lookaheadbuffer functionality +// (see langos pkg) +// +// 3. [negative test] attempt at downloading the file using once again the same root hash +// and the same redundancy strategy to find the file inaccessible after forgetting. +// +// 4. [positive test] attempt at downloading the file using a strategy that allows for +// using redundancy to reconstruct the file and find the file recoverable. +// +// nolint:thelper +func TestBzzUploadDownloadWithRedundancy(t *testing.T) { + fileUploadResource := "/bzz" + fileDownloadResource := func(addr string) string { return "/bzz/" + addr + "/" } + + testRedundancy := func(t *testing.T, rLevel redundancy.Level, encrypt bool, levels int, chunkCnt int, shardCnt int, parityCnt int) { + t.Helper() + seed, err := pseudorand.NewSeed() + if err != nil { + t.Fatal(err) + } + fetchTimeout := 100 * time.Millisecond + store := mockstorer.NewForgettingStore(inmemchunkstore.New()) + storerMock := mockstorer.NewWithChunkStore(store) + client, _, _, _ := newTestServer(t, testServerOptions{ + Storer: storerMock, + Logger: log.Noop, + Post: mockpost.New(mockpost.WithAcceptAll()), + }) + + dataReader := pseudorand.NewReader(seed, chunkCnt*swarm.ChunkSize) + + var refResponse api.BzzUploadResponse + jsonhttptest.Request(t, client, http.MethodPost, fileUploadResource, + http.StatusCreated, + jsonhttptest.WithRequestHeader(api.SwarmDeferredUploadHeader, "True"), + jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, batchOkStr), + jsonhttptest.WithRequestBody(dataReader), + jsonhttptest.WithRequestHeader(api.SwarmEncryptHeader, fmt.Sprintf("%t", encrypt)), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyLevelHeader, fmt.Sprintf("%d", rLevel)), + jsonhttptest.WithRequestHeader(api.ContentTypeHeader, "image/jpeg; charset=utf-8"), + jsonhttptest.WithUnmarshalJSONResponse(&refResponse), + ) + + t.Run("download multiple ranges without redundancy should succeed", func(t *testing.T) { + // the underlying chunk store is in recording mode, so all chunks retrieved + // in this test will be forgotten in the subsequent ones. + store.Record() + defer store.Unrecord() + // we intend to forget as many chunks as possible for the given redundancy level + forget := parityCnt + if parityCnt > shardCnt { + forget = shardCnt + } + if levels == 1 { + forget = 2 + } + start, end := 420, 450 + gap := swarm.ChunkSize + for j := 2; j < levels; j++ { + gap *= shardCnt + } + ranges := make([][2]int, forget) + for i := 0; i < forget; i++ { + pre := i * gap + ranges[i] = [2]int{pre + start, pre + end} + } + rangeHeader, want := createRangeHeader(dataReader, ranges) + + var body []byte + respHeaders := jsonhttptest.Request(t, client, http.MethodGet, + fileDownloadResource(refResponse.Reference.String()), + http.StatusPartialContent, + jsonhttptest.WithRequestHeader(api.RangeHeader, rangeHeader), + jsonhttptest.WithRequestHeader(api.SwarmLookAheadBufferSizeHeader, "0"), + // set for the replicas so that no replica gets deleted + jsonhttptest.WithRequestHeader(api.SwarmRedundancyLevelHeader, "0"), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyStrategyHeader, "0"), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyFallbackModeHeader, "false"), + jsonhttptest.WithRequestHeader(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()), + jsonhttptest.WithPutResponseBody(&body), + ) + + got := parseRangeParts(t, respHeaders.Get(api.ContentTypeHeader), body) + + if len(got) != len(want) { + t.Fatalf("got %v parts, want %v parts", len(got), len(want)) + } + for i := 0; i < len(want); i++ { + if !bytes.Equal(got[i], want[i]) { + t.Errorf("part %v: got %q, want %q", i, string(got[i]), string(want[i])) + } + } + }) + + t.Run("download without redundancy should NOT succeed", func(t *testing.T) { + if rLevel == 0 { + t.Skip("NA") + } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "GET", fileDownloadResource(refResponse.Reference.String()), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set(api.SwarmRedundancyStrategyHeader, "0") + req.Header.Set(api.SwarmRedundancyFallbackModeHeader, "false") + req.Header.Set(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()) + + _, err = client.Do(req) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected error %v; got %v", io.ErrUnexpectedEOF, err) + } + }) + + t.Run("download with redundancy should succeed", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.TODO(), "GET", fileDownloadResource(refResponse.Reference.String()), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set(api.SwarmRedundancyStrategyHeader, "3") + req.Header.Set(api.SwarmRedundancyFallbackModeHeader, "true") + req.Header.Set(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()) + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d; got %d", http.StatusOK, resp.StatusCode) + } + _, err = dataReader.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + ok, err := dataReader.Equal(resp.Body) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatalf("content mismatch") + } + }) + } + for _, rLevel := range []redundancy.Level{1, 2, 3, 4} { + rLevel := rLevel + t.Run(fmt.Sprintf("level=%d", rLevel), func(t *testing.T) { + for _, encrypt := range []bool{false, true} { + encrypt := encrypt + shardCnt := rLevel.GetMaxShards() + parityCnt := rLevel.GetParities(shardCnt) + if encrypt { + shardCnt = rLevel.GetMaxEncShards() + parityCnt = rLevel.GetEncParities(shardCnt) + } + for _, levels := range []int{1, 2, 3} { + chunkCnt := 1 + switch levels { + case 1: + chunkCnt = 2 + case 2: + chunkCnt = shardCnt + 1 + case 3: + chunkCnt = shardCnt*shardCnt + 1 + } + levels := levels + t.Run(fmt.Sprintf("encrypt=%v levels=%d chunks=%d", encrypt, levels, chunkCnt), func(t *testing.T) { + if levels > 2 && (encrypt == (rLevel%2 == 1)) { + t.Skip("skipping to save time") + } + t.Parallel() + testRedundancy(t, rLevel, encrypt, levels, chunkCnt, shardCnt, parityCnt) + }) + } + } + }) + } +} + func TestBzzFiles(t *testing.T) { t.Parallel() @@ -459,35 +662,57 @@ func TestBzzFilesRangeRequests(t *testing.T) { } } -func createRangeHeader(data []byte, ranges [][2]int) (header string, parts [][]byte) { - header = "bytes=" - for i, r := range ranges { - if i > 0 { - header += ", " +func createRangeHeader(data interface{}, ranges [][2]int) (header string, parts [][]byte) { + getLen := func() int { + switch data := data.(type) { + case []byte: + return len(data) + case interface{ Size() int }: + return data.Size() + default: + panic("unknown data type") } - if r[0] >= 0 && r[1] >= 0 { - parts = append(parts, data[r[0]:r[1]]) - // Range: =-, end is inclusive - header += fmt.Sprintf("%v-%v", r[0], r[1]-1) - } else { - if r[0] >= 0 { - header += strconv.Itoa(r[0]) // Range: =- - parts = append(parts, data[r[0]:]) + } + getRange := func(start, end int) []byte { + switch data := data.(type) { + case []byte: + return data[start:end] + case io.ReadSeeker: + buf := make([]byte, end-start) + _, err := data.Seek(int64(start), io.SeekStart) + if err != nil { + panic(err) } - header += "-" - if r[1] >= 0 { - if r[0] >= 0 { - // Range: =-, end is inclusive - header += strconv.Itoa(r[1] - 1) - } else { - // Range: =-, the parameter is length - header += strconv.Itoa(r[1]) - } - parts = append(parts, data[:r[1]]) + _, err = io.ReadFull(data, buf) + if err != nil { + panic(err) } + return buf + default: + panic("unknown data type") + } + } + + rangeStrs := make([]string, len(ranges)) + for i, r := range ranges { + start, end := r[0], r[1] + switch { + case start < 0: + // Range: =-, the parameter is length + rangeStrs[i] = "-" + strconv.Itoa(end) + start = 0 + case r[1] < 0: + // Range: =- + rangeStrs[i] = strconv.Itoa(start) + "-" + end = getLen() + default: + // Range: =-, end is inclusive + rangeStrs[i] = fmt.Sprintf("%v-%v", start, end-1) } + parts = append(parts, getRange(start, end)) } - return + header = "bytes=" + strings.Join(rangeStrs, ", ") // nolint:staticcheck + return header, parts } func parseRangeParts(t *testing.T, contentType string, body []byte) (parts [][]byte) { diff --git a/pkg/api/debugstorage_test.go b/pkg/api/debugstorage_test.go index b83ddb5af99..72d8800b882 100644 --- a/pkg/api/debugstorage_test.go +++ b/pkg/api/debugstorage_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" - storer "github.com/ethersphere/bee/pkg/storer" + "github.com/ethersphere/bee/pkg/storer" mockstorer "github.com/ethersphere/bee/pkg/storer/mock" ) diff --git a/pkg/file/joiner/joiner.go b/pkg/file/joiner/joiner.go index c1ea6f5b4ea..dc9c850084c 100644 --- a/pkg/file/joiner/joiner.go +++ b/pkg/file/joiner/joiner.go @@ -11,7 +11,6 @@ import ( "io" "sync" "sync/atomic" - "time" "github.com/ethersphere/bee/pkg/bmt" "github.com/ethersphere/bee/pkg/encryption" @@ -41,24 +40,20 @@ type joiner struct { // decoderCache is cache of decoders for intermediate chunks type decoderCache struct { - fetcher storage.Getter // network retrieval interface to fetch chunks - putter storage.Putter // interface to local storage to save reconstructed chunks - mu sync.Mutex // mutex to protect cache - cache map[string]storage.Getter // map from chunk address to RS getter - strategy getter.Strategy // strategy to use for retrieval - strict bool // strict mode - fetcherTimeout time.Duration // timeout for each fetch + fetcher storage.Getter // network retrieval interface to fetch chunks + putter storage.Putter // interface to local storage to save reconstructed chunks + mu sync.Mutex // mutex to protect cache + cache map[string]storage.Getter // map from chunk address to RS getter + config getter.Config // getter configuration } // NewDecoderCache creates a new decoder cache -func NewDecoderCache(g storage.Getter, p storage.Putter, strategy getter.Strategy, strict bool, fetcherTimeout time.Duration) *decoderCache { +func NewDecoderCache(g storage.Getter, p storage.Putter, conf getter.Config) *decoderCache { return &decoderCache{ - fetcher: g, - putter: p, - cache: make(map[string]storage.Getter), - strategy: strategy, - strict: strict, - fetcherTimeout: fetcherTimeout, + fetcher: g, + putter: p, + cache: make(map[string]storage.Getter), + config: conf, } } @@ -90,7 +85,7 @@ func (g *decoderCache) GetOrCreate(addrs []swarm.Address, shardCnt int) storage. defer g.mu.Unlock() g.cache[key] = nil } - d = getter.New(addrs, shardCnt, g.fetcher, g.putter, g.strategy, g.strict, g.fetcherTimeout, remove) + d = getter.New(addrs, shardCnt, g.fetcher, g.putter, remove, g.config) g.cache[key] = d return d } @@ -98,7 +93,7 @@ func (g *decoderCache) GetOrCreate(addrs []swarm.Address, shardCnt int) storage. // New creates a new Joiner. A Joiner provides Read, Seek and Size functionalities. func New(ctx context.Context, g storage.Getter, putter storage.Putter, address swarm.Address) (file.Joiner, int64, error) { // retrieve the root chunk to read the total data length the be retrieved - rLevel := replicas.GetLevelFromContext(ctx) + rLevel := redundancy.GetLevelFromContext(ctx) rootChunkGetter := store.New(g) if rLevel != redundancy.NONE { rootChunkGetter = store.New(replicas.NewGetter(g, rLevel)) @@ -118,27 +113,32 @@ func New(ctx context.Context, g storage.Getter, putter storage.Putter, address s spanFn := func(data []byte) (redundancy.Level, int64) { return 0, int64(bmt.LengthFromSpan(data[:swarm.SpanSize])) } - var strategy getter.Strategy - var strict bool - var fetcherTimeout time.Duration + conf, err := getter.NewConfigFromContext(ctx, getter.DefaultConfig) + if err != nil { + return nil, 0, err + } // override stuff if root chunk has redundancy if rLevel != redundancy.NONE { _, parities := file.ReferenceCount(uint64(span), rLevel, encryption) rootParity = parities - strategy, strict, fetcherTimeout = getter.GetParamsFromContext(ctx) + spanFn = chunkToSpan if encryption { maxBranching = rLevel.GetMaxEncShards() } else { maxBranching = rLevel.GetMaxShards() } + } else { + // if root chunk has no redundancy, strategy is ignored and set to NONE and strict is set to true + conf.Strategy = getter.DATA + conf.Strict = true } j := &joiner{ addr: rootChunk.Address(), refLength: refLength, ctx: ctx, - decoders: NewDecoderCache(g, putter, strategy, strict, fetcherTimeout), + decoders: NewDecoderCache(g, putter, conf), span: span, rootData: rootData, rootParity: rootParity, diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 7a97080171b..15d46bf220b 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -27,12 +27,17 @@ import ( storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" testingc "github.com/ethersphere/bee/pkg/storage/testing" + mockstorer "github.com/ethersphere/bee/pkg/storer/mock" "github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/util/testutil" + "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" + "github.com/ethersphere/bee/pkg/util/testutil/racedetection" "gitlab.com/nolash/go-mockbytes" "golang.org/x/sync/errgroup" ) +// nolint:paralleltest,tparallel,thelper + func TestJoiner_ErrReferenceLength(t *testing.T) { t.Parallel() @@ -1018,12 +1023,9 @@ func (m *mockPutter) store(cnt int) error { return nil } +// nolint:thelper func TestJoinerRedundancy(t *testing.T) { - - strategyTimeout := getter.StrategyTimeout - defer func() { getter.StrategyTimeout = strategyTimeout }() - getter.StrategyTimeout = 100 * time.Millisecond - + t.Parallel() for _, tc := range []struct { rLevel redundancy.Level encryptChunk bool @@ -1063,10 +1065,8 @@ func TestJoinerRedundancy(t *testing.T) { } { tc := tc t.Run(fmt.Sprintf("redundancy=%d encryption=%t", tc.rLevel, tc.encryptChunk), func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) defer cancel() - shardCnt := tc.rLevel.GetMaxShards() parityCnt := tc.rLevel.GetParities(shardCnt) if tc.encryptChunk { @@ -1109,13 +1109,12 @@ func TestJoinerRedundancy(t *testing.T) { if err != nil { t.Fatal(err) } + strategyTimeout := 100 * time.Millisecond // all data can be read back readCheck := func(t *testing.T, expErr error) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), 15*getter.StrategyTimeout) + ctx, cancel := context.WithTimeout(context.Background(), 15*strategyTimeout) defer cancel() - ctx = getter.SetFetchTimeout(ctx, getter.StrategyTimeout) + ctx = getter.SetConfigInContext(ctx, getter.RACE, true, (10 * strategyTimeout).String(), strategyTimeout.String()) joinReader, rootSpan, err := joiner.New(ctx, store, store, swarmAddr) if err != nil { t.Fatal(err) @@ -1127,10 +1126,11 @@ func TestJoinerRedundancy(t *testing.T) { } i := 0 eg, ectx := errgroup.WithContext(ctx) + scnt: for ; i < shardCnt; i++ { select { case <-ectx.Done(): - break + break scnt default: } i := i @@ -1153,6 +1153,7 @@ func TestJoinerRedundancy(t *testing.T) { }) } err = eg.Wait() + if !errors.Is(err, expErr) { t.Fatalf("unexpected error reading chunkdata at chunk position %d: expected %v. got %v", i, expErr, err) } @@ -1191,3 +1192,171 @@ func TestJoinerRedundancy(t *testing.T) { }) } } + +// TestJoinerRedundancyMultilevel tests the joiner with all combinations of +// redundancy level, encryption and size (levels, i.e., the height of the swarm hash tree). +// +// The test cases have the following structure: +// +// 1. upload a file with a given redundancy level and encryption +// +// 2. [positive test] download the file by the reference returned by the upload API response +// This uses range queries to target specific (number of) chunks of the file structure +// During path traversal in the swarm hash tree, the underlying mocksore (forgetting) +// is in 'recording' mode, flagging all the retrieved chunks as chunks to forget. +// This is to simulate the scenario where some of the chunks are not available/lost +// +// 3. [negative test] attempt at downloading the file using once again the same root hash +// and a no-redundancy strategy to find the file inaccessible after forgetting. +// 3a. [negative test] download file using NONE without fallback and fail +// 3b. [negative test] download file using DATA without fallback and fail +// +// 4. [positive test] download file using DATA with fallback to allow for +// reconstruction via erasure coding and succeed. +// +// 5. [positive test] after recovery chunks are saved, so fotgetting no longer +// repeat 3a/3b but this time succeed +// +// nolint:thelper +func TestJoinerRedundancyMultilevel(t *testing.T) { + t.Parallel() + test := func(t *testing.T, rLevel redundancy.Level, encrypt bool, levels, size int) { + t.Helper() + store := mockstorer.NewForgettingStore(inmemchunkstore.New()) + testutil.CleanupCloser(t, store) + seed, err := pseudorand.NewSeed() + if err != nil { + t.Fatal(err) + } + dataReader := pseudorand.NewReader(seed, size*swarm.ChunkSize) + ctx := context.Background() + // ctx = redundancy.SetLevelInContext(ctx, rLevel) + ctx = redundancy.SetLevelInContext(ctx, redundancy.NONE) + pipe := builder.NewPipelineBuilder(ctx, store, encrypt, rLevel) + addr, err := builder.FeedPipeline(ctx, pipe, dataReader) + if err != nil { + t.Fatal(err) + } + expRead := swarm.ChunkSize + buf := make([]byte, expRead) + offset := mrand.Intn(size) * expRead + canReadRange := func(t *testing.T, s getter.Strategy, fallback bool, levels int, canRead bool) { + ctx := context.Background() + strategyTimeout := 100 * time.Millisecond + decodingTimeout := 600 * time.Millisecond + if racedetection.IsOn() { + decodingTimeout *= 2 + } + ctx = getter.SetConfigInContext(ctx, s, fallback, (2 * strategyTimeout).String(), strategyTimeout.String()) + ctx, cancel := context.WithTimeout(ctx, time.Duration(levels)*(3*strategyTimeout+decodingTimeout)) + defer cancel() + j, _, err := joiner.New(ctx, store, store, addr) + if err != nil { + t.Fatal(err) + } + n, err := j.ReadAt(buf, int64(offset)) + if !canRead { + if !errors.Is(err, storage.ErrNotFound) && !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected error %v or %v. got %v", storage.ErrNotFound, context.DeadlineExceeded, err) + } + return + } + if err != nil { + t.Fatal(err) + } + if n != expRead { + t.Errorf("read %d bytes out of %d", n, expRead) + } + _, err = dataReader.Seek(int64(offset), io.SeekStart) + if err != nil { + t.Fatal(err) + } + ok, err := dataReader.Match(bytes.NewBuffer(buf), expRead) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("content mismatch") + } + } + + // first sanity check and and recover a range + t.Run("NONE w/o fallback CAN retrieve", func(t *testing.T) { + store.Record() + defer store.Unrecord() + canReadRange(t, getter.NONE, false, levels, true) + }) + + // do not forget the root chunk + store.Unmiss(swarm.NewAddress(addr.Bytes()[:swarm.HashSize])) + // after we forget the chunks on the way to the range, we should not be able to retrieve + t.Run("NONE w/o fallback CANNOT retrieve", func(t *testing.T) { + canReadRange(t, getter.NONE, false, levels, false) + }) + + // we lost a data chunk, we cannot recover using DATA only strategy with no fallback + t.Run("DATA w/o fallback CANNOT retrieve", func(t *testing.T) { + canReadRange(t, getter.DATA, false, levels, false) + }) + + if rLevel == 0 { + // allowing fallback mode will not help if no redundancy used for upload + t.Run("DATA w fallback CANNOT retrieve", func(t *testing.T) { + canReadRange(t, getter.DATA, true, levels, false) + }) + return + } + // allowing fallback mode will make the range retrievable using erasure decoding + t.Run("DATA w fallback CAN retrieve", func(t *testing.T) { + canReadRange(t, getter.DATA, true, levels, true) + }) + // after the reconstructed data is stored, we can retrieve the range using DATA only mode + t.Run("after recovery, NONE w/o fallback CAN retrieve", func(t *testing.T) { + canReadRange(t, getter.NONE, false, levels, true) + }) + } + r2level := []int{2, 1, 2, 3, 2} + encryptChunk := []bool{false, false, true, true, true} + for _, rLevel := range []redundancy.Level{0, 1, 2, 3, 4} { + rLevel := rLevel + // speeding up tests by skipping some of them + t.Run(fmt.Sprintf("rLevel=%v", rLevel), func(t *testing.T) { + t.Parallel() + for _, encrypt := range []bool{false, true} { + encrypt := encrypt + shardCnt := rLevel.GetMaxShards() + if encrypt { + shardCnt = rLevel.GetMaxEncShards() + } + for _, levels := range []int{1, 2, 3} { + chunkCnt := 1 + switch levels { + case 1: + chunkCnt = 2 + case 2: + chunkCnt = shardCnt + 1 + case 3: + chunkCnt = shardCnt*shardCnt + 1 + } + t.Run(fmt.Sprintf("encrypt=%v levels=%d chunks=%d incomplete", encrypt, levels, chunkCnt), func(t *testing.T) { + if r2level[rLevel] != levels || encrypt != encryptChunk[rLevel] { + t.Skip("skipping to save time") + } + test(t, rLevel, encrypt, levels, chunkCnt) + }) + switch levels { + case 1: + chunkCnt = shardCnt + case 2: + chunkCnt = shardCnt * shardCnt + case 3: + continue + } + t.Run(fmt.Sprintf("encrypt=%v levels=%d chunks=%d full", encrypt, levels, chunkCnt), func(t *testing.T) { + test(t, rLevel, encrypt, levels, chunkCnt) + }) + } + } + }) + } +} diff --git a/pkg/file/pipeline/hashtrie/hashtrie_test.go b/pkg/file/pipeline/hashtrie/hashtrie_test.go index c3729f78901..4a966e265f3 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie_test.go +++ b/pkg/file/pipeline/hashtrie/hashtrie_test.go @@ -24,7 +24,6 @@ import ( "github.com/ethersphere/bee/pkg/file/pipeline/mock" "github.com/ethersphere/bee/pkg/file/pipeline/store" "github.com/ethersphere/bee/pkg/file/redundancy" - "github.com/ethersphere/bee/pkg/replicas" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" "github.com/ethersphere/bee/pkg/swarm" @@ -46,7 +45,7 @@ func init() { binary.LittleEndian.PutUint64(span, 1) } -// NewErasureHashTrieWriter returns back an redundancy param and a HastTrieWriter pipeline +// newErasureHashTrieWriter returns back an redundancy param and a HastTrieWriter pipeline // which are using simple BMT and StoreWriter pipelines for chunk writes func newErasureHashTrieWriter( ctx context.Context, @@ -303,20 +302,20 @@ func TestRedundancy(t *testing.T) { level: redundancy.INSANE, encryption: false, writes: 98, // 97 chunk references fit into one chunk + 1 carrier - parities: 38, // 31 (full ch) + 7 (2 ref) + parities: 37, // 31 (full ch) + 6 (2 ref) }, { desc: "redundancy write for encrypted data", level: redundancy.PARANOID, encryption: true, writes: 21, // 21 encrypted chunk references fit into one chunk + 1 carrier - parities: 118, // // 88 (full ch) + 30 (2 ref) + parities: 116, // // 87 (full ch) + 29 (2 ref) }, } { tc := tc t.Run(tc.desc, func(t *testing.T) { t.Parallel() - subCtx := replicas.SetLevel(ctx, tc.level) + subCtx := redundancy.SetLevelInContext(ctx, tc.level) s := inmemchunkstore.New() intermediateChunkCounter := mock.NewChainWriter() diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go index e23f2a3e6af..4e8da1b6390 100644 --- a/pkg/file/redundancy/getter/getter.go +++ b/pkg/file/redundancy/getter/getter.go @@ -9,7 +9,6 @@ import ( "io" "sync" "sync/atomic" - "time" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" @@ -28,13 +27,16 @@ type decoder struct { waits []chan struct{} // wait channels for each chunk rsbuf [][]byte // RS buffer of data + parity shards for erasure decoding ready chan struct{} // signal channel for successful retrieval of shardCnt chunks + lastLen int // length of the last data chunk in the RS buffer shardCnt int // number of data shards parityCnt int // number of parity shards wg sync.WaitGroup // wait group to wait for all goroutines to finish mu sync.Mutex // mutex to protect buffer + err error // error of the last erasure decoding fetchedCnt atomic.Int32 // count successful retrievals cancel func() // cancel function for RS decoding remove func() // callback to remove decoder from decoders cache + config Config // configuration } type Getter interface { @@ -42,14 +44,10 @@ type Getter interface { io.Closer } -// New returns a decoder object used tos retrieve children of an intermediate chunk -func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter, strategy Strategy, strict bool, fetchTimeout time.Duration, remove func()) Getter { +// New returns a decoder object used to retrieve children of an intermediate chunk +func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter, remove func(), conf Config) Getter { ctx, cancel := context.WithCancel(context.Background()) size := len(addrs) - if fetchTimeout == 0 { - fetchTimeout = 30 * time.Second - } - strategyTimeout := StrategyTimeout rsg := &decoder{ fetcher: g, @@ -64,6 +62,7 @@ func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter remove: remove, shardCnt: shardCnt, parityCnt: size - shardCnt, + config: conf, } // after init, cache and wait channels are immutable, need no locking @@ -73,11 +72,13 @@ func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter } // prefetch chunks according to strategy - rsg.wg.Add(1) - go func() { - rsg.prefetch(ctx, strategy, strict, strategyTimeout, fetchTimeout) - rsg.wg.Done() - }() + if !conf.Strict || conf.Strategy != NONE { + rsg.wg.Add(1) + go func() { + rsg.err = rsg.prefetch(ctx) + rsg.wg.Done() + }() + } return rsg } @@ -97,10 +98,30 @@ func (g *decoder) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, err } select { case <-g.waits[i]: - return swarm.NewChunk(addr, g.rsbuf[i]), nil case <-ctx.Done(): return nil, ctx.Err() } + return swarm.NewChunk(addr, g.getData(i)), nil +} + +// setData sets the data shard in the RS buffer +func (g *decoder) setData(i int, chdata []byte) { + data := chdata + // pad the chunk with zeros if it is smaller than swarm.ChunkSize + if len(data) < swarm.ChunkWithSpanSize { + g.lastLen = len(data) + data = make([]byte, swarm.ChunkWithSpanSize) + copy(data, chdata) + } + g.rsbuf[i] = data +} + +// getData returns the data shard from the RS buffer +func (g *decoder) getData(i int) []byte { + if i == g.shardCnt-1 && g.lastLen > 0 { + return g.rsbuf[i][:g.lastLen] // cut padding + } + return g.rsbuf[i] } // fly commits to retrieve the chunk (fly and land) @@ -115,7 +136,9 @@ func (g *decoder) fly(i int, up bool) (success bool) { // it races with erasure recovery which takes precedence even if it started later // due to the fact that erasure recovery could only implement global locking on all shards func (g *decoder) fetch(ctx context.Context, i int) { - ch, err := g.fetcher.Get(ctx, g.addrs[i]) + fctx, cancel := context.WithTimeout(ctx, g.config.FetchTimeout) + defer cancel() + ch, err := g.fetcher.Get(fctx, g.addrs[i]) if err != nil { _ = g.fly(i, false) // unset inflight return @@ -139,8 +162,8 @@ func (g *decoder) fetch(ctx context.Context, i int) { } // write chunk to rsbuf and signal waiters - g.rsbuf[i] = ch.Data() // save the chunk in the RS buffer - if i < len(g.waits) { + g.setData(i, ch.Data()) // save the chunk in the RS buffer + if i < len(g.waits) { // if the chunk is a data shard close(g.waits[i]) // signal that the chunk is retrieved } @@ -218,6 +241,9 @@ func (g *decoder) save(ctx context.Context, missing []int) error { return nil } +// Close terminates the prefetch loop, waits for all goroutines to finish and +// removes the decoder from the cache +// it implements the io.Closer interface func (g *decoder) Close() error { g.cancel() g.wg.Wait() diff --git a/pkg/file/redundancy/getter/getter_test.go b/pkg/file/redundancy/getter/getter_test.go index bb00f7ddd02..b18caa55c12 100644 --- a/pkg/file/redundancy/getter/getter_test.go +++ b/pkg/file/redundancy/getter/getter_test.go @@ -21,37 +21,13 @@ import ( "github.com/ethersphere/bee/pkg/file/redundancy/getter" "github.com/ethersphere/bee/pkg/storage" inmem "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" + mockstorer "github.com/ethersphere/bee/pkg/storer/mock" "github.com/ethersphere/bee/pkg/swarm" + "github.com/ethersphere/bee/pkg/util/testutil/racedetection" "github.com/klauspost/reedsolomon" "golang.org/x/sync/errgroup" ) -type delayed struct { - storage.ChunkStore - cache map[string]time.Duration - mu sync.Mutex -} - -func (d *delayed) delay(addr swarm.Address, delay time.Duration) { - d.mu.Lock() - defer d.mu.Unlock() - d.cache[addr.String()] = delay -} - -func (d *delayed) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { - d.mu.Lock() - defer d.mu.Unlock() - if delay, ok := d.cache[addr.String()]; ok && delay > 0 { - select { - case <-time.After(delay): - delete(d.cache, addr.String()) - case <-ctx.Done(): - return nil, ctx.Err() - } - } - return d.ChunkStore.Get(ctx, addr) -} - // TestGetter tests the retrieval of chunks with missing data shards // using the RACE strategy for a number of erasure code parameters func TestGetterRACE(t *testing.T) { @@ -118,11 +94,10 @@ func TestGetterFallback(t *testing.T) { func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { t.Helper() - - strategyTimeout := getter.StrategyTimeout - defer func() { getter.StrategyTimeout = strategyTimeout }() - getter.StrategyTimeout = 100 * time.Millisecond - + strategyTimeout := 100 * time.Millisecond + if racedetection.On { + strategyTimeout *= 2 + } store := inmem.New() buf := make([][]byte, bufSize) addrs := initData(t, buf, shardCnt, store) @@ -140,7 +115,12 @@ func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { } ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - g := getter.New(addrs, shardCnt, store, store, getter.RACE, false, 2*getter.StrategyTimeout, func() {}) + conf := getter.Config{ + Strategy: getter.RACE, + FetchTimeout: 2 * strategyTimeout, + StrategyTimeout: strategyTimeout, + } + g := getter.New(addrs, shardCnt, store, store, func() {}, conf) defer g.Close() parityCnt := len(buf) - shardCnt q := make(chan error, 1) @@ -149,9 +129,13 @@ func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { q <- err }() err := context.DeadlineExceeded + wait := strategyTimeout * 2 + if racedetection.On { + wait *= 2 + } select { case err = <-q: - case <-time.After(getter.StrategyTimeout * 10): + case <-time.After(wait): } switch { case erasureCnt > parityCnt: @@ -175,13 +159,11 @@ func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { t.Helper() - strategyTimeout := getter.StrategyTimeout - defer func() { getter.StrategyTimeout = strategyTimeout }() - getter.StrategyTimeout = 100 * time.Millisecond + strategyTimeout := 150 * time.Millisecond bufSize := 12 shardCnt := 6 - store := &delayed{ChunkStore: inmem.New(), cache: make(map[string]time.Duration)} + store := mockstorer.NewDelayedStore(inmem.New()) buf := make([][]byte, bufSize) addrs := initData(t, buf, shardCnt, store) @@ -201,14 +183,20 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { waitDelayed, waitErased := make(chan error, 1), make(chan error, 1) // complete retrieval of delayed chunk by putting it into the store after a while - delay := +getter.StrategyTimeout / 4 + delay := strategyTimeout / 4 if s == getter.NONE { - delay += getter.StrategyTimeout + delay += strategyTimeout } - store.delay(addrs[delayed], delay) + store.Delay(addrs[delayed], delay) // create getter start := time.Now() - g := getter.New(addrs, shardCnt, store, store, s, strict, getter.StrategyTimeout/2, func() {}) + conf := getter.Config{ + Strategy: s, + Strict: strict, + FetchTimeout: strategyTimeout / 2, + StrategyTimeout: strategyTimeout, + } + g := getter.New(addrs, shardCnt, store, store, func() {}, conf) defer g.Close() // launch delayed and erased chunk retrieval @@ -219,11 +207,15 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { // delayed and erased chunk retrieval completes go func() { defer wg.Done() + ctx, cancel := context.WithTimeout(ctx, strategyTimeout*time.Duration(5-s)) + defer cancel() _, err := g.Get(ctx, addrs[delayed]) waitDelayed <- err }() go func() { defer wg.Done() + ctx, cancel := context.WithTimeout(ctx, strategyTimeout*time.Duration(5-s)) + defer cancel() _, err := g.Get(ctx, addrs[erased]) waitErased <- err }() @@ -234,7 +226,7 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { if err != nil { t.Fatal("unexpected error", err) } - round := time.Since(start) / getter.StrategyTimeout + round := time.Since(start) / strategyTimeout switch { case strict && s == getter.NONE: if round < 1 { @@ -260,15 +252,15 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { if err != nil { t.Fatal("unexpected error", err) } - round = time.Since(start) / getter.StrategyTimeout + round = time.Since(start) / strategyTimeout switch { case strict: t.Fatalf("unexpected completion of erased chunk retrieval. got round %d", round) case s == getter.NONE: - if round < 2 { + if round < 3 { t.Fatalf("unexpected early completion of erased chunk retrieval. got round %d", round) } - if round > 2 { + if round > 3 { t.Fatalf("unexpected late completion of erased chunk retrieval. got round %d", round) } case s == getter.DATA: @@ -281,12 +273,12 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { } checkShardsAvailable(t, store, addrs[:erased], buf[:erased]) - case <-time.After(getter.StrategyTimeout * 2): + case <-time.After(strategyTimeout * 2): if !strict { t.Fatal("unexpected timeout using strategy", s, "with strict", strict) } } - case <-time.After(getter.StrategyTimeout * 3): + case <-time.After(strategyTimeout * 3): if !strict || s != getter.NONE { t.Fatal("unexpected timeout using strategy", s, "with strict", strict) } diff --git a/pkg/file/redundancy/getter/strategies.go b/pkg/file/redundancy/getter/strategies.go index 8bf944e8ae7..bb5188e9cc2 100644 --- a/pkg/file/redundancy/getter/strategies.go +++ b/pkg/file/redundancy/getter/strategies.go @@ -6,21 +6,34 @@ package getter import ( "context" + "errors" "fmt" "time" ) -var ( - StrategyTimeout = 500 * time.Millisecond // timeout for each strategy +const ( + DefaultStrategy = NONE // default prefetching strategy + DefaultStrict = true // default fallback modes + DefaultFetchTimeout = 30 * time.Second // timeout for each chunk retrieval + DefaultStrategyTimeout = 300 * time.Millisecond // timeout for each strategy ) type ( - strategyKey struct{} - modeKey struct{} - fetcherTimeoutKey struct{} - Strategy = int + strategyKey struct{} + modeKey struct{} + fetchTimeoutKey struct{} + strategyTimeoutKey struct{} + Strategy = int ) +// Config is the configuration for the getter - public +type Config struct { + Strategy Strategy + Strict bool + FetchTimeout time.Duration + StrategyTimeout time.Duration +} + const ( NONE Strategy = iota // no prefetching and no decoding DATA // just retrieve data shards no decoding @@ -29,17 +42,57 @@ const ( strategyCnt ) -// GetParamsFromContext extracts the strategy and strict mode from the context -func GetParamsFromContext(ctx context.Context) (s Strategy, strict bool, fetcherTimeout time.Duration) { - s, _ = ctx.Value(strategyKey{}).(Strategy) - strict, _ = ctx.Value(modeKey{}).(bool) - fetcherTimeout, _ = ctx.Value(fetcherTimeoutKey{}).(time.Duration) - return s, strict, fetcherTimeout +// DefaultConfig is the default configuration for the getter +var DefaultConfig = Config{ + Strategy: DefaultStrategy, + Strict: DefaultStrict, + FetchTimeout: DefaultFetchTimeout, + StrategyTimeout: DefaultStrategyTimeout, } -// SetFetchTimeout sets the timeout for each fetch -func SetFetchTimeout(ctx context.Context, timeout time.Duration) context.Context { - return context.WithValue(ctx, fetcherTimeoutKey{}, timeout) +// NewConfigFromContext returns a new Config based on the context +func NewConfigFromContext(ctx context.Context, def Config) (conf Config, err error) { + var ok bool + conf = def + e := func(s string, errs ...error) error { + if len(errs) > 0 { + return fmt.Errorf("error setting %s from context: %w", s, errors.Join(errs...)) + } + return fmt.Errorf("error setting %s from context", s) + } + if val := ctx.Value(strategyKey{}); val != nil { + conf.Strategy, ok = val.(Strategy) + if !ok { + return conf, e("strategy") + } + } + if val := ctx.Value(modeKey{}); val != nil { + conf.Strict, ok = val.(bool) + if !ok { + return conf, e("fallback mode") + } + } + if val := ctx.Value(fetchTimeoutKey{}); val != nil { + fetchTimeoutVal, ok := val.(string) + if !ok { + return conf, e("fetcher timeout") + } + conf.FetchTimeout, err = time.ParseDuration(fetchTimeoutVal) + if err != nil { + return conf, e("fetcher timeout", err) + } + } + if val := ctx.Value(strategyTimeoutKey{}); val != nil { + strategyTimeoutVal, ok := val.(string) + if !ok { + return conf, e("fetcher timeout") + } + conf.StrategyTimeout, err = time.ParseDuration(strategyTimeoutVal) + if err != nil { + return conf, e("fetcher timeout", err) + } + } + return conf, nil } // SetStrategy sets the strategy for the retrieval @@ -52,9 +105,28 @@ func SetStrict(ctx context.Context, strict bool) context.Context { return context.WithValue(ctx, modeKey{}, strict) } -func (g *decoder) prefetch(ctx context.Context, strategy int, strict bool, strategyTimeout, fetchTimeout time.Duration) { - if strict && strategy == NONE { - return +// SetFetchTimeout sets the timeout for each fetch +func SetFetchTimeout(ctx context.Context, timeout string) context.Context { + return context.WithValue(ctx, fetchTimeoutKey{}, timeout) +} + +// SetStrategyTimeout sets the timeout for each strategy +func SetStrategyTimeout(ctx context.Context, timeout string) context.Context { + return context.WithValue(ctx, fetchTimeoutKey{}, timeout) +} + +// SetConfigInContext sets the config params in the context +func SetConfigInContext(ctx context.Context, s Strategy, fallbackmode bool, fetchTimeout, strategyTimeout string) context.Context { + ctx = SetStrategy(ctx, s) + ctx = SetStrict(ctx, !fallbackmode) + ctx = SetFetchTimeout(ctx, fetchTimeout) + ctx = SetStrategyTimeout(ctx, strategyTimeout) + return ctx +} + +func (g *decoder) prefetch(ctx context.Context) error { + if g.config.Strict && g.config.Strategy == NONE { + return nil } defer g.remove() var cancels []func() @@ -66,16 +138,16 @@ func (g *decoder) prefetch(ctx context.Context, strategy int, strict bool, strat defer cancelAll() run := func(s Strategy) error { if s == PROX { // NOT IMPLEMENTED - return fmt.Errorf("strategy %d not implemented", s) + return errors.New("strategy not implemented") } var stop <-chan time.Time if s < RACE { - timer := time.NewTimer(strategyTimeout) + timer := time.NewTimer(g.config.StrategyTimeout) defer timer.Stop() stop = timer.C } - lctx, cancel := context.WithTimeout(ctx, fetchTimeout) + lctx, cancel := context.WithCancel(ctx) cancels = append(cancels, cancel) prefetch(lctx, g, s) @@ -93,9 +165,14 @@ func (g *decoder) prefetch(ctx context.Context, strategy int, strict bool, strat return g.recover(ctx) // context to cancel when shardCnt chunks are retrieved } var err error - for s := strategy; s == strategy || (err != nil && !strict && s < strategyCnt); s++ { + for s := g.config.Strategy; s < strategyCnt; s++ { err = run(s) + if g.config.Strict || err == nil { + break + } } + + return err } // prefetch launches the retrieval of chunks based on the strategy diff --git a/pkg/file/redundancy/level.go b/pkg/file/redundancy/level.go index 3dfba1cd084..f7b4f5c19f1 100644 --- a/pkg/file/redundancy/level.go +++ b/pkg/file/redundancy/level.go @@ -5,6 +5,7 @@ package redundancy import ( + "context" "errors" "fmt" @@ -106,44 +107,40 @@ func (l Level) Decrement() Level { // TABLE INITS var mediumEt = newErasureTable( - []int{94, 68, 46, 28, 14, 5, 1}, - []int{9, 8, 7, 6, 5, 4, 3}, + []int{95, 69, 47, 29, 15, 6, 2, 1}, + []int{9, 8, 7, 6, 5, 4, 3, 2}, ) var encMediumEt = newErasureTable( - []int{47, 34, 23, 14, 7, 2}, - []int{9, 8, 7, 6, 5, 4}, + []int{47, 34, 23, 14, 7, 3, 1}, + []int{9, 8, 7, 6, 5, 4, 3}, ) var strongEt = newErasureTable( - []int{104, 95, 86, 77, 69, 61, 53, 46, 39, 32, 26, 20, 15, 10, 6, 3, 1}, - []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5}, + []int{105, 96, 87, 78, 70, 62, 54, 47, 40, 33, 27, 21, 16, 11, 7, 4, 2, 1}, + []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4}, ) var encStrongEt = newErasureTable( - []int{52, 47, 43, 38, 34, 30, 26, 23, 19, 16, 13, 10, 7, 5, 3, 1}, - []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, + []int{52, 48, 43, 39, 35, 31, 27, 23, 20, 16, 13, 10, 8, 5, 3, 2, 1}, + []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5}, ) var insaneEt = newErasureTable( - []int{92, 87, 82, 77, 73, 68, 63, 59, 54, 50, 45, 41, 37, 33, 29, 26, 22, 19, 16, 13, 10, 8, 5, 3, 2, 1}, - []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, + []int{93, 88, 83, 78, 74, 69, 64, 60, 55, 51, 46, 42, 38, 34, 30, 27, 23, 20, 17, 14, 11, 9, 6, 4, 3, 2, 1}, + []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5}, ) var encInsaneEt = newErasureTable( - []int{46, 43, 41, 38, 36, 34, 31, 29, 27, 25, 22, 20, 18, 16, 14, 13, 11, 9, 8, 6, 5, 4, 2, 1}, - []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 7}, + []int{46, 44, 41, 39, 37, 34, 32, 30, 27, 25, 23, 21, 19, 17, 15, 13, 11, 10, 8, 7, 5, 4, 3, 2, 1}, + []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 6}, ) var paranoidEt = newErasureTable( []int{ - 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, - 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, - 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, - 7, 6, 5, 4, 3, 2, 1, + 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, + 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, }, []int{ - 90, 88, 87, 85, 84, 82, 81, 79, 77, 76, - 74, 72, 71, 69, 67, 66, 64, 62, 60, 59, - 57, 55, 53, 51, 49, 48, 46, 44, 41, 39, - 37, 35, 32, 30, 27, 24, 20, + 89, 87, 86, 84, 83, 81, 80, 78, 76, 75, 73, 71, 70, 68, 66, 65, 63, 61, 59, 58, + 56, 54, 52, 50, 48, 47, 45, 43, 40, 38, 36, 34, 31, 29, 26, 23, 19, }, ) var encParanoidEt = newErasureTable( @@ -152,8 +149,8 @@ var encParanoidEt = newErasureTable( 8, 7, 6, 5, 4, 3, 2, 1, }, []int{ - 88, 85, 82, 79, 76, 72, 69, 66, 62, 59, - 55, 51, 48, 44, 39, 35, 30, 24, + 87, 84, 81, 78, 75, 71, 68, 65, 61, 58, + 54, 50, 47, 43, 38, 34, 29, 23, }, ) @@ -167,3 +164,19 @@ func GetReplicaCounts() [5]int { // for the five levels of redundancy are 0, 2, 4, 5, 19 // we use an approximation as the successive powers of 2 var replicaCounts = [5]int{0, 2, 4, 8, 16} + +type levelKey struct{} + +// SetLevelInContext sets the redundancy level in the context +func SetLevelInContext(ctx context.Context, level Level) context.Context { + return context.WithValue(ctx, levelKey{}, level) +} + +// GetLevelFromContext is a helper function to extract the redundancy level from the context +func GetLevelFromContext(ctx context.Context) Level { + rlevel := PARANOID + if val := ctx.Value(levelKey{}); val != nil { + rlevel = val.(Level) + } + return rlevel +} diff --git a/pkg/replicas/getter.go b/pkg/replicas/getter.go index 496bc5d658a..26345919958 100644 --- a/pkg/replicas/getter.go +++ b/pkg/replicas/getter.go @@ -24,20 +24,7 @@ import ( // then the probability of Swarmageddon is less than 0.000001 // assuming the error rate of chunk retrievals stays below the level expressed // as depth by the publisher. -type ErrSwarmageddon struct { - error -} - -func (err *ErrSwarmageddon) Unwrap() []error { - if err == nil || err.error == nil { - return nil - } - var uwe interface{ Unwrap() []error } - if !errors.As(err.error, &uwe) { - return nil - } - return uwe.Unwrap() -} +var ErrSwarmageddon = errors.New("swarmageddon has begun") // getter is the private implementation of storage.Getter, an interface for // retrieving chunks. This getter embeds the original simple chunk getter and extends it @@ -69,7 +56,7 @@ func (g *getter) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, e resultC := make(chan swarm.Chunk) // errc collects the errors errc := make(chan error, 17) - var errs []error + var errs error errcnt := 0 // concurrently call to retrieve chunk using original CAC address @@ -108,10 +95,10 @@ func (g *getter) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, e return chunk, nil case err = <-errc: - errs = append(errs, err) + errs = errors.Join(errs, err) errcnt++ if errcnt > total { - return nil, &ErrSwarmageddon{errors.Join(errs...)} + return nil, errors.Join(ErrSwarmageddon, errs) } // ticker switches on the address channel diff --git a/pkg/replicas/getter_test.go b/pkg/replicas/getter_test.go index 3b11ad26d94..7435b0574fd 100644 --- a/pkg/replicas/getter_test.go +++ b/pkg/replicas/getter_test.go @@ -195,18 +195,11 @@ func TestGetter(t *testing.T) { } t.Run("returns correct error", func(t *testing.T) { - var esg *replicas.ErrSwarmageddon - if !errors.As(err, &esg) { + if !errors.Is(err, replicas.ErrSwarmageddon) { t.Fatalf("incorrect error. want Swarmageddon. got %v", err) } - errs := esg.Unwrap() - for _, err := range errs { - if !errors.Is(err, tc.failure.err) { - t.Fatalf("incorrect error. want it to wrap %v. got %v", tc.failure.err, err) - } - } - if len(errs) != tc.count+1 { - t.Fatalf("incorrect error. want %d. got %d", tc.count+1, len(errs)) + if !errors.Is(err, tc.failure.err) { + t.Fatalf("incorrect error. want it to wrap %v. got %v", tc.failure.err, err) } }) } @@ -265,7 +258,6 @@ func TestGetter(t *testing.T) { } } }) - }) } } diff --git a/pkg/replicas/putter.go b/pkg/replicas/putter.go index f2334a994b8..4aa55b638f0 100644 --- a/pkg/replicas/putter.go +++ b/pkg/replicas/putter.go @@ -11,6 +11,7 @@ import ( "errors" "sync" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" @@ -29,7 +30,7 @@ func NewPutter(p storage.Putter) storage.Putter { // Put makes the getter satisfy the storage.Getter interface func (p *putter) Put(ctx context.Context, ch swarm.Chunk) (err error) { - rlevel := GetLevelFromContext(ctx) + rlevel := redundancy.GetLevelFromContext(ctx) errs := []error{p.putter.Put(ctx, ch)} if rlevel == 0 { return errs[0] diff --git a/pkg/replicas/putter_test.go b/pkg/replicas/putter_test.go index 7a90b5ec308..7d05624ebca 100644 --- a/pkg/replicas/putter_test.go +++ b/pkg/replicas/putter_test.go @@ -75,7 +75,7 @@ func TestPutter(t *testing.T) { t.Fatal(err) } ctx := context.Background() - ctx = replicas.SetLevel(ctx, tc.level) + ctx = redundancy.SetLevelInContext(ctx, tc.level) ch, err := cac.New(buf) if err != nil { @@ -174,7 +174,7 @@ func TestPutter(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) defer cancel() - ctx = replicas.SetLevel(ctx, tc.level) + ctx = redundancy.SetLevelInContext(ctx, tc.level) ch, err := cac.New(buf) if err != nil { t.Fatal(err) diff --git a/pkg/replicas/replicas.go b/pkg/replicas/replicas.go index 18b21d4b8b8..fd45b28a3b1 100644 --- a/pkg/replicas/replicas.go +++ b/pkg/replicas/replicas.go @@ -1,4 +1,4 @@ -// Copyright 2020 The Swarm Authors. All rights reserved. +// Copyright 2023 The Swarm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -11,7 +11,6 @@ package replicas import ( - "context" "time" "github.com/ethersphere/bee/pkg/crypto" @@ -19,31 +18,13 @@ import ( "github.com/ethersphere/bee/pkg/swarm" ) -type redundancyLevelType struct{} - var ( - // redundancyLevel is the context key for the redundancy level - redundancyLevel redundancyLevelType // RetryInterval is the duration between successive additional requests RetryInterval = 300 * time.Millisecond privKey, _ = crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) signer = crypto.NewDefaultSigner(privKey) ) -// SetLevel sets the redundancy level in the context -func SetLevel(ctx context.Context, level redundancy.Level) context.Context { - return context.WithValue(ctx, redundancyLevel, level) -} - -// GetLevelFromContext is a helper function to extract the redundancy level from the context -func GetLevelFromContext(ctx context.Context) redundancy.Level { - rlevel := redundancy.PARANOID - if val := ctx.Value(redundancyLevel); val != nil { - rlevel = val.(redundancy.Level) - } - return rlevel -} - // replicator running the find for replicas type replicator struct { addr []byte // chunk address diff --git a/pkg/steward/steward_test.go b/pkg/steward/steward_test.go index 8e2083abc80..fe3ccd05ba9 100644 --- a/pkg/steward/steward_test.go +++ b/pkg/steward/steward_test.go @@ -17,7 +17,6 @@ import ( "github.com/ethersphere/bee/pkg/file/pipeline/builder" "github.com/ethersphere/bee/pkg/file/redundancy" postagetesting "github.com/ethersphere/bee/pkg/postage/mock" - "github.com/ethersphere/bee/pkg/replicas" "github.com/ethersphere/bee/pkg/steward" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" @@ -49,7 +48,7 @@ func TestSteward(t *testing.T) { s = steward.New(store, localRetrieval, inmem) stamper = postagetesting.NewStamper() ) - ctx = replicas.SetLevel(ctx, redundancy.NONE) + ctx = redundancy.SetLevelInContext(ctx, redundancy.NONE) n, err := rand.Read(data) if n != cap(data) { diff --git a/pkg/storageincentives/proof_test.go b/pkg/storageincentives/proof_test.go index 9ee3745e1af..dcd7002f913 100644 --- a/pkg/storageincentives/proof_test.go +++ b/pkg/storageincentives/proof_test.go @@ -44,8 +44,6 @@ var testData []byte // Test asserts that MakeInclusionProofs will generate the same // output for given sample. func TestMakeInclusionProofsRegression(t *testing.T) { - t.Parallel() - const sampleSize = 16 keyRaw := `00000000000000000000000000000000` diff --git a/pkg/storer/mock/forgetting.go b/pkg/storer/mock/forgetting.go new file mode 100644 index 00000000000..5588b5c263e --- /dev/null +++ b/pkg/storer/mock/forgetting.go @@ -0,0 +1,128 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mockstorer + +import ( + "context" + "sync" + "sync/atomic" + "time" + + storage "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +type DelayedStore struct { + storage.ChunkStore + cache map[string]time.Duration + mu sync.Mutex +} + +func NewDelayedStore(s storage.ChunkStore) *DelayedStore { + return &DelayedStore{ + ChunkStore: s, + cache: make(map[string]time.Duration), + } +} + +func (d *DelayedStore) Delay(addr swarm.Address, delay time.Duration) { + d.mu.Lock() + defer d.mu.Unlock() + d.cache[addr.String()] = delay +} + +func (d *DelayedStore) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + d.mu.Lock() + defer d.mu.Unlock() + if delay, ok := d.cache[addr.String()]; ok && delay > 0 { + select { + case <-time.After(delay): + delete(d.cache, addr.String()) + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return d.ChunkStore.Get(ctx, addr) +} + +type ForgettingStore struct { + storage.ChunkStore + record atomic.Bool + mu sync.Mutex + n atomic.Int64 + missed map[string]struct{} +} + +func NewForgettingStore(s storage.ChunkStore) *ForgettingStore { + return &ForgettingStore{ChunkStore: s, missed: make(map[string]struct{})} +} + +func (f *ForgettingStore) Stored() int64 { + return f.n.Load() +} + +func (f *ForgettingStore) Record() { + f.record.Store(true) +} + +func (f *ForgettingStore) Unrecord() { + f.record.Store(false) +} + +func (f *ForgettingStore) Miss(addr swarm.Address) { + f.mu.Lock() + defer f.mu.Unlock() + f.miss(addr) +} + +func (f *ForgettingStore) Unmiss(addr swarm.Address) { + f.mu.Lock() + defer f.mu.Unlock() + f.unmiss(addr) +} + +func (f *ForgettingStore) miss(addr swarm.Address) { + f.missed[addr.String()] = struct{}{} +} + +func (f *ForgettingStore) unmiss(addr swarm.Address) { + delete(f.missed, addr.String()) +} + +func (f *ForgettingStore) isMiss(addr swarm.Address) bool { + _, ok := f.missed[addr.String()] + return ok +} + +func (f *ForgettingStore) Reset() { + f.missed = make(map[string]struct{}) +} + +func (f *ForgettingStore) Missed() int { + return len(f.missed) +} + +// Get implements the ChunkStore interface. +// if in recording phase, record the chunk address as miss and returns Get on the embedded store +// if in forgetting phase, returns ErrNotFound if the chunk address is recorded as miss +func (f *ForgettingStore) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.record.Load() { + f.miss(addr) + } else if f.isMiss(addr) { + return nil, storage.ErrNotFound + } + return f.ChunkStore.Get(ctx, addr) +} + +// Put implements the ChunkStore interface. +func (f *ForgettingStore) Put(ctx context.Context, ch swarm.Chunk) (err error) { + f.n.Add(1) + if !f.record.Load() { + f.Unmiss(ch.Address()) + } + return f.ChunkStore.Put(ctx, ch) +} diff --git a/pkg/util/testutil/pseudorand/reader.go b/pkg/util/testutil/pseudorand/reader.go new file mode 100644 index 00000000000..d7fcf0c3646 --- /dev/null +++ b/pkg/util/testutil/pseudorand/reader.go @@ -0,0 +1,182 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// this is a pseudorandom reader that generates a deterministic +// sequence of bytes based on the seed. It is used in tests to +// enable large volumes of pseudorandom data to be generated +// and compared without having to store the data in memory. +package pseudorand + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/ethersphere/bee/pkg/swarm" +) + +const bufSize = 4096 + +// Reader is a pseudorandom reader that generates a deterministic +// sequence of bytes based on the seed. +type Reader struct { + cur int + len int + seg [40]byte + buf [bufSize]byte +} + +// NewSeed creates a new seed. +func NewSeed() ([]byte, error) { + seed := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, seed) + return seed, err +} + +// New creates a new pseudorandom reader seeded with the given seed. +func NewReader(seed []byte, l int) *Reader { + r := &Reader{len: l} + _ = copy(r.buf[8:], seed) + r.fill() + return r +} + +// Size returns the size of the reader. +func (r *Reader) Size() int { + return r.len +} + +// Read reads len(buf) bytes into buf. It returns the number of bytes read (0 <= n <= len(buf)) and any error encountered. Even if Read returns n < len(buf), it may use all of buf as scratch space during the call. If some data is available but not len(buf) bytes, Read conventionally returns what is available instead of waiting for more. +func (r *Reader) Read(buf []byte) (n int, err error) { + cur := r.cur % bufSize + toRead := min(bufSize-cur, r.len-r.cur) + if toRead < len(buf) { + buf = buf[:toRead] + } + n = copy(buf, r.buf[cur:]) + r.cur += n + if r.cur == r.len { + return n, io.EOF + } + if r.cur%bufSize == 0 { + r.fill() + } + return n, nil +} + +// Equal compares the contents of the reader with the contents of +// the given reader. It returns true if the contents are equal upto n bytes +func (r1 *Reader) Equal(r2 io.Reader) (bool, error) { + ok, err := r1.Match(r2, r1.len) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + n, err := io.ReadFull(r2, make([]byte, 1)) + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return n == 0, nil + } + return false, err +} + +// Match compares the contents of the reader with the contents of +// the given reader. It returns true if the contents are equal upto n bytes +func (r1 *Reader) Match(r2 io.Reader, l int) (bool, error) { + + read := func(r io.Reader, buf []byte) (n int, err error) { + for n < len(buf) && err == nil { + i, e := r.Read(buf[n:]) + if e == nil && i == 0 { + return n, nil + } + err = e + n += i + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + err = nil + } + return n, err + } + + buf1 := make([]byte, bufSize) + buf2 := make([]byte, bufSize) + for l > 0 { + if l <= len(buf1) { + buf1 = buf1[:l] + buf2 = buf2[:l] + } + + n1, err := read(r1, buf1) + if err != nil { + return false, err + } + n2, err := read(r2, buf2) + if err != nil { + return false, err + } + if n1 != n2 { + return false, nil + } + if !bytes.Equal(buf1[:n1], buf2[:n2]) { + return false, nil + } + l -= n1 + } + return true, nil +} + +// Seek sets the offset for the next Read to offset, interpreted +// according to whence: 0 means relative to the start of the file, +// 1 means relative to the current offset, and 2 means relative to +// the end. It returns the new offset and an error, if any. +func (r *Reader) Seek(offset int64, whence int) (int64, error) { + switch whence { + case 0: + r.cur = int(offset) + case 1: + r.cur += int(offset) + case 2: + r.cur = r.len - int(offset) + } + if r.cur < 0 || r.cur > r.len { + return 0, fmt.Errorf("seek: invalid offset %d", r.cur) + } + r.fill() + return int64(r.cur), nil +} + +// Offset returns the current offset of the reader. +func (r *Reader) Offset() int64 { + return int64(r.cur) +} + +// ReadAt reads len(buf) bytes into buf starting at offset off. +func (r *Reader) ReadAt(buf []byte, off int64) (n int, err error) { + if _, err := r.Seek(off, io.SeekStart); err != nil { + return 0, err + } + return r.Read(buf) +} + +// fill fills the buffer with the hash of the current segment. +func (r *Reader) fill() { + if r.cur >= r.len { + return + } + bufSegments := bufSize / 32 + start := r.cur / bufSegments + rem := (r.cur % bufSize) / 32 + h := swarm.NewHasher() + for i := 32 * rem; i < len(r.buf); i += 32 { + binary.BigEndian.PutUint64(r.seg[:], uint64((start+i)/32)) + h.Reset() + _, _ = h.Write(r.seg[:]) + copy(r.buf[i:], h.Sum(nil)) + } +} diff --git a/pkg/util/testutil/pseudorand/reader_test.go b/pkg/util/testutil/pseudorand/reader_test.go new file mode 100644 index 00000000000..4ec85a90d73 --- /dev/null +++ b/pkg/util/testutil/pseudorand/reader_test.go @@ -0,0 +1,121 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pseudorand_test + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/rand" + "testing" + + "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" +) + +func TestReader(t *testing.T) { + size := 42000 + seed := make([]byte, 32) + r := pseudorand.NewReader(seed, size) + content, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + t.Run("deterministicity", func(t *testing.T) { + r2 := pseudorand.NewReader(seed, size) + content2, err := io.ReadAll(r2) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(content, content2) { + t.Fatal("content mismatch") + } + }) + t.Run("randomness", func(t *testing.T) { + bufSize := 4096 + if bytes.Equal(content[:bufSize], content[bufSize:2*bufSize]) { + t.Fatal("buffers should not match") + } + }) + t.Run("re-readability", func(t *testing.T) { + ns, err := r.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != 0 { + t.Fatal("seek mismatch") + } + var read []byte + buf := make([]byte, 8200) + total := 0 + for { + s := rand.Intn(820) + n, err := r.Read(buf[:s]) + total += n + read = append(read, buf[:n]...) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatal(err) + } + } + read = read[:total] + if len(read) != len(content) { + t.Fatal("content length mismatch. expected", len(content), "got", len(read)) + } + if !bytes.Equal(content, read) { + t.Fatal("content mismatch") + } + }) + t.Run("comparison", func(t *testing.T) { + ns, err := r.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != 0 { + t.Fatal("seek mismatch") + } + if eq, err := r.Equal(bytes.NewBuffer(content)); err != nil { + t.Fatal(err) + } else if !eq { + t.Fatal("content mismatch") + } + ns, err = r.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != 0 { + t.Fatal("seek mismatch") + } + if eq, err := r.Equal(bytes.NewBuffer(content[:len(content)-1])); err != nil { + t.Fatal(err) + } else if eq { + t.Fatal("content should match") + } + }) + t.Run("seek and match", func(t *testing.T) { + for i := 0; i < 20; i++ { + off := rand.Intn(size) + n := rand.Intn(size - off) + t.Run(fmt.Sprintf("off=%d n=%d", off, n), func(t *testing.T) { + ns, err := r.Seek(int64(off), io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != int64(off) { + t.Fatal("seek mismatch") + } + ok, err := r.Match(bytes.NewBuffer(content[off:off+n]), n) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("content mismatch") + } + }) + } + }) +} diff --git a/pkg/util/testutil/racedetection/off.go b/pkg/util/testutil/racedetection/off.go new file mode 100644 index 00000000000..d57125bfd03 --- /dev/null +++ b/pkg/util/testutil/racedetection/off.go @@ -0,0 +1,10 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !race +// +build !race + +package racedetection + +const On = false diff --git a/pkg/util/testutil/racedetection/on.go b/pkg/util/testutil/racedetection/on.go new file mode 100644 index 00000000000..92438ecfdc3 --- /dev/null +++ b/pkg/util/testutil/racedetection/on.go @@ -0,0 +1,10 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build race +// +build race + +package racedetection + +const On = true diff --git a/pkg/util/testutil/racedetection/race.go b/pkg/util/testutil/racedetection/race.go new file mode 100644 index 00000000000..411295e87f6 --- /dev/null +++ b/pkg/util/testutil/racedetection/race.go @@ -0,0 +1,9 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package racedetection + +func IsOn() bool { + return On +}