diff --git a/cmd/bee/cmd/split.go b/cmd/bee/cmd/split.go index 6eba46c69c2..e173ae186fa 100644 --- a/cmd/bee/cmd/split.go +++ b/cmd/bee/cmd/split.go @@ -33,7 +33,7 @@ type pipelineFunc func(context.Context, io.Reader) (swarm.Address, error) func requestPipelineFn(s storage.Putter, encrypt bool) pipelineFunc { return func(ctx context.Context, r io.Reader) (swarm.Address, error) { - pipe := builder.NewPipelineBuilder(ctx, s, encrypt) + pipe := builder.NewPipelineBuilder(ctx, s, encrypt, 0) return builder.FeedPipeline(ctx, pipe, r) } } diff --git a/go.mod b/go.mod index 5f6f2d5a5cd..5d3e61c85ec 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.5 github.com/ipfs/go-cid v0.4.1 github.com/kardianos/service v1.2.0 + github.com/klauspost/reedsolomon v1.11.8 github.com/libp2p/go-libp2p v0.30.0 github.com/multiformats/go-multiaddr v0.11.0 github.com/multiformats/go-multiaddr-dns v0.3.1 @@ -105,7 +106,7 @@ require ( github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect github.com/klauspost/compress v1.16.7 // indirect - github.com/klauspost/cpuid/v2 v2.2.5 // indirect + github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/koron/go-ssdp v0.0.4 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/libp2p/go-buffer-pool v0.1.0 // indirect diff --git a/go.sum b/go.sum index 0b58fffeeff..c2938b6ec32 100644 --- a/go.sum +++ b/go.sum @@ -522,10 +522,12 @@ github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQs github.com/klauspost/cpuid v0.0.0-20170728055534-ae7887de9fa5/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/cpuid/v2 v2.0.4/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.6/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= -github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= +github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg= github.com/klauspost/pgzip v1.0.2-0.20170402124221-0bf5dcad4ada/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= +github.com/klauspost/reedsolomon v1.11.8 h1:s8RpUW5TK4hjr+djiOpbZJB4ksx+TdYbRH7vHQpwPOY= +github.com/klauspost/reedsolomon v1.11.8/go.mod h1:4bXRN+cVzMdml6ti7qLouuYi32KHJ5MGv0Qd8a47h6A= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/koron/go-ssdp v0.0.4 h1:1IDwrghSKYM7yLf7XCzbByg2sJ/JcNOZRXS2jczTwz0= diff --git a/openapi/Swarm.yaml b/openapi/Swarm.yaml index accfe44ae2e..61b446d1095 100644 --- a/openapi/Swarm.yaml +++ b/openapi/Swarm.yaml @@ -1,7 +1,7 @@ openapi: 3.0.3 info: - version: 5.1.1 + version: 5.2.0 title: Bee API description: "A list of the currently provided Interfaces to interact with the swarm, implementing file operations and sending messages" @@ -120,6 +120,11 @@ paths: $ref: "SwarmCommon.yaml#/components/parameters/SwarmEncryptParameter" name: swarm-encrypt required: false + - in: header + schema: + $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyLevelParameter" + name: swarm-redundancy-level + required: false requestBody: content: @@ -158,11 +163,10 @@ paths: $ref: "SwarmCommon.yaml#/components/schemas/SwarmReference" required: true description: Swarm address reference to content - - in: header - schema: - $ref: "SwarmCommon.yaml#/components/parameters/SwarmCache" - name: swarm-cache - required: false + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmCache" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyStrategyParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyFallbackModeParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmChunkRetrievalTimeoutParameter" responses: "200": description: Retrieved content specified by reference @@ -254,6 +258,7 @@ paths: - $ref: "SwarmCommon.yaml#/components/parameters/SwarmErrorDocumentParameter" - $ref: "SwarmCommon.yaml#/components/parameters/SwarmPostageBatchId" - $ref: "SwarmCommon.yaml#/components/parameters/SwarmDeferredUpload" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyLevelParameter" requestBody: content: multipart/form-data: @@ -305,11 +310,10 @@ paths: $ref: "SwarmCommon.yaml#/components/schemas/SwarmReference" required: true description: Swarm address of content - - in: header - schema: - $ref: "SwarmCommon.yaml#/components/parameters/SwarmCache" - name: swarm-cache - required: false + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmCache" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyStrategyParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyFallbackModeParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmChunkRetrievalTimeoutParameter" responses: "200": description: Ok @@ -347,6 +351,9 @@ paths: type: string required: true description: Path to the file in the collection. + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyStrategyParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmRedundancyFallbackModeParameter" + - $ref: "SwarmCommon.yaml#/components/parameters/SwarmChunkRetrievalTimeoutParameter" responses: "200": description: Ok diff --git a/openapi/SwarmCommon.yaml b/openapi/SwarmCommon.yaml index a3e46342eea..07119cc6894 100644 --- a/openapi/SwarmCommon.yaml +++ b/openapi/SwarmCommon.yaml @@ -934,6 +934,52 @@ components: description: > Represents the encrypting state of the file + SwarmRedundancyLevelParameter: + in: header + name: swarm-redundancy-level + schema: + type: integer + enum: [0, 1, 2, 3, 4] + required: false + description: > + Add redundancy to the data being uploaded so that downloaders can download it with better UX. + 0 value is default and does not add any redundancy to the file. + + SwarmRedundancyStrategyParameter: + in: header + name: swarm-redundancy-strategy + schema: + type: integer + enum: [0, 1, 2, 3] + required: false + description: > + Specify the retrieve strategy on redundant data. + The numbers stand for NONE, DATA, PROX and RACE, respectively. + Strategy NONE means no prefetching takes place. + Strategy DATA means only data chunks are prefetched. + Strategy PROX means only chunks that are close to the node are prefetched. + Strategy RACE means all chunks are prefetched: n data chunks and k parity chunks. The first n chunks to arrive are used to reconstruct the file. + Multiple strategies can be used in a fallback cascade if the swarm redundancy fallback mode is set to true. + The default strategy is NONE, DATA, falling back to PROX, falling back to RACE + + SwarmRedundancyFallbackModeParameter: + in: header + name: swarm-redundancy-fallback-mode + schema: + type: boolean + required: false + description: > + Specify if the retrieve strategies (chunk prefetching on redundant data) are used in a fallback cascade. The default is true. + + SwarmChunkRetrievalTimeoutParameter: + in: header + name: swarm-chunk-retrieval-timeout + schema: + $ref: "#/components/schemas/Duration" + required: false + description: > + Specify the timeout for chunk retrieval. The default is 30 seconds. + ContentTypePreserved: in: header name: Content-Type diff --git a/pkg/api/api.go b/pkg/api/api.go index bd9fdbac2da..64c5712edf8 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -33,6 +33,7 @@ import ( "github.com/ethersphere/bee/pkg/feeds" "github.com/ethersphere/bee/pkg/file/pipeline" "github.com/ethersphere/bee/pkg/file/pipeline/builder" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/p2p" @@ -69,16 +70,21 @@ import ( const loggerName = "api" const ( - SwarmPinHeader = "Swarm-Pin" - SwarmTagHeader = "Swarm-Tag" - SwarmEncryptHeader = "Swarm-Encrypt" - SwarmIndexDocumentHeader = "Swarm-Index-Document" - SwarmErrorDocumentHeader = "Swarm-Error-Document" - SwarmFeedIndexHeader = "Swarm-Feed-Index" - SwarmFeedIndexNextHeader = "Swarm-Feed-Index-Next" - SwarmCollectionHeader = "Swarm-Collection" - SwarmPostageBatchIdHeader = "Swarm-Postage-Batch-Id" - SwarmDeferredUploadHeader = "Swarm-Deferred-Upload" + SwarmPinHeader = "Swarm-Pin" + SwarmTagHeader = "Swarm-Tag" + SwarmEncryptHeader = "Swarm-Encrypt" + SwarmIndexDocumentHeader = "Swarm-Index-Document" + SwarmErrorDocumentHeader = "Swarm-Error-Document" + SwarmFeedIndexHeader = "Swarm-Feed-Index" + SwarmFeedIndexNextHeader = "Swarm-Feed-Index-Next" + SwarmCollectionHeader = "Swarm-Collection" + SwarmPostageBatchIdHeader = "Swarm-Postage-Batch-Id" + SwarmDeferredUploadHeader = "Swarm-Deferred-Upload" + SwarmRedundancyLevelHeader = "Swarm-Redundancy-Level" + SwarmRedundancyStrategyHeader = "Swarm-Redundancy-Strategy" + SwarmRedundancyFallbackModeHeader = "Swarm-Redundancy-Fallback-Mode" + SwarmChunkRetrievalTimeoutHeader = "Swarm-Chunk-Retrieval-Timeout" + SwarmLookAheadBufferSizeHeader = "Swarm-Lookahead-Buffer-Size" ImmutableHeader = "Immutable" GasPriceHeader = "Gas-Price" @@ -94,18 +100,6 @@ const ( OriginHeader = "Origin" ) -// The size of buffer used for prefetching content with Langos. -// Warning: This value influences the number of chunk requests and chunker join goroutines -// per file request. -// Recommended value is 8 or 16 times the io.Copy default buffer value which is 32kB, depending -// on the file size. Use lookaheadBufferSize() to get the correct buffer size for the request. -const ( - smallFileBufferSize = 8 * 32 * 1024 - largeFileBufferSize = 16 * 32 * 1024 - - largeBufferFilesizeThreshold = 10 * 1000000 // ten megs -) - const ( multiPartFormData = "multipart/form-data" contentTypeTar = "application/x-tar" @@ -610,20 +604,12 @@ func (s *Service) gasConfigMiddleware(handlerName string) func(h http.Handler) h } } -func lookaheadBufferSize(size int64) int { - if size <= largeBufferFilesizeThreshold { - return smallFileBufferSize - } - return largeFileBufferSize -} - // corsHandler sets CORS headers to HTTP response if allowed origins are configured. func (s *Service) corsHandler(h http.Handler) http.Handler { allowedHeaders := []string{ "User-Agent", "Accept", "X-Requested-With", "Access-Control-Request-Headers", "Access-Control-Request-Method", "Accept-Ranges", "Content-Encoding", AuthorizationHeader, AcceptEncodingHeader, ContentTypeHeader, ContentDispositionHeader, RangeHeader, OriginHeader, - SwarmTagHeader, SwarmPinHeader, SwarmEncryptHeader, SwarmIndexDocumentHeader, SwarmErrorDocumentHeader, SwarmCollectionHeader, SwarmPostageBatchIdHeader, SwarmDeferredUploadHeader, - GasPriceHeader, GasLimitHeader, ImmutableHeader, + SwarmTagHeader, SwarmPinHeader, SwarmEncryptHeader, SwarmIndexDocumentHeader, SwarmErrorDocumentHeader, SwarmCollectionHeader, SwarmPostageBatchIdHeader, SwarmDeferredUploadHeader, SwarmRedundancyLevelHeader, SwarmRedundancyStrategyHeader, SwarmRedundancyFallbackModeHeader, SwarmChunkRetrievalTimeoutHeader, SwarmFeedIndexHeader, SwarmFeedIndexNextHeader, GasPriceHeader, GasLimitHeader, ImmutableHeader, } allowedHeadersStr := strings.Join(allowedHeaders, ", ") @@ -848,16 +834,16 @@ func (s *Service) newStamperPutter(ctx context.Context, opts putterOptions) (sto type pipelineFunc func(context.Context, io.Reader) (swarm.Address, error) -func requestPipelineFn(s storage.Putter, encrypt bool) pipelineFunc { +func requestPipelineFn(s storage.Putter, encrypt bool, rLevel redundancy.Level) pipelineFunc { return func(ctx context.Context, r io.Reader) (swarm.Address, error) { - pipe := builder.NewPipelineBuilder(ctx, s, encrypt) + pipe := builder.NewPipelineBuilder(ctx, s, encrypt, rLevel) return builder.FeedPipeline(ctx, pipe, r) } } -func requestPipelineFactory(ctx context.Context, s storage.Putter, encrypt bool) func() pipeline.Interface { +func requestPipelineFactory(ctx context.Context, s storage.Putter, encrypt bool, rLevel redundancy.Level) func() pipeline.Interface { return func() pipeline.Interface { - return builder.NewPipelineBuilder(ctx, s, encrypt) + return builder.NewPipelineBuilder(ctx, s, encrypt, rLevel) } } diff --git a/pkg/api/api_test.go b/pkg/api/api_test.go index a31bb87ff5d..b06e5808479 100644 --- a/pkg/api/api_test.go +++ b/pkg/api/api_test.go @@ -31,6 +31,7 @@ import ( "github.com/ethersphere/bee/pkg/feeds" "github.com/ethersphere/bee/pkg/file/pipeline" "github.com/ethersphere/bee/pkg/file/pipeline/builder" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/log" p2pmock "github.com/ethersphere/bee/pkg/p2p/mock" @@ -321,9 +322,9 @@ func request(t *testing.T, client *http.Client, method, resource string, body io return resp } -func pipelineFactory(s storage.Putter, encrypt bool) func() pipeline.Interface { +func pipelineFactory(s storage.Putter, encrypt bool, rLevel redundancy.Level) func() pipeline.Interface { return func() pipeline.Interface { - return builder.NewPipelineBuilder(context.Background(), s, encrypt) + return builder.NewPipelineBuilder(context.Background(), s, encrypt, rLevel) } } diff --git a/pkg/api/bytes.go b/pkg/api/bytes.go index 46bd0e491f0..4cb90c1b763 100644 --- a/pkg/api/bytes.go +++ b/pkg/api/bytes.go @@ -12,6 +12,7 @@ import ( "strconv" "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/postage" storage "github.com/ethersphere/bee/pkg/storage" @@ -33,11 +34,12 @@ func (s *Service) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { defer span.Finish() headers := struct { - BatchID []byte `map:"Swarm-Postage-Batch-Id" validate:"required"` - SwarmTag uint64 `map:"Swarm-Tag"` - Pin bool `map:"Swarm-Pin"` - Deferred *bool `map:"Swarm-Deferred-Upload"` - Encrypt bool `map:"Swarm-Encrypt"` + BatchID []byte `map:"Swarm-Postage-Batch-Id" validate:"required"` + SwarmTag uint64 `map:"Swarm-Tag"` + Pin bool `map:"Swarm-Pin"` + Deferred *bool `map:"Swarm-Deferred-Upload"` + Encrypt bool `map:"Swarm-Encrypt"` + RLevel redundancy.Level `map:"Swarm-Redundancy-Level"` }{} if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) @@ -98,7 +100,7 @@ func (s *Service) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { logger: logger, } - p := requestPipelineFn(putter, headers.Encrypt) + p := requestPipelineFn(putter, headers.Encrypt, headers.RLevel) address, err := p(ctx, r.Body) if err != nil { logger.Debug("split write all failed", "error", err) diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index 3089d3fa6c0..6d96b3e83bc 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -16,7 +16,6 @@ import ( "strings" "time" - "github.com/ethersphere/bee/pkg/topology" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" olog "github.com/opentracing/opentracing-go/log" @@ -25,6 +24,8 @@ import ( "github.com/ethersphere/bee/pkg/feeds" "github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/loadsave" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/file/redundancy/getter" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/manifest" @@ -32,23 +33,44 @@ import ( storage "github.com/ethersphere/bee/pkg/storage" storer "github.com/ethersphere/bee/pkg/storer" "github.com/ethersphere/bee/pkg/swarm" + "github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/tracing" "github.com/ethersphere/langos" "github.com/gorilla/mux" ) +// The size of buffer used for prefetching content with Langos when not using erasure coding +// Warning: This value influences the number of chunk requests and chunker join goroutines +// per file request. +// Recommended value is 8 or 16 times the io.Copy default buffer value which is 32kB, depending +// on the file size. Use lookaheadBufferSize() to get the correct buffer size for the request. +const ( + smallFileBufferSize = 8 * 32 * 1024 + largeFileBufferSize = 16 * 32 * 1024 + + largeBufferFilesizeThreshold = 10 * 1000000 // ten megs +) + +func lookaheadBufferSize(size int64) int { + if size <= largeBufferFilesizeThreshold { + return smallFileBufferSize + } + return largeFileBufferSize +} + func (s *Service) bzzUploadHandler(w http.ResponseWriter, r *http.Request) { span, logger, ctx := s.tracer.StartSpanFromContext(r.Context(), "post_bzz", s.logger.WithName("post_bzz").Build()) defer span.Finish() headers := struct { - ContentType string `map:"Content-Type,mimeMediaType" validate:"required"` - BatchID []byte `map:"Swarm-Postage-Batch-Id" validate:"required"` - SwarmTag uint64 `map:"Swarm-Tag"` - Pin bool `map:"Swarm-Pin"` - Deferred *bool `map:"Swarm-Deferred-Upload"` - Encrypt bool `map:"Swarm-Encrypt"` - IsDir bool `map:"Swarm-Collection"` + ContentType string `map:"Content-Type,mimeMediaType" validate:"required"` + BatchID []byte `map:"Swarm-Postage-Batch-Id" validate:"required"` + SwarmTag uint64 `map:"Swarm-Tag"` + Pin bool `map:"Swarm-Pin"` + Deferred *bool `map:"Swarm-Deferred-Upload"` + Encrypt bool `map:"Swarm-Encrypt"` + IsDir bool `map:"Swarm-Collection"` + RLevel redundancy.Level `map:"Swarm-Redundancy-Level"` }{} if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) @@ -110,10 +132,10 @@ func (s *Service) bzzUploadHandler(w http.ResponseWriter, r *http.Request) { } if headers.IsDir || headers.ContentType == multiPartFormData { - s.dirUploadHandler(ctx, logger, span, ow, r, putter, r.Header.Get(ContentTypeHeader), headers.Encrypt, tag) + s.dirUploadHandler(ctx, logger, span, ow, r, putter, r.Header.Get(ContentTypeHeader), headers.Encrypt, tag, headers.RLevel) return } - s.fileUploadHandler(ctx, logger, span, ow, r, putter, headers.Encrypt, tag) + s.fileUploadHandler(ctx, logger, span, ow, r, putter, headers.Encrypt, tag, headers.RLevel) } // fileUploadResponse is returned when an HTTP request to upload a file is successful @@ -132,6 +154,7 @@ func (s *Service) fileUploadHandler( putter storer.PutterSession, encrypt bool, tagID uint64, + rLevel redundancy.Level, ) { queries := struct { FileName string `map:"name" validate:"startsnotwith=/"` @@ -141,7 +164,7 @@ func (s *Service) fileUploadHandler( return } - p := requestPipelineFn(putter, encrypt) + p := requestPipelineFn(putter, encrypt, rLevel) // first store the file and get its reference fr, err := p(ctx, r.Body) @@ -181,8 +204,8 @@ func (s *Service) fileUploadHandler( } } - factory := requestPipelineFactory(ctx, putter, encrypt) - l := loadsave.New(s.storer.ChunkStore(), factory) + factory := requestPipelineFactory(ctx, putter, encrypt, rLevel) + l := loadsave.New(s.storer.ChunkStore(), s.storer.Cache(), factory) m, err := manifest.NewDefaultManifest(l, encrypt) if err != nil { @@ -283,8 +306,13 @@ func (s *Service) serveReference(logger log.Logger, address swarm.Address, pathV loggerV1 := logger.V(1).Build() headers := struct { - Cache *bool `map:"Swarm-Cache"` + Cache *bool `map:"Swarm-Cache"` + Strategy getter.Strategy `map:"Swarm-Redundancy-Strategy"` + FallbackMode bool `map:"Swarm-Redundancy-Fallback-Mode"` + ChunkRetrievalTimeout string `map:"Swarm-Chunk-Retrieval-Timeout"` + LookaheadBufferSize *string `map:"Swarm-Lookahead-Buffer-Size"` }{} + if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) return @@ -293,10 +321,12 @@ func (s *Service) serveReference(logger log.Logger, address swarm.Address, pathV if headers.Cache != nil { cache = *headers.Cache } + ls := loadsave.NewReadonly(s.storer.Download(cache)) feedDereferenced := false ctx := r.Context() + ctx = getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, getter.DefaultStrategyTimeout.String()) FETCH: // read manifest entry @@ -377,7 +407,6 @@ FETCH: jsonhttp.NotFound(w, "address not found or incorrect") return } - me, err := m.Lookup(ctx, pathVar) if err != nil { loggerV1.Debug("bzz download: invalid path", "address", address, "path", pathVar, "error", err) @@ -467,8 +496,13 @@ func (s *Service) serveManifestEntry( // downloadHandler contains common logic for dowloading Swarm file from API func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *http.Request, reference swarm.Address, additionalHeaders http.Header, etag bool) { headers := struct { - Cache *bool `map:"Swarm-Cache"` + Cache *bool `map:"Swarm-Cache"` + Strategy getter.Strategy `map:"Swarm-Redundancy-Strategy"` + FallbackMode bool `map:"Swarm-Redundancy-Fallback-Mode"` + ChunkRetrievalTimeout string `map:"Swarm-Chunk-Retrieval-Timeout"` + LookaheadBufferSize *string `map:"Swarm-Lookahead-Buffer-Size"` }{} + if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) return @@ -477,7 +511,10 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h if headers.Cache != nil { cache = *headers.Cache } - reader, l, err := joiner.New(r.Context(), s.storer.Download(cache), reference) + + ctx := r.Context() + ctx = getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, getter.DefaultStrategyTimeout.String()) + reader, l, err := joiner.New(ctx, s.storer.Download(cache), s.storer.Cache(), reference) if err != nil { if errors.Is(err, storage.ErrNotFound) || errors.Is(err, topology.ErrNotFound) { logger.Debug("api download: not found ", "address", reference, "error", err) @@ -500,7 +537,19 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h } w.Header().Set(ContentLengthHeader, strconv.FormatInt(l, 10)) w.Header().Set("Access-Control-Expose-Headers", ContentDispositionHeader) - http.ServeContent(w, r, "", time.Now(), langos.NewBufferedLangos(reader, lookaheadBufferSize(l))) + bufSize := int64(lookaheadBufferSize(l)) + if headers.LookaheadBufferSize != nil { + bufSize, err = strconv.ParseInt(*headers.LookaheadBufferSize, 10, 64) + if err != nil { + logger.Debug("parsing lookahead buffer size", "error", err) + bufSize = 0 + } + } + if bufSize > 0 { + http.ServeContent(w, r, "", time.Now(), langos.NewBufferedLangos(reader, int(bufSize))) + return + } + http.ServeContent(w, r, "", time.Now(), reader) } // manifestMetadataLoad returns the value for a key stored in the metadata of diff --git a/pkg/api/bzz_test.go b/pkg/api/bzz_test.go index 4e54a967d3a..634d4eb0166 100644 --- a/pkg/api/bzz_test.go +++ b/pkg/api/bzz_test.go @@ -16,20 +16,223 @@ import ( "strconv" "strings" "testing" + "time" "github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/file/loadsave" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/manifest" mockbatchstore "github.com/ethersphere/bee/pkg/postage/batchstore/mock" mockpost "github.com/ethersphere/bee/pkg/postage/mock" + "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" mockstorer "github.com/ethersphere/bee/pkg/storer/mock" "github.com/ethersphere/bee/pkg/swarm" + "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" ) -// nolint:paralleltest,tparallel +// nolint:paralleltest,tparallel,thelper + +// TestBzzUploadDownloadWithRedundancy tests the API for upload and download files +// with all combinations of redundancy level, encryption and size (levels, i.e., the +// +// height of the swarm hash tree). +// +// This is a variation on the same play as TestJoinerRedundancy +// but here the tested scenario is simplified since we are not testing the intricacies of +// download strategies, but only correct parameter passing and correct recovery functionality +// +// The test cases have the following structure: +// +// 1. upload a file with a given redundancy level and encryption +// +// 2. [positive test] download the file by the reference returned by the upload API response +// This uses range queries to target specific (number of) chunks of the file structure. +// During path traversal in the swarm hash tree, the underlying mocksore (forgetting) +// is in 'recording' mode, flagging all the retrieved chunks as chunks to forget. +// This is to simulate the scenario where some of the chunks are not available/lost +// NOTE: For this to work one needs to switch off lookaheadbuffer functionality +// (see langos pkg) +// +// 3. [negative test] attempt at downloading the file using once again the same root hash +// and the same redundancy strategy to find the file inaccessible after forgetting. +// +// 4. [positive test] attempt at downloading the file using a strategy that allows for +// using redundancy to reconstruct the file and find the file recoverable. +// +// nolint:thelper +func TestBzzUploadDownloadWithRedundancy(t *testing.T) { + fileUploadResource := "/bzz" + fileDownloadResource := func(addr string) string { return "/bzz/" + addr + "/" } + + testRedundancy := func(t *testing.T, rLevel redundancy.Level, encrypt bool, levels int, chunkCnt int, shardCnt int, parityCnt int) { + t.Helper() + seed, err := pseudorand.NewSeed() + if err != nil { + t.Fatal(err) + } + fetchTimeout := 100 * time.Millisecond + store := mockstorer.NewForgettingStore(inmemchunkstore.New()) + storerMock := mockstorer.NewWithChunkStore(store) + client, _, _, _ := newTestServer(t, testServerOptions{ + Storer: storerMock, + Logger: log.Noop, + Post: mockpost.New(mockpost.WithAcceptAll()), + }) + + dataReader := pseudorand.NewReader(seed, chunkCnt*swarm.ChunkSize) + + var refResponse api.BzzUploadResponse + jsonhttptest.Request(t, client, http.MethodPost, fileUploadResource, + http.StatusCreated, + jsonhttptest.WithRequestHeader(api.SwarmDeferredUploadHeader, "True"), + jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, batchOkStr), + jsonhttptest.WithRequestBody(dataReader), + jsonhttptest.WithRequestHeader(api.SwarmEncryptHeader, fmt.Sprintf("%t", encrypt)), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyLevelHeader, fmt.Sprintf("%d", rLevel)), + jsonhttptest.WithRequestHeader(api.ContentTypeHeader, "image/jpeg; charset=utf-8"), + jsonhttptest.WithUnmarshalJSONResponse(&refResponse), + ) + + t.Run("download multiple ranges without redundancy should succeed", func(t *testing.T) { + // the underlying chunk store is in recording mode, so all chunks retrieved + // in this test will be forgotten in the subsequent ones. + store.Record() + defer store.Unrecord() + // we intend to forget as many chunks as possible for the given redundancy level + forget := parityCnt + if parityCnt > shardCnt { + forget = shardCnt + } + if levels == 1 { + forget = 2 + } + start, end := 420, 450 + gap := swarm.ChunkSize + for j := 2; j < levels; j++ { + gap *= shardCnt + } + ranges := make([][2]int, forget) + for i := 0; i < forget; i++ { + pre := i * gap + ranges[i] = [2]int{pre + start, pre + end} + } + rangeHeader, want := createRangeHeader(dataReader, ranges) + + var body []byte + respHeaders := jsonhttptest.Request(t, client, http.MethodGet, + fileDownloadResource(refResponse.Reference.String()), + http.StatusPartialContent, + jsonhttptest.WithRequestHeader(api.RangeHeader, rangeHeader), + jsonhttptest.WithRequestHeader(api.SwarmLookAheadBufferSizeHeader, "0"), + // set for the replicas so that no replica gets deleted + jsonhttptest.WithRequestHeader(api.SwarmRedundancyLevelHeader, "0"), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyStrategyHeader, "0"), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyFallbackModeHeader, "false"), + jsonhttptest.WithRequestHeader(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()), + jsonhttptest.WithPutResponseBody(&body), + ) + + got := parseRangeParts(t, respHeaders.Get(api.ContentTypeHeader), body) + + if len(got) != len(want) { + t.Fatalf("got %v parts, want %v parts", len(got), len(want)) + } + for i := 0; i < len(want); i++ { + if !bytes.Equal(got[i], want[i]) { + t.Errorf("part %v: got %q, want %q", i, string(got[i]), string(want[i])) + } + } + }) + + t.Run("download without redundancy should NOT succeed", func(t *testing.T) { + if rLevel == 0 { + t.Skip("NA") + } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "GET", fileDownloadResource(refResponse.Reference.String()), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set(api.SwarmRedundancyStrategyHeader, "0") + req.Header.Set(api.SwarmRedundancyFallbackModeHeader, "false") + req.Header.Set(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()) + + _, err = client.Do(req) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected error %v; got %v", io.ErrUnexpectedEOF, err) + } + }) + + t.Run("download with redundancy should succeed", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.TODO(), "GET", fileDownloadResource(refResponse.Reference.String()), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set(api.SwarmRedundancyStrategyHeader, "3") + req.Header.Set(api.SwarmRedundancyFallbackModeHeader, "true") + req.Header.Set(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()) + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d; got %d", http.StatusOK, resp.StatusCode) + } + _, err = dataReader.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + ok, err := dataReader.Equal(resp.Body) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatalf("content mismatch") + } + }) + } + for _, rLevel := range []redundancy.Level{1, 2, 3, 4} { + rLevel := rLevel + t.Run(fmt.Sprintf("level=%d", rLevel), func(t *testing.T) { + for _, encrypt := range []bool{false, true} { + encrypt := encrypt + shardCnt := rLevel.GetMaxShards() + parityCnt := rLevel.GetParities(shardCnt) + if encrypt { + shardCnt = rLevel.GetMaxEncShards() + parityCnt = rLevel.GetEncParities(shardCnt) + } + for _, levels := range []int{1, 2, 3} { + chunkCnt := 1 + switch levels { + case 1: + chunkCnt = 2 + case 2: + chunkCnt = shardCnt + 1 + case 3: + chunkCnt = shardCnt*shardCnt + 1 + } + levels := levels + t.Run(fmt.Sprintf("encrypt=%v levels=%d chunks=%d", encrypt, levels, chunkCnt), func(t *testing.T) { + if levels > 2 && (encrypt == (rLevel%2 == 1)) { + t.Skip("skipping to save time") + } + t.Parallel() + testRedundancy(t, rLevel, encrypt, levels, chunkCnt, shardCnt, parityCnt) + }) + } + } + }) + } +} + func TestBzzFiles(t *testing.T) { t.Parallel() @@ -187,6 +390,28 @@ func TestBzzFiles(t *testing.T) { ) }) + t.Run("redundancy", func(t *testing.T) { + fileName := "my-pictures.jpeg" + + var resp api.BzzUploadResponse + jsonhttptest.Request(t, client, http.MethodPost, fileUploadResource+"?name="+fileName, http.StatusCreated, + jsonhttptest.WithRequestHeader(api.SwarmDeferredUploadHeader, "true"), + jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, batchOkStr), + jsonhttptest.WithRequestBody(bytes.NewReader(simpleData)), + jsonhttptest.WithRequestHeader(api.SwarmEncryptHeader, "True"), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyLevelHeader, "4"), + jsonhttptest.WithRequestHeader(api.ContentTypeHeader, "image/jpeg; charset=utf-8"), + jsonhttptest.WithUnmarshalJSONResponse(&resp), + ) + + jsonhttptest.Request(t, client, http.MethodGet, fileDownloadResource(resp.Reference.String()), http.StatusOK, + jsonhttptest.WithExpectedContentLength(len(simpleData)), + jsonhttptest.WithExpectedResponseHeader(api.ContentTypeHeader, "image/jpeg; charset=utf-8"), + jsonhttptest.WithExpectedResponseHeader(api.ContentDispositionHeader, fmt.Sprintf(`inline; filename="%s"`, fileName)), + jsonhttptest.WithExpectedResponse(simpleData), + ) + }) + t.Run("filter out filename path", func(t *testing.T) { fileName := "my-pictures.jpeg" fileNameWithPath := "../../" + fileName @@ -437,35 +662,57 @@ func TestBzzFilesRangeRequests(t *testing.T) { } } -func createRangeHeader(data []byte, ranges [][2]int) (header string, parts [][]byte) { - header = "bytes=" - for i, r := range ranges { - if i > 0 { - header += ", " +func createRangeHeader(data interface{}, ranges [][2]int) (header string, parts [][]byte) { + getLen := func() int { + switch data := data.(type) { + case []byte: + return len(data) + case interface{ Size() int }: + return data.Size() + default: + panic("unknown data type") } - if r[0] >= 0 && r[1] >= 0 { - parts = append(parts, data[r[0]:r[1]]) - // Range: =-, end is inclusive - header += fmt.Sprintf("%v-%v", r[0], r[1]-1) - } else { - if r[0] >= 0 { - header += strconv.Itoa(r[0]) // Range: =- - parts = append(parts, data[r[0]:]) + } + getRange := func(start, end int) []byte { + switch data := data.(type) { + case []byte: + return data[start:end] + case io.ReadSeeker: + buf := make([]byte, end-start) + _, err := data.Seek(int64(start), io.SeekStart) + if err != nil { + panic(err) } - header += "-" - if r[1] >= 0 { - if r[0] >= 0 { - // Range: =-, end is inclusive - header += strconv.Itoa(r[1] - 1) - } else { - // Range: =-, the parameter is length - header += strconv.Itoa(r[1]) - } - parts = append(parts, data[:r[1]]) + _, err = io.ReadFull(data, buf) + if err != nil { + panic(err) } + return buf + default: + panic("unknown data type") + } + } + + rangeStrs := make([]string, len(ranges)) + for i, r := range ranges { + start, end := r[0], r[1] + switch { + case start < 0: + // Range: =-, the parameter is length + rangeStrs[i] = "-" + strconv.Itoa(end) + start = 0 + case r[1] < 0: + // Range: =- + rangeStrs[i] = strconv.Itoa(start) + "-" + end = getLen() + default: + // Range: =-, end is inclusive + rangeStrs[i] = fmt.Sprintf("%v-%v", start, end-1) } + parts = append(parts, getRange(start, end)) } - return + header = "bytes=" + strings.Join(rangeStrs, ", ") // nolint:staticcheck + return header, parts } func parseRangeParts(t *testing.T, contentType string, body []byte) (parts [][]byte) { @@ -554,7 +801,7 @@ func TestFeedIndirection(t *testing.T) { t.Fatal(err) } m, err := manifest.NewDefaultManifest( - loadsave.New(storer.ChunkStore(), pipelineFactory(storer.Cache(), false)), + loadsave.New(storer.ChunkStore(), storer.Cache(), pipelineFactory(storer.Cache(), false, 0)), false, ) if err != nil { diff --git a/pkg/api/debugstorage_test.go b/pkg/api/debugstorage_test.go index b83ddb5af99..72d8800b882 100644 --- a/pkg/api/debugstorage_test.go +++ b/pkg/api/debugstorage_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" - storer "github.com/ethersphere/bee/pkg/storer" + "github.com/ethersphere/bee/pkg/storer" mockstorer "github.com/ethersphere/bee/pkg/storer/mock" ) diff --git a/pkg/api/dirs.go b/pkg/api/dirs.go index e22c9cacf21..06aff9c59d6 100644 --- a/pkg/api/dirs.go +++ b/pkg/api/dirs.go @@ -19,6 +19,7 @@ import ( "strings" "github.com/ethersphere/bee/pkg/file/loadsave" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/manifest" @@ -45,6 +46,7 @@ func (s *Service) dirUploadHandler( contentTypeString string, encrypt bool, tag uint64, + rLevel redundancy.Level, ) { if r.Body == http.NoBody { logger.Error(nil, "request has no body") @@ -77,6 +79,7 @@ func (s *Service) dirUploadHandler( s.storer.ChunkStore(), r.Header.Get(SwarmIndexDocumentHeader), r.Header.Get(SwarmErrorDocumentHeader), + rLevel, ) if err != nil { logger.Debug("store dir failed", "error", err) @@ -125,13 +128,14 @@ func storeDir( getter storage.Getter, indexFilename, errorFilename string, + rLevel redundancy.Level, ) (swarm.Address, error) { logger := tracing.NewLoggerWithTraceID(ctx, log) loggerV1 := logger.V(1).Build() - p := requestPipelineFn(putter, encrypt) - ls := loadsave.New(getter, requestPipelineFactory(ctx, putter, encrypt)) + p := requestPipelineFn(putter, encrypt, rLevel) + ls := loadsave.New(getter, putter, requestPipelineFactory(ctx, putter, encrypt, rLevel)) dirManifest, err := manifest.NewDefaultManifest(ls, encrypt) if err != nil { diff --git a/pkg/api/feed.go b/pkg/api/feed.go index d85f89fe9b3..23961e6d5fc 100644 --- a/pkg/api/feed.go +++ b/pkg/api/feed.go @@ -196,7 +196,7 @@ func (s *Service) feedPostHandler(w http.ResponseWriter, r *http.Request) { logger: logger, } - l := loadsave.New(s.storer.ChunkStore(), requestPipelineFactory(r.Context(), putter, false)) + l := loadsave.New(s.storer.ChunkStore(), s.storer.Cache(), requestPipelineFactory(r.Context(), putter, false, 0)) feedManifest, err := manifest.NewDefaultManifest(l, false) if err != nil { logger.Debug("create manifest failed", "error", err) diff --git a/pkg/api/pin.go b/pkg/api/pin.go index a98af277583..7c80e5f196b 100644 --- a/pkg/api/pin.go +++ b/pkg/api/pin.go @@ -50,7 +50,7 @@ func (s *Service) pinRootHash(w http.ResponseWriter, r *http.Request) { } getter := s.storer.Download(true) - traverser := traversal.New(getter) + traverser := traversal.New(getter, s.storer.Cache()) sem := semaphore.NewWeighted(100) var errTraverse error diff --git a/pkg/encryption/store/decrypt_store.go b/pkg/encryption/store/decrypt_store.go index cf16b8ac3cd..7a9f0395e91 100644 --- a/pkg/encryption/store/decrypt_store.go +++ b/pkg/encryption/store/decrypt_store.go @@ -9,6 +9,8 @@ import ( "encoding/binary" "github.com/ethersphere/bee/pkg/encryption" + "github.com/ethersphere/bee/pkg/file" + "github.com/ethersphere/bee/pkg/file/redundancy" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "golang.org/x/crypto/sha3" @@ -37,7 +39,7 @@ func (s *decryptingStore) Get(ctx context.Context, addr swarm.Address) (ch swarm return nil, err } - d, err := decryptChunkData(ch.Data(), ref[swarm.HashSize:]) + d, err := DecryptChunkData(ch.Data(), ref[swarm.HashSize:]) if err != nil { return nil, err } @@ -48,19 +50,19 @@ func (s *decryptingStore) Get(ctx context.Context, addr swarm.Address) (ch swarm } } -func decryptChunkData(chunkData []byte, encryptionKey encryption.Key) ([]byte, error) { +func DecryptChunkData(chunkData []byte, encryptionKey encryption.Key) ([]byte, error) { decryptedSpan, decryptedData, err := decrypt(chunkData, encryptionKey) if err != nil { return nil, err } // removing extra bytes which were just added for padding - length := binary.LittleEndian.Uint64(decryptedSpan) - refSize := int64(swarm.HashSize + encryption.KeyLength) - for length > swarm.ChunkSize { - length = length + (swarm.ChunkSize - 1) - length = length / swarm.ChunkSize - length *= uint64(refSize) + level, span := redundancy.DecodeSpan(decryptedSpan) + length := binary.LittleEndian.Uint64(span) + if length > swarm.ChunkSize { + dataRefSize := uint64(swarm.HashSize + encryption.KeyLength) + dataShards, parities := file.ReferenceCount(length, level, true) + length = dataRefSize*uint64(dataShards) + uint64(parities*swarm.HashSize) } c := make([]byte, length+8) diff --git a/pkg/file/addresses/addresses_getter_test.go b/pkg/file/addresses/addresses_getter_test.go index 542c2d8a442..a5fe9d00000 100644 --- a/pkg/file/addresses/addresses_getter_test.go +++ b/pkg/file/addresses/addresses_getter_test.go @@ -63,7 +63,7 @@ func TestAddressesGetterIterateChunkAddresses(t *testing.T) { addressesGetter := addresses.NewGetter(store, addressIterFunc) - j, _, err := joiner.New(ctx, addressesGetter, rootChunk.Address()) + j, _, err := joiner.New(ctx, addressesGetter, store, rootChunk.Address()) if err != nil { t.Fatal(err) } diff --git a/pkg/file/file_test.go b/pkg/file/file_test.go index 951b679d591..cc677336b64 100644 --- a/pkg/file/file_test.go +++ b/pkg/file/file_test.go @@ -48,7 +48,7 @@ func testSplitThenJoin(t *testing.T) { paramstring = strings.Split(t.Name(), "/") dataIdx, _ = strconv.ParseInt(paramstring[1], 10, 0) store = inmemchunkstore.New() - p = builder.NewPipelineBuilder(context.Background(), store, false) + p = builder.NewPipelineBuilder(context.Background(), store, false, 0) data, _ = test.GetVector(t, int(dataIdx)) ) @@ -62,7 +62,7 @@ func testSplitThenJoin(t *testing.T) { } // then join - r, l, err := joiner.New(ctx, store, resultAddress) + r, l, err := joiner.New(ctx, store, store, resultAddress) if err != nil { t.Fatal(err) } diff --git a/pkg/file/joiner/joiner.go b/pkg/file/joiner/joiner.go index b046c6bd0fd..dc9c850084c 100644 --- a/pkg/file/joiner/joiner.go +++ b/pkg/file/joiner/joiner.go @@ -7,51 +7,143 @@ package joiner import ( "context" - "encoding/binary" "errors" "io" "sync" "sync/atomic" + "github.com/ethersphere/bee/pkg/bmt" "github.com/ethersphere/bee/pkg/encryption" "github.com/ethersphere/bee/pkg/encryption/store" "github.com/ethersphere/bee/pkg/file" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/file/redundancy/getter" + "github.com/ethersphere/bee/pkg/replicas" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "golang.org/x/sync/errgroup" ) type joiner struct { - addr swarm.Address - rootData []byte - span int64 - off int64 - refLength int - - ctx context.Context - getter storage.Getter + addr swarm.Address + rootData []byte + span int64 + off int64 + refLength int + rootParity int + maxBranching int // maximum branching in an intermediate chunk + + ctx context.Context + decoders *decoderCache + chunkToSpan func(data []byte) (redundancy.Level, int64) // returns parity and span value from chunkData +} + +// decoderCache is cache of decoders for intermediate chunks +type decoderCache struct { + fetcher storage.Getter // network retrieval interface to fetch chunks + putter storage.Putter // interface to local storage to save reconstructed chunks + mu sync.Mutex // mutex to protect cache + cache map[string]storage.Getter // map from chunk address to RS getter + config getter.Config // getter configuration +} + +// NewDecoderCache creates a new decoder cache +func NewDecoderCache(g storage.Getter, p storage.Putter, conf getter.Config) *decoderCache { + return &decoderCache{ + fetcher: g, + putter: p, + cache: make(map[string]storage.Getter), + config: conf, + } +} + +func fingerprint(addrs []swarm.Address) string { + h := swarm.NewHasher() + for _, addr := range addrs { + _, _ = h.Write(addr.Bytes()) + } + return string(h.Sum(nil)) +} + +// GetOrCreate returns a decoder for the given chunk address +func (g *decoderCache) GetOrCreate(addrs []swarm.Address, shardCnt int) storage.Getter { + if len(addrs) == shardCnt { + return g.fetcher + } + key := fingerprint(addrs) + g.mu.Lock() + defer g.mu.Unlock() + d, ok := g.cache[key] + if ok { + if d == nil { + return g.fetcher + } + return d + } + remove := func() { + g.mu.Lock() + defer g.mu.Unlock() + g.cache[key] = nil + } + d = getter.New(addrs, shardCnt, g.fetcher, g.putter, remove, g.config) + g.cache[key] = d + return d } // New creates a new Joiner. A Joiner provides Read, Seek and Size functionalities. -func New(ctx context.Context, getter storage.Getter, address swarm.Address) (file.Joiner, int64, error) { - getter = store.New(getter) +func New(ctx context.Context, g storage.Getter, putter storage.Putter, address swarm.Address) (file.Joiner, int64, error) { // retrieve the root chunk to read the total data length the be retrieved - rootChunk, err := getter.Get(ctx, address) + rLevel := redundancy.GetLevelFromContext(ctx) + rootChunkGetter := store.New(g) + if rLevel != redundancy.NONE { + rootChunkGetter = store.New(replicas.NewGetter(g, rLevel)) + } + rootChunk, err := rootChunkGetter.Get(ctx, address) if err != nil { return nil, 0, err } - var chunkData = rootChunk.Data() - - span := int64(binary.LittleEndian.Uint64(chunkData[:swarm.SpanSize])) + chunkData := rootChunk.Data() + rootData := chunkData[swarm.SpanSize:] + refLength := len(address.Bytes()) + encryption := refLength == encryption.ReferenceSize + rLevel, span := chunkToSpan(chunkData) + rootParity := 0 + maxBranching := swarm.ChunkSize / refLength + spanFn := func(data []byte) (redundancy.Level, int64) { + return 0, int64(bmt.LengthFromSpan(data[:swarm.SpanSize])) + } + conf, err := getter.NewConfigFromContext(ctx, getter.DefaultConfig) + if err != nil { + return nil, 0, err + } + // override stuff if root chunk has redundancy + if rLevel != redundancy.NONE { + _, parities := file.ReferenceCount(uint64(span), rLevel, encryption) + rootParity = parities + + spanFn = chunkToSpan + if encryption { + maxBranching = rLevel.GetMaxEncShards() + } else { + maxBranching = rLevel.GetMaxShards() + } + } else { + // if root chunk has no redundancy, strategy is ignored and set to NONE and strict is set to true + conf.Strategy = getter.DATA + conf.Strict = true + } j := &joiner{ - addr: rootChunk.Address(), - refLength: len(address.Bytes()), - ctx: ctx, - getter: getter, - span: span, - rootData: chunkData[swarm.SpanSize:], + addr: rootChunk.Address(), + refLength: refLength, + ctx: ctx, + decoders: NewDecoderCache(g, putter, conf), + span: span, + rootData: rootData, + rootParity: rootParity, + maxBranching: maxBranching, + chunkToSpan: spanFn, } return j, span, nil @@ -81,7 +173,7 @@ func (j *joiner) ReadAt(buffer []byte, off int64) (read int, err error) { } var bytesRead int64 var eg errgroup.Group - j.readAtOffset(buffer, j.rootData, 0, j.span, off, 0, readLen, &bytesRead, &eg) + j.readAtOffset(buffer, j.rootData, 0, j.span, off, 0, readLen, &bytesRead, j.rootParity, &eg) err = eg.Wait() if err != nil { @@ -93,7 +185,13 @@ func (j *joiner) ReadAt(buffer []byte, off int64) (read int, err error) { var ErrMalformedTrie = errors.New("malformed tree") -func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffset, bytesToRead int64, bytesRead *int64, eg *errgroup.Group) { +func (j *joiner) readAtOffset( + b, data []byte, + cur, subTrieSize, off, bufferOffset, bytesToRead int64, + bytesRead *int64, + parity int, + eg *errgroup.Group, +) { // we are at a leaf data chunk if subTrieSize <= int64(len(data)) { dataOffsetStart := off - cur @@ -108,21 +206,30 @@ func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffse atomic.AddInt64(bytesRead, int64(n)) return } + pSize, err := file.ChunkPayloadSize(data) + if err != nil { + eg.Go(func() error { + return err + }) + return + } + addrs, shardCnt := file.ChunkAddresses(data[:pSize], parity, j.refLength) + g := store.New(j.decoders.GetOrCreate(addrs, shardCnt)) for cursor := 0; cursor < len(data); cursor += j.refLength { if bytesToRead == 0 { break } // fast forward the cursor - sec := subtrieSection(data, cursor, j.refLength, subTrieSize) - if cur+sec < off { + sec := j.subtrieSection(data, cursor, pSize, parity, subTrieSize) + if cur+sec <= off { cur += sec continue } // if we are here it means that we are within the bounds of the data we need to read - address := swarm.NewAddress(data[cursor : cursor+j.refLength]) + addr := swarm.NewAddress(data[cursor : cursor+j.refLength]) subtrieSpan := sec subtrieSpanLimit := sec @@ -139,22 +246,23 @@ func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffse func(address swarm.Address, b []byte, cur, subTrieSize, off, bufferOffset, bytesToRead, subtrieSpanLimit int64) { eg.Go(func() error { - ch, err := j.getter.Get(j.ctx, address) + ch, err := g.Get(j.ctx, addr) if err != nil { return err } chunkData := ch.Data()[8:] - subtrieSpan := int64(chunkToSpan(ch.Data())) + subtrieLevel, subtrieSpan := j.chunkToSpan(ch.Data()) + _, subtrieParity := file.ReferenceCount(uint64(subtrieSpan), subtrieLevel, j.refLength == encryption.ReferenceSize) if subtrieSpan > subtrieSpanLimit { return ErrMalformedTrie } - j.readAtOffset(b, chunkData, cur, subtrieSpan, off, bufferOffset, currentReadSize, bytesRead, eg) + j.readAtOffset(b, chunkData, cur, subtrieSpan, off, bufferOffset, currentReadSize, bytesRead, subtrieParity, eg) return nil }) - }(address, b, cur, subtrieSpan, off, bufferOffset, currentReadSize, subtrieSpanLimit) + }(addr, b, cur, subtrieSpan, off, bufferOffset, currentReadSize, subtrieSpanLimit) bufferOffset += currentReadSize bytesToRead -= currentReadSize @@ -163,8 +271,13 @@ func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffse } } +// getShards returns the effective reference number respective to the intermediate chunk payload length and its parities +func (j *joiner) getShards(payloadSize, parities int) int { + return (payloadSize - parities*swarm.HashSize) / j.refLength +} + // brute-forces the subtrie size for each of the sections in this intermediate chunk -func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64 { +func (j *joiner) subtrieSection(data []byte, startIdx, payloadSize, parities int, subtrieSize int64) int64 { // assume we have a trie of size `y` then we can assume that all of // the forks except for the last one on the right are of equal size // this is due to how the splitter wraps levels. @@ -173,9 +286,9 @@ func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64 // where y is the size of the subtrie, refs are the number of references // x is constant (the brute forced value) and l is the size of the last subtrie var ( - refs = int64(len(data) / refLen) // how many references in the intermediate chunk - branching = int64(4096 / refLen) // branching factor is chunkSize divided by reference length - branchSize = int64(4096) + refs = int64(j.getShards(payloadSize, parities)) // how many effective references in the intermediate chunk + branching = int64(j.maxBranching) // branching factor is chunkSize divided by reference length + branchSize = int64(swarm.ChunkSize) ) for { whatsLeft := subtrieSize - (branchSize * (refs - 1)) @@ -186,7 +299,7 @@ func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64 } // handle last branch edge case - if startIdx == int(refs-1)*refLen { + if startIdx == int(refs-1)*j.refLength { return subtrieSize - (refs-1)*branchSize } return branchSize @@ -229,10 +342,10 @@ func (j *joiner) IterateChunkAddresses(fn swarm.AddressIterFunc) error { return err } - return j.processChunkAddresses(j.ctx, fn, j.rootData, j.span) + return j.processChunkAddresses(j.ctx, fn, j.rootData, j.span, j.rootParity) } -func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIterFunc, data []byte, subTrieSize int64) error { +func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIterFunc, data []byte, subTrieSize int64, parity int) error { // we are at a leaf data chunk if subTrieSize <= int64(len(data)) { return nil @@ -248,42 +361,43 @@ func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIter var wg sync.WaitGroup - for cursor := 0; cursor < len(data); cursor += j.refLength { - ref := data[cursor : cursor+j.refLength] - var reportAddr swarm.Address - address := swarm.NewAddress(ref) - if len(ref) == encryption.ReferenceSize { - reportAddr = swarm.NewAddress(ref[:swarm.HashSize]) - } else { - reportAddr = swarm.NewAddress(ref) - } - - if err := fn(reportAddr); err != nil { + eSize, err := file.ChunkPayloadSize(data) + if err != nil { + return err + } + addrs, shardCnt := file.ChunkAddresses(data[:eSize], parity, j.refLength) + g := store.New(j.decoders.GetOrCreate(addrs, shardCnt)) + for i, addr := range addrs { + if err := fn(addr); err != nil { return err } - - sec := subtrieSection(data, cursor, j.refLength, subTrieSize) + cursor := i * swarm.HashSize + if j.refLength == encryption.ReferenceSize { + cursor += swarm.HashSize * min(i, shardCnt) + } + sec := j.subtrieSection(data, cursor, eSize, parity, subTrieSize) if sec <= swarm.ChunkSize { continue } - func(address swarm.Address, eg *errgroup.Group) { - wg.Add(1) + wg.Add(1) + eg.Go(func() error { + defer wg.Done() - eg.Go(func() error { - defer wg.Done() + if j.refLength == encryption.ReferenceSize && i < shardCnt { + addr = swarm.NewAddress(data[cursor : cursor+swarm.HashSize*2]) + } + ch, err := g.Get(ectx, addr) + if err != nil { + return err + } - ch, err := j.getter.Get(ectx, address) - if err != nil { - return err - } + chunkData := ch.Data()[8:] + subtrieLevel, subtrieSpan := j.chunkToSpan(ch.Data()) + _, parities := file.ReferenceCount(uint64(subtrieSpan), subtrieLevel, j.refLength != swarm.HashSize) - chunkData := ch.Data()[8:] - subtrieSpan := int64(chunkToSpan(ch.Data())) - - return j.processChunkAddresses(ectx, fn, chunkData, subtrieSpan) - }) - }(address, eg) + return j.processChunkAddresses(ectx, fn, chunkData, subtrieSpan, parities) + }) wg.Wait() } @@ -295,6 +409,9 @@ func (j *joiner) Size() int64 { return j.span } -func chunkToSpan(data []byte) uint64 { - return binary.LittleEndian.Uint64(data[:8]) +// chunkToSpan returns redundancy level and span value +// in the types that the package uses +func chunkToSpan(data []byte) (redundancy.Level, int64) { + level, spanBytes := redundancy.DecodeSpan(data[:swarm.SpanSize]) + return level, int64(bmt.LengthFromSpan(spanBytes)) } diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 6063df6d846..15d46bf220b 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -7,6 +7,7 @@ package joiner_test import ( "bytes" "context" + "crypto/rand" "encoding/binary" "errors" "fmt" @@ -17,24 +18,31 @@ import ( "time" "github.com/ethersphere/bee/pkg/cac" - "github.com/ethersphere/bee/pkg/encryption/store" "github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/pipeline/builder" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/file/redundancy/getter" "github.com/ethersphere/bee/pkg/file/splitter" filetest "github.com/ethersphere/bee/pkg/file/testing" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" testingc "github.com/ethersphere/bee/pkg/storage/testing" + mockstorer "github.com/ethersphere/bee/pkg/storer/mock" "github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/util/testutil" + "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" + "github.com/ethersphere/bee/pkg/util/testutil/racedetection" "gitlab.com/nolash/go-mockbytes" + "golang.org/x/sync/errgroup" ) +// nolint:paralleltest,tparallel,thelper + func TestJoiner_ErrReferenceLength(t *testing.T) { t.Parallel() store := inmemchunkstore.New() - _, _, err := joiner.New(context.Background(), store, swarm.ZeroAddress) + _, _, err := joiner.New(context.Background(), store, store, swarm.ZeroAddress) if !errors.Is(err, storage.ErrReferenceLength) { t.Fatalf("expected ErrReferenceLength %x but got %v", swarm.ZeroAddress, err) @@ -64,7 +72,7 @@ func TestJoinerSingleChunk(t *testing.T) { } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, mockAddr) + joinReader, l, err := joiner.New(ctx, store, store, mockAddr) if err != nil { t.Fatal(err) } @@ -86,7 +94,6 @@ func TestJoinerDecryptingStore_NormalChunk(t *testing.T) { t.Parallel() st := inmemchunkstore.New() - store := store.New(st) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -104,7 +111,7 @@ func TestJoinerDecryptingStore_NormalChunk(t *testing.T) { } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, mockAddr) + joinReader, l, err := joiner.New(ctx, st, st, mockAddr) if err != nil { t.Fatal(err) } @@ -125,34 +132,34 @@ func TestJoinerDecryptingStore_NormalChunk(t *testing.T) { func TestJoinerWithReference(t *testing.T) { t.Parallel() - store := inmemchunkstore.New() + st := inmemchunkstore.New() ctx, cancel := context.WithCancel(context.Background()) defer cancel() // create root chunk and two data chunks referenced in the root chunk rootChunk := filetest.GenerateTestRandomFileChunk(swarm.ZeroAddress, swarm.ChunkSize*2, swarm.SectionSize*2) - err := store.Put(ctx, rootChunk) + err := st.Put(ctx, rootChunk) if err != nil { t.Fatal(err) } firstAddress := swarm.NewAddress(rootChunk.Data()[8 : swarm.SectionSize+8]) firstChunk := filetest.GenerateTestRandomFileChunk(firstAddress, swarm.ChunkSize, swarm.ChunkSize) - err = store.Put(ctx, firstChunk) + err = st.Put(ctx, firstChunk) if err != nil { t.Fatal(err) } secondAddress := swarm.NewAddress(rootChunk.Data()[swarm.SectionSize+8:]) secondChunk := filetest.GenerateTestRandomFileChunk(secondAddress, swarm.ChunkSize, swarm.ChunkSize) - err = store.Put(ctx, secondChunk) + err = st.Put(ctx, secondChunk) if err != nil { t.Fatal(err) } // read back data and compare - joinReader, l, err := joiner.New(ctx, store, rootChunk.Address()) + joinReader, l, err := joiner.New(ctx, st, st, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -182,7 +189,7 @@ func TestJoinerMalformed(t *testing.T) { defer cancel() subTrie := []byte{8085: 1} - pb := builder.NewPipelineBuilder(ctx, store, false) + pb := builder.NewPipelineBuilder(ctx, store, false, 0) c1addr, _ := builder.FeedPipeline(ctx, pb, bytes.NewReader(subTrie)) chunk2 := testingc.GenerateTestRandomChunk() @@ -208,7 +215,7 @@ func TestJoinerMalformed(t *testing.T) { t.Fatal(err) } - joinReader, _, err := joiner.New(ctx, store, rootChunk.Address()) + joinReader, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -248,13 +255,13 @@ func TestEncryptDecrypt(t *testing.T) { t.Fatal(err) } ctx := context.Background() - pipe := builder.NewPipelineBuilder(ctx, store, true) + pipe := builder.NewPipelineBuilder(ctx, store, true, 0) testDataReader := bytes.NewReader(testData) resultAddress, err := builder.FeedPipeline(ctx, pipe, testDataReader) if err != nil { t.Fatal(err) } - reader, l, err := joiner.New(context.Background(), store, resultAddress) + reader, l, err := joiner.New(context.Background(), store, store, resultAddress) if err != nil { t.Fatal(err) } @@ -341,7 +348,7 @@ func TestSeek(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, addr) + j, _, err := joiner.New(ctx, store, store, addr) if err != nil { t.Fatal(err) } @@ -618,7 +625,7 @@ func TestPrefetch(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, addr) + j, _, err := joiner.New(ctx, store, store, addr) if err != nil { t.Fatal(err) } @@ -667,7 +674,7 @@ func TestJoinerReadAt(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -714,7 +721,7 @@ func TestJoinerOneLevel(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -808,7 +815,7 @@ func TestJoinerTwoLevelsAcrossChunk(t *testing.T) { t.Fatal(err) } - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -864,7 +871,7 @@ func TestJoinerIterateChunkAddresses(t *testing.T) { createdAddresses := []swarm.Address{rootChunk.Address(), firstAddress, secondAddress} - j, _, err := joiner.New(ctx, store, rootChunk.Address()) + j, _, err := joiner.New(ctx, store, store, rootChunk.Address()) if err != nil { t.Fatal(err) } @@ -911,13 +918,13 @@ func TestJoinerIterateChunkAddresses_Encrypted(t *testing.T) { t.Fatal(err) } ctx := context.Background() - pipe := builder.NewPipelineBuilder(ctx, store, true) + pipe := builder.NewPipelineBuilder(ctx, store, true, 0) testDataReader := bytes.NewReader(testData) resultAddress, err := builder.FeedPipeline(ctx, pipe, testDataReader) if err != nil { t.Fatal(err) } - j, l, err := joiner.New(context.Background(), store, resultAddress) + j, l, err := joiner.New(context.Background(), store, store, resultAddress) if err != nil { t.Fatal(err) } @@ -951,3 +958,405 @@ func TestJoinerIterateChunkAddresses_Encrypted(t *testing.T) { } } } + +type mockPutter struct { + storage.ChunkStore + shards, parities chan swarm.Chunk + done chan struct{} +} + +func newMockPutter(store storage.ChunkStore, shardCnt, parityCnt int) *mockPutter { + return &mockPutter{ + ChunkStore: store, + done: make(chan struct{}, 1), + shards: make(chan swarm.Chunk, shardCnt), + parities: make(chan swarm.Chunk, parityCnt), + } +} + +func (m *mockPutter) Put(ctx context.Context, ch swarm.Chunk) error { + if len(m.shards) < cap(m.shards) { + m.shards <- ch + return nil + } + if len(m.parities) < cap(m.parities) { + m.parities <- ch + return nil + } + err := m.ChunkStore.Put(context.Background(), ch) + select { + case m.done <- struct{}{}: + default: + } + return err +} + +func (m *mockPutter) wait(ctx context.Context) { + select { + case <-m.done: + case <-ctx.Done(): + } + close(m.parities) + close(m.shards) +} + +func (m *mockPutter) store(cnt int) error { + n := 0 + for ch := range m.parities { + if err := m.ChunkStore.Put(context.Background(), ch); err != nil { + return err + } + n++ + if n == cnt { + return nil + } + } + for ch := range m.shards { + if err := m.ChunkStore.Put(context.Background(), ch); err != nil { + return err + } + n++ + if n == cnt { + break + } + } + return nil +} + +// nolint:thelper +func TestJoinerRedundancy(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + rLevel redundancy.Level + encryptChunk bool + }{ + { + redundancy.MEDIUM, + false, + }, + { + redundancy.MEDIUM, + true, + }, + { + redundancy.STRONG, + false, + }, + { + redundancy.STRONG, + true, + }, + { + redundancy.INSANE, + false, + }, + { + redundancy.INSANE, + true, + }, + { + redundancy.PARANOID, + false, + }, + { + redundancy.PARANOID, + true, + }, + } { + tc := tc + t.Run(fmt.Sprintf("redundancy=%d encryption=%t", tc.rLevel, tc.encryptChunk), func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + shardCnt := tc.rLevel.GetMaxShards() + parityCnt := tc.rLevel.GetParities(shardCnt) + if tc.encryptChunk { + shardCnt = tc.rLevel.GetMaxEncShards() + parityCnt = tc.rLevel.GetEncParities(shardCnt) + } + store := inmemchunkstore.New() + putter := newMockPutter(store, shardCnt, parityCnt) + pipe := builder.NewPipelineBuilder(ctx, putter, tc.encryptChunk, tc.rLevel) + dataChunks := make([]swarm.Chunk, shardCnt) + chunkSize := swarm.ChunkSize + size := chunkSize + for i := 0; i < shardCnt; i++ { + if i == shardCnt-1 { + size = 5 + } + chunkData := make([]byte, size) + _, err := io.ReadFull(rand.Reader, chunkData) + if err != nil { + t.Fatal(err) + } + dataChunks[i], err = cac.New(chunkData) + if err != nil { + t.Fatal(err) + } + _, err = pipe.Write(chunkData) + if err != nil { + t.Fatal(err) + } + } + + // reader init + sum, err := pipe.Sum() + if err != nil { + t.Fatal(err) + } + swarmAddr := swarm.NewAddress(sum) + putter.wait(ctx) + _, err = store.Get(ctx, swarm.NewAddress(sum[:swarm.HashSize])) + if err != nil { + t.Fatal(err) + } + strategyTimeout := 100 * time.Millisecond + // all data can be read back + readCheck := func(t *testing.T, expErr error) { + ctx, cancel := context.WithTimeout(context.Background(), 15*strategyTimeout) + defer cancel() + ctx = getter.SetConfigInContext(ctx, getter.RACE, true, (10 * strategyTimeout).String(), strategyTimeout.String()) + joinReader, rootSpan, err := joiner.New(ctx, store, store, swarmAddr) + if err != nil { + t.Fatal(err) + } + // sanity checks + expectedRootSpan := chunkSize*(shardCnt-1) + 5 + if int64(expectedRootSpan) != rootSpan { + t.Fatalf("Expected root span %d. Got: %d", expectedRootSpan, rootSpan) + } + i := 0 + eg, ectx := errgroup.WithContext(ctx) + scnt: + for ; i < shardCnt; i++ { + select { + case <-ectx.Done(): + break scnt + default: + } + i := i + eg.Go(func() error { + chunkData := make([]byte, chunkSize) + n, err := joinReader.ReadAt(chunkData, int64(i*chunkSize)) + if err != nil { + return err + } + select { + case <-ectx.Done(): + return ectx.Err() + default: + } + expectedChunkData := dataChunks[i].Data()[swarm.SpanSize:] + if !bytes.Equal(expectedChunkData, chunkData[:n]) { + return fmt.Errorf("data mismatch on chunk position %d", i) + } + return nil + }) + } + err = eg.Wait() + + if !errors.Is(err, expErr) { + t.Fatalf("unexpected error reading chunkdata at chunk position %d: expected %v. got %v", i, expErr, err) + } + } + t.Run("no recovery possible with no chunk stored", func(t *testing.T) { + readCheck(t, context.DeadlineExceeded) + }) + + if err := putter.store(shardCnt - 1); err != nil { + t.Fatal(err) + } + t.Run("no recovery possible with 1 short of shardCnt chunks stored", func(t *testing.T) { + readCheck(t, context.DeadlineExceeded) + }) + + if err := putter.store(1); err != nil { + t.Fatal(err) + } + t.Run("recovery given shardCnt chunks stored", func(t *testing.T) { + readCheck(t, nil) + }) + + if err := putter.store(shardCnt + parityCnt); err != nil { + t.Fatal(err) + } + t.Run("success given shardCnt data chunks stored, no need for recovery", func(t *testing.T) { + readCheck(t, nil) + }) + // success after rootChunk deleted using replicas given shardCnt data chunks stored, no need for recovery + if err := store.Delete(ctx, swarm.NewAddress(swarmAddr.Bytes()[:swarm.HashSize])); err != nil { + t.Fatal(err) + } + t.Run("recover from replica if root deleted", func(t *testing.T) { + readCheck(t, nil) + }) + }) + } +} + +// TestJoinerRedundancyMultilevel tests the joiner with all combinations of +// redundancy level, encryption and size (levels, i.e., the height of the swarm hash tree). +// +// The test cases have the following structure: +// +// 1. upload a file with a given redundancy level and encryption +// +// 2. [positive test] download the file by the reference returned by the upload API response +// This uses range queries to target specific (number of) chunks of the file structure +// During path traversal in the swarm hash tree, the underlying mocksore (forgetting) +// is in 'recording' mode, flagging all the retrieved chunks as chunks to forget. +// This is to simulate the scenario where some of the chunks are not available/lost +// +// 3. [negative test] attempt at downloading the file using once again the same root hash +// and a no-redundancy strategy to find the file inaccessible after forgetting. +// 3a. [negative test] download file using NONE without fallback and fail +// 3b. [negative test] download file using DATA without fallback and fail +// +// 4. [positive test] download file using DATA with fallback to allow for +// reconstruction via erasure coding and succeed. +// +// 5. [positive test] after recovery chunks are saved, so fotgetting no longer +// repeat 3a/3b but this time succeed +// +// nolint:thelper +func TestJoinerRedundancyMultilevel(t *testing.T) { + t.Parallel() + test := func(t *testing.T, rLevel redundancy.Level, encrypt bool, levels, size int) { + t.Helper() + store := mockstorer.NewForgettingStore(inmemchunkstore.New()) + testutil.CleanupCloser(t, store) + seed, err := pseudorand.NewSeed() + if err != nil { + t.Fatal(err) + } + dataReader := pseudorand.NewReader(seed, size*swarm.ChunkSize) + ctx := context.Background() + // ctx = redundancy.SetLevelInContext(ctx, rLevel) + ctx = redundancy.SetLevelInContext(ctx, redundancy.NONE) + pipe := builder.NewPipelineBuilder(ctx, store, encrypt, rLevel) + addr, err := builder.FeedPipeline(ctx, pipe, dataReader) + if err != nil { + t.Fatal(err) + } + expRead := swarm.ChunkSize + buf := make([]byte, expRead) + offset := mrand.Intn(size) * expRead + canReadRange := func(t *testing.T, s getter.Strategy, fallback bool, levels int, canRead bool) { + ctx := context.Background() + strategyTimeout := 100 * time.Millisecond + decodingTimeout := 600 * time.Millisecond + if racedetection.IsOn() { + decodingTimeout *= 2 + } + ctx = getter.SetConfigInContext(ctx, s, fallback, (2 * strategyTimeout).String(), strategyTimeout.String()) + ctx, cancel := context.WithTimeout(ctx, time.Duration(levels)*(3*strategyTimeout+decodingTimeout)) + defer cancel() + j, _, err := joiner.New(ctx, store, store, addr) + if err != nil { + t.Fatal(err) + } + n, err := j.ReadAt(buf, int64(offset)) + if !canRead { + if !errors.Is(err, storage.ErrNotFound) && !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected error %v or %v. got %v", storage.ErrNotFound, context.DeadlineExceeded, err) + } + return + } + if err != nil { + t.Fatal(err) + } + if n != expRead { + t.Errorf("read %d bytes out of %d", n, expRead) + } + _, err = dataReader.Seek(int64(offset), io.SeekStart) + if err != nil { + t.Fatal(err) + } + ok, err := dataReader.Match(bytes.NewBuffer(buf), expRead) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("content mismatch") + } + } + + // first sanity check and and recover a range + t.Run("NONE w/o fallback CAN retrieve", func(t *testing.T) { + store.Record() + defer store.Unrecord() + canReadRange(t, getter.NONE, false, levels, true) + }) + + // do not forget the root chunk + store.Unmiss(swarm.NewAddress(addr.Bytes()[:swarm.HashSize])) + // after we forget the chunks on the way to the range, we should not be able to retrieve + t.Run("NONE w/o fallback CANNOT retrieve", func(t *testing.T) { + canReadRange(t, getter.NONE, false, levels, false) + }) + + // we lost a data chunk, we cannot recover using DATA only strategy with no fallback + t.Run("DATA w/o fallback CANNOT retrieve", func(t *testing.T) { + canReadRange(t, getter.DATA, false, levels, false) + }) + + if rLevel == 0 { + // allowing fallback mode will not help if no redundancy used for upload + t.Run("DATA w fallback CANNOT retrieve", func(t *testing.T) { + canReadRange(t, getter.DATA, true, levels, false) + }) + return + } + // allowing fallback mode will make the range retrievable using erasure decoding + t.Run("DATA w fallback CAN retrieve", func(t *testing.T) { + canReadRange(t, getter.DATA, true, levels, true) + }) + // after the reconstructed data is stored, we can retrieve the range using DATA only mode + t.Run("after recovery, NONE w/o fallback CAN retrieve", func(t *testing.T) { + canReadRange(t, getter.NONE, false, levels, true) + }) + } + r2level := []int{2, 1, 2, 3, 2} + encryptChunk := []bool{false, false, true, true, true} + for _, rLevel := range []redundancy.Level{0, 1, 2, 3, 4} { + rLevel := rLevel + // speeding up tests by skipping some of them + t.Run(fmt.Sprintf("rLevel=%v", rLevel), func(t *testing.T) { + t.Parallel() + for _, encrypt := range []bool{false, true} { + encrypt := encrypt + shardCnt := rLevel.GetMaxShards() + if encrypt { + shardCnt = rLevel.GetMaxEncShards() + } + for _, levels := range []int{1, 2, 3} { + chunkCnt := 1 + switch levels { + case 1: + chunkCnt = 2 + case 2: + chunkCnt = shardCnt + 1 + case 3: + chunkCnt = shardCnt*shardCnt + 1 + } + t.Run(fmt.Sprintf("encrypt=%v levels=%d chunks=%d incomplete", encrypt, levels, chunkCnt), func(t *testing.T) { + if r2level[rLevel] != levels || encrypt != encryptChunk[rLevel] { + t.Skip("skipping to save time") + } + test(t, rLevel, encrypt, levels, chunkCnt) + }) + switch levels { + case 1: + chunkCnt = shardCnt + case 2: + chunkCnt = shardCnt * shardCnt + case 3: + continue + } + t.Run(fmt.Sprintf("encrypt=%v levels=%d chunks=%d full", encrypt, levels, chunkCnt), func(t *testing.T) { + test(t, rLevel, encrypt, levels, chunkCnt) + }) + } + } + }) + } +} diff --git a/pkg/file/loadsave/loadsave.go b/pkg/file/loadsave/loadsave.go index 1f5267a4e8e..1d65e45ab6b 100644 --- a/pkg/file/loadsave/loadsave.go +++ b/pkg/file/loadsave/loadsave.go @@ -27,13 +27,15 @@ var errReadonlyLoadSave = errors.New("readonly manifest loadsaver") // load all of the subtrie of a given hash in memory. type loadSave struct { getter storage.Getter + putter storage.Putter pipelineFn func() pipeline.Interface } // New returns a new read-write load-saver. -func New(getter storage.Getter, pipelineFn func() pipeline.Interface) file.LoadSaver { +func New(getter storage.Getter, putter storage.Putter, pipelineFn func() pipeline.Interface) file.LoadSaver { return &loadSave{ getter: getter, + putter: putter, pipelineFn: pipelineFn, } } @@ -47,7 +49,7 @@ func NewReadonly(getter storage.Getter) file.LoadSaver { } func (ls *loadSave) Load(ctx context.Context, ref []byte) ([]byte, error) { - j, _, err := joiner.New(ctx, ls.getter, swarm.NewAddress(ref)) + j, _, err := joiner.New(ctx, ls.getter, ls.putter, swarm.NewAddress(ref)) if err != nil { return nil, err } diff --git a/pkg/file/loadsave/loadsave_test.go b/pkg/file/loadsave/loadsave_test.go index 21f4847bef9..64154859757 100644 --- a/pkg/file/loadsave/loadsave_test.go +++ b/pkg/file/loadsave/loadsave_test.go @@ -28,7 +28,7 @@ func TestLoadSave(t *testing.T) { t.Parallel() store := inmemchunkstore.New() - ls := loadsave.New(store, pipelineFn(store)) + ls := loadsave.New(store, store, pipelineFn(store)) ref, err := ls.Save(context.Background(), data) if err != nil { @@ -73,6 +73,6 @@ func TestReadonlyLoadSave(t *testing.T) { func pipelineFn(s storage.Putter) func() pipeline.Interface { return func() pipeline.Interface { - return builder.NewPipelineBuilder(context.Background(), s, false) + return builder.NewPipelineBuilder(context.Background(), s, false, 0) } } diff --git a/pkg/file/pipeline/builder/builder.go b/pkg/file/pipeline/builder/builder.go index 0833f50317c..d022ed5724a 100644 --- a/pkg/file/pipeline/builder/builder.go +++ b/pkg/file/pipeline/builder/builder.go @@ -17,23 +17,31 @@ import ( "github.com/ethersphere/bee/pkg/file/pipeline/feeder" "github.com/ethersphere/bee/pkg/file/pipeline/hashtrie" "github.com/ethersphere/bee/pkg/file/pipeline/store" + "github.com/ethersphere/bee/pkg/file/redundancy" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" ) // NewPipelineBuilder returns the appropriate pipeline according to the specified parameters -func NewPipelineBuilder(ctx context.Context, s storage.Putter, encrypt bool) pipeline.Interface { +func NewPipelineBuilder(ctx context.Context, s storage.Putter, encrypt bool, rLevel redundancy.Level) pipeline.Interface { if encrypt { - return newEncryptionPipeline(ctx, s) + return newEncryptionPipeline(ctx, s, rLevel) } - return newPipeline(ctx, s) + return newPipeline(ctx, s, rLevel) } // newPipeline creates a standard pipeline that only hashes content with BMT to create // a merkle-tree of hashes that represent the given arbitrary size byte stream. Partial // writes are supported. The pipeline flow is: Data -> Feeder -> BMT -> Storage -> HashTrie. -func newPipeline(ctx context.Context, s storage.Putter) pipeline.Interface { - tw := hashtrie.NewHashTrieWriter(swarm.ChunkSize, swarm.Branches, swarm.HashSize, newShortPipelineFunc(ctx, s)) +func newPipeline(ctx context.Context, s storage.Putter, rLevel redundancy.Level) pipeline.Interface { + pipeline := newShortPipelineFunc(ctx, s) + tw := hashtrie.NewHashTrieWriter( + ctx, + swarm.HashSize, + redundancy.New(rLevel, false, pipeline), + pipeline, + s, + ) lsw := store.NewStoreWriter(ctx, s, tw) b := bmt.NewBmtWriter(lsw) return feeder.NewChunkFeederWriter(swarm.ChunkSize, b) @@ -53,8 +61,14 @@ func newShortPipelineFunc(ctx context.Context, s storage.Putter) func() pipeline // writes are supported. The pipeline flow is: Data -> Feeder -> Encryption -> BMT -> Storage -> HashTrie. // Note that the encryption writer will mutate the data to contain the encrypted span, but the span field // with the unencrypted span is preserved. -func newEncryptionPipeline(ctx context.Context, s storage.Putter) pipeline.Interface { - tw := hashtrie.NewHashTrieWriter(swarm.ChunkSize, 64, swarm.HashSize+encryption.KeyLength, newShortEncryptionPipelineFunc(ctx, s)) +func newEncryptionPipeline(ctx context.Context, s storage.Putter, rLevel redundancy.Level) pipeline.Interface { + tw := hashtrie.NewHashTrieWriter( + ctx, + swarm.HashSize+encryption.KeyLength, + redundancy.New(rLevel, true, newShortPipelineFunc(ctx, s)), + newShortEncryptionPipelineFunc(ctx, s), + s, + ) lsw := store.NewStoreWriter(ctx, s, tw) b := bmt.NewBmtWriter(lsw) enc := enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b) diff --git a/pkg/file/pipeline/builder/builder_test.go b/pkg/file/pipeline/builder/builder_test.go index c720f25f356..266474d9376 100644 --- a/pkg/file/pipeline/builder/builder_test.go +++ b/pkg/file/pipeline/builder/builder_test.go @@ -23,7 +23,7 @@ func TestPartialWrites(t *testing.T) { t.Parallel() m := inmemchunkstore.New() - p := builder.NewPipelineBuilder(context.Background(), m, false) + p := builder.NewPipelineBuilder(context.Background(), m, false, 0) _, _ = p.Write([]byte("hello ")) _, _ = p.Write([]byte("world")) @@ -41,7 +41,7 @@ func TestHelloWorld(t *testing.T) { t.Parallel() m := inmemchunkstore.New() - p := builder.NewPipelineBuilder(context.Background(), m, false) + p := builder.NewPipelineBuilder(context.Background(), m, false, 0) data := []byte("hello world") _, err := p.Write(data) @@ -64,7 +64,7 @@ func TestEmpty(t *testing.T) { t.Parallel() m := inmemchunkstore.New() - p := builder.NewPipelineBuilder(context.Background(), m, false) + p := builder.NewPipelineBuilder(context.Background(), m, false, 0) data := []byte{} _, err := p.Write(data) @@ -92,7 +92,7 @@ func TestAllVectors(t *testing.T) { t.Parallel() m := inmemchunkstore.New() - p := builder.NewPipelineBuilder(context.Background(), m, false) + p := builder.NewPipelineBuilder(context.Background(), m, false, 0) _, err := p.Write(data) if err != nil { @@ -155,7 +155,7 @@ func benchmarkPipeline(b *testing.B, count int) { data := testutil.RandBytes(b, count) m := inmemchunkstore.New() - p := builder.NewPipelineBuilder(context.Background(), m, false) + p := builder.NewPipelineBuilder(context.Background(), m, false, 0) b.StartTimer() diff --git a/pkg/file/pipeline/hashtrie/hashtrie.go b/pkg/file/pipeline/hashtrie/hashtrie.go index c559a646c55..b3c898c76e1 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie.go +++ b/pkg/file/pipeline/hashtrie/hashtrie.go @@ -5,10 +5,15 @@ package hashtrie import ( + "context" "encoding/binary" "errors" + "fmt" "github.com/ethersphere/bee/pkg/file/pipeline" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/replicas" + "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" ) @@ -20,26 +25,44 @@ var ( const maxLevel = 8 type hashTrieWriter struct { - branching int - chunkSize int - refSize int - fullChunk int // full chunk size in terms of the data represented in the buffer (span+refsize) - cursors []int // level cursors, key is level. level 0 is data level and is not represented in this package. writes always start at level 1. higher levels will always have LOWER cursor values. - buffer []byte // keeps all level data - full bool // indicates whether the trie is full. currently we support (128^7)*4096 = 2305843009213693952 bytes - pipelineFn pipeline.PipelineFunc + ctx context.Context // context for put function of dispersed replica chunks + refSize int + cursors []int // level cursors, key is level. level 0 is data level holds how many chunks were processed. Intermediate higher levels will always have LOWER cursor values. + buffer []byte // keeps intermediate level data + full bool // indicates whether the trie is full. currently we support (128^7)*4096 = 2305843009213693952 bytes + pipelineFn pipeline.PipelineFunc + rParams redundancy.RedundancyParams + parityChunkFn redundancy.ParityChunkCallback + chunkCounters []uint8 // counts the chunk references in intermediate chunks. key is the chunk level. + effectiveChunkCounters []uint8 // counts the effective chunk references in intermediate chunks. key is the chunk level. + maxChildrenChunks uint8 // maximum number of chunk references in intermediate chunks. + replicaPutter storage.Putter // putter to save dispersed replicas of the root chunk } -func NewHashTrieWriter(chunkSize, branching, refLen int, pipelineFn pipeline.PipelineFunc) pipeline.ChainWriter { - return &hashTrieWriter{ - cursors: make([]int, 9), - buffer: make([]byte, swarm.ChunkWithSpanSize*9*2), // double size as temp workaround for weak calculation of needed buffer space - branching: branching, - chunkSize: chunkSize, - refSize: refLen, - fullChunk: (refLen + swarm.SpanSize) * branching, - pipelineFn: pipelineFn, +func NewHashTrieWriter( + ctx context.Context, + refLen int, + rParams redundancy.RedundancyParams, + pipelineFn pipeline.PipelineFunc, + replicaPutter storage.Putter, +) pipeline.ChainWriter { + h := &hashTrieWriter{ + ctx: ctx, + refSize: refLen, + cursors: make([]int, 9), + buffer: make([]byte, swarm.ChunkWithSpanSize*9*2), // double size as temp workaround for weak calculation of needed buffer space + rParams: rParams, + pipelineFn: pipelineFn, + chunkCounters: make([]uint8, 9), + effectiveChunkCounters: make([]uint8, 9), + maxChildrenChunks: uint8(rParams.MaxShards() + rParams.Parities(rParams.MaxShards())), + replicaPutter: replicas.NewPutter(replicaPutter), } + h.parityChunkFn = func(level int, span, address []byte) error { + return h.writeToIntermediateLevel(level, true, span, address, []byte{}) + } + + return h } // accepts writes of hashes from the previous writer in the chain, by definition these writes @@ -47,30 +70,51 @@ func NewHashTrieWriter(chunkSize, branching, refLen int, pipelineFn pipeline.Pip func (h *hashTrieWriter) ChainWrite(p *pipeline.PipeWriteArgs) error { oneRef := h.refSize + swarm.SpanSize l := len(p.Span) + len(p.Ref) + len(p.Key) - if l%oneRef != 0 { + if l%oneRef != 0 || l == 0 { return errInconsistentRefs } if h.full { return errTrieFull } - return h.writeToLevel(1, p.Span, p.Ref, p.Key) + if h.rParams.Level() == redundancy.NONE { + return h.writeToIntermediateLevel(1, false, p.Span, p.Ref, p.Key) + } else { + return h.writeToDataLevel(p.Span, p.Ref, p.Key, p.Data) + } } -func (h *hashTrieWriter) writeToLevel(level int, span, ref, key []byte) error { +func (h *hashTrieWriter) writeToIntermediateLevel(level int, parityChunk bool, span, ref, key []byte) error { copy(h.buffer[h.cursors[level]:h.cursors[level]+len(span)], span) h.cursors[level] += len(span) copy(h.buffer[h.cursors[level]:h.cursors[level]+len(ref)], ref) h.cursors[level] += len(ref) copy(h.buffer[h.cursors[level]:h.cursors[level]+len(key)], key) h.cursors[level] += len(key) - howLong := (h.refSize + swarm.SpanSize) * h.branching - if h.levelSize(level) == howLong { - return h.wrapFullLevel(level) + // update counters + if !parityChunk { + h.effectiveChunkCounters[level]++ + } + h.chunkCounters[level]++ + if h.chunkCounters[level] == h.maxChildrenChunks { + // at this point the erasure coded chunks have been written + err := h.wrapFullLevel(level) + return err } return nil } +// writeToDataLevel caches data chunks and call writeToIntermediateLevel +func (h *hashTrieWriter) writeToDataLevel(span, ref, key, data []byte) error { + // write dataChunks to the level above + err := h.writeToIntermediateLevel(1, false, span, ref, key) + if err != nil { + return err + } + + return h.rParams.ChunkWrite(0, data, h.parityChunkFn) +} + // wrapLevel wraps an existing level and writes the resulting hash to the following level // then truncates the current level data by shifting the cursors. // Steps are performed in the following order: @@ -81,20 +125,36 @@ func (h *hashTrieWriter) writeToLevel(level int, span, ref, key []byte) error { // - get the hash that was created, append it one level above, and if necessary, wrap that level too // - remove already hashed data from buffer // -// assumes that the function has been called when refsize+span*branching has been reached +// assumes that h.chunkCounters[level] has reached h.maxChildrenChunks at fullchunk +// or redundancy.Encode was called in case of rightmost chunks func (h *hashTrieWriter) wrapFullLevel(level int) error { data := h.buffer[h.cursors[level+1]:h.cursors[level]] sp := uint64(0) var hashes []byte - for i := 0; i < len(data); i += h.refSize + 8 { + offset := 0 + for i := uint8(0); i < h.effectiveChunkCounters[level]; i++ { // sum up the spans of the level, then we need to bmt them and store it as a chunk // then write the chunk address to the next level up - sp += binary.LittleEndian.Uint64(data[i : i+8]) - hash := data[i+8 : i+h.refSize+8] + sp += binary.LittleEndian.Uint64(data[offset : offset+swarm.SpanSize]) + offset += +swarm.SpanSize + hash := data[offset : offset+h.refSize] + offset += h.refSize hashes = append(hashes, hash...) } + parities := 0 + for offset < len(data) { + // we do not add span of parity chunks to the common because that is gibberish + offset += +swarm.SpanSize + hash := data[offset : offset+swarm.HashSize] // parity reference has always hash length + offset += swarm.HashSize + hashes = append(hashes, hash...) + parities++ + } spb := make([]byte, 8) binary.LittleEndian.PutUint64(spb, sp) + if parities > 0 { + redundancy.EncodeLevel(spb, h.rParams.Level()) + } hashes = append(spb, hashes...) writer := h.pipelineFn() args := pipeline.PipeWriteArgs{ @@ -105,7 +165,13 @@ func (h *hashTrieWriter) wrapFullLevel(level int) error { if err != nil { return err } - err = h.writeToLevel(level+1, args.Span, args.Ref, args.Key) + + err = h.writeToIntermediateLevel(level+1, false, args.Span, args.Ref, args.Key) + if err != nil { + return err + } + + err = h.rParams.ChunkWrite(level, args.Data, h.parityChunkFn) if err != nil { return err } @@ -113,19 +179,13 @@ func (h *hashTrieWriter) wrapFullLevel(level int) error { // this "truncates" the current level that was wrapped // by setting the cursors to the cursors of one level above h.cursors[level] = h.cursors[level+1] + h.chunkCounters[level], h.effectiveChunkCounters[level] = 0, 0 if level+1 == 8 { h.full = true } return nil } -func (h *hashTrieWriter) levelSize(level int) int { - if level == 8 { - return h.cursors[level] - } - return h.cursors[level] - h.cursors[level+1] -} - // Sum returns the Swarm merkle-root content-addressed hash // of an arbitrary-length binary data. // The algorithm it uses is as follows: @@ -142,25 +202,22 @@ func (h *hashTrieWriter) levelSize(level int) int { // - more than one hash, in which case we _do_ perform a hashing operation, appending the hash to // the next level func (h *hashTrieWriter) Sum() ([]byte, error) { - oneRef := h.refSize + swarm.SpanSize for i := 1; i < maxLevel; i++ { - l := h.levelSize(i) - if l%oneRef != 0 { - return nil, errInconsistentRefs - } + l := h.chunkCounters[i] switch { case l == 0: // level empty, continue to the next. continue - case l == h.fullChunk: + case l == h.maxChildrenChunks: // this case is possible and necessary due to the carry over // in the next switch case statement. normal writes done // through writeToLevel will automatically wrap a full level. + // erasure encoding call is not necessary since ElevateCarrierChunk solves that err := h.wrapFullLevel(i) if err != nil { return nil, err } - case l == oneRef: + case l == 1: // this cursor assignment basically means: // take the hash|span|key from this level, and append it to // the data of the next level. you may wonder how this works: @@ -175,20 +232,46 @@ func (h *hashTrieWriter) Sum() ([]byte, error) { // that might or might not have data. the eventual result is that the last // hash generated will always be carried over to the last level (8), then returned. h.cursors[i+1] = h.cursors[i] + // replace cached chunk to the level as well + err := h.rParams.ElevateCarrierChunk(i-1, h.parityChunkFn) + if err != nil { + return nil, err + } + // update counters, subtracting from current level is not necessary + h.effectiveChunkCounters[i+1]++ + h.chunkCounters[i+1]++ default: + // call erasure encoding before writing the last chunk on the level + err := h.rParams.Encode(i-1, h.parityChunkFn) + if err != nil { + return nil, err + } // more than 0 but smaller than chunk size - wrap the level to the one above it - err := h.wrapFullLevel(i) + err = h.wrapFullLevel(i) if err != nil { return nil, err } } } - levelLen := h.levelSize(8) - if levelLen != oneRef { + levelLen := h.chunkCounters[maxLevel] + if levelLen != 1 { return nil, errInconsistentRefs } // return the hash in the highest level, that's all we need - data := h.buffer[0:h.cursors[8]] - return data[8:], nil + data := h.buffer[0:h.cursors[maxLevel]] + rootHash := data[swarm.SpanSize:] + + // save disperse replicas of the root chunk + if h.rParams.Level() != redundancy.NONE { + rootData, err := h.rParams.GetRootData() + if err != nil { + return nil, err + } + err = h.replicaPutter.Put(h.ctx, swarm.NewChunk(swarm.NewAddress(rootHash[:swarm.HashSize]), rootData)) + if err != nil { + return nil, fmt.Errorf("hashtrie: cannot put dispersed replica %s", err.Error()) + } + } + return rootHash, nil } diff --git a/pkg/file/pipeline/hashtrie/hashtrie_test.go b/pkg/file/pipeline/hashtrie/hashtrie_test.go index 2078bbf18d1..4a966e265f3 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie_test.go +++ b/pkg/file/pipeline/hashtrie/hashtrie_test.go @@ -5,15 +5,26 @@ package hashtrie_test import ( + "bytes" "context" "encoding/binary" "errors" + "sync/atomic" "testing" + bmtUtils "github.com/ethersphere/bee/pkg/bmt" + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/encryption" + dec "github.com/ethersphere/bee/pkg/encryption/store" + "github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file/pipeline" "github.com/ethersphere/bee/pkg/file/pipeline/bmt" + enc "github.com/ethersphere/bee/pkg/file/pipeline/encryption" "github.com/ethersphere/bee/pkg/file/pipeline/hashtrie" + "github.com/ethersphere/bee/pkg/file/pipeline/mock" "github.com/ethersphere/bee/pkg/file/pipeline/store" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" "github.com/ethersphere/bee/pkg/swarm" ) @@ -34,13 +45,47 @@ func init() { binary.LittleEndian.PutUint64(span, 1) } +// newErasureHashTrieWriter returns back an redundancy param and a HastTrieWriter pipeline +// which are using simple BMT and StoreWriter pipelines for chunk writes +func newErasureHashTrieWriter( + ctx context.Context, + s storage.Putter, + rLevel redundancy.Level, + encryptChunks bool, + intermediateChunkPipeline, parityChunkPipeline pipeline.ChainWriter, + replicaPutter storage.Putter, +) (redundancy.RedundancyParams, pipeline.ChainWriter) { + pf := func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline) + return bmt.NewBmtWriter(lsw) + } + if encryptChunks { + pf = func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline) + b := bmt.NewBmtWriter(lsw) + return enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b) + } + } + ppf := func() pipeline.ChainWriter { + lsw := store.NewStoreWriter(ctx, s, parityChunkPipeline) + return bmt.NewBmtWriter(lsw) + } + + hashSize := swarm.HashSize + if encryptChunks { + hashSize *= 2 + } + + r := redundancy.New(rLevel, encryptChunks, ppf) + ht := hashtrie.NewHashTrieWriter(ctx, hashSize, r, pf, replicaPutter) + return r, ht +} + func TestLevels(t *testing.T) { t.Parallel() var ( - branching = 4 - chunkSize = 128 - hashSize = 32 + hashSize = 32 ) // to create a level wrap we need to do branching^(level-1) writes @@ -100,7 +145,7 @@ func TestLevels(t *testing.T) { return bmt.NewBmtWriter(lsw) } - ht := hashtrie.NewHashTrieWriter(chunkSize, branching, hashSize, pf) + ht := hashtrie.NewHashTrieWriter(ctx, hashSize, redundancy.New(0, false, pf), pf, s) for i := 0; i < tc.writes; i++ { a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span} @@ -129,21 +174,31 @@ func TestLevels(t *testing.T) { } } +type redundancyMock struct { + redundancy.Params +} + +func (r redundancyMock) MaxShards() int { + return 4 +} + func TestLevels_TrieFull(t *testing.T) { t.Parallel() var ( - branching = 4 - chunkSize = 128 - hashSize = 32 - writes = 16384 // this is to get a balanced trie - s = inmemchunkstore.New() - pf = func() pipeline.ChainWriter { + hashSize = 32 + writes = 16384 // this is to get a balanced trie + s = inmemchunkstore.New() + pf = func() pipeline.ChainWriter { lsw := store.NewStoreWriter(ctx, s, nil) return bmt.NewBmtWriter(lsw) } + r = redundancy.New(0, false, pf) + rMock = &redundancyMock{ + Params: *r, + } - ht = hashtrie.NewHashTrieWriter(chunkSize, branching, hashSize, pf) + ht = hashtrie.NewHashTrieWriter(ctx, hashSize, rMock, pf, s) ) // to create a level wrap we need to do branching^(level-1) writes @@ -176,17 +231,15 @@ func TestRegression(t *testing.T) { t.Parallel() var ( - branching = 128 - chunkSize = 4096 - hashSize = 32 - writes = 67100000 / 4096 - span = make([]byte, 8) - s = inmemchunkstore.New() - pf = func() pipeline.ChainWriter { + hashSize = 32 + writes = 67100000 / 4096 + span = make([]byte, 8) + s = inmemchunkstore.New() + pf = func() pipeline.ChainWriter { lsw := store.NewStoreWriter(ctx, s, nil) return bmt.NewBmtWriter(lsw) } - ht = hashtrie.NewHashTrieWriter(chunkSize, branching, hashSize, pf) + ht = hashtrie.NewHashTrieWriter(ctx, hashSize, redundancy.New(0, false, pf), pf, s) ) binary.LittleEndian.PutUint64(span, 4096) @@ -213,3 +266,119 @@ func TestRegression(t *testing.T) { t.Fatalf("want span %d got %d", writes*4096, sp) } } + +type replicaPutter struct { + storage.Putter + replicaCount atomic.Uint32 +} + +func (r *replicaPutter) Put(ctx context.Context, chunk swarm.Chunk) error { + r.replicaCount.Add(1) + return r.Putter.Put(ctx, chunk) +} + +// TestRedundancy using erasure coding library and checks carrierChunk function and modified span in intermediate chunk +func TestRedundancy(t *testing.T) { + t.Parallel() + // chunks need to have data so that it will not throw error on redundancy caching + ch, err := cac.New(make([]byte, swarm.ChunkSize)) + if err != nil { + t.Fatal(err) + } + chData := ch.Data() + chSpan := chData[:swarm.SpanSize] + chAddr := ch.Address().Bytes() + + // test logic assumes a simple 2 level chunk tree with carrier chunk + for _, tc := range []struct { + desc string + level redundancy.Level + encryption bool + writes int + parities int + }{ + { + desc: "redundancy write for not encrypted data", + level: redundancy.INSANE, + encryption: false, + writes: 98, // 97 chunk references fit into one chunk + 1 carrier + parities: 37, // 31 (full ch) + 6 (2 ref) + }, + { + desc: "redundancy write for encrypted data", + level: redundancy.PARANOID, + encryption: true, + writes: 21, // 21 encrypted chunk references fit into one chunk + 1 carrier + parities: 116, // // 87 (full ch) + 29 (2 ref) + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + subCtx := redundancy.SetLevelInContext(ctx, tc.level) + + s := inmemchunkstore.New() + intermediateChunkCounter := mock.NewChainWriter() + parityChunkCounter := mock.NewChainWriter() + replicaChunkCounter := &replicaPutter{Putter: s} + + r, ht := newErasureHashTrieWriter(subCtx, s, tc.level, tc.encryption, intermediateChunkCounter, parityChunkCounter, replicaChunkCounter) + + // write data to the hashTrie + var key []byte + if tc.encryption { + key = make([]byte, swarm.HashSize) + } + for i := 0; i < tc.writes; i++ { + a := &pipeline.PipeWriteArgs{Data: chData, Span: chSpan, Ref: chAddr, Key: key} + err := ht.ChainWrite(a) + if err != nil { + t.Fatal(err) + } + } + + ref, err := ht.Sum() + if err != nil { + t.Fatal(err) + } + + // sanity check for the test samples + if tc.parities != parityChunkCounter.ChainWriteCalls() { + t.Errorf("generated parities should be %d. Got: %d", tc.parities, parityChunkCounter.ChainWriteCalls()) + } + if intermediateChunkCounter.ChainWriteCalls() != 2 { // root chunk and the chunk which was written before carrierChunk movement + t.Errorf("effective chunks should be %d. Got: %d", tc.writes, intermediateChunkCounter.ChainWriteCalls()) + } + + rootch, err := s.Get(subCtx, swarm.NewAddress(ref[:swarm.HashSize])) + if err != nil { + t.Fatal(err) + } + chData := rootch.Data() + if tc.encryption { + chData, err = dec.DecryptChunkData(chData, ref[swarm.HashSize:]) + if err != nil { + t.Fatal(err) + } + } + + // span check + level, sp := redundancy.DecodeSpan(chData[:swarm.SpanSize]) + expectedSpan := bmtUtils.LengthToSpan(int64(tc.writes * swarm.ChunkSize)) + if !bytes.Equal(expectedSpan, sp) { + t.Fatalf("want span %d got %d", expectedSpan, span) + } + if level != tc.level { + t.Fatalf("encoded level differs from the uploaded one %d. Got: %d", tc.level, level) + } + expectedParities := tc.parities - r.Parities(r.MaxShards()) + _, parity := file.ReferenceCount(bmtUtils.LengthFromSpan(sp), level, tc.encryption) + if expectedParities != parity { + t.Fatalf("want parity %d got %d", expectedParities, parity) + } + if tc.level.GetReplicaCount()+1 != int(replicaChunkCounter.replicaCount.Load()) { // +1 is the original chunk + t.Fatalf("unexpected number of replicas: want %d. Got: %d", tc.level.GetReplicaCount(), int(replicaChunkCounter.replicaCount.Load())) + } + }) + } +} diff --git a/pkg/file/redundancy/export_test.go b/pkg/file/redundancy/export_test.go new file mode 100644 index 00000000000..7ea470a7e9c --- /dev/null +++ b/pkg/file/redundancy/export_test.go @@ -0,0 +1,15 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy + +// SetErasureEncoder changes erasureEncoderFunc to a new erasureEncoder facade +func SetErasureEncoder(f func(shards, parities int) (ErasureEncoder, error)) { + erasureEncoderFunc = f +} + +// GetErasureEncoder returns erasureEncoderFunc +func GetErasureEncoder() func(shards, parities int) (ErasureEncoder, error) { + return erasureEncoderFunc +} diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go new file mode 100644 index 00000000000..4e8da1b6390 --- /dev/null +++ b/pkg/file/redundancy/getter/getter.go @@ -0,0 +1,252 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package getter + +import ( + "context" + "io" + "sync" + "sync/atomic" + + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/klauspost/reedsolomon" +) + +// decoder is a private implementation of storage.Getter +// if retrieves children of an intermediate chunk potentially using erasure decoding +// it caches sibling chunks if erasure decoding started already +type decoder struct { + fetcher storage.Getter // network retrieval interface to fetch chunks + putter storage.Putter // interface to local storage to save reconstructed chunks + addrs []swarm.Address // all addresses of the intermediate chunk + inflight []atomic.Bool // locks to protect wait channels and RS buffer + cache map[string]int // map from chunk address shard position index + waits []chan struct{} // wait channels for each chunk + rsbuf [][]byte // RS buffer of data + parity shards for erasure decoding + ready chan struct{} // signal channel for successful retrieval of shardCnt chunks + lastLen int // length of the last data chunk in the RS buffer + shardCnt int // number of data shards + parityCnt int // number of parity shards + wg sync.WaitGroup // wait group to wait for all goroutines to finish + mu sync.Mutex // mutex to protect buffer + err error // error of the last erasure decoding + fetchedCnt atomic.Int32 // count successful retrievals + cancel func() // cancel function for RS decoding + remove func() // callback to remove decoder from decoders cache + config Config // configuration +} + +type Getter interface { + storage.Getter + io.Closer +} + +// New returns a decoder object used to retrieve children of an intermediate chunk +func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter, remove func(), conf Config) Getter { + ctx, cancel := context.WithCancel(context.Background()) + size := len(addrs) + + rsg := &decoder{ + fetcher: g, + putter: p, + addrs: addrs, + inflight: make([]atomic.Bool, size), + cache: make(map[string]int, size), + waits: make([]chan struct{}, shardCnt), + rsbuf: make([][]byte, size), + ready: make(chan struct{}, 1), + cancel: cancel, + remove: remove, + shardCnt: shardCnt, + parityCnt: size - shardCnt, + config: conf, + } + + // after init, cache and wait channels are immutable, need no locking + for i := 0; i < shardCnt; i++ { + rsg.cache[addrs[i].ByteString()] = i + rsg.waits[i] = make(chan struct{}) + } + + // prefetch chunks according to strategy + if !conf.Strict || conf.Strategy != NONE { + rsg.wg.Add(1) + go func() { + rsg.err = rsg.prefetch(ctx) + rsg.wg.Done() + }() + } + return rsg +} + +// Get will call parities and other sibling chunks if the chunk address cannot be retrieved +// assumes it is called for data shards only +func (g *decoder) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { + i, ok := g.cache[addr.ByteString()] + if !ok { + return nil, storage.ErrNotFound + } + if g.fly(i, true) { + g.wg.Add(1) + go func() { + g.fetch(ctx, i) + g.wg.Done() + }() + } + select { + case <-g.waits[i]: + case <-ctx.Done(): + return nil, ctx.Err() + } + return swarm.NewChunk(addr, g.getData(i)), nil +} + +// setData sets the data shard in the RS buffer +func (g *decoder) setData(i int, chdata []byte) { + data := chdata + // pad the chunk with zeros if it is smaller than swarm.ChunkSize + if len(data) < swarm.ChunkWithSpanSize { + g.lastLen = len(data) + data = make([]byte, swarm.ChunkWithSpanSize) + copy(data, chdata) + } + g.rsbuf[i] = data +} + +// getData returns the data shard from the RS buffer +func (g *decoder) getData(i int) []byte { + if i == g.shardCnt-1 && g.lastLen > 0 { + return g.rsbuf[i][:g.lastLen] // cut padding + } + return g.rsbuf[i] +} + +// fly commits to retrieve the chunk (fly and land) +// it marks a chunk as inflight and returns true unless it is already inflight +// the atomic bool implements a singleflight pattern +func (g *decoder) fly(i int, up bool) (success bool) { + return g.inflight[i].CompareAndSwap(!up, up) +} + +// fetch retrieves a chunk from the underlying storage +// it must be called asynchonously and only once for each chunk (singleflight pattern) +// it races with erasure recovery which takes precedence even if it started later +// due to the fact that erasure recovery could only implement global locking on all shards +func (g *decoder) fetch(ctx context.Context, i int) { + fctx, cancel := context.WithTimeout(ctx, g.config.FetchTimeout) + defer cancel() + ch, err := g.fetcher.Get(fctx, g.addrs[i]) + if err != nil { + _ = g.fly(i, false) // unset inflight + return + } + + g.mu.Lock() + defer g.mu.Unlock() + if i < len(g.waits) { + select { + case <-g.waits[i]: // if chunk is retrieved, ignore + return + default: + } + } + + select { + case <-ctx.Done(): // if context is cancelled, ignore + _ = g.fly(i, false) // unset inflight + return + default: + } + + // write chunk to rsbuf and signal waiters + g.setData(i, ch.Data()) // save the chunk in the RS buffer + if i < len(g.waits) { // if the chunk is a data shard + close(g.waits[i]) // signal that the chunk is retrieved + } + + // if all chunks are retrieved, signal ready + n := g.fetchedCnt.Add(1) + if n == int32(g.shardCnt) { + close(g.ready) // signal that just enough chunks are retrieved for decoding + } +} + +// missing gathers missing data shards not yet retrieved +// it sets the chunk as inflight and returns the index of the missing data shards +func (g *decoder) missing() (m []int) { + for i := 0; i < g.shardCnt; i++ { + select { + case <-g.waits[i]: // if chunk is retrieved, ignore + continue + default: + } + _ = g.fly(i, true) // commit (RS) or will commit to retrieve the chunk + m = append(m, i) // remember the missing chunk + } + return m +} + +// decode uses Reed-Solomon erasure coding decoder to recover data shards +// it must be called after shqrdcnt shards are retrieved +// it must be called under g.mu mutex protection +func (g *decoder) decode(ctx context.Context) error { + enc, err := reedsolomon.New(g.shardCnt, g.parityCnt) + if err != nil { + return err + } + + // decode data + return enc.ReconstructData(g.rsbuf) +} + +// recover wraps the stages of data shard recovery: +// 1. gather missing data shards +// 2. decode using Reed-Solomon decoder +// 3. save reconstructed chunks +func (g *decoder) recover(ctx context.Context) error { + // buffer lock acquired + g.mu.Lock() + defer g.mu.Unlock() + + // gather missing shards + m := g.missing() + if len(m) == 0 { + return nil + } + + // decode using Reed-Solomon decoder + if err := g.decode(ctx); err != nil { + return err + } + + // close wait channels for missing chunks + for _, i := range m { + close(g.waits[i]) + } + + // save chunks + return g.save(ctx, m) +} + +// save iterate over reconstructed shards and puts the corresponding chunks to local storage +func (g *decoder) save(ctx context.Context, missing []int) error { + for _, i := range missing { + if err := g.putter.Put(ctx, swarm.NewChunk(g.addrs[i], g.rsbuf[i])); err != nil { + return err + } + } + return nil +} + +// Close terminates the prefetch loop, waits for all goroutines to finish and +// removes the decoder from the cache +// it implements the io.Closer interface +func (g *decoder) Close() error { + g.cancel() + g.wg.Wait() + g.remove() + return nil +} diff --git a/pkg/file/redundancy/getter/getter_test.go b/pkg/file/redundancy/getter/getter_test.go new file mode 100644 index 00000000000..b18caa55c12 --- /dev/null +++ b/pkg/file/redundancy/getter/getter_test.go @@ -0,0 +1,386 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package getter_test + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + mrand "math/rand" + "sync" + "testing" + "time" + + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/file/redundancy/getter" + "github.com/ethersphere/bee/pkg/storage" + inmem "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" + mockstorer "github.com/ethersphere/bee/pkg/storer/mock" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/ethersphere/bee/pkg/util/testutil/racedetection" + "github.com/klauspost/reedsolomon" + "golang.org/x/sync/errgroup" +) + +// TestGetter tests the retrieval of chunks with missing data shards +// using the RACE strategy for a number of erasure code parameters +func TestGetterRACE(t *testing.T) { + type getterTest struct { + bufSize int + shardCnt int + erasureCnt int + } + + var tcs []getterTest + for bufSize := 3; bufSize <= 128; bufSize += 21 { + for shardCnt := bufSize/2 + 1; shardCnt <= bufSize; shardCnt += 21 { + parityCnt := bufSize - shardCnt + erasures := mrand.Perm(parityCnt - 1) + if len(erasures) > 3 { + erasures = erasures[:3] + } + for _, erasureCnt := range erasures { + tcs = append(tcs, getterTest{bufSize, shardCnt, erasureCnt}) + } + tcs = append(tcs, getterTest{bufSize, shardCnt, parityCnt}, getterTest{bufSize, shardCnt, parityCnt + 1}) + erasures = mrand.Perm(shardCnt - 1) + if len(erasures) > 3 { + erasures = erasures[:3] + } + for _, erasureCnt := range erasures { + tcs = append(tcs, getterTest{bufSize, shardCnt, erasureCnt + parityCnt + 1}) + } + } + } + t.Run("GET with RACE", func(t *testing.T) { + t.Parallel() + + for _, tc := range tcs { + t.Run(fmt.Sprintf("data/total/missing=%d/%d/%d", tc.shardCnt, tc.bufSize, tc.erasureCnt), func(t *testing.T) { + testDecodingRACE(t, tc.bufSize, tc.shardCnt, tc.erasureCnt) + }) + } + }) +} + +// TestGetterFallback tests the retrieval of chunks with missing data shards +// using the strict or fallback mode starting with NONE and DATA strategies +func TestGetterFallback(t *testing.T) { + t.Run("GET", func(t *testing.T) { + t.Run("NONE", func(t *testing.T) { + t.Run("strict", func(t *testing.T) { + testDecodingFallback(t, getter.NONE, true) + }) + t.Run("fallback", func(t *testing.T) { + testDecodingFallback(t, getter.NONE, false) + }) + }) + t.Run("DATA", func(t *testing.T) { + t.Run("strict", func(t *testing.T) { + testDecodingFallback(t, getter.DATA, true) + }) + t.Run("fallback", func(t *testing.T) { + testDecodingFallback(t, getter.DATA, false) + }) + }) + }) +} + +func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { + t.Helper() + strategyTimeout := 100 * time.Millisecond + if racedetection.On { + strategyTimeout *= 2 + } + store := inmem.New() + buf := make([][]byte, bufSize) + addrs := initData(t, buf, shardCnt, store) + + var addr swarm.Address + erasures := forget(t, store, addrs, erasureCnt) + for _, i := range erasures { + if i < shardCnt { + addr = addrs[i] + break + } + } + if len(addr.Bytes()) == 0 { + t.Skip("no data shard erased") + } + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + conf := getter.Config{ + Strategy: getter.RACE, + FetchTimeout: 2 * strategyTimeout, + StrategyTimeout: strategyTimeout, + } + g := getter.New(addrs, shardCnt, store, store, func() {}, conf) + defer g.Close() + parityCnt := len(buf) - shardCnt + q := make(chan error, 1) + go func() { + _, err := g.Get(ctx, addr) + q <- err + }() + err := context.DeadlineExceeded + wait := strategyTimeout * 2 + if racedetection.On { + wait *= 2 + } + select { + case err = <-q: + case <-time.After(wait): + } + switch { + case erasureCnt > parityCnt: + t.Run("unable to recover", func(t *testing.T) { + if !errors.Is(err, storage.ErrNotFound) && + !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected not found error or deadline exceeded, got %v", err) + } + }) + case erasureCnt <= parityCnt: + t.Run("will recover", func(t *testing.T) { + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + checkShardsAvailable(t, store, addrs[:shardCnt], buf[:shardCnt]) + }) + } +} + +// testDecodingFallback tests the retrieval of chunks with missing data shards +func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { + t.Helper() + + strategyTimeout := 150 * time.Millisecond + + bufSize := 12 + shardCnt := 6 + store := mockstorer.NewDelayedStore(inmem.New()) + buf := make([][]byte, bufSize) + addrs := initData(t, buf, shardCnt, store) + + // erase two data shards + delayed, erased := 1, 0 + ctx := context.TODO() + err := store.Delete(ctx, addrs[erased]) + if err != nil { + t.Fatal(err) + } + + // context for enforced retrievals with long timeout + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + // signal channels for delayed and erased chunk retrieval + waitDelayed, waitErased := make(chan error, 1), make(chan error, 1) + + // complete retrieval of delayed chunk by putting it into the store after a while + delay := strategyTimeout / 4 + if s == getter.NONE { + delay += strategyTimeout + } + store.Delay(addrs[delayed], delay) + // create getter + start := time.Now() + conf := getter.Config{ + Strategy: s, + Strict: strict, + FetchTimeout: strategyTimeout / 2, + StrategyTimeout: strategyTimeout, + } + g := getter.New(addrs, shardCnt, store, store, func() {}, conf) + defer g.Close() + + // launch delayed and erased chunk retrieval + wg := sync.WaitGroup{} + // defer wg.Wait() + wg.Add(2) + // signal using the waitDelayed and waitErased channels when + // delayed and erased chunk retrieval completes + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(ctx, strategyTimeout*time.Duration(5-s)) + defer cancel() + _, err := g.Get(ctx, addrs[delayed]) + waitDelayed <- err + }() + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(ctx, strategyTimeout*time.Duration(5-s)) + defer cancel() + _, err := g.Get(ctx, addrs[erased]) + waitErased <- err + }() + + // wait for delayed chunk retrieval to complete + select { + case err := <-waitDelayed: + if err != nil { + t.Fatal("unexpected error", err) + } + round := time.Since(start) / strategyTimeout + switch { + case strict && s == getter.NONE: + if round < 1 { + t.Fatalf("unexpected completion of delayed chunk retrieval. got round %d", round) + } + case s == getter.NONE: + if round < 1 { + t.Fatalf("unexpected completion of delayed chunk retrieval. got round %d", round) + } + if round > 2 { + t.Fatalf("unexpected late completion of delayed chunk retrieval. got round %d", round) + } + case s == getter.DATA: + if round > 0 { + t.Fatalf("unexpected late completion of delayed chunk retrieval. got round %d", round) + } + } + + checkShardsAvailable(t, store, addrs[delayed:], buf[delayed:]) + // wait for erased chunk retrieval to complete + select { + case err := <-waitErased: + if err != nil { + t.Fatal("unexpected error", err) + } + round = time.Since(start) / strategyTimeout + switch { + case strict: + t.Fatalf("unexpected completion of erased chunk retrieval. got round %d", round) + case s == getter.NONE: + if round < 3 { + t.Fatalf("unexpected early completion of erased chunk retrieval. got round %d", round) + } + if round > 3 { + t.Fatalf("unexpected late completion of erased chunk retrieval. got round %d", round) + } + case s == getter.DATA: + if round < 1 { + t.Fatalf("unexpected early completion of erased chunk retrieval. got round %d", round) + } + if round > 1 { + t.Fatalf("unexpected late completion of delayed chunk retrieval. got round %d", round) + } + } + checkShardsAvailable(t, store, addrs[:erased], buf[:erased]) + + case <-time.After(strategyTimeout * 2): + if !strict { + t.Fatal("unexpected timeout using strategy", s, "with strict", strict) + } + } + case <-time.After(strategyTimeout * 3): + if !strict || s != getter.NONE { + t.Fatal("unexpected timeout using strategy", s, "with strict", strict) + } + } +} + +func initData(t *testing.T, buf [][]byte, shardCnt int, s storage.ChunkStore) []swarm.Address { + t.Helper() + spanBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(spanBytes, swarm.ChunkSize) + + for i := 0; i < len(buf); i++ { + buf[i] = make([]byte, swarm.ChunkWithSpanSize) + if i >= shardCnt { + continue + } + _, err := io.ReadFull(rand.Reader, buf[i]) + if err != nil { + t.Fatal(err) + } + copy(buf[i], spanBytes) + } + + // fill in parity chunks + rs, err := reedsolomon.New(shardCnt, len(buf)-shardCnt) + if err != nil { + t.Fatal(err) + } + err = rs.Encode(buf) + if err != nil { + t.Fatal(err) + } + + // calculate chunk addresses and upload to the store + addrs := make([]swarm.Address, len(buf)) + ctx := context.TODO() + for i := 0; i < len(buf); i++ { + chunk, err := cac.NewWithDataSpan(buf[i]) + if err != nil { + t.Fatal(err) + } + err = s.Put(ctx, chunk) + if err != nil { + t.Fatal(err) + } + addrs[i] = chunk.Address() + } + + return addrs +} + +func checkShardsAvailable(t *testing.T, s storage.ChunkStore, addrs []swarm.Address, data [][]byte) { + t.Helper() + eg, ctx := errgroup.WithContext(context.Background()) + for i, addr := range addrs { + i := i + addr := addr + eg.Go(func() (err error) { + var delay time.Duration + var ch swarm.Chunk + for i := 0; i < 30; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + default: + <-time.After(delay) + delay = 50 * time.Millisecond + } + ch, err = s.Get(ctx, addr) + if err == nil { + break + } + err = fmt.Errorf("datashard %d with address %v is not available: %w", i, addr, err) + select { + case <-ctx.Done(): + return ctx.Err() + default: + <-time.After(delay) + delay = 50 * time.Millisecond + } + } + if err == nil && !bytes.Equal(ch.Data(), data[i]) { + return fmt.Errorf("datashard %d has incorrect data", i) + } + return err + }) + } + if err := eg.Wait(); err != nil { + t.Fatal(err) + } +} + +func forget(t *testing.T, store storage.ChunkStore, addrs []swarm.Address, erasureCnt int) (erasures []int) { + t.Helper() + + ctx := context.TODO() + erasures = mrand.Perm(len(addrs))[:erasureCnt] + for _, i := range erasures { + err := store.Delete(ctx, addrs[i]) + if err != nil { + t.Fatal(err) + } + } + return erasures +} diff --git a/pkg/file/redundancy/getter/strategies.go b/pkg/file/redundancy/getter/strategies.go new file mode 100644 index 00000000000..bb5188e9cc2 --- /dev/null +++ b/pkg/file/redundancy/getter/strategies.go @@ -0,0 +1,205 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package getter + +import ( + "context" + "errors" + "fmt" + "time" +) + +const ( + DefaultStrategy = NONE // default prefetching strategy + DefaultStrict = true // default fallback modes + DefaultFetchTimeout = 30 * time.Second // timeout for each chunk retrieval + DefaultStrategyTimeout = 300 * time.Millisecond // timeout for each strategy +) + +type ( + strategyKey struct{} + modeKey struct{} + fetchTimeoutKey struct{} + strategyTimeoutKey struct{} + Strategy = int +) + +// Config is the configuration for the getter - public +type Config struct { + Strategy Strategy + Strict bool + FetchTimeout time.Duration + StrategyTimeout time.Duration +} + +const ( + NONE Strategy = iota // no prefetching and no decoding + DATA // just retrieve data shards no decoding + PROX // proximity driven selective fetching + RACE // aggressive fetching racing all chunks + strategyCnt +) + +// DefaultConfig is the default configuration for the getter +var DefaultConfig = Config{ + Strategy: DefaultStrategy, + Strict: DefaultStrict, + FetchTimeout: DefaultFetchTimeout, + StrategyTimeout: DefaultStrategyTimeout, +} + +// NewConfigFromContext returns a new Config based on the context +func NewConfigFromContext(ctx context.Context, def Config) (conf Config, err error) { + var ok bool + conf = def + e := func(s string, errs ...error) error { + if len(errs) > 0 { + return fmt.Errorf("error setting %s from context: %w", s, errors.Join(errs...)) + } + return fmt.Errorf("error setting %s from context", s) + } + if val := ctx.Value(strategyKey{}); val != nil { + conf.Strategy, ok = val.(Strategy) + if !ok { + return conf, e("strategy") + } + } + if val := ctx.Value(modeKey{}); val != nil { + conf.Strict, ok = val.(bool) + if !ok { + return conf, e("fallback mode") + } + } + if val := ctx.Value(fetchTimeoutKey{}); val != nil { + fetchTimeoutVal, ok := val.(string) + if !ok { + return conf, e("fetcher timeout") + } + conf.FetchTimeout, err = time.ParseDuration(fetchTimeoutVal) + if err != nil { + return conf, e("fetcher timeout", err) + } + } + if val := ctx.Value(strategyTimeoutKey{}); val != nil { + strategyTimeoutVal, ok := val.(string) + if !ok { + return conf, e("fetcher timeout") + } + conf.StrategyTimeout, err = time.ParseDuration(strategyTimeoutVal) + if err != nil { + return conf, e("fetcher timeout", err) + } + } + return conf, nil +} + +// SetStrategy sets the strategy for the retrieval +func SetStrategy(ctx context.Context, s Strategy) context.Context { + return context.WithValue(ctx, strategyKey{}, s) +} + +// SetStrict sets the strict mode for the retrieval +func SetStrict(ctx context.Context, strict bool) context.Context { + return context.WithValue(ctx, modeKey{}, strict) +} + +// SetFetchTimeout sets the timeout for each fetch +func SetFetchTimeout(ctx context.Context, timeout string) context.Context { + return context.WithValue(ctx, fetchTimeoutKey{}, timeout) +} + +// SetStrategyTimeout sets the timeout for each strategy +func SetStrategyTimeout(ctx context.Context, timeout string) context.Context { + return context.WithValue(ctx, fetchTimeoutKey{}, timeout) +} + +// SetConfigInContext sets the config params in the context +func SetConfigInContext(ctx context.Context, s Strategy, fallbackmode bool, fetchTimeout, strategyTimeout string) context.Context { + ctx = SetStrategy(ctx, s) + ctx = SetStrict(ctx, !fallbackmode) + ctx = SetFetchTimeout(ctx, fetchTimeout) + ctx = SetStrategyTimeout(ctx, strategyTimeout) + return ctx +} + +func (g *decoder) prefetch(ctx context.Context) error { + if g.config.Strict && g.config.Strategy == NONE { + return nil + } + defer g.remove() + var cancels []func() + cancelAll := func() { + for _, cancel := range cancels { + cancel() + } + } + defer cancelAll() + run := func(s Strategy) error { + if s == PROX { // NOT IMPLEMENTED + return errors.New("strategy not implemented") + } + + var stop <-chan time.Time + if s < RACE { + timer := time.NewTimer(g.config.StrategyTimeout) + defer timer.Stop() + stop = timer.C + } + lctx, cancel := context.WithCancel(ctx) + cancels = append(cancels, cancel) + prefetch(lctx, g, s) + + select { + // successfully retrieved shardCnt number of chunks + case <-g.ready: + cancelAll() + case <-stop: + return fmt.Errorf("prefetching with strategy %d timed out", s) + case <-ctx.Done(): + return nil + } + // call the erasure decoder + // if decoding is successful terminate the prefetch loop + return g.recover(ctx) // context to cancel when shardCnt chunks are retrieved + } + var err error + for s := g.config.Strategy; s < strategyCnt; s++ { + err = run(s) + if g.config.Strict || err == nil { + break + } + } + + return err +} + +// prefetch launches the retrieval of chunks based on the strategy +func prefetch(ctx context.Context, g *decoder, s Strategy) { + var m []int + switch s { + case NONE: + return + case DATA: + // only retrieve data shards + m = g.missing() + case PROX: + // proximity driven selective fetching + // NOT IMPLEMENTED + case RACE: + // retrieve all chunks at once enabling race among chunks + m = g.missing() + for i := g.shardCnt; i < len(g.addrs); i++ { + m = append(m, i) + } + } + for _, i := range m { + i := i + g.wg.Add(1) + go func() { + g.fetch(ctx, i) + g.wg.Done() + }() + } +} diff --git a/pkg/file/redundancy/level.go b/pkg/file/redundancy/level.go new file mode 100644 index 00000000000..f7b4f5c19f1 --- /dev/null +++ b/pkg/file/redundancy/level.go @@ -0,0 +1,182 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy + +import ( + "context" + "errors" + "fmt" + + "github.com/ethersphere/bee/pkg/swarm" +) + +// Level is the redundancy level +// which carries information about how much redundancy should be added to data to remain retrievable with a 1-10^(-6) certainty +// in different groups of expected chunk retrival error rates (level values) +type Level uint8 + +const ( + // no redundancy will be added + NONE Level = iota + // expected 1% chunk retrieval error rate + MEDIUM + // expected 5% chunk retrieval error rate + STRONG + // expected 10% chunk retrieval error rate + INSANE + // expected 50% chunk retrieval error rate + PARANOID +) + +// GetParities returns number of parities based on appendix F table 5 +func (l Level) GetParities(shards int) int { + et, err := l.getErasureTable() + if err != nil { + return 0 + } + return et.getParities(shards) +} + +// GetMaxShards returns back the maximum number of effective data chunks +func (l Level) GetMaxShards() int { + p := l.GetParities(swarm.Branches) + return swarm.Branches - p +} + +// GetEncParities returns number of parities for encrypted chunks based on appendix F table 6 +func (l Level) GetEncParities(shards int) int { + et, err := l.getEncErasureTable() + if err != nil { + return 0 + } + return et.getParities(shards) +} + +func (l Level) getErasureTable() (erasureTable, error) { + switch l { + case NONE: + return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") + case MEDIUM: + return mediumEt, nil + case STRONG: + return strongEt, nil + case INSANE: + return insaneEt, nil + case PARANOID: + return paranoidEt, nil + default: + return erasureTable{}, fmt.Errorf("redundancy: level value %d is not a legit redundancy level", l) + } +} + +func (l Level) getEncErasureTable() (erasureTable, error) { + switch l { + case NONE: + return erasureTable{}, errors.New("redundancy: level NONE does not have erasure table") + case MEDIUM: + return encMediumEt, nil + case STRONG: + return encStrongEt, nil + case INSANE: + return encInsaneEt, nil + case PARANOID: + return encParanoidEt, nil + default: + return erasureTable{}, fmt.Errorf("redundancy: level value %d is not a legit redundancy level", l) + } +} + +// GetMaxEncShards returns back the maximum number of effective encrypted data chunks +func (l Level) GetMaxEncShards() int { + p := l.GetEncParities(swarm.EncryptedBranches) + return (swarm.Branches - p) / 2 +} + +// GetReplicaCount returns back the dispersed replica number +func (l Level) GetReplicaCount() int { + return replicaCounts[int(l)] +} + +// Decrement returns a weaker redundancy level compare to the current one +func (l Level) Decrement() Level { + return Level(uint8(l) - 1) +} + +// TABLE INITS + +var mediumEt = newErasureTable( + []int{95, 69, 47, 29, 15, 6, 2, 1}, + []int{9, 8, 7, 6, 5, 4, 3, 2}, +) +var encMediumEt = newErasureTable( + []int{47, 34, 23, 14, 7, 3, 1}, + []int{9, 8, 7, 6, 5, 4, 3}, +) + +var strongEt = newErasureTable( + []int{105, 96, 87, 78, 70, 62, 54, 47, 40, 33, 27, 21, 16, 11, 7, 4, 2, 1}, + []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4}, +) +var encStrongEt = newErasureTable( + []int{52, 48, 43, 39, 35, 31, 27, 23, 20, 16, 13, 10, 8, 5, 3, 2, 1}, + []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5}, +) + +var insaneEt = newErasureTable( + []int{93, 88, 83, 78, 74, 69, 64, 60, 55, 51, 46, 42, 38, 34, 30, 27, 23, 20, 17, 14, 11, 9, 6, 4, 3, 2, 1}, + []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5}, +) +var encInsaneEt = newErasureTable( + []int{46, 44, 41, 39, 37, 34, 32, 30, 27, 25, 23, 21, 19, 17, 15, 13, 11, 10, 8, 7, 5, 4, 3, 2, 1}, + []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 6}, +) + +var paranoidEt = newErasureTable( + []int{ + 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, + 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, + }, + []int{ + 89, 87, 86, 84, 83, 81, 80, 78, 76, 75, 73, 71, 70, 68, 66, 65, 63, 61, 59, 58, + 56, 54, 52, 50, 48, 47, 45, 43, 40, 38, 36, 34, 31, 29, 26, 23, 19, + }, +) +var encParanoidEt = newErasureTable( + []int{ + 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, + 8, 7, 6, 5, 4, 3, 2, 1, + }, + []int{ + 87, 84, 81, 78, 75, 71, 68, 65, 61, 58, + 54, 50, 47, 43, 38, 34, 29, 23, + }, +) + +// GetReplicaCounts returns back the ascending dispersed replica counts for all redundancy levels +func GetReplicaCounts() [5]int { + c := replicaCounts + return c +} + +// the actual number of replicas needed to keep the error rate below 1/10^6 +// for the five levels of redundancy are 0, 2, 4, 5, 19 +// we use an approximation as the successive powers of 2 +var replicaCounts = [5]int{0, 2, 4, 8, 16} + +type levelKey struct{} + +// SetLevelInContext sets the redundancy level in the context +func SetLevelInContext(ctx context.Context, level Level) context.Context { + return context.WithValue(ctx, levelKey{}, level) +} + +// GetLevelFromContext is a helper function to extract the redundancy level from the context +func GetLevelFromContext(ctx context.Context) Level { + rlevel := PARANOID + if val := ctx.Value(levelKey{}); val != nil { + rlevel = val.(Level) + } + return rlevel +} diff --git a/pkg/file/redundancy/redundancy.go b/pkg/file/redundancy/redundancy.go new file mode 100644 index 00000000000..443dc0637b3 --- /dev/null +++ b/pkg/file/redundancy/redundancy.go @@ -0,0 +1,198 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy + +import ( + "fmt" + + "github.com/ethersphere/bee/pkg/file/pipeline" + "github.com/ethersphere/bee/pkg/swarm" + "github.com/klauspost/reedsolomon" +) + +// ParityChunkCallback is called when a new parity chunk has been created +type ParityChunkCallback func(level int, span, address []byte) error + +type RedundancyParams interface { + MaxShards() int // returns the maximum data shard number being used in an intermediate chunk + Level() Level + Parities(int) int + ChunkWrite(int, []byte, ParityChunkCallback) error + ElevateCarrierChunk(int, ParityChunkCallback) error + Encode(int, ParityChunkCallback) error + GetRootData() ([]byte, error) +} + +type ErasureEncoder interface { + Encode([][]byte) error +} + +var erasureEncoderFunc = func(shards, parities int) (ErasureEncoder, error) { + return reedsolomon.New(shards, parities) +} + +type Params struct { + level Level + pipeLine pipeline.PipelineFunc + buffer [][][]byte // keeps bytes of chunks on each level for producing erasure coded data; [levelIndex][branchIndex][byteIndex] + cursor []int // index of the current buffered chunk in Buffer. this is basically the latest used branchIndex. + maxShards int // number of chunks after which the parity encode function should be called + maxParity int // number of parity chunks if maxShards has been reached for erasure coding + encryption bool +} + +func New(level Level, encryption bool, pipeLine pipeline.PipelineFunc) *Params { + maxShards := 0 + maxParity := 0 + if encryption { + maxShards = level.GetMaxEncShards() + maxParity = level.GetParities(swarm.EncryptedBranches) + } else { + maxShards = level.GetMaxShards() + maxParity = level.GetParities(swarm.BmtBranches) + } + // init dataBuffer for erasure coding + rsChunkLevels := 0 + if level != NONE { + rsChunkLevels = 8 + } + Buffer := make([][][]byte, rsChunkLevels) + for i := 0; i < rsChunkLevels; i++ { + Buffer[i] = make([][]byte, swarm.BmtBranches) // 128 long always because buffer varies at encrypted chunks + } + + return &Params{ + level: level, + pipeLine: pipeLine, + buffer: Buffer, + cursor: make([]int, 9), + maxShards: maxShards, + maxParity: maxParity, + encryption: encryption, + } +} + +func (p *Params) MaxShards() int { + return p.maxShards +} + +func (p *Params) Level() Level { + return p.level +} + +func (p *Params) Parities(shards int) int { + if p.encryption { + return p.level.GetEncParities(shards) + } + return p.level.GetParities(shards) +} + +// ChunkWrite caches the chunk data on the given chunk level and if it is full then it calls Encode +func (p *Params) ChunkWrite(chunkLevel int, data []byte, callback ParityChunkCallback) error { + if p.level == NONE { + return nil + } + if len(data) != swarm.ChunkWithSpanSize { + zeros := make([]byte, swarm.ChunkWithSpanSize-len(data)) + data = append(data, zeros...) + } + + return p.chunkWrite(chunkLevel, data, callback) +} + +// ChunkWrite caches the chunk data on the given chunk level and if it is full then it calls Encode +func (p *Params) chunkWrite(chunkLevel int, data []byte, callback ParityChunkCallback) error { + // append chunk to the buffer + p.buffer[chunkLevel][p.cursor[chunkLevel]] = data + p.cursor[chunkLevel]++ + + // add parity chunk if it is necessary + if p.cursor[chunkLevel] == p.maxShards { + // append erasure coded data + return p.encode(chunkLevel, callback) + } + return nil +} + +// Encode produces and stores parity chunks that will be also passed back to the caller +func (p *Params) Encode(chunkLevel int, callback ParityChunkCallback) error { + if p.level == NONE || p.cursor[chunkLevel] == 0 { + return nil + } + + return p.encode(chunkLevel, callback) +} + +func (p *Params) encode(chunkLevel int, callback ParityChunkCallback) error { + shards := p.cursor[chunkLevel] + parities := p.Parities(shards) + + n := shards + parities + // realloc for parity chunks if it does not override the prev one + // calculate parity chunks + enc, err := erasureEncoderFunc(shards, parities) + if err != nil { + return err + } + + pz := len(p.buffer[chunkLevel][0]) + for i := shards; i < n; i++ { + p.buffer[chunkLevel][i] = make([]byte, pz) + } + err = enc.Encode(p.buffer[chunkLevel][:n]) + if err != nil { + return err + } + + for i := shards; i < n; i++ { + chunkData := p.buffer[chunkLevel][i] + span := chunkData[:swarm.SpanSize] + + writer := p.pipeLine() + args := pipeline.PipeWriteArgs{ + Data: chunkData, + Span: span, + } + err = writer.ChainWrite(&args) + if err != nil { + return err + } + + err = callback(chunkLevel+1, span, args.Ref) + if err != nil { + return err + } + } + p.cursor[chunkLevel] = 0 + + return nil +} + +// ElevateCarrierChunk moves the last poor orphan chunk to the level above where it can fit and there are other chunks as well. +func (p *Params) ElevateCarrierChunk(chunkLevel int, callback ParityChunkCallback) error { + if p.level == NONE { + return nil + } + if p.cursor[chunkLevel] != 1 { + return fmt.Errorf("redundancy: cannot elevate carrier chunk because it is not the only chunk on the level. It has %d chunks", p.cursor[chunkLevel]) + } + + // not necessary to update current level since we will not work with it anymore + return p.chunkWrite(chunkLevel+1, p.buffer[chunkLevel][p.cursor[chunkLevel]-1], callback) +} + +// GetRootData returns the topmost chunk in the tree. +// throws and error if the encoding has not been finished in the BMT +// OR redundancy is not used in the BMT +func (p *Params) GetRootData() ([]byte, error) { + if p.level == NONE { + return nil, fmt.Errorf("redundancy: no redundancy level is used for the file in order to cache root data") + } + lastBuffer := p.buffer[len(p.buffer)-1] + if len(lastBuffer[0]) != swarm.ChunkWithSpanSize { + return nil, fmt.Errorf("redundancy: hashtrie sum has not finished in order to cache root data") + } + return lastBuffer[0], nil +} diff --git a/pkg/file/redundancy/redundancy_test.go b/pkg/file/redundancy/redundancy_test.go new file mode 100644 index 00000000000..bb7aa35ba75 --- /dev/null +++ b/pkg/file/redundancy/redundancy_test.go @@ -0,0 +1,150 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy_test + +import ( + "crypto/rand" + "fmt" + "io" + "sync" + "testing" + + "github.com/ethersphere/bee/pkg/file/pipeline" + "github.com/ethersphere/bee/pkg/file/pipeline/bmt" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/swarm" +) + +type mockEncoder struct { + shards, parities int +} + +func newMockEncoder(shards, parities int) (redundancy.ErasureEncoder, error) { + return &mockEncoder{ + shards: shards, + parities: parities, + }, nil +} + +// Encode makes MSB of span equal to data +func (m *mockEncoder) Encode(buffer [][]byte) error { + // writes parity data + indicatedValue := 0 + for i := m.shards; i < m.shards+m.parities; i++ { + data := make([]byte, 32) + data[swarm.SpanSize-1], data[swarm.SpanSize] = uint8(indicatedValue), uint8(indicatedValue) + buffer[i] = data + indicatedValue++ + } + return nil +} + +type ParityChainWriter struct { + sync.Mutex + chainWriteCalls int + sumCalls int + validCalls []bool +} + +func NewParityChainWriter() *ParityChainWriter { + return &ParityChainWriter{} +} + +func (c *ParityChainWriter) ChainWriteCalls() int { + c.Lock() + defer c.Unlock() + return c.chainWriteCalls +} +func (c *ParityChainWriter) SumCalls() int { c.Lock(); defer c.Unlock(); return c.sumCalls } + +func (c *ParityChainWriter) ChainWrite(args *pipeline.PipeWriteArgs) error { + c.Lock() + defer c.Unlock() + valid := args.Span[len(args.Span)-1] == args.Data[len(args.Span)] && args.Data[len(args.Span)] == byte(c.chainWriteCalls) + c.chainWriteCalls++ + c.validCalls = append(c.validCalls, valid) + return nil +} +func (c *ParityChainWriter) Sum() ([]byte, error) { + c.Lock() + defer c.Unlock() + c.sumCalls++ + return nil, nil +} + +func TestEncode(t *testing.T) { + t.Parallel() + // initializes mockEncoder -> creates shard chunks -> redundancy.chunkWrites -> call encode + erasureEncoder := redundancy.GetErasureEncoder() + defer func() { + redundancy.SetErasureEncoder(erasureEncoder) + }() + redundancy.SetErasureEncoder(newMockEncoder) + + // test on the data level + for _, level := range []redundancy.Level{redundancy.MEDIUM, redundancy.STRONG, redundancy.INSANE, redundancy.PARANOID} { + for _, encrypted := range []bool{false, true} { + maxShards := level.GetMaxShards() + if encrypted { + maxShards = level.GetMaxEncShards() + } + for shardCount := 1; shardCount <= maxShards; shardCount++ { + t.Run(fmt.Sprintf("redundancy level %d is checked with %d shards", level, shardCount), func(t *testing.T) { + parityChainWriter := NewParityChainWriter() + ppf := func() pipeline.ChainWriter { + return bmt.NewBmtWriter(parityChainWriter) + } + params := redundancy.New(level, encrypted, ppf) + // checks parity pipelinecalls are valid + + parityCount := 0 + parityCallback := func(level int, span, address []byte) error { + parityCount++ + return nil + } + + for i := 0; i < shardCount; i++ { + buffer := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, buffer) + if err != nil { + t.Fatal(err) + } + err = params.ChunkWrite(0, buffer, parityCallback) + if err != nil { + t.Fatal(err) + } + } + if shardCount != maxShards { + // encode should be called automatically when reaching maxshards + err := params.Encode(0, parityCallback) + if err != nil { + t.Fatal(err) + } + } + + // CHECKS + + if parityCount != parityChainWriter.chainWriteCalls { + t.Fatalf("parity callback was called %d times meanwhile chainwrite was called %d times", parityCount, parityChainWriter.chainWriteCalls) + } + + expectedParityCount := params.Level().GetParities(shardCount) + if encrypted { + expectedParityCount = params.Level().GetEncParities(shardCount) + } + if parityCount != expectedParityCount { + t.Fatalf("parity callback was called %d times meanwhile expected parity number should be %d", parityCount, expectedParityCount) + } + + for i, validCall := range parityChainWriter.validCalls { + if !validCall { + t.Fatalf("parity chunk data is wrong at parity index %d", i) + } + } + }) + } + } + } +} diff --git a/pkg/file/redundancy/span.go b/pkg/file/redundancy/span.go new file mode 100644 index 00000000000..eb71286fa43 --- /dev/null +++ b/pkg/file/redundancy/span.go @@ -0,0 +1,34 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy + +import ( + "github.com/ethersphere/bee/pkg/swarm" +) + +// EncodeLevel encodes used redundancy level for uploading into span keeping the real byte count for the chunk. +// assumes span is LittleEndian +func EncodeLevel(span []byte, level Level) { + // set parity in the most signifact byte + span[swarm.SpanSize-1] = uint8(level) | 1<<7 // p + 128 +} + +// DecodeSpan decodes the used redundancy level from span keeping the real byte count for the chunk. +// assumes span is LittleEndian +func DecodeSpan(span []byte) (Level, []byte) { + spanCopy := make([]byte, swarm.SpanSize) + copy(spanCopy, span) + if !IsLevelEncoded(spanCopy) { + return 0, spanCopy + } + pByte := spanCopy[swarm.SpanSize-1] + return Level(pByte & ((1 << 7) - 1)), append(spanCopy[:swarm.SpanSize-1], 0) +} + +// IsLevelEncoded checks whether the redundancy level is encoded in the span +// assumes span is LittleEndian +func IsLevelEncoded(span []byte) bool { + return span[swarm.SpanSize-1] > 128 +} diff --git a/pkg/file/redundancy/table.go b/pkg/file/redundancy/table.go new file mode 100644 index 00000000000..c993db310c6 --- /dev/null +++ b/pkg/file/redundancy/table.go @@ -0,0 +1,52 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package redundancy + +type erasureTable struct { + shards []int + parities []int +} + +// newErasureTable initializes a shards<->parities table +// +// the value order must be strictly descending in both arrays +// example usage: +// shards := []int{94, 68, 46, 28, 14, 5, 1} +// parities := []int{9, 8, 7, 6, 5, 4, 3} +// var et = newErasureTable(shards, parities) +func newErasureTable(shards, parities []int) erasureTable { + if len(shards) != len(parities) { + panic("redundancy table: shards and parities arrays must be of equal size") + } + + maxShards := shards[0] + maxParities := parities[0] + for k := 1; k < len(shards); k++ { + s := shards[k] + if maxShards <= s { + panic("redundancy table: shards should be in strictly descending order") + } + p := parities[k] + if maxParities <= p { + panic("redundancy table: parities should be in strictly descending order") + } + maxShards, maxParities = s, p + } + + return erasureTable{ + shards: shards, + parities: parities, + } +} + +// getParities gives back the optimal parity number for a given shard +func (et *erasureTable) getParities(maxShards int) int { + for k, s := range et.shards { + if maxShards >= s { + return et.parities[k] + } + } + return 0 +} diff --git a/pkg/file/utils.go b/pkg/file/utils.go new file mode 100644 index 00000000000..35f21414b82 --- /dev/null +++ b/pkg/file/utils.go @@ -0,0 +1,96 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package file + +import ( + "bytes" + "errors" + + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/swarm" +) + +var ( + zeroAddress = [32]byte{} +) + +// ChunkPayloadSize returns the effective byte length of an intermediate chunk +// assumes data is always chunk size (without span) +func ChunkPayloadSize(data []byte) (int, error) { + l := len(data) + for l >= swarm.HashSize { + if !bytes.Equal(data[l-swarm.HashSize:l], zeroAddress[:]) { + return l, nil + } + + l -= swarm.HashSize + } + + return 0, errors.New("redundancy getter: intermediate chunk does not have at least a child") +} + +// ChunkAddresses returns data shards and parities of the intermediate chunk +// assumes data is truncated by ChunkPayloadSize +func ChunkAddresses(data []byte, parities, reflen int) (addrs []swarm.Address, shardCnt int) { + shardCnt = (len(data) - parities*swarm.HashSize) / reflen + for offset := 0; offset < len(data); offset += reflen { + addrs = append(addrs, swarm.NewAddress(data[offset:offset+swarm.HashSize])) + if len(addrs) == shardCnt && reflen != swarm.HashSize { + reflen = swarm.HashSize + offset += reflen + } + } + return addrs, shardCnt +} + +// ReferenceCount brute-forces the data shard count from which identify the parity count as well in a substree +// assumes span > swarm.chunkSize +// returns data and parity shard number +func ReferenceCount(span uint64, level redundancy.Level, encrytedChunk bool) (int, int) { + // assume we have a trie of size `span` then we can assume that all of + // the forks except for the last one on the right are of equal size + // this is due to how the splitter wraps levels. + // first the algorithm will search for a BMT level where span can be included + // then identify how large data one reference can hold on that level + // then count how many references can satisfy span + // and finally how many parity shards should be on that level + maxShards := level.GetMaxShards() + if encrytedChunk { + maxShards = level.GetMaxEncShards() + } + var ( + branching = uint64(maxShards) // branching factor is how many data shard references can fit into one intermediate chunk + branchSize = uint64(swarm.ChunkSize) + ) + // search for branch level big enough to include span + branchLevel := 1 + for { + if branchSize >= span { + break + } + branchSize *= branching + branchLevel++ + } + // span in one full reference + referenceSize := uint64(swarm.ChunkSize) + // referenceSize = branching ** (branchLevel - 1) + for i := 1; i < branchLevel-1; i++ { + referenceSize *= branching + } + + dataShardAddresses := 1 + spanOffset := referenceSize + for spanOffset < span { + spanOffset += referenceSize + dataShardAddresses++ + } + + parityAddresses := level.GetParities(dataShardAddresses) + if encrytedChunk { + parityAddresses = level.GetEncParities(dataShardAddresses) + } + + return dataShardAddresses, parityAddresses +} diff --git a/pkg/node/bootstrap.go b/pkg/node/bootstrap.go index bf2bc4030cb..a1e48228da7 100644 --- a/pkg/node/bootstrap.go +++ b/pkg/node/bootstrap.go @@ -237,7 +237,7 @@ func bootstrapNode( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - reader, l, err = joiner.New(ctx, localStore.Download(true), snapshotReference) + reader, l, err = joiner.New(ctx, localStore.Download(true), localStore.Cache(), snapshotReference) if err != nil { logger.Warning("bootstrap: file joiner failed", "error", err) continue diff --git a/pkg/node/node.go b/pkg/node/node.go index 6c8aef8e4bf..214cad27d4f 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -1058,7 +1058,7 @@ func NewBee( b.resolverCloser = multiResolver feedFactory := factory.New(localStore.Download(true)) - steward := steward.New(localStore, retrieval) + steward := steward.New(localStore, retrieval, localStore.Cache()) extraOpts := api.ExtraOptions{ Pingpong: pingPong, diff --git a/pkg/replicas/export_test.go b/pkg/replicas/export_test.go new file mode 100644 index 00000000000..6e029f302c6 --- /dev/null +++ b/pkg/replicas/export_test.go @@ -0,0 +1,15 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package replicas + +import "github.com/ethersphere/bee/pkg/storage" + +var ( + Signer = signer +) + +func Wait(g storage.Getter) { + g.(*getter).wg.Wait() +} diff --git a/pkg/replicas/getter.go b/pkg/replicas/getter.go new file mode 100644 index 00000000000..26345919958 --- /dev/null +++ b/pkg/replicas/getter.go @@ -0,0 +1,150 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// the code below implements the integration of dispersed replicas in chunk fetching. +// using storage.Getter interface. +package replicas + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/soc" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +// ErrSwarmageddon is returned in case of a vis mayor called Swarmageddon. +// Swarmageddon is the situation when none of the replicas can be retrieved. +// If 2^{depth} replicas were uploaded and they all have valid postage stamps +// then the probability of Swarmageddon is less than 0.000001 +// assuming the error rate of chunk retrievals stays below the level expressed +// as depth by the publisher. +var ErrSwarmageddon = errors.New("swarmageddon has begun") + +// getter is the private implementation of storage.Getter, an interface for +// retrieving chunks. This getter embeds the original simple chunk getter and extends it +// to a multiplexed variant that fetches chunks with replicas. +// +// the strategy to retrieve a chunk that has replicas can be configured with a few parameters: +// - RetryInterval: the delay before a new batch of replicas is fetched. +// - depth: 2^{depth} is the total number of additional replicas that have been uploaded +// (by default, it is assumed to be 4, ie. total of 16) +// - (not implemented) pivot: replicas with address in the proximity of pivot will be tried first +type getter struct { + wg sync.WaitGroup + storage.Getter + level redundancy.Level +} + +// NewGetter is the getter constructor +func NewGetter(g storage.Getter, level redundancy.Level) storage.Getter { + return &getter{Getter: g, level: level} +} + +// Get makes the getter satisfy the storage.Getter interface +func (g *getter) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // channel that the results (retrieved chunks) are gathered to from concurrent + // workers each fetching a replica + resultC := make(chan swarm.Chunk) + // errc collects the errors + errc := make(chan error, 17) + var errs error + errcnt := 0 + + // concurrently call to retrieve chunk using original CAC address + g.wg.Add(1) + go func() { + defer g.wg.Done() + ch, err := g.Getter.Get(ctx, addr) + if err != nil { + errc <- err + return + } + + select { + case resultC <- ch: + case <-ctx.Done(): + } + }() + // counters + n := 0 // counts the replica addresses tried + target := 2 // the number of replicas attempted to download in this batch + total := g.level.GetReplicaCount() + + // + rr := newReplicator(addr, g.level) + next := rr.c + var wait <-chan time.Time // nil channel to disable case + // addresses used are doubling each period of search expansion + // (at intervals of RetryInterval) + ticker := time.NewTicker(RetryInterval) + defer ticker.Stop() + for level := uint8(0); level <= uint8(g.level); { + select { + // at least one chunk is retrieved, cancel the rest and return early + case chunk := <-resultC: + cancel() + return chunk, nil + + case err = <-errc: + errs = errors.Join(errs, err) + errcnt++ + if errcnt > total { + return nil, errors.Join(ErrSwarmageddon, errs) + } + + // ticker switches on the address channel + case <-wait: + wait = nil + next = rr.c + level++ + target = 1 << level + n = 0 + continue + + // getting the addresses in order + case so := <-next: + if so == nil { + next = nil + continue + } + + g.wg.Add(1) + go func() { + defer g.wg.Done() + ch, err := g.Getter.Get(ctx, swarm.NewAddress(so.addr)) + if err != nil { + errc <- err + return + } + + soc, err := soc.FromChunk(ch) + if err != nil { + errc <- err + return + } + + select { + case resultC <- soc.WrappedChunk(): + case <-ctx.Done(): + } + }() + n++ + if n < target { + continue + } + next = nil + wait = ticker.C + } + } + + return nil, nil +} diff --git a/pkg/replicas/getter_test.go b/pkg/replicas/getter_test.go new file mode 100644 index 00000000000..7435b0574fd --- /dev/null +++ b/pkg/replicas/getter_test.go @@ -0,0 +1,263 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package replicas_test + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/replicas" + "github.com/ethersphere/bee/pkg/soc" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +type testGetter struct { + ch swarm.Chunk + now time.Time + origCalled chan struct{} + origIndex int + errf func(int) chan struct{} + firstFound int32 + attempts atomic.Int32 + cancelled chan struct{} + addresses [17]swarm.Address + latencies [17]time.Duration +} + +func (tg *testGetter) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + i := tg.attempts.Add(1) - 1 + tg.addresses[i] = addr + tg.latencies[i] = time.Since(tg.now) + + if addr.Equal(tg.ch.Address()) { + tg.origIndex = int(i) + close(tg.origCalled) + ch = tg.ch + } + + if i != tg.firstFound { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-tg.errf(int(i)): + return nil, storage.ErrNotFound + } + } + defer func() { + go func() { + select { + case <-time.After(100 * time.Millisecond): + case <-ctx.Done(): + close(tg.cancelled) + } + }() + }() + + if ch != nil { + return ch, nil + } + return soc.New(addr.Bytes(), tg.ch).Sign(replicas.Signer) +} + +func newTestGetter(ch swarm.Chunk, firstFound int, errf func(int) chan struct{}) *testGetter { + return &testGetter{ + ch: ch, + errf: errf, + firstFound: int32(firstFound), + cancelled: make(chan struct{}), + origCalled: make(chan struct{}), + } +} + +// Close implements the storage.Getter interface +func (tg *testGetter) Close() error { + return nil +} + +func TestGetter(t *testing.T) { + t.Parallel() + // failure is a struct that defines a failure scenario to test + type failure struct { + name string + err error + errf func(int, int) func(int) chan struct{} + } + // failures is a list of failure scenarios to test + failures := []failure{ + { + "timeout", + context.Canceled, + func(_, _ int) func(i int) chan struct{} { + return func(i int) chan struct{} { + return nil + } + }, + }, + { + "not found", + storage.ErrNotFound, + func(_, _ int) func(i int) chan struct{} { + c := make(chan struct{}) + close(c) + return func(i int) chan struct{} { + return c + } + }, + }, + } + type test struct { + name string + failure failure + level int + count int + found int + } + + var tests []test + for _, f := range failures { + for level, c := range redundancy.GetReplicaCounts() { + for j := 0; j <= c*2+1; j++ { + tests = append(tests, test{ + name: fmt.Sprintf("%s level %d count %d found %d", f.name, level, c, j), + failure: f, + level: level, + count: c, + found: j, + }) + } + } + } + + // initiailise the base chunk + chunkLen := 420 + buf := make([]byte, chunkLen) + if _, err := io.ReadFull(rand.Reader, buf); err != nil { + t.Fatal(err) + } + ch, err := cac.New(buf) + if err != nil { + t.Fatal(err) + } + // reset retry interval to speed up tests + retryInterval := replicas.RetryInterval + defer func() { replicas.RetryInterval = retryInterval }() + replicas.RetryInterval = 100 * time.Millisecond + + // run the tests + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // initiate a chunk retrieval session using replicas.Getter + // embedding a testGetter that simulates the behaviour of a chunk store + store := newTestGetter(ch, tc.found, tc.failure.errf(tc.found, tc.count)) + g := replicas.NewGetter(store, redundancy.Level(tc.level)) + store.now = time.Now() + ctx, cancel := context.WithCancel(context.Background()) + if tc.found > tc.count { + wait := replicas.RetryInterval / 2 * time.Duration(1+2*tc.level) + go func() { + time.Sleep(wait) + cancel() + }() + } + _, err := g.Get(ctx, ch.Address()) + replicas.Wait(g) + cancel() + + // test the returned error + if tc.found <= tc.count { + if err != nil { + t.Fatalf("expected no error. got %v", err) + } + // if j <= c, the original chunk should be retrieved and the context should be cancelled + t.Run("retrievals cancelled", func(t *testing.T) { + + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for context to be cancelled") + case <-store.cancelled: + } + }) + + } else { + if err == nil { + t.Fatalf("expected error. got ") + } + + t.Run("returns correct error", func(t *testing.T) { + if !errors.Is(err, replicas.ErrSwarmageddon) { + t.Fatalf("incorrect error. want Swarmageddon. got %v", err) + } + if !errors.Is(err, tc.failure.err) { + t.Fatalf("incorrect error. want it to wrap %v. got %v", tc.failure.err, err) + } + }) + } + + attempts := int(store.attempts.Load()) + // the original chunk should be among those attempted for retrieval + addresses := store.addresses[:attempts] + latencies := store.latencies[:attempts] + t.Run("original address called", func(t *testing.T) { + select { + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting form original address to be attempted for retrieval") + case <-store.origCalled: + i := store.origIndex + if i > 2 { + t.Fatalf("original address called too late. want at most 2 (preceding attempts). got %v (latency: %v)", i, latencies[i]) + } + addresses = append(addresses[:i], addresses[i+1:]...) + latencies = append(latencies[:i], latencies[i+1:]...) + attempts-- + } + }) + + t.Run("retrieved count", func(t *testing.T) { + if attempts > tc.count { + t.Fatalf("too many attempts to retrieve a replica: want at most %v. got %v.", tc.count, attempts) + } + if tc.found > tc.count { + if attempts < tc.count { + t.Fatalf("too few attempts to retrieve a replica: want at least %v. got %v.", tc.count, attempts) + } + return + } + max := 2 + for i := 1; i < tc.level && max < tc.found; i++ { + max = max * 2 + } + if attempts > max { + t.Fatalf("too many attempts to retrieve a replica: want at most %v. got %v. latencies %v", max, attempts, latencies) + } + }) + + t.Run("dispersion", func(t *testing.T) { + + if err := dispersed(redundancy.Level(tc.level), ch, addresses); err != nil { + t.Fatalf("addresses are not dispersed: %v", err) + } + }) + + t.Run("latency", func(t *testing.T) { + counts := redundancy.GetReplicaCounts() + for i, latency := range latencies { + multiplier := latency / replicas.RetryInterval + if multiplier > 0 && i < counts[multiplier-1] { + t.Fatalf("incorrect latency for retrieving replica %d: %v", i, err) + } + } + }) + }) + } +} diff --git a/pkg/replicas/putter.go b/pkg/replicas/putter.go new file mode 100644 index 00000000000..4aa55b638f0 --- /dev/null +++ b/pkg/replicas/putter.go @@ -0,0 +1,61 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// the code below implements the integration of dispersed replicas in chunk upload. +// using storage.Putter interface. +package replicas + +import ( + "context" + "errors" + "sync" + + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/soc" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +// putter is the private implementation of the public storage.Putter interface +// putter extends the original putter to a concurrent multiputter +type putter struct { + putter storage.Putter +} + +// NewPutter is the putter constructor +func NewPutter(p storage.Putter) storage.Putter { + return &putter{p} +} + +// Put makes the getter satisfy the storage.Getter interface +func (p *putter) Put(ctx context.Context, ch swarm.Chunk) (err error) { + rlevel := redundancy.GetLevelFromContext(ctx) + errs := []error{p.putter.Put(ctx, ch)} + if rlevel == 0 { + return errs[0] + } + + rr := newReplicator(ch.Address(), rlevel) + errc := make(chan error, rlevel.GetReplicaCount()) + wg := sync.WaitGroup{} + for r := range rr.c { + r := r + wg.Add(1) + go func() { + defer wg.Done() + sch, err := soc.New(r.id, ch).Sign(signer) + if err == nil { + err = p.putter.Put(ctx, sch) + } + errc <- err + }() + } + + wg.Wait() + close(errc) + for err := range errc { + errs = append(errs, err) + } + return errors.Join(errs...) +} diff --git a/pkg/replicas/putter_test.go b/pkg/replicas/putter_test.go new file mode 100644 index 00000000000..7d05624ebca --- /dev/null +++ b/pkg/replicas/putter_test.go @@ -0,0 +1,195 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package replicas_test + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/replicas" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" + "github.com/ethersphere/bee/pkg/swarm" +) + +var ( + errTestA = errors.New("A") + errTestB = errors.New("B") +) + +type testBasePutter struct { + getErrors func(context.Context, swarm.Address) error + putErrors func(context.Context, swarm.Address) error + store storage.ChunkStore +} + +func (tbp *testBasePutter) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { + + g := tbp.getErrors + if g != nil { + return nil, g(ctx, addr) + } + return tbp.store.Get(ctx, addr) +} + +func (tbp *testBasePutter) Put(ctx context.Context, ch swarm.Chunk) error { + + g := tbp.putErrors + if g != nil { + return g(ctx, ch.Address()) + } + return tbp.store.Put(ctx, ch) +} + +func TestPutter(t *testing.T) { + t.Parallel() + tcs := []struct { + level redundancy.Level + length int + }{ + {0, 1}, + {1, 1}, + {2, 1}, + {3, 1}, + {4, 1}, + {0, 4096}, + {1, 4096}, + {2, 4096}, + {3, 4096}, + {4, 4096}, + } + for _, tc := range tcs { + t.Run(fmt.Sprintf("redundancy:%d, size:%d", tc.level, tc.length), func(t *testing.T) { + buf := make([]byte, tc.length) + if _, err := io.ReadFull(rand.Reader, buf); err != nil { + t.Fatal(err) + } + ctx := context.Background() + ctx = redundancy.SetLevelInContext(ctx, tc.level) + + ch, err := cac.New(buf) + if err != nil { + t.Fatal(err) + } + store := inmemchunkstore.New() + defer store.Close() + p := replicas.NewPutter(store) + + if err := p.Put(ctx, ch); err != nil { + t.Fatalf("expected no error. got %v", err) + } + var addrs []swarm.Address + orig := false + _ = store.Iterate(ctx, func(chunk swarm.Chunk) (stop bool, err error) { + if ch.Address().Equal(chunk.Address()) { + orig = true + return false, nil + } + addrs = append(addrs, chunk.Address()) + return false, nil + }) + if !orig { + t.Fatal("origial chunk missing") + } + t.Run("dispersion", func(t *testing.T) { + if err := dispersed(tc.level, ch, addrs); err != nil { + t.Fatalf("addresses are not dispersed: %v", err) + } + }) + t.Run("attempts", func(t *testing.T) { + count := tc.level.GetReplicaCount() + if len(addrs) != count { + t.Fatalf("incorrect number of attempts. want %v, got %v", count, len(addrs)) + } + }) + + t.Run("replication", func(t *testing.T) { + if err := replicated(store, ch, addrs); err != nil { + t.Fatalf("chunks are not replicas: %v", err) + } + }) + }) + } + t.Run("error handling", func(t *testing.T) { + tcs := []struct { + name string + level redundancy.Level + length int + f func(*testBasePutter) *testBasePutter + err []error + }{ + {"put errors", 4, 4096, func(tbp *testBasePutter) *testBasePutter { + var j int32 + i := &j + atomic.StoreInt32(i, 0) + tbp.putErrors = func(ctx context.Context, _ swarm.Address) error { + j := atomic.AddInt32(i, 1) + if j == 6 { + return errTestA + } + if j == 12 { + return errTestB + } + return nil + } + return tbp + }, []error{errTestA, errTestB}}, + {"put latencies", 4, 4096, func(tbp *testBasePutter) *testBasePutter { + var j int32 + i := &j + atomic.StoreInt32(i, 0) + tbp.putErrors = func(ctx context.Context, _ swarm.Address) error { + j := atomic.AddInt32(i, 1) + if j == 6 { + select { + case <-time.After(100 * time.Millisecond): + case <-ctx.Done(): + return ctx.Err() + } + } + if j == 12 { + return errTestA + } + return nil + } + return tbp + }, []error{errTestA, context.DeadlineExceeded}}, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + buf := make([]byte, tc.length) + if _, err := io.ReadFull(rand.Reader, buf); err != nil { + t.Fatal(err) + } + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + ctx = redundancy.SetLevelInContext(ctx, tc.level) + ch, err := cac.New(buf) + if err != nil { + t.Fatal(err) + } + store := inmemchunkstore.New() + defer store.Close() + p := replicas.NewPutter(tc.f(&testBasePutter{store: store})) + errs := p.Put(ctx, ch) + for _, err := range tc.err { + if !errors.Is(errs, err) { + t.Fatalf("incorrect error. want it to contain %v. got %v.", tc.err, errs) + } + } + }) + } + }) + +} diff --git a/pkg/replicas/replica_test.go b/pkg/replicas/replica_test.go new file mode 100644 index 00000000000..ce56401d987 --- /dev/null +++ b/pkg/replicas/replica_test.go @@ -0,0 +1,59 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// replicas_test just contains helper functions to verify dispersion and replication +package replicas_test + +import ( + "context" + "errors" + "fmt" + + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/soc" + "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +// dispersed verifies that a set of addresses are maximally dispersed without repetition +func dispersed(level redundancy.Level, ch swarm.Chunk, addrs []swarm.Address) error { + nhoods := make(map[byte]bool) + + for _, addr := range addrs { + if len(addr.Bytes()) != swarm.HashSize { + return errors.New("corrupt data: invalid address length") + } + nh := addr.Bytes()[0] >> (8 - int(level)) + if nhoods[nh] { + return errors.New("not dispersed enough: duplicate neighbourhood") + } + nhoods[nh] = true + } + if len(nhoods) != len(addrs) { + return fmt.Errorf("not dispersed enough: unexpected number of neighbourhood covered: want %v. got %v", len(addrs), len(nhoods)) + } + + return nil +} + +// replicated verifies that the replica chunks are indeed replicas +// of the original chunk wrapped in soc +func replicated(store storage.ChunkStore, ch swarm.Chunk, addrs []swarm.Address) error { + ctx := context.Background() + for _, addr := range addrs { + chunk, err := store.Get(ctx, addr) + if err != nil { + return err + } + + sch, err := soc.FromChunk(chunk) + if err != nil { + return err + } + if !sch.WrappedChunk().Equal(ch) { + return errors.New("invalid replica: does not wrap original content addressed chunk") + } + } + return nil +} diff --git a/pkg/replicas/replicas.go b/pkg/replicas/replicas.go new file mode 100644 index 00000000000..fd45b28a3b1 --- /dev/null +++ b/pkg/replicas/replicas.go @@ -0,0 +1,124 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package replicas implements a scheme to replicate chunks +// in such a way that +// - the replicas are optimally dispersed to aid cross-neighbourhood redundancy +// - the replicas addresses can be deduced by retrievers only knowing the address +// of the original content addressed chunk +// - no new chunk validation rules are introduced +package replicas + +import ( + "time" + + "github.com/ethersphere/bee/pkg/crypto" + "github.com/ethersphere/bee/pkg/file/redundancy" + "github.com/ethersphere/bee/pkg/swarm" +) + +var ( + // RetryInterval is the duration between successive additional requests + RetryInterval = 300 * time.Millisecond + privKey, _ = crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) + signer = crypto.NewDefaultSigner(privKey) +) + +// replicator running the find for replicas +type replicator struct { + addr []byte // chunk address + queue [16]*replica // to sort addresses according to di + exist [30]bool // maps the 16 distinct nibbles on all levels + sizes [5]int // number of distinct neighnourhoods redcorded for each depth + c chan *replica + rLevel redundancy.Level +} + +// newReplicator replicator constructor +func newReplicator(addr swarm.Address, rLevel redundancy.Level) *replicator { + rr := &replicator{ + addr: addr.Bytes(), + sizes: redundancy.GetReplicaCounts(), + c: make(chan *replica, 16), + rLevel: rLevel, + } + go rr.replicas() + return rr +} + +// replica of the mined SOC chunk (address) that serve as replicas +type replica struct { + addr, id []byte // byte slice of SOC address and SOC ID +} + +// replicate returns a replica params strucure seeded with a byte of entropy as argument +func (rr *replicator) replicate(i uint8) (sp *replica) { + // change the last byte of the address to create SOC ID + id := make([]byte, 32) + copy(id, rr.addr) + id[0] = i + // calculate SOC address for potential replica + h := swarm.NewHasher() + _, _ = h.Write(id) + _, _ = h.Write(swarm.ReplicasOwner) + return &replica{h.Sum(nil), id} +} + +// replicas enumerates replica parameters (SOC ID) pushing it in a channel given as argument +// the order of replicas is so that addresses are always maximally dispersed +// in successive sets of addresses. +// I.e., the binary tree representing the new addresses prefix bits up to depth is balanced +func (rr *replicator) replicas() { + defer close(rr.c) + n := 0 + for i := uint8(0); n < rr.rLevel.GetReplicaCount() && i < 255; i++ { + // create soc replica (ID and address using constant owner) + // the soc is added to neighbourhoods of depths in the closed interval [from...to] + r := rr.replicate(i) + d, m := rr.add(r, rr.rLevel) + if d == 0 { + continue + } + for m, r = range rr.queue[n:] { + if r == nil { + break + } + rr.c <- r + } + n += m + } +} + +// add inserts the soc replica into a replicator so that addresses are balanced +func (rr *replicator) add(r *replica, rLevel redundancy.Level) (depth int, rank int) { + if rLevel == redundancy.NONE { + return 0, 0 + } + nh := nh(rLevel, r.addr) + if rr.exist[nh] { + return 0, 0 + } + rr.exist[nh] = true + l, o := rr.add(r, rLevel.Decrement()) + d := uint8(rLevel) - 1 + if l == 0 { + o = rr.sizes[d] + rr.sizes[d]++ + rr.queue[o] = r + l = rLevel.GetReplicaCount() + } + return l, o +} + +// UTILS + +// index bases needed to keep track how many addresses were mined for a level. +var replicaIndexBases = [5]int{0, 2, 6, 14} + +// nh returns the lookup key based on the redundancy level +// to be used as index to the replicators exist array +func nh(rLevel redundancy.Level, addr []byte) int { + d := uint8(rLevel) + return replicaIndexBases[d-1] + int(addr[0]>>(8-d)) +} diff --git a/pkg/soc/validator.go b/pkg/soc/validator.go index a707eaded36..db0c388808e 100644 --- a/pkg/soc/validator.go +++ b/pkg/soc/validator.go @@ -5,6 +5,8 @@ package soc import ( + "bytes" + "github.com/ethersphere/bee/pkg/swarm" ) @@ -15,6 +17,11 @@ func Valid(ch swarm.Chunk) bool { return false } + // disperse replica validation + if bytes.Equal(s.owner, swarm.ReplicasOwner) && !bytes.Equal(s.WrappedChunk().Address().Bytes()[1:32], s.id[1:32]) { + return false + } + address, err := s.Address() if err != nil { return false diff --git a/pkg/soc/validator_test.go b/pkg/soc/validator_test.go index 18ed00001a4..695d43cc46f 100644 --- a/pkg/soc/validator_test.go +++ b/pkg/soc/validator_test.go @@ -5,9 +5,13 @@ package soc_test import ( + "crypto/rand" + "io" "strings" "testing" + "github.com/ethersphere/bee/pkg/cac" + "github.com/ethersphere/bee/pkg/crypto" "github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/swarm" ) @@ -31,6 +35,60 @@ func TestValid(t *testing.T) { } } +// TestValidDispersedReplica verifies that the validator can detect +// valid dispersed replicas chunks. +func TestValidDispersedReplica(t *testing.T) { + t.Parallel() + + t.Run("valid", func(t *testing.T) { + privKey, _ := crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) + signer := crypto.NewDefaultSigner(privKey) + + chData := make([]byte, swarm.ChunkSize) + _, _ = io.ReadFull(rand.Reader, chData) + ch, err := cac.New(chData) + if err != nil { + t.Fatal(err) + } + id := append([]byte{1}, ch.Address().Bytes()[1:]...) + + socCh, err := soc.New(id, ch).Sign(signer) + if err != nil { + t.Fatal(err) + } + + // check valid chunk + if !soc.Valid(socCh) { + t.Fatal("dispersed replica chunk is invalid") + } + }) + + t.Run("invalid", func(t *testing.T) { + privKey, _ := crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) + signer := crypto.NewDefaultSigner(privKey) + + chData := make([]byte, swarm.ChunkSize) + _, _ = io.ReadFull(rand.Reader, chData) + ch, err := cac.New(chData) + if err != nil { + t.Fatal(err) + } + id := append([]byte{1}, ch.Address().Bytes()[1:]...) + // change to invalid ID + id[2] += 1 + + socCh, err := soc.New(id, ch).Sign(signer) + if err != nil { + t.Fatal(err) + } + + // check valid chunk + if soc.Valid(socCh) { + t.Fatal("dispersed replica should be invalid") + } + }) +} + // TestInvalid verifies that the validator can detect chunks // with invalid data and invalid address. func TestInvalid(t *testing.T) { diff --git a/pkg/steward/steward.go b/pkg/steward/steward.go index 9f7bf0641a2..73b3d17b87f 100644 --- a/pkg/steward/steward.go +++ b/pkg/steward/steward.go @@ -36,11 +36,11 @@ type steward struct { netTraverser traversal.Traverser } -func New(ns storer.NetStore, r retrieval.Interface) Interface { +func New(ns storer.NetStore, r retrieval.Interface, joinerPutter storage.Putter) Interface { return &steward{ netStore: ns, - traverser: traversal.New(ns.Download(true)), - netTraverser: traversal.New(&netGetter{r}), + traverser: traversal.New(ns.Download(true), joinerPutter), + netTraverser: traversal.New(&netGetter{r}, joinerPutter), } } diff --git a/pkg/steward/steward_test.go b/pkg/steward/steward_test.go index 6644c86d69b..fe3ccd05ba9 100644 --- a/pkg/steward/steward_test.go +++ b/pkg/steward/steward_test.go @@ -10,10 +10,12 @@ import ( "crypto/rand" "errors" "sync" + "sync/atomic" "testing" "time" "github.com/ethersphere/bee/pkg/file/pipeline/builder" + "github.com/ethersphere/bee/pkg/file/redundancy" postagetesting "github.com/ethersphere/bee/pkg/postage/mock" "github.com/ethersphere/bee/pkg/steward" storage "github.com/ethersphere/bee/pkg/storage" @@ -22,19 +24,31 @@ import ( "github.com/ethersphere/bee/pkg/swarm" ) +type counter struct { + storage.ChunkStore + count atomic.Int32 +} + +func (c *counter) Put(ctx context.Context, ch swarm.Chunk) (err error) { + c.count.Add(1) + return c.ChunkStore.Put(ctx, ch) +} + func TestSteward(t *testing.T) { t.Parallel() + inmem := &counter{ChunkStore: inmemchunkstore.New()} var ( ctx = context.Background() chunks = 1000 data = make([]byte, chunks*4096) //1k chunks - chunkStore = inmemchunkstore.New() + chunkStore = inmem store = mockstorer.NewWithChunkStore(chunkStore) localRetrieval = &localRetriever{ChunkStore: chunkStore} - s = steward.New(store, localRetrieval) + s = steward.New(store, localRetrieval, inmem) stamper = postagetesting.NewStamper() ) + ctx = redundancy.SetLevelInContext(ctx, redundancy.NONE) n, err := rand.Read(data) if n != cap(data) { @@ -44,21 +58,13 @@ func TestSteward(t *testing.T) { t.Fatal(err) } - pipe := builder.NewPipelineBuilder(ctx, chunkStore, false) + pipe := builder.NewPipelineBuilder(ctx, chunkStore, false, 0) addr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data)) if err != nil { t.Fatal(err) } - chunkCount := 0 - err = chunkStore.Iterate(context.Background(), func(ch swarm.Chunk) (bool, error) { - chunkCount++ - return false, nil - }) - if err != nil { - t.Fatalf("failed iterating: %v", err) - } - + chunkCount := int(inmem.count.Load()) done := make(chan struct{}) errc := make(chan error, 1) go func() { diff --git a/pkg/storer/mock/forgetting.go b/pkg/storer/mock/forgetting.go new file mode 100644 index 00000000000..5588b5c263e --- /dev/null +++ b/pkg/storer/mock/forgetting.go @@ -0,0 +1,128 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mockstorer + +import ( + "context" + "sync" + "sync/atomic" + "time" + + storage "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +type DelayedStore struct { + storage.ChunkStore + cache map[string]time.Duration + mu sync.Mutex +} + +func NewDelayedStore(s storage.ChunkStore) *DelayedStore { + return &DelayedStore{ + ChunkStore: s, + cache: make(map[string]time.Duration), + } +} + +func (d *DelayedStore) Delay(addr swarm.Address, delay time.Duration) { + d.mu.Lock() + defer d.mu.Unlock() + d.cache[addr.String()] = delay +} + +func (d *DelayedStore) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + d.mu.Lock() + defer d.mu.Unlock() + if delay, ok := d.cache[addr.String()]; ok && delay > 0 { + select { + case <-time.After(delay): + delete(d.cache, addr.String()) + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return d.ChunkStore.Get(ctx, addr) +} + +type ForgettingStore struct { + storage.ChunkStore + record atomic.Bool + mu sync.Mutex + n atomic.Int64 + missed map[string]struct{} +} + +func NewForgettingStore(s storage.ChunkStore) *ForgettingStore { + return &ForgettingStore{ChunkStore: s, missed: make(map[string]struct{})} +} + +func (f *ForgettingStore) Stored() int64 { + return f.n.Load() +} + +func (f *ForgettingStore) Record() { + f.record.Store(true) +} + +func (f *ForgettingStore) Unrecord() { + f.record.Store(false) +} + +func (f *ForgettingStore) Miss(addr swarm.Address) { + f.mu.Lock() + defer f.mu.Unlock() + f.miss(addr) +} + +func (f *ForgettingStore) Unmiss(addr swarm.Address) { + f.mu.Lock() + defer f.mu.Unlock() + f.unmiss(addr) +} + +func (f *ForgettingStore) miss(addr swarm.Address) { + f.missed[addr.String()] = struct{}{} +} + +func (f *ForgettingStore) unmiss(addr swarm.Address) { + delete(f.missed, addr.String()) +} + +func (f *ForgettingStore) isMiss(addr swarm.Address) bool { + _, ok := f.missed[addr.String()] + return ok +} + +func (f *ForgettingStore) Reset() { + f.missed = make(map[string]struct{}) +} + +func (f *ForgettingStore) Missed() int { + return len(f.missed) +} + +// Get implements the ChunkStore interface. +// if in recording phase, record the chunk address as miss and returns Get on the embedded store +// if in forgetting phase, returns ErrNotFound if the chunk address is recorded as miss +func (f *ForgettingStore) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.record.Load() { + f.miss(addr) + } else if f.isMiss(addr) { + return nil, storage.ErrNotFound + } + return f.ChunkStore.Get(ctx, addr) +} + +// Put implements the ChunkStore interface. +func (f *ForgettingStore) Put(ctx context.Context, ch swarm.Chunk) (err error) { + f.n.Add(1) + if !f.record.Load() { + f.Unmiss(ch.Address()) + } + return f.ChunkStore.Put(ctx, ch) +} diff --git a/pkg/swarm/swarm.go b/pkg/swarm/swarm.go index 7e8d514b095..b69b4c034bf 100644 --- a/pkg/swarm/swarm.go +++ b/pkg/swarm/swarm.go @@ -37,6 +37,12 @@ var ( ErrInvalidChunk = errors.New("invalid chunk") ) +var ( + // Ethereum Address for SOC owner of Dispersed Replicas + // generated from private key 0x0100000000000000000000000000000000000000000000000000000000000000 + ReplicasOwner, _ = hex.DecodeString("dc5b20847f43d67928f49cd4f85d696b5a7617b5") +) + var ( // EmptyAddress is the address that is all zeroes. EmptyAddress = NewAddress(make([]byte, HashSize)) diff --git a/pkg/traversal/traversal.go b/pkg/traversal/traversal.go index e9609a90c8f..dbdb5603e82 100644 --- a/pkg/traversal/traversal.go +++ b/pkg/traversal/traversal.go @@ -29,19 +29,20 @@ type Traverser interface { } // New constructs for a new Traverser. -func New(store storage.Getter) Traverser { - return &service{store: store} +func New(getter storage.Getter, putter storage.Putter) Traverser { + return &service{getter: getter, putter: putter} } // service is implementation of Traverser using storage.Storer as its storage. type service struct { - store storage.Getter + getter storage.Getter + putter storage.Putter } // Traverse implements Traverser.Traverse method. func (s *service) Traverse(ctx context.Context, addr swarm.Address, iterFn swarm.AddressIterFunc) error { processBytes := func(ref swarm.Address) error { - j, _, err := joiner.New(ctx, s.store, ref) + j, _, err := joiner.New(ctx, s.getter, s.putter, ref) if err != nil { return fmt.Errorf("traversal: joiner error on %q: %w", ref, err) } @@ -54,7 +55,7 @@ func (s *service) Traverse(ctx context.Context, addr swarm.Address, iterFn swarm // skip SOC check for encrypted references if addr.IsValidLength() { - ch, err := s.store.Get(ctx, addr) + ch, err := s.getter.Get(ctx, addr) if err != nil { return fmt.Errorf("traversal: failed to get root chunk %s: %w", addr.String(), err) } @@ -64,7 +65,7 @@ func (s *service) Traverse(ctx context.Context, addr swarm.Address, iterFn swarm } } - ls := loadsave.NewReadonly(s.store) + ls := loadsave.NewReadonly(s.getter) switch mf, err := manifest.NewDefaultManifestReference(addr, ls); { case errors.Is(err, manifest.ErrInvalidManifestType): break diff --git a/pkg/traversal/traversal_test.go b/pkg/traversal/traversal_test.go index dc8a78fcb53..7d9d473e1e4 100644 --- a/pkg/traversal/traversal_test.go +++ b/pkg/traversal/traversal_test.go @@ -161,13 +161,13 @@ func TestTraversalBytes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - pipe := builder.NewPipelineBuilder(ctx, storerMock, false) + pipe := builder.NewPipelineBuilder(ctx, storerMock, false, 0) address, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data)) if err != nil { t.Fatal(err) } - err = traversal.New(storerMock).Traverse(ctx, address, iter.Next) + err = traversal.New(storerMock, storerMock).Traverse(ctx, address, iter.Next) if err != nil { t.Fatal(err) } @@ -256,13 +256,13 @@ func TestTraversalFiles(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - pipe := builder.NewPipelineBuilder(ctx, storerMock, false) + pipe := builder.NewPipelineBuilder(ctx, storerMock, false, 0) fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data)) if err != nil { t.Fatal(err) } - ls := loadsave.New(storerMock, pipelineFactory(storerMock, false)) + ls := loadsave.New(storerMock, storerMock, pipelineFactory(storerMock, false)) fManifest, err := manifest.NewDefaultManifest(ls, false) if err != nil { t.Fatal(err) @@ -294,7 +294,7 @@ func TestTraversalFiles(t *testing.T) { t.Fatal(err) } - err = traversal.New(storerMock).Traverse(ctx, address, iter.Next) + err = traversal.New(storerMock, storerMock).Traverse(ctx, address, iter.Next) if err != nil { t.Fatal(err) } @@ -421,7 +421,7 @@ func TestTraversalManifest(t *testing.T) { } wantHashes = append(wantHashes, tc.manifestHashes...) - ls := loadsave.New(storerMock, pipelineFactory(storerMock, false)) + ls := loadsave.New(storerMock, storerMock, pipelineFactory(storerMock, false)) dirManifest, err := manifest.NewMantarayManifest(ls, false) if err != nil { t.Fatal(err) @@ -430,7 +430,7 @@ func TestTraversalManifest(t *testing.T) { for _, f := range tc.files { data := generateSample(f.size) - pipe := builder.NewPipelineBuilder(ctx, storerMock, false) + pipe := builder.NewPipelineBuilder(ctx, storerMock, false, 0) fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data)) if err != nil { t.Fatal(err) @@ -452,7 +452,7 @@ func TestTraversalManifest(t *testing.T) { t.Fatal(err) } - err = traversal.New(storerMock).Traverse(ctx, address, iter.Next) + err = traversal.New(storerMock, storerMock).Traverse(ctx, address, iter.Next) if err != nil { t.Fatal(err) } @@ -490,7 +490,7 @@ func TestTraversalSOC(t *testing.T) { t.Fatal(err) } - err = traversal.New(store).Traverse(ctx, sch.Address(), iter.Next) + err = traversal.New(store, store).Traverse(ctx, sch.Address(), iter.Next) if err != nil { t.Fatal(err) } @@ -506,6 +506,6 @@ func TestTraversalSOC(t *testing.T) { func pipelineFactory(s storage.Putter, encrypt bool) func() pipeline.Interface { return func() pipeline.Interface { - return builder.NewPipelineBuilder(context.Background(), s, encrypt) + return builder.NewPipelineBuilder(context.Background(), s, encrypt, 0) } } diff --git a/pkg/util/testutil/pseudorand/reader.go b/pkg/util/testutil/pseudorand/reader.go new file mode 100644 index 00000000000..d7fcf0c3646 --- /dev/null +++ b/pkg/util/testutil/pseudorand/reader.go @@ -0,0 +1,182 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// this is a pseudorandom reader that generates a deterministic +// sequence of bytes based on the seed. It is used in tests to +// enable large volumes of pseudorandom data to be generated +// and compared without having to store the data in memory. +package pseudorand + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/ethersphere/bee/pkg/swarm" +) + +const bufSize = 4096 + +// Reader is a pseudorandom reader that generates a deterministic +// sequence of bytes based on the seed. +type Reader struct { + cur int + len int + seg [40]byte + buf [bufSize]byte +} + +// NewSeed creates a new seed. +func NewSeed() ([]byte, error) { + seed := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, seed) + return seed, err +} + +// New creates a new pseudorandom reader seeded with the given seed. +func NewReader(seed []byte, l int) *Reader { + r := &Reader{len: l} + _ = copy(r.buf[8:], seed) + r.fill() + return r +} + +// Size returns the size of the reader. +func (r *Reader) Size() int { + return r.len +} + +// Read reads len(buf) bytes into buf. It returns the number of bytes read (0 <= n <= len(buf)) and any error encountered. Even if Read returns n < len(buf), it may use all of buf as scratch space during the call. If some data is available but not len(buf) bytes, Read conventionally returns what is available instead of waiting for more. +func (r *Reader) Read(buf []byte) (n int, err error) { + cur := r.cur % bufSize + toRead := min(bufSize-cur, r.len-r.cur) + if toRead < len(buf) { + buf = buf[:toRead] + } + n = copy(buf, r.buf[cur:]) + r.cur += n + if r.cur == r.len { + return n, io.EOF + } + if r.cur%bufSize == 0 { + r.fill() + } + return n, nil +} + +// Equal compares the contents of the reader with the contents of +// the given reader. It returns true if the contents are equal upto n bytes +func (r1 *Reader) Equal(r2 io.Reader) (bool, error) { + ok, err := r1.Match(r2, r1.len) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + n, err := io.ReadFull(r2, make([]byte, 1)) + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return n == 0, nil + } + return false, err +} + +// Match compares the contents of the reader with the contents of +// the given reader. It returns true if the contents are equal upto n bytes +func (r1 *Reader) Match(r2 io.Reader, l int) (bool, error) { + + read := func(r io.Reader, buf []byte) (n int, err error) { + for n < len(buf) && err == nil { + i, e := r.Read(buf[n:]) + if e == nil && i == 0 { + return n, nil + } + err = e + n += i + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + err = nil + } + return n, err + } + + buf1 := make([]byte, bufSize) + buf2 := make([]byte, bufSize) + for l > 0 { + if l <= len(buf1) { + buf1 = buf1[:l] + buf2 = buf2[:l] + } + + n1, err := read(r1, buf1) + if err != nil { + return false, err + } + n2, err := read(r2, buf2) + if err != nil { + return false, err + } + if n1 != n2 { + return false, nil + } + if !bytes.Equal(buf1[:n1], buf2[:n2]) { + return false, nil + } + l -= n1 + } + return true, nil +} + +// Seek sets the offset for the next Read to offset, interpreted +// according to whence: 0 means relative to the start of the file, +// 1 means relative to the current offset, and 2 means relative to +// the end. It returns the new offset and an error, if any. +func (r *Reader) Seek(offset int64, whence int) (int64, error) { + switch whence { + case 0: + r.cur = int(offset) + case 1: + r.cur += int(offset) + case 2: + r.cur = r.len - int(offset) + } + if r.cur < 0 || r.cur > r.len { + return 0, fmt.Errorf("seek: invalid offset %d", r.cur) + } + r.fill() + return int64(r.cur), nil +} + +// Offset returns the current offset of the reader. +func (r *Reader) Offset() int64 { + return int64(r.cur) +} + +// ReadAt reads len(buf) bytes into buf starting at offset off. +func (r *Reader) ReadAt(buf []byte, off int64) (n int, err error) { + if _, err := r.Seek(off, io.SeekStart); err != nil { + return 0, err + } + return r.Read(buf) +} + +// fill fills the buffer with the hash of the current segment. +func (r *Reader) fill() { + if r.cur >= r.len { + return + } + bufSegments := bufSize / 32 + start := r.cur / bufSegments + rem := (r.cur % bufSize) / 32 + h := swarm.NewHasher() + for i := 32 * rem; i < len(r.buf); i += 32 { + binary.BigEndian.PutUint64(r.seg[:], uint64((start+i)/32)) + h.Reset() + _, _ = h.Write(r.seg[:]) + copy(r.buf[i:], h.Sum(nil)) + } +} diff --git a/pkg/util/testutil/pseudorand/reader_test.go b/pkg/util/testutil/pseudorand/reader_test.go new file mode 100644 index 00000000000..4ec85a90d73 --- /dev/null +++ b/pkg/util/testutil/pseudorand/reader_test.go @@ -0,0 +1,121 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pseudorand_test + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/rand" + "testing" + + "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" +) + +func TestReader(t *testing.T) { + size := 42000 + seed := make([]byte, 32) + r := pseudorand.NewReader(seed, size) + content, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + t.Run("deterministicity", func(t *testing.T) { + r2 := pseudorand.NewReader(seed, size) + content2, err := io.ReadAll(r2) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(content, content2) { + t.Fatal("content mismatch") + } + }) + t.Run("randomness", func(t *testing.T) { + bufSize := 4096 + if bytes.Equal(content[:bufSize], content[bufSize:2*bufSize]) { + t.Fatal("buffers should not match") + } + }) + t.Run("re-readability", func(t *testing.T) { + ns, err := r.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != 0 { + t.Fatal("seek mismatch") + } + var read []byte + buf := make([]byte, 8200) + total := 0 + for { + s := rand.Intn(820) + n, err := r.Read(buf[:s]) + total += n + read = append(read, buf[:n]...) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatal(err) + } + } + read = read[:total] + if len(read) != len(content) { + t.Fatal("content length mismatch. expected", len(content), "got", len(read)) + } + if !bytes.Equal(content, read) { + t.Fatal("content mismatch") + } + }) + t.Run("comparison", func(t *testing.T) { + ns, err := r.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != 0 { + t.Fatal("seek mismatch") + } + if eq, err := r.Equal(bytes.NewBuffer(content)); err != nil { + t.Fatal(err) + } else if !eq { + t.Fatal("content mismatch") + } + ns, err = r.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != 0 { + t.Fatal("seek mismatch") + } + if eq, err := r.Equal(bytes.NewBuffer(content[:len(content)-1])); err != nil { + t.Fatal(err) + } else if eq { + t.Fatal("content should match") + } + }) + t.Run("seek and match", func(t *testing.T) { + for i := 0; i < 20; i++ { + off := rand.Intn(size) + n := rand.Intn(size - off) + t.Run(fmt.Sprintf("off=%d n=%d", off, n), func(t *testing.T) { + ns, err := r.Seek(int64(off), io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != int64(off) { + t.Fatal("seek mismatch") + } + ok, err := r.Match(bytes.NewBuffer(content[off:off+n]), n) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("content mismatch") + } + }) + } + }) +} diff --git a/pkg/util/testutil/racedetection/off.go b/pkg/util/testutil/racedetection/off.go new file mode 100644 index 00000000000..d57125bfd03 --- /dev/null +++ b/pkg/util/testutil/racedetection/off.go @@ -0,0 +1,10 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !race +// +build !race + +package racedetection + +const On = false diff --git a/pkg/util/testutil/racedetection/on.go b/pkg/util/testutil/racedetection/on.go new file mode 100644 index 00000000000..92438ecfdc3 --- /dev/null +++ b/pkg/util/testutil/racedetection/on.go @@ -0,0 +1,10 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build race +// +build race + +package racedetection + +const On = true diff --git a/pkg/util/testutil/racedetection/race.go b/pkg/util/testutil/racedetection/race.go new file mode 100644 index 00000000000..411295e87f6 --- /dev/null +++ b/pkg/util/testutil/racedetection/race.go @@ -0,0 +1,9 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package racedetection + +func IsOn() bool { + return On +}