Skip to content

Commit

Permalink
[Functions] Minor Listener refactor (#11323)
Browse files Browse the repository at this point in the history
1. Add an interface type to make Listener mockable
2. Return internal errors from handleRequest()
  • Loading branch information
bolekk authored Nov 17, 2023
1 parent 3ed1689 commit 738146e
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 32 deletions.
74 changes: 43 additions & 31 deletions core/services/functions/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ import (
)

var (
_ job.ServiceCtx = &FunctionsListener{}

sizeBuckets = []float64{
1024,
1024 * 4,
Expand Down Expand Up @@ -118,7 +116,14 @@ const (
FlagSecretsMaxSize uint32 = 2
)

type FunctionsListener struct {
//go:generate mockery --quiet --name FunctionsListener --output ./mocks/ --case=underscore
type FunctionsListener interface {
job.ServiceCtx

HandleOffchainRequest(ctx context.Context, request *OffchainRequest) error
}

type functionsListener struct {
services.StateMachine
client client.Client
contractAddressHex string
Expand All @@ -137,11 +142,13 @@ type FunctionsListener struct {
logPollerWrapper evmrelayTypes.LogPollerWrapper
}

func (l *FunctionsListener) HealthReport() map[string]error {
var _ FunctionsListener = &functionsListener{}

func (l *functionsListener) HealthReport() map[string]error {
return map[string]error{l.Name(): l.Healthy()}
}

func (l *FunctionsListener) Name() string { return l.logger.Name() }
func (l *functionsListener) Name() string { return l.logger.Name() }

func formatRequestId(requestId [32]byte) string {
return fmt.Sprintf("0x%x", requestId)
Expand All @@ -159,8 +166,8 @@ func NewFunctionsListener(
urlsMonEndpoint commontypes.MonitoringEndpoint,
decryptor threshold.Decryptor,
logPollerWrapper evmrelayTypes.LogPollerWrapper,
) *FunctionsListener {
return &FunctionsListener{
) *functionsListener {
return &functionsListener{
client: client,
contractAddressHex: contractAddressHex,
job: job,
Expand All @@ -177,7 +184,7 @@ func NewFunctionsListener(
}

// Start complies with job.Service
func (l *FunctionsListener) Start(context.Context) error {
func (l *functionsListener) Start(context.Context) error {
return l.StartOnce("FunctionsListener", func() error {
l.serviceContext, l.serviceCancel = context.WithCancel(context.Background())

Expand All @@ -204,7 +211,7 @@ func (l *FunctionsListener) Start(context.Context) error {
}

// Close complies with job.Service
func (l *FunctionsListener) Close() error {
func (l *functionsListener) Close() error {
return l.StopOnce("FunctionsListener", func() error {
l.serviceCancel()
close(l.chStop)
Expand All @@ -213,7 +220,7 @@ func (l *FunctionsListener) Close() error {
})
}

func (l *FunctionsListener) processOracleEventsV1() {
func (l *functionsListener) processOracleEventsV1() {
defer l.shutdownWaitGroup.Done()
freqMillis := l.pluginConfig.ListenerEventsCheckFrequencyMillis
if freqMillis == 0 {
Expand Down Expand Up @@ -247,15 +254,15 @@ func (l *FunctionsListener) processOracleEventsV1() {
}
}

func (l *FunctionsListener) getNewHandlerContext() (context.Context, context.CancelFunc) {
func (l *functionsListener) getNewHandlerContext() (context.Context, context.CancelFunc) {
timeoutSec := l.pluginConfig.ListenerEventHandlerTimeoutSec
if timeoutSec == 0 {
return context.WithCancel(l.serviceContext)
}
return context.WithTimeout(l.serviceContext, time.Duration(timeoutSec)*time.Second)
}

func (l *FunctionsListener) setError(ctx context.Context, requestId RequestID, errType ErrType, errBytes []byte) {
func (l *functionsListener) setError(ctx context.Context, requestId RequestID, errType ErrType, errBytes []byte) {
if errType == INTERNAL_ERROR {
promRequestInternalError.WithLabelValues(l.contractAddressHex).Inc()
} else {
Expand All @@ -267,23 +274,23 @@ func (l *FunctionsListener) setError(ctx context.Context, requestId RequestID, e
}
}

func (l *FunctionsListener) getMaxCBORsize(flags RequestFlags) uint32 {
func (l *functionsListener) getMaxCBORsize(flags RequestFlags) uint32 {
idx := flags[FlagCBORMaxSize]
if int(idx) >= len(l.pluginConfig.MaxRequestSizesList) {
return l.pluginConfig.MaxRequestSizeBytes // deprecated
}
return l.pluginConfig.MaxRequestSizesList[idx]
}

func (l *FunctionsListener) getMaxSecretsSize(flags RequestFlags) uint32 {
func (l *functionsListener) getMaxSecretsSize(flags RequestFlags) uint32 {
idx := flags[FlagSecretsMaxSize]
if int(idx) >= len(l.pluginConfig.MaxSecretsSizesList) {
return math.MaxUint32 // not enforced if not configured
}
return l.pluginConfig.MaxSecretsSizesList[idx]
}

func (l *FunctionsListener) HandleOffchainRequest(ctx context.Context, request *OffchainRequest) error {
func (l *functionsListener) HandleOffchainRequest(ctx context.Context, request *OffchainRequest) error {
if request == nil {
return errors.New("HandleOffchainRequest: received nil request")
}
Expand Down Expand Up @@ -318,11 +325,10 @@ func (l *FunctionsListener) HandleOffchainRequest(ctx context.Context, request *
}
return err
}
l.handleRequest(ctx, requestId, request.SubscriptionId, subscriptionOwner, RequestFlags{}, &request.Data)
return nil
return l.handleRequest(ctx, requestId, request.SubscriptionId, subscriptionOwner, RequestFlags{}, &request.Data)
}

func (l *FunctionsListener) handleOracleRequestV1(request *evmrelayTypes.OracleRequest) {
func (l *functionsListener) handleOracleRequestV1(request *evmrelayTypes.OracleRequest) {
defer l.shutdownWaitGroup.Done()
l.logger.Infow("handleOracleRequestV1: oracle request v1 received", "requestID", formatRequestId(request.RequestId))
ctx, cancel := l.getNewHandlerContext()
Expand Down Expand Up @@ -354,10 +360,13 @@ func (l *FunctionsListener) handleOracleRequestV1(request *evmrelayTypes.OracleR
l.setError(ctx, request.RequestId, USER_ERROR, []byte(err.Error()))
return
}
l.handleRequest(ctx, request.RequestId, request.SubscriptionId, request.SubscriptionOwner, request.Flags, requestData)
err = l.handleRequest(ctx, request.RequestId, request.SubscriptionId, request.SubscriptionOwner, request.Flags, requestData)
if err != nil {
l.logger.Errorw("handleOracleRequestV1: error in handleRequest()", "requestID", formatRequestId(request.RequestId), "err", err)
}
}

func (l *FunctionsListener) parseCBOR(requestId RequestID, cborData []byte, maxSizeBytes uint32) (*RequestData, error) {
func (l *functionsListener) parseCBOR(requestId RequestID, cborData []byte, maxSizeBytes uint32) (*RequestData, error) {
if maxSizeBytes > 0 && uint32(len(cborData)) > maxSizeBytes {
l.logger.Errorw("request too big", "requestID", formatRequestId(requestId), "requestSize", len(cborData), "maxRequestSize", maxSizeBytes)
return nil, fmt.Errorf("request too big (max %d bytes)", maxSizeBytes)
Expand All @@ -372,7 +381,8 @@ func (l *FunctionsListener) parseCBOR(requestId RequestID, cborData []byte, maxS
return &requestData, nil
}

func (l *FunctionsListener) handleRequest(ctx context.Context, requestID RequestID, subscriptionId uint64, subscriptionOwner common.Address, flags RequestFlags, requestData *RequestData) {
// Handle secret fetching/decryption and functions computation. Return error only for internal errors.
func (l *functionsListener) handleRequest(ctx context.Context, requestID RequestID, subscriptionId uint64, subscriptionOwner common.Address, flags RequestFlags, requestData *RequestData) error {
startTime := time.Now()
defer func() {
duration := time.Since(startTime)
Expand All @@ -385,34 +395,34 @@ func (l *FunctionsListener) handleRequest(ctx context.Context, requestID Request
if err != nil {
l.logger.Errorw("failed to create ExternalAdapterClient", "requestID", requestIDStr, "err", err)
l.setError(ctx, requestID, INTERNAL_ERROR, []byte(err.Error()))
return
return err
}

nodeProvidedSecrets, userErr, internalErr := l.getSecrets(ctx, eaClient, requestID, subscriptionOwner, requestData)
if internalErr != nil {
l.logger.Errorw("internal error during getSecrets", "requestID", requestIDStr, "err", internalErr)
l.setError(ctx, requestID, INTERNAL_ERROR, []byte(internalErr.Error()))
return
return internalErr
}
if userErr != nil {
l.logger.Debugw("user error during getSecrets", "requestID", requestIDStr, "err", userErr)
l.setError(ctx, requestID, USER_ERROR, []byte(userErr.Error()))
return
return nil // user error
}

maxSecretsSize := l.getMaxSecretsSize(flags)
if uint32(len(nodeProvidedSecrets)) > maxSecretsSize {
l.logger.Errorw("secrets size too big", "requestID", requestIDStr, "secretsSize", len(nodeProvidedSecrets), "maxSecretsSize", maxSecretsSize)
l.setError(ctx, requestID, USER_ERROR, []byte("secrets size too big"))
return
return nil // user error
}

computationResult, computationError, domains, err := eaClient.RunComputation(ctx, requestIDStr, l.job.Name.ValueOrZero(), subscriptionOwner.Hex(), subscriptionId, flags, nodeProvidedSecrets, requestData)

if err != nil {
l.logger.Errorw("internal adapter error", "requestID", requestIDStr, "err", err)
l.setError(ctx, requestID, INTERNAL_ERROR, []byte(err.Error()))
return
return err
}

if len(computationError) == 0 && len(computationResult) == 0 {
Expand All @@ -438,11 +448,13 @@ func (l *FunctionsListener) handleRequest(ctx context.Context, requestID Request
l.logger.Debugw("saving computation result", "requestID", requestIDStr)
if err2 := l.pluginORM.SetResult(requestID, computationResult, time.Now(), pg.WithParentCtx(ctx)); err2 != nil {
l.logger.Errorw("call to SetResult failed", "requestID", requestIDStr, "err", err2)
return err2
}
}
return nil
}

func (l *FunctionsListener) handleOracleResponseV1(response *evmrelayTypes.OracleResponse) {
func (l *functionsListener) handleOracleResponseV1(response *evmrelayTypes.OracleResponse) {
defer l.shutdownWaitGroup.Done()
l.logger.Infow("oracle response v1 received", "requestID", formatRequestId(response.RequestId))

Expand All @@ -454,7 +466,7 @@ func (l *FunctionsListener) handleOracleResponseV1(response *evmrelayTypes.Oracl
promRequestConfirmed.WithLabelValues(l.contractAddressHex).Inc()
}

func (l *FunctionsListener) timeoutRequests() {
func (l *functionsListener) timeoutRequests() {
defer l.shutdownWaitGroup.Done()
timeoutSec, freqSec, batchSize := l.pluginConfig.RequestTimeoutSec, l.pluginConfig.RequestTimeoutCheckFrequencySec, l.pluginConfig.RequestTimeoutBatchLookupSize
if timeoutSec == 0 || freqSec == 0 || batchSize == 0 {
Expand Down Expand Up @@ -490,7 +502,7 @@ func (l *FunctionsListener) timeoutRequests() {
}
}

func (l *FunctionsListener) pruneRequests() {
func (l *functionsListener) pruneRequests() {
defer l.shutdownWaitGroup.Done()
maxStoredRequests, freqSec, batchSize := l.pluginConfig.PruneMaxStoredRequests, l.pluginConfig.PruneCheckFrequencySec, l.pluginConfig.PruneBatchSize
if maxStoredRequests == 0 {
Expand Down Expand Up @@ -532,7 +544,7 @@ func (l *FunctionsListener) pruneRequests() {
}
}

func (l *FunctionsListener) reportSourceCodeDomains(requestId RequestID, domains []string) {
func (l *functionsListener) reportSourceCodeDomains(requestId RequestID, domains []string) {
r := &telem.FunctionsRequest{
RequestId: formatRequestId(requestId),
NodeAddress: l.job.OCR2OracleSpec.TransmitterID.ValueOrZero(),
Expand All @@ -547,7 +559,7 @@ func (l *FunctionsListener) reportSourceCodeDomains(requestId RequestID, domains
}
}

func (l *FunctionsListener) getSecrets(ctx context.Context, eaClient ExternalAdapterClient, requestID RequestID, subscriptionOwner common.Address, requestData *RequestData) (decryptedSecrets string, userError, internalError error) {
func (l *functionsListener) getSecrets(ctx context.Context, eaClient ExternalAdapterClient, requestID RequestID, subscriptionOwner common.Address, requestData *RequestData) (decryptedSecrets string, userError, internalError error) {
if l.decryptor == nil {
l.logger.Warn("Decryptor not configured")
return "", nil, nil
Expand Down
21 changes: 20 additions & 1 deletion core/services/functions/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import (
)

type FunctionsListenerUniverse struct {
service *functions_service.FunctionsListener
service functions_service.FunctionsListener
bridgeAccessor *functions_mocks.BridgeAccessor
eaClient *functions_mocks.ExternalAdapterClient
pluginORM *functions_mocks.ORM
Expand Down Expand Up @@ -219,6 +219,25 @@ func TestFunctionsListener_HandleOffchainRequest_Invalid(t *testing.T) {
require.Error(t, uni.service.HandleOffchainRequest(testutils.Context(t), request))
}

func TestFunctionsListener_HandleOffchainRequest_InternalError(t *testing.T) {
testutils.SkipShortDB(t)
t.Parallel()
uni := NewFunctionsListenerUniverse(t, 0, 1_000_000)
uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil)
uni.bridgeAccessor.On("NewExternalAdapterClient").Return(uni.eaClient, nil)
uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil, nil, errors.New("error"))
uni.pluginORM.On("SetError", RequestID, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)

request := &functions_service.OffchainRequest{
RequestId: RequestID[:],
RequestInitiator: SubscriptionOwner.Bytes(),
SubscriptionId: uint64(SubscriptionID),
SubscriptionOwner: SubscriptionOwner.Bytes(),
Data: functions_service.RequestData{},
}
require.Error(t, uni.service.HandleOffchainRequest(testutils.Context(t), request))
}

func TestFunctionsListener_HandleOracleRequestV1_ComputationError(t *testing.T) {
testutils.SkipShortDB(t)
t.Parallel()
Expand Down
71 changes: 71 additions & 0 deletions core/services/functions/mocks/functions_listener.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 738146e

Please sign in to comment.