diff --git a/execution.go b/execution.go index c13b151..d8de4e8 100644 --- a/execution.go +++ b/execution.go @@ -8,6 +8,7 @@ import ( "fmt" "math/big" "net/http" + "strings" "time" "github.com/golang-jwt/jwt/v5" @@ -58,23 +59,10 @@ func NewEngineAPIExecutionClient( return nil, err } - authToken := "" - if jwtSecret != "" { - secret, err := hex.DecodeString(jwtSecret) - if err != nil { - return nil, err - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "exp": time.Now().Add(time.Hour * 1).Unix(), // Expires in 1 hour - "iat": time.Now().Unix(), - }) - - // Sign the token with the decoded secret - authToken, err = token.SignedString(secret) - if err != nil { - return nil, err - } + authToken, err := getAuthToken(jwtSecret) + if err != nil { + ethClient.Close() + return nil, err } engineClient, err := rpc.DialOptions(context.Background(), engineURL, @@ -294,3 +282,26 @@ func (c *EngineAPIExecutionClient) SetFinal(ctx context.Context, height uint64) func (c *EngineAPIExecutionClient) derivePrevRandao(blockHeight uint64) common.Hash { return common.BigToHash(big.NewInt(int64(blockHeight))) //nolint:gosec // disable G115 } + +// Add this function to execution.go +func getAuthToken(jwtSecret string) (string, error) { + if jwtSecret == "" { + return "", nil + } + secret, err := hex.DecodeString(strings.TrimPrefix(jwtSecret, "0x")) + if err != nil { + return "", fmt.Errorf("failed to decode JWT secret: %w", err) + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "exp": time.Now().Add(time.Hour * 1).Unix(), // Expires in 1 hour + "iat": time.Now().Unix(), + }) + + // Sign the token with the decoded secret + authToken, err := token.SignedString(secret) + if err != nil { + return "", fmt.Errorf("failed to sign JWT token: %w", err) + } + return authToken, nil +} diff --git a/integration_test.go b/integration_test.go index f577830..9db5bfc 100644 --- a/integration_test.go +++ b/integration_test.go @@ -13,7 +13,6 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -72,9 +71,19 @@ func setupTestRethEngine(t *testing.T) string { err = os.WriteFile(jwtFile, []byte(jwtSecret), 0600) require.NoError(t, err) + t.Cleanup(func() { + err := os.Remove(jwtFile) + require.NoError(t, err) + }) + cli, err := client.NewClientWithOpts() require.NoError(t, err) + t.Cleanup(func() { + err := cli.Close() + require.NoError(t, err) + }) + rethContainer, err := cli.ContainerCreate(context.Background(), &container.Config{ Image: "ghcr.io/paradigmxyz/reth:v1.1.1", @@ -124,21 +133,19 @@ func setupTestRethEngine(t *testing.T) string { nil, nil, "reth") require.NoError(t, err) - err = cli.ContainerStart(context.Background(), rethContainer.ID, container.StartOptions{}) - require.NoError(t, err) - - err = waitForRethContainer(t, jwtSecret) - require.NoError(t, err) - t.Cleanup(func() { - err = cli.ContainerStop(context.Background(), rethContainer.ID, container.StopOptions{}) + err := cli.ContainerStop(context.Background(), rethContainer.ID, container.StopOptions{}) require.NoError(t, err) err = cli.ContainerRemove(context.Background(), rethContainer.ID, container.RemoveOptions{}) require.NoError(t, err) - err = os.Remove(jwtFile) - require.NoError(t, err) }) + err = cli.ContainerStart(context.Background(), rethContainer.ID, container.StartOptions{}) + require.NoError(t, err) + + err = waitForRethContainer(t, jwtSecret) + require.NoError(t, err) + return jwtSecret } @@ -146,16 +153,16 @@ func setupTestRethEngine(t *testing.T) string { func waitForRethContainer(t *testing.T, jwtSecret string) error { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - client := &http.Client{ - Timeout: 1 * time.Second, + Timeout: 100 * time.Millisecond, } + timer := time.NewTimer(500 * time.Millisecond) + defer timer.Stop() + for { select { - case <-ctx.Done(): + case <-timer.C: return fmt.Errorf("timeout waiting for reth container to be ready") default: // check :8545 is ready @@ -164,28 +171,19 @@ func waitForRethContainer(t *testing.T, jwtSecret string) error { if err == nil { resp.Body.Close() if resp.StatusCode == http.StatusOK { - // check :8551 is ready - req, err := http.NewRequest("POST", TEST_ENGINE_URL, strings.NewReader(`{"jsonrpc":"2.0","method":"engine_exchangeTransitionConfigurationV1","params":[],"id":1}`)) + // check :8551 is ready with a stateless call + req, err := http.NewRequest("POST", TEST_ENGINE_URL, strings.NewReader(`{"jsonrpc":"2.0","method":"engine_getClientVersionV1","params":[],"id":1}`)) if err != nil { return err } req.Header.Set("Content-Type", "application/json") - jwtSecretBytes, err := hex.DecodeString(strings.TrimPrefix(jwtSecret, "0x")) + authToken, err := getAuthToken(jwtSecret) if err != nil { - return fmt.Errorf("failed to decode JWT secret: %w", err) - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "iat": time.Now().Unix(), - }) - - signedToken, err := token.SignedString(jwtSecretBytes) - if err != nil { - return fmt.Errorf("failed to sign JWT token: %w", err) + return err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signedToken)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken)) resp, err := client.Do(req) if err == nil { @@ -196,7 +194,7 @@ func waitForRethContainer(t *testing.T, jwtSecret string) error { } } } - time.Sleep(500 * time.Millisecond) + time.Sleep(100 * time.Millisecond) } } } @@ -223,14 +221,21 @@ func TestExecutionClientLifecycle(t *testing.T) { ) require.NoError(t, err) - // sub tests are only functional grouping. - t.Run("InitChain", func(t *testing.T) { + require.True(t, t.Run("InitChain", func(t *testing.T) { stateRoot, gasLimit, err := executionClient.InitChain(context.Background(), genesisTime, initialHeight, CHAIN_ID) require.NoError(t, err) require.Equal(t, rollkitGenesisStateRoot, stateRoot) require.Equal(t, uint64(1000000), gasLimit) - }) + })) + + require.True(t, t.Run("InitChain_InvalidPayloadTimestamp", func(t *testing.T) { + blockTime := time.Date(2024, 3, 13, 13, 54, 0, 0, time.UTC) // pre-cancun timestamp not supported + _, _, err := executionClient.InitChain(context.Background(), blockTime, initialHeight, CHAIN_ID) + // payload timestamp is not within the cancun timestamp + require.Error(t, err) + require.ErrorContains(t, err, "Unsupported fork") + })) privateKey, err := crypto.HexToECDSA(TEST_PRIVATE_KEY) require.NoError(t, err) @@ -252,7 +257,7 @@ func TestExecutionClientLifecycle(t *testing.T) { err = rpcClient.SendTransaction(context.Background(), signedTx) require.NoError(t, err) - t.Run("GetTxs", func(t *testing.T) { + require.True(t, t.Run("GetTxs", func(t *testing.T) { txs, err := executionClient.GetTxs(context.Background()) require.NoError(t, err) assert.Equal(t, 1, len(txs)) @@ -269,7 +274,7 @@ func TestExecutionClientLifecycle(t *testing.T) { assert.Equal(t, rSignedTx, r) assert.Equal(t, sSignedTx, s) assert.Equal(t, vSignedTx, v) - }) + })) txBytes, err := signedTx.MarshalBinary() require.NoError(t, err) @@ -277,35 +282,12 @@ func TestExecutionClientLifecycle(t *testing.T) { blockHeight := uint64(1) blockTime := genesisTime.Add(10 * time.Second) - t.Run("ExecuteTxs", func(t *testing.T) { + require.True(t, t.Run("ExecuteTxs", func(t *testing.T) { newStateroot := common.HexToHash("0x362b7d8a31e7671b0f357756221ac385790c25a27ab222dc8cbdd08944f5aea4") stateroot, gasUsed, err := executionClient.ExecuteTxs(context.Background(), []rollkit_types.Tx{rollkit_types.Tx(txBytes)}, blockHeight, blockTime, rollkitGenesisStateRoot) require.NoError(t, err) assert.Greater(t, gasLimit, gasUsed) assert.Equal(t, rollkit_types.Hash(newStateroot[:]), stateroot) - }) -} - -func TestExecutionClient_InitChain_InvalidPayloadTimestamp(t *testing.T) { - jwtSecret := setupTestRethEngine(t) - - initialHeight := uint64(0) - genesisHash := common.HexToHash(GENESIS_HASH) - blockTime := time.Date(2024, 3, 13, 13, 54, 0, 0, time.UTC) // pre-cancun timestamp not supported - - executionClient, err := NewEngineAPIExecutionClient( - &proxy_json_rpc.Config{}, - TEST_ETH_URL, - TEST_ENGINE_URL, - jwtSecret, - genesisHash, - common.Address{}, - ) - require.NoError(t, err) - - _, _, err = executionClient.InitChain(context.Background(), blockTime, initialHeight, CHAIN_ID) - // payload timestamp is not within the cancun timestamp - require.Error(t, err) - require.ErrorContains(t, err, "Unsupported fork") + })) }