diff --git a/core/services/gateway/connectionmanager.go b/core/services/gateway/connectionmanager.go index 17a31014c81..86660f03dc9 100644 --- a/core/services/gateway/connectionmanager.go +++ b/core/services/gateway/connectionmanager.go @@ -87,6 +87,9 @@ func NewConnectionManager(gwConfig *config.GatewayConfig, clock utils.Clock, lgg return nil, fmt.Errorf("duplicate node address %s in DON %s", nodeAddress, donConfig.DonId) } nodes[nodeAddress] = &nodeState{conn: network.NewWSConnectionWrapper()} + if nodes[nodeAddress].conn == nil { + return nil, fmt.Errorf("error creating WSConnectionWrapper for node %s", nodeAddress) + } } dons[donConfig.DonId] = &donConnectionManager{ donConfig: &donConfig, @@ -232,11 +235,18 @@ func (m *donConnectionManager) SetHandler(handler handlers.Handler) { } func (m *donConnectionManager) SendToNode(ctx context.Context, nodeAddress string, msg *api.Message) error { + if msg == nil { + return errors.New("nil message") + } data, err := m.codec.EncodeRequest(msg) if err != nil { return fmt.Errorf("error encoding request for node %s: %v", nodeAddress, err) } - return m.nodes[nodeAddress].conn.Write(ctx, websocket.BinaryMessage, data) + nodeState := m.nodes[nodeAddress] + if nodeState == nil { + return fmt.Errorf("node %s not found", nodeAddress) + } + return nodeState.conn.Write(ctx, websocket.BinaryMessage, data) } func (m *donConnectionManager) readLoop(nodeAddress string, nodeState *nodeState) { diff --git a/core/services/gateway/connectionmanager_test.go b/core/services/gateway/connectionmanager_test.go index f924761439a..d198ef67295 100644 --- a/core/services/gateway/connectionmanager_test.go +++ b/core/services/gateway/connectionmanager_test.go @@ -8,8 +8,10 @@ import ( "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" gc "github.com/smartcontractkit/chainlink/v2/core/services/gateway/common" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network" @@ -208,3 +210,20 @@ func TestConnectionManager_FinalizeHandshake(t *testing.T) { err = mgr.FinalizeHandshake(attemptId, response, nil) require.ErrorIs(t, err, network.ErrChallengeInvalidSignature) } + +func TestConnectionManager_SendToNode_Failures(t *testing.T) { + t.Parallel() + + config, nodes := newTestConfig(t, 2) + clock := utils.NewFixedClock(time.Now()) + mgr, err := gateway.NewConnectionManager(config, clock, logger.TestLogger(t)) + require.NoError(t, err) + + donMgr := mgr.DONConnectionManager("my_don_1") + err = donMgr.SendToNode(testutils.Context(t), nodes[0].Address, nil) + require.Error(t, err) + + message := &api.Message{} + err = donMgr.SendToNode(testutils.Context(t), "some_other_node", message) + require.Error(t, err) +} diff --git a/core/services/gateway/gateway.go b/core/services/gateway/gateway.go index 8c77e9b7485..d64ee43233a 100644 --- a/core/services/gateway/gateway.go +++ b/core/services/gateway/gateway.go @@ -4,9 +4,12 @@ import ( "context" "encoding/json" "fmt" + "strings" "go.uber.org/multierr" + "github.com/ethereum/go-ethereum/common" + "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" @@ -59,6 +62,12 @@ func NewGatewayFromConfig(config *config.GatewayConfig, handlerFactory HandlerFa if donConnMgr == nil { return nil, fmt.Errorf("connection manager ID %s not found", donConfig.DonId) } + for idx, nodeConfig := range donConfig.Members { + donConfig.Members[idx].Address = strings.ToLower(nodeConfig.Address) + if !common.IsHexAddress(nodeConfig.Address) { + return nil, fmt.Errorf("invalid node address %s", nodeConfig.Address) + } + } handler, err := handlerFactory.NewHandler(donConfig.HandlerName, donConfig.HandlerConfig, &donConfig, donConnMgr) if err != nil { return nil, err diff --git a/core/services/gateway/gateway_test.go b/core/services/gateway/gateway_test.go index a6662505db0..a2ee8f7a6c0 100644 --- a/core/services/gateway/gateway_test.go +++ b/core/services/gateway/gateway_test.go @@ -49,6 +49,10 @@ HandlerName = "dummy" [[dons]] DonId = "my_don_2" HandlerName = "dummy" + +[[dons.Members]] +Name = "node one" +Address = "0x0001020304050607080900010203040506070809" `) lggr := logger.TestLogger(t) @@ -102,6 +106,24 @@ SomeOtherField = "abcd" require.Error(t, err) } +func TestGateway_NewGatewayFromConfig_InvalidNodeAddress(t *testing.T) { + t.Parallel() + + tomlConfig := buildConfig(` +[[dons]] +HandlerName = "dummy" +DonId = "my_don" + +[[dons.Members]] +Name = "node one" +Address = "0xnot_an_address" +`) + + lggr := logger.TestLogger(t) + _, err := gateway.NewGatewayFromConfig(parseTOMLConfig(t, tomlConfig), gateway.NewHandlerFactory(nil, lggr), lggr) + require.Error(t, err) +} + func TestGateway_CleanStartAndClose(t *testing.T) { t.Parallel() diff --git a/core/services/gateway/integration_tests/gateway_integration_test.go b/core/services/gateway/integration_tests/gateway_integration_test.go index 7a0204a3156..310047950e6 100644 --- a/core/services/gateway/integration_tests/gateway_integration_test.go +++ b/core/services/gateway/integration_tests/gateway_integration_test.go @@ -6,6 +6,7 @@ import ( "crypto/ecdsa" "fmt" "net/http" + "strings" "sync/atomic" "testing" @@ -110,6 +111,8 @@ func TestIntegration_Gateway_NoFullNodes_BasicConnectionAndMessage(t *testing.T) t.Parallel() nodeKeys := common.NewTestNodes(t, 1)[0] + // Verify that addresses in config are case-insensitive + nodeKeys.Address = strings.ToUpper(nodeKeys.Address) // Launch Gateway lggr := logger.TestLogger(t)