diff --git a/types/operator.go b/types/operator.go index e0187e9f..7b7ada9a 100644 --- a/types/operator.go +++ b/types/operator.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "math/big" + "net/http" "github.com/Layr-Labs/eigensdk-go/crypto/bls" "github.com/Layr-Labs/eigensdk-go/utils" @@ -46,7 +47,8 @@ func (o Operator) Validate() error { return utils.WrapError(ErrInvalidMetadataUrl, err) } - body, err := utils.ReadPublicURL(o.MetadataUrl) + client := &http.Client{} + body, err := utils.ReadPublicURL(o.MetadataUrl, client) if err != nil { return utils.WrapError(ErrReadingMetadataUrlResponse, err) } diff --git a/types/operator_metadata.go b/types/operator_metadata.go index b80a5881..fd2ccada 100644 --- a/types/operator_metadata.go +++ b/types/operator_metadata.go @@ -1,6 +1,8 @@ package types import ( + "net/http" + "github.com/Layr-Labs/eigensdk-go/utils" ) @@ -45,7 +47,8 @@ func (om *OperatorMetadata) Validate() error { return ErrLogoRequired } - if err = utils.IsImageURL(om.Logo); err != nil { + client := &http.Client{} + if err = utils.IsImageURL(om.Logo, client); err != nil { return err } diff --git a/utils/utils.go b/utils/utils.go index 5b54c65f..a062f449 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -62,6 +62,7 @@ func EcdsaPrivateKeyToAddress(privateKey *ecdsa.PrivateKey) (gethcommon.Address, return crypto.PubkeyToAddress(*publicKeyECDSA), nil } +// RoundUpDivideBig divides two positive big.Int numbers and rounds up the result. func RoundUpDivideBig(a, b *big.Int) *big.Int { one := new(big.Int).SetUint64(1) res := new(big.Int) @@ -75,14 +76,13 @@ func IsValidEthereumAddress(address string) bool { return ethAddrPattern.MatchString(address) } -func ReadPublicURL(url string) ([]byte, error) { - // allow no redirects - httpClient := http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - Timeout: 3 * time.Second, +func ReadPublicURL(url string, httpClient *http.Client) ([]byte, error) { + // Allow no redirects + httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse } + httpClient.Timeout = 3 * time.Second + resp, err := httpClient.Get(url) if err != nil { return []byte{}, err @@ -170,7 +170,7 @@ func CheckIfUrlIsValid(rawUrl string) error { return nil } -func IsImageURL(urlString string) error { +func IsImageURL(urlString string, httpClient *http.Client) error { // Parse the URL parsedURL, err := url.Parse(urlString) if err != nil { @@ -186,7 +186,7 @@ func IsImageURL(urlString string) error { // Check if the extension is in the list of image extensions for _, imgExt := range ImageExtensions { if strings.EqualFold(extension, imgExt) { - imageBytes, err := ReadPublicURL(urlString) + imageBytes, err := ReadPublicURL(urlString, httpClient) if err != nil { return err } diff --git a/utils/utils_test.go b/utils/utils_test.go index 0f7e4885..7a0f0c5c 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -1,51 +1,472 @@ package utils import ( + "math/big" + "net/http" + "net/http/httptest" + "os" "strings" "testing" + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/assert" ) +func TestReadFile(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content []byte + expectError bool + }{ + { + name: "Read existing file", + content: []byte("Hello, world!"), + }, + { + name: "File does not exist", + content: nil, + expectError: true, + }, + { + name: "Empty file", + content: []byte(""), + }, + { + name: "Large file", + content: make([]byte, 1024*1024), + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var filePath string + if tc.content != nil { + // Create a temporary file + tmpFile, err := os.CreateTemp("", "testfile") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // Write content to the file + _, err = tmpFile.Write(tc.content) + assert.NoError(t, err) + + // Close the file + err = tmpFile.Close() + assert.NoError(t, err) + + filePath = tmpFile.Name() + } else { + // Non-existent file path + filePath = "nonexistentfile.txt" + } + + // Read the file using ReadFile function + readContent, err := ReadFile(filePath) + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.content, readContent) + } + }) + } +} + +func TestReadYamlConfig(t *testing.T) { + t.Parallel() + + type Config struct { + Name string `yaml:"name"` + Age int `yaml:"age"` + } + + tests := []struct { + name string + yamlContent string + expectError bool + expected Config + }{ + { + name: "Valid YAML", + yamlContent: ` +name: John Doe +age: 30 +`, + expected: Config{Name: "John Doe", Age: 30}, + }, + { + name: "Missing fields", + yamlContent: ` +name: Alice +`, + expected: Config{Name: "Alice", Age: 0}, + }, + { + name: "Empty YAML", + yamlContent: ``, + expected: Config{}, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Create a temporary YAML file + tmpFile, err := os.CreateTemp("", "testconfig*.yaml") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + _, err = tmpFile.Write([]byte(tc.yamlContent)) + assert.NoError(t, err) + err = tmpFile.Close() + assert.NoError(t, err) + + var config Config + err = ReadYamlConfig(tmpFile.Name(), &config) + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, config) + } + }) + } +} + +func TestReadJsonConfig(t *testing.T) { + t.Parallel() + + type Config struct { + Name string `json:"name"` + Age int `json:"age"` + } + + tests := []struct { + name string + jsonContent string + expectError bool + expected Config + }{ + { + name: "Valid JSON", + jsonContent: `{ + "name": "Jane Doe", + "age": 25 +}`, + expected: Config{Name: "Jane Doe", Age: 25}, + }, + { + name: "Missing fields", + jsonContent: `{"name": "Alice"}`, + expected: Config{Name: "Alice", Age: 0}, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Create a temporary JSON file + tmpFile, err := os.CreateTemp("", "testconfig*.json") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + _, err = tmpFile.Write([]byte(tc.jsonContent)) + assert.NoError(t, err) + err = tmpFile.Close() + assert.NoError(t, err) + + var config Config + err = ReadJsonConfig(tmpFile.Name(), &config) + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, config) + } + }) + } +} + +func TestEcdsaPrivateKeyToAddress(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + privateKeyHex string + expectError bool + expectedAddr gethcommon.Address + }{ + { + name: "Valid private key", + privateKeyHex: "4c0883a69102937d6231471b5dbb6204fe512961708279e6e82cc073aa8aa1a9", + expectedAddr: gethcommon.HexToAddress("0x8019FFe7A44A943c3a507C94D418DA3eD829f04d"), + }, + { + name: "Invalid private key (non-hex)", + privateKeyHex: "invalidkey", + expectError: true, + }, + { + name: "Invalid private key (empty)", + privateKeyHex: "", + expectError: true, + }, + { + name: "Valid private key with 0x prefix", + privateKeyHex: "0x8f2a5594902d6e44c069ad1c6e42c3e92a0416b4cf8ba689344844ab63b0f5f8", + expectedAddr: gethcommon.HexToAddress("0x8089DD5E304F65219dF307357256C07896f87703"), + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + privateKeyHex := strings.TrimPrefix(tc.privateKeyHex, "0x") + privateKey, err := crypto.HexToECDSA(privateKeyHex) + if tc.expectError { + assert.Error(t, err) + return + } + assert.NoError(t, err) + + address, err := EcdsaPrivateKeyToAddress(privateKey) + assert.NoError(t, err) + assert.Equal(t, tc.expectedAddr.Hex(), address.Hex()) + }) + } +} + +func TestRoundUpDivideBig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + a *big.Int + b *big.Int + expected *big.Int + expectPanic bool + }{ + { + name: "Divide 5 by 2", + a: big.NewInt(5), + b: big.NewInt(2), + expected: big.NewInt(3), + expectPanic: false, + }, + { + name: "Divide 10 by 3", + a: big.NewInt(10), + b: big.NewInt(3), + expected: big.NewInt(4), + expectPanic: false, + }, + { + name: "Divide 100 by 33", + a: big.NewInt(100), + b: big.NewInt(33), + expected: big.NewInt(4), + expectPanic: false, + }, + { + name: "Divide 0 by 1", + a: big.NewInt(0), + b: big.NewInt(1), + expected: big.NewInt(0), + expectPanic: false, + }, + { + name: "Divide 1 by 1", + a: big.NewInt(1), + b: big.NewInt(1), + expected: big.NewInt(1), + expectPanic: false, + }, + { + name: "Divide by zero (100/0)", + a: big.NewInt(100), + b: big.NewInt(0), + expected: nil, + expectPanic: true, + }, + { + name: "Both zero (0/0)", + a: big.NewInt(0), + b: big.NewInt(0), + expected: nil, + expectPanic: true, + }, + { + name: "Large division", + a: big.NewInt(999999999999999999), + b: big.NewInt(3), + expected: big.NewInt(333333333333333333), + expectPanic: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if tc.expectPanic { + assert.Panics(t, func() { + RoundUpDivideBig(new(big.Int).Set(tc.a), tc.b) + }) + } else { + result := RoundUpDivideBig(new(big.Int).Set(tc.a), tc.b) + assert.Equal(t, tc.expected.String(), result.String(), "RoundUpDivideBig(%s, %s)", tc.a.String(), tc.b.String()) + } + }) + } +} + +func TestReadPublicURL(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Serve different content based on the request URL + if r.URL.Path == "/small" { + w.WriteHeader(http.StatusOK) + w.Write(make([]byte, 1024)) //nolint:errcheck + } else if r.URL.Path == "/large" { + w.WriteHeader(http.StatusOK) + w.Write(make([]byte, 2*1024*1024)) //nolint:errcheck + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + tests := []struct { + name string + urlPath string + expectedErr error + }{ + { + name: "request < 1mb", + urlPath: "/small", + expectedErr: nil, + }, + { + name: "request too large", + urlPath: "/large", + expectedErr: ErrResponseTooLarge, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := server.URL + tt.urlPath + client := &http.Client{} + _, err := ReadPublicURL(url, client) + assert.Equal(t, tt.expectedErr, err) + }) + } +} + +func TestIsImageURL(t *testing.T) { + t.Parallel() + + // Create an httptest server that serves a valid PNG image + pngData := []byte{137, 80, 78, 71, 13, 10, 26, 10} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, ".png") { + w.Header().Set("Content-Type", "image/png") + w.WriteHeader(http.StatusOK) + w.Write(pngData) //nolint:errcheck + } else { + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + tests := []struct { + name string + url string + expectedErr error + }{ + { + name: "valid PNG image URL", + url: server.URL + "/image.png", + expectedErr: nil, + }, + { + name: "invalid image extension", + url: server.URL + "/image.jpg", + expectedErr: ErrInvalidImageExtension, + }, + { + name: "non-image URL", + url: server.URL + "/notfound", + expectedErr: ErrInvalidImageExtension, + }, + { + name: "image with wrong MIME type", + url: server.URL + "/image.jpg", + expectedErr: ErrInvalidImageExtension, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &http.Client{} + err := IsImageURL(tt.url, client) + assert.Equal(t, tt.expectedErr, err) + }) + } +} + func TestIsValidEthereumAddress(t *testing.T) { t.Parallel() tests := []struct { - name string - address string - expected bool + name string + address string + isValid bool }{ { - name: "valid address", - address: "0x1234567890abcdef1234567890abcdef12345678", - expected: true, + name: "valid address", + address: "0x1234567890abcdef1234567890abcdef12345678", + isValid: true, }, { - name: "uppercase", - address: "0x1234567890ABCDEF1234567890ABCDEF12345678", - expected: true, + name: "uppercase", + address: "0x1234567890ABCDEF1234567890ABCDEF12345678", + isValid: true, }, { - name: "too short", - address: "0x1234567890abcdef1234567890abcdef123456", - expected: false, + name: "too short", + address: "0x1234567890abcdef1234567890abcdef123456", + isValid: false, }, { - name: "missing 0x prefix", - address: "001234567890abcdef1234567890abcdef12345678", - expected: false, + name: "missing 0x prefix", + address: "001234567890abcdef1234567890abcdef12345678", + isValid: false, }, { - name: "non-hex characters", - address: "0x1234567890abcdef1234567890abcdef123ÅÅÅÅÅ", - expected: false, + name: "non-hex characters", + address: "0x1234567890abcdef1234567890abcdef123ÅÅÅÅÅ", + isValid: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := IsValidEthereumAddress(tt.address) - assert.Equal(t, tt.expected, result) + assert.Equal(t, tt.isValid, result) }) } } @@ -141,34 +562,6 @@ func TestCheckIfUrlIsValid(t *testing.T) { } } -func TestReadPublicURL(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - url string - expectedErr error - }{ - { - name: "request < 1mb", - url: "https://raw.githubusercontent.com/shrimalmadhur/metadata/main/logo.png", - expectedErr: nil, - }, - { - name: "request too large", - url: "https://raw.githubusercontent.com/shrimalmadhur/metadata/main/2mb.png", - expectedErr: ErrResponseTooLarge, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := ReadPublicURL(tt.url) - assert.Equal(t, tt.expectedErr, err) - }) - } -} - func TestValidateText(t *testing.T) { t.Parallel()