Skip to content

Commit

Permalink
DE-1144 refactor httphelpers (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
vtopc authored Oct 29, 2024
1 parent 609ce8b commit 2285d22
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ all: test

.PHONY: test
test:
export GO111MODULE=on; go test . -v
export GO111MODULE=on; go test . -race -count=1

.PHONY: godoc
godoc:
Expand Down
60 changes: 35 additions & 25 deletions httphelpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"net/url"
Expand Down Expand Up @@ -131,6 +130,7 @@ func (f *urlEncodedPayload) getPayloadBuffer() (*bytes.Buffer, error) {
for _, keyVal := range f.Values {
data.Add(keyVal.key, keyVal.value)
}

return bytes.NewBufferString(data.Encode()), nil
}

Expand Down Expand Up @@ -177,6 +177,7 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) {

for _, keyVal := range f.Values {
if tmp, err := writer.CreateFormField(keyVal.key); err == nil {
// TODO(DE-1139): handle error:
tmp.Write([]byte(keyVal.value))
} else {
return nil, err
Expand All @@ -186,7 +187,9 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) {
for _, file := range f.Files {
if tmp, err := writer.CreateFormFile(file.key, path.Base(file.value)); err == nil {
if fp, err := os.Open(file.value); err == nil {
// TODO(DE-1139): defer in a loop:
defer fp.Close()
// TODO(DE-1139): handle error:
io.Copy(tmp, fp)
} else {
return nil, err
Expand All @@ -198,7 +201,9 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) {

for _, file := range f.ReadClosers {
if tmp, err := writer.CreateFormFile(file.key, file.name); err == nil {
// TODO(DE-1139): defer in a loop:
defer file.value.Close()
// TODO(DE-1139): handle error:
io.Copy(tmp, file.value)
} else {
return nil, err
Expand All @@ -208,6 +213,7 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) {
for _, buff := range f.Buffers {
if tmp, err := writer.CreateFormFile(buff.key, buff.name); err == nil {
r := bytes.NewReader(buff.value)
// TODO(DE-1139): handle error:
io.Copy(tmp, r)
} else {
return nil, err
Expand All @@ -221,8 +227,10 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) {

func (f *formDataPayload) getContentType() string {
if f.contentType == "" {
// TODO(DE-1139): handle error:
f.getPayloadBuffer()
}

return f.contentType
}

Expand All @@ -234,23 +242,23 @@ func (r *httpRequest) addHeader(name, value string) {
}

func (r *httpRequest) makeGetRequest(ctx context.Context) (*httpResponse, error) {
return r.makeRequest(ctx, "GET", nil)
return r.makeRequest(ctx, http.MethodGet, nil)
}

func (r *httpRequest) makePostRequest(ctx context.Context, payload payload) (*httpResponse, error) {
return r.makeRequest(ctx, "POST", payload)
return r.makeRequest(ctx, http.MethodPost, payload)
}

func (r *httpRequest) makePutRequest(ctx context.Context, payload payload) (*httpResponse, error) {
return r.makeRequest(ctx, "PUT", payload)
return r.makeRequest(ctx, http.MethodPut, payload)
}

func (r *httpRequest) makeDeleteRequest(ctx context.Context) (*httpResponse, error) {
return r.makeRequest(ctx, "DELETE", nil)
return r.makeRequest(ctx, http.MethodDelete, nil)
}

func (r *httpRequest) NewRequest(ctx context.Context, method string, payload payload) (*http.Request, error) {
url, err := r.generateUrlWithParameters()
uri, err := r.generateUrlWithParameters()
if err != nil {
return nil, err
}
Expand All @@ -263,13 +271,12 @@ func (r *httpRequest) NewRequest(ctx context.Context, method string, payload pay
} else {
body = nil
}
req, err := http.NewRequest(method, url, body)

req, err := http.NewRequestWithContext(ctx, method, uri, body)
if err != nil {
return nil, err
}

req = req.WithContext(ctx)

if payload != nil && payload.getContentType() != "" {
req.Header.Add("Content-Type", payload.getContentType())
}
Expand All @@ -286,6 +293,7 @@ func (r *httpRequest) NewRequest(ctx context.Context, method string, payload pay
}
req.Header.Add(header, value)
}

return req, nil
}

Expand All @@ -305,52 +313,53 @@ func (r *httpRequest) makeRequest(ctx context.Context, method string, payload pa
}
}

response := httpResponse{}

resp, err := r.Client.Do(req)
if resp != nil {
response.Code = resp.StatusCode
}
if err != nil {
if urlErr, ok := err.(*url.Error); ok {
if urlErr.Err == io.EOF {
return nil, errors.Wrap(err, "remote server prematurely closed connection")
}
var urlErr *url.Error
if errors.As(err, &urlErr) && urlErr != nil && errors.Is(urlErr.Err, io.EOF) {
return nil, errors.Wrap(err, "remote server prematurely closed connection")
}

return nil, errors.Wrap(err, "while making http request")
}

defer resp.Body.Close()
responseBody, err := ioutil.ReadAll(resp.Body)

response := httpResponse{
Code: resp.StatusCode,
}

responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "while reading response body")
}

response.Data = responseBody

return &response, nil
}

func (r *httpRequest) generateUrlWithParameters() (string, error) {
url, err := url.Parse(r.URL)
uri, err := url.Parse(r.URL)
if err != nil {
return "", err
}

if !validURL.MatchString(url.Path) {
return "", errors.New(`BaseAPI must end with a /v1, /v2, /v3 or /v4; setBaseAPI("https://host/v3")`)
if !validURL.MatchString(uri.Path) {
return "", errors.New(`APIBase must end with a /v1, /v2, /v3 or /v4; SetAPIBase("https://host/v3")`)
}

q := url.Query()
q := uri.Query()
if r.Parameters != nil && len(r.Parameters) > 0 {
for name, values := range r.Parameters {
for _, value := range values {
q.Add(name, value)
}
}
}
url.RawQuery = q.Encode()
uri.RawQuery = q.Encode()

return url.String(), nil
return uri.String(), nil
}

func (r *httpRequest) curlString(req *http.Request, p payload) string {
Expand Down Expand Up @@ -384,5 +393,6 @@ func (r *httpRequest) curlString(req *http.Request, p payload) string {
}
}
}

return strings.Join(parts, " ")
}
4 changes: 2 additions & 2 deletions mailgun_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/facebookgo/ensure"
"github.com/mailgun/mailgun-go/v4"
"github.com/stretchr/testify/assert"
)

const domain = "valid-mailgun-domain"
Expand All @@ -33,8 +34,7 @@ func TestInvalidBaseAPI(t *testing.T) {

ctx := context.Background()
_, err := mg.GetDomain(ctx, "unknown.domain")
ensure.NotNil(t, err)
ensure.DeepEqual(t, err.Error(), `BaseAPI must end with a /v1, /v2, /v3 or /v4; setBaseAPI("https://host/v3")`)
assert.EqualError(t, err, `APIBase must end with a /v1, /v2, /v3 or /v4; SetAPIBase("https://host/v3")`)
}

func TestValidBaseAPI(t *testing.T) {
Expand Down

0 comments on commit 2285d22

Please sign in to comment.