diff --git a/openapi/SwarmCommon.yaml b/openapi/SwarmCommon.yaml index 4a7faadc154..07119cc6894 100644 --- a/openapi/SwarmCommon.yaml +++ b/openapi/SwarmCommon.yaml @@ -954,7 +954,7 @@ components: required: false description: > Specify the retrieve strategy on redundant data. - The mumbers stand for NONE, DATA, PROX and RACE, respectively. + The numbers stand for NONE, DATA, PROX and RACE, respectively. Strategy NONE means no prefetching takes place. Strategy DATA means only data chunks are prefetched. Strategy PROX means only chunks that are close to the node are prefetched. diff --git a/pkg/api/api.go b/pkg/api/api.go index 26d11213e5a..470174c68d4 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -83,7 +83,8 @@ const ( SwarmRedundancyLevelHeader = "Swarm-Redundancy-Level" SwarmRedundancyStrategyHeader = "Swarm-Redundancy-Strategy" SwarmRedundancyFallbackModeHeader = "Swarm-Redundancy-Fallback-Mode" - SwarmChunkRetrievalTimeoutHeader = "Swarm-Chunk-Retrieval-Timeout-Level" + SwarmChunkRetrievalTimeoutHeader = "Swarm-Chunk-Retrieval-Timeout" + SwarmLookAheadBufferSizeHeader = "Swarm-Lookahead-Buffer-Size" ImmutableHeader = "Immutable" GasPriceHeader = "Gas-Price" @@ -99,18 +100,6 @@ const ( OriginHeader = "Origin" ) -// The size of buffer used for prefetching content with Langos. -// Warning: This value influences the number of chunk requests and chunker join goroutines -// per file request. -// Recommended value is 8 or 16 times the io.Copy default buffer value which is 32kB, depending -// on the file size. Use lookaheadBufferSize() to get the correct buffer size for the request. -const ( - smallFileBufferSize = 8 * 32 * 1024 - largeFileBufferSize = 16 * 32 * 1024 - - largeBufferFilesizeThreshold = 10 * 1000000 // ten megs -) - const ( multiPartFormData = "multipart/form-data" contentTypeTar = "application/x-tar" @@ -615,13 +604,6 @@ func (s *Service) gasConfigMiddleware(handlerName string) func(h http.Handler) h } } -func lookaheadBufferSize(size int64) int { - if size <= largeBufferFilesizeThreshold { - return smallFileBufferSize - } - return largeFileBufferSize -} - // corsHandler sets CORS headers to HTTP response if allowed origins are configured. func (s *Service) corsHandler(h http.Handler) http.Handler { allowedHeaders := []string{ diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index 3f8cfbac1f6..67a80a58918 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -35,6 +35,25 @@ import ( "github.com/gorilla/mux" ) +// The size of buffer used for prefetching content with Langos when not using erasure coding +// Warning: This value influences the number of chunk requests and chunker join goroutines +// per file request. +// Recommended value is 8 or 16 times the io.Copy default buffer value which is 32kB, depending +// on the file size. Use lookaheadBufferSize() to get the correct buffer size for the request. +const ( + smallFileBufferSize = 8 * 32 * 1024 + largeFileBufferSize = 16 * 32 * 1024 + + largeBufferFilesizeThreshold = 10 * 1000000 // ten megs +) + +func lookaheadBufferSize(size int64) int { + if size <= largeBufferFilesizeThreshold { + return smallFileBufferSize + } + return largeFileBufferSize +} + func (s *Service) bzzUploadHandler(w http.ResponseWriter, r *http.Request) { logger := tracing.NewLoggerWithTraceID(r.Context(), s.logger.WithName("post_bzz").Build()) @@ -272,8 +291,13 @@ func (s *Service) serveReference(logger log.Logger, address swarm.Address, pathV loggerV1 := logger.V(1).Build() headers := struct { - Cache *bool `map:"Swarm-Cache"` + Cache *bool `map:"Swarm-Cache"` + Strategy getter.Strategy `map:"Swarm-Redundancy-Strategy"` + FallbackMode bool `map:"Swarm-Redundancy-Fallback-Mode"` + ChunkRetrievalTimeout string `map:"Swarm-Chunk-Retrieval-Timeout"` + LookaheadBufferSize *string `map:"Swarm-Lookahead-Buffer-Size"` }{} + if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) return @@ -282,10 +306,12 @@ func (s *Service) serveReference(logger log.Logger, address swarm.Address, pathV if headers.Cache != nil { cache = *headers.Cache } + ls := loadsave.NewReadonly(s.storer.Download(cache)) feedDereferenced := false ctx := r.Context() + ctx = getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, getter.DefaultStrategyTimeout.String()) FETCH: // read manifest entry @@ -366,7 +392,6 @@ FETCH: jsonhttp.NotFound(w, "address not found or incorrect") return } - me, err := m.Lookup(ctx, pathVar) if err != nil { loggerV1.Debug("bzz download: invalid path", "address", address, "path", pathVar, "error", err) @@ -459,8 +484,10 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h Cache *bool `map:"Swarm-Cache"` Strategy getter.Strategy `map:"Swarm-Redundancy-Strategy"` FallbackMode bool `map:"Swarm-Redundancy-Fallback-Mode"` - ChunkRetrievalTimeout time.Duration `map:"Swarm-Chunk-Retrieval-Timeout"` + ChunkRetrievalTimeout string `map:"Swarm-Chunk-Retrieval-Timeout"` + LookaheadBufferSize *string `map:"Swarm-Lookahead-Buffer-Size"` }{} + if response := s.mapStructure(r.Header, &headers); response != nil { response("invalid header params", logger, w) return @@ -471,9 +498,7 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h } ctx := r.Context() - ctx = getter.SetStrategy(ctx, headers.Strategy) - ctx = getter.SetStrict(ctx, headers.FallbackMode) - ctx = getter.SetFetchTimeout(ctx, headers.ChunkRetrievalTimeout) + ctx = getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, getter.DefaultStrategyTimeout.String()) reader, l, err := joiner.New(ctx, s.storer.Download(cache), s.storer.Cache(), reference) if err != nil { if errors.Is(err, storage.ErrNotFound) || errors.Is(err, topology.ErrNotFound) { @@ -497,7 +522,19 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h } w.Header().Set(ContentLengthHeader, strconv.FormatInt(l, 10)) w.Header().Set("Access-Control-Expose-Headers", ContentDispositionHeader) - http.ServeContent(w, r, "", time.Now(), langos.NewBufferedLangos(reader, lookaheadBufferSize(l))) + bufSize := int64(lookaheadBufferSize(l)) + if headers.LookaheadBufferSize != nil { + bufSize, err = strconv.ParseInt(*headers.LookaheadBufferSize, 10, 64) + if err != nil { + logger.Debug("parsing lookahead buffer size", "error", err) + bufSize = 0 + } + } + if bufSize > 0 { + http.ServeContent(w, r, "", time.Now(), langos.NewBufferedLangos(reader, int(bufSize))) + return + } + http.ServeContent(w, r, "", time.Now(), reader) } // manifestMetadataLoad returns the value for a key stored in the metadata of diff --git a/pkg/api/bzz_test.go b/pkg/api/bzz_test.go index d2ad44e507c..634d4eb0166 100644 --- a/pkg/api/bzz_test.go +++ b/pkg/api/bzz_test.go @@ -16,20 +16,223 @@ import ( "strconv" "strings" "testing" + "time" "github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/file/loadsave" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/manifest" mockbatchstore "github.com/ethersphere/bee/pkg/postage/batchstore/mock" mockpost "github.com/ethersphere/bee/pkg/postage/mock" + "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" mockstorer "github.com/ethersphere/bee/pkg/storer/mock" "github.com/ethersphere/bee/pkg/swarm" + "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" ) -// nolint:paralleltest,tparallel +// nolint:paralleltest,tparallel,thelper + +// TestBzzUploadDownloadWithRedundancy tests the API for upload and download files +// with all combinations of redundancy level, encryption and size (levels, i.e., the +// +// height of the swarm hash tree). +// +// This is a variation on the same play as TestJoinerRedundancy +// but here the tested scenario is simplified since we are not testing the intricacies of +// download strategies, but only correct parameter passing and correct recovery functionality +// +// The test cases have the following structure: +// +// 1. upload a file with a given redundancy level and encryption +// +// 2. [positive test] download the file by the reference returned by the upload API response +// This uses range queries to target specific (number of) chunks of the file structure. +// During path traversal in the swarm hash tree, the underlying mocksore (forgetting) +// is in 'recording' mode, flagging all the retrieved chunks as chunks to forget. +// This is to simulate the scenario where some of the chunks are not available/lost +// NOTE: For this to work one needs to switch off lookaheadbuffer functionality +// (see langos pkg) +// +// 3. [negative test] attempt at downloading the file using once again the same root hash +// and the same redundancy strategy to find the file inaccessible after forgetting. +// +// 4. [positive test] attempt at downloading the file using a strategy that allows for +// using redundancy to reconstruct the file and find the file recoverable. +// +// nolint:thelper +func TestBzzUploadDownloadWithRedundancy(t *testing.T) { + fileUploadResource := "/bzz" + fileDownloadResource := func(addr string) string { return "/bzz/" + addr + "/" } + + testRedundancy := func(t *testing.T, rLevel redundancy.Level, encrypt bool, levels int, chunkCnt int, shardCnt int, parityCnt int) { + t.Helper() + seed, err := pseudorand.NewSeed() + if err != nil { + t.Fatal(err) + } + fetchTimeout := 100 * time.Millisecond + store := mockstorer.NewForgettingStore(inmemchunkstore.New()) + storerMock := mockstorer.NewWithChunkStore(store) + client, _, _, _ := newTestServer(t, testServerOptions{ + Storer: storerMock, + Logger: log.Noop, + Post: mockpost.New(mockpost.WithAcceptAll()), + }) + + dataReader := pseudorand.NewReader(seed, chunkCnt*swarm.ChunkSize) + + var refResponse api.BzzUploadResponse + jsonhttptest.Request(t, client, http.MethodPost, fileUploadResource, + http.StatusCreated, + jsonhttptest.WithRequestHeader(api.SwarmDeferredUploadHeader, "True"), + jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, batchOkStr), + jsonhttptest.WithRequestBody(dataReader), + jsonhttptest.WithRequestHeader(api.SwarmEncryptHeader, fmt.Sprintf("%t", encrypt)), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyLevelHeader, fmt.Sprintf("%d", rLevel)), + jsonhttptest.WithRequestHeader(api.ContentTypeHeader, "image/jpeg; charset=utf-8"), + jsonhttptest.WithUnmarshalJSONResponse(&refResponse), + ) + + t.Run("download multiple ranges without redundancy should succeed", func(t *testing.T) { + // the underlying chunk store is in recording mode, so all chunks retrieved + // in this test will be forgotten in the subsequent ones. + store.Record() + defer store.Unrecord() + // we intend to forget as many chunks as possible for the given redundancy level + forget := parityCnt + if parityCnt > shardCnt { + forget = shardCnt + } + if levels == 1 { + forget = 2 + } + start, end := 420, 450 + gap := swarm.ChunkSize + for j := 2; j < levels; j++ { + gap *= shardCnt + } + ranges := make([][2]int, forget) + for i := 0; i < forget; i++ { + pre := i * gap + ranges[i] = [2]int{pre + start, pre + end} + } + rangeHeader, want := createRangeHeader(dataReader, ranges) + + var body []byte + respHeaders := jsonhttptest.Request(t, client, http.MethodGet, + fileDownloadResource(refResponse.Reference.String()), + http.StatusPartialContent, + jsonhttptest.WithRequestHeader(api.RangeHeader, rangeHeader), + jsonhttptest.WithRequestHeader(api.SwarmLookAheadBufferSizeHeader, "0"), + // set for the replicas so that no replica gets deleted + jsonhttptest.WithRequestHeader(api.SwarmRedundancyLevelHeader, "0"), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyStrategyHeader, "0"), + jsonhttptest.WithRequestHeader(api.SwarmRedundancyFallbackModeHeader, "false"), + jsonhttptest.WithRequestHeader(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()), + jsonhttptest.WithPutResponseBody(&body), + ) + + got := parseRangeParts(t, respHeaders.Get(api.ContentTypeHeader), body) + + if len(got) != len(want) { + t.Fatalf("got %v parts, want %v parts", len(got), len(want)) + } + for i := 0; i < len(want); i++ { + if !bytes.Equal(got[i], want[i]) { + t.Errorf("part %v: got %q, want %q", i, string(got[i]), string(want[i])) + } + } + }) + + t.Run("download without redundancy should NOT succeed", func(t *testing.T) { + if rLevel == 0 { + t.Skip("NA") + } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "GET", fileDownloadResource(refResponse.Reference.String()), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set(api.SwarmRedundancyStrategyHeader, "0") + req.Header.Set(api.SwarmRedundancyFallbackModeHeader, "false") + req.Header.Set(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()) + + _, err = client.Do(req) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected error %v; got %v", io.ErrUnexpectedEOF, err) + } + }) + + t.Run("download with redundancy should succeed", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.TODO(), "GET", fileDownloadResource(refResponse.Reference.String()), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set(api.SwarmRedundancyStrategyHeader, "3") + req.Header.Set(api.SwarmRedundancyFallbackModeHeader, "true") + req.Header.Set(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()) + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status %d; got %d", http.StatusOK, resp.StatusCode) + } + _, err = dataReader.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + ok, err := dataReader.Equal(resp.Body) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatalf("content mismatch") + } + }) + } + for _, rLevel := range []redundancy.Level{1, 2, 3, 4} { + rLevel := rLevel + t.Run(fmt.Sprintf("level=%d", rLevel), func(t *testing.T) { + for _, encrypt := range []bool{false, true} { + encrypt := encrypt + shardCnt := rLevel.GetMaxShards() + parityCnt := rLevel.GetParities(shardCnt) + if encrypt { + shardCnt = rLevel.GetMaxEncShards() + parityCnt = rLevel.GetEncParities(shardCnt) + } + for _, levels := range []int{1, 2, 3} { + chunkCnt := 1 + switch levels { + case 1: + chunkCnt = 2 + case 2: + chunkCnt = shardCnt + 1 + case 3: + chunkCnt = shardCnt*shardCnt + 1 + } + levels := levels + t.Run(fmt.Sprintf("encrypt=%v levels=%d chunks=%d", encrypt, levels, chunkCnt), func(t *testing.T) { + if levels > 2 && (encrypt == (rLevel%2 == 1)) { + t.Skip("skipping to save time") + } + t.Parallel() + testRedundancy(t, rLevel, encrypt, levels, chunkCnt, shardCnt, parityCnt) + }) + } + } + }) + } +} + func TestBzzFiles(t *testing.T) { t.Parallel() @@ -459,35 +662,57 @@ func TestBzzFilesRangeRequests(t *testing.T) { } } -func createRangeHeader(data []byte, ranges [][2]int) (header string, parts [][]byte) { - header = "bytes=" - for i, r := range ranges { - if i > 0 { - header += ", " +func createRangeHeader(data interface{}, ranges [][2]int) (header string, parts [][]byte) { + getLen := func() int { + switch data := data.(type) { + case []byte: + return len(data) + case interface{ Size() int }: + return data.Size() + default: + panic("unknown data type") } - if r[0] >= 0 && r[1] >= 0 { - parts = append(parts, data[r[0]:r[1]]) - // Range: =-, end is inclusive - header += fmt.Sprintf("%v-%v", r[0], r[1]-1) - } else { - if r[0] >= 0 { - header += strconv.Itoa(r[0]) // Range: =- - parts = append(parts, data[r[0]:]) + } + getRange := func(start, end int) []byte { + switch data := data.(type) { + case []byte: + return data[start:end] + case io.ReadSeeker: + buf := make([]byte, end-start) + _, err := data.Seek(int64(start), io.SeekStart) + if err != nil { + panic(err) } - header += "-" - if r[1] >= 0 { - if r[0] >= 0 { - // Range: =-, end is inclusive - header += strconv.Itoa(r[1] - 1) - } else { - // Range: =-, the parameter is length - header += strconv.Itoa(r[1]) - } - parts = append(parts, data[:r[1]]) + _, err = io.ReadFull(data, buf) + if err != nil { + panic(err) } + return buf + default: + panic("unknown data type") + } + } + + rangeStrs := make([]string, len(ranges)) + for i, r := range ranges { + start, end := r[0], r[1] + switch { + case start < 0: + // Range: =-, the parameter is length + rangeStrs[i] = "-" + strconv.Itoa(end) + start = 0 + case r[1] < 0: + // Range: =- + rangeStrs[i] = strconv.Itoa(start) + "-" + end = getLen() + default: + // Range: =-, end is inclusive + rangeStrs[i] = fmt.Sprintf("%v-%v", start, end-1) } + parts = append(parts, getRange(start, end)) } - return + header = "bytes=" + strings.Join(rangeStrs, ", ") // nolint:staticcheck + return header, parts } func parseRangeParts(t *testing.T, contentType string, body []byte) (parts [][]byte) { diff --git a/pkg/api/debugstorage_test.go b/pkg/api/debugstorage_test.go index b83ddb5af99..72d8800b882 100644 --- a/pkg/api/debugstorage_test.go +++ b/pkg/api/debugstorage_test.go @@ -9,7 +9,7 @@ import ( "testing" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" - storer "github.com/ethersphere/bee/pkg/storer" + "github.com/ethersphere/bee/pkg/storer" mockstorer "github.com/ethersphere/bee/pkg/storer/mock" ) diff --git a/pkg/file/joiner/joiner.go b/pkg/file/joiner/joiner.go index c1ea6f5b4ea..dc9c850084c 100644 --- a/pkg/file/joiner/joiner.go +++ b/pkg/file/joiner/joiner.go @@ -11,7 +11,6 @@ import ( "io" "sync" "sync/atomic" - "time" "github.com/ethersphere/bee/pkg/bmt" "github.com/ethersphere/bee/pkg/encryption" @@ -41,24 +40,20 @@ type joiner struct { // decoderCache is cache of decoders for intermediate chunks type decoderCache struct { - fetcher storage.Getter // network retrieval interface to fetch chunks - putter storage.Putter // interface to local storage to save reconstructed chunks - mu sync.Mutex // mutex to protect cache - cache map[string]storage.Getter // map from chunk address to RS getter - strategy getter.Strategy // strategy to use for retrieval - strict bool // strict mode - fetcherTimeout time.Duration // timeout for each fetch + fetcher storage.Getter // network retrieval interface to fetch chunks + putter storage.Putter // interface to local storage to save reconstructed chunks + mu sync.Mutex // mutex to protect cache + cache map[string]storage.Getter // map from chunk address to RS getter + config getter.Config // getter configuration } // NewDecoderCache creates a new decoder cache -func NewDecoderCache(g storage.Getter, p storage.Putter, strategy getter.Strategy, strict bool, fetcherTimeout time.Duration) *decoderCache { +func NewDecoderCache(g storage.Getter, p storage.Putter, conf getter.Config) *decoderCache { return &decoderCache{ - fetcher: g, - putter: p, - cache: make(map[string]storage.Getter), - strategy: strategy, - strict: strict, - fetcherTimeout: fetcherTimeout, + fetcher: g, + putter: p, + cache: make(map[string]storage.Getter), + config: conf, } } @@ -90,7 +85,7 @@ func (g *decoderCache) GetOrCreate(addrs []swarm.Address, shardCnt int) storage. defer g.mu.Unlock() g.cache[key] = nil } - d = getter.New(addrs, shardCnt, g.fetcher, g.putter, g.strategy, g.strict, g.fetcherTimeout, remove) + d = getter.New(addrs, shardCnt, g.fetcher, g.putter, remove, g.config) g.cache[key] = d return d } @@ -98,7 +93,7 @@ func (g *decoderCache) GetOrCreate(addrs []swarm.Address, shardCnt int) storage. // New creates a new Joiner. A Joiner provides Read, Seek and Size functionalities. func New(ctx context.Context, g storage.Getter, putter storage.Putter, address swarm.Address) (file.Joiner, int64, error) { // retrieve the root chunk to read the total data length the be retrieved - rLevel := replicas.GetLevelFromContext(ctx) + rLevel := redundancy.GetLevelFromContext(ctx) rootChunkGetter := store.New(g) if rLevel != redundancy.NONE { rootChunkGetter = store.New(replicas.NewGetter(g, rLevel)) @@ -118,27 +113,32 @@ func New(ctx context.Context, g storage.Getter, putter storage.Putter, address s spanFn := func(data []byte) (redundancy.Level, int64) { return 0, int64(bmt.LengthFromSpan(data[:swarm.SpanSize])) } - var strategy getter.Strategy - var strict bool - var fetcherTimeout time.Duration + conf, err := getter.NewConfigFromContext(ctx, getter.DefaultConfig) + if err != nil { + return nil, 0, err + } // override stuff if root chunk has redundancy if rLevel != redundancy.NONE { _, parities := file.ReferenceCount(uint64(span), rLevel, encryption) rootParity = parities - strategy, strict, fetcherTimeout = getter.GetParamsFromContext(ctx) + spanFn = chunkToSpan if encryption { maxBranching = rLevel.GetMaxEncShards() } else { maxBranching = rLevel.GetMaxShards() } + } else { + // if root chunk has no redundancy, strategy is ignored and set to NONE and strict is set to true + conf.Strategy = getter.DATA + conf.Strict = true } j := &joiner{ addr: rootChunk.Address(), refLength: refLength, ctx: ctx, - decoders: NewDecoderCache(g, putter, strategy, strict, fetcherTimeout), + decoders: NewDecoderCache(g, putter, conf), span: span, rootData: rootData, rootParity: rootParity, diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 7a97080171b..15d46bf220b 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -27,12 +27,17 @@ import ( storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" testingc "github.com/ethersphere/bee/pkg/storage/testing" + mockstorer "github.com/ethersphere/bee/pkg/storer/mock" "github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/util/testutil" + "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" + "github.com/ethersphere/bee/pkg/util/testutil/racedetection" "gitlab.com/nolash/go-mockbytes" "golang.org/x/sync/errgroup" ) +// nolint:paralleltest,tparallel,thelper + func TestJoiner_ErrReferenceLength(t *testing.T) { t.Parallel() @@ -1018,12 +1023,9 @@ func (m *mockPutter) store(cnt int) error { return nil } +// nolint:thelper func TestJoinerRedundancy(t *testing.T) { - - strategyTimeout := getter.StrategyTimeout - defer func() { getter.StrategyTimeout = strategyTimeout }() - getter.StrategyTimeout = 100 * time.Millisecond - + t.Parallel() for _, tc := range []struct { rLevel redundancy.Level encryptChunk bool @@ -1063,10 +1065,8 @@ func TestJoinerRedundancy(t *testing.T) { } { tc := tc t.Run(fmt.Sprintf("redundancy=%d encryption=%t", tc.rLevel, tc.encryptChunk), func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) defer cancel() - shardCnt := tc.rLevel.GetMaxShards() parityCnt := tc.rLevel.GetParities(shardCnt) if tc.encryptChunk { @@ -1109,13 +1109,12 @@ func TestJoinerRedundancy(t *testing.T) { if err != nil { t.Fatal(err) } + strategyTimeout := 100 * time.Millisecond // all data can be read back readCheck := func(t *testing.T, expErr error) { - t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), 15*getter.StrategyTimeout) + ctx, cancel := context.WithTimeout(context.Background(), 15*strategyTimeout) defer cancel() - ctx = getter.SetFetchTimeout(ctx, getter.StrategyTimeout) + ctx = getter.SetConfigInContext(ctx, getter.RACE, true, (10 * strategyTimeout).String(), strategyTimeout.String()) joinReader, rootSpan, err := joiner.New(ctx, store, store, swarmAddr) if err != nil { t.Fatal(err) @@ -1127,10 +1126,11 @@ func TestJoinerRedundancy(t *testing.T) { } i := 0 eg, ectx := errgroup.WithContext(ctx) + scnt: for ; i < shardCnt; i++ { select { case <-ectx.Done(): - break + break scnt default: } i := i @@ -1153,6 +1153,7 @@ func TestJoinerRedundancy(t *testing.T) { }) } err = eg.Wait() + if !errors.Is(err, expErr) { t.Fatalf("unexpected error reading chunkdata at chunk position %d: expected %v. got %v", i, expErr, err) } @@ -1191,3 +1192,171 @@ func TestJoinerRedundancy(t *testing.T) { }) } } + +// TestJoinerRedundancyMultilevel tests the joiner with all combinations of +// redundancy level, encryption and size (levels, i.e., the height of the swarm hash tree). +// +// The test cases have the following structure: +// +// 1. upload a file with a given redundancy level and encryption +// +// 2. [positive test] download the file by the reference returned by the upload API response +// This uses range queries to target specific (number of) chunks of the file structure +// During path traversal in the swarm hash tree, the underlying mocksore (forgetting) +// is in 'recording' mode, flagging all the retrieved chunks as chunks to forget. +// This is to simulate the scenario where some of the chunks are not available/lost +// +// 3. [negative test] attempt at downloading the file using once again the same root hash +// and a no-redundancy strategy to find the file inaccessible after forgetting. +// 3a. [negative test] download file using NONE without fallback and fail +// 3b. [negative test] download file using DATA without fallback and fail +// +// 4. [positive test] download file using DATA with fallback to allow for +// reconstruction via erasure coding and succeed. +// +// 5. [positive test] after recovery chunks are saved, so fotgetting no longer +// repeat 3a/3b but this time succeed +// +// nolint:thelper +func TestJoinerRedundancyMultilevel(t *testing.T) { + t.Parallel() + test := func(t *testing.T, rLevel redundancy.Level, encrypt bool, levels, size int) { + t.Helper() + store := mockstorer.NewForgettingStore(inmemchunkstore.New()) + testutil.CleanupCloser(t, store) + seed, err := pseudorand.NewSeed() + if err != nil { + t.Fatal(err) + } + dataReader := pseudorand.NewReader(seed, size*swarm.ChunkSize) + ctx := context.Background() + // ctx = redundancy.SetLevelInContext(ctx, rLevel) + ctx = redundancy.SetLevelInContext(ctx, redundancy.NONE) + pipe := builder.NewPipelineBuilder(ctx, store, encrypt, rLevel) + addr, err := builder.FeedPipeline(ctx, pipe, dataReader) + if err != nil { + t.Fatal(err) + } + expRead := swarm.ChunkSize + buf := make([]byte, expRead) + offset := mrand.Intn(size) * expRead + canReadRange := func(t *testing.T, s getter.Strategy, fallback bool, levels int, canRead bool) { + ctx := context.Background() + strategyTimeout := 100 * time.Millisecond + decodingTimeout := 600 * time.Millisecond + if racedetection.IsOn() { + decodingTimeout *= 2 + } + ctx = getter.SetConfigInContext(ctx, s, fallback, (2 * strategyTimeout).String(), strategyTimeout.String()) + ctx, cancel := context.WithTimeout(ctx, time.Duration(levels)*(3*strategyTimeout+decodingTimeout)) + defer cancel() + j, _, err := joiner.New(ctx, store, store, addr) + if err != nil { + t.Fatal(err) + } + n, err := j.ReadAt(buf, int64(offset)) + if !canRead { + if !errors.Is(err, storage.ErrNotFound) && !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected error %v or %v. got %v", storage.ErrNotFound, context.DeadlineExceeded, err) + } + return + } + if err != nil { + t.Fatal(err) + } + if n != expRead { + t.Errorf("read %d bytes out of %d", n, expRead) + } + _, err = dataReader.Seek(int64(offset), io.SeekStart) + if err != nil { + t.Fatal(err) + } + ok, err := dataReader.Match(bytes.NewBuffer(buf), expRead) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("content mismatch") + } + } + + // first sanity check and and recover a range + t.Run("NONE w/o fallback CAN retrieve", func(t *testing.T) { + store.Record() + defer store.Unrecord() + canReadRange(t, getter.NONE, false, levels, true) + }) + + // do not forget the root chunk + store.Unmiss(swarm.NewAddress(addr.Bytes()[:swarm.HashSize])) + // after we forget the chunks on the way to the range, we should not be able to retrieve + t.Run("NONE w/o fallback CANNOT retrieve", func(t *testing.T) { + canReadRange(t, getter.NONE, false, levels, false) + }) + + // we lost a data chunk, we cannot recover using DATA only strategy with no fallback + t.Run("DATA w/o fallback CANNOT retrieve", func(t *testing.T) { + canReadRange(t, getter.DATA, false, levels, false) + }) + + if rLevel == 0 { + // allowing fallback mode will not help if no redundancy used for upload + t.Run("DATA w fallback CANNOT retrieve", func(t *testing.T) { + canReadRange(t, getter.DATA, true, levels, false) + }) + return + } + // allowing fallback mode will make the range retrievable using erasure decoding + t.Run("DATA w fallback CAN retrieve", func(t *testing.T) { + canReadRange(t, getter.DATA, true, levels, true) + }) + // after the reconstructed data is stored, we can retrieve the range using DATA only mode + t.Run("after recovery, NONE w/o fallback CAN retrieve", func(t *testing.T) { + canReadRange(t, getter.NONE, false, levels, true) + }) + } + r2level := []int{2, 1, 2, 3, 2} + encryptChunk := []bool{false, false, true, true, true} + for _, rLevel := range []redundancy.Level{0, 1, 2, 3, 4} { + rLevel := rLevel + // speeding up tests by skipping some of them + t.Run(fmt.Sprintf("rLevel=%v", rLevel), func(t *testing.T) { + t.Parallel() + for _, encrypt := range []bool{false, true} { + encrypt := encrypt + shardCnt := rLevel.GetMaxShards() + if encrypt { + shardCnt = rLevel.GetMaxEncShards() + } + for _, levels := range []int{1, 2, 3} { + chunkCnt := 1 + switch levels { + case 1: + chunkCnt = 2 + case 2: + chunkCnt = shardCnt + 1 + case 3: + chunkCnt = shardCnt*shardCnt + 1 + } + t.Run(fmt.Sprintf("encrypt=%v levels=%d chunks=%d incomplete", encrypt, levels, chunkCnt), func(t *testing.T) { + if r2level[rLevel] != levels || encrypt != encryptChunk[rLevel] { + t.Skip("skipping to save time") + } + test(t, rLevel, encrypt, levels, chunkCnt) + }) + switch levels { + case 1: + chunkCnt = shardCnt + case 2: + chunkCnt = shardCnt * shardCnt + case 3: + continue + } + t.Run(fmt.Sprintf("encrypt=%v levels=%d chunks=%d full", encrypt, levels, chunkCnt), func(t *testing.T) { + test(t, rLevel, encrypt, levels, chunkCnt) + }) + } + } + }) + } +} diff --git a/pkg/file/pipeline/hashtrie/hashtrie_test.go b/pkg/file/pipeline/hashtrie/hashtrie_test.go index c3729f78901..4a966e265f3 100644 --- a/pkg/file/pipeline/hashtrie/hashtrie_test.go +++ b/pkg/file/pipeline/hashtrie/hashtrie_test.go @@ -24,7 +24,6 @@ import ( "github.com/ethersphere/bee/pkg/file/pipeline/mock" "github.com/ethersphere/bee/pkg/file/pipeline/store" "github.com/ethersphere/bee/pkg/file/redundancy" - "github.com/ethersphere/bee/pkg/replicas" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" "github.com/ethersphere/bee/pkg/swarm" @@ -46,7 +45,7 @@ func init() { binary.LittleEndian.PutUint64(span, 1) } -// NewErasureHashTrieWriter returns back an redundancy param and a HastTrieWriter pipeline +// newErasureHashTrieWriter returns back an redundancy param and a HastTrieWriter pipeline // which are using simple BMT and StoreWriter pipelines for chunk writes func newErasureHashTrieWriter( ctx context.Context, @@ -303,20 +302,20 @@ func TestRedundancy(t *testing.T) { level: redundancy.INSANE, encryption: false, writes: 98, // 97 chunk references fit into one chunk + 1 carrier - parities: 38, // 31 (full ch) + 7 (2 ref) + parities: 37, // 31 (full ch) + 6 (2 ref) }, { desc: "redundancy write for encrypted data", level: redundancy.PARANOID, encryption: true, writes: 21, // 21 encrypted chunk references fit into one chunk + 1 carrier - parities: 118, // // 88 (full ch) + 30 (2 ref) + parities: 116, // // 87 (full ch) + 29 (2 ref) }, } { tc := tc t.Run(tc.desc, func(t *testing.T) { t.Parallel() - subCtx := replicas.SetLevel(ctx, tc.level) + subCtx := redundancy.SetLevelInContext(ctx, tc.level) s := inmemchunkstore.New() intermediateChunkCounter := mock.NewChainWriter() diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go index e23f2a3e6af..4e8da1b6390 100644 --- a/pkg/file/redundancy/getter/getter.go +++ b/pkg/file/redundancy/getter/getter.go @@ -9,7 +9,6 @@ import ( "io" "sync" "sync/atomic" - "time" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" @@ -28,13 +27,16 @@ type decoder struct { waits []chan struct{} // wait channels for each chunk rsbuf [][]byte // RS buffer of data + parity shards for erasure decoding ready chan struct{} // signal channel for successful retrieval of shardCnt chunks + lastLen int // length of the last data chunk in the RS buffer shardCnt int // number of data shards parityCnt int // number of parity shards wg sync.WaitGroup // wait group to wait for all goroutines to finish mu sync.Mutex // mutex to protect buffer + err error // error of the last erasure decoding fetchedCnt atomic.Int32 // count successful retrievals cancel func() // cancel function for RS decoding remove func() // callback to remove decoder from decoders cache + config Config // configuration } type Getter interface { @@ -42,14 +44,10 @@ type Getter interface { io.Closer } -// New returns a decoder object used tos retrieve children of an intermediate chunk -func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter, strategy Strategy, strict bool, fetchTimeout time.Duration, remove func()) Getter { +// New returns a decoder object used to retrieve children of an intermediate chunk +func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter, remove func(), conf Config) Getter { ctx, cancel := context.WithCancel(context.Background()) size := len(addrs) - if fetchTimeout == 0 { - fetchTimeout = 30 * time.Second - } - strategyTimeout := StrategyTimeout rsg := &decoder{ fetcher: g, @@ -64,6 +62,7 @@ func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter remove: remove, shardCnt: shardCnt, parityCnt: size - shardCnt, + config: conf, } // after init, cache and wait channels are immutable, need no locking @@ -73,11 +72,13 @@ func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter } // prefetch chunks according to strategy - rsg.wg.Add(1) - go func() { - rsg.prefetch(ctx, strategy, strict, strategyTimeout, fetchTimeout) - rsg.wg.Done() - }() + if !conf.Strict || conf.Strategy != NONE { + rsg.wg.Add(1) + go func() { + rsg.err = rsg.prefetch(ctx) + rsg.wg.Done() + }() + } return rsg } @@ -97,10 +98,30 @@ func (g *decoder) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, err } select { case <-g.waits[i]: - return swarm.NewChunk(addr, g.rsbuf[i]), nil case <-ctx.Done(): return nil, ctx.Err() } + return swarm.NewChunk(addr, g.getData(i)), nil +} + +// setData sets the data shard in the RS buffer +func (g *decoder) setData(i int, chdata []byte) { + data := chdata + // pad the chunk with zeros if it is smaller than swarm.ChunkSize + if len(data) < swarm.ChunkWithSpanSize { + g.lastLen = len(data) + data = make([]byte, swarm.ChunkWithSpanSize) + copy(data, chdata) + } + g.rsbuf[i] = data +} + +// getData returns the data shard from the RS buffer +func (g *decoder) getData(i int) []byte { + if i == g.shardCnt-1 && g.lastLen > 0 { + return g.rsbuf[i][:g.lastLen] // cut padding + } + return g.rsbuf[i] } // fly commits to retrieve the chunk (fly and land) @@ -115,7 +136,9 @@ func (g *decoder) fly(i int, up bool) (success bool) { // it races with erasure recovery which takes precedence even if it started later // due to the fact that erasure recovery could only implement global locking on all shards func (g *decoder) fetch(ctx context.Context, i int) { - ch, err := g.fetcher.Get(ctx, g.addrs[i]) + fctx, cancel := context.WithTimeout(ctx, g.config.FetchTimeout) + defer cancel() + ch, err := g.fetcher.Get(fctx, g.addrs[i]) if err != nil { _ = g.fly(i, false) // unset inflight return @@ -139,8 +162,8 @@ func (g *decoder) fetch(ctx context.Context, i int) { } // write chunk to rsbuf and signal waiters - g.rsbuf[i] = ch.Data() // save the chunk in the RS buffer - if i < len(g.waits) { + g.setData(i, ch.Data()) // save the chunk in the RS buffer + if i < len(g.waits) { // if the chunk is a data shard close(g.waits[i]) // signal that the chunk is retrieved } @@ -218,6 +241,9 @@ func (g *decoder) save(ctx context.Context, missing []int) error { return nil } +// Close terminates the prefetch loop, waits for all goroutines to finish and +// removes the decoder from the cache +// it implements the io.Closer interface func (g *decoder) Close() error { g.cancel() g.wg.Wait() diff --git a/pkg/file/redundancy/getter/getter_test.go b/pkg/file/redundancy/getter/getter_test.go index bb00f7ddd02..b18caa55c12 100644 --- a/pkg/file/redundancy/getter/getter_test.go +++ b/pkg/file/redundancy/getter/getter_test.go @@ -21,37 +21,13 @@ import ( "github.com/ethersphere/bee/pkg/file/redundancy/getter" "github.com/ethersphere/bee/pkg/storage" inmem "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" + mockstorer "github.com/ethersphere/bee/pkg/storer/mock" "github.com/ethersphere/bee/pkg/swarm" + "github.com/ethersphere/bee/pkg/util/testutil/racedetection" "github.com/klauspost/reedsolomon" "golang.org/x/sync/errgroup" ) -type delayed struct { - storage.ChunkStore - cache map[string]time.Duration - mu sync.Mutex -} - -func (d *delayed) delay(addr swarm.Address, delay time.Duration) { - d.mu.Lock() - defer d.mu.Unlock() - d.cache[addr.String()] = delay -} - -func (d *delayed) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { - d.mu.Lock() - defer d.mu.Unlock() - if delay, ok := d.cache[addr.String()]; ok && delay > 0 { - select { - case <-time.After(delay): - delete(d.cache, addr.String()) - case <-ctx.Done(): - return nil, ctx.Err() - } - } - return d.ChunkStore.Get(ctx, addr) -} - // TestGetter tests the retrieval of chunks with missing data shards // using the RACE strategy for a number of erasure code parameters func TestGetterRACE(t *testing.T) { @@ -118,11 +94,10 @@ func TestGetterFallback(t *testing.T) { func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { t.Helper() - - strategyTimeout := getter.StrategyTimeout - defer func() { getter.StrategyTimeout = strategyTimeout }() - getter.StrategyTimeout = 100 * time.Millisecond - + strategyTimeout := 100 * time.Millisecond + if racedetection.On { + strategyTimeout *= 2 + } store := inmem.New() buf := make([][]byte, bufSize) addrs := initData(t, buf, shardCnt, store) @@ -140,7 +115,12 @@ func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { } ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - g := getter.New(addrs, shardCnt, store, store, getter.RACE, false, 2*getter.StrategyTimeout, func() {}) + conf := getter.Config{ + Strategy: getter.RACE, + FetchTimeout: 2 * strategyTimeout, + StrategyTimeout: strategyTimeout, + } + g := getter.New(addrs, shardCnt, store, store, func() {}, conf) defer g.Close() parityCnt := len(buf) - shardCnt q := make(chan error, 1) @@ -149,9 +129,13 @@ func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { q <- err }() err := context.DeadlineExceeded + wait := strategyTimeout * 2 + if racedetection.On { + wait *= 2 + } select { case err = <-q: - case <-time.After(getter.StrategyTimeout * 10): + case <-time.After(wait): } switch { case erasureCnt > parityCnt: @@ -175,13 +159,11 @@ func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { t.Helper() - strategyTimeout := getter.StrategyTimeout - defer func() { getter.StrategyTimeout = strategyTimeout }() - getter.StrategyTimeout = 100 * time.Millisecond + strategyTimeout := 150 * time.Millisecond bufSize := 12 shardCnt := 6 - store := &delayed{ChunkStore: inmem.New(), cache: make(map[string]time.Duration)} + store := mockstorer.NewDelayedStore(inmem.New()) buf := make([][]byte, bufSize) addrs := initData(t, buf, shardCnt, store) @@ -201,14 +183,20 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { waitDelayed, waitErased := make(chan error, 1), make(chan error, 1) // complete retrieval of delayed chunk by putting it into the store after a while - delay := +getter.StrategyTimeout / 4 + delay := strategyTimeout / 4 if s == getter.NONE { - delay += getter.StrategyTimeout + delay += strategyTimeout } - store.delay(addrs[delayed], delay) + store.Delay(addrs[delayed], delay) // create getter start := time.Now() - g := getter.New(addrs, shardCnt, store, store, s, strict, getter.StrategyTimeout/2, func() {}) + conf := getter.Config{ + Strategy: s, + Strict: strict, + FetchTimeout: strategyTimeout / 2, + StrategyTimeout: strategyTimeout, + } + g := getter.New(addrs, shardCnt, store, store, func() {}, conf) defer g.Close() // launch delayed and erased chunk retrieval @@ -219,11 +207,15 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { // delayed and erased chunk retrieval completes go func() { defer wg.Done() + ctx, cancel := context.WithTimeout(ctx, strategyTimeout*time.Duration(5-s)) + defer cancel() _, err := g.Get(ctx, addrs[delayed]) waitDelayed <- err }() go func() { defer wg.Done() + ctx, cancel := context.WithTimeout(ctx, strategyTimeout*time.Duration(5-s)) + defer cancel() _, err := g.Get(ctx, addrs[erased]) waitErased <- err }() @@ -234,7 +226,7 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { if err != nil { t.Fatal("unexpected error", err) } - round := time.Since(start) / getter.StrategyTimeout + round := time.Since(start) / strategyTimeout switch { case strict && s == getter.NONE: if round < 1 { @@ -260,15 +252,15 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { if err != nil { t.Fatal("unexpected error", err) } - round = time.Since(start) / getter.StrategyTimeout + round = time.Since(start) / strategyTimeout switch { case strict: t.Fatalf("unexpected completion of erased chunk retrieval. got round %d", round) case s == getter.NONE: - if round < 2 { + if round < 3 { t.Fatalf("unexpected early completion of erased chunk retrieval. got round %d", round) } - if round > 2 { + if round > 3 { t.Fatalf("unexpected late completion of erased chunk retrieval. got round %d", round) } case s == getter.DATA: @@ -281,12 +273,12 @@ func testDecodingFallback(t *testing.T, s getter.Strategy, strict bool) { } checkShardsAvailable(t, store, addrs[:erased], buf[:erased]) - case <-time.After(getter.StrategyTimeout * 2): + case <-time.After(strategyTimeout * 2): if !strict { t.Fatal("unexpected timeout using strategy", s, "with strict", strict) } } - case <-time.After(getter.StrategyTimeout * 3): + case <-time.After(strategyTimeout * 3): if !strict || s != getter.NONE { t.Fatal("unexpected timeout using strategy", s, "with strict", strict) } diff --git a/pkg/file/redundancy/getter/strategies.go b/pkg/file/redundancy/getter/strategies.go index 8bf944e8ae7..bb5188e9cc2 100644 --- a/pkg/file/redundancy/getter/strategies.go +++ b/pkg/file/redundancy/getter/strategies.go @@ -6,21 +6,34 @@ package getter import ( "context" + "errors" "fmt" "time" ) -var ( - StrategyTimeout = 500 * time.Millisecond // timeout for each strategy +const ( + DefaultStrategy = NONE // default prefetching strategy + DefaultStrict = true // default fallback modes + DefaultFetchTimeout = 30 * time.Second // timeout for each chunk retrieval + DefaultStrategyTimeout = 300 * time.Millisecond // timeout for each strategy ) type ( - strategyKey struct{} - modeKey struct{} - fetcherTimeoutKey struct{} - Strategy = int + strategyKey struct{} + modeKey struct{} + fetchTimeoutKey struct{} + strategyTimeoutKey struct{} + Strategy = int ) +// Config is the configuration for the getter - public +type Config struct { + Strategy Strategy + Strict bool + FetchTimeout time.Duration + StrategyTimeout time.Duration +} + const ( NONE Strategy = iota // no prefetching and no decoding DATA // just retrieve data shards no decoding @@ -29,17 +42,57 @@ const ( strategyCnt ) -// GetParamsFromContext extracts the strategy and strict mode from the context -func GetParamsFromContext(ctx context.Context) (s Strategy, strict bool, fetcherTimeout time.Duration) { - s, _ = ctx.Value(strategyKey{}).(Strategy) - strict, _ = ctx.Value(modeKey{}).(bool) - fetcherTimeout, _ = ctx.Value(fetcherTimeoutKey{}).(time.Duration) - return s, strict, fetcherTimeout +// DefaultConfig is the default configuration for the getter +var DefaultConfig = Config{ + Strategy: DefaultStrategy, + Strict: DefaultStrict, + FetchTimeout: DefaultFetchTimeout, + StrategyTimeout: DefaultStrategyTimeout, } -// SetFetchTimeout sets the timeout for each fetch -func SetFetchTimeout(ctx context.Context, timeout time.Duration) context.Context { - return context.WithValue(ctx, fetcherTimeoutKey{}, timeout) +// NewConfigFromContext returns a new Config based on the context +func NewConfigFromContext(ctx context.Context, def Config) (conf Config, err error) { + var ok bool + conf = def + e := func(s string, errs ...error) error { + if len(errs) > 0 { + return fmt.Errorf("error setting %s from context: %w", s, errors.Join(errs...)) + } + return fmt.Errorf("error setting %s from context", s) + } + if val := ctx.Value(strategyKey{}); val != nil { + conf.Strategy, ok = val.(Strategy) + if !ok { + return conf, e("strategy") + } + } + if val := ctx.Value(modeKey{}); val != nil { + conf.Strict, ok = val.(bool) + if !ok { + return conf, e("fallback mode") + } + } + if val := ctx.Value(fetchTimeoutKey{}); val != nil { + fetchTimeoutVal, ok := val.(string) + if !ok { + return conf, e("fetcher timeout") + } + conf.FetchTimeout, err = time.ParseDuration(fetchTimeoutVal) + if err != nil { + return conf, e("fetcher timeout", err) + } + } + if val := ctx.Value(strategyTimeoutKey{}); val != nil { + strategyTimeoutVal, ok := val.(string) + if !ok { + return conf, e("fetcher timeout") + } + conf.StrategyTimeout, err = time.ParseDuration(strategyTimeoutVal) + if err != nil { + return conf, e("fetcher timeout", err) + } + } + return conf, nil } // SetStrategy sets the strategy for the retrieval @@ -52,9 +105,28 @@ func SetStrict(ctx context.Context, strict bool) context.Context { return context.WithValue(ctx, modeKey{}, strict) } -func (g *decoder) prefetch(ctx context.Context, strategy int, strict bool, strategyTimeout, fetchTimeout time.Duration) { - if strict && strategy == NONE { - return +// SetFetchTimeout sets the timeout for each fetch +func SetFetchTimeout(ctx context.Context, timeout string) context.Context { + return context.WithValue(ctx, fetchTimeoutKey{}, timeout) +} + +// SetStrategyTimeout sets the timeout for each strategy +func SetStrategyTimeout(ctx context.Context, timeout string) context.Context { + return context.WithValue(ctx, fetchTimeoutKey{}, timeout) +} + +// SetConfigInContext sets the config params in the context +func SetConfigInContext(ctx context.Context, s Strategy, fallbackmode bool, fetchTimeout, strategyTimeout string) context.Context { + ctx = SetStrategy(ctx, s) + ctx = SetStrict(ctx, !fallbackmode) + ctx = SetFetchTimeout(ctx, fetchTimeout) + ctx = SetStrategyTimeout(ctx, strategyTimeout) + return ctx +} + +func (g *decoder) prefetch(ctx context.Context) error { + if g.config.Strict && g.config.Strategy == NONE { + return nil } defer g.remove() var cancels []func() @@ -66,16 +138,16 @@ func (g *decoder) prefetch(ctx context.Context, strategy int, strict bool, strat defer cancelAll() run := func(s Strategy) error { if s == PROX { // NOT IMPLEMENTED - return fmt.Errorf("strategy %d not implemented", s) + return errors.New("strategy not implemented") } var stop <-chan time.Time if s < RACE { - timer := time.NewTimer(strategyTimeout) + timer := time.NewTimer(g.config.StrategyTimeout) defer timer.Stop() stop = timer.C } - lctx, cancel := context.WithTimeout(ctx, fetchTimeout) + lctx, cancel := context.WithCancel(ctx) cancels = append(cancels, cancel) prefetch(lctx, g, s) @@ -93,9 +165,14 @@ func (g *decoder) prefetch(ctx context.Context, strategy int, strict bool, strat return g.recover(ctx) // context to cancel when shardCnt chunks are retrieved } var err error - for s := strategy; s == strategy || (err != nil && !strict && s < strategyCnt); s++ { + for s := g.config.Strategy; s < strategyCnt; s++ { err = run(s) + if g.config.Strict || err == nil { + break + } } + + return err } // prefetch launches the retrieval of chunks based on the strategy diff --git a/pkg/file/redundancy/level.go b/pkg/file/redundancy/level.go index 3dfba1cd084..f7b4f5c19f1 100644 --- a/pkg/file/redundancy/level.go +++ b/pkg/file/redundancy/level.go @@ -5,6 +5,7 @@ package redundancy import ( + "context" "errors" "fmt" @@ -106,44 +107,40 @@ func (l Level) Decrement() Level { // TABLE INITS var mediumEt = newErasureTable( - []int{94, 68, 46, 28, 14, 5, 1}, - []int{9, 8, 7, 6, 5, 4, 3}, + []int{95, 69, 47, 29, 15, 6, 2, 1}, + []int{9, 8, 7, 6, 5, 4, 3, 2}, ) var encMediumEt = newErasureTable( - []int{47, 34, 23, 14, 7, 2}, - []int{9, 8, 7, 6, 5, 4}, + []int{47, 34, 23, 14, 7, 3, 1}, + []int{9, 8, 7, 6, 5, 4, 3}, ) var strongEt = newErasureTable( - []int{104, 95, 86, 77, 69, 61, 53, 46, 39, 32, 26, 20, 15, 10, 6, 3, 1}, - []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5}, + []int{105, 96, 87, 78, 70, 62, 54, 47, 40, 33, 27, 21, 16, 11, 7, 4, 2, 1}, + []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4}, ) var encStrongEt = newErasureTable( - []int{52, 47, 43, 38, 34, 30, 26, 23, 19, 16, 13, 10, 7, 5, 3, 1}, - []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, + []int{52, 48, 43, 39, 35, 31, 27, 23, 20, 16, 13, 10, 8, 5, 3, 2, 1}, + []int{21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5}, ) var insaneEt = newErasureTable( - []int{92, 87, 82, 77, 73, 68, 63, 59, 54, 50, 45, 41, 37, 33, 29, 26, 22, 19, 16, 13, 10, 8, 5, 3, 2, 1}, - []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, + []int{93, 88, 83, 78, 74, 69, 64, 60, 55, 51, 46, 42, 38, 34, 30, 27, 23, 20, 17, 14, 11, 9, 6, 4, 3, 2, 1}, + []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5}, ) var encInsaneEt = newErasureTable( - []int{46, 43, 41, 38, 36, 34, 31, 29, 27, 25, 22, 20, 18, 16, 14, 13, 11, 9, 8, 6, 5, 4, 2, 1}, - []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 7}, + []int{46, 44, 41, 39, 37, 34, 32, 30, 27, 25, 23, 21, 19, 17, 15, 13, 11, 10, 8, 7, 5, 4, 3, 2, 1}, + []int{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 6}, ) var paranoidEt = newErasureTable( []int{ - 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, - 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, - 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, - 7, 6, 5, 4, 3, 2, 1, + 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, + 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, }, []int{ - 90, 88, 87, 85, 84, 82, 81, 79, 77, 76, - 74, 72, 71, 69, 67, 66, 64, 62, 60, 59, - 57, 55, 53, 51, 49, 48, 46, 44, 41, 39, - 37, 35, 32, 30, 27, 24, 20, + 89, 87, 86, 84, 83, 81, 80, 78, 76, 75, 73, 71, 70, 68, 66, 65, 63, 61, 59, 58, + 56, 54, 52, 50, 48, 47, 45, 43, 40, 38, 36, 34, 31, 29, 26, 23, 19, }, ) var encParanoidEt = newErasureTable( @@ -152,8 +149,8 @@ var encParanoidEt = newErasureTable( 8, 7, 6, 5, 4, 3, 2, 1, }, []int{ - 88, 85, 82, 79, 76, 72, 69, 66, 62, 59, - 55, 51, 48, 44, 39, 35, 30, 24, + 87, 84, 81, 78, 75, 71, 68, 65, 61, 58, + 54, 50, 47, 43, 38, 34, 29, 23, }, ) @@ -167,3 +164,19 @@ func GetReplicaCounts() [5]int { // for the five levels of redundancy are 0, 2, 4, 5, 19 // we use an approximation as the successive powers of 2 var replicaCounts = [5]int{0, 2, 4, 8, 16} + +type levelKey struct{} + +// SetLevelInContext sets the redundancy level in the context +func SetLevelInContext(ctx context.Context, level Level) context.Context { + return context.WithValue(ctx, levelKey{}, level) +} + +// GetLevelFromContext is a helper function to extract the redundancy level from the context +func GetLevelFromContext(ctx context.Context) Level { + rlevel := PARANOID + if val := ctx.Value(levelKey{}); val != nil { + rlevel = val.(Level) + } + return rlevel +} diff --git a/pkg/replicas/getter.go b/pkg/replicas/getter.go index 496bc5d658a..26345919958 100644 --- a/pkg/replicas/getter.go +++ b/pkg/replicas/getter.go @@ -24,20 +24,7 @@ import ( // then the probability of Swarmageddon is less than 0.000001 // assuming the error rate of chunk retrievals stays below the level expressed // as depth by the publisher. -type ErrSwarmageddon struct { - error -} - -func (err *ErrSwarmageddon) Unwrap() []error { - if err == nil || err.error == nil { - return nil - } - var uwe interface{ Unwrap() []error } - if !errors.As(err.error, &uwe) { - return nil - } - return uwe.Unwrap() -} +var ErrSwarmageddon = errors.New("swarmageddon has begun") // getter is the private implementation of storage.Getter, an interface for // retrieving chunks. This getter embeds the original simple chunk getter and extends it @@ -69,7 +56,7 @@ func (g *getter) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, e resultC := make(chan swarm.Chunk) // errc collects the errors errc := make(chan error, 17) - var errs []error + var errs error errcnt := 0 // concurrently call to retrieve chunk using original CAC address @@ -108,10 +95,10 @@ func (g *getter) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, e return chunk, nil case err = <-errc: - errs = append(errs, err) + errs = errors.Join(errs, err) errcnt++ if errcnt > total { - return nil, &ErrSwarmageddon{errors.Join(errs...)} + return nil, errors.Join(ErrSwarmageddon, errs) } // ticker switches on the address channel diff --git a/pkg/replicas/getter_test.go b/pkg/replicas/getter_test.go index 3b11ad26d94..7435b0574fd 100644 --- a/pkg/replicas/getter_test.go +++ b/pkg/replicas/getter_test.go @@ -195,18 +195,11 @@ func TestGetter(t *testing.T) { } t.Run("returns correct error", func(t *testing.T) { - var esg *replicas.ErrSwarmageddon - if !errors.As(err, &esg) { + if !errors.Is(err, replicas.ErrSwarmageddon) { t.Fatalf("incorrect error. want Swarmageddon. got %v", err) } - errs := esg.Unwrap() - for _, err := range errs { - if !errors.Is(err, tc.failure.err) { - t.Fatalf("incorrect error. want it to wrap %v. got %v", tc.failure.err, err) - } - } - if len(errs) != tc.count+1 { - t.Fatalf("incorrect error. want %d. got %d", tc.count+1, len(errs)) + if !errors.Is(err, tc.failure.err) { + t.Fatalf("incorrect error. want it to wrap %v. got %v", tc.failure.err, err) } }) } @@ -265,7 +258,6 @@ func TestGetter(t *testing.T) { } } }) - }) } } diff --git a/pkg/replicas/putter.go b/pkg/replicas/putter.go index f2334a994b8..4aa55b638f0 100644 --- a/pkg/replicas/putter.go +++ b/pkg/replicas/putter.go @@ -11,6 +11,7 @@ import ( "errors" "sync" + "github.com/ethersphere/bee/pkg/file/redundancy" "github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" @@ -29,7 +30,7 @@ func NewPutter(p storage.Putter) storage.Putter { // Put makes the getter satisfy the storage.Getter interface func (p *putter) Put(ctx context.Context, ch swarm.Chunk) (err error) { - rlevel := GetLevelFromContext(ctx) + rlevel := redundancy.GetLevelFromContext(ctx) errs := []error{p.putter.Put(ctx, ch)} if rlevel == 0 { return errs[0] diff --git a/pkg/replicas/putter_test.go b/pkg/replicas/putter_test.go index 7a90b5ec308..7d05624ebca 100644 --- a/pkg/replicas/putter_test.go +++ b/pkg/replicas/putter_test.go @@ -75,7 +75,7 @@ func TestPutter(t *testing.T) { t.Fatal(err) } ctx := context.Background() - ctx = replicas.SetLevel(ctx, tc.level) + ctx = redundancy.SetLevelInContext(ctx, tc.level) ch, err := cac.New(buf) if err != nil { @@ -174,7 +174,7 @@ func TestPutter(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) defer cancel() - ctx = replicas.SetLevel(ctx, tc.level) + ctx = redundancy.SetLevelInContext(ctx, tc.level) ch, err := cac.New(buf) if err != nil { t.Fatal(err) diff --git a/pkg/replicas/replicas.go b/pkg/replicas/replicas.go index 18b21d4b8b8..fd45b28a3b1 100644 --- a/pkg/replicas/replicas.go +++ b/pkg/replicas/replicas.go @@ -1,4 +1,4 @@ -// Copyright 2020 The Swarm Authors. All rights reserved. +// Copyright 2023 The Swarm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -11,7 +11,6 @@ package replicas import ( - "context" "time" "github.com/ethersphere/bee/pkg/crypto" @@ -19,31 +18,13 @@ import ( "github.com/ethersphere/bee/pkg/swarm" ) -type redundancyLevelType struct{} - var ( - // redundancyLevel is the context key for the redundancy level - redundancyLevel redundancyLevelType // RetryInterval is the duration between successive additional requests RetryInterval = 300 * time.Millisecond privKey, _ = crypto.DecodeSecp256k1PrivateKey(append([]byte{1}, make([]byte, 31)...)) signer = crypto.NewDefaultSigner(privKey) ) -// SetLevel sets the redundancy level in the context -func SetLevel(ctx context.Context, level redundancy.Level) context.Context { - return context.WithValue(ctx, redundancyLevel, level) -} - -// GetLevelFromContext is a helper function to extract the redundancy level from the context -func GetLevelFromContext(ctx context.Context) redundancy.Level { - rlevel := redundancy.PARANOID - if val := ctx.Value(redundancyLevel); val != nil { - rlevel = val.(redundancy.Level) - } - return rlevel -} - // replicator running the find for replicas type replicator struct { addr []byte // chunk address diff --git a/pkg/steward/steward_test.go b/pkg/steward/steward_test.go index 8e2083abc80..fe3ccd05ba9 100644 --- a/pkg/steward/steward_test.go +++ b/pkg/steward/steward_test.go @@ -17,7 +17,6 @@ import ( "github.com/ethersphere/bee/pkg/file/pipeline/builder" "github.com/ethersphere/bee/pkg/file/redundancy" postagetesting "github.com/ethersphere/bee/pkg/postage/mock" - "github.com/ethersphere/bee/pkg/replicas" "github.com/ethersphere/bee/pkg/steward" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" @@ -49,7 +48,7 @@ func TestSteward(t *testing.T) { s = steward.New(store, localRetrieval, inmem) stamper = postagetesting.NewStamper() ) - ctx = replicas.SetLevel(ctx, redundancy.NONE) + ctx = redundancy.SetLevelInContext(ctx, redundancy.NONE) n, err := rand.Read(data) if n != cap(data) { diff --git a/pkg/storageincentives/proof_test.go b/pkg/storageincentives/proof_test.go index 9ee3745e1af..dcd7002f913 100644 --- a/pkg/storageincentives/proof_test.go +++ b/pkg/storageincentives/proof_test.go @@ -44,8 +44,6 @@ var testData []byte // Test asserts that MakeInclusionProofs will generate the same // output for given sample. func TestMakeInclusionProofsRegression(t *testing.T) { - t.Parallel() - const sampleSize = 16 keyRaw := `00000000000000000000000000000000` diff --git a/pkg/storer/mock/forgetting.go b/pkg/storer/mock/forgetting.go new file mode 100644 index 00000000000..5588b5c263e --- /dev/null +++ b/pkg/storer/mock/forgetting.go @@ -0,0 +1,128 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mockstorer + +import ( + "context" + "sync" + "sync/atomic" + "time" + + storage "github.com/ethersphere/bee/pkg/storage" + "github.com/ethersphere/bee/pkg/swarm" +) + +type DelayedStore struct { + storage.ChunkStore + cache map[string]time.Duration + mu sync.Mutex +} + +func NewDelayedStore(s storage.ChunkStore) *DelayedStore { + return &DelayedStore{ + ChunkStore: s, + cache: make(map[string]time.Duration), + } +} + +func (d *DelayedStore) Delay(addr swarm.Address, delay time.Duration) { + d.mu.Lock() + defer d.mu.Unlock() + d.cache[addr.String()] = delay +} + +func (d *DelayedStore) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + d.mu.Lock() + defer d.mu.Unlock() + if delay, ok := d.cache[addr.String()]; ok && delay > 0 { + select { + case <-time.After(delay): + delete(d.cache, addr.String()) + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return d.ChunkStore.Get(ctx, addr) +} + +type ForgettingStore struct { + storage.ChunkStore + record atomic.Bool + mu sync.Mutex + n atomic.Int64 + missed map[string]struct{} +} + +func NewForgettingStore(s storage.ChunkStore) *ForgettingStore { + return &ForgettingStore{ChunkStore: s, missed: make(map[string]struct{})} +} + +func (f *ForgettingStore) Stored() int64 { + return f.n.Load() +} + +func (f *ForgettingStore) Record() { + f.record.Store(true) +} + +func (f *ForgettingStore) Unrecord() { + f.record.Store(false) +} + +func (f *ForgettingStore) Miss(addr swarm.Address) { + f.mu.Lock() + defer f.mu.Unlock() + f.miss(addr) +} + +func (f *ForgettingStore) Unmiss(addr swarm.Address) { + f.mu.Lock() + defer f.mu.Unlock() + f.unmiss(addr) +} + +func (f *ForgettingStore) miss(addr swarm.Address) { + f.missed[addr.String()] = struct{}{} +} + +func (f *ForgettingStore) unmiss(addr swarm.Address) { + delete(f.missed, addr.String()) +} + +func (f *ForgettingStore) isMiss(addr swarm.Address) bool { + _, ok := f.missed[addr.String()] + return ok +} + +func (f *ForgettingStore) Reset() { + f.missed = make(map[string]struct{}) +} + +func (f *ForgettingStore) Missed() int { + return len(f.missed) +} + +// Get implements the ChunkStore interface. +// if in recording phase, record the chunk address as miss and returns Get on the embedded store +// if in forgetting phase, returns ErrNotFound if the chunk address is recorded as miss +func (f *ForgettingStore) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.record.Load() { + f.miss(addr) + } else if f.isMiss(addr) { + return nil, storage.ErrNotFound + } + return f.ChunkStore.Get(ctx, addr) +} + +// Put implements the ChunkStore interface. +func (f *ForgettingStore) Put(ctx context.Context, ch swarm.Chunk) (err error) { + f.n.Add(1) + if !f.record.Load() { + f.Unmiss(ch.Address()) + } + return f.ChunkStore.Put(ctx, ch) +} diff --git a/pkg/util/testutil/pseudorand/reader.go b/pkg/util/testutil/pseudorand/reader.go new file mode 100644 index 00000000000..d7fcf0c3646 --- /dev/null +++ b/pkg/util/testutil/pseudorand/reader.go @@ -0,0 +1,182 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// this is a pseudorandom reader that generates a deterministic +// sequence of bytes based on the seed. It is used in tests to +// enable large volumes of pseudorandom data to be generated +// and compared without having to store the data in memory. +package pseudorand + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/ethersphere/bee/pkg/swarm" +) + +const bufSize = 4096 + +// Reader is a pseudorandom reader that generates a deterministic +// sequence of bytes based on the seed. +type Reader struct { + cur int + len int + seg [40]byte + buf [bufSize]byte +} + +// NewSeed creates a new seed. +func NewSeed() ([]byte, error) { + seed := make([]byte, 32) + _, err := io.ReadFull(rand.Reader, seed) + return seed, err +} + +// New creates a new pseudorandom reader seeded with the given seed. +func NewReader(seed []byte, l int) *Reader { + r := &Reader{len: l} + _ = copy(r.buf[8:], seed) + r.fill() + return r +} + +// Size returns the size of the reader. +func (r *Reader) Size() int { + return r.len +} + +// Read reads len(buf) bytes into buf. It returns the number of bytes read (0 <= n <= len(buf)) and any error encountered. Even if Read returns n < len(buf), it may use all of buf as scratch space during the call. If some data is available but not len(buf) bytes, Read conventionally returns what is available instead of waiting for more. +func (r *Reader) Read(buf []byte) (n int, err error) { + cur := r.cur % bufSize + toRead := min(bufSize-cur, r.len-r.cur) + if toRead < len(buf) { + buf = buf[:toRead] + } + n = copy(buf, r.buf[cur:]) + r.cur += n + if r.cur == r.len { + return n, io.EOF + } + if r.cur%bufSize == 0 { + r.fill() + } + return n, nil +} + +// Equal compares the contents of the reader with the contents of +// the given reader. It returns true if the contents are equal upto n bytes +func (r1 *Reader) Equal(r2 io.Reader) (bool, error) { + ok, err := r1.Match(r2, r1.len) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + n, err := io.ReadFull(r2, make([]byte, 1)) + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return n == 0, nil + } + return false, err +} + +// Match compares the contents of the reader with the contents of +// the given reader. It returns true if the contents are equal upto n bytes +func (r1 *Reader) Match(r2 io.Reader, l int) (bool, error) { + + read := func(r io.Reader, buf []byte) (n int, err error) { + for n < len(buf) && err == nil { + i, e := r.Read(buf[n:]) + if e == nil && i == 0 { + return n, nil + } + err = e + n += i + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + err = nil + } + return n, err + } + + buf1 := make([]byte, bufSize) + buf2 := make([]byte, bufSize) + for l > 0 { + if l <= len(buf1) { + buf1 = buf1[:l] + buf2 = buf2[:l] + } + + n1, err := read(r1, buf1) + if err != nil { + return false, err + } + n2, err := read(r2, buf2) + if err != nil { + return false, err + } + if n1 != n2 { + return false, nil + } + if !bytes.Equal(buf1[:n1], buf2[:n2]) { + return false, nil + } + l -= n1 + } + return true, nil +} + +// Seek sets the offset for the next Read to offset, interpreted +// according to whence: 0 means relative to the start of the file, +// 1 means relative to the current offset, and 2 means relative to +// the end. It returns the new offset and an error, if any. +func (r *Reader) Seek(offset int64, whence int) (int64, error) { + switch whence { + case 0: + r.cur = int(offset) + case 1: + r.cur += int(offset) + case 2: + r.cur = r.len - int(offset) + } + if r.cur < 0 || r.cur > r.len { + return 0, fmt.Errorf("seek: invalid offset %d", r.cur) + } + r.fill() + return int64(r.cur), nil +} + +// Offset returns the current offset of the reader. +func (r *Reader) Offset() int64 { + return int64(r.cur) +} + +// ReadAt reads len(buf) bytes into buf starting at offset off. +func (r *Reader) ReadAt(buf []byte, off int64) (n int, err error) { + if _, err := r.Seek(off, io.SeekStart); err != nil { + return 0, err + } + return r.Read(buf) +} + +// fill fills the buffer with the hash of the current segment. +func (r *Reader) fill() { + if r.cur >= r.len { + return + } + bufSegments := bufSize / 32 + start := r.cur / bufSegments + rem := (r.cur % bufSize) / 32 + h := swarm.NewHasher() + for i := 32 * rem; i < len(r.buf); i += 32 { + binary.BigEndian.PutUint64(r.seg[:], uint64((start+i)/32)) + h.Reset() + _, _ = h.Write(r.seg[:]) + copy(r.buf[i:], h.Sum(nil)) + } +} diff --git a/pkg/util/testutil/pseudorand/reader_test.go b/pkg/util/testutil/pseudorand/reader_test.go new file mode 100644 index 00000000000..4ec85a90d73 --- /dev/null +++ b/pkg/util/testutil/pseudorand/reader_test.go @@ -0,0 +1,121 @@ +// Copyright 2023 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pseudorand_test + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/rand" + "testing" + + "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" +) + +func TestReader(t *testing.T) { + size := 42000 + seed := make([]byte, 32) + r := pseudorand.NewReader(seed, size) + content, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + t.Run("deterministicity", func(t *testing.T) { + r2 := pseudorand.NewReader(seed, size) + content2, err := io.ReadAll(r2) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(content, content2) { + t.Fatal("content mismatch") + } + }) + t.Run("randomness", func(t *testing.T) { + bufSize := 4096 + if bytes.Equal(content[:bufSize], content[bufSize:2*bufSize]) { + t.Fatal("buffers should not match") + } + }) + t.Run("re-readability", func(t *testing.T) { + ns, err := r.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != 0 { + t.Fatal("seek mismatch") + } + var read []byte + buf := make([]byte, 8200) + total := 0 + for { + s := rand.Intn(820) + n, err := r.Read(buf[:s]) + total += n + read = append(read, buf[:n]...) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Fatal(err) + } + } + read = read[:total] + if len(read) != len(content) { + t.Fatal("content length mismatch. expected", len(content), "got", len(read)) + } + if !bytes.Equal(content, read) { + t.Fatal("content mismatch") + } + }) + t.Run("comparison", func(t *testing.T) { + ns, err := r.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != 0 { + t.Fatal("seek mismatch") + } + if eq, err := r.Equal(bytes.NewBuffer(content)); err != nil { + t.Fatal(err) + } else if !eq { + t.Fatal("content mismatch") + } + ns, err = r.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != 0 { + t.Fatal("seek mismatch") + } + if eq, err := r.Equal(bytes.NewBuffer(content[:len(content)-1])); err != nil { + t.Fatal(err) + } else if eq { + t.Fatal("content should match") + } + }) + t.Run("seek and match", func(t *testing.T) { + for i := 0; i < 20; i++ { + off := rand.Intn(size) + n := rand.Intn(size - off) + t.Run(fmt.Sprintf("off=%d n=%d", off, n), func(t *testing.T) { + ns, err := r.Seek(int64(off), io.SeekStart) + if err != nil { + t.Fatal(err) + } + if ns != int64(off) { + t.Fatal("seek mismatch") + } + ok, err := r.Match(bytes.NewBuffer(content[off:off+n]), n) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("content mismatch") + } + }) + } + }) +} diff --git a/pkg/util/testutil/racedetection/off.go b/pkg/util/testutil/racedetection/off.go new file mode 100644 index 00000000000..d57125bfd03 --- /dev/null +++ b/pkg/util/testutil/racedetection/off.go @@ -0,0 +1,10 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !race +// +build !race + +package racedetection + +const On = false diff --git a/pkg/util/testutil/racedetection/on.go b/pkg/util/testutil/racedetection/on.go new file mode 100644 index 00000000000..92438ecfdc3 --- /dev/null +++ b/pkg/util/testutil/racedetection/on.go @@ -0,0 +1,10 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build race +// +build race + +package racedetection + +const On = true diff --git a/pkg/util/testutil/racedetection/race.go b/pkg/util/testutil/racedetection/race.go new file mode 100644 index 00000000000..411295e87f6 --- /dev/null +++ b/pkg/util/testutil/racedetection/race.go @@ -0,0 +1,9 @@ +// Copyright 2020 The Swarm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package racedetection + +func IsOn() bool { + return On +}