From 7ca8b0d45f2b0ca851528cbc57e24408be683aa6 Mon Sep 17 00:00:00 2001 From: Gabriel Paradiso Date: Thu, 22 Feb 2024 10:53:33 +0100 Subject: [PATCH] fix: use semver to compare versions --- .../handlers/functions/allowlist/allowlist.go | 27 +++-------- .../functions/allowlist/allowlist_test.go | 48 +++++++++---------- go.mod | 2 +- 3 files changed, 31 insertions(+), 46 deletions(-) diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist.go b/core/services/gateway/handlers/functions/allowlist/allowlist.go index 6bfd3a8f534..020de2359c2 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist.go @@ -13,6 +13,7 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" + "golang.org/x/mod/semver" "github.com/smartcontractkit/chainlink-common/pkg/services" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" @@ -219,12 +220,12 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b return errors.Wrap(err, "failed to fetch the tos contract type and version") } - batchProcessingAllowed, err := ContractVersionIsSmallerOrEqual(tosContractMinBatchProcessingVersion, typeAndVersion) + currentVersion, err := ExtractContractVersion(typeAndVersion) if err != nil { - return errors.Wrap(err, "failed to compare contract version") + return fmt.Errorf("failed to extract version: %w", err) } - if batchProcessingAllowed { + if semver.Compare(tosContractMinBatchProcessingVersion, currentVersion) <= 0 { err = a.syncBlockedSenders(ctx, tosContract, blockNum) if err != nil { return errors.Wrap(err, "failed to sync the stored allowed and blocked senders") @@ -369,23 +370,7 @@ func (a *onchainAllowlist) loadStoredAllowedSenderList() { a.update(allowedList) } -// ContractVersionIsSmallerOrEqual receives two stringify contract versions s1 and s2 with the following format: `v(\d+).(\d+).(\d+)` -// and returns true in case s1 <= s2 -func ContractVersionIsSmallerOrEqual(s1 string, s2 string) (bool, error) { - versionS1, err := extractContractVersion(s1) - if err != nil { - return false, fmt.Errorf("failed to extract version: %w", err) - } - - versionS2, err := extractContractVersion(s2) - if err != nil { - return false, fmt.Errorf("failed to extract version: %w", err) - } - - return (versionS1 <= versionS2), nil -} - -func extractContractVersion(str string) (string, error) { +func ExtractContractVersion(str string) (string, error) { pattern := `v(\d+).(\d+).(\d+)` re := regexp.MustCompile(pattern) @@ -393,5 +378,5 @@ func extractContractVersion(str string) (string, error) { if len(match) != 4 { return "", fmt.Errorf("version not found in string: %s", str) } - return match[1] + match[2] + match[3], nil + return fmt.Sprintf("v%s.%s.%s", match[1], match[2], match[3]), nil } diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist_test.go b/core/services/gateway/handlers/functions/allowlist/allowlist_test.go index b81059bc7a2..735c0bff7dc 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist_test.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist_test.go @@ -3,6 +3,7 @@ package allowlist_test import ( "context" "encoding/hex" + "fmt" "math/big" "testing" "time" @@ -316,51 +317,50 @@ func TestUpdateFromContract(t *testing.T) { } -func TestContractVersionIsSmallerOrEqual(t *testing.T) { +func TestExtractContractVersion(t *testing.T) { type tc struct { name string - version1 string - version2 string - expectedResult bool + versionStr string + expectedResult string expectedError *string } - var errInvalidVersion = "failed to extract version: version not found in string: invalid_version" + var errInvalidVersion = func(v string) *string { + ev := fmt.Sprintf("version not found in string: %s", v) + return &ev + } + tcs := []tc{ { - name: "OK-bigger_version", - version1: "v1.0.1", - version2: "v1.0.0", - expectedResult: false, + name: "OK-Tos_type_and_version", + versionStr: "Functions Terms of Service Allow List v1.1.0", + expectedResult: "v1.1.0", expectedError: nil, }, { - name: "OK-smaller_version", - version1: "v1.1.0", - version2: "v2.0.0", - expectedResult: true, + name: "OK-double_digits_minor", + versionStr: "Functions Terms of Service Allow List v1.20.0", + expectedResult: "v1.20.0", expectedError: nil, }, { - name: "OK-same_version", - version1: "v1.2.0", - version2: "v1.2.0", - expectedResult: true, - expectedError: nil, + name: "NOK-invalid_version", + versionStr: "invalid_version", + expectedResult: "", + expectedError: errInvalidVersion("invalid_version"), }, { - name: "NOK-invalid_version", - version1: "invalid_version", - version2: "v1.1.0", - expectedResult: false, - expectedError: &errInvalidVersion, + name: "NOK-incomplete_version", + versionStr: "v2.0", + expectedResult: "", + expectedError: errInvalidVersion("v2.0"), }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - actualResult, actualError := allowlist.ContractVersionIsSmallerOrEqual(tc.version1, tc.version2) + actualResult, actualError := allowlist.ExtractContractVersion(tc.versionStr) require.Equal(t, tc.expectedResult, actualResult) if tc.expectedError != nil { diff --git a/go.mod b/go.mod index be354048dcd..47d24b80805 100644 --- a/go.mod +++ b/go.mod @@ -97,6 +97,7 @@ require ( go.uber.org/zap v1.26.0 golang.org/x/crypto v0.19.0 golang.org/x/exp v0.0.0-20240213143201-ec583247a57a + golang.org/x/mod v0.15.0 golang.org/x/sync v0.6.0 golang.org/x/term v0.17.0 golang.org/x/text v0.14.0 @@ -315,7 +316,6 @@ require ( go.opentelemetry.io/proto/otlp v1.0.0 // indirect go.uber.org/ratelimit v0.2.0 // indirect golang.org/x/arch v0.7.0 // indirect - golang.org/x/mod v0.15.0 // indirect golang.org/x/net v0.21.0 // indirect golang.org/x/oauth2 v0.17.0 // indirect golang.org/x/sys v0.17.0 // indirect