From 166974d2e3e6ced2d7fa49667a0f4718de696997 Mon Sep 17 00:00:00 2001 From: zelig Date: Wed, 28 Feb 2024 11:46:36 +0100 Subject: [PATCH] fix(getter): fix freeze and combine errors return --- pkg/api/bzz.go | 4 +- pkg/api/bzz_test.go | 28 +- pkg/file/joiner/joiner_test.go | 19 +- pkg/file/redundancy/getter/getter.go | 443 +++++++++------------- pkg/file/redundancy/getter/getter_test.go | 13 +- pkg/file/redundancy/getter/strategies.go | 151 +++++++- 6 files changed, 338 insertions(+), 320 deletions(-) diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index 1102b313607..91d8fac56fe 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -327,7 +327,7 @@ func (s *Service) serveReference(logger log.Logger, address swarm.Address, pathV strategyTimeout := getter.DefaultStrategyTimeout.String() ctx := r.Context() - ctx, err := getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, &strategyTimeout, logger) + ctx, err := getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, &strategyTimeout) if err != nil { logger.Error(err, err.Error()) jsonhttp.BadRequest(w, "could not parse headers") @@ -521,7 +521,7 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h strategyTimeout := getter.DefaultStrategyTimeout.String() ctx := r.Context() - ctx, err := getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, &strategyTimeout, logger) + ctx, err := getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, &strategyTimeout) if err != nil { logger.Error(err, err.Error()) jsonhttp.BadRequest(w, "could not parse headers") diff --git a/pkg/api/bzz_test.go b/pkg/api/bzz_test.go index fdcb3f1209f..fbdf9be8432 100644 --- a/pkg/api/bzz_test.go +++ b/pkg/api/bzz_test.go @@ -151,33 +151,20 @@ func TestBzzUploadDownloadWithRedundancy(t *testing.T) { if rLevel == 0 { t.Skip("NA") } - req, err := http.NewRequestWithContext(context.Background(), "GET", fileDownloadResource(refResponse.Reference.String()), nil) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "GET", fileDownloadResource(refResponse.Reference.String()), nil) if err != nil { t.Fatal(err) } + req.Header.Set(api.SwarmLookAheadBufferSizeHeader, "0") req.Header.Set(api.SwarmRedundancyStrategyHeader, "0") req.Header.Set(api.SwarmRedundancyFallbackModeHeader, "false") req.Header.Set(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()) - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected status %d; got %d", http.StatusOK, resp.StatusCode) - } - _, err = dataReader.Seek(0, io.SeekStart) - if err != nil { - t.Fatal(err) - } - ok, err := dataReader.Equal(resp.Body) - if err != nil { - t.Fatal(err) - } - if ok { - t.Fatal("there should be missing data") + _, err = client.Do(req) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected error %v; got %v", io.ErrUnexpectedEOF, err) } }) @@ -186,6 +173,7 @@ func TestBzzUploadDownloadWithRedundancy(t *testing.T) { if err != nil { t.Fatal(err) } + req.Header.Set(api.SwarmLookAheadBufferSizeHeader, "0") req.Header.Set(api.SwarmRedundancyStrategyHeader, "3") req.Header.Set(api.SwarmRedundancyFallbackModeHeader, "true") req.Header.Set(api.SwarmChunkRetrievalTimeoutHeader, fetchTimeout.String()) diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index bfb6e4d2cd6..125f2d727d5 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -24,7 +24,6 @@ import ( "github.com/ethersphere/bee/pkg/file/redundancy/getter" "github.com/ethersphere/bee/pkg/file/splitter" filetest "github.com/ethersphere/bee/pkg/file/testing" - "github.com/ethersphere/bee/pkg/log" storage "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" testingc "github.com/ethersphere/bee/pkg/storage/testing" @@ -32,6 +31,7 @@ import ( "github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/util/testutil" "github.com/ethersphere/bee/pkg/util/testutil/pseudorand" + "github.com/ethersphere/bee/pkg/util/testutil/racedetection" "gitlab.com/nolash/go-mockbytes" "golang.org/x/sync/errgroup" ) @@ -1112,14 +1112,15 @@ func TestJoinerRedundancy(t *testing.T) { strategyTimeout := 100 * time.Millisecond // all data can be read back readCheck := func(t *testing.T, expErr error) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 15*strategyTimeout) + defer cancel() strategyTimeoutStr := strategyTimeout.String() decodeTimeoutStr := (10 * strategyTimeout).String() fallback := true s := getter.RACE - ctx, err := getter.SetConfigInContext(ctx, &s, &fallback, &decodeTimeoutStr, &strategyTimeoutStr, log.Noop) + ctx, err := getter.SetConfigInContext(ctx, &s, &fallback, &decodeTimeoutStr, &strategyTimeoutStr) if err != nil { t.Fatal(err) } @@ -1168,14 +1169,14 @@ func TestJoinerRedundancy(t *testing.T) { } } t.Run("no recovery possible with no chunk stored", func(t *testing.T) { - readCheck(t, storage.ErrNotFound) + readCheck(t, context.DeadlineExceeded) }) if err := putter.store(shardCnt - 1); err != nil { t.Fatal(err) } t.Run("no recovery possible with 1 short of shardCnt chunks stored", func(t *testing.T) { - readCheck(t, storage.ErrNotFound) + readCheck(t, context.DeadlineExceeded) }) if err := putter.store(1); err != nil { @@ -1252,15 +1253,21 @@ func TestJoinerRedundancyMultilevel(t *testing.T) { canReadRange := func(t *testing.T, s getter.Strategy, fallback bool, levels int, canRead bool) { ctx := context.Background() strategyTimeout := 100 * time.Millisecond + decodingTimeout := 600 * time.Millisecond + if racedetection.IsOn() { + decodingTimeout *= 2 + } strategyTimeoutStr := strategyTimeout.String() decodingTimeoutStr := (2 * strategyTimeout).String() - ctx, err := getter.SetConfigInContext(ctx, &s, &fallback, &decodingTimeoutStr, &strategyTimeoutStr, log.Noop) + ctx, err := getter.SetConfigInContext(ctx, &s, &fallback, &decodingTimeoutStr, &strategyTimeoutStr) if err != nil { t.Fatal(err) } + ctx, cancel := context.WithTimeout(ctx, time.Duration(levels)*(3*strategyTimeout+decodingTimeout)) + defer cancel() j, _, err := joiner.New(ctx, store, store, addr) if err != nil { t.Fatal(err) diff --git a/pkg/file/redundancy/getter/getter.go b/pkg/file/redundancy/getter/getter.go index ce1181b045d..77c6f1e0168 100644 --- a/pkg/file/redundancy/getter/getter.go +++ b/pkg/file/redundancy/getter/getter.go @@ -7,46 +7,42 @@ package getter import ( "context" "errors" + "fmt" "io" "sync" "sync/atomic" - "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/swarm" "github.com/klauspost/reedsolomon" ) -var ( - errStrategyNotAllowed = errors.New("strategy not allowed") - errStrategyFailed = errors.New("strategy failed") -) - // decoder is a private implementation of storage.Getter // if retrieves children of an intermediate chunk potentially using erasure decoding // it caches sibling chunks if erasure decoding started already type decoder struct { - fetcher storage.Getter // network retrieval interface to fetch chunks - putter storage.Putter // interface to local storage to save reconstructed chunks - addrs []swarm.Address // all addresses of the intermediate chunk - inflight []atomic.Bool // locks to protect wait channels and RS buffer - cache map[string]int // map from chunk address shard position index - waits []chan error // wait channels for each chunk - rsbuf [][]byte // RS buffer of data + parity shards for erasure decoding - goodRecovery chan struct{} // signal channel for successful retrieval of shardCnt chunks - badRecovery chan struct{} // signals that either the recovery has failed or not allowed to run - lastLen int // length of the last data chunk in the RS buffer - shardCnt int // number of data shards - parityCnt int // number of parity shards - wg sync.WaitGroup // wait group to wait for all goroutines to finish - mu sync.Mutex // mutex to protect buffer - err error // error of the last erasure decoding - fetchedCnt atomic.Int32 // count successful retrievals - failedCnt atomic.Int32 // count successful retrievals - cancel func() // cancel function for RS decoding - remove func() // callback to remove decoder from decoders cache - config Config // configuration - logger log.Logger + fetcher storage.Getter // network retrieval interface to fetch chunks + putter storage.Putter // interface to local storage to save reconstructed chunks + cache map[string]int // map from chunk address shard position index + addrs []swarm.Address // all addresses of the intermediate chunk + inflight []atomic.Bool // locks to protect wait channels and RS buffer + waits []chan struct{} // wait channels for each chunk + derrs []error // decoding errors + ferrs []error // fetch errors + decoded chan struct{} // signal that the decoding is finished + err error // error of the last erasure decoding + rsbuf [][]byte // RS buffer of data + parity shards for erasure decoding + chunks [][]byte // chunks fetched + lastLen int // length of the last data chunk in the RS buffer + shardCnt int // number of data shards + parityCnt int // number of parity shards + fetched *counter // count number of fetched chunks + failed *counter // count number of failed retrievals + mu sync.Mutex // mutex to protect the decoder state + wg sync.WaitGroup // wait group to wait for all goroutines to finish + cancel func() // cancel function for RS decoding + remove func() // callback to remove decoder from decoders cache + config Config // configuration } type Getter interface { @@ -60,242 +56,184 @@ func New(addrs []swarm.Address, shardCnt int, g storage.Getter, p storage.Putter size := len(addrs) d := &decoder{ - fetcher: g, - putter: p, - addrs: addrs, - inflight: make([]atomic.Bool, size), - cache: make(map[string]int, size), - waits: make([]chan error, size), - rsbuf: make([][]byte, size), - goodRecovery: make(chan struct{}), - badRecovery: make(chan struct{}), - cancel: cancel, - remove: remove, - shardCnt: shardCnt, - parityCnt: size - shardCnt, - config: conf, - logger: conf.Logger.WithName("redundancy").Build(), + fetcher: g, + putter: p, + addrs: addrs, + cache: make(map[string]int, size), + inflight: make([]atomic.Bool, shardCnt), + waits: make([]chan struct{}, shardCnt), + ferrs: make([]error, shardCnt), + derrs: make([]error, shardCnt), + decoded: make(chan struct{}), + rsbuf: make([][]byte, size), + chunks: make([][]byte, size), + fetched: newCounter(shardCnt - 1), + failed: newCounter(size - shardCnt), + cancel: cancel, + remove: remove, + shardCnt: shardCnt, + parityCnt: size - shardCnt, + config: conf, } - // after init, cache and wait channels are immutable, need no locking - for i := 0; i < shardCnt; i++ { - d.cache[addrs[i].ByteString()] = i + if conf.Strategy == RACE || !conf.Strict { + go func() { // if not enough shards are retrieved, signal that decoding is finished + select { + case <-ctx.Done(): + case <-d.failed.c: + d.close(ErrNotEnoughShards) + } + }() + } else { + d.fetched.cancel() + d.failed.cancel() + d.close(ErrRecoveryUnavailable) } - // after init, cache and wait channels are immutable, need no locking - for i := 0; i < size; i++ { - d.waits[i] = make(chan error) + // init cache and wait channels + // after init, they are immutable, need no locking + for i := 0; i < shardCnt; i++ { + d.cache[addrs[i].ByteString()] = i + d.waits[i] = make(chan struct{}) } // prefetch chunks according to strategy - if !conf.Strict || conf.Strategy != NONE { - d.wg.Add(1) - go func() { - defer d.wg.Done() - d.err = d.prefetch(ctx) - }() - } else { // recovery not allowed - close(d.badRecovery) - } - + d.wg.Add(1) + go d.run(ctx) return d } // Get will call parities and other sibling chunks if the chunk address cannot be retrieved // assumes it is called for data shards only -func (g *decoder) Get(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { +func (g *decoder) Get(ctx context.Context, addr swarm.Address) (c swarm.Chunk, err error) { i, ok := g.cache[addr.ByteString()] if !ok { return nil, storage.ErrNotFound } - err := g.fetch(ctx, i, true) - if err != nil { - return nil, err - } - return swarm.NewChunk(addr, g.getData(i)), nil -} - -// fetch retrieves a chunk from the netstore if it is the first time the chunk is fetched. -// If the fetch fails and waiting for the recovery is allowed, the function will wait -// for either a good or bad recovery signal. -func (g *decoder) fetch(ctx context.Context, i int, waitForRecovery bool) (err error) { - - waitRecovery := func(err error) error { - if !waitForRecovery { - return err - } - - select { - case <-g.badRecovery: - return storage.ErrNotFound - case <-g.goodRecovery: - g.logger.Debug("recovered chunk", "address", g.addrs[i]) - return nil - case <-ctx.Done(): - return ctx.Err() - } - } - - // first time if g.fly(i) { - - fctx, cancel := context.WithTimeout(ctx, g.config.FetchTimeout) - defer cancel() - g.wg.Add(1) - defer g.wg.Done() - - // retrieval - ch, err := g.fetcher.Get(fctx, g.addrs[i]) - if err != nil { - g.failedCnt.Add(1) - close(g.waits[i]) - return waitRecovery(err) - } - - g.fetchedCnt.Add(1) - g.setData(i, ch.Data()) - close(g.waits[i]) - return nil - } - - select { - case <-g.waits[i]: - case <-ctx.Done(): - return ctx.Err() + go g.fetch(ctx, i) } - if g.getData(i) != nil { - return nil - } + fetched := g.waits[i] + decoded := g.decoded + for { + select { + case <-fetched: + // if the chunk is retrieval is completed and there is no error, return the chunk + if g.ferrs[i] == nil { + return swarm.NewChunk(addr, g.getData(i, g.chunks)), nil + } + fetched = nil - return waitRecovery(storage.ErrNotFound) -} + case <-decoded: + // if the RS decoding is completed + // if there was no error, return the chunk from the RS buffer (recovery) + if g.err == nil { + return swarm.NewChunk(addr, g.getData(i, g.rsbuf)), nil + } -func (g *decoder) prefetch(ctx context.Context) error { - defer g.remove() + // otherwise (if there was an error), and chunk retrieval had already been attempted, + // return the combined error of fetching and the decoding + if fetched == nil { + return nil, errors.Join(g.err, g.ferrs[i]) + } + // continue waiting for retrieval to complete + // disable this case and enable the case waiting for retrieval to complete + decoded = nil - run := func(s Strategy) error { - if err := g.runStrategy(ctx, s); err != nil { - return err + case <-ctx.Done(): + // if the context is cancelled, return the error + return nil, errors.Join(g.err, fmt.Errorf("Get: %w", ctx.Err())) } - - return g.recover(ctx) } +} - var err error - for s := g.config.Strategy; s < strategyCnt; s++ { - - err = run(s) - if err != nil { - if s == DATA || s == RACE { - g.logger.Debug("failed recovery", "strategy", s) - } - } - if err == nil { - if s > DATA { - g.logger.Debug("successful recovery", "strategy", s) - } - close(g.goodRecovery) - break - } - if g.config.Strict { // only run one strategy - break - } +// setData sets the data shard in the chunks slice +func (g *decoder) setData(i int, chdata []byte) { + data := chdata + // pad the chunk with zeros if it is smaller than swarm.ChunkSize + if len(data) < swarm.ChunkWithSpanSize { + g.lastLen = len(data) + data = make([]byte, swarm.ChunkWithSpanSize) + copy(data, chdata) } + g.chunks[i] = data +} - if err != nil { - close(g.badRecovery) - return err +// getData returns the data shard from the RS buffer +func (g *decoder) getData(i int, s [][]byte) []byte { + if i == g.shardCnt-1 && g.lastLen > 0 { + return s[i][:g.lastLen] // cut padding } + return s[i] +} - return err +// fly commits to retrieve the chunk (fly and land) +// it marks a chunk as inflight and returns true unless it is already inflight +// the atomic bool implements the signal for a singleflight pattern +func (g *decoder) fly(i int) (success bool) { + return g.inflight[i].CompareAndSwap(false, true) } -func (g *decoder) runStrategy(ctx context.Context, s Strategy) error { - - // across the different strategies, the common goal is to fetch at least as many chunks - // as the number of data shards. - // DATA strategy has a max error tolerance of zero. - // RACE strategy has a max error tolerance of number of parity chunks. - var allowedErrs int - var m []int - - switch s { - case NONE: - return errStrategyNotAllowed - case DATA: - // only retrieve data shards - m = g.unattemptedDataShards() - allowedErrs = 0 - case PROX: - // proximity driven selective fetching - // NOT IMPLEMENTED - return errStrategyNotAllowed - case RACE: - allowedErrs = g.parityCnt - // retrieve all chunks at once enabling race among chunks - m = g.unattemptedDataShards() - for i := g.shardCnt; i < len(g.addrs); i++ { - m = append(m, i) +// fetch retrieves a chunk from the underlying storage +// it must be called asynchonously and only once for each chunk (singleflight pattern) +// it races with erasure recovery; the latter takes precedence even if it started later +// due to the fact that erasure recovery could only implement global locking on all shards +func (g *decoder) fetch(ctx context.Context, i int) { + defer g.wg.Done() + // set the timeout context for the fetch + fctx, cancel := context.WithTimeout(ctx, g.config.FetchTimeout) + defer cancel() + + // retrieve the chunk using the underlying storage + ch, err := g.fetcher.Get(fctx, g.addrs[i]) + // if there was an error, the error and return + g.mu.Lock() + defer g.mu.Unlock() + // whatever happens, signal that the chunk retrieval is finished + if err != nil { + if i < g.shardCnt { + g.ferrs[i] = err } + g.failed.inc() + return } - - if len(m) == 0 { - return nil + if i < g.shardCnt { + defer close(g.waits[i]) } - c := make(chan error, len(m)) + // write chunk to rsbuf and signal waiters + g.setData(i, ch.Data()) // save the chunk in the RS buffer - for _, i := range m { - g.wg.Add(1) - go func(i int) { - defer g.wg.Done() - c <- g.fetch(ctx, i, false) - }(i) - } + // if all chunks are retrieved, signal ready + g.fetched.inc() +} - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-c: - if g.fetchedCnt.Load() >= int32(g.shardCnt) { - return nil +// missing gathers missing data shards not yet retrieved or retrieved with an error +// it sets the chunk as inflight and returns the index of the missing data shards +func (g *decoder) missing() (m []int) { + // initialize RS buffer + for i, ch := range g.chunks { + if len(ch) > 0 { + g.rsbuf[i] = ch + if i < g.shardCnt { + _ = g.fly(i) // commit (RS) or will commit to retrieve the chunk } - if g.failedCnt.Load() > int32(allowedErrs) { - return errStrategyFailed + } else { + if i < g.shardCnt { + g.derrs[i] = g.ferrs[i] + m = append(m, i) } } } -} - -// recover wraps the stages of data shard recovery: -// 1. gather missing data shards -// 2. decode using Reed-Solomon decoder -// 3. save reconstructed chunks -func (g *decoder) recover(ctx context.Context) error { - // gather missing shards - m := g.missingDataShards() - if len(m) == 0 { - return nil // recovery is not needed as there are no missing data chunks - } - - // decode using Reed-Solomon decoder - if err := g.decode(ctx); err != nil { - return err - } - - // save chunks - return g.save(ctx, m) + return m } // decode uses Reed-Solomon erasure coding decoder to recover data shards -// it must be called after shqrdcnt shards are retrieved +// it must be called after shardcnt shards are retrieved +// it must be called under mutex protection func (g *decoder) decode(ctx context.Context) error { - g.mu.Lock() - defer g.mu.Unlock() - enc, err := reedsolomon.New(g.shardCnt, g.parityCnt) if err != nil { return err @@ -305,72 +243,47 @@ func (g *decoder) decode(ctx context.Context) error { return enc.ReconstructData(g.rsbuf) } -func (g *decoder) unattemptedDataShards() (m []int) { - for i := 0; i < g.shardCnt; i++ { - select { - case <-g.waits[i]: // attempted - continue - default: - m = append(m, i) // remember the missing chunk - } - } - return m -} - -// it must be called under mutex protection -func (g *decoder) missingDataShards() (m []int) { - for i := 0; i < g.shardCnt; i++ { - if g.getData(i) == nil { - m = append(m, i) - } - } - return m -} +// recover wraps the stages of data shard recovery: +// 1. gather missing data shards +// 2. decode using Reed-Solomon decoder +// 3. save reconstructed chunks +func (g *decoder) recover(ctx context.Context) (err error) { -// setData sets the data shard in the RS buffer -func (g *decoder) setData(i int, chdata []byte) { + defer func() { + g.close(err) + }() g.mu.Lock() defer g.mu.Unlock() - data := chdata - // pad the chunk with zeros if it is smaller than swarm.ChunkSize - if len(data) < swarm.ChunkWithSpanSize { - g.lastLen = len(data) - data = make([]byte, swarm.ChunkWithSpanSize) - copy(data, chdata) - } - g.rsbuf[i] = data -} + // gather missing shards + m := g.missing() -// getData returns the data shard from the RS buffer -func (g *decoder) getData(i int) []byte { - g.mu.Lock() - defer g.mu.Unlock() - if i == g.shardCnt-1 && g.lastLen > 0 { - return g.rsbuf[i][:g.lastLen] // cut padding + // decode using Reed-Solomon decoder + err = g.decode(ctx) + if err != nil { + return err } - return g.rsbuf[i] -} -// fly commits to retrieve the chunk (fly and land) -// it marks a chunk as inflight and returns true unless it is already inflight -// the atomic bool implements a singleflight pattern -func (g *decoder) fly(i int) (success bool) { - return g.inflight[i].CompareAndSwap(false, true) + // save chunks + err = g.save(ctx, m) + return err } // save iterate over reconstructed shards and puts the corresponding chunks to local storage func (g *decoder) save(ctx context.Context, missing []int) error { - g.mu.Lock() - defer g.mu.Unlock() for _, i := range missing { - if err := g.putter.Put(ctx, swarm.NewChunk(g.addrs[i], g.rsbuf[i])); err != nil { + if err := g.putter.Put(ctx, swarm.NewChunk(g.addrs[i], g.getData(i, g.rsbuf))); err != nil { return err } } return nil } +func (g *decoder) close(err error) { + g.err = err + close(g.decoded) +} + // Close terminates the prefetch loop, waits for all goroutines to finish and // removes the decoder from the cache // it implements the io.Closer interface diff --git a/pkg/file/redundancy/getter/getter_test.go b/pkg/file/redundancy/getter/getter_test.go index 95609fc4c1e..8be841fd688 100644 --- a/pkg/file/redundancy/getter/getter_test.go +++ b/pkg/file/redundancy/getter/getter_test.go @@ -19,7 +19,6 @@ import ( "github.com/ethersphere/bee/pkg/cac" "github.com/ethersphere/bee/pkg/file/redundancy/getter" - "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/storage" inmem "github.com/ethersphere/bee/pkg/storage/inmemchunkstore" mockstorer "github.com/ethersphere/bee/pkg/storer/mock" @@ -39,20 +38,20 @@ func TestGetterRACE(t *testing.T) { } var tcs []getterTest - for bufSize := 3; bufSize <= 128; bufSize += 21 { + for bufSize := 45; bufSize <= 128; bufSize += 42 { for shardCnt := bufSize/2 + 1; shardCnt <= bufSize; shardCnt += 21 { parityCnt := bufSize - shardCnt erasures := mrand.Perm(parityCnt - 1) - if len(erasures) > 3 { - erasures = erasures[:3] + if len(erasures) > 1 { + erasures = erasures[:1] } for _, erasureCnt := range erasures { tcs = append(tcs, getterTest{bufSize, shardCnt, erasureCnt}) } tcs = append(tcs, getterTest{bufSize, shardCnt, parityCnt}, getterTest{bufSize, shardCnt, parityCnt + 1}) erasures = mrand.Perm(shardCnt - 1) - if len(erasures) > 3 { - erasures = erasures[:3] + if len(erasures) > 1 { + erasures = erasures[:1] } for _, erasureCnt := range erasures { tcs = append(tcs, getterTest{bufSize, shardCnt, erasureCnt + parityCnt + 1}) @@ -73,7 +72,6 @@ func TestGetterRACE(t *testing.T) { // TestGetterFallback tests the retrieval of chunks with missing data shards // using the strict or fallback mode starting with NONE and DATA strategies func TestGetterFallback(t *testing.T) { - t.Skip("removed strategy timeout") t.Run("GET", func(t *testing.T) { t.Run("NONE", func(t *testing.T) { t.Run("strict", func(t *testing.T) { @@ -121,7 +119,6 @@ func testDecodingRACE(t *testing.T, bufSize, shardCnt, erasureCnt int) { Strategy: getter.RACE, FetchTimeout: 2 * strategyTimeout, StrategyTimeout: strategyTimeout, - Logger: log.Noop, } g := getter.New(addrs, shardCnt, store, store, func() {}, conf) defer g.Close() diff --git a/pkg/file/redundancy/getter/strategies.go b/pkg/file/redundancy/getter/strategies.go index 410a0003962..c3ce5493af0 100644 --- a/pkg/file/redundancy/getter/strategies.go +++ b/pkg/file/redundancy/getter/strategies.go @@ -8,9 +8,9 @@ import ( "context" "errors" "fmt" + "sync/atomic" "time" - "github.com/ethersphere/bee/pkg/log" "github.com/ethersphere/bee/pkg/retrieval" ) @@ -21,12 +21,16 @@ const ( DefaultStrategyTimeout = 300 * time.Millisecond // timeout for each strategy ) +var ( + ErrNotEnoughShards = errors.New("not enough shards to reconstruct") + ErrRecoveryUnavailable = errors.New("recovery disabled for strategy") +) + type ( strategyKey struct{} modeKey struct{} fetchTimeoutKey struct{} strategyTimeoutKey struct{} - loggerKey struct{} Strategy = int ) @@ -36,7 +40,6 @@ type Config struct { Strict bool FetchTimeout time.Duration StrategyTimeout time.Duration - Logger log.Logger } const ( @@ -53,7 +56,6 @@ var DefaultConfig = Config{ Strict: DefaultStrict, FetchTimeout: DefaultFetchTimeout, StrategyTimeout: DefaultStrategyTimeout, - Logger: log.Noop, } // NewConfigFromContext returns a new Config based on the context @@ -90,12 +92,6 @@ func NewConfigFromContext(ctx context.Context, def Config) (conf Config, err err return conf, e("strategy timeout") } } - if val := ctx.Value(loggerKey{}); val != nil { - conf.Logger, ok = val.(log.Logger) - if !ok { - return conf, e("strategy timeout") - } - } return conf, nil } @@ -120,13 +116,8 @@ func SetStrategyTimeout(ctx context.Context, timeout time.Duration) context.Cont return context.WithValue(ctx, strategyTimeoutKey{}, timeout) } -// SetStrategyTimeout sets the timeout for each strategy -func SetLogger(ctx context.Context, l log.Logger) context.Context { - return context.WithValue(ctx, loggerKey{}, l) -} - // SetConfigInContext sets the config params in the context -func SetConfigInContext(ctx context.Context, s *Strategy, fallbackmode *bool, fetchTimeout, strategyTimeout *string, logger log.Logger) (context.Context, error) { +func SetConfigInContext(ctx context.Context, s *Strategy, fallbackmode *bool, fetchTimeout, strategyTimeout *string) (context.Context, error) { if s != nil { ctx = SetStrategy(ctx, *s) } @@ -151,9 +142,131 @@ func SetConfigInContext(ctx context.Context, s *Strategy, fallbackmode *bool, fe ctx = SetStrategyTimeout(ctx, dur) } - if logger != nil { - ctx = SetLogger(ctx, logger) + return ctx, nil +} + +func (d *decoder) run(ctx context.Context) { + // prefetch chunks according to strategy + var strategies []Strategy + if d.config.Strategy > NONE || !d.config.Strict { + strategies = []Strategy{d.config.Strategy} + } + if !d.config.Strict { + for i := d.config.Strategy + 1; i < strategyCnt; i++ { + strategies = append(strategies, i) + } } + d.prefetch(ctx, strategies) + d.cancel() + d.wg.Done() + d.remove() +} - return ctx, nil +func (g *decoder) prefetch(ctx context.Context, strategies []Strategy) { + // context to cancel when shardCnt chunks are retrieved + lctx, cancel := context.WithCancel(ctx) + defer cancel() + + run := func(s Strategy) (err error) { + if s == PROX { // NOT IMPLEMENTED + return nil + } + + var timeout <-chan time.Time + if s < RACE { + timer := time.NewTimer(g.config.StrategyTimeout) + defer timer.Stop() + timeout = timer.C + } + prefetch(lctx, g, s) + + select { + case <-g.fetched.c: + // successfully retrieved shardCnt number of chunks + g.fetched.cancel() + g.failed.cancel() + cancel() + // sdignal that decoding is finished + return g.recover(ctx) + + case <-timeout: // strategy timeout + return nil + + case <-ctx.Done(): // context cancelled + return ctx.Err() + } + } + for _, s := range strategies { + if err := run(s); err != nil { + break + } + } +} + +// prefetch launches the retrieval of chunks based on the strategy +func prefetch(ctx context.Context, g *decoder, s Strategy) { + switch s { + case NONE: + return + case DATA: + // only retrieve data shards + for i := range g.waits { + select { + case <-ctx.Done(): + return + default: + } + if g.fly(i) { + i := i + g.wg.Add(1) + go g.fetch(ctx, i) + } + } + case PROX: + // proximity driven selective fetching + // NOT IMPLEMENTED + case RACE: + // retrieve all chunks at once enabling race among chunks + // only retrieve data shards + for i := range g.addrs { + select { + case <-ctx.Done(): + return + default: + } + if i >= g.shardCnt || g.fly(i) { + i := i + g.wg.Add(1) + go g.fetch(ctx, i) + } + } + } +} + +// counter counts the number of successful or failed retrievals +// count counts the number of successful or failed retrievals +// if either successful retrievals reach shardCnt or failed ones reach parityCnt + 1, +// it signals by true or false respectively on the channel argument and terminates +type counter struct { + c chan struct{} + n atomic.Int32 + max int + off atomic.Bool +} + +func newCounter(max int) *counter { + return &counter{ + c: make(chan struct{}), + max: max, + } +} + +func (c *counter) inc() { + if !c.off.Load() && c.n.Add(1) == int32(c.max)+1 { + close(c.c) + } +} + +func (c *counter) cancel() { + c.off.Store(true) }