From 58a4432c5f16f64348098dbe36b0ed1c93fdedeb Mon Sep 17 00:00:00 2001 From: bruce-riley <96066700+bruce-riley@users.noreply.github.com> Date: Wed, 31 May 2023 17:08:42 -0500 Subject: [PATCH] CCQ: Validation and marshalling changes (#3017) --- node/cmd/guardiand/node.go | 47 ++++++-------- node/cmd/guardiand/query.go | 101 +++++++++++++++---------------- node/pkg/common/queryRequest.go | 52 ++++++++++++++++ node/pkg/common/queryResponse.go | 47 +++++++------- node/pkg/watchers/evm/watcher.go | 51 +++++----------- 5 files changed, 161 insertions(+), 137 deletions(-) diff --git a/node/cmd/guardiand/node.go b/node/cmd/guardiand/node.go index 9fb1aa467f..ec32c70c5c 100644 --- a/node/cmd/guardiand/node.go +++ b/node/cmd/guardiand/node.go @@ -56,7 +56,6 @@ import ( "github.com/spf13/cobra" "github.com/wormhole-foundation/wormhole/sdk/vaa" "go.uber.org/zap" - "google.golang.org/protobuf/proto" ipfslog "github.com/ipfs/go-log/v2" googleapi_option "google.golang.org/api/option" @@ -992,7 +991,7 @@ func runNode(cmd *cobra.Command, args []string) { signedQueryReqReadC, signedQueryReqWriteC := makeChannelPair[*gossipv1.SignedQueryRequest](common.SignedQueryRequestChannelSize) // Per-chain query requests - chainQueryReqC := make(map[vaa.ChainID]chan *gossipv1.SignedQueryRequest) + chainQueryReqC := make(map[vaa.ChainID]chan *common.QueryRequest) // Query responses from watchers to query handler aggregated across all chains queryResponseReadC, queryResponseWriteC := makeChannelPair[*common.QueryResponse](0) @@ -1012,16 +1011,10 @@ func runNode(cmd *cobra.Command, args []string) { case <-rootCtx.Done(): return case response := <-c: - var queryRequest gossipv1.QueryRequest - err = proto.Unmarshal(response.Msg.Request.QueryRequest, &queryRequest) - if err != nil { - logger.Error("received invalid response from watcher", zap.Stringer("watcherChainId", chainId)) - continue - } - if vaa.ChainID(queryRequest.ChainId) != chainId { + if response.ChainID != chainId { // SECURITY: This should never happen. If it does, a watcher has been compromised. logger.Fatal("SECURITY CRITICAL: Received query response from a chain that was not marked as originating from that chain", - zap.Uint32("responseChainId", queryRequest.ChainId), + zap.Uint16("responseChainId", uint16(response.ChainID)), zap.Stringer("watcherChainId", chainId), ) } else { @@ -1234,7 +1227,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Ethereum watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDEthereum) chainObsvReqC[vaa.ChainIDEthereum] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDEthereum] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDEthereum] = make(chan *common.QueryRequest, queryRequestBufferSize) ethWatcher = evm.NewEthWatcher(*ethRPC, ethContractAddr, "eth", vaa.ChainIDEthereum, chainMsgC[vaa.ChainIDEthereum], setWriteC, chainObsvReqC[vaa.ChainIDEthereum], chainQueryReqC[vaa.ChainIDEthereum], chainQueryResponseC[vaa.ChainIDEthereum], *unsafeDevMode) if err := supervisor.Run(ctx, "ethwatch", common.WrapWithScissors(ethWatcher.Run, "ethwatch")); err != nil { @@ -1246,7 +1239,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting BSC watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDBSC) chainObsvReqC[vaa.ChainIDBSC] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDBSC] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDBSC] = make(chan *common.QueryRequest, queryRequestBufferSize) bscWatcher := evm.NewEthWatcher(*bscRPC, bscContractAddr, "bsc", vaa.ChainIDBSC, chainMsgC[vaa.ChainIDBSC], nil, chainObsvReqC[vaa.ChainIDBSC], chainQueryReqC[vaa.ChainIDBSC], chainQueryResponseC[vaa.ChainIDBSC], *unsafeDevMode) bscWatcher.SetWaitForConfirmations(true) if err := supervisor.Run(ctx, "bscwatch", common.WrapWithScissors(bscWatcher.Run, "bscwatch")); err != nil { @@ -1263,7 +1256,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Polygon watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDPolygon) chainObsvReqC[vaa.ChainIDPolygon] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDPolygon] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDPolygon] = make(chan *common.QueryRequest, queryRequestBufferSize) polygonWatcher := evm.NewEthWatcher(*polygonRPC, polygonContractAddr, "polygon", vaa.ChainIDPolygon, chainMsgC[vaa.ChainIDPolygon], nil, chainObsvReqC[vaa.ChainIDPolygon], chainQueryReqC[vaa.ChainIDPolygon], chainQueryResponseC[vaa.ChainIDPolygon], *unsafeDevMode) polygonWatcher.SetWaitForConfirmations(waitForConfirmations) if err := polygonWatcher.SetRootChainParams(*polygonRootChainRpc, *polygonRootChainContractAddress); err != nil { @@ -1277,7 +1270,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Avalanche watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDAvalanche) chainObsvReqC[vaa.ChainIDAvalanche] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDAvalanche] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDAvalanche] = make(chan *common.QueryRequest, queryRequestBufferSize) if err := supervisor.Run(ctx, "avalanchewatch", common.WrapWithScissors(evm.NewEthWatcher(*avalancheRPC, avalancheContractAddr, "avalanche", vaa.ChainIDAvalanche, chainMsgC[vaa.ChainIDAvalanche], nil, chainObsvReqC[vaa.ChainIDAvalanche], chainQueryReqC[vaa.ChainIDAvalanche], chainQueryResponseC[vaa.ChainIDAvalanche], *unsafeDevMode).Run, "avalanchewatch")); err != nil { return err @@ -1287,7 +1280,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Oasis watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDOasis) chainObsvReqC[vaa.ChainIDOasis] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDOasis] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDOasis] = make(chan *common.QueryRequest, queryRequestBufferSize) if err := supervisor.Run(ctx, "oasiswatch", common.WrapWithScissors(evm.NewEthWatcher(*oasisRPC, oasisContractAddr, "oasis", vaa.ChainIDOasis, chainMsgC[vaa.ChainIDOasis], nil, chainObsvReqC[vaa.ChainIDOasis], chainQueryReqC[vaa.ChainIDOasis], chainQueryResponseC[vaa.ChainIDOasis], *unsafeDevMode).Run, "oasiswatch")); err != nil { return err @@ -1297,7 +1290,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Aurora watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDAurora) chainObsvReqC[vaa.ChainIDAurora] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDAurora] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDAurora] = make(chan *common.QueryRequest, queryRequestBufferSize) if err := supervisor.Run(ctx, "aurorawatch", common.WrapWithScissors(evm.NewEthWatcher(*auroraRPC, auroraContractAddr, "aurora", vaa.ChainIDAurora, chainMsgC[vaa.ChainIDAurora], nil, chainObsvReqC[vaa.ChainIDAurora], chainQueryReqC[vaa.ChainIDAurora], chainQueryResponseC[vaa.ChainIDAurora], *unsafeDevMode).Run, "aurorawatch")); err != nil { return err @@ -1307,7 +1300,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Fantom watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDFantom) chainObsvReqC[vaa.ChainIDFantom] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDFantom] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDFantom] = make(chan *common.QueryRequest, queryRequestBufferSize) if err := supervisor.Run(ctx, "fantomwatch", common.WrapWithScissors(evm.NewEthWatcher(*fantomRPC, fantomContractAddr, "fantom", vaa.ChainIDFantom, chainMsgC[vaa.ChainIDFantom], nil, chainObsvReqC[vaa.ChainIDFantom], chainQueryReqC[vaa.ChainIDFantom], chainQueryResponseC[vaa.ChainIDFantom], *unsafeDevMode).Run, "fantomwatch")); err != nil { return err @@ -1317,7 +1310,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Karura watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDKarura) chainObsvReqC[vaa.ChainIDKarura] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDKarura] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDKarura] = make(chan *common.QueryRequest, queryRequestBufferSize) if err := supervisor.Run(ctx, "karurawatch", common.WrapWithScissors(evm.NewEthWatcher(*karuraRPC, karuraContractAddr, "karura", vaa.ChainIDKarura, chainMsgC[vaa.ChainIDKarura], nil, chainObsvReqC[vaa.ChainIDKarura], chainQueryReqC[vaa.ChainIDKarura], chainQueryResponseC[vaa.ChainIDKarura], *unsafeDevMode).Run, "karurawatch")); err != nil { return err @@ -1327,7 +1320,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Acala watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDAcala) chainObsvReqC[vaa.ChainIDAcala] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDAcala] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDAcala] = make(chan *common.QueryRequest, queryRequestBufferSize) if err := supervisor.Run(ctx, "acalawatch", common.WrapWithScissors(evm.NewEthWatcher(*acalaRPC, acalaContractAddr, "acala", vaa.ChainIDAcala, chainMsgC[vaa.ChainIDAcala], nil, chainObsvReqC[vaa.ChainIDAcala], chainQueryReqC[vaa.ChainIDAcala], chainQueryResponseC[vaa.ChainIDAcala], *unsafeDevMode).Run, "acalawatch")); err != nil { return err @@ -1337,7 +1330,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Klaytn watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDKlaytn) chainObsvReqC[vaa.ChainIDKlaytn] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDKlaytn] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDKlaytn] = make(chan *common.QueryRequest, queryRequestBufferSize) if err := supervisor.Run(ctx, "klaytnwatch", common.WrapWithScissors(evm.NewEthWatcher(*klaytnRPC, klaytnContractAddr, "klaytn", vaa.ChainIDKlaytn, chainMsgC[vaa.ChainIDKlaytn], nil, chainObsvReqC[vaa.ChainIDKlaytn], chainQueryReqC[vaa.ChainIDKlaytn], chainQueryResponseC[vaa.ChainIDKlaytn], *unsafeDevMode).Run, "klaytnwatch")); err != nil { return err @@ -1347,7 +1340,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Celo watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDCelo) chainObsvReqC[vaa.ChainIDCelo] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDCelo] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDCelo] = make(chan *common.QueryRequest, queryRequestBufferSize) if err := supervisor.Run(ctx, "celowatch", common.WrapWithScissors(evm.NewEthWatcher(*celoRPC, celoContractAddr, "celo", vaa.ChainIDCelo, chainMsgC[vaa.ChainIDCelo], nil, chainObsvReqC[vaa.ChainIDCelo], chainQueryReqC[vaa.ChainIDCelo], chainQueryResponseC[vaa.ChainIDCelo], *unsafeDevMode).Run, "celowatch")); err != nil { return err @@ -1357,7 +1350,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Moonbeam watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDMoonbeam) chainObsvReqC[vaa.ChainIDMoonbeam] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDMoonbeam] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDMoonbeam] = make(chan *common.QueryRequest, queryRequestBufferSize) if err := supervisor.Run(ctx, "moonbeamwatch", common.WrapWithScissors(evm.NewEthWatcher(*moonbeamRPC, moonbeamContractAddr, "moonbeam", vaa.ChainIDMoonbeam, chainMsgC[vaa.ChainIDMoonbeam], nil, chainObsvReqC[vaa.ChainIDMoonbeam], chainQueryReqC[vaa.ChainIDMoonbeam], chainQueryResponseC[vaa.ChainIDMoonbeam], *unsafeDevMode).Run, "moonbeamwatch")); err != nil { return err @@ -1370,7 +1363,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Arbitrum watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDArbitrum) chainObsvReqC[vaa.ChainIDArbitrum] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDArbitrum] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDArbitrum] = make(chan *common.QueryRequest, queryRequestBufferSize) arbitrumWatcher := evm.NewEthWatcher(*arbitrumRPC, arbitrumContractAddr, "arbitrum", vaa.ChainIDArbitrum, chainMsgC[vaa.ChainIDArbitrum], nil, chainObsvReqC[vaa.ChainIDArbitrum], chainQueryReqC[vaa.ChainIDArbitrum], chainQueryResponseC[vaa.ChainIDArbitrum], *unsafeDevMode) arbitrumWatcher.SetL1Finalizer(ethWatcher) if err := supervisor.Run(ctx, "arbitrumwatch", common.WrapWithScissors(arbitrumWatcher.Run, "arbitrumwatch")); err != nil { @@ -1381,7 +1374,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Optimism watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDOptimism) chainObsvReqC[vaa.ChainIDOptimism] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDOptimism] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDOptimism] = make(chan *common.QueryRequest, queryRequestBufferSize) optimismWatcher := evm.NewEthWatcher(*optimismRPC, optimismContractAddr, "optimism", vaa.ChainIDOptimism, chainMsgC[vaa.ChainIDOptimism], nil, chainObsvReqC[vaa.ChainIDOptimism], chainQueryReqC[vaa.ChainIDOptimism], chainQueryResponseC[vaa.ChainIDOptimism], *unsafeDevMode) // If rootChainParams are set, pass them in for pre-Bedrock mode @@ -1521,7 +1514,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Neon watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDNeon) chainObsvReqC[vaa.ChainIDNeon] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDNeon] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDNeon] = make(chan *common.QueryRequest, queryRequestBufferSize) neonWatcher := evm.NewEthWatcher(*neonRPC, neonContractAddr, "neon", vaa.ChainIDNeon, chainMsgC[vaa.ChainIDNeon], nil, chainObsvReqC[vaa.ChainIDNeon], chainQueryReqC[vaa.ChainIDNeon], chainQueryResponseC[vaa.ChainIDNeon], *unsafeDevMode) neonWatcher.SetL1Finalizer(solanaFinalizedWatcher) if err := supervisor.Run(ctx, "neonwatch", common.WrapWithScissors(neonWatcher.Run, "neonwatch")); err != nil { @@ -1532,7 +1525,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Base watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDBase) chainObsvReqC[vaa.ChainIDBase] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDBase] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDBase] = make(chan *common.QueryRequest, queryRequestBufferSize) baseWatcher := evm.NewEthWatcher(*baseRPC, baseContractAddr, "base", vaa.ChainIDBase, chainMsgC[vaa.ChainIDBase], nil, chainObsvReqC[vaa.ChainIDBase], chainQueryReqC[vaa.ChainIDBase], chainQueryResponseC[vaa.ChainIDBase], *unsafeDevMode) if err := supervisor.Run(ctx, "basewatch", common.WrapWithScissors(baseWatcher.Run, "basewatch")); err != nil { return err @@ -1545,7 +1538,7 @@ func runNode(cmd *cobra.Command, args []string) { logger.Info("Starting Sepolia watcher") common.MustRegisterReadinessSyncing(vaa.ChainIDSepolia) chainObsvReqC[vaa.ChainIDSepolia] = make(chan *gossipv1.ObservationRequest, observationRequestBufferSize) - chainQueryReqC[vaa.ChainIDSepolia] = make(chan *gossipv1.SignedQueryRequest, queryRequestBufferSize) + chainQueryReqC[vaa.ChainIDSepolia] = make(chan *common.QueryRequest, queryRequestBufferSize) sepoliaWatcher := evm.NewEthWatcher(*sepoliaRPC, sepoliaContractAddr, "sepolia", vaa.ChainIDSepolia, chainMsgC[vaa.ChainIDSepolia], nil, chainObsvReqC[vaa.ChainIDSepolia], chainQueryReqC[vaa.ChainIDSepolia], chainQueryResponseC[vaa.ChainIDSepolia], *unsafeDevMode) if err := supervisor.Run(ctx, "sepoliawatch", common.WrapWithScissors(sepoliaWatcher.Run, "sepoliawatch")); err != nil { return err diff --git a/node/cmd/guardiand/query.go b/node/cmd/guardiand/query.go index a88a49bda0..09e73da291 100644 --- a/node/cmd/guardiand/query.go +++ b/node/cmd/guardiand/query.go @@ -2,7 +2,6 @@ package guardiand import ( "context" - "encoding/hex" "fmt" "strings" "time" @@ -29,16 +28,14 @@ const ( type ( // pendingQuery is the cache entry for a given query. pendingQuery struct { - req *gossipv1.SignedQueryRequest - reqId string - chainId vaa.ChainID - channel chan *gossipv1.SignedQueryRequest + req *common.QueryRequest + channel chan *common.QueryRequest receiveTime time.Time lastUpdateTime time.Time inProgress bool - // resp is only populated when we need to retry sending the response to p2p. - resp *common.QueryResponse + // respPub is only populated when we need to retry sending the response to p2p. + respPub *common.QueryResponsePublication } ) @@ -47,7 +44,7 @@ func handleQueryRequests( ctx context.Context, logger *zap.Logger, signedQueryReqC <-chan *gossipv1.SignedQueryRequest, - chainQueryReqC map[vaa.ChainID]chan *gossipv1.SignedQueryRequest, + chainQueryReqC map[vaa.ChainID]chan *common.QueryRequest, allowedRequestors map[ethCommon.Address]struct{}, queryResponseReadC <-chan *common.QueryResponse, queryResponseWriteC chan<- *common.QueryResponsePublication, @@ -91,75 +88,85 @@ func handleQueryRequests( continue } - var queryRequest gossipv1.QueryRequest - err = proto.Unmarshal(signedQueryRequest.QueryRequest, &queryRequest) + var qr gossipv1.QueryRequest + err = proto.Unmarshal(signedQueryRequest.QueryRequest, &qr) if err != nil { - qLogger.Error("received invalid message", - zap.String("requestor", signerAddress.Hex())) + qLogger.Error("failed to unmarshal query request", zap.String("requestor", signerAddress.Hex()), zap.Error(err)) continue } - reqId := requestID(signedQueryRequest) - chainId := vaa.ChainID(queryRequest.ChainId) + if err := common.ValidateQueryRequest(&qr); err != nil { + qLogger.Error("received invalid message", zap.String("requestor", signerAddress.Hex()), zap.Error(err)) + continue + } + + queryRequest := common.CreateQueryRequest(signedQueryRequest, &qr) // Look up the channel for this chain. - channel, channelExists := chainQueryReqC[chainId] + channel, channelExists := chainQueryReqC[queryRequest.ChainID] if !channelExists { - qLogger.Error("unknown chain ID for query request, dropping it", zap.String("requestID", reqId), zap.Uint32("chain_id", queryRequest.ChainId)) + qLogger.Error("unknown chain ID for query request, dropping it", zap.String("requestID", queryRequest.RequestID), zap.Stringer("chain_id", queryRequest.ChainID)) continue } // Make sure this is not a duplicate request. TODO: Should we do something smarter here than just dropping the duplicate? - if oldReq, exists := pendingQueries[reqId]; exists { - qLogger.Warn("dropping duplicate query request", zap.String("requestID", reqId), zap.Stringer("origRecvTime", oldReq.receiveTime)) + if oldReq, exists := pendingQueries[queryRequest.RequestID]; exists { + qLogger.Warn("dropping duplicate query request", zap.String("requestID", queryRequest.RequestID), zap.Stringer("origRecvTime", oldReq.receiveTime)) continue } // Add the query to our cache. pq := &pendingQuery{ - req: signedQueryRequest, - reqId: reqId, - chainId: chainId, + req: queryRequest, channel: channel, receiveTime: time.Now(), inProgress: true, } - pendingQueries[reqId] = pq + pendingQueries[queryRequest.RequestID] = pq // Forward the request to the watcher. ccqForwardToWatcher(qLogger, pq) case resp := <-queryResponseReadC: - reqId := resp.RequestID() if resp.Status == common.QuerySuccess { + if resp.Result == nil { + qLogger.Error("received a successful query response with a nil result, dropping it!", zap.String("requestID", resp.RequestID)) + continue + } + + respPub := &common.QueryResponsePublication{ + Request: resp.SignedRequest, + Response: *resp.Result, + } + // Send the response to be published. select { - case queryResponseWriteC <- resp.Msg: - qLogger.Debug("forwarded query response to p2p", zap.String("requestID", reqId)) - delete(pendingQueries, reqId) + case queryResponseWriteC <- respPub: + qLogger.Debug("forwarded query response to p2p", zap.String("requestID", resp.RequestID)) + delete(pendingQueries, resp.RequestID) default: - if pq, exists := pendingQueries[reqId]; exists { - qLogger.Warn("failed to publish query response to p2p, will retry publishing next interval", zap.String("requestID", reqId)) - pq.resp = resp + if pq, exists := pendingQueries[resp.RequestID]; exists { + qLogger.Warn("failed to publish query response to p2p, will retry publishing next interval", zap.String("requestID", resp.RequestID)) + pq.respPub = respPub pq.inProgress = false } else { - qLogger.Warn("failed to publish query response to p2p, request is no longer in cache, dropping it", zap.String("requestID", reqId)) - delete(pendingQueries, reqId) + qLogger.Warn("failed to publish query response to p2p, request is no longer in cache, dropping it", zap.String("requestID", resp.RequestID)) + delete(pendingQueries, resp.RequestID) } } } else if resp.Status == common.QueryRetryNeeded { - if pq, exists := pendingQueries[reqId]; exists { - qLogger.Warn("query failed, will retry next interval", zap.String("requestID", reqId)) + if pq, exists := pendingQueries[resp.RequestID]; exists { + qLogger.Warn("query failed, will retry next interval", zap.String("requestID", resp.RequestID)) pq.inProgress = false } else { - qLogger.Warn("query failed, request is no longer in cache, dropping it", zap.String("requestID", reqId)) + qLogger.Warn("query failed, request is no longer in cache, dropping it", zap.String("requestID", resp.RequestID)) } } else if resp.Status == common.QueryFatalError { - qLogger.Error("query encountered a fatal error, dropping it", zap.String("requestID", reqId)) - delete(pendingQueries, reqId) + qLogger.Error("query encountered a fatal error, dropping it", zap.String("requestID", resp.RequestID)) + delete(pendingQueries, resp.RequestID) } else { - qLogger.Error("received an unexpected query status, dropping it", zap.String("requestID", reqId), zap.Int("status", int(resp.Status))) - delete(pendingQueries, reqId) + qLogger.Error("received an unexpected query status, dropping it", zap.String("requestID", resp.RequestID), zap.Int("status", int(resp.Status))) + delete(pendingQueries, resp.RequestID) } case <-ticker.C: @@ -171,17 +178,17 @@ func handleQueryRequests( qLogger.Warn("query request timed out, dropping it", zap.String("requestId", reqId), zap.Stringer("receiveTime", pq.receiveTime)) delete(pendingQueries, reqId) } else { - if pq.resp != nil { + if pq.respPub != nil { // Resend the response to be published. select { - case queryResponseWriteC <- pq.resp.Msg: + case queryResponseWriteC <- pq.respPub: qLogger.Debug("resend of query response to p2p succeeded", zap.String("requestID", reqId)) delete(pendingQueries, reqId) default: qLogger.Warn("resend of query response to p2p failed again, will keep retrying", zap.String("requestID", reqId)) } } else if !pq.inProgress && pq.lastUpdateTime.Add(retryInterval).Before(now) { - qLogger.Info("retrying query request", zap.String("requestId", pq.reqId), zap.Stringer("receiveTime", pq.receiveTime)) + qLogger.Info("retrying query request", zap.String("requestId", reqId), zap.Stringer("receiveTime", pq.receiveTime)) pq.inProgress = true ccqForwardToWatcher(qLogger, pq) } @@ -220,19 +227,11 @@ func ccqForwardToWatcher(qLogger *zap.Logger, pq *pendingQuery) { select { // TODO: only send the query request itself and reassemble in this module case pq.channel <- pq.req: - qLogger.Debug("forwarded query request to watcher", zap.String("requestID", pq.reqId), zap.Stringer("chainID", pq.chainId)) + qLogger.Debug("forwarded query request to watcher", zap.String("requestID", pq.req.RequestID), zap.Stringer("chainID", pq.req.ChainID)) pq.lastUpdateTime = pq.receiveTime default: // By leaving lastUpdateTime unset and setting inProgress to false, we will retry next interval. - qLogger.Warn("failed to send query request to watcher, will retry next interval", zap.String("requestID", pq.reqId), zap.Uint16("chain_id", uint16(pq.chainId))) + qLogger.Warn("failed to send query request to watcher, will retry next interval", zap.String("requestID", pq.req.RequestID), zap.Stringer("chain_id", pq.req.ChainID)) pq.inProgress = false } } - -// requestID returns the request signature as a hex string. -func requestID(req *gossipv1.SignedQueryRequest) string { - if req == nil { - return "nil" - } - return hex.EncodeToString(req.Signature) -} diff --git a/node/pkg/common/queryRequest.go b/node/pkg/common/queryRequest.go index 4cc08566a7..f96be4cbb7 100644 --- a/node/pkg/common/queryRequest.go +++ b/node/pkg/common/queryRequest.go @@ -1,7 +1,13 @@ package common import ( + "encoding/hex" + "fmt" + "math" + "strings" + gossipv1 "github.com/certusone/wormhole/node/pkg/proto/gossip/v1" + "github.com/wormhole-foundation/wormhole/sdk/vaa" ethCommon "github.com/ethereum/go-ethereum/common" ethCrypto "github.com/ethereum/go-ethereum/crypto" @@ -9,6 +15,25 @@ import ( const SignedQueryRequestChannelSize = 50 +// QueryRequest is an internal representation of a query request. +type QueryRequest struct { + SignedRequest *gossipv1.SignedQueryRequest + Request *gossipv1.QueryRequest + RequestID string + ChainID vaa.ChainID +} + +// CreateQueryRequest creates a QueryRequest object from the signed query request. +func CreateQueryRequest(signedRequest *gossipv1.SignedQueryRequest, request *gossipv1.QueryRequest) *QueryRequest { + return &QueryRequest{ + SignedRequest: signedRequest, + Request: request, + RequestID: hex.EncodeToString(signedRequest.Signature), + ChainID: vaa.ChainID(request.ChainId), + } +} + +// QueryRequestDigest returns the query signing prefix based on the environment. func QueryRequestDigest(env Environment, b []byte) ethCommon.Hash { // TODO: should this use a different standard of signing messages, like https://eips.ethereum.org/EIPS/eip-712 var queryRequestPrefix []byte @@ -23,6 +48,7 @@ func QueryRequestDigest(env Environment, b []byte) ethCommon.Hash { return ethCrypto.Keccak256Hash(append(queryRequestPrefix, b...)) } +// PostSignedQueryRequest posts a signed query request to the specified channel. func PostSignedQueryRequest(signedQueryReqSendC chan<- *gossipv1.SignedQueryRequest, req *gossipv1.SignedQueryRequest) error { select { case signedQueryReqSendC <- req: @@ -31,3 +57,29 @@ func PostSignedQueryRequest(signedQueryReqSendC chan<- *gossipv1.SignedQueryRequ return ErrChanFull } } + +// ValidateQueryRequest does basic validation on a received query request. +func ValidateQueryRequest(queryRequest *gossipv1.QueryRequest) error { + if queryRequest.ChainId > math.MaxUint16 { + return fmt.Errorf("invalid chain id: %d is out of bounds", queryRequest.ChainId) + } + switch req := queryRequest.Message.(type) { + case *gossipv1.QueryRequest_EthCallQueryRequest: + if len(req.EthCallQueryRequest.To) != 20 { + return fmt.Errorf("invalid length for To contract") + } + if len(req.EthCallQueryRequest.Data) > math.MaxUint32 { + return fmt.Errorf("request data too long") + } + if len(req.EthCallQueryRequest.Block) > math.MaxUint32 { + return fmt.Errorf("request block too long") + } + if !strings.HasPrefix(req.EthCallQueryRequest.Block, "0x") { + return fmt.Errorf("request block must be a hex number or hash starting with 0x") + } + default: + return fmt.Errorf("received invalid message from query module") + } + + return nil +} diff --git a/node/pkg/common/queryResponse.go b/node/pkg/common/queryResponse.go index a42e0be0bf..ec148ca2e2 100644 --- a/node/pkg/common/queryResponse.go +++ b/node/pkg/common/queryResponse.go @@ -31,15 +31,21 @@ const ( ) type QueryResponse struct { - Status QueryStatus - Msg *QueryResponsePublication + RequestID string + ChainID vaa.ChainID + Status QueryStatus + SignedRequest *gossipv1.SignedQueryRequest + Result *EthCallQueryResponse } -func (resp *QueryResponse) RequestID() string { - if resp == nil || resp.Msg == nil { - return "nil" +func CreateQueryResponse(req *QueryRequest, status QueryStatus, result *EthCallQueryResponse) *QueryResponse { + return &QueryResponse{ + RequestID: req.RequestID, + ChainID: vaa.ChainID(req.Request.ChainId), + SignedRequest: req.SignedRequest, + Status: status, + Result: result, } - return resp.Msg.RequestID() } var queryResponsePrefix = []byte("query_response_0000000000000000000|") @@ -73,6 +79,17 @@ func (msg *QueryResponsePublication) Marshal() ([]byte, error) { return nil, fmt.Errorf("received invalid message from query module") } + if err := ValidateQueryRequest(&queryRequest); err != nil { + return nil, fmt.Errorf("queryRequest is invalid: %w", err) + } + + if len(msg.Response.Hash) != 32 { + return nil, fmt.Errorf("invalid length for block hash") + } + if len(msg.Response.Result) > math.MaxUint32 { + return nil, fmt.Errorf("response data too long") + } + buf := new(bytes.Buffer) // Source @@ -86,23 +103,11 @@ func (msg *QueryResponsePublication) Marshal() ([]byte, error) { switch req := queryRequest.Message.(type) { case *gossipv1.QueryRequest_EthCallQueryRequest: vaa.MustWrite(buf, binary.BigEndian, uint8(1)) - if queryRequest.ChainId > math.MaxUint16 { - return nil, fmt.Errorf("invalid chain id: %d is out of bounds", queryRequest.ChainId) - } vaa.MustWrite(buf, binary.BigEndian, uint16(queryRequest.ChainId)) vaa.MustWrite(buf, binary.BigEndian, queryRequest.Nonce) // uint32 - if len(req.EthCallQueryRequest.To) != 20 { - return nil, fmt.Errorf("invalid length for To contract") - } buf.Write(req.EthCallQueryRequest.To) - if len(req.EthCallQueryRequest.Data) > math.MaxUint32 { - return nil, fmt.Errorf("request data too long") - } vaa.MustWrite(buf, binary.BigEndian, uint32(len(req.EthCallQueryRequest.Data))) buf.Write(req.EthCallQueryRequest.Data) - if len(req.EthCallQueryRequest.Block) > math.MaxUint32 { - return nil, fmt.Errorf("request block too long") - } vaa.MustWrite(buf, binary.BigEndian, uint32(len(req.EthCallQueryRequest.Block))) // TODO: should this be an enum or the literal string? buf.Write([]byte(req.EthCallQueryRequest.Block)) @@ -111,14 +116,8 @@ func (msg *QueryResponsePublication) Marshal() ([]byte, error) { // TODO: probably some kind of request/response pair validation // TODO: is uint64 safe? vaa.MustWrite(buf, binary.BigEndian, msg.Response.Number.Uint64()) - if len(msg.Response.Hash) != 32 { - return nil, fmt.Errorf("invalid length for block hash") - } buf.Write(msg.Response.Hash[:]) vaa.MustWrite(buf, binary.BigEndian, uint32(msg.Response.Time.Unix())) - if len(msg.Response.Result) > math.MaxUint32 { - return nil, fmt.Errorf("response data too long") - } vaa.MustWrite(buf, binary.BigEndian, uint32(len(msg.Response.Result))) buf.Write(msg.Response.Result) return buf.Bytes(), nil diff --git a/node/pkg/watchers/evm/watcher.go b/node/pkg/watchers/evm/watcher.go index c9ca47cbb5..2c039e2ce8 100644 --- a/node/pkg/watchers/evm/watcher.go +++ b/node/pkg/watchers/evm/watcher.go @@ -22,7 +22,6 @@ import ( eth_common "github.com/ethereum/go-ethereum/common" eth_hexutil "github.com/ethereum/go-ethereum/common/hexutil" "go.uber.org/zap" - "google.golang.org/protobuf/proto" "github.com/certusone/wormhole/node/pkg/common" "github.com/certusone/wormhole/node/pkg/readiness" @@ -97,7 +96,7 @@ type ( // Incoming query requests from the network. Pre-filtered to only // include requests for our chainID. - queryReqC <-chan *gossipv1.SignedQueryRequest + queryReqC <-chan *common.QueryRequest // Outbound query responses to query requests queryResponseC chan<- *common.QueryResponse @@ -151,7 +150,7 @@ func NewEthWatcher( msgC chan<- *common.MessagePublication, setC chan<- *common.GuardianSet, obsvReqC <-chan *gossipv1.ObservationRequest, - queryReqC <-chan *gossipv1.SignedQueryRequest, + queryReqC <-chan *common.QueryRequest, queryResponseC chan<- *common.QueryResponse, unsafeDevMode bool, ) *Watcher { @@ -544,23 +543,14 @@ func (w *Watcher) Run(parentCtx context.Context) error { select { case <-ctx.Done(): return nil - case signedQueryRequest := <-w.queryReqC: - // TODO: only receive the unmarshalled query request (see note in query.go) - var queryRequest gossipv1.QueryRequest - err := proto.Unmarshal(signedQueryRequest.QueryRequest, &queryRequest) - if err != nil { - logger.Error("received invalid message from query module", zap.String("component", "ccqevm")) - w.ccqSendQueryResponse(logger, common.QueryFatalError, signedQueryRequest, nil) - continue - } - + case queryRequest := <-w.queryReqC: // This can't happen unless there is a programming error - the caller // is expected to send us only requests for our chainID. - if vaa.ChainID(queryRequest.ChainId) != w.chainID { + if queryRequest.ChainID != w.chainID { panic("ccqevm: invalid chain ID") } - switch req := queryRequest.Message.(type) { + switch req := queryRequest.Request.Message.(type) { case *gossipv1.QueryRequest_EthCallQueryRequest: to := eth_common.BytesToAddress(req.EthCallQueryRequest.To) data := eth_hexutil.Encode(req.EthCallQueryRequest.Data) @@ -638,7 +628,7 @@ func (w *Watcher) Run(parentCtx context.Context) error { zap.String("block", block), zap.String("component", "ccqevm"), ) - w.ccqSendQueryResponse(logger, common.QueryRetryNeeded, signedQueryRequest, nil) + w.ccqSendQueryResponse(logger, queryRequest, common.QueryRetryNeeded, nil) continue } @@ -650,7 +640,7 @@ func (w *Watcher) Run(parentCtx context.Context) error { zap.String("block", block), zap.String("component", "ccqevm"), ) - w.ccqSendQueryResponse(logger, common.QueryRetryNeeded, signedQueryRequest, nil) + w.ccqSendQueryResponse(logger, queryRequest, common.QueryRetryNeeded, nil) continue } @@ -662,7 +652,7 @@ func (w *Watcher) Run(parentCtx context.Context) error { zap.String("block", block), zap.String("component", "ccqevm"), ) - w.ccqSendQueryResponse(logger, common.QueryRetryNeeded, signedQueryRequest, nil) + w.ccqSendQueryResponse(logger, queryRequest, common.QueryRetryNeeded, nil) continue } @@ -674,7 +664,7 @@ func (w *Watcher) Run(parentCtx context.Context) error { zap.String("block", block), zap.String("component", "ccqevm"), ) - w.ccqSendQueryResponse(logger, common.QueryRetryNeeded, signedQueryRequest, nil) + w.ccqSendQueryResponse(logger, queryRequest, common.QueryRetryNeeded, nil) continue } @@ -688,7 +678,7 @@ func (w *Watcher) Run(parentCtx context.Context) error { zap.String("block", block), zap.String("component", "ccqevm"), ) - w.ccqSendQueryResponse(logger, common.QueryRetryNeeded, signedQueryRequest, nil) + w.ccqSendQueryResponse(logger, queryRequest, common.QueryRetryNeeded, nil) continue } @@ -711,14 +701,14 @@ func (w *Watcher) Run(parentCtx context.Context) error { Result: callResult, } - w.ccqSendQueryResponse(logger, common.QuerySuccess, signedQueryRequest, resp) + w.ccqSendQueryResponse(logger, queryRequest, common.QuerySuccess, resp) default: logger.Warn("received unsupported request type", - zap.Any("payload", queryRequest.Message), + zap.Any("payload", queryRequest.Request.Message), zap.String("component", "ccqevm"), ) - w.ccqSendQueryResponse(logger, common.QueryFatalError, signedQueryRequest, nil) + w.ccqSendQueryResponse(logger, queryRequest, common.QueryFatalError, nil) } } } @@ -1141,19 +1131,10 @@ func (w *Watcher) SetMaxWaitConfirmations(maxWaitConfirmations uint64) { } // ccqSendQueryResponse sends an error response back to the query handler. -func (w *Watcher) ccqSendQueryResponse(logger *zap.Logger, status common.QueryStatus, req *gossipv1.SignedQueryRequest, resp *common.EthCallQueryResponse) { - queryResponse := common.QueryResponse{ - Status: status, - Msg: &common.QueryResponsePublication{ - Request: req, - }, - } - - if resp != nil { - queryResponse.Msg.Response = *resp - } +func (w *Watcher) ccqSendQueryResponse(logger *zap.Logger, req *common.QueryRequest, status common.QueryStatus, result *common.EthCallQueryResponse) { + queryResponse := common.CreateQueryResponse(req, status, result) select { - case w.queryResponseC <- &queryResponse: + case w.queryResponseC <- queryResponse: logger.Debug("published query response error to handler", zap.String("component", "ccqevm")) default: logger.Error("failed to published query response error to handler", zap.String("component", "ccqevm"))