diff --git a/core/capabilities/compute/compute.go b/core/capabilities/compute/compute.go index 5a43b7bf40b..4de17e290ca 100644 --- a/core/capabilities/compute/compute.go +++ b/core/capabilities/compute/compute.go @@ -3,8 +3,10 @@ package compute import ( "context" "crypto/sha256" + "encoding/json" "errors" "fmt" + "strings" "time" "github.com/google/uuid" @@ -18,6 +20,9 @@ import ( coretypes "github.com/smartcontractkit/chainlink-common/pkg/types/core" "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/validation" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi" + ghcapabilities "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" ) const ( @@ -68,7 +73,9 @@ type Compute struct { registry coretypes.CapabilitiesRegistry modules *moduleCache - transformer ConfigTransformer + transformer ConfigTransformer + outgoingConnectorHandler *webapi.OutgoingConnectorHandler + idGenerator func() string } func (c *Compute) RegisterToWorkflow(ctx context.Context, request capabilities.RegisterToWorkflowRequest) error { @@ -104,7 +111,7 @@ func (c *Compute) Execute(ctx context.Context, request capabilities.CapabilityRe m, ok := c.modules.get(id) if !ok { - mod, err := c.initModule(id, cfg.ModuleConfig, cfg.Binary, request.Metadata.WorkflowID, request.Metadata.ReferenceID) + mod, err := c.initModule(id, cfg.ModuleConfig, cfg.Binary, request.Metadata.WorkflowID, request.Metadata.WorkflowExecutionID, request.Metadata.ReferenceID) if err != nil { return capabilities.CapabilityResponse{}, err } @@ -115,8 +122,10 @@ func (c *Compute) Execute(ctx context.Context, request capabilities.CapabilityRe return c.executeWithModule(m.module, cfg.Config, request) } -func (c *Compute) initModule(id string, cfg *host.ModuleConfig, binary []byte, workflowID, referenceID string) (*module, error) { +func (c *Compute) initModule(id string, cfg *host.ModuleConfig, binary []byte, workflowID, workflowExecutionID, referenceID string) (*module, error) { initStart := time.Now() + + cfg.Fetch = c.createFetcher(workflowID, workflowExecutionID) mod, err := host.NewModule(cfg, binary) if err != nil { return nil, fmt.Errorf("failed to instantiate WASM module: %w", err) @@ -186,12 +195,62 @@ func (c *Compute) Close() error { return nil } -func NewAction(log logger.Logger, registry coretypes.CapabilitiesRegistry) *Compute { +func (c *Compute) createFetcher(workflowID, workflowExecutionID string) func(req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { + return func(req *wasmpb.FetchRequest) (*wasmpb.FetchResponse, error) { + if err := validation.ValidateWorkflowOrExecutionID(workflowID); err != nil { + return nil, fmt.Errorf("workflow ID %q is invalid: %w", workflowID, err) + } + if err := validation.ValidateWorkflowOrExecutionID(workflowExecutionID); err != nil { + return nil, fmt.Errorf("workflow execution ID %q is invalid: %w", workflowExecutionID, err) + } + + messageID := strings.Join([]string{ + workflowID, + workflowExecutionID, + ghcapabilities.MethodComputeAction, + c.idGenerator(), + }, "/") + + fields := req.Headers.GetFields() + headersReq := make(map[string]string, len(fields)) + for k, v := range fields { + headersReq[k] = v.String() + } + + payloadBytes, err := json.Marshal(ghcapabilities.Request{ + URL: req.Url, + Method: req.Method, + Headers: headersReq, + Body: req.Body, + TimeoutMs: req.TimeoutMs, + }) + if err != nil { + return nil, fmt.Errorf("failed to marshal fetch request: %w", err) + } + + resp, err := c.outgoingConnectorHandler.HandleSingleNodeRequest(context.Background(), messageID, payloadBytes) + if err != nil { + return nil, err + } + + c.log.Debugw("received gateway response", "resp", resp) + var response wasmpb.FetchResponse + err = json.Unmarshal(resp.Body.Payload, &response) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal fetch response: %w", err) + } + return &response, nil + } +} + +func NewAction(config webapi.ServiceConfig, log logger.Logger, registry coretypes.CapabilitiesRegistry, handler *webapi.OutgoingConnectorHandler, idGenerator func() string) *Compute { compute := &Compute{ - log: logger.Named(log, "CustomCompute"), - registry: registry, - modules: newModuleCache(clockwork.NewRealClock(), 1*time.Minute, 10*time.Minute, 3), - transformer: NewTransformer(), + log: logger.Named(log, "CustomCompute"), + registry: registry, + modules: newModuleCache(clockwork.NewRealClock(), 1*time.Minute, 10*time.Minute, 3), + transformer: NewTransformer(), + outgoingConnectorHandler: handler, + idGenerator: idGenerator, } return compute } diff --git a/core/capabilities/compute/compute_test.go b/core/capabilities/compute/compute_test.go index 39e2b12150b..35021969ebe 100644 --- a/core/capabilities/compute/compute_test.go +++ b/core/capabilities/compute/compute_test.go @@ -1,10 +1,14 @@ package compute import ( + "context" + "encoding/json" + "strings" "testing" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink/v2/core/capabilities" @@ -14,30 +18,72 @@ import ( cappkg "github.com/smartcontractkit/chainlink-common/pkg/capabilities" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-common/pkg/values" + corecapabilities "github.com/smartcontractkit/chainlink/v2/core/capabilities" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" + gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" + ghcapabilities "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" ) -func Test_Compute_Start_AddsToRegistry(t *testing.T) { - log := logger.TestLogger(t) - registry := capabilities.NewRegistry(log) - - compute := NewAction(log, registry) - compute.modules.clock = clockwork.NewFakeClock() +const ( + fetchBinaryLocation = "test/fetch/cmd/testmodule.wasm" + fetchBinaryCmd = "core/capabilities/compute/test/fetch/cmd" + validRequestUUID = "d2fe6db9-beb4-47c9-b2d6-d3065ace111e" +) - require.NoError(t, compute.Start(tests.Context(t))) +var defaultConfig = webapi.ServiceConfig{ + RateLimiter: common.RateLimiterConfig{ + GlobalRPS: 100.0, + GlobalBurst: 100, + PerSenderRPS: 100.0, + PerSenderBurst: 100, + }, +} - cp, err := registry.Get(tests.Context(t), CapabilityIDCompute) - require.NoError(t, err) - assert.Equal(t, compute, cp) +type testHarness struct { + registry *corecapabilities.Registry + connector *gcmocks.GatewayConnector + log logger.Logger + config webapi.ServiceConfig + connectorHandler *webapi.OutgoingConnectorHandler + compute *Compute } -func Test_Compute_Execute_MissingConfig(t *testing.T) { +func setup(t *testing.T, config webapi.ServiceConfig) testHarness { log := logger.TestLogger(t) registry := capabilities.NewRegistry(log) + connector := gcmocks.NewGatewayConnector(t) + idGeneratorFn := func() string { return validRequestUUID } + connectorHandler, err := webapi.NewOutgoingConnectorHandler(connector, config, ghcapabilities.MethodComputeAction, log) + require.NoError(t, err) - compute := NewAction(log, registry) + compute := NewAction(config, log, registry, connectorHandler, idGeneratorFn) compute.modules.clock = clockwork.NewFakeClock() - require.NoError(t, compute.Start(tests.Context(t))) + return testHarness{ + registry: registry, + connector: connector, + log: log, + config: config, + connectorHandler: connectorHandler, + compute: compute, + } +} + +func TestComputeStartAddsToRegistry(t *testing.T) { + th := setup(t, defaultConfig) + + require.NoError(t, th.compute.Start(tests.Context(t))) + + cp, err := th.registry.Get(tests.Context(t), CapabilityIDCompute) + require.NoError(t, err) + assert.Equal(t, th.compute, cp) +} + +func TestComputeExecuteMissingConfig(t *testing.T) { + th := setup(t, defaultConfig) + require.NoError(t, th.compute.Start(tests.Context(t))) binary := wasmtest.CreateTestBinary(binaryCmd, binaryLocation, true, t) @@ -52,18 +98,14 @@ func Test_Compute_Execute_MissingConfig(t *testing.T) { ReferenceID: "compute", }, } - _, err = compute.Execute(tests.Context(t), req) + _, err = th.compute.Execute(tests.Context(t), req) assert.ErrorContains(t, err, "invalid request: could not find \"config\" in map") } -func Test_Compute_Execute_MissingBinary(t *testing.T) { - log := logger.TestLogger(t) - registry := capabilities.NewRegistry(log) +func TestComputeExecuteMissingBinary(t *testing.T) { + th := setup(t, defaultConfig) - compute := NewAction(log, registry) - compute.modules.clock = clockwork.NewFakeClock() - - require.NoError(t, compute.Start(tests.Context(t))) + require.NoError(t, th.compute.Start(tests.Context(t))) config, err := values.WrapMap(map[string]any{ "config": []byte(""), @@ -76,18 +118,14 @@ func Test_Compute_Execute_MissingBinary(t *testing.T) { ReferenceID: "compute", }, } - _, err = compute.Execute(tests.Context(t), req) + _, err = th.compute.Execute(tests.Context(t), req) assert.ErrorContains(t, err, "invalid request: could not find \"binary\" in map") } -func Test_Compute_Execute(t *testing.T) { - log := logger.TestLogger(t) - registry := capabilities.NewRegistry(log) - - compute := NewAction(log, registry) - compute.modules.clock = clockwork.NewFakeClock() +func TestComputeExecute(t *testing.T) { + th := setup(t, defaultConfig) - require.NoError(t, compute.Start(tests.Context(t))) + require.NoError(t, th.compute.Start(tests.Context(t))) binary := wasmtest.CreateTestBinary(binaryCmd, binaryLocation, true, t) @@ -110,7 +148,7 @@ func Test_Compute_Execute(t *testing.T) { ReferenceID: "compute", }, } - resp, err := compute.Execute(tests.Context(t), req) + resp, err := th.compute.Execute(tests.Context(t), req) assert.NoError(t, err) assert.True(t, resp.Value.Underlying["Value"].(*values.Bool).Underlying) @@ -132,7 +170,88 @@ func Test_Compute_Execute(t *testing.T) { ReferenceID: "compute", }, } - resp, err = compute.Execute(tests.Context(t), req) + resp, err = th.compute.Execute(tests.Context(t), req) assert.NoError(t, err) assert.False(t, resp.Value.Underlying["Value"].(*values.Bool).Underlying) } + +func TestComputeFetch(t *testing.T) { + workflowID := "15c631d295ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0" + workflowExecutionID := "95ef5e32deb99a10ee6804bc4af13855687559d7ff6552ac6dbb2ce0abbadeed" + th := setup(t, defaultConfig) + + th.connector.EXPECT().DonID().Return("don-id") + th.connector.EXPECT().GatewayIDs().Return([]string{"gateway1", "gateway2"}) + + msgID := strings.Join([]string{ + workflowID, + workflowExecutionID, + ghcapabilities.MethodComputeAction, + validRequestUUID, + }, "/") + + gatewayResp := gatewayResponse(t, msgID) + th.connector.On("SignAndSendToGateway", mock.Anything, "gateway1", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + th.connectorHandler.HandleGatewayMessage(context.Background(), "gateway1", gatewayResp) + }).Once() + + require.NoError(t, th.compute.Start(tests.Context(t))) + + binary := wasmtest.CreateTestBinary(fetchBinaryCmd, fetchBinaryLocation, true, t) + + config, err := values.WrapMap(map[string]any{ + "config": []byte(""), + "binary": binary, + }) + require.NoError(t, err) + + req := cappkg.CapabilityRequest{ + Config: config, + Metadata: cappkg.RequestMetadata{ + WorkflowID: workflowID, + WorkflowExecutionID: workflowExecutionID, + ReferenceID: "compute", + }, + } + + headers, err := values.NewMap(map[string]any{}) + require.NoError(t, err) + expected := cappkg.CapabilityResponse{ + Value: &values.Map{ + Underlying: map[string]values.Value{ + "Value": &values.Map{ + Underlying: map[string]values.Value{ + "Body": values.NewBytes([]byte("response body")), + "Headers": headers, + "StatusCode": values.NewInt64(200), + "ErrorMessage": values.NewString(""), + "ExecutionError": values.NewBool(false), + }, + }, + }, + }, + } + + actual, err := th.compute.Execute(tests.Context(t), req) + require.NoError(t, err) + assert.EqualValues(t, expected, actual) +} + +func gatewayResponse(t *testing.T, msgID string) *api.Message { + headers := map[string]string{"Content-Type": "application/json"} + body := []byte("response body") + responsePayload, err := json.Marshal(ghcapabilities.Response{ + StatusCode: 200, + Headers: headers, + Body: body, + ExecutionError: false, + }) + require.NoError(t, err) + return &api.Message{ + Body: api.MessageBody{ + MessageId: msgID, + Method: ghcapabilities.MethodComputeAction, + Payload: responsePayload, + }, + } +} diff --git a/core/capabilities/compute/test/fetch/cmd/main.go b/core/capabilities/compute/test/fetch/cmd/main.go new file mode 100644 index 00000000000..bc45b426005 --- /dev/null +++ b/core/capabilities/compute/test/fetch/cmd/main.go @@ -0,0 +1,43 @@ +//go:build wasip1 + +package main + +import ( + "net/http" + + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/cli/cmd/testdata/fixtures/capabilities/basictrigger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk" +) + +func BuildWorkflow(config []byte) *sdk.WorkflowSpecFactory { + workflow := sdk.NewWorkflowSpecFactory( + sdk.NewWorkflowParams{ + Name: "tester", + Owner: "ryan", + }, + ) + + triggerCfg := basictrigger.TriggerConfig{Name: "trigger", Number: 100} + trigger := triggerCfg.New(workflow) + + sdk.Compute1[basictrigger.TriggerOutputs, sdk.FetchResponse]( + workflow, + "compute", + sdk.Compute1Inputs[basictrigger.TriggerOutputs]{Arg0: trigger}, + func(rsdk sdk.Runtime, outputs basictrigger.TriggerOutputs) (sdk.FetchResponse, error) { + return rsdk.Fetch(sdk.FetchRequest{ + Method: http.MethodGet, + URL: "https://min-api.cryptocompare.com/data/pricemultifull?fsyms=ETH&tsyms=BTC", + }) + }) + + return workflow +} + +func main() { + runner := wasm.NewRunner() + workflow := BuildWorkflow(runner.Config()) + runner.Run(workflow) +} diff --git a/core/capabilities/webapi/target/connector_handler.go b/core/capabilities/webapi/outgoing_connector_handler.go similarity index 69% rename from core/capabilities/webapi/target/connector_handler.go rename to core/capabilities/webapi/outgoing_connector_handler.go index 11fdd115705..b00b82b2bd0 100644 --- a/core/capabilities/webapi/target/connector_handler.go +++ b/core/capabilities/webapi/outgoing_connector_handler.go @@ -1,8 +1,9 @@ -package target +package webapi import ( "context" "encoding/json" + "fmt" "sort" "sync" @@ -11,28 +12,35 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" - "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/webapicapabilities" ) -var _ connector.GatewayConnectorHandler = &ConnectorHandler{} +var _ connector.GatewayConnectorHandler = &OutgoingConnectorHandler{} -type ConnectorHandler struct { +type OutgoingConnectorHandler struct { gc connector.GatewayConnector + method string lggr logger.Logger responseChs map[string]chan *api.Message responseChsMu sync.Mutex rateLimiter *common.RateLimiter } -func NewConnectorHandler(gc connector.GatewayConnector, config ServiceConfig, lgger logger.Logger) (*ConnectorHandler, error) { +func NewOutgoingConnectorHandler(gc connector.GatewayConnector, config ServiceConfig, method string, lgger logger.Logger) (*OutgoingConnectorHandler, error) { rateLimiter, err := common.NewRateLimiter(config.RateLimiter) if err != nil { return nil, err } + + if !validMethod(method) { + return nil, fmt.Errorf("invalid outgoing connector handler method: %s", method) + } + responseChs := make(map[string]chan *api.Message) - return &ConnectorHandler{ + return &OutgoingConnectorHandler{ gc: gc, + method: method, responseChs: responseChs, responseChsMu: sync.Mutex{}, rateLimiter: rateLimiter, @@ -42,7 +50,7 @@ func NewConnectorHandler(gc connector.GatewayConnector, config ServiceConfig, lg // HandleSingleNodeRequest sends a request to first available gateway node and blocks until response is received // TODO: handle retries and timeouts -func (c *ConnectorHandler) HandleSingleNodeRequest(ctx context.Context, messageID string, payload []byte) (*api.Message, error) { +func (c *OutgoingConnectorHandler) HandleSingleNodeRequest(ctx context.Context, messageID string, payload []byte) (*api.Message, error) { ch := make(chan *api.Message, 1) c.responseChsMu.Lock() c.responseChs[messageID] = ch @@ -53,7 +61,7 @@ func (c *ConnectorHandler) HandleSingleNodeRequest(ctx context.Context, messageI body := &api.MessageBody{ MessageId: messageID, DonId: c.gc.DonID(), - Method: webapicapabilities.MethodWebAPITarget, + Method: c.method, Payload: payload, } @@ -79,7 +87,7 @@ func (c *ConnectorHandler) HandleSingleNodeRequest(ctx context.Context, messageI } } -func (c *ConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayID string, msg *api.Message) { +func (c *OutgoingConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayID string, msg *api.Message) { body := &msg.Body l := logger.With(c.lggr, "gatewayID", gatewayID, "method", body.Method, "messageID", msg.Body.MessageId) if !c.rateLimiter.Allow(body.Sender) { @@ -90,8 +98,9 @@ func (c *ConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayID s } l.Debugw("handling gateway request") switch body.Method { - case webapicapabilities.MethodWebAPITarget: - var payload webapicapabilities.TargetResponsePayload + case capabilities.MethodWebAPITarget, capabilities.MethodComputeAction: + body := &msg.Body + var payload capabilities.Response err := json.Unmarshal(body.Payload, &payload) if err != nil { l.Errorw("failed to unmarshal payload", "err", err) @@ -115,10 +124,19 @@ func (c *ConnectorHandler) HandleGatewayMessage(ctx context.Context, gatewayID s } } -func (c *ConnectorHandler) Start(ctx context.Context) error { - return c.gc.AddHandler([]string{webapicapabilities.MethodWebAPITarget}, c) +func (c *OutgoingConnectorHandler) Start(ctx context.Context) error { + return c.gc.AddHandler([]string{c.method}, c) } -func (c *ConnectorHandler) Close() error { +func (c *OutgoingConnectorHandler) Close() error { return nil } + +func validMethod(method string) bool { + switch method { + case capabilities.MethodWebAPITarget, capabilities.MethodComputeAction: + return true + default: + return false + } +} diff --git a/core/capabilities/webapi/target/target.go b/core/capabilities/webapi/target/target.go index d4a66ccf391..4576f95a54e 100644 --- a/core/capabilities/webapi/target/target.go +++ b/core/capabilities/webapi/target/target.go @@ -11,8 +11,9 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types/core" "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink/v2/core/capabilities/validation" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi" "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi/webapicap" - "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/webapicapabilities" + ghcapabilities "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" ) const ID = "web-api-target@1.0.0" @@ -26,7 +27,7 @@ var capabilityInfo = capabilities.MustNewCapabilityInfo( ) const ( - DefaultDeliveryMode = SingleNode + DefaultDeliveryMode = webapi.SingleNode DefaultHTTPMethod = "GET" DefaultTimeoutMs = 30000 MaxTimeoutMs = 600000 @@ -35,13 +36,13 @@ const ( // Capability is a target capability that sends HTTP requests to external clients via the Chainlink Gateway. type Capability struct { capabilityInfo capabilities.CapabilityInfo - connectorHandler *ConnectorHandler + connectorHandler *webapi.OutgoingConnectorHandler lggr logger.Logger registry core.CapabilitiesRegistry - config ServiceConfig + config webapi.ServiceConfig } -func NewCapability(config ServiceConfig, registry core.CapabilitiesRegistry, connectorHandler *ConnectorHandler, lggr logger.Logger) (*Capability, error) { +func NewCapability(config webapi.ServiceConfig, registry core.CapabilitiesRegistry, connectorHandler *webapi.OutgoingConnectorHandler, lggr logger.Logger) (*Capability, error) { return &Capability{ capabilityInfo: capabilityInfo, config: config, @@ -72,7 +73,7 @@ func getMessageID(req capabilities.CapabilityRequest) (string, error) { } messageID := []string{ req.Metadata.WorkflowExecutionID, - webapicapabilities.MethodWebAPITarget, + ghcapabilities.MethodWebAPITarget, } return strings.Join(messageID, "/"), nil } @@ -85,15 +86,15 @@ func defaultIfNil[T any](value *T, defaultValue T) T { return defaultValue } -func getPayload(input webapicap.TargetPayload, cfg webapicap.TargetConfig) (webapicapabilities.TargetRequestPayload, error) { +func getPayload(input webapicap.TargetPayload, cfg webapicap.TargetConfig) (ghcapabilities.Request, error) { method := defaultIfNil(input.Method, DefaultHTTPMethod) body := defaultIfNil(input.Body, "") timeoutMs := defaultIfNil(cfg.TimeoutMs, DefaultTimeoutMs) if timeoutMs > MaxTimeoutMs { - return webapicapabilities.TargetRequestPayload{}, fmt.Errorf("timeoutMs must be between 0 and %d", MaxTimeoutMs) + return ghcapabilities.Request{}, fmt.Errorf("timeoutMs must be between 0 and %d", MaxTimeoutMs) } - return webapicapabilities.TargetRequestPayload{ + return ghcapabilities.Request{ URL: input.Url, Method: method, Headers: input.Headers, @@ -133,17 +134,17 @@ func (c *Capability) Execute(ctx context.Context, req capabilities.CapabilityReq } // Default to SingleNode delivery mode - deliveryMode := defaultIfNil(workflowCfg.DeliveryMode, SingleNode) + deliveryMode := defaultIfNil(workflowCfg.DeliveryMode, webapi.SingleNode) switch deliveryMode { - case SingleNode: + case webapi.SingleNode: // blocking call to handle single node request. waits for response from gateway resp, err := c.connectorHandler.HandleSingleNodeRequest(ctx, messageID, payloadBytes) if err != nil { return capabilities.CapabilityResponse{}, err } c.lggr.Debugw("received gateway response", "resp", resp) - var payload webapicapabilities.TargetResponsePayload + var payload ghcapabilities.Response err = json.Unmarshal(resp.Body.Payload, &payload) if err != nil { return capabilities.CapabilityResponse{}, err diff --git a/core/capabilities/webapi/target/target_test.go b/core/capabilities/webapi/target/target_test.go index 524c99cb5f0..f51cdcd0d70 100644 --- a/core/capabilities/webapi/target/target_test.go +++ b/core/capabilities/webapi/target/target_test.go @@ -14,11 +14,12 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" registrymock "github.com/smartcontractkit/chainlink-common/pkg/types/core/mocks" "github.com/smartcontractkit/chainlink-common/pkg/values" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" + ghcapabilities "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" - "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/webapicapabilities" ) const ( @@ -28,7 +29,7 @@ const ( owner1 = "0x00000000000000000000000000000000000000aa" ) -var defaultConfig = ServiceConfig{ +var defaultConfig = webapi.ServiceConfig{ RateLimiter: common.RateLimiterConfig{ GlobalRPS: 100.0, GlobalBurst: 100, @@ -41,16 +42,16 @@ type testHarness struct { registry *registrymock.CapabilitiesRegistry connector *gcmocks.GatewayConnector lggr logger.Logger - config ServiceConfig - connectorHandler *ConnectorHandler + config webapi.ServiceConfig + connectorHandler *webapi.OutgoingConnectorHandler capability *Capability } -func setup(t *testing.T, config ServiceConfig) testHarness { +func setup(t *testing.T, config webapi.ServiceConfig) testHarness { registry := registrymock.NewCapabilitiesRegistry(t) connector := gcmocks.NewGatewayConnector(t) lggr := logger.Test(t) - connectorHandler, err := NewConnectorHandler(connector, config, lggr) + connectorHandler, err := webapi.NewOutgoingConnectorHandler(connector, config, ghcapabilities.MethodWebAPITarget, lggr) require.NoError(t, err) capability, err := NewCapability(config, registry, connectorHandler, lggr) @@ -89,7 +90,7 @@ func inputsAndConfig(t *testing.T) (*values.Map, *values.Map) { require.NoError(t, err) wfConfig, err := values.NewMap(map[string]interface{}{ "timeoutMs": 1000, - "schedule": SingleNode, + "schedule": webapi.SingleNode, }) require.NoError(t, err) return inputs, wfConfig @@ -111,7 +112,7 @@ func capabilityRequest(t *testing.T) capabilities.CapabilityRequest { func gatewayResponse(t *testing.T, msgID string) *api.Message { headers := map[string]string{"Content-Type": "application/json"} body := []byte("response body") - responsePayload, err := json.Marshal(webapicapabilities.TargetResponsePayload{ + responsePayload, err := json.Marshal(ghcapabilities.Response{ StatusCode: 200, Headers: headers, Body: body, @@ -121,7 +122,7 @@ func gatewayResponse(t *testing.T, msgID string) *api.Message { return &api.Message{ Body: api.MessageBody{ MessageId: msgID, - Method: webapicapabilities.MethodWebAPITarget, + Method: ghcapabilities.MethodWebAPITarget, Payload: responsePayload, }, } diff --git a/core/capabilities/webapi/trigger.go b/core/capabilities/webapi/trigger/trigger.go similarity index 90% rename from core/capabilities/webapi/trigger.go rename to core/capabilities/webapi/trigger/trigger.go index 8eb971f9a83..a08d2f577ff 100644 --- a/core/capabilities/webapi/trigger.go +++ b/core/capabilities/webapi/trigger/trigger.go @@ -18,8 +18,8 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" + ghcapabilities "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" - "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/webapicapabilities" ) const defaultSendChannelBufferSize = 1000 @@ -139,7 +139,7 @@ func (h *triggerConnectorHandler) HandleGatewayMessage(ctx context.Context, gate err := json.Unmarshal(body.Payload, &payload) if err != nil { h.lggr.Errorw("error decoding payload", "err", err) - err = h.sendResponse(ctx, gatewayID, body, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: fmt.Errorf("error %s decoding payload", err.Error()).Error()}) + err = h.sendResponse(ctx, gatewayID, body, ghcapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: fmt.Errorf("error %s decoding payload", err.Error()).Error()}) if err != nil { h.lggr.Errorw("error sending response", "err", err) } @@ -147,13 +147,13 @@ func (h *triggerConnectorHandler) HandleGatewayMessage(ctx context.Context, gate } switch body.Method { - case webapicapabilities.MethodWebAPITrigger: + case ghcapabilities.MethodWebAPITrigger: resp := h.processTrigger(ctx, gatewayID, body, sender, payload) - var response webapicapabilities.TriggerResponsePayload + var response ghcapabilities.TriggerResponsePayload if resp == nil { - response = webapicapabilities.TriggerResponsePayload{Status: "ACCEPTED"} + response = ghcapabilities.TriggerResponsePayload{Status: "ACCEPTED"} } else { - response = webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: resp.Error()} + response = ghcapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: resp.Error()} h.lggr.Errorw("Error processing trigger", "gatewayID", gatewayID, "body", body, "response", resp) } err = h.sendResponse(ctx, gatewayID, body, response) @@ -164,7 +164,7 @@ func (h *triggerConnectorHandler) HandleGatewayMessage(ctx context.Context, gate default: h.lggr.Errorw("unsupported method", "id", gatewayID, "method", body.Method) - err = h.sendResponse(ctx, gatewayID, body, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: fmt.Errorf("unsupported method %s", body.Method).Error()}) + err = h.sendResponse(ctx, gatewayID, body, ghcapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: fmt.Errorf("unsupported method %s", body.Method).Error()}) if err != nil { h.lggr.Errorw("error sending response", "err", err) } @@ -272,7 +272,7 @@ func (h *triggerConnectorHandler) sendResponse(ctx context.Context, gatewayID st payloadJSON, err := json.Marshal(payload) if err != nil { h.lggr.Errorw("error marshalling payload", "err", err) - payloadJSON, _ = json.Marshal(webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: fmt.Errorf("error %s marshalling payload", err.Error()).Error()}) + payloadJSON, _ = json.Marshal(ghcapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: fmt.Errorf("error %s marshalling payload", err.Error()).Error()}) } body := &api.MessageBody{ diff --git a/core/capabilities/webapi/trigger_test.go b/core/capabilities/webapi/trigger/trigger_test.go similarity index 91% rename from core/capabilities/webapi/trigger_test.go rename to core/capabilities/webapi/trigger/trigger_test.go index f5e12a48758..0c73e31fe62 100644 --- a/core/capabilities/webapi/trigger_test.go +++ b/core/capabilities/webapi/trigger/trigger_test.go @@ -22,7 +22,7 @@ import ( corelogger "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" - "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/webapicapabilities" + ghcapabilities "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" ) const ( @@ -86,7 +86,7 @@ func setup(t *testing.T) testHarness { func gatewayRequest(t *testing.T, privateKey string, topics string, methodName string) *api.Message { messageID := "12345" if methodName == "" { - methodName = webapicapabilities.MethodWebAPITrigger + methodName = ghcapabilities.MethodWebAPITrigger } donID := "workflow_don_1" @@ -118,8 +118,8 @@ func gatewayRequest(t *testing.T, privateKey string, topics string, methodName s return msg } -func getResponseFromArg(arg interface{}) (webapicapabilities.TriggerResponsePayload, error) { - var response webapicapabilities.TriggerResponsePayload +func getResponseFromArg(arg interface{}) (ghcapabilities.TriggerResponsePayload, error) { + var response ghcapabilities.TriggerResponsePayload msgBody := arg.(*api.MessageBody) err := json.Unmarshal(msgBody.Payload, &response) return response, err @@ -182,7 +182,7 @@ func TestTriggerExecute(t *testing.T) { th.connector.On("SignAndSendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { resp, _ := getResponseFromArg(args.Get(2)) - require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ACCEPTED"}, resp) + require.Equal(t, ghcapabilities.TriggerResponsePayload{Status: "ACCEPTED"}, resp) }).Return(nil).Once() th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) @@ -204,7 +204,7 @@ func TestTriggerExecute(t *testing.T) { th.connector.On("SignAndSendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { resp, _ := getResponseFromArg(args.Get(2)) - require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ACCEPTED"}, resp) + require.Equal(t, ghcapabilities.TriggerResponsePayload{Status: "ACCEPTED"}, resp) }).Return(nil).Once() th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) @@ -231,7 +231,7 @@ func TestTriggerExecute(t *testing.T) { th.connector.On("SignAndSendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { resp, _ := getResponseFromArg(args.Get(2)) - require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "empty Workflow Topics"}, resp) + require.Equal(t, ghcapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "empty Workflow Topics"}, resp) }).Return(nil).Once() th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) @@ -244,7 +244,7 @@ func TestTriggerExecute(t *testing.T) { gatewayRequest := gatewayRequest(t, privateKey1, `["foo"]`, "") th.connector.On("SignAndSendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { resp, _ := getResponseFromArg(args.Get(2)) - require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "no Matching Workflow Topics"}, resp) + require.Equal(t, ghcapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "no Matching Workflow Topics"}, resp) }).Return(nil).Once() th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) @@ -257,7 +257,7 @@ func TestTriggerExecute(t *testing.T) { th.connector.On("SignAndSendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { resp, _ := getResponseFromArg(args.Get(2)) - require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "unauthorized Sender 0x2dAC9f74Ee66e2D55ea1B8BE284caFedE048dB3A, messageID 12345"}, resp) + require.Equal(t, ghcapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "unauthorized Sender 0x2dAC9f74Ee66e2D55ea1B8BE284caFedE048dB3A, messageID 12345"}, resp) }).Return(nil).Once() th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) @@ -269,7 +269,7 @@ func TestTriggerExecute(t *testing.T) { gatewayRequest := gatewayRequest(t, privateKey2, `["ad_hoc_price_update"]`, "boo") th.connector.On("SignAndSendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { resp, _ := getResponseFromArg(args.Get(2)) - require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "unsupported method boo"}, resp) + require.Equal(t, ghcapabilities.TriggerResponsePayload{Status: "ERROR", ErrorMessage: "unsupported method boo"}, resp) }).Return(nil).Once() th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) @@ -338,7 +338,7 @@ func TestTriggerExecute2WorkflowsSameTopicDifferentAllowLists(t *testing.T) { th.connector.On("SignAndSendToGateway", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { resp, _ := getResponseFromArg(args.Get(2)) - require.Equal(t, webapicapabilities.TriggerResponsePayload{Status: "ACCEPTED"}, resp) + require.Equal(t, ghcapabilities.TriggerResponsePayload{Status: "ACCEPTED"}, resp) }).Return(nil).Once() th.trigger.HandleGatewayMessage(ctx, "gateway1", gatewayRequest) diff --git a/core/capabilities/webapi/target/types.go b/core/capabilities/webapi/types.go similarity index 97% rename from core/capabilities/webapi/target/types.go rename to core/capabilities/webapi/types.go index 799eb646872..62d6143bea5 100644 --- a/core/capabilities/webapi/target/types.go +++ b/core/capabilities/webapi/types.go @@ -1,4 +1,4 @@ -package target +package webapi import "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" diff --git a/core/services/gateway/handler_factory.go b/core/services/gateway/handler_factory.go index 0c1eeaf676e..4d58b95fbca 100644 --- a/core/services/gateway/handler_factory.go +++ b/core/services/gateway/handler_factory.go @@ -9,8 +9,8 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions" - "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/webapicapabilities" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/network" ) @@ -45,7 +45,7 @@ func (hf *handlerFactory) NewHandler(handlerType HandlerType, handlerConfig json case DummyHandlerType: return handlers.NewDummyHandler(donConfig, don, hf.lggr) case WebAPICapabilitiesType: - return webapicapabilities.NewHandler(handlerConfig, donConfig, don, hf.httpClient, hf.lggr) + return capabilities.NewHandler(handlerConfig, donConfig, don, hf.httpClient, hf.lggr) default: return nil, fmt.Errorf("unsupported handler type %s", handlerType) } diff --git a/core/services/gateway/handlers/webapicapabilities/handler.go b/core/services/gateway/handlers/capabilities/handler.go similarity index 86% rename from core/services/gateway/handlers/webapicapabilities/handler.go rename to core/services/gateway/handlers/capabilities/handler.go index aa6823e4775..904a64c8896 100644 --- a/core/services/gateway/handlers/webapicapabilities/handler.go +++ b/core/services/gateway/handlers/capabilities/handler.go @@ -1,4 +1,4 @@ -package webapicapabilities +package capabilities import ( "context" @@ -22,6 +22,7 @@ const ( // NOTE: more methods will go here. HTTP trigger/action/target; etc. MethodWebAPITarget = "web_api_target" MethodWebAPITrigger = "web_api_trigger" + MethodComputeAction = "compute_action" ) type handler struct { @@ -74,12 +75,12 @@ func NewHandler(handlerConfig json.RawMessage, donConfig *config.DONConfig, don // sendHTTPMessageToClient is an outgoing message from the gateway to external endpoints // returns message to be sent back to the capability node func (h *handler) sendHTTPMessageToClient(ctx context.Context, req network.HTTPRequest, msg *api.Message) (*api.Message, error) { - var payload TargetResponsePayload + var payload Response resp, err := h.httpClient.Send(ctx, req) if err != nil { return nil, err } - payload = TargetResponsePayload{ + payload = Response{ ExecutionError: false, StatusCode: resp.StatusCode, Headers: resp.Headers, @@ -100,28 +101,43 @@ func (h *handler) sendHTTPMessageToClient(ctx context.Context, req network.HTTPR }, nil } -func (h *handler) handleWebAPITargetMessage(ctx context.Context, msg *api.Message, nodeAddr string) error { - h.lggr.Debugw("handling web api target message", "messageId", msg.Body.MessageId, "nodeAddr", nodeAddr) +func (h *handler) handleWebAPITriggerMessage(ctx context.Context, msg *api.Message, nodeAddr string) error { + h.mu.Lock() + savedCb, found := h.savedCallbacks[msg.Body.MessageId] + delete(h.savedCallbacks, msg.Body.MessageId) + h.mu.Unlock() + + if found { + // Send first response from a node back to the user, ignore any other ones. + // TODO: in practice, we should wait for at least 2F+1 nodes to respond and then return an aggregated response + // back to the user. + savedCb.callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.NoError, ErrMsg: ""} + close(savedCb.callbackCh) + } + return nil +} + +func (h *handler) handleWebAPIOutgoingMessage(ctx context.Context, msg *api.Message, nodeAddr string) error { + h.lggr.Debugw("handling webAPI outgoing message", "messageId", msg.Body.MessageId, "nodeAddr", nodeAddr) if !h.nodeRateLimiter.Allow(nodeAddr) { return fmt.Errorf("rate limit exceeded for node %s", nodeAddr) } - var targetPayload TargetRequestPayload - err := json.Unmarshal(msg.Body.Payload, &targetPayload) + var payload Request + err := json.Unmarshal(msg.Body.Payload, &payload) if err != nil { return err } - // send message to target - timeout := time.Duration(targetPayload.TimeoutMs) * time.Millisecond + + timeout := time.Duration(payload.TimeoutMs) * time.Millisecond req := network.HTTPRequest{ - Method: targetPayload.Method, - URL: targetPayload.URL, - Headers: targetPayload.Headers, - Body: targetPayload.Body, + Method: payload.Method, + URL: payload.URL, + Headers: payload.Headers, + Body: payload.Body, Timeout: timeout, } - // this handle method must be non-blocking - // send response to node (target capability) async - // if there is a non-HTTP error (e.g. malformed request), send payload with success set to false and error messages + + // send response to node async h.wg.Add(1) go func() { defer h.wg.Done() @@ -129,11 +145,11 @@ func (h *handler) handleWebAPITargetMessage(ctx context.Context, msg *api.Messag newCtx := context.WithoutCancel(ctx) newCtx, cancel := context.WithTimeout(newCtx, timeout) defer cancel() - l := h.lggr.With("url", targetPayload.URL, "messageId", msg.Body.MessageId, "method", targetPayload.Method) + l := h.lggr.With("url", payload.URL, "messageId", msg.Body.MessageId, "method", payload.Method) respMsg, err := h.sendHTTPMessageToClient(newCtx, req, msg) if err != nil { l.Errorw("error while sending HTTP request to external endpoint", "err", err) - payload := TargetResponsePayload{ + payload := Response{ ExecutionError: true, ErrorMessage: err.Error(), } @@ -166,28 +182,12 @@ func (h *handler) handleWebAPITargetMessage(ctx context.Context, msg *api.Messag return nil } -func (h *handler) handleWebAPITriggerMessage(ctx context.Context, msg *api.Message, nodeAddr string) error { - h.mu.Lock() - savedCb, found := h.savedCallbacks[msg.Body.MessageId] - delete(h.savedCallbacks, msg.Body.MessageId) - h.mu.Unlock() - - if found { - // Send first response from a node back to the user, ignore any other ones. - // TODO: in practice, we should wait for at least 2F+1 nodes to respond and then return an aggregated response - // back to the user. - savedCb.callbackCh <- handlers.UserCallbackPayload{Msg: msg, ErrCode: api.NoError, ErrMsg: ""} - close(savedCb.callbackCh) - } - return nil -} - func (h *handler) HandleNodeMessage(ctx context.Context, msg *api.Message, nodeAddr string) error { switch msg.Body.Method { case MethodWebAPITrigger: return h.handleWebAPITriggerMessage(ctx, msg, nodeAddr) - case MethodWebAPITarget: - return h.handleWebAPITargetMessage(ctx, msg, nodeAddr) + case MethodWebAPITarget, MethodComputeAction: + return h.handleWebAPIOutgoingMessage(ctx, msg, nodeAddr) default: return fmt.Errorf("unsupported method: %s", msg.Body.Method) } diff --git a/core/services/gateway/handlers/webapicapabilities/handler_test.go b/core/services/gateway/handlers/capabilities/handler_test.go similarity index 73% rename from core/services/gateway/handlers/webapicapabilities/handler_test.go rename to core/services/gateway/handlers/capabilities/handler_test.go index 8d0308c0e71..eb5d883ac14 100644 --- a/core/services/gateway/handlers/webapicapabilities/handler_test.go +++ b/core/services/gateway/handlers/capabilities/handler_test.go @@ -1,7 +1,8 @@ -package webapicapabilities +package capabilities import ( "encoding/json" + "errors" "fmt" "strconv" "testing" @@ -75,7 +76,7 @@ func TestHandler_SendHTTPMessageToClient(t *testing.T) { handler, httpClient, don, nodes := setupHandler(t) ctx := testutils.Context(t) nodeAddr := nodes[0].Address - payload := TargetRequestPayload{ + payload := Request{ Method: "GET", URL: "http://example.com", Headers: map[string]string{}, @@ -101,7 +102,7 @@ func TestHandler_SendHTTPMessageToClient(t *testing.T) { }, nil).Once() don.EXPECT().SendToNode(mock.Anything, nodes[0].Address, mock.MatchedBy(func(m *api.Message) bool { - var payload TargetResponsePayload + var payload Response err2 := json.Unmarshal(m.Body.Payload, &payload) if err2 != nil { return false @@ -134,7 +135,7 @@ func TestHandler_SendHTTPMessageToClient(t *testing.T) { }, nil).Once() don.EXPECT().SendToNode(mock.Anything, nodes[0].Address, mock.MatchedBy(func(m *api.Message) bool { - var payload TargetResponsePayload + var payload Response err2 := json.Unmarshal(m.Body.Payload, &payload) if err2 != nil { return false @@ -163,7 +164,7 @@ func TestHandler_SendHTTPMessageToClient(t *testing.T) { httpClient.EXPECT().Send(mock.Anything, mock.Anything).Return(nil, fmt.Errorf("error while marshalling")).Once() don.EXPECT().SendToNode(mock.Anything, nodes[0].Address, mock.MatchedBy(func(m *api.Message) bool { - var payload TargetResponsePayload + var payload Response err2 := json.Unmarshal(m.Body.Payload, &payload) if err2 != nil { return false @@ -310,3 +311,119 @@ func TestHandlerReceiveHTTPMessageFromClient(t *testing.T) { }) // TODO: Validate Senders and rate limit chck, pending question in trigger about where senders and rate limits are validated } + +func TestHandleComputeActionMessage(t *testing.T) { + handler, httpClient, don, nodes := setupHandler(t) + ctx := testutils.Context(t) + nodeAddr := nodes[0].Address + payload := Request{ + Method: "GET", + URL: "http://example.com", + Headers: map[string]string{}, + Body: nil, + TimeoutMs: 2000, + } + payloadBytes, err := json.Marshal(payload) + require.NoError(t, err) + msg := &api.Message{ + Body: api.MessageBody{ + MessageId: "123", + Method: MethodComputeAction, + DonId: "testDonId", + Payload: json.RawMessage(payloadBytes), + }, + } + + t.Run("OK-compute_with_fetch", func(t *testing.T) { + httpClient.EXPECT().Send(mock.Anything, mock.Anything).Return(&network.HTTPResponse{ + StatusCode: 200, + Headers: map[string]string{}, + Body: []byte("response body"), + }, nil).Once() + + don.EXPECT().SendToNode(mock.Anything, nodes[0].Address, mock.MatchedBy(func(m *api.Message) bool { + var payload Response + err2 := json.Unmarshal(m.Body.Payload, &payload) + if err2 != nil { + return false + } + return "123" == m.Body.MessageId && + MethodComputeAction == m.Body.Method && + "testDonId" == m.Body.DonId && + 200 == payload.StatusCode && + 0 == len(payload.Headers) && + string(payload.Body) == "response body" && + !payload.ExecutionError + })).Return(nil).Once() + + err = handler.HandleNodeMessage(ctx, msg, nodeAddr) + require.NoError(t, err) + + require.Eventually(t, func() bool { + // ensure all goroutines close + err2 := handler.Close() + require.NoError(t, err2) + return httpClient.AssertExpectations(t) && don.AssertExpectations(t) + }, tests.WaitTimeout(t), 100*time.Millisecond) + }) + + t.Run("NOK-payload_error_making_external_request", func(t *testing.T) { + httpClient.EXPECT().Send(mock.Anything, mock.Anything).Return(&network.HTTPResponse{ + StatusCode: 404, + Headers: map[string]string{}, + Body: []byte("access denied"), + }, nil).Once() + + don.EXPECT().SendToNode(mock.Anything, nodes[0].Address, mock.MatchedBy(func(m *api.Message) bool { + var payload Response + err2 := json.Unmarshal(m.Body.Payload, &payload) + if err2 != nil { + return false + } + return "123" == m.Body.MessageId && + MethodComputeAction == m.Body.Method && + "testDonId" == m.Body.DonId && + 404 == payload.StatusCode && + string(payload.Body) == "access denied" && + 0 == len(payload.Headers) && + !payload.ExecutionError + })).Return(nil).Once() + + err = handler.HandleNodeMessage(ctx, msg, nodeAddr) + require.NoError(t, err) + + require.Eventually(t, func() bool { + // // ensure all goroutines close + err2 := handler.Close() + require.NoError(t, err2) + return httpClient.AssertExpectations(t) && don.AssertExpectations(t) + }, tests.WaitTimeout(t), 100*time.Millisecond) + }) + + t.Run("NOK-error_outside_payload", func(t *testing.T) { + httpClient.EXPECT().Send(mock.Anything, mock.Anything).Return(nil, errors.New("error while marshalling")).Once() + + don.EXPECT().SendToNode(mock.Anything, nodes[0].Address, mock.MatchedBy(func(m *api.Message) bool { + var payload Response + err2 := json.Unmarshal(m.Body.Payload, &payload) + if err2 != nil { + return false + } + return "123" == m.Body.MessageId && + MethodComputeAction == m.Body.Method && + "testDonId" == m.Body.DonId && + payload.ExecutionError && + "error while marshalling" == payload.ErrorMessage + })).Return(nil).Once() + + err = handler.HandleNodeMessage(ctx, msg, nodeAddr) + require.NoError(t, err) + + require.Eventually(t, func() bool { + // // ensure all goroutines close + err2 := handler.Close() + require.NoError(t, err2) + return httpClient.AssertExpectations(t) && don.AssertExpectations(t) + }, tests.WaitTimeout(t), 100*time.Millisecond) + }) +} diff --git a/core/services/gateway/handlers/webapicapabilities/webapi.go b/core/services/gateway/handlers/capabilities/webapi.go similarity index 92% rename from core/services/gateway/handlers/webapicapabilities/webapi.go rename to core/services/gateway/handlers/capabilities/webapi.go index f6e2f6679a1..a0213eb8f42 100644 --- a/core/services/gateway/handlers/webapicapabilities/webapi.go +++ b/core/services/gateway/handlers/capabilities/webapi.go @@ -1,6 +1,6 @@ -package webapicapabilities +package capabilities -type TargetRequestPayload struct { +type Request struct { URL string `json:"url"` // URL to query, only http and https protocols are supported. Method string `json:"method,omitempty"` // HTTP verb, defaults to GET. Headers map[string]string `json:"headers,omitempty"` // HTTP headers, defaults to empty. @@ -8,7 +8,7 @@ type TargetRequestPayload struct { TimeoutMs uint32 `json:"timeoutMs,omitempty"` // Timeout in milliseconds } -type TargetResponsePayload struct { +type Response struct { ExecutionError bool `json:"executionError"` // true if there were non-HTTP errors. false if HTTP request was sent regardless of status (2xx, 4xx, 5xx) ErrorMessage string `json:"errorMessage,omitempty"` // error message in case of failure StatusCode int `json:"statusCode,omitempty"` // HTTP status code diff --git a/core/services/standardcapabilities/delegate.go b/core/services/standardcapabilities/delegate.go index 87d61617762..17e7cf5c12f 100644 --- a/core/services/standardcapabilities/delegate.go +++ b/core/services/standardcapabilities/delegate.go @@ -14,9 +14,11 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types/core" "github.com/smartcontractkit/chainlink/v2/core/capabilities/compute" gatewayconnector "github.com/smartcontractkit/chainlink/v2/core/capabilities/gateway_connector" - trigger "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi" webapitarget "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi/target" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi/trigger" "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/chaintype" @@ -210,13 +212,13 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.Ser if len(spec.StandardCapabilitiesSpec.Config) == 0 { return nil, errors.New("config is empty") } - var targetCfg webapitarget.ServiceConfig + var targetCfg webapi.ServiceConfig err := toml.Unmarshal([]byte(spec.StandardCapabilitiesSpec.Config), &targetCfg) if err != nil { return nil, err } lggr := d.logger.Named("WebAPITarget") - handler, err := webapitarget.NewConnectorHandler(connector, targetCfg, lggr) + handler, err := webapi.NewOutgoingConnectorHandler(connector, targetCfg, capabilities.MethodWebAPITarget, lggr) if err != nil { return nil, err } @@ -228,7 +230,31 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.Ser } if spec.StandardCapabilitiesSpec.Command == commandOverrideForCustomComputeAction { - computeSrvc := compute.NewAction(log, d.registry) + if d.gatewayConnectorWrapper == nil { + return nil, errors.New("gateway connector is required for custom compute capability") + } + + if len(spec.StandardCapabilitiesSpec.Config) == 0 { + return nil, errors.New("config is empty") + } + + var fetchCfg webapi.ServiceConfig + err := toml.Unmarshal([]byte(spec.StandardCapabilitiesSpec.Config), &fetchCfg) + if err != nil { + return nil, err + } + lggr := d.logger.Named("ComputeAction") + + handler, err := webapi.NewOutgoingConnectorHandler(d.gatewayConnectorWrapper.GetGatewayConnector(), fetchCfg, capabilities.MethodComputeAction, lggr) + if err != nil { + return nil, err + } + + idGeneratorFn := func() string { + return uuid.New().String() + } + + computeSrvc := compute.NewAction(fetchCfg, log, d.registry, handler, idGeneratorFn) return []job.ServiceCtx{computeSrvc}, nil } diff --git a/core/services/workflows/engine_test.go b/core/services/workflows/engine_test.go index 67ec5c49cac..72e1583a5ca 100644 --- a/core/services/workflows/engine_test.go +++ b/core/services/workflows/engine_test.go @@ -20,6 +20,9 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/workflows" "github.com/smartcontractkit/chainlink-common/pkg/workflows/sdk" "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/webapi" + gcmocks "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector/mocks" + ghcapabilities "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" coreCap "github.com/smartcontractkit/chainlink/v2/core/capabilities" "github.com/smartcontractkit/chainlink/v2/core/capabilities/compute" @@ -27,6 +30,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/wasmtest" "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" "github.com/smartcontractkit/chainlink/v2/core/services/job" p2ptypes "github.com/smartcontractkit/chainlink/v2/core/services/p2p/types" "github.com/smartcontractkit/chainlink/v2/core/services/registrysyncer" @@ -1416,8 +1420,24 @@ func TestEngine_WithCustomComputeStep(t *testing.T) { ctx := testutils.Context(t) log := logger.TestLogger(t) reg := coreCap.NewRegistry(logger.TestLogger(t)) + cfg := webapi.ServiceConfig{ + RateLimiter: common.RateLimiterConfig{ + GlobalRPS: 100.0, + GlobalBurst: 100, + PerSenderRPS: 100.0, + PerSenderBurst: 100, + }, + } + + connector := gcmocks.NewGatewayConnector(t) + handler, err := webapi.NewOutgoingConnectorHandler( + connector, + cfg, + ghcapabilities.MethodComputeAction, log) + require.NoError(t, err) - compute := compute.NewAction(log, reg) + idGeneratorFn := func() string { return "validRequestID" } + compute := compute.NewAction(cfg, log, reg, handler, idGeneratorFn) require.NoError(t, compute.Start(ctx)) defer compute.Close() @@ -1463,8 +1483,23 @@ func TestEngine_CustomComputePropagatesBreaks(t *testing.T) { ctx := testutils.Context(t) log := logger.TestLogger(t) reg := coreCap.NewRegistry(logger.TestLogger(t)) + cfg := webapi.ServiceConfig{ + RateLimiter: common.RateLimiterConfig{ + GlobalRPS: 100.0, + GlobalBurst: 100, + PerSenderRPS: 100.0, + PerSenderBurst: 100, + }, + } + connector := gcmocks.NewGatewayConnector(t) + handler, err := webapi.NewOutgoingConnectorHandler( + connector, + cfg, + ghcapabilities.MethodComputeAction, log) + require.NoError(t, err) - compute := compute.NewAction(log, reg) + idGeneratorFn := func() string { return "validRequestID" } + compute := compute.NewAction(cfg, log, reg, handler, idGeneratorFn) require.NoError(t, compute.Start(ctx)) defer compute.Close()