Skip to content

Commit

Permalink
reuse capability registration logic from the trigger in the executabl…
Browse files Browse the repository at this point in the history
…e capability
  • Loading branch information
ettec committed Dec 16, 2024
1 parent 62c2376 commit 8236209
Show file tree
Hide file tree
Showing 33 changed files with 2,015 additions and 682 deletions.
21 changes: 17 additions & 4 deletions core/capabilities/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/smartcontractkit/chainlink-common/pkg/services"
"github.com/smartcontractkit/chainlink-common/pkg/values"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/remote"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/remote/aggregation"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/remote/executable"
remotetypes "github.com/smartcontractkit/chainlink/v2/core/capabilities/remote/types"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/streams"
Expand Down Expand Up @@ -60,6 +61,7 @@ func unmarshalCapabilityConfig(data []byte) (capabilities.CapabilityConfiguratio

var remoteTriggerConfig *capabilities.RemoteTriggerConfig
var remoteTargetConfig *capabilities.RemoteTargetConfig
var remoteExecutableConfig *capabilities.RemoteExecutableConfig

switch cconf.GetRemoteConfig().(type) {
case *capabilitiespb.CapabilityConfig_RemoteTriggerConfig:
Expand All @@ -73,6 +75,12 @@ func unmarshalCapabilityConfig(data []byte) (capabilities.CapabilityConfiguratio
prtc := cconf.GetRemoteTargetConfig()
remoteTargetConfig = &capabilities.RemoteTargetConfig{}
remoteTargetConfig.RequestHashExcludedAttributes = prtc.RequestHashExcludedAttributes
case *capabilitiespb.CapabilityConfig_RemoteExecutableConfig:
prtc := cconf.GetRemoteExecutableConfig()
remoteExecutableConfig = &capabilities.RemoteExecutableConfig{}
remoteExecutableConfig.RequestHashExcludedAttributes = prtc.RequestHashExcludedAttributes
remoteExecutableConfig.RegistrationRefresh = prtc.RegistrationRefresh.AsDuration()
remoteExecutableConfig.RegistrationExpiry = prtc.RegistrationExpiry.AsDuration()
}

dc, err := values.FromMapValueProto(cconf.DefaultConfig)
Expand All @@ -81,9 +89,10 @@ func unmarshalCapabilityConfig(data []byte) (capabilities.CapabilityConfiguratio
}

return capabilities.CapabilityConfiguration{
DefaultConfig: dc,
RemoteTriggerConfig: remoteTriggerConfig,
RemoteTargetConfig: remoteTargetConfig,
DefaultConfig: dc,
RemoteTriggerConfig: remoteTriggerConfig,
RemoteTargetConfig: remoteTargetConfig,
RemoteExecutableConfig: remoteExecutableConfig,
}, nil
}

Expand Down Expand Up @@ -280,7 +289,7 @@ func (w *launcher) addRemoteCapabilities(ctx context.Context, myDON registrysync
w.lggr,
)
} else {
aggregator = remote.NewDefaultModeAggregator(uint32(remoteDON.F) + 1)
aggregator = aggregation.NewDefaultModeAggregator(uint32(remoteDON.F) + 1)
}

// TODO: We need to implement a custom, Mercury-specific
Expand All @@ -307,7 +316,9 @@ func (w *launcher) addRemoteCapabilities(ctx context.Context, myDON registrysync
case capabilities.CapabilityTypeAction:
newActionFn := func(info capabilities.CapabilityInfo) (capabilityService, error) {
client := executable.NewClient(
capabilityConfig.RemoteExecutableConfig,
info,
remoteDON.DON,
myDON.DON,
w.dispatcher,
defaultTargetRequestTimeout,
Expand All @@ -325,7 +336,9 @@ func (w *launcher) addRemoteCapabilities(ctx context.Context, myDON registrysync
case capabilities.CapabilityTypeTarget:
newTargetFn := func(info capabilities.CapabilityInfo) (capabilityService, error) {
client := executable.NewClient(
capabilityConfig.RemoteExecutableConfig,
info,
remoteDON.DON,
myDON.DON,
w.dispatcher,
defaultTargetRequestTimeout,
Expand Down
58 changes: 58 additions & 0 deletions core/capabilities/remote/aggregation/default_mode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package aggregation

import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"

commoncap "github.com/smartcontractkit/chainlink-common/pkg/capabilities"
"github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb"
remotetypes "github.com/smartcontractkit/chainlink/v2/core/capabilities/remote/types"
)

// Default MODE Aggregator needs a configurable number of identical responses for aggregation to succeed
type defaultModeAggregator struct {
minIdenticalResponses uint32
}

var _ remotetypes.Aggregator = &defaultModeAggregator{}

func NewDefaultModeAggregator(minIdenticalResponses uint32) *defaultModeAggregator {
return &defaultModeAggregator{
minIdenticalResponses: minIdenticalResponses,
}
}

func (a *defaultModeAggregator) Aggregate(_ string, responses [][]byte) (commoncap.TriggerResponse, error) {
found, err := AggregateModeRaw(responses, a.minIdenticalResponses)
if err != nil {
return commoncap.TriggerResponse{}, fmt.Errorf("failed to aggregate responses, err: %w", err)
}

unmarshaled, err := pb.UnmarshalTriggerResponse(found)
if err != nil {
return commoncap.TriggerResponse{}, fmt.Errorf("failed to unmarshal aggregated responses, err: %w", err)
}
return unmarshaled, nil
}

func AggregateModeRaw(elemList [][]byte, minIdenticalResponses uint32) ([]byte, error) {
hashToCount := make(map[string]uint32)
var found []byte
for _, elem := range elemList {
hasher := sha256.New()
hasher.Write(elem)
sha := hex.EncodeToString(hasher.Sum(nil))
hashToCount[sha]++
if hashToCount[sha] >= minIdenticalResponses {
found = elem
// update in case we find another elem with an even higher count
minIdenticalResponses = hashToCount[sha]
}
}
if found == nil {
return nil, errors.New("not enough identical responses found")
}
return found, nil
}
51 changes: 51 additions & 0 deletions core/capabilities/remote/aggregation/default_mode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package aggregation

import (
"testing"

"github.com/stretchr/testify/require"

commoncap "github.com/smartcontractkit/chainlink-common/pkg/capabilities"
"github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb"
"github.com/smartcontractkit/chainlink-common/pkg/values"
)

var (
triggerEvent1 = map[string]any{"event": "triggerEvent1"}
triggerEvent2 = map[string]any{"event": "triggerEvent2"}
)

func TestDefaultModeAggregator_Aggregate(t *testing.T) {
val, err := values.NewMap(triggerEvent1)
require.NoError(t, err)
capResponse1 := commoncap.TriggerResponse{
Event: commoncap.TriggerEvent{
Outputs: val,
},
Err: nil,
}
marshaled1, err := pb.MarshalTriggerResponse(capResponse1)
require.NoError(t, err)

val2, err := values.NewMap(triggerEvent2)
require.NoError(t, err)
capResponse2 := commoncap.TriggerResponse{
Event: commoncap.TriggerEvent{
Outputs: val2,
},
Err: nil,
}
marshaled2, err := pb.MarshalTriggerResponse(capResponse2)
require.NoError(t, err)

agg := NewDefaultModeAggregator(2)
_, err = agg.Aggregate("", [][]byte{marshaled1})
require.Error(t, err)

_, err = agg.Aggregate("", [][]byte{marshaled1, marshaled2})
require.Error(t, err)

res, err := agg.Aggregate("", [][]byte{marshaled1, marshaled2, marshaled1})
require.NoError(t, err)
require.Equal(t, res, capResponse1)
}
31 changes: 16 additions & 15 deletions core/capabilities/remote/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package remote
import (
"context"
"fmt"
"strconv"
"sync"
"time"

Expand Down Expand Up @@ -42,8 +43,8 @@ type dispatcher struct {
}

type key struct {
capId string
donId uint32
capID string
donID uint32
}

var _ services.Service = &dispatcher{}
Expand Down Expand Up @@ -74,7 +75,7 @@ func (d *dispatcher) Start(ctx context.Context) error {
d.peer = d.peerWrapper.GetPeer()
d.peerID = d.peer.ID()
if d.peer == nil {
return fmt.Errorf("peer is not initialized")
return errors.New("peer is not initialized")
}
d.wg.Add(1)
go func() {
Expand All @@ -96,20 +97,20 @@ func (d *dispatcher) Close() error {
var capReceiveChannelUsage = promauto.NewGaugeVec(prometheus.GaugeOpts{
Name: "capability_receive_channel_usage",
Help: "The usage of the receive channel for each capability, 0 indicates empty, 1 indicates full.",
}, []string{"capabilityId", "donId"})
}, []string{"capabilityId", "donID"})

type receiver struct {
cancel context.CancelFunc
ch chan *types.MessageBody
}

func (d *dispatcher) SetReceiver(capabilityId string, donId uint32, rec types.Receiver) error {
func (d *dispatcher) SetReceiver(capabilityID string, donID uint32, rec types.Receiver) error {
d.mu.Lock()
defer d.mu.Unlock()
k := key{capabilityId, donId}
k := key{capabilityID, donID}
_, ok := d.receivers[k]
if ok {
return fmt.Errorf("%w: receiver already exists for capability %s and don %d", ErrReceiverExists, capabilityId, donId)
return fmt.Errorf("%w: receiver already exists for capability %s and don %d", ErrReceiverExists, capabilityID, donID)
}

receiverCh := make(chan *types.MessageBody, d.cfg.ReceiverBufferSize())
Expand All @@ -134,24 +135,24 @@ func (d *dispatcher) SetReceiver(capabilityId string, donId uint32, rec types.Re
ch: receiverCh,
}

d.lggr.Debugw("receiver set", "capabilityId", capabilityId, "donId", donId)
d.lggr.Debugw("receiver set", "capabilityID", capabilityID, "donID", donID)
return nil
}

func (d *dispatcher) RemoveReceiver(capabilityId string, donId uint32) {
func (d *dispatcher) RemoveReceiver(capabilityID string, donID uint32) {
d.mu.Lock()
defer d.mu.Unlock()

receiverKey := key{capabilityId, donId}
receiverKey := key{capabilityID, donID}
if receiver, ok := d.receivers[receiverKey]; ok {
receiver.cancel()
delete(d.receivers, receiverKey)
d.lggr.Debugw("receiver removed", "capabilityId", capabilityId, "donId", donId)
d.lggr.Debugw("receiver removed", "capabilityID", capabilityID, "donID", donID)
}
}

func (d *dispatcher) Send(peerID p2ptypes.PeerID, msgBody *types.MessageBody) error {
msgBody.Version = uint32(d.cfg.SupportedVersion())
msgBody.Version = uint32(d.cfg.SupportedVersion()) //nolint:gosec // disable G115: supported version is not expected to exceed uint32 max value
msgBody.Sender = d.peerID[:]
msgBody.Receiver = peerID[:]
msgBody.Timestamp = time.Now().UnixMilli()
Expand Down Expand Up @@ -194,17 +195,17 @@ func (d *dispatcher) receive() {
receiver, ok := d.receivers[k]
d.mu.RUnlock()
if !ok {
d.lggr.Debugw("received message for unregistered capability", "capabilityId", SanitizeLogString(k.capId), "donId", k.donId)
d.lggr.Debugw("received message for unregistered capability", "capabilityId", SanitizeLogString(k.capID), "donID", k.donID)
d.tryRespondWithError(msg.Sender, body, types.Error_CAPABILITY_NOT_FOUND)
continue
}

receiverQueueUsage := float64(len(receiver.ch)) / float64(d.cfg.ReceiverBufferSize())
capReceiveChannelUsage.WithLabelValues(k.capId, fmt.Sprint(k.donId)).Set(receiverQueueUsage)
capReceiveChannelUsage.WithLabelValues(k.capID, strconv.FormatUint(uint64(k.donID), 10)).Set(receiverQueueUsage)
select {
case receiver.ch <- body:
default:
d.lggr.Warnw("receiver channel full, dropping message", "capabilityId", k.capId, "donId", k.donId)
d.lggr.Warnw("receiver channel full, dropping message", "capabilityId", k.capID, "donID", k.donID)
}
}
}
Expand Down
35 changes: 34 additions & 1 deletion core/capabilities/remote/executable/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/smartcontractkit/chainlink-common/pkg/services"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/remote"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/remote/executable/request"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/remote/registration"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/remote/types"
"github.com/smartcontractkit/chainlink/v2/core/logger"
)
Expand All @@ -30,6 +31,7 @@ type client struct {
localDONInfo commoncap.DON
dispatcher types.Dispatcher
requestTimeout time.Duration
registrationClient *registration.Client

requestIDToCallerRequest map[string]*request.ClientRequest
mutex sync.Mutex
Expand All @@ -41,21 +43,32 @@ var _ commoncap.ExecutableCapability = &client{}
var _ types.Receiver = &client{}
var _ services.Service = &client{}

func NewClient(remoteCapabilityInfo commoncap.CapabilityInfo, localDonInfo commoncap.DON, dispatcher types.Dispatcher,
func NewClient(remoteExecutableConfig *commoncap.RemoteExecutableConfig, remoteCapabilityInfo commoncap.CapabilityInfo, remoteDonInfo commoncap.DON, localDonInfo commoncap.DON, dispatcher types.Dispatcher,
requestTimeout time.Duration, lggr logger.Logger) *client {
if remoteExecutableConfig == nil {
lggr.Info("no remote config provided, using default values")
remoteExecutableConfig = &commoncap.RemoteExecutableConfig{}
}
remoteExecutableConfig.ApplyDefaults()

return &client{
lggr: lggr.Named("ExecutableCapabilityClient"),
remoteCapabilityInfo: remoteCapabilityInfo,
localDONInfo: localDonInfo,
dispatcher: dispatcher,
requestTimeout: requestTimeout,
requestIDToCallerRequest: make(map[string]*request.ClientRequest),
registrationClient: registration.NewClient(lggr, types.MethodRegisterToWorkflow, remoteExecutableConfig.RegistrationRefresh, remoteCapabilityInfo, remoteDonInfo, localDonInfo, dispatcher, "ExecutableClient"),
stopCh: make(services.StopChan),
}
}

func (c *client) Start(ctx context.Context) error {
return c.StartOnce(c.Name(), func() error {
if err := c.registrationClient.Start(ctx); err != nil {
return fmt.Errorf("failed to start registration client: %w", err)
}

c.wg.Add(1)
go func() {
defer c.wg.Done()
Expand All @@ -77,6 +90,12 @@ func (c *client) Close() error {
close(c.stopCh)
c.cancelAllRequests(errors.New("client closed"))
c.wg.Wait()

err := c.registrationClient.Close()
if err != nil {
c.lggr.Errorw("failed to close registration client", "err", err)
}

c.lggr.Info("ExecutableCapability closed")
return nil
})
Expand Down Expand Up @@ -140,10 +159,24 @@ func (c *client) Info(ctx context.Context) (commoncap.CapabilityInfo, error) {
}

func (c *client) RegisterToWorkflow(ctx context.Context, registerRequest commoncap.RegisterToWorkflowRequest) error {
rawRequest, err := pb.MarshalRegisterToWorkflowRequest(registerRequest)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
workflowID := registerRequest.Metadata.WorkflowID
if workflowID == "" {
return errors.New("empty workflowID")
}

if err = c.registrationClient.RegisterWorkflow(workflowID, rawRequest); err != nil {
return fmt.Errorf("failed to register workflow: %w", err)
}

return nil
}

func (c *client) UnregisterFromWorkflow(ctx context.Context, unregisterRequest commoncap.UnregisterFromWorkflowRequest) error {
c.registrationClient.UnregisterWorkflow(unregisterRequest.Metadata.WorkflowID)
return nil
}

Expand Down
Loading

0 comments on commit 8236209

Please sign in to comment.