From 672fc9575acd80177f89991a60a95ef2ceddb4f7 Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Wed, 12 Jun 2024 11:49:56 +0200 Subject: [PATCH] fix(p2p)!: use request timeout (#194) ## Overview We were missing request timeout for 1 handler. Fixing it + adding timeout tests. Also, renaming `RangeRequestTimeout` to `RequestTimeout` for consistency. --- p2p/exchange.go | 2 +- p2p/exchange_test.go | 4 +- p2p/options.go | 34 +++++------ p2p/options_test.go | 8 +-- p2p/server.go | 24 ++++---- p2p/server_test.go | 133 ++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 167 insertions(+), 38 deletions(-) diff --git a/p2p/exchange.go b/p2p/exchange.go index df199e10..fbf1f23a 100644 --- a/p2p/exchange.go +++ b/p2p/exchange.go @@ -293,7 +293,7 @@ func (ex *Exchange[H]) GetRangeByHeight( )) defer span.End() session := newSession[H]( - ex.ctx, ex.host, ex.peerTracker, ex.protocolID, ex.Params.RangeRequestTimeout, ex.metrics, withValidation(from), + ex.ctx, ex.host, ex.peerTracker, ex.protocolID, ex.Params.RequestTimeout, ex.metrics, withValidation(from), ) defer session.close() // we request the next header height that we don't have: `fromHead`+1 diff --git a/p2p/exchange_test.go b/p2p/exchange_test.go index 38298e0e..1a775bd3 100644 --- a/p2p/exchange_test.go +++ b/p2p/exchange_test.go @@ -482,7 +482,7 @@ func TestExchange_RequestHeadersFromAnotherPeerWhenTimeout(t *testing.T) { // create client + server(it does not have needed headers) exchg, store := createP2PExAndServer(t, host0, host1) - exchg.Params.RangeRequestTimeout = time.Millisecond * 100 + exchg.Params.RequestTimeout = time.Millisecond * 100 // create one more server(with more headers in the store) serverSideEx, err := NewExchangeServer[*headertest.DummyHeader]( host2, headertest.NewStore[*headertest.DummyHeader](t, headertest.NewTestSuite(t), 10), @@ -490,7 +490,7 @@ func TestExchange_RequestHeadersFromAnotherPeerWhenTimeout(t *testing.T) { ) require.NoError(t, err) // change store implementation - serverSideEx.store = &timedOutStore{timeout: exchg.Params.RangeRequestTimeout} + serverSideEx.store = &timedOutStore{timeout: exchg.Params.RequestTimeout} require.NoError(t, serverSideEx.Start(context.Background())) t.Cleanup(func() { serverSideEx.Stop(context.Background()) //nolint:errcheck diff --git a/p2p/options.go b/p2p/options.go index b1b4b9cf..aa5ec505 100644 --- a/p2p/options.go +++ b/p2p/options.go @@ -22,9 +22,9 @@ type ServerParameters struct { WriteDeadline time.Duration // ReadDeadline sets the timeout for reading messages from the stream ReadDeadline time.Duration - // RangeRequestTimeout defines a timeout after which the session will try to re-request headers + // RequestTimeout defines a timeout after which the session will try to re-request headers // from another peer. - RangeRequestTimeout time.Duration + RequestTimeout time.Duration // networkID is a network that will be used to create a protocol.ID // Is empty by default networkID string @@ -35,9 +35,9 @@ type ServerParameters struct { // DefaultServerParameters returns the default params to configure the store. func DefaultServerParameters() ServerParameters { return ServerParameters{ - WriteDeadline: time.Second * 8, - ReadDeadline: time.Minute, - RangeRequestTimeout: time.Second * 10, + WriteDeadline: time.Second * 8, + ReadDeadline: time.Minute, + RequestTimeout: time.Second * 10, } } @@ -48,9 +48,9 @@ func (p *ServerParameters) Validate() error { if p.ReadDeadline == 0 { return fmt.Errorf("invalid read time duration: %v", p.ReadDeadline) } - if p.RangeRequestTimeout == 0 { + if p.RequestTimeout == 0 { return fmt.Errorf("invalid request timeout for session: "+ - "%s. %s: %v", greaterThenZero, providedSuffix, p.RangeRequestTimeout) + "%s. %s: %v", greaterThenZero, providedSuffix, p.RequestTimeout) } return nil } @@ -88,15 +88,15 @@ func WithReadDeadline[T ServerParameters](deadline time.Duration) Option[T] { } } -// WithRangeRequestTimeout is a functional option that configures the -// `RangeRequestTimeout` parameter. -func WithRangeRequestTimeout[T parameters](duration time.Duration) Option[T] { +// WithRequestTimeout is a functional option that configures the +// `RequestTimeout` parameter. +func WithRequestTimeout[T parameters](duration time.Duration) Option[T] { return func(p *T) { switch t := any(p).(type) { case *ClientParameters: - t.RangeRequestTimeout = duration + t.RequestTimeout = duration case *ServerParameters: - t.RangeRequestTimeout = duration + t.RequestTimeout = duration } } } @@ -125,9 +125,9 @@ func WithParams[T parameters](params T) Option[T] { type ClientParameters struct { // MaxHeadersPerRangeRequest defines the max amount of headers that can be requested per 1 request. MaxHeadersPerRangeRequest uint64 - // RangeRequestTimeout defines a timeout after which the session will try to re-request headers + // RequestTimeout defines a timeout after which the session will try to re-request headers // from another peer. - RangeRequestTimeout time.Duration + RequestTimeout time.Duration // networkID is a network that will be used to create a protocol.ID networkID string // chainID is an identifier of the chain. @@ -142,7 +142,7 @@ type ClientParameters struct { func DefaultClientParameters() ClientParameters { return ClientParameters{ MaxHeadersPerRangeRequest: 64, - RangeRequestTimeout: time.Second * 8, + RequestTimeout: time.Second * 8, } } @@ -156,9 +156,9 @@ func (p *ClientParameters) Validate() error { return fmt.Errorf("invalid MaxHeadersPerRangeRequest:%s. %s: %v", greaterThenZero, providedSuffix, p.MaxHeadersPerRangeRequest) } - if p.RangeRequestTimeout == 0 { + if p.RequestTimeout == 0 { return fmt.Errorf("invalid request timeout for session: "+ - "%s. %s: %v", greaterThenZero, providedSuffix, p.RangeRequestTimeout) + "%s. %s: %v", greaterThenZero, providedSuffix, p.RequestTimeout) } return nil } diff --git a/p2p/options_test.go b/p2p/options_test.go index a6d7861d..afb4e919 100644 --- a/p2p/options_test.go +++ b/p2p/options_test.go @@ -12,11 +12,11 @@ func TestOptionsClientWithParams(t *testing.T) { timeout := time.Second opt := WithParams(ClientParameters{ - RangeRequestTimeout: timeout, + RequestTimeout: timeout, }) opt(¶ms) - assert.Equal(t, timeout, params.RangeRequestTimeout) + assert.Equal(t, timeout, params.RequestTimeout) } func TestOptionsServerWithParams(t *testing.T) { @@ -24,9 +24,9 @@ func TestOptionsServerWithParams(t *testing.T) { timeout := time.Second opt := WithParams(ServerParameters{ - RangeRequestTimeout: timeout, + RequestTimeout: timeout, }) opt(¶ms) - assert.Equal(t, timeout, params.RangeRequestTimeout) + assert.Equal(t, timeout, params.RequestTimeout) } diff --git a/p2p/server.go b/p2p/server.go index 5e0194ad..ebd8111f 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -105,13 +105,16 @@ func (serv *ExchangeServer[H]) requestHandler(stream network.Stream) { log.Error(err) } + ctx, cancel := context.WithTimeout(serv.ctx, serv.Params.RequestTimeout) + defer cancel() + var headers []H // retrieve and write Headers switch pbreq.Data.(type) { case *p2p_pb.HeaderRequest_Hash: - headers, err = serv.handleRequestByHash(pbreq.GetHash()) + headers, err = serv.handleRequestByHash(ctx, pbreq.GetHash()) case *p2p_pb.HeaderRequest_Origin: - headers, err = serv.handleRequest(pbreq.GetOrigin(), pbreq.GetOrigin()+pbreq.Amount) + headers, err = serv.handleRangeRequest(ctx, pbreq.GetOrigin(), pbreq.GetOrigin()+pbreq.Amount) default: log.Warn("server: invalid data type received") stream.Reset() //nolint:errcheck @@ -166,11 +169,9 @@ func (serv *ExchangeServer[H]) requestHandler(stream network.Stream) { // handleRequestByHash returns the Header at the given hash // if it exists. -func (serv *ExchangeServer[H]) handleRequestByHash(hash []byte) ([]H, error) { +func (serv *ExchangeServer[H]) handleRequestByHash(ctx context.Context, hash []byte) ([]H, error) { startTime := time.Now() log.Debugw("server: handling header request", "hash", header.Hash(hash).String()) - ctx, cancel := context.WithTimeout(serv.ctx, serv.Params.RangeRequestTimeout) - defer cancel() ctx, span := tracerServ.Start(ctx, "request-by-hash", trace.WithAttributes( attribute.String("hash", header.Hash(hash).String()), )) @@ -194,15 +195,16 @@ func (serv *ExchangeServer[H]) handleRequestByHash(hash []byte) ([]H, error) { return []H{h}, nil } -// handleRequest fetches the Header at the given origin and +// handleRangeRequest fetches the Header at the given origin and // writes it to the stream. -func (serv *ExchangeServer[H]) handleRequest(from, to uint64) ([]H, error) { +func (serv *ExchangeServer[H]) handleRangeRequest(ctx context.Context, from, to uint64) ([]H, error) { if from == uint64(0) { - return serv.handleHeadRequest() + return serv.handleHeadRequest(ctx) } startTime := time.Now() - ctx, span := tracerServ.Start(serv.ctx, "request-range", trace.WithAttributes( + log.Debugw("server: handling range request", "from", from, "to", to) + ctx, span := tracerServ.Start(ctx, "request-range", trace.WithAttributes( attribute.Int64("from", int64(from)), attribute.Int64("to", int64(to)))) defer span.End() @@ -266,11 +268,9 @@ func (serv *ExchangeServer[H]) handleRequest(from, to uint64) ([]H, error) { } // handleHeadRequest returns the latest stored head. -func (serv *ExchangeServer[H]) handleHeadRequest() ([]H, error) { +func (serv *ExchangeServer[H]) handleHeadRequest(ctx context.Context) ([]H, error) { startTime := time.Now() log.Debug("server: handling head request") - ctx, cancel := context.WithTimeout(serv.ctx, serv.Params.RangeRequestTimeout) - defer cancel() ctx, span := tracerServ.Start(ctx, "request-head") defer span.End() diff --git a/p2p/server_test.go b/p2p/server_test.go index b8e558a8..1e896b2e 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -3,6 +3,7 @@ package p2p import ( "context" "testing" + "time" "github.com/ipfs/go-datastore" "github.com/stretchr/testify/require" @@ -28,7 +29,7 @@ func TestExchangeServer_handleRequestTimeout(t *testing.T) { server.Stop(context.Background()) //nolint:errcheck }) - _, err = server.handleRequest(1, 200) + _, err = server.handleRangeRequest(context.Background(), 1, 200) require.Error(t, err) } @@ -48,6 +49,134 @@ func TestExchangeServer_errorsOnLargeRequest(t *testing.T) { server.Stop(context.Background()) //nolint:errcheck }) - _, err = server.handleRequest(1, header.MaxRangeRequestSize*2) + _, err = server.handleRangeRequest(context.Background(), 1, header.MaxRangeRequestSize*2) require.Error(t, err) } + +func TestExchangeServer_Timeout(t *testing.T) { + const testRequestTimeout = 150 * time.Millisecond + + peer := createMocknet(t, 1) + + server, err := NewExchangeServer( + peer[0], + timeoutStore[*headertest.DummyHeader]{}, + WithNetworkID[ServerParameters](networkID), + WithRequestTimeout[ServerParameters](testRequestTimeout), + ) + require.NoError(t, err) + + err = server.Start(context.Background()) + require.NoError(t, err) + + t.Cleanup(func() { + _ = server.Stop(context.Background()) + }) + + testCases := []struct { + name string + fn func() error + }{ + { + name: "handleHeadRequest", + fn: func() error { + ctx, cancel := context.WithTimeout(context.Background(), testRequestTimeout) + defer cancel() + + _, err := server.handleHeadRequest(ctx) + return err + }, + }, + { + name: "handleRequest", + fn: func() error { + ctx, cancel := context.WithTimeout(context.Background(), testRequestTimeout) + defer cancel() + + _, err := server.handleRangeRequest(ctx, 1, 100) + return err + }, + }, + { + name: "handleHeadRequest", + fn: func() error { + ctx, cancel := context.WithTimeout(context.Background(), testRequestTimeout) + defer cancel() + + hash := headertest.RandDummyHeader(t).Hash() + _, err := server.handleRequestByHash(ctx, hash) + return err + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + start := time.Now() + err := tc.fn() + took := time.Since(start) + + require.Error(t, err) + require.GreaterOrEqual(t, took, testRequestTimeout) + }) + } +} + +var _ header.Store[*headertest.DummyHeader] = timeoutStore[*headertest.DummyHeader]{} + +// timeoutStore does nothing but waits till context cancellation for every method. +type timeoutStore[H header.Header[H]] struct{} + +func (timeoutStore[H]) Head(ctx context.Context, _ ...header.HeadOption[H]) (H, error) { + <-ctx.Done() + var zero H + return zero, ctx.Err() +} + +func (timeoutStore[H]) Get(ctx context.Context, _ header.Hash) (H, error) { + <-ctx.Done() + var zero H + return zero, ctx.Err() +} + +func (timeoutStore[H]) GetByHeight(ctx context.Context, _ uint64) (H, error) { + <-ctx.Done() + var zero H + return zero, ctx.Err() +} + +func (timeoutStore[H]) GetRangeByHeight(ctx context.Context, from H, to uint64) ([]H, error) { + <-ctx.Done() + return nil, ctx.Err() +} + +func (timeoutStore[H]) Init(ctx context.Context, _ H) error { + <-ctx.Done() + return ctx.Err() +} + +func (timeoutStore[H]) Height() uint64 { + return 0 +} + +func (timeoutStore[H]) Has(ctx context.Context, _ header.Hash) (bool, error) { + <-ctx.Done() + return false, ctx.Err() +} + +func (timeoutStore[H]) HasAt(ctx context.Context, _ uint64) bool { + <-ctx.Done() + return false +} + +func (timeoutStore[H]) Append(ctx context.Context, _ ...H) error { + <-ctx.Done() + return ctx.Err() +} + +func (timeoutStore[H]) GetRange(ctx context.Context, _ uint64, _ uint64) ([]H, error) { + <-ctx.Done() + return nil, ctx.Err() +}