Skip to content

Commit

Permalink
fix(p2p)!: use request timeout (#194)
Browse files Browse the repository at this point in the history
## Overview

We were missing request timeout for 1 handler. Fixing it + adding
timeout tests.

Also, renaming `RangeRequestTimeout` to `RequestTimeout` for
consistency.
  • Loading branch information
cristaloleg authored Jun 12, 2024
1 parent b81f0d7 commit 672fc95
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 38 deletions.
2 changes: 1 addition & 1 deletion p2p/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (ex *Exchange[H]) GetRangeByHeight(
))
defer span.End()
session := newSession[H](
ex.ctx, ex.host, ex.peerTracker, ex.protocolID, ex.Params.RangeRequestTimeout, ex.metrics, withValidation(from),
ex.ctx, ex.host, ex.peerTracker, ex.protocolID, ex.Params.RequestTimeout, ex.metrics, withValidation(from),
)
defer session.close()
// we request the next header height that we don't have: `fromHead`+1
Expand Down
4 changes: 2 additions & 2 deletions p2p/exchange_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,15 +482,15 @@ func TestExchange_RequestHeadersFromAnotherPeerWhenTimeout(t *testing.T) {

// create client + server(it does not have needed headers)
exchg, store := createP2PExAndServer(t, host0, host1)
exchg.Params.RangeRequestTimeout = time.Millisecond * 100
exchg.Params.RequestTimeout = time.Millisecond * 100
// create one more server(with more headers in the store)
serverSideEx, err := NewExchangeServer[*headertest.DummyHeader](
host2, headertest.NewStore[*headertest.DummyHeader](t, headertest.NewTestSuite(t), 10),
WithNetworkID[ServerParameters](networkID),
)
require.NoError(t, err)
// change store implementation
serverSideEx.store = &timedOutStore{timeout: exchg.Params.RangeRequestTimeout}
serverSideEx.store = &timedOutStore{timeout: exchg.Params.RequestTimeout}
require.NoError(t, serverSideEx.Start(context.Background()))
t.Cleanup(func() {
serverSideEx.Stop(context.Background()) //nolint:errcheck
Expand Down
34 changes: 17 additions & 17 deletions p2p/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ type ServerParameters struct {
WriteDeadline time.Duration
// ReadDeadline sets the timeout for reading messages from the stream
ReadDeadline time.Duration
// RangeRequestTimeout defines a timeout after which the session will try to re-request headers
// RequestTimeout defines a timeout after which the session will try to re-request headers
// from another peer.
RangeRequestTimeout time.Duration
RequestTimeout time.Duration
// networkID is a network that will be used to create a protocol.ID
// Is empty by default
networkID string
Expand All @@ -35,9 +35,9 @@ type ServerParameters struct {
// DefaultServerParameters returns the default params to configure the store.
func DefaultServerParameters() ServerParameters {
return ServerParameters{
WriteDeadline: time.Second * 8,
ReadDeadline: time.Minute,
RangeRequestTimeout: time.Second * 10,
WriteDeadline: time.Second * 8,
ReadDeadline: time.Minute,
RequestTimeout: time.Second * 10,
}
}

Expand All @@ -48,9 +48,9 @@ func (p *ServerParameters) Validate() error {
if p.ReadDeadline == 0 {
return fmt.Errorf("invalid read time duration: %v", p.ReadDeadline)
}
if p.RangeRequestTimeout == 0 {
if p.RequestTimeout == 0 {
return fmt.Errorf("invalid request timeout for session: "+
"%s. %s: %v", greaterThenZero, providedSuffix, p.RangeRequestTimeout)
"%s. %s: %v", greaterThenZero, providedSuffix, p.RequestTimeout)
}
return nil
}
Expand Down Expand Up @@ -88,15 +88,15 @@ func WithReadDeadline[T ServerParameters](deadline time.Duration) Option[T] {
}
}

// WithRangeRequestTimeout is a functional option that configures the
// `RangeRequestTimeout` parameter.
func WithRangeRequestTimeout[T parameters](duration time.Duration) Option[T] {
// WithRequestTimeout is a functional option that configures the
// `RequestTimeout` parameter.
func WithRequestTimeout[T parameters](duration time.Duration) Option[T] {
return func(p *T) {
switch t := any(p).(type) {
case *ClientParameters:
t.RangeRequestTimeout = duration
t.RequestTimeout = duration
case *ServerParameters:
t.RangeRequestTimeout = duration
t.RequestTimeout = duration
}
}
}
Expand Down Expand Up @@ -125,9 +125,9 @@ func WithParams[T parameters](params T) Option[T] {
type ClientParameters struct {
// MaxHeadersPerRangeRequest defines the max amount of headers that can be requested per 1 request.
MaxHeadersPerRangeRequest uint64
// RangeRequestTimeout defines a timeout after which the session will try to re-request headers
// RequestTimeout defines a timeout after which the session will try to re-request headers
// from another peer.
RangeRequestTimeout time.Duration
RequestTimeout time.Duration
// networkID is a network that will be used to create a protocol.ID
networkID string
// chainID is an identifier of the chain.
Expand All @@ -142,7 +142,7 @@ type ClientParameters struct {
func DefaultClientParameters() ClientParameters {
return ClientParameters{
MaxHeadersPerRangeRequest: 64,
RangeRequestTimeout: time.Second * 8,
RequestTimeout: time.Second * 8,
}
}

Expand All @@ -156,9 +156,9 @@ func (p *ClientParameters) Validate() error {
return fmt.Errorf("invalid MaxHeadersPerRangeRequest:%s. %s: %v",
greaterThenZero, providedSuffix, p.MaxHeadersPerRangeRequest)
}
if p.RangeRequestTimeout == 0 {
if p.RequestTimeout == 0 {
return fmt.Errorf("invalid request timeout for session: "+
"%s. %s: %v", greaterThenZero, providedSuffix, p.RangeRequestTimeout)
"%s. %s: %v", greaterThenZero, providedSuffix, p.RequestTimeout)
}
return nil
}
Expand Down
8 changes: 4 additions & 4 deletions p2p/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ func TestOptionsClientWithParams(t *testing.T) {

timeout := time.Second
opt := WithParams(ClientParameters{
RangeRequestTimeout: timeout,
RequestTimeout: timeout,
})

opt(&params)
assert.Equal(t, timeout, params.RangeRequestTimeout)
assert.Equal(t, timeout, params.RequestTimeout)
}

func TestOptionsServerWithParams(t *testing.T) {
params := DefaultServerParameters()

timeout := time.Second
opt := WithParams(ServerParameters{
RangeRequestTimeout: timeout,
RequestTimeout: timeout,
})

opt(&params)
assert.Equal(t, timeout, params.RangeRequestTimeout)
assert.Equal(t, timeout, params.RequestTimeout)
}
24 changes: 12 additions & 12 deletions p2p/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,16 @@ func (serv *ExchangeServer[H]) requestHandler(stream network.Stream) {
log.Error(err)
}

ctx, cancel := context.WithTimeout(serv.ctx, serv.Params.RequestTimeout)
defer cancel()

var headers []H
// retrieve and write Headers
switch pbreq.Data.(type) {
case *p2p_pb.HeaderRequest_Hash:
headers, err = serv.handleRequestByHash(pbreq.GetHash())
headers, err = serv.handleRequestByHash(ctx, pbreq.GetHash())
case *p2p_pb.HeaderRequest_Origin:
headers, err = serv.handleRequest(pbreq.GetOrigin(), pbreq.GetOrigin()+pbreq.Amount)
headers, err = serv.handleRangeRequest(ctx, pbreq.GetOrigin(), pbreq.GetOrigin()+pbreq.Amount)
default:
log.Warn("server: invalid data type received")
stream.Reset() //nolint:errcheck
Expand Down Expand Up @@ -166,11 +169,9 @@ func (serv *ExchangeServer[H]) requestHandler(stream network.Stream) {

// handleRequestByHash returns the Header at the given hash
// if it exists.
func (serv *ExchangeServer[H]) handleRequestByHash(hash []byte) ([]H, error) {
func (serv *ExchangeServer[H]) handleRequestByHash(ctx context.Context, hash []byte) ([]H, error) {
startTime := time.Now()
log.Debugw("server: handling header request", "hash", header.Hash(hash).String())
ctx, cancel := context.WithTimeout(serv.ctx, serv.Params.RangeRequestTimeout)
defer cancel()
ctx, span := tracerServ.Start(ctx, "request-by-hash", trace.WithAttributes(
attribute.String("hash", header.Hash(hash).String()),
))
Expand All @@ -194,15 +195,16 @@ func (serv *ExchangeServer[H]) handleRequestByHash(hash []byte) ([]H, error) {
return []H{h}, nil
}

// handleRequest fetches the Header at the given origin and
// handleRangeRequest fetches the Header at the given origin and
// writes it to the stream.
func (serv *ExchangeServer[H]) handleRequest(from, to uint64) ([]H, error) {
func (serv *ExchangeServer[H]) handleRangeRequest(ctx context.Context, from, to uint64) ([]H, error) {
if from == uint64(0) {
return serv.handleHeadRequest()
return serv.handleHeadRequest(ctx)
}

startTime := time.Now()
ctx, span := tracerServ.Start(serv.ctx, "request-range", trace.WithAttributes(
log.Debugw("server: handling range request", "from", from, "to", to)
ctx, span := tracerServ.Start(ctx, "request-range", trace.WithAttributes(
attribute.Int64("from", int64(from)),
attribute.Int64("to", int64(to))))
defer span.End()
Expand Down Expand Up @@ -266,11 +268,9 @@ func (serv *ExchangeServer[H]) handleRequest(from, to uint64) ([]H, error) {
}

// handleHeadRequest returns the latest stored head.
func (serv *ExchangeServer[H]) handleHeadRequest() ([]H, error) {
func (serv *ExchangeServer[H]) handleHeadRequest(ctx context.Context) ([]H, error) {
startTime := time.Now()
log.Debug("server: handling head request")
ctx, cancel := context.WithTimeout(serv.ctx, serv.Params.RangeRequestTimeout)
defer cancel()
ctx, span := tracerServ.Start(ctx, "request-head")
defer span.End()

Expand Down
133 changes: 131 additions & 2 deletions p2p/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package p2p
import (
"context"
"testing"
"time"

"github.com/ipfs/go-datastore"
"github.com/stretchr/testify/require"
Expand All @@ -28,7 +29,7 @@ func TestExchangeServer_handleRequestTimeout(t *testing.T) {
server.Stop(context.Background()) //nolint:errcheck
})

_, err = server.handleRequest(1, 200)
_, err = server.handleRangeRequest(context.Background(), 1, 200)
require.Error(t, err)
}

Expand All @@ -48,6 +49,134 @@ func TestExchangeServer_errorsOnLargeRequest(t *testing.T) {
server.Stop(context.Background()) //nolint:errcheck
})

_, err = server.handleRequest(1, header.MaxRangeRequestSize*2)
_, err = server.handleRangeRequest(context.Background(), 1, header.MaxRangeRequestSize*2)
require.Error(t, err)
}

func TestExchangeServer_Timeout(t *testing.T) {
const testRequestTimeout = 150 * time.Millisecond

peer := createMocknet(t, 1)

server, err := NewExchangeServer(
peer[0],
timeoutStore[*headertest.DummyHeader]{},
WithNetworkID[ServerParameters](networkID),
WithRequestTimeout[ServerParameters](testRequestTimeout),
)
require.NoError(t, err)

err = server.Start(context.Background())
require.NoError(t, err)

t.Cleanup(func() {
_ = server.Stop(context.Background())
})

testCases := []struct {
name string
fn func() error
}{
{
name: "handleHeadRequest",
fn: func() error {
ctx, cancel := context.WithTimeout(context.Background(), testRequestTimeout)
defer cancel()

_, err := server.handleHeadRequest(ctx)
return err
},
},
{
name: "handleRequest",
fn: func() error {
ctx, cancel := context.WithTimeout(context.Background(), testRequestTimeout)
defer cancel()

_, err := server.handleRangeRequest(ctx, 1, 100)
return err
},
},
{
name: "handleHeadRequest",
fn: func() error {
ctx, cancel := context.WithTimeout(context.Background(), testRequestTimeout)
defer cancel()

hash := headertest.RandDummyHeader(t).Hash()
_, err := server.handleRequestByHash(ctx, hash)
return err
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

start := time.Now()
err := tc.fn()
took := time.Since(start)

require.Error(t, err)
require.GreaterOrEqual(t, took, testRequestTimeout)
})
}
}

var _ header.Store[*headertest.DummyHeader] = timeoutStore[*headertest.DummyHeader]{}

// timeoutStore does nothing but waits till context cancellation for every method.
type timeoutStore[H header.Header[H]] struct{}

func (timeoutStore[H]) Head(ctx context.Context, _ ...header.HeadOption[H]) (H, error) {
<-ctx.Done()
var zero H
return zero, ctx.Err()
}

func (timeoutStore[H]) Get(ctx context.Context, _ header.Hash) (H, error) {
<-ctx.Done()
var zero H
return zero, ctx.Err()
}

func (timeoutStore[H]) GetByHeight(ctx context.Context, _ uint64) (H, error) {
<-ctx.Done()
var zero H
return zero, ctx.Err()
}

func (timeoutStore[H]) GetRangeByHeight(ctx context.Context, from H, to uint64) ([]H, error) {
<-ctx.Done()
return nil, ctx.Err()
}

func (timeoutStore[H]) Init(ctx context.Context, _ H) error {
<-ctx.Done()
return ctx.Err()
}

func (timeoutStore[H]) Height() uint64 {
return 0
}

func (timeoutStore[H]) Has(ctx context.Context, _ header.Hash) (bool, error) {
<-ctx.Done()
return false, ctx.Err()
}

func (timeoutStore[H]) HasAt(ctx context.Context, _ uint64) bool {
<-ctx.Done()
return false
}

func (timeoutStore[H]) Append(ctx context.Context, _ ...H) error {
<-ctx.Done()
return ctx.Err()
}

func (timeoutStore[H]) GetRange(ctx context.Context, _ uint64, _ uint64) ([]H, error) {
<-ctx.Done()
return nil, ctx.Err()
}

0 comments on commit 672fc95

Please sign in to comment.