From d736d9e0838983a021677bc608556b3994f46690 Mon Sep 17 00:00:00 2001 From: Matthew Pendrey Date: Fri, 14 Jun 2024 17:30:59 +0100 Subject: [PATCH] remote target wait for initiated threads to finish on close (#13524) --- .changeset/sour-pigs-develop.md | 5 + core/capabilities/remote/target/client.go | 74 +++++++++------ .../capabilities/remote/target/client_test.go | 4 +- .../remote/target/endtoend_test.go | 91 +++++++++++++++---- .../remote/target/request/client_request.go | 17 +++- .../target/request/client_request_test.go | 8 ++ core/capabilities/remote/target/server.go | 20 ++-- .../capabilities/remote/target/server_test.go | 10 +- 8 files changed, 160 insertions(+), 69 deletions(-) create mode 100644 .changeset/sour-pigs-develop.md diff --git a/.changeset/sour-pigs-develop.md b/.changeset/sour-pigs-develop.md new file mode 100644 index 00000000000..5737b20601f --- /dev/null +++ b/.changeset/sour-pigs-develop.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +#internal remote target wait until initiated threads exit on close diff --git a/core/capabilities/remote/target/client.go b/core/capabilities/remote/target/client.go index 2fb11930164..281e7ac5fc1 100644 --- a/core/capabilities/remote/target/client.go +++ b/core/capabilities/remote/target/client.go @@ -52,35 +52,14 @@ func NewClient(remoteCapabilityInfo commoncap.CapabilityInfo, localDonInfo commo } } -func (c *client) expireRequests() { - c.mutex.Lock() - defer c.mutex.Unlock() - - for messageID, req := range c.messageIDToCallerRequest { - if req.Expired() { - req.Cancel(errors.New("request expired")) - delete(c.messageIDToCallerRequest, messageID) - } - } -} - func (c *client) Start(ctx context.Context) error { return c.StartOnce(c.Name(), func() error { c.wg.Add(1) go func() { defer c.wg.Done() - ticker := time.NewTicker(c.requestTimeout) - defer ticker.Stop() - c.lggr.Info("TargetClient started") - for { - select { - case <-c.stopCh: - return - case <-ticker.C: - c.expireRequests() - } - } + c.checkForExpiredRequests() }() + c.lggr.Info("TargetClient started") return nil }) } @@ -88,12 +67,46 @@ func (c *client) Start(ctx context.Context) error { func (c *client) Close() error { return c.StopOnce(c.Name(), func() error { close(c.stopCh) + c.cancelAllRequests(errors.New("client closed")) c.wg.Wait() c.lggr.Info("TargetClient closed") return nil }) } +func (c *client) checkForExpiredRequests() { + ticker := time.NewTicker(c.requestTimeout) + defer ticker.Stop() + for { + select { + case <-c.stopCh: + return + case <-ticker.C: + c.expireRequests() + } + } +} + +func (c *client) expireRequests() { + c.mutex.Lock() + defer c.mutex.Unlock() + + for messageID, req := range c.messageIDToCallerRequest { + if req.Expired() { + req.Cancel(errors.New("request expired")) + delete(c.messageIDToCallerRequest, messageID) + } + } +} + +func (c *client) cancelAllRequests(err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + for _, req := range c.messageIDToCallerRequest { + req.Cancel(err) + } +} + func (c *client) Info(ctx context.Context) (commoncap.CapabilityInfo, error) { return c.remoteCapabilityInfo, nil } @@ -121,8 +134,11 @@ func (c *client) Execute(ctx context.Context, capReq commoncap.CapabilityRequest return nil, fmt.Errorf("request for message ID %s already exists", messageID) } - cCtx, _ := c.stopCh.NewCtx() - req, err := request.NewClientRequest(cCtx, c.lggr, capReq, messageID, c.remoteCapabilityInfo, c.localDONInfo, c.dispatcher, + // TODO confirm reasons for below workaround and see if can be resolved + // The context passed in by the workflow engine is cancelled prior to the results being read from the response channel + // The wrapping of the context with 'WithoutCancel' is a workaround for that behaviour. + requestCtx := context.WithoutCancel(ctx) + req, err := request.NewClientRequest(requestCtx, c.lggr, capReq, messageID, c.remoteCapabilityInfo, c.localDONInfo, c.dispatcher, c.requestTimeout) if err != nil { return nil, fmt.Errorf("failed to create client request: %w", err) @@ -146,11 +162,9 @@ func (c *client) Receive(msg *types.MessageBody) { return } - go func() { - if err := req.OnMessage(ctx, msg); err != nil { - c.lggr.Errorw("failed to add response to request", "messageID", messageID, "err", err) - } - }() + if err := req.OnMessage(ctx, msg); err != nil { + c.lggr.Errorw("failed to add response to request", "messageID", messageID, "err", err) + } } func GetMessageIDForRequest(req commoncap.CapabilityRequest) (string, error) { diff --git a/core/capabilities/remote/target/client_test.go b/core/capabilities/remote/target/client_test.go index 5f9261eed8f..a00ec6dda6c 100644 --- a/core/capabilities/remote/target/client_test.go +++ b/core/capabilities/remote/target/client_test.go @@ -152,7 +152,7 @@ func testClient(ctx context.Context, t *testing.T, numWorkflowPeers int, workflo ID: "workflow-don", } - broker := newTestMessageBroker() + broker := newTestAsyncMessageBroker(100) receivers := make([]remotetypes.Receiver, numCapabilityPeers) for i := 0; i < numCapabilityPeers; i++ { @@ -172,6 +172,8 @@ func testClient(ctx context.Context, t *testing.T, numWorkflowPeers int, workflo callers[i] = caller } + servicetest.Run(t, broker) + executeInputs, err := values.NewMap( map[string]any{ "executeValue1": "aValue1", diff --git a/core/capabilities/remote/target/endtoend_test.go b/core/capabilities/remote/target/endtoend_test.go index c9e9fea28f0..5c9dc191878 100644 --- a/core/capabilities/remote/target/endtoend_test.go +++ b/core/capabilities/remote/target/endtoend_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" commoncap "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" "github.com/smartcontractkit/chainlink-common/pkg/values" "github.com/smartcontractkit/chainlink/v2/core/capabilities/remote/target" @@ -214,7 +215,7 @@ func testRemoteTarget(ctx context.Context, t *testing.T, underlying commoncap.Ta F: workflowDonF, } - broker := newTestMessageBroker() + broker := newTestAsyncMessageBroker(1000) workflowDONs := map[string]commoncap.DON{ workflowDonInfo.ID: workflowDonInfo, @@ -240,6 +241,8 @@ func testRemoteTarget(ctx context.Context, t *testing.T, underlying commoncap.Ta workflowNodes[i] = workflowNode } + servicetest.Run(t, broker) + executeInputs, err := values.NewMap( map[string]any{ "executeValue1": "aValue1", @@ -271,49 +274,97 @@ func testRemoteTarget(ctx context.Context, t *testing.T, underlying commoncap.Ta wg.Wait() } -type testMessageBroker struct { +type testAsyncMessageBroker struct { + services.StateMachine nodes map[p2ptypes.PeerID]remotetypes.Receiver + + sendCh chan *remotetypes.MessageBody + + stopCh services.StopChan + wg sync.WaitGroup +} + +func (a *testAsyncMessageBroker) HealthReport() map[string]error { + return nil +} + +func (a *testAsyncMessageBroker) Name() string { + return "testAsyncMessageBroker" } -func newTestMessageBroker() *testMessageBroker { - return &testMessageBroker{ - nodes: make(map[p2ptypes.PeerID]remotetypes.Receiver), +func newTestAsyncMessageBroker(sendChBufferSize int) *testAsyncMessageBroker { + return &testAsyncMessageBroker{ + nodes: make(map[p2ptypes.PeerID]remotetypes.Receiver), + stopCh: make(services.StopChan), + sendCh: make(chan *remotetypes.MessageBody, sendChBufferSize), } } -func (r *testMessageBroker) NewDispatcherForNode(nodePeerID p2ptypes.PeerID) remotetypes.Dispatcher { +func (a *testAsyncMessageBroker) Start(ctx context.Context) error { + return a.StartOnce("testAsyncMessageBroker", func() error { + a.wg.Add(1) + go func() { + defer a.wg.Done() + + for { + select { + case <-a.stopCh: + return + case msg := <-a.sendCh: + receiverId := toPeerID(msg.Receiver) + + receiver, ok := a.nodes[receiverId] + if !ok { + panic("server not found for peer id") + } + + receiver.Receive(msg) + } + } + }() + return nil + }) +} + +func (a *testAsyncMessageBroker) Close() error { + return a.StopOnce("testAsyncMessageBroker", func() error { + close(a.stopCh) + + a.wg.Wait() + return nil + }) +} + +func (a *testAsyncMessageBroker) NewDispatcherForNode(nodePeerID p2ptypes.PeerID) remotetypes.Dispatcher { return &nodeDispatcher{ callerPeerID: nodePeerID, - broker: r, + broker: a, } } -func (r *testMessageBroker) RegisterReceiverNode(nodePeerID p2ptypes.PeerID, node remotetypes.Receiver) { - if _, ok := r.nodes[nodePeerID]; ok { +func (a *testAsyncMessageBroker) RegisterReceiverNode(nodePeerID p2ptypes.PeerID, node remotetypes.Receiver) { + if _, ok := a.nodes[nodePeerID]; ok { panic("node already registered") } - r.nodes[nodePeerID] = node + a.nodes[nodePeerID] = node } -func (r *testMessageBroker) Send(msg *remotetypes.MessageBody) { - receiverId := toPeerID(msg.Receiver) - - receiver, ok := r.nodes[receiverId] - if !ok { - panic("server not found for peer id") - } - - receiver.Receive(msg) +func (a *testAsyncMessageBroker) Send(msg *remotetypes.MessageBody) { + a.sendCh <- msg } func toPeerID(id []byte) p2ptypes.PeerID { return [32]byte(id) } +type broker interface { + Send(msg *remotetypes.MessageBody) +} + type nodeDispatcher struct { callerPeerID p2ptypes.PeerID - broker *testMessageBroker + broker broker } func (t *nodeDispatcher) Send(peerID p2ptypes.PeerID, msgBody *remotetypes.MessageBody) error { diff --git a/core/capabilities/remote/target/request/client_request.go b/core/capabilities/remote/target/request/client_request.go index eb33a9ac70a..b48aa28207a 100644 --- a/core/capabilities/remote/target/request/client_request.go +++ b/core/capabilities/remote/target/request/client_request.go @@ -21,6 +21,7 @@ import ( ) type ClientRequest struct { + cancelFn context.CancelFunc responseCh chan commoncap.CapabilityResponse createdAt time.Time responseIDCount map[[32]byte]int @@ -33,6 +34,7 @@ type ClientRequest struct { respSent bool mux sync.Mutex + wg *sync.WaitGroup } func NewClientRequest(ctx context.Context, lggr logger.Logger, req commoncap.CapabilityRequest, messageID string, @@ -56,9 +58,14 @@ func NewClientRequest(ctx context.Context, lggr logger.Logger, req commoncap.Cap lggr.Debugw("sending request to peers", "execID", req.Metadata.WorkflowExecutionID, "schedule", peerIDToTransmissionDelay) responseReceived := make(map[p2ptypes.PeerID]bool) + + ctxWithCancel, cancelFn := context.WithCancel(ctx) + wg := &sync.WaitGroup{} for peerID, delay := range peerIDToTransmissionDelay { responseReceived[peerID] = false - go func(peerID ragep2ptypes.PeerID, delay time.Duration) { + wg.Add(1) + go func(ctx context.Context, peerID ragep2ptypes.PeerID, delay time.Duration) { + defer wg.Done() message := &types.MessageBody{ CapabilityId: remoteCapabilityInfo.ID, CapabilityDonId: remoteCapabilityDonInfo.ID, @@ -69,7 +76,7 @@ func NewClientRequest(ctx context.Context, lggr logger.Logger, req commoncap.Cap } select { - case <-ctx.Done(): + case <-ctxWithCancel.Done(): lggr.Debugw("context done, not sending request to peer", "execID", req.Metadata.WorkflowExecutionID, "peerID", peerID) return case <-time.After(delay): @@ -79,10 +86,11 @@ func NewClientRequest(ctx context.Context, lggr logger.Logger, req commoncap.Cap lggr.Errorw("failed to send message", "peerID", peerID, "err", err) } } - }(peerID, delay) + }(ctxWithCancel, peerID, delay) } return &ClientRequest{ + cancelFn: cancelFn, createdAt: time.Now(), requestTimeout: requestTimeout, requiredIdenticalResponses: int(remoteCapabilityDonInfo.F + 1), @@ -90,6 +98,7 @@ func NewClientRequest(ctx context.Context, lggr logger.Logger, req commoncap.Cap errorCount: make(map[string]int), responseReceived: responseReceived, responseCh: make(chan commoncap.CapabilityResponse, 1), + wg: wg, }, nil } @@ -102,6 +111,8 @@ func (c *ClientRequest) Expired() bool { } func (c *ClientRequest) Cancel(err error) { + c.cancelFn() + c.wg.Wait() c.mux.Lock() defer c.mux.Unlock() if !c.respSent { diff --git a/core/capabilities/remote/target/request/client_request_test.go b/core/capabilities/remote/target/request/client_request_test.go index e4b0d9da88e..a053623cd2c 100644 --- a/core/capabilities/remote/target/request/client_request_test.go +++ b/core/capabilities/remote/target/request/client_request_test.go @@ -2,6 +2,7 @@ package request_test import ( "context" + "errors" "testing" "time" @@ -101,6 +102,8 @@ func Test_ClientRequest_MessageValidation(t *testing.T) { dispatcher := &clientRequestTestDispatcher{msgs: make(chan *types.MessageBody, 100)} request, err := request.NewClientRequest(ctx, lggr, capabilityRequest, messageID, capInfo, workflowDonInfo, dispatcher, 10*time.Minute) + defer request.Cancel(errors.New("test end")) + require.NoError(t, err) capabilityResponse2 := commoncap.CapabilityResponse{ @@ -142,6 +145,7 @@ func Test_ClientRequest_MessageValidation(t *testing.T) { request, err := request.NewClientRequest(ctx, lggr, capabilityRequest, messageID, capInfo, workflowDonInfo, dispatcher, 10*time.Minute) require.NoError(t, err) + defer request.Cancel(errors.New("test end")) msg.Sender = capabilityPeers[0][:] err = request.OnMessage(ctx, msg) @@ -167,6 +171,7 @@ func Test_ClientRequest_MessageValidation(t *testing.T) { request, err := request.NewClientRequest(ctx, lggr, capabilityRequest, messageID, capInfo, workflowDonInfo, dispatcher, 10*time.Minute) require.NoError(t, err) + defer request.Cancel(errors.New("test end")) msg.Sender = capabilityPeers[0][:] err = request.OnMessage(ctx, msg) @@ -189,6 +194,7 @@ func Test_ClientRequest_MessageValidation(t *testing.T) { request, err := request.NewClientRequest(ctx, lggr, capabilityRequest, messageID, capInfo, workflowDonInfo, dispatcher, 10*time.Minute) require.NoError(t, err) + defer request.Cancel(errors.New("test end")) <-dispatcher.msgs <-dispatcher.msgs @@ -226,6 +232,7 @@ func Test_ClientRequest_MessageValidation(t *testing.T) { request, err := request.NewClientRequest(ctx, lggr, capabilityRequest, messageID, capInfo, workflowDonInfo, dispatcher, 10*time.Minute) require.NoError(t, err) + defer request.Cancel(errors.New("test end")) <-dispatcher.msgs <-dispatcher.msgs @@ -275,6 +282,7 @@ func Test_ClientRequest_MessageValidation(t *testing.T) { request, err := request.NewClientRequest(ctx, lggr, capabilityRequest, messageID, capInfo, workflowDonInfo, dispatcher, 10*time.Minute) require.NoError(t, err) + defer request.Cancel(errors.New("test end")) <-dispatcher.msgs <-dispatcher.msgs diff --git a/core/capabilities/remote/target/server.go b/core/capabilities/remote/target/server.go index 83451049e75..9a578eebd3e 100644 --- a/core/capabilities/remote/target/server.go +++ b/core/capabilities/remote/target/server.go @@ -106,8 +106,7 @@ func (r *server) expireRequests() { } } -// Receive handles incoming messages from remote nodes and dispatches them to the corresponding request without blocking -// the client. +// Receive handles incoming messages from remote nodes and dispatches them to the corresponding request. func (r *server) Receive(msg *types.MessageBody) { r.receiveLock.Lock() defer r.receiveLock.Unlock() @@ -136,16 +135,13 @@ func (r *server) Receive(msg *types.MessageBody) { req := r.requestIDToRequest[requestID] - r.wg.Add(1) - go func() { - defer r.wg.Done() - ctx, cancel := r.stopCh.NewCtx() - defer cancel() - err := req.OnMessage(ctx, msg) - if err != nil { - r.lggr.Errorw("request failed to OnMessage new message", "request", req, "err", err) - } - }() + // TODO context should be received from the dispatcher here - pending KS-296 + ctx, cancel := r.stopCh.NewCtx() + defer cancel() + err := req.OnMessage(ctx, msg) + if err != nil { + r.lggr.Errorw("request failed to OnMessage new message", "request", req, "err", err) + } } func GetMessageID(msg *types.MessageBody) string { diff --git a/core/capabilities/remote/target/server_test.go b/core/capabilities/remote/target/server_test.go index 80c0d5bc6e0..507612c143c 100644 --- a/core/capabilities/remote/target/server_test.go +++ b/core/capabilities/remote/target/server_test.go @@ -135,14 +135,18 @@ func testRemoteTargetServer(ctx context.Context, t *testing.T, F: workflowDonF, } - broker := newTestMessageBroker() + var srvcs []services.Service + broker := newTestAsyncMessageBroker(1000) + err := broker.Start(context.Background()) + require.NoError(t, err) + srvcs = append(srvcs, broker) workflowDONs := map[string]commoncap.DON{ workflowDonInfo.ID: workflowDonInfo, } capabilityNodes := make([]remotetypes.Receiver, numCapabilityPeers) - srvcs := make([]services.Service, numCapabilityPeers) + for i := 0; i < numCapabilityPeers; i++ { capabilityPeer := capabilityPeers[i] capabilityDispatcher := broker.NewDispatcherForNode(capabilityPeer) @@ -151,7 +155,7 @@ func testRemoteTargetServer(ctx context.Context, t *testing.T, require.NoError(t, capabilityNode.Start(ctx)) broker.RegisterReceiverNode(capabilityPeer, capabilityNode) capabilityNodes[i] = capabilityNode - srvcs[i] = capabilityNode + srvcs = append(srvcs, capabilityNode) } workflowNodes := make([]*serverTestClient, numWorkflowPeers)