From 4950c31af80b769c28496d6ad52fd68d12246d52 Mon Sep 17 00:00:00 2001 From: istae <14264581+istae@users.noreply.github.com> Date: Tue, 13 Feb 2024 23:24:56 +0300 Subject: [PATCH] fix: strategy and fetch timeout parsing (#4579) --- pkg/api/bytes.go | 9 ++-- pkg/api/bzz.go | 38 +++++++++----- pkg/file/joiner/joiner_test.go | 22 ++++++++- pkg/file/redundancy/getter/strategies.go | 63 +++++++++++++++--------- pkg/retrieval/retrieval.go | 6 +-- 5 files changed, 92 insertions(+), 46 deletions(-) diff --git a/pkg/api/bytes.go b/pkg/api/bytes.go index 4cb90c1b763..988d73018ac 100644 --- a/pkg/api/bytes.go +++ b/pkg/api/bytes.go @@ -20,7 +20,6 @@ import ( "github.com/ethersphere/bee/pkg/tracing" "github.com/gorilla/mux" "github.com/opentracing/opentracing-go/ext" - "github.com/opentracing/opentracing-go/log" olog "github.com/opentracing/opentracing-go/log" ) @@ -63,7 +62,7 @@ func (s *Service) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { default: jsonhttp.InternalServerError(w, "cannot get or create tag") } - ext.LogError(span, err, log.String("action", "tag.create")) + ext.LogError(span, err, olog.String("action", "tag.create")) return } span.SetTag("tagID", tag) @@ -90,7 +89,7 @@ func (s *Service) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { default: jsonhttp.BadRequest(w, nil) } - ext.LogError(span, err, log.String("action", "new.StamperPutter")) + ext.LogError(span, err, olog.String("action", "new.StamperPutter")) return } @@ -111,7 +110,7 @@ func (s *Service) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { default: jsonhttp.InternalServerError(ow, "split write all failed") } - ext.LogError(span, err, log.String("action", "split.WriteAll")) + ext.LogError(span, err, olog.String("action", "split.WriteAll")) return } @@ -122,7 +121,7 @@ func (s *Service) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { logger.Debug("done split failed", "error", err) logger.Error(nil, "done split failed") jsonhttp.InternalServerError(ow, "done split failed") - ext.LogError(span, err, log.String("action", "putter.Done")) + ext.LogError(span, err, olog.String("action", "putter.Done")) return } diff --git a/pkg/api/bzz.go b/pkg/api/bzz.go index 2cf1d66d53d..91d8fac56fe 100644 --- a/pkg/api/bzz.go +++ b/pkg/api/bzz.go @@ -306,11 +306,10 @@ func (s *Service) serveReference(logger log.Logger, address swarm.Address, pathV loggerV1 := logger.V(1).Build() headers := struct { - Cache *bool `map:"Swarm-Cache"` - Strategy getter.Strategy `map:"Swarm-Redundancy-Strategy"` - FallbackMode bool `map:"Swarm-Redundancy-Fallback-Mode"` - ChunkRetrievalTimeout string `map:"Swarm-Chunk-Retrieval-Timeout"` - LookaheadBufferSize *string `map:"Swarm-Lookahead-Buffer-Size"` + Cache *bool `map:"Swarm-Cache"` + Strategy *getter.Strategy `map:"Swarm-Redundancy-Strategy"` + FallbackMode *bool `map:"Swarm-Redundancy-Fallback-Mode"` + ChunkRetrievalTimeout *string `map:"Swarm-Chunk-Retrieval-Timeout"` }{} if response := s.mapStructure(r.Header, &headers); response != nil { @@ -325,8 +324,15 @@ func (s *Service) serveReference(logger log.Logger, address swarm.Address, pathV ls := loadsave.NewReadonly(s.storer.Download(cache)) feedDereferenced := false + strategyTimeout := getter.DefaultStrategyTimeout.String() + ctx := r.Context() - ctx = getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, getter.DefaultStrategyTimeout.String()) + ctx, err := getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, &strategyTimeout) + if err != nil { + logger.Error(err, err.Error()) + jsonhttp.BadRequest(w, "could not parse headers") + return + } FETCH: // read manifest entry @@ -496,11 +502,11 @@ func (s *Service) serveManifestEntry( // downloadHandler contains common logic for dowloading Swarm file from API func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *http.Request, reference swarm.Address, additionalHeaders http.Header, etag bool) { headers := struct { - Cache *bool `map:"Swarm-Cache"` - Strategy getter.Strategy `map:"Swarm-Redundancy-Strategy"` - FallbackMode bool `map:"Swarm-Redundancy-Fallback-Mode"` - ChunkRetrievalTimeout string `map:"Swarm-Chunk-Retrieval-Timeout"` - LookaheadBufferSize *int `map:"Swarm-Lookahead-Buffer-Size"` + Strategy *getter.Strategy `map:"Swarm-Redundancy-Strategy"` + FallbackMode *bool `map:"Swarm-Redundancy-Fallback-Mode"` + ChunkRetrievalTimeout *string `map:"Swarm-Chunk-Retrieval-Timeout"` + LookaheadBufferSize *int `map:"Swarm-Lookahead-Buffer-Size"` + Cache *bool `map:"Swarm-Cache"` }{} if response := s.mapStructure(r.Header, &headers); response != nil { @@ -512,8 +518,16 @@ func (s *Service) downloadHandler(logger log.Logger, w http.ResponseWriter, r *h cache = *headers.Cache } + strategyTimeout := getter.DefaultStrategyTimeout.String() + ctx := r.Context() - ctx = getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, getter.DefaultStrategyTimeout.String()) + ctx, err := getter.SetConfigInContext(ctx, headers.Strategy, headers.FallbackMode, headers.ChunkRetrievalTimeout, &strategyTimeout) + if err != nil { + logger.Error(err, err.Error()) + jsonhttp.BadRequest(w, "could not parse headers") + return + } + reader, l, err := joiner.New(ctx, s.storer.Download(cache), s.storer.Cache(), reference) if err != nil { if errors.Is(err, storage.ErrNotFound) || errors.Is(err, topology.ErrNotFound) { diff --git a/pkg/file/joiner/joiner_test.go b/pkg/file/joiner/joiner_test.go index 15d46bf220b..125f2d727d5 100644 --- a/pkg/file/joiner/joiner_test.go +++ b/pkg/file/joiner/joiner_test.go @@ -1114,7 +1114,17 @@ func TestJoinerRedundancy(t *testing.T) { readCheck := func(t *testing.T, expErr error) { ctx, cancel := context.WithTimeout(context.Background(), 15*strategyTimeout) defer cancel() - ctx = getter.SetConfigInContext(ctx, getter.RACE, true, (10 * strategyTimeout).String(), strategyTimeout.String()) + + strategyTimeoutStr := strategyTimeout.String() + decodeTimeoutStr := (10 * strategyTimeout).String() + fallback := true + s := getter.RACE + + ctx, err := getter.SetConfigInContext(ctx, &s, &fallback, &decodeTimeoutStr, &strategyTimeoutStr) + if err != nil { + t.Fatal(err) + } + joinReader, rootSpan, err := joiner.New(ctx, store, store, swarmAddr) if err != nil { t.Fatal(err) @@ -1247,7 +1257,15 @@ func TestJoinerRedundancyMultilevel(t *testing.T) { if racedetection.IsOn() { decodingTimeout *= 2 } - ctx = getter.SetConfigInContext(ctx, s, fallback, (2 * strategyTimeout).String(), strategyTimeout.String()) + + strategyTimeoutStr := strategyTimeout.String() + decodingTimeoutStr := (2 * strategyTimeout).String() + + ctx, err := getter.SetConfigInContext(ctx, &s, &fallback, &decodingTimeoutStr, &strategyTimeoutStr) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(ctx, time.Duration(levels)*(3*strategyTimeout+decodingTimeout)) defer cancel() j, _, err := joiner.New(ctx, store, store, addr) diff --git a/pkg/file/redundancy/getter/strategies.go b/pkg/file/redundancy/getter/strategies.go index bb5188e9cc2..9269636e3e9 100644 --- a/pkg/file/redundancy/getter/strategies.go +++ b/pkg/file/redundancy/getter/strategies.go @@ -9,13 +9,15 @@ import ( "errors" "fmt" "time" + + "github.com/ethersphere/bee/pkg/retrieval" ) const ( - DefaultStrategy = NONE // default prefetching strategy - DefaultStrict = true // default fallback modes - DefaultFetchTimeout = 30 * time.Second // timeout for each chunk retrieval - DefaultStrategyTimeout = 300 * time.Millisecond // timeout for each strategy + DefaultStrategy = NONE // default prefetching strategy + DefaultStrict = true // default fallback modes + DefaultFetchTimeout = retrieval.RetrieveChunkTimeout // timeout for each chunk retrieval + DefaultStrategyTimeout = 300 * time.Millisecond // timeout for each strategy ) type ( @@ -73,25 +75,18 @@ func NewConfigFromContext(ctx context.Context, def Config) (conf Config, err err } } if val := ctx.Value(fetchTimeoutKey{}); val != nil { - fetchTimeoutVal, ok := val.(string) + conf.FetchTimeout, ok = val.(time.Duration) if !ok { return conf, e("fetcher timeout") } - conf.FetchTimeout, err = time.ParseDuration(fetchTimeoutVal) - if err != nil { - return conf, e("fetcher timeout", err) - } } if val := ctx.Value(strategyTimeoutKey{}); val != nil { - strategyTimeoutVal, ok := val.(string) + conf.StrategyTimeout, ok = val.(time.Duration) if !ok { - return conf, e("fetcher timeout") - } - conf.StrategyTimeout, err = time.ParseDuration(strategyTimeoutVal) - if err != nil { - return conf, e("fetcher timeout", err) + return conf, e("strategy timeout") } } + return conf, nil } @@ -106,22 +101,42 @@ func SetStrict(ctx context.Context, strict bool) context.Context { } // SetFetchTimeout sets the timeout for each fetch -func SetFetchTimeout(ctx context.Context, timeout string) context.Context { +func SetFetchTimeout(ctx context.Context, timeout time.Duration) context.Context { return context.WithValue(ctx, fetchTimeoutKey{}, timeout) } // SetStrategyTimeout sets the timeout for each strategy -func SetStrategyTimeout(ctx context.Context, timeout string) context.Context { - return context.WithValue(ctx, fetchTimeoutKey{}, timeout) +func SetStrategyTimeout(ctx context.Context, timeout time.Duration) context.Context { + return context.WithValue(ctx, strategyTimeoutKey{}, timeout) } // SetConfigInContext sets the config params in the context -func SetConfigInContext(ctx context.Context, s Strategy, fallbackmode bool, fetchTimeout, strategyTimeout string) context.Context { - ctx = SetStrategy(ctx, s) - ctx = SetStrict(ctx, !fallbackmode) - ctx = SetFetchTimeout(ctx, fetchTimeout) - ctx = SetStrategyTimeout(ctx, strategyTimeout) - return ctx +func SetConfigInContext(ctx context.Context, s *Strategy, fallbackmode *bool, fetchTimeout, strategyTimeout *string) (context.Context, error) { + if s != nil { + ctx = SetStrategy(ctx, *s) + } + + if fallbackmode != nil { + ctx = SetStrict(ctx, !(*fallbackmode)) + } + + if fetchTimeout != nil { + dur, err := time.ParseDuration(*fetchTimeout) + if err != nil { + return nil, err + } + ctx = SetFetchTimeout(ctx, dur) + } + + if strategyTimeout != nil { + dur, err := time.ParseDuration(*strategyTimeout) + if err != nil { + return nil, err + } + ctx = SetStrategyTimeout(ctx, dur) + } + + return ctx, nil } func (g *decoder) prefetch(ctx context.Context) error { diff --git a/pkg/retrieval/retrieval.go b/pkg/retrieval/retrieval.go index cae78558df9..69a04355ee9 100644 --- a/pkg/retrieval/retrieval.go +++ b/pkg/retrieval/retrieval.go @@ -122,7 +122,7 @@ func (s *Service) Protocol() p2p.ProtocolSpec { } const ( - retrieveChunkTimeout = time.Second * 30 + RetrieveChunkTimeout = time.Second * 30 preemptiveInterval = time.Second overDraftRefresh = time.Millisecond * 600 skiplistDur = time.Minute @@ -320,7 +320,7 @@ func (s *Service) retrieveChunk(ctx context.Context, quit chan struct{}, chunkAd } }() - ctx, cancel := context.WithTimeout(ctx, retrieveChunkTimeout) + ctx, cancel := context.WithTimeout(ctx, RetrieveChunkTimeout) defer cancel() stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName) @@ -425,7 +425,7 @@ func (s *Service) closestPeer(addr swarm.Address, skipPeers []swarm.Address, all } func (s *Service) handler(p2pctx context.Context, p p2p.Peer, stream p2p.Stream) (err error) { - ctx, cancel := context.WithTimeout(p2pctx, retrieveChunkTimeout) + ctx, cancel := context.WithTimeout(p2pctx, RetrieveChunkTimeout) defer cancel() w, r := protobuf.NewWriterAndReader(stream)