From 1a9f8ccbbe950dc97b5027e9223d51308aa63366 Mon Sep 17 00:00:00 2001 From: Cedric Date: Wed, 13 Nov 2024 15:35:50 +0000 Subject: [PATCH] Add worker pool to WASM capability (#15088) * [chore] Add worker pool to compute capability - Also add step-level timeout to engine. This was removed when we moved away from ExecuteSync(). * WIP * Some more comments --- core/capabilities/compute/compute.go | 119 +++++++++++++++--- core/capabilities/compute/compute_test.go | 20 +-- core/capabilities/compute/transformer.go | 54 ++++---- core/capabilities/compute/transformer_test.go | 21 +++- .../services/standardcapabilities/delegate.go | 8 +- core/services/workflows/engine.go | 13 +- core/services/workflows/engine_test.go | 32 ++--- 7 files changed, 196 insertions(+), 71 deletions(-) diff --git a/core/capabilities/compute/compute.go b/core/capabilities/compute/compute.go index 7e6961d2e8a..78a4cc1e033 100644 --- a/core/capabilities/compute/compute.go +++ b/core/capabilities/compute/compute.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "strings" + "sync" "time" "github.com/google/uuid" @@ -19,6 +20,7 @@ import ( capabilitiespb "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb" "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" coretypes "github.com/smartcontractkit/chainlink-common/pkg/types/core" "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" wasmpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" @@ -73,7 +75,8 @@ var ( var _ capabilities.ActionCapability = (*Compute)(nil) type Compute struct { - log logger.Logger + stopCh services.StopChan + log logger.Logger // emitter is used to emit messages from the WASM module to a configured collector. emitter custmsg.MessageEmitter @@ -82,9 +85,13 @@ type Compute struct { // transformer is used to transform a values.Map into a ParsedConfig struct on each execution // of a request. - transformer ConfigTransformer + transformer *transformer outgoingConnectorHandler *webapi.OutgoingConnectorHandler idGenerator func() string + + numWorkers int + queue chan request + wg sync.WaitGroup } func (c *Compute) RegisterToWorkflow(ctx context.Context, request capabilities.RegisterToWorkflowRequest) error { @@ -100,35 +107,76 @@ func generateID(binary []byte) string { return fmt.Sprintf("%x", id) } -func copyRequest(req capabilities.CapabilityRequest) capabilities.CapabilityRequest { - return capabilities.CapabilityRequest{ - Metadata: req.Metadata, - Inputs: req.Inputs.CopyMap(), - Config: req.Config.CopyMap(), +func (c *Compute) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { + ch, err := c.enqueueRequest(ctx, request) + if err != nil { + return capabilities.CapabilityResponse{}, err + } + + select { + case <-c.stopCh: + return capabilities.CapabilityResponse{}, errors.New("service shutting down, aborting request") + case <-ctx.Done(): + return capabilities.CapabilityResponse{}, fmt.Errorf("request cancelled by upstream: %w", ctx.Err()) + case resp := <-ch: + return resp.resp, resp.err } } -func (c *Compute) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { - copied := copyRequest(request) +type request struct { + ch chan response + req capabilities.CapabilityRequest + ctx func() context.Context +} - cfg, err := c.transformer.Transform(copied.Config) +type response struct { + resp capabilities.CapabilityResponse + err error +} + +func (c *Compute) enqueueRequest(ctx context.Context, req capabilities.CapabilityRequest) (<-chan response, error) { + ch := make(chan response) + r := request{ + ch: ch, + req: req, + ctx: func() context.Context { return ctx }, + } + select { + case <-c.stopCh: + return nil, errors.New("service shutting down, aborting request") + case <-ctx.Done(): + return nil, fmt.Errorf("could not enqueue request: %w", ctx.Err()) + case c.queue <- r: + return ch, nil + } +} + +func (c *Compute) execute(ctx context.Context, respCh chan response, req capabilities.CapabilityRequest) { + copiedReq, cfg, err := c.transformer.Transform(req) if err != nil { - return capabilities.CapabilityResponse{}, fmt.Errorf("invalid request: could not transform config: %w", err) + respCh <- response{err: fmt.Errorf("invalid request: could not transform config: %w", err)} + return } id := generateID(cfg.Binary) m, ok := c.modules.get(id) if !ok { - mod, err := c.initModule(id, cfg.ModuleConfig, cfg.Binary, request.Metadata) - if err != nil { - return capabilities.CapabilityResponse{}, err + mod, innerErr := c.initModule(id, cfg.ModuleConfig, cfg.Binary, copiedReq.Metadata) + if innerErr != nil { + respCh <- response{err: innerErr} + return } m = mod } - return c.executeWithModule(ctx, m.module, cfg.Config, request) + resp, err := c.executeWithModule(ctx, m.module, cfg.Config, copiedReq) + select { + case <-c.stopCh: + case <-ctx.Done(): + case respCh <- response{resp: resp, err: err}: + } } func (c *Compute) initModule(id string, cfg *host.ModuleConfig, binary []byte, requestMetadata capabilities.RequestMetadata) (*module, error) { @@ -196,11 +244,35 @@ func (c *Compute) Info(ctx context.Context) (capabilities.CapabilityInfo, error) func (c *Compute) Start(ctx context.Context) error { c.modules.start() + + c.wg.Add(c.numWorkers) + for i := 0; i < c.numWorkers; i++ { + go func() { + innerCtx, cancel := c.stopCh.NewCtx() + defer cancel() + + defer c.wg.Done() + c.worker(innerCtx) + }() + } return c.registry.Add(ctx, c) } +func (c *Compute) worker(ctx context.Context) { + for { + select { + case <-c.stopCh: + return + case req := <-c.queue: + c.execute(req.ctx(), req.ch, req.req) + } + } +} + func (c *Compute) Close() error { c.modules.close() + close(c.stopCh) + c.wg.Wait() return nil } @@ -270,18 +342,31 @@ func (c *Compute) createFetcher() func(ctx context.Context, req *wasmpb.FetchReq } } +const ( + defaultNumWorkers = 3 +) + +type Config struct { + webapi.ServiceConfig + NumWorkers int +} + func NewAction( - config webapi.ServiceConfig, + config Config, log logger.Logger, registry coretypes.CapabilitiesRegistry, handler *webapi.OutgoingConnectorHandler, idGenerator func() string, opts ...func(*Compute), ) *Compute { + if config.NumWorkers == 0 { + config.NumWorkers = defaultNumWorkers + } var ( lggr = logger.Named(log, "CustomCompute") labeler = custmsg.NewLabeler() compute = &Compute{ + stopCh: make(services.StopChan), log: lggr, emitter: labeler, registry: registry, @@ -289,6 +374,8 @@ func NewAction( transformer: NewTransformer(lggr, labeler), outgoingConnectorHandler: handler, idGenerator: idGenerator, + queue: make(chan request), + numWorkers: defaultNumWorkers, } ) diff --git a/core/capabilities/compute/compute_test.go b/core/capabilities/compute/compute_test.go index ec82533f2bb..719bff82edf 100644 --- a/core/capabilities/compute/compute_test.go +++ b/core/capabilities/compute/compute_test.go @@ -32,12 +32,14 @@ const ( validRequestUUID = "d2fe6db9-beb4-47c9-b2d6-d3065ace111e" ) -var defaultConfig = webapi.ServiceConfig{ - RateLimiter: common.RateLimiterConfig{ - GlobalRPS: 100.0, - GlobalBurst: 100, - PerSenderRPS: 100.0, - PerSenderBurst: 100, +var defaultConfig = Config{ + ServiceConfig: webapi.ServiceConfig{ + RateLimiter: common.RateLimiterConfig{ + GlobalRPS: 100.0, + GlobalBurst: 100, + PerSenderRPS: 100.0, + PerSenderBurst: 100, + }, }, } @@ -45,17 +47,17 @@ type testHarness struct { registry *corecapabilities.Registry connector *gcmocks.GatewayConnector log logger.Logger - config webapi.ServiceConfig + config Config connectorHandler *webapi.OutgoingConnectorHandler compute *Compute } -func setup(t *testing.T, config webapi.ServiceConfig) testHarness { +func setup(t *testing.T, config Config) testHarness { log := logger.TestLogger(t) registry := capabilities.NewRegistry(log) connector := gcmocks.NewGatewayConnector(t) idGeneratorFn := func() string { return validRequestUUID } - connectorHandler, err := webapi.NewOutgoingConnectorHandler(connector, config, ghcapabilities.MethodComputeAction, log) + connectorHandler, err := webapi.NewOutgoingConnectorHandler(connector, config.ServiceConfig, ghcapabilities.MethodComputeAction, log) require.NoError(t, err) compute := NewAction(config, log, registry, connectorHandler, idGeneratorFn) diff --git a/core/capabilities/compute/transformer.go b/core/capabilities/compute/transformer.go index 99efcda8323..3b4ae4cfa69 100644 --- a/core/capabilities/compute/transformer.go +++ b/core/capabilities/compute/transformer.go @@ -5,21 +5,13 @@ import ( "fmt" "time" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" ) -type Transformer[T any, U any] interface { - // Transform changes a struct of type T into a struct of type U. Accepts a variadic list of options to modify the - // output struct. - Transform(T, ...func(*U)) (*U, error) -} - -// ConfigTransformer is a Transformer that converts a values.Map into a ParsedConfig struct. -type ConfigTransformer = Transformer[*values.Map, ParsedConfig] - // ParsedConfig is a struct that contains the binary and config for a wasm module, as well as the module config. type ParsedConfig struct { Binary []byte @@ -36,25 +28,41 @@ type transformer struct { emitter custmsg.MessageEmitter } +func shallowCopy(m *values.Map) *values.Map { + to := values.EmptyMap() + + for k, v := range m.Underlying { + to.Underlying[k] = v + } + + return to +} + // Transform attempts to read a valid ParsedConfig from an arbitrary values map. The map must // contain the binary and config keys. Optionally the map may specify wasm module specific // configuration values such as maxMemoryMBs, timeout, and tickInterval. Default logger and // emitter for the module are taken from the transformer instance. Override these values with // the functional options. -func (t *transformer) Transform(in *values.Map, opts ...func(*ParsedConfig)) (*ParsedConfig, error) { - binary, err := popValue[[]byte](in, binaryKey) +func (t *transformer) Transform(req capabilities.CapabilityRequest, opts ...func(*ParsedConfig)) (capabilities.CapabilityRequest, *ParsedConfig, error) { + copiedReq := capabilities.CapabilityRequest{ + Inputs: req.Inputs, + Metadata: req.Metadata, + Config: shallowCopy(req.Config), + } + + binary, err := popValue[[]byte](copiedReq.Config, binaryKey) if err != nil { - return nil, NewInvalidRequestError(err) + return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err) } - config, err := popValue[[]byte](in, configKey) + config, err := popValue[[]byte](copiedReq.Config, configKey) if err != nil { - return nil, NewInvalidRequestError(err) + return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err) } - maxMemoryMBs, err := popOptionalValue[int64](in, maxMemoryMBsKey) + maxMemoryMBs, err := popOptionalValue[int64](copiedReq.Config, maxMemoryMBsKey) if err != nil { - return nil, NewInvalidRequestError(err) + return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err) } mc := &host.ModuleConfig{ @@ -63,30 +71,30 @@ func (t *transformer) Transform(in *values.Map, opts ...func(*ParsedConfig)) (*P Labeler: t.emitter, } - timeout, err := popOptionalValue[string](in, timeoutKey) + timeout, err := popOptionalValue[string](copiedReq.Config, timeoutKey) if err != nil { - return nil, NewInvalidRequestError(err) + return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err) } var td time.Duration if timeout != "" { td, err = time.ParseDuration(timeout) if err != nil { - return nil, NewInvalidRequestError(err) + return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err) } mc.Timeout = &td } - tickInterval, err := popOptionalValue[string](in, tickIntervalKey) + tickInterval, err := popOptionalValue[string](copiedReq.Config, tickIntervalKey) if err != nil { - return nil, NewInvalidRequestError(err) + return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err) } var ti time.Duration if tickInterval != "" { ti, err = time.ParseDuration(tickInterval) if err != nil { - return nil, NewInvalidRequestError(err) + return capabilities.CapabilityRequest{}, nil, NewInvalidRequestError(err) } mc.TickInterval = ti } @@ -101,7 +109,7 @@ func (t *transformer) Transform(in *values.Map, opts ...func(*ParsedConfig)) (*P opt(pc) } - return pc, nil + return copiedReq, pc, nil } func NewTransformer(lggr logger.Logger, emitter custmsg.MessageEmitter) *transformer { diff --git a/core/capabilities/compute/transformer_test.go b/core/capabilities/compute/transformer_test.go index 83131636462..ee77e20d6f6 100644 --- a/core/capabilities/compute/transformer_test.go +++ b/core/capabilities/compute/transformer_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" @@ -94,6 +95,9 @@ func Test_transformer(t *testing.T) { "binary": []byte{0x01, 0x02, 0x03}, "config": []byte{0x04, 0x05, 0x06}, }) + giveReq := capabilities.CapabilityRequest{ + Config: giveMap, + } require.NoError(t, err) wantTO := 4 * time.Second @@ -110,7 +114,7 @@ func Test_transformer(t *testing.T) { } tf := NewTransformer(lgger, emitter) - gotConfig, err := tf.Transform(giveMap) + _, gotConfig, err := tf.Transform(giveReq) require.NoError(t, err) assert.Equal(t, wantConfig, gotConfig) @@ -121,6 +125,9 @@ func Test_transformer(t *testing.T) { "binary": []byte{0x01, 0x02, 0x03}, "config": []byte{0x04, 0x05, 0x06}, }) + giveReq := capabilities.CapabilityRequest{ + Config: giveMap, + } require.NoError(t, err) wantConfig := &ParsedConfig{ @@ -133,7 +140,7 @@ func Test_transformer(t *testing.T) { } tf := NewTransformer(lgger, emitter) - gotConfig, err := tf.Transform(giveMap) + _, gotConfig, err := tf.Transform(giveReq) require.NoError(t, err) assert.Equal(t, wantConfig, gotConfig) @@ -145,10 +152,13 @@ func Test_transformer(t *testing.T) { "binary": []byte{0x01, 0x02, 0x03}, "config": []byte{0x04, 0x05, 0x06}, }) + giveReq := capabilities.CapabilityRequest{ + Config: giveMap, + } require.NoError(t, err) tf := NewTransformer(lgger, emitter) - _, err = tf.Transform(giveMap) + _, _, err = tf.Transform(giveReq) require.Error(t, err) require.ErrorContains(t, err, "invalid request") @@ -160,10 +170,13 @@ func Test_transformer(t *testing.T) { "binary": []byte{0x01, 0x02, 0x03}, "config": []byte{0x04, 0x05, 0x06}, }) + giveReq := capabilities.CapabilityRequest{ + Config: giveMap, + } require.NoError(t, err) tf := NewTransformer(lgger, emitter) - _, err = tf.Transform(giveMap) + _, _, err = tf.Transform(giveReq) require.Error(t, err) require.ErrorContains(t, err, "invalid request") diff --git a/core/services/standardcapabilities/delegate.go b/core/services/standardcapabilities/delegate.go index 80a60c334fc..a92e082dead 100644 --- a/core/services/standardcapabilities/delegate.go +++ b/core/services/standardcapabilities/delegate.go @@ -237,14 +237,14 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.Ser return nil, errors.New("config is empty") } - var fetchCfg webapi.ServiceConfig - err := toml.Unmarshal([]byte(spec.StandardCapabilitiesSpec.Config), &fetchCfg) + var cfg compute.Config + err := toml.Unmarshal([]byte(spec.StandardCapabilitiesSpec.Config), &cfg) if err != nil { return nil, err } lggr := d.logger.Named("ComputeAction") - handler, err := webapi.NewOutgoingConnectorHandler(d.gatewayConnectorWrapper.GetGatewayConnector(), fetchCfg, capabilities.MethodComputeAction, lggr) + handler, err := webapi.NewOutgoingConnectorHandler(d.gatewayConnectorWrapper.GetGatewayConnector(), cfg.ServiceConfig, capabilities.MethodComputeAction, lggr) if err != nil { return nil, err } @@ -253,7 +253,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.Ser return uuid.New().String() } - computeSrvc := compute.NewAction(fetchCfg, log, d.registry, handler, idGeneratorFn) + computeSrvc := compute.NewAction(cfg, log, d.registry, handler, idGeneratorFn) return []job.ServiceCtx{computeSrvc}, nil } diff --git a/core/services/workflows/engine.go b/core/services/workflows/engine.go index 8e2cb8e34cb..e20af85540d 100644 --- a/core/services/workflows/engine.go +++ b/core/services/workflows/engine.go @@ -113,6 +113,7 @@ type Engine struct { newWorkerTimeout time.Duration maxExecutionDuration time.Duration heartbeatCadence time.Duration + stepTimeoutDuration time.Duration // testing lifecycle hook to signal when an execution is finished. onExecutionFinished func(string) @@ -755,7 +756,10 @@ func (e *Engine) workerForStepRequest(ctx context.Context, msg stepRequest) { // TODO ks-462 inputs logCustMsg(ctx, cma, "executing step", l) - inputs, outputs, err := e.executeStep(ctx, l, msg) + stepCtx, cancel := context.WithTimeout(ctx, e.stepTimeoutDuration) + defer cancel() + + inputs, outputs, err := e.executeStep(stepCtx, l, msg) var stepStatus string switch { case errors.Is(capabilities.ErrStopExecution, err): @@ -1137,6 +1141,7 @@ type Config struct { Binary []byte SecretsFetcher secretsFetcher HeartbeatCadence time.Duration + StepTimeout time.Duration // For testing purposes only maxRetries int @@ -1152,6 +1157,7 @@ const ( defaultNewWorkerTimeout = 2 * time.Second defaultMaxExecutionDuration = 10 * time.Minute defaultHeartbeatCadence = 5 * time.Minute + defaultStepTimeout = 2 * time.Minute ) func NewEngine(ctx context.Context, cfg Config) (engine *Engine, err error) { @@ -1183,6 +1189,10 @@ func NewEngine(ctx context.Context, cfg Config) (engine *Engine, err error) { cfg.HeartbeatCadence = defaultHeartbeatCadence } + if cfg.StepTimeout == 0 { + cfg.StepTimeout = defaultStepTimeout + } + if cfg.retryMs == 0 { cfg.retryMs = 5000 } @@ -1235,6 +1245,7 @@ func NewEngine(ctx context.Context, cfg Config) (engine *Engine, err error) { triggerEvents: make(chan capabilities.TriggerResponse), stopCh: make(chan struct{}), newWorkerTimeout: cfg.NewWorkerTimeout, + stepTimeoutDuration: cfg.StepTimeout, maxExecutionDuration: cfg.MaxExecutionDuration, heartbeatCadence: cfg.HeartbeatCadence, onExecutionFinished: cfg.onExecutionFinished, diff --git a/core/services/workflows/engine_test.go b/core/services/workflows/engine_test.go index 5e87d4f7603..e6667fe0bc6 100644 --- a/core/services/workflows/engine_test.go +++ b/core/services/workflows/engine_test.go @@ -1429,19 +1429,21 @@ func TestEngine_WithCustomComputeStep(t *testing.T) { ctx := testutils.Context(t) log := logger.TestLogger(t) reg := coreCap.NewRegistry(logger.TestLogger(t)) - cfg := webapi.ServiceConfig{ - RateLimiter: common.RateLimiterConfig{ - GlobalRPS: 100.0, - GlobalBurst: 100, - PerSenderRPS: 100.0, - PerSenderBurst: 100, + cfg := compute.Config{ + ServiceConfig: webapi.ServiceConfig{ + RateLimiter: common.RateLimiterConfig{ + GlobalRPS: 100.0, + GlobalBurst: 100, + PerSenderRPS: 100.0, + PerSenderBurst: 100, + }, }, } connector := gcmocks.NewGatewayConnector(t) handler, err := webapi.NewOutgoingConnectorHandler( connector, - cfg, + cfg.ServiceConfig, ghcapabilities.MethodComputeAction, log) require.NoError(t, err) @@ -1493,18 +1495,20 @@ func TestEngine_CustomComputePropagatesBreaks(t *testing.T) { ctx := testutils.Context(t) log := logger.TestLogger(t) reg := coreCap.NewRegistry(logger.TestLogger(t)) - cfg := webapi.ServiceConfig{ - RateLimiter: common.RateLimiterConfig{ - GlobalRPS: 100.0, - GlobalBurst: 100, - PerSenderRPS: 100.0, - PerSenderBurst: 100, + cfg := compute.Config{ + ServiceConfig: webapi.ServiceConfig{ + RateLimiter: common.RateLimiterConfig{ + GlobalRPS: 100.0, + GlobalBurst: 100, + PerSenderRPS: 100.0, + PerSenderBurst: 100, + }, }, } connector := gcmocks.NewGatewayConnector(t) handler, err := webapi.NewOutgoingConnectorHandler( connector, - cfg, + cfg.ServiceConfig, ghcapabilities.MethodComputeAction, log) require.NoError(t, err)