diff --git a/utils/apiclient/client.go b/utils/apiclient/client.go index a61bec5cfc..5b39da610b 100644 --- a/utils/apiclient/client.go +++ b/utils/apiclient/client.go @@ -1,6 +1,7 @@ package apiclient import ( + "encoding/base64" "encoding/json" "fmt" "io/ioutil" @@ -10,17 +11,20 @@ import ( "github.com/pkg/errors" ) -func (c *APIClient) getRequest(endpoint string, queryParams url.Values) (interface{}, error) { - client := c.HTTP - if client == nil { - client = &http.Client{} - } - +func (c *APIClient) createRequestBody(endpoint string, queryParams url.Values) (*http.Request, error) { fullURL := c.url(endpoint, queryParams) req, err := http.NewRequest("GET", fullURL, nil) if err != nil { return nil, errors.Wrap(err, "http GET request creation failed") } + return req, nil +} + +func (c *APIClient) callAPI(req *http.Request) (interface{}, error) { + client := c.HTTP + if client == nil { + client = &http.Client{} + } resp, err := client.Do(req) if err != nil { @@ -45,6 +49,50 @@ func (c *APIClient) getRequest(endpoint string, queryParams url.Values) (interfa return result, nil } +func setHeaders(req *http.Request, args map[string]interface{}) { + for key, value := range args { + strValue, ok := value.(string) + if !ok { + fmt.Printf("Skipping non-string value for header %s\n", key) + continue + } + + req.Header.Set(key, strValue) + } +} + +func setAuthHeaders(req *http.Request, authType string, args map[string]interface{}) error { + switch authType { + case "basic": + username, ok := args["username"].(string) + if !ok { + return fmt.Errorf("missing or invalid username") + } + password, ok := args["password"].(string) + if !ok { + return fmt.Errorf("missing or invalid password") + } + + authHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) + setHeaders(req, map[string]interface{}{ + "Authorization": authHeader, + }) + + case "api_key": + apiKey, ok := args["api_key"].(string) + if !ok { + return fmt.Errorf("missing or invalid API key") + } + setHeaders(req, map[string]interface{}{ + "Authorization": apiKey, + }) + + default: + return fmt.Errorf("unsupported auth type: %s", authType) + } + return nil +} + func (c *APIClient) url(endpoint string, qstr url.Values) string { return fmt.Sprintf("%s/%s?%s", c.BaseURL, endpoint, qstr.Encode()) } diff --git a/utils/apiclient/client_test.go b/utils/apiclient/client_test.go index 9dbbc61ffa..8db1193cc7 100644 --- a/utils/apiclient/client_test.go +++ b/utils/apiclient/client_test.go @@ -22,7 +22,7 @@ func Test_url(t *testing.T) { assert.Equal(t, "https://stellar.org/federation?acct=2382376&federation_type=bank_account&swift=BOPBPHMM&type=forward", furl) } -func Test_getRequest(t *testing.T) { +func Test_callAPI(t *testing.T) { friendbotFundResponse := `{"key": "value"}` hmock := httptest.NewClient() @@ -41,10 +41,18 @@ func Test_getRequest(t *testing.T) { qstr.Add("swift", "BOPBPHMM") qstr.Add("acct", "2382376") - result, err := c.getRequest("federation", qstr) + req, err := c.createRequestBody("federation", qstr) if err != nil { t.Fatal(err) } + setAuthHeaders(req, "api_key", map[string]interface{}{"api_key": "test_api_key"}) + assert.Equal(t, "test_api_key", req.Header.Get("Authorization")) + + result, err := c.callAPI(req) + if err != nil { + t.Fatal(err) + } + expected := map[string]interface{}{"key": "value"} assert.Equal(t, expected, result) } diff --git a/utils/apiclient/main.go b/utils/apiclient/main.go index 707b019dac..952f17fa43 100644 --- a/utils/apiclient/main.go +++ b/utils/apiclient/main.go @@ -12,7 +12,6 @@ type HTTP interface { } type APIClient struct { - BaseURL string - AuthToken string - HTTP HTTP + BaseURL string + HTTP HTTP }