diff --git a/go.mod b/go.mod index 49e6019..3a579d9 100644 --- a/go.mod +++ b/go.mod @@ -6,13 +6,13 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/hashicorp/go-cty v1.4.1-0.20200414143053-d3edf31b6320 - github.com/hashicorp/go-retryablehttp v0.7.7 github.com/hashicorp/terraform-plugin-docs v0.19.4 github.com/hashicorp/terraform-plugin-log v0.9.0 github.com/hashicorp/terraform-plugin-sdk/v2 v2.34.0 github.com/jarcoal/httpmock v1.3.1 github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.28.0 + golang.org/x/sync v0.8.0 ) require ( @@ -75,7 +75,6 @@ require ( golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/net v0.25.0 // indirect - golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.26.0 // indirect golang.org/x/text v0.19.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect diff --git a/go.sum b/go.sum index 9a1f07c..d01abe3 100644 --- a/go.sum +++ b/go.sum @@ -84,8 +84,6 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-plugin v1.6.0 h1:wgd4KxHJTVGGqWBq4QPB1i5BZNEx9BR8+OFmHDmTk8A= github.com/hashicorp/go-plugin v1.6.0/go.mod h1:lBS5MtSSBZk0SHc66KACcjjlU6WzEVP/8pwz68aMkCI= -github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= -github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= diff --git a/internal/bitwarden/webapi/client.go b/internal/bitwarden/webapi/client.go index cab6c28..2b59181 100644 --- a/internal/bitwarden/webapi/client.go +++ b/internal/bitwarden/webapi/client.go @@ -8,17 +8,23 @@ import ( "fmt" "io" "mime/multipart" + "net/http" "net/url" "strings" "time" - "github.com/hashicorp/go-retryablehttp" "github.com/hashicorp/terraform-plugin-log/tflog" "github.com/maxlaverse/terraform-provider-bitwarden/internal/bitwarden/crypto" "github.com/maxlaverse/terraform-provider-bitwarden/internal/bitwarden/crypto/keybuilder" "github.com/maxlaverse/terraform-provider-bitwarden/internal/bitwarden/models" ) +const ( + defaultRequestTimeout = 10 * time.Second + maxConcurrentRequests = 4 + maxRetryAttempts = 3 +) + type Client interface { CreateFolder(ctx context.Context, obj Folder) (*Folder, error) CreateObject(context.Context, models.Object) (*models.Object, error) @@ -55,22 +61,22 @@ type Client interface { func NewClient(serverURL, deviceIdentifier, providerVersion string, opts ...Options) Client { c := &client{ - device: DeviceInformation(deviceIdentifier, providerVersion), - serverURL: strings.TrimSuffix(serverURL, "/"), - httpClient: retryablehttp.NewClient(), + device: DeviceInformation(deviceIdentifier, providerVersion), + serverURL: strings.TrimSuffix(serverURL, "/"), + httpClient: &http.Client{ + Transport: NewRetryRoundTripper(maxConcurrentRequests, maxRetryAttempts, defaultRequestTimeout), + }, } for _, o := range opts { o(c) } - c.httpClient.Logger = nil - c.httpClient.CheckRetry = CustomRetryPolicy - c.httpClient.HTTPClient.Timeout = 10 * time.Second + return c } type client struct { device deviceInfoWithOfficialFallback - httpClient *retryablehttp.Client + httpClient *http.Client serverURL string sessionAccessToken string } @@ -86,7 +92,7 @@ func (c *client) CreateFolder(ctx context.Context, obj Folder) (*Folder, error) func (c *client) CreateObject(ctx context.Context, obj models.Object) (*models.Object, error) { var err error - var httpReq *retryablehttp.Request + var httpReq *http.Request if len(obj.CollectionIds) != 0 { cipherCreationRequest := CreateCipherRequest{ Cipher: obj, @@ -506,8 +512,8 @@ func (c *client) Sync(ctx context.Context) (*SyncResponse, error) { return doRequest[SyncResponse](ctx, c.httpClient, httpReq) } -func (c *client) prepareRequest(ctx context.Context, reqMethod, reqUrl string, reqBody interface{}) (*retryablehttp.Request, error) { - var httpReq *retryablehttp.Request +func (c *client) prepareRequest(ctx context.Context, reqMethod, reqUrl string, reqBody interface{}) (*http.Request, error) { + var httpReq *http.Request var err error if reqBody != nil { @@ -525,12 +531,12 @@ func (c *client) prepareRequest(ctx context.Context, reqMethod, reqUrl string, r return nil, fmt.Errorf("unable to marshall request body: %w", err) } } - httpReq, err = retryablehttp.NewRequestWithContext(ctx, reqMethod, reqUrl, bytes.NewBuffer(bodyBytes)) + httpReq, err = http.NewRequestWithContext(ctx, reqMethod, reqUrl, bytes.NewBuffer(bodyBytes)) if httpReq != nil && len(contentType) > 0 { httpReq.Header.Add("Content-Type", contentType) } } else { - httpReq, err = retryablehttp.NewRequestWithContext(ctx, reqMethod, reqUrl, nil) + httpReq, err = http.NewRequestWithContext(ctx, reqMethod, reqUrl, nil) } if err != nil { @@ -549,7 +555,7 @@ func (c *client) prepareRequest(ctx context.Context, reqMethod, reqUrl string, r return httpReq, nil } -func doRequest[T any](ctx context.Context, httpClient *retryablehttp.Client, httpReq *retryablehttp.Request) (*T, error) { +func doRequest[T any](ctx context.Context, httpClient *http.Client, httpReq *http.Request) (*T, error) { logRequest(ctx, httpReq) resp, err := httpClient.Do(httpReq) @@ -560,9 +566,12 @@ func doRequest[T any](ctx context.Context, httpClient *retryablehttp.Client, htt body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading response body from to '%s': %w", httpReq.URL, err) + return nil, fmt.Errorf("error reading response body from '%s %s': %w", httpReq.Method, httpReq.URL, err) } + debugInfo := map[string]interface{}{"status_code": resp.StatusCode, "url": httpReq.URL.RequestURI(), "body": string(body), "headers": resp.Header} + tflog.Trace(ctx, "Response from Bitwarden server", debugInfo) + if resp.StatusCode != 200 { return nil, fmt.Errorf("bad response status code for '%s': %d!=200, body:%s", httpReq.URL, resp.StatusCode, string(body)) } @@ -580,23 +589,37 @@ func doRequest[T any](ctx context.Context, httpClient *retryablehttp.Client, htt fmt.Printf("Body to unmarshall: %s\n", string(body)) return nil, fmt.Errorf("error unmarshalling response from '%s': %w", httpReq.URL, err) } - debugInfo := map[string]interface{}{"url": httpReq.URL.RequestURI(), "body": string(body)} - tflog.Trace(ctx, "Response from Bitwarden server", debugInfo) return &res, nil } -func logRequest(ctx context.Context, httpReq *retryablehttp.Request) { - bodyCopy, err := httpReq.BodyBytes() - if err != nil { - tflog.Trace(ctx, "Unable to re-read request body", map[string]interface{}{"error": err}) - } +func logRequest(ctx context.Context, httpReq *http.Request) { debugInfo := map[string]interface{}{ "url": httpReq.URL.RequestURI(), "method": httpReq.Method, "headers": httpReq.Header, - "body": string(bodyCopy), + } + + if httpReq.Body != nil { + bodyCopy, newBody, err := readAndRestoreBody(httpReq.Body) + if err != nil { + tflog.Trace(ctx, "Unable to re-read request body", map[string]interface{}{"error": err}) + } + httpReq.Body = newBody + debugInfo["body"] = string(bodyCopy) } tflog.Trace(ctx, "Request to Bitwarden server ", debugInfo) } + +func readAndRestoreBody(rc io.ReadCloser) ([]byte, io.ReadCloser, error) { + var buf bytes.Buffer + + tee := io.TeeReader(rc, &buf) + + body, err := io.ReadAll(tee) + if err != nil { + return nil, nil, err + } + return body, io.NopCloser(bytes.NewReader(buf.Bytes())), nil +} diff --git a/internal/bitwarden/webapi/custom_retry_policy.go b/internal/bitwarden/webapi/custom_retry_policy.go deleted file mode 100644 index 9761e1e..0000000 --- a/internal/bitwarden/webapi/custom_retry_policy.go +++ /dev/null @@ -1,39 +0,0 @@ -package webapi - -import ( - "context" - "net" - "net/http" - - "github.com/hashicorp/go-retryablehttp" - "github.com/hashicorp/terraform-plugin-log/tflog" -) - -func CustomRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { - debugInfo := map[string]interface{}{ - "error": err, - } - if resp != nil { - debugInfo["status_code"] = resp.StatusCode - debugInfo["status_message"] = resp.Status - } - - var willRetry bool - var handlerErr error - - if err != nil { - if err == context.DeadlineExceeded { - willRetry = false - } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - willRetry = false - } - } else { - willRetry, handlerErr = retryablehttp.DefaultRetryPolicy(ctx, resp, err) - } - - debugInfo["will_retry"] = willRetry - debugInfo["handler_error"] = handlerErr - tflog.Trace(ctx, "retry_handler", debugInfo) - - return willRetry, handlerErr -} diff --git a/internal/bitwarden/webapi/options.go b/internal/bitwarden/webapi/options.go index abbefd1..4770f50 100644 --- a/internal/bitwarden/webapi/options.go +++ b/internal/bitwarden/webapi/options.go @@ -6,12 +6,16 @@ type Options func(c Client) func DisableRetries() Options { return func(c Client) { - c.(*client).httpClient.RetryMax = 0 + roundTripper, ok := c.(*client).httpClient.Transport.(*RetryRoundTripper) + if !ok { + return + } + roundTripper.DisableRetries = true } } func WithCustomClient(httpClient http.Client) Options { return func(c Client) { - c.(*client).httpClient.HTTPClient = &httpClient + c.(*client).httpClient = &httpClient } } diff --git a/internal/bitwarden/webapi/retry_round_tripper.go b/internal/bitwarden/webapi/retry_round_tripper.go new file mode 100644 index 0000000..0482119 --- /dev/null +++ b/internal/bitwarden/webapi/retry_round_tripper.go @@ -0,0 +1,180 @@ +package webapi + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "math" + "net" + "net/http" + "strconv" + "time" + + "github.com/hashicorp/terraform-plugin-log/tflog" + "golang.org/x/sync/semaphore" +) + +type RetryRoundTripper struct { + DisableRetries bool + Transport http.RoundTripper + + concurrentRequestsSem *semaphore.Weighted + maxLowLevelRetries int + requestTimeout time.Duration +} + +// maxLowLevelRetries is the maximum number of retries for low-level errors (e.g. timeouts). +// A value of 0 means no retries. +func NewRetryRoundTripper(maxConcurrentRequests int, maxLowLevelRetries int, requestTimeout time.Duration) *RetryRoundTripper { + return &RetryRoundTripper{ + Transport: http.DefaultTransport, + concurrentRequestsSem: semaphore.NewWeighted(int64(maxConcurrentRequests)), + maxLowLevelRetries: maxLowLevelRetries, + requestTimeout: requestTimeout, + } +} + +func (rrt *RetryRoundTripper) RoundTrip(httpReq *http.Request) (*http.Response, error) { + err := rrt.concurrentRequestsSem.Acquire(httpReq.Context(), 1) + if err == nil { + defer rrt.concurrentRequestsSem.Release(1) + } + + ctx := httpReq.Context() + attemptNumber := 0 + for { + attemptNumber += 1 + + resp, shouldRetry, err := rrt.doRequest(ctx, httpReq, attemptNumber) + if err != nil { + return nil, err + } + + if !shouldRetry || rrt.DisableRetries { + return resp, nil + } + } +} + +func (rrt *RetryRoundTripper) doRequest(ctx context.Context, httpReq *http.Request, attemptNumber int) (*http.Response, bool, error) { + ctx, cancel := context.WithTimeout(ctx, rrt.requestTimeout) + defer cancel() + + resp, err := rrt.Transport.RoundTrip(httpReq.WithContext(ctx)) + + // Successfully got an HTTP response that is not a 429 + if err == nil && resp.StatusCode != http.StatusTooManyRequests { + // We read the body as we're cancelling the context when leaving the function. + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return resp, false, fmt.Errorf("failed to read response body in round tripper: %w", err) + } + + resp.Body = io.NopCloser(bytes.NewReader(body)) + return resp, false, nil + } + + // Got a low-level error, or a 429 HTTP response. We're returning the error + // so let's not log anything additional. + if !rrt.isRetryableError(err, attemptNumber, httpReq.Method) { + return nil, false, err + } + + // Retryable request that had a response. A body is present, throw it away. + io.ReadAll(resp.Body) + resp.Body.Close() + + var waitDuration time.Duration + debugInfo := map[string]interface{}{ + "url": httpReq.URL.RequestURI(), + "method": httpReq.Method, + "attempt_number": attemptNumber, + "is_retryable": true, + } + + if err != nil { + debugInfo["error"] = err + waitDuration = backoff(attemptNumber) + } else if resp.StatusCode == http.StatusTooManyRequests { + debugInfo["status_code"] = resp.StatusCode + debugInfo["status_message"] = resp.Status + waitDuration = tryToReadWaitDurationFromHeaders(resp) + } + + debugInfo["wait_duration_sec"] = waitDuration.Seconds() + tflog.Info(ctx, "retry_round_tripper", debugInfo) + + return resp, true, sleepWithContext(ctx, waitDuration) +} + +func (rrt *RetryRoundTripper) isRetryableError(err error, attemptNumber int, httpMethod string) bool { + if err != nil { + return false + } + + if attemptNumber >= rrt.maxLowLevelRetries-1 { + return false + } + if isConnectTimeout(err) { + return true + } + if isReadTimeout(err) && httpMethod == http.MethodGet { + return true + } + return false +} + +func tryToReadWaitDurationFromHeaders(resp *http.Response) time.Duration { + retryAfterRaw := resp.Header.Get("X-Retry-After") + if len(retryAfterRaw) != 0 { + retryAfter, err := strconv.ParseInt(retryAfterRaw, 10, 64) + if err == nil { + return time.Minute * time.Duration(retryAfter) + } + } + return 0 +} + +func isConnectTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + opErr, ok := netErr.(*net.OpError) + if ok && opErr.Op == "dial" { + return true + } + } + return false +} + +func isReadTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + opErr, ok := netErr.(*net.OpError) + if ok && opErr.Op == "read" { + return true + } + } + return false +} + +func backoff(attempt int) time.Duration { + maxInterval := 30 * time.Second + delay := time.Duration(math.Pow(2, float64(attempt))) * time.Second + if delay > maxInterval { + delay = maxInterval + } + + return delay +} + +func sleepWithContext(ctx context.Context, duration time.Duration) error { + select { + case <-ctx.Done(): + return fmt.Errorf("sleep cancelled: %v", ctx.Err()) + case <-time.After(duration): + return nil + } +}