Skip to content

Commit

Permalink
Guard USDC API hanging and rate limits (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
matYang authored Nov 8, 2023
1 parent e32c2e8 commit 0f3634f
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 42 deletions.
4 changes: 4 additions & 0 deletions core/services/ocr2/plugins/ccip/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ type USDCConfig struct {
SourceTokenAddress common.Address
SourceMessageTransmitterAddress common.Address
AttestationAPI string
AttestationAPITimeoutSeconds int
}

func (uc *USDCConfig) ValidateUSDCConfig() error {
if uc.AttestationAPI == "" {
return errors.New("AttestationAPI is required")
}
if uc.AttestationAPITimeoutSeconds < 0 {
return errors.New("AttestationAPITimeoutSeconds must be non-negative")
}
if uc.SourceTokenAddress == utils.ZeroAddress {
return errors.New("SourceTokenAddress is required")
}
Expand Down
1 change: 1 addition & 0 deletions core/services/ocr2/plugins/ccip/execution_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ func getTokenDataProviders(lggr logger.Logger, pluginConfig ccipconfig.Execution
lggr,
usdcReader,
attestationURI,
pluginConfig.USDCConfig.AttestationAPITimeoutSeconds,
),
)
}
Expand Down
29 changes: 16 additions & 13 deletions core/services/ocr2/plugins/ccip/execution_reporting_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ func (r *ExecutionReportingPlugin) buildBatch(
availableGas := uint64(r.offchainConfig.BatchGasLimit)
expectedNonces := make(map[common.Address]uint64)
availableDataLen := MaxDataLenPerBatch
skipTokenWithData := false

for _, msg := range report.sendRequestsWithMeta {
msgLggr := lggr.With("messageID", hexutil.Encode(msg.MessageId[:]))
Expand Down Expand Up @@ -525,13 +526,15 @@ func (r *ExecutionReportingPlugin) buildBatch(
continue
}

tokenData, ready, err2 := getTokenData(ctx, msgLggr, msg, r.config.tokenDataProviders)
tokenData, err2 := getTokenData(ctx, msgLggr, msg, r.config.tokenDataProviders, skipTokenWithData)
if err2 != nil {
msgLggr.Errorw("Skipping message unable to check token data", "err", err2)
continue
}
if !ready {
msgLggr.Warnw("Skipping message attestation not ready")
// When fetching token data, 3rd party API could hang or rate limit or fail due to any reason.
// If this happens, we skip all remaining msgs that require token data in this batch.
// If the issue is transient, then it is likely for other nodes in the DON to succeed and execute the msg anyway.
// If the issue is API outage or rate limit, then we should indeed avoid calling the API.
// If API issues do not resolve, eventually the root will only contain msg that should be skipped, and be snoozed.
skipTokenWithData = true
msgLggr.Errorw("Skipping message unable to get token data", "err", err2)
continue
}

Expand Down Expand Up @@ -621,28 +624,28 @@ func (r *ExecutionReportingPlugin) buildBatch(
return executableMessages
}

func getTokenData(ctx context.Context, lggr logger.Logger, msg internal.EVM2EVMOnRampCCIPSendRequestedWithMeta, tokenDataProviders map[common.Address]tokendata.Reader) (tokenData [][]byte, allReady bool, err error) {
func getTokenData(ctx context.Context, lggr logger.Logger, msg internal.EVM2EVMOnRampCCIPSendRequestedWithMeta, tokenDataProviders map[common.Address]tokendata.Reader, skipTokenWithData bool) (tokenData [][]byte, err error) {
for _, token := range msg.TokenAmounts {
offchainTokenDataProvider, ok := tokenDataProviders[token.Token]
if !ok {
// No token data required
tokenData = append(tokenData, []byte{})
continue
}
if skipTokenWithData {
// If token data is required but should be skipped, exit without calling the API
return [][]byte{}, errors.New("token requiring data is flagged to be skipped")
}
lggr.Infow("Fetching token data", "token", token.Token.Hex())
tknData, err2 := offchainTokenDataProvider.ReadTokenData(ctx, msg)
if err2 != nil {
if errors.Is(err2, tokendata.ErrNotReady) {
lggr.Infow("Token data not ready yet", "token", token.Token.Hex())
return [][]byte{}, false, nil
}
return [][]byte{}, false, err2
return [][]byte{}, err2
}

lggr.Infow("Token data retrieved", "token", token.Token.Hex())
tokenData = append(tokenData, tknData)
}
return tokenData, true, nil
return tokenData, nil
}

func (r *ExecutionReportingPlugin) isRateLimitEnoughForTokenPool(
Expand Down
4 changes: 3 additions & 1 deletion core/services/ocr2/plugins/ccip/tokendata/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
)

var (
ErrNotReady = errors.New("token data not ready")
ErrNotReady = errors.New("token data not ready")
ErrRateLimit = errors.New("token data API is being rate limited")
ErrTimeout = errors.New("token data API timed out")
)

// Reader is an interface for fetching offchain token data
Expand Down
43 changes: 33 additions & 10 deletions core/services/ocr2/plugins/ccip/tokendata/usdc/usdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"strings"
"time"

"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/pkg/errors"
Expand All @@ -23,8 +24,9 @@ import (
)

const (
apiVersion = "v1"
attestationPath = "attestations"
apiVersion = "v1"
attestationPath = "attestations"
defaultAttestationTimeout = 5 * time.Second
)

type attestationStatus string
Expand Down Expand Up @@ -63,9 +65,10 @@ func (m messageAndAttestation) Validate() error {
}

type TokenDataReader struct {
lggr logger.Logger
usdcReader ccipdata.USDCReader
attestationApi *url.URL
lggr logger.Logger
usdcReader ccipdata.USDCReader
attestationApi *url.URL
attestationApiTimeout time.Duration
}

type attestationResponse struct {
Expand All @@ -75,11 +78,17 @@ type attestationResponse struct {

var _ tokendata.Reader = &TokenDataReader{}

func NewUSDCTokenDataReader(lggr logger.Logger, usdcReader ccipdata.USDCReader, usdcAttestationApi *url.URL) *TokenDataReader {
func NewUSDCTokenDataReader(lggr logger.Logger, usdcReader ccipdata.USDCReader, usdcAttestationApi *url.URL, usdcAttestationApiTimeoutSeconds int) *TokenDataReader {
timeout := time.Duration(usdcAttestationApiTimeoutSeconds) * time.Second
if usdcAttestationApiTimeoutSeconds == 0 {
timeout = defaultAttestationTimeout
}

return &TokenDataReader{
lggr: lggr,
usdcReader: usdcReader,
attestationApi: usdcAttestationApi,
lggr: lggr,
usdcReader: usdcReader,
attestationApi: usdcAttestationApi,
attestationApiTimeout: timeout,
}
}

Expand Down Expand Up @@ -135,16 +144,30 @@ func (s *TokenDataReader) getUSDCMessageBody(ctx context.Context, msg internal.E

func (s *TokenDataReader) callAttestationApi(ctx context.Context, usdcMessageHash [32]byte) (attestationResponse, error) {
fullAttestationUrl := fmt.Sprintf("%s/%s/%s/0x%x", s.attestationApi, apiVersion, attestationPath, usdcMessageHash)
req, err := http.NewRequestWithContext(ctx, "GET", fullAttestationUrl, nil)

// Use a timeout to guard against attestation API hanging, causing observation timeout and failing to make any progress.
timeoutCtx, cancel := context.WithTimeout(ctx, s.attestationApiTimeout)
defer cancel()
req, err := http.NewRequestWithContext(timeoutCtx, "GET", fullAttestationUrl, nil)

if err != nil {
return attestationResponse{}, err
}
req.Header.Add("accept", "application/json")
res, err := http.DefaultClient.Do(req)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return attestationResponse{}, tokendata.ErrTimeout
}
return attestationResponse{}, err
}
defer res.Body.Close()

// Explicitly signal if the API is being rate limited
if res.StatusCode == http.StatusTooManyRequests {
return attestationResponse{}, tokendata.ErrRateLimit
}

body, err := io.ReadAll(res.Body)
if err != nil {
return attestationResponse{}, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestUSDCReader_ReadTokenData(t *testing.T) {
attestationURI, err := url.ParseRequestURI(ts.URL)
require.NoError(t, err)

usdcService := usdc.NewUSDCTokenDataReader(lggr, &usdcReader, attestationURI)
usdcService := usdc.NewUSDCTokenDataReader(lggr, &usdcReader, attestationURI, 0)
msgAndAttestation, err := usdcService.ReadTokenData(context.Background(), internal.EVM2EVMOnRampCCIPSendRequestedWithMeta{
EVM2EVMMessage: internal.EVM2EVMMessage{
SequenceNumber: seqNum,
Expand Down
109 changes: 92 additions & 17 deletions core/services/ocr2/plugins/ccip/tokendata/usdc/usdc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/pkg/errors"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

Expand All @@ -17,6 +19,7 @@ import (
"github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal"
"github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipdata"
ccipdatamocks "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipdata/mocks"
"github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/tokendata"
"github.com/smartcontractkit/chainlink/v2/core/utils"
)

Expand All @@ -32,7 +35,7 @@ func TestUSDCReader_callAttestationApi(t *testing.T) {
lggr := logger.TestLogger(t)
usdcReader, err := ccipdata.NewUSDCReader(lggr, mockMsgTransmitter, nil)
require.NoError(t, err)
usdcService := NewUSDCTokenDataReader(lggr, usdcReader, attestationURI)
usdcService := NewUSDCTokenDataReader(lggr, usdcReader, attestationURI, 0)

attestation, err := usdcService.callAttestationApi(context.Background(), [32]byte(common.FromHex(usdcMessageHash)))
require.NoError(t, err)
Expand All @@ -57,7 +60,7 @@ func TestUSDCReader_callAttestationApiMock(t *testing.T) {
lp.On("RegisterFilter", mock.Anything).Return(nil)
usdcReader, err := ccipdata.NewUSDCReader(lggr, mockMsgTransmitter, lp)
require.NoError(t, err)
usdcService := NewUSDCTokenDataReader(lggr, usdcReader, attestationURI)
usdcService := NewUSDCTokenDataReader(lggr, usdcReader, attestationURI, 0)
attestation, err := usdcService.callAttestationApi(context.Background(), utils.RandomBytes32())
require.NoError(t, err)

Expand All @@ -66,21 +69,93 @@ func TestUSDCReader_callAttestationApiMock(t *testing.T) {
}

func TestUSDCReader_callAttestationApiMockError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
attestationURI, err := url.ParseRequestURI(ts.URL)
require.NoError(t, err)
t.Parallel()

tests := []struct {
name string
getTs func() *httptest.Server
customTimeoutSeconds int
expectedError error
}{
{
name: "server error",
getTs: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
},
expectedError: nil,
},
{
name: "default timeout",
getTs: func() *httptest.Server {
response := attestationResponse{
Status: attestationStatusSuccess,
Attestation: "720502893578a89a8a87982982ef781c18b193",
}
responseBytes, _ := json.Marshal(response)

return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(defaultAttestationTimeout + time.Second)
_, err := w.Write(responseBytes)
require.NoError(t, err)
}))

},
expectedError: tokendata.ErrTimeout,
},
{
name: "custom timeout",
getTs: func() *httptest.Server {
response := attestationResponse{
Status: attestationStatusSuccess,
Attestation: "720502893578a89a8a87982982ef781c18b193",
}
responseBytes, _ := json.Marshal(response)

return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2*time.Second + time.Second)
_, err := w.Write(responseBytes)
require.NoError(t, err)
}))

},
customTimeoutSeconds: 2,
expectedError: tokendata.ErrTimeout,
},
{
name: "rate limit",
getTs: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
}))
},
expectedError: tokendata.ErrRateLimit,
},
}

lggr := logger.TestLogger(t)
lp := mocks.NewLogPoller(t)
lp.On("RegisterFilter", mock.Anything).Return(nil)
usdcReader, err := ccipdata.NewUSDCReader(lggr, mockMsgTransmitter, lp)
require.NoError(t, err)
usdcService := NewUSDCTokenDataReader(lggr, usdcReader, attestationURI)
_, err = usdcService.callAttestationApi(context.Background(), utils.RandomBytes32())
require.Error(t, err)
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ts := test.getTs()
defer ts.Close()

attestationURI, err := url.ParseRequestURI(ts.URL)
require.NoError(t, err)

lggr := logger.TestLogger(t)
lp := mocks.NewLogPoller(t)
lp.On("RegisterFilter", mock.Anything).Return(nil)
usdcReader, err := ccipdata.NewUSDCReader(lggr, mockMsgTransmitter, lp)
require.NoError(t, err)
usdcService := NewUSDCTokenDataReader(lggr, usdcReader, attestationURI, test.customTimeoutSeconds)
_, err = usdcService.callAttestationApi(context.Background(), utils.RandomBytes32())
require.Error(t, err)

if test.expectedError != nil {
require.True(t, errors.Is(err, test.expectedError))
}
})
}
}

func getMockUSDCEndpoint(t *testing.T, response attestationResponse) *httptest.Server {
Expand All @@ -99,7 +174,7 @@ func TestGetUSDCMessageBody(t *testing.T) {
usdcReader.On("GetLastUSDCMessagePriorToLogIndexInTx", mock.Anything, mock.Anything, mock.Anything).Return(expectedBody, nil)

lggr := logger.TestLogger(t)
usdcService := NewUSDCTokenDataReader(lggr, &usdcReader, nil)
usdcService := NewUSDCTokenDataReader(lggr, &usdcReader, nil, 0)

// Make the first call and assert the underlying function is called
body, err := usdcService.getUSDCMessageBody(context.Background(), internal.EVM2EVMOnRampCCIPSendRequestedWithMeta{})
Expand Down

0 comments on commit 0f3634f

Please sign in to comment.