diff --git a/internal/bitwarden/webapi/client.go b/internal/bitwarden/webapi/client.go index fde5c9f..c12cdbc 100644 --- a/internal/bitwarden/webapi/client.go +++ b/internal/bitwarden/webapi/client.go @@ -3,13 +3,14 @@ package webapi import ( "bytes" "context" + "encoding/base64" "encoding/json" "fmt" "io" "mime/multipart" - "net/http" "net/url" "strings" + "time" "github.com/hashicorp/go-retryablehttp" "github.com/hashicorp/terraform-plugin-log/tflog" @@ -47,26 +48,6 @@ type Client interface { Sync(ctx context.Context) (*SyncResponse, error) } -type Options func(c Client) - -func DisableRetries() Options { - return func(c Client) { - c.(*client).httpClient.RetryMax = 0 - } -} - -func WithCustomClient(httpClient http.Client) Options { - return func(c Client) { - c.(*client).httpClient.HTTPClient = &httpClient - } -} - -func WithDeviceIdentifier(deviceIdentifier string) Options { - return func(c Client) { - c.(*client).deviceIdentifier = deviceIdentifier - } -} - func NewClient(serverURL string, opts ...Options) Client { c := &client{ deviceName: deviceName, @@ -78,7 +59,8 @@ func NewClient(serverURL string, opts ...Options) Client { o(c) } c.httpClient.Logger = nil - + c.httpClient.CheckRetry = CustomRetryPolicy + c.httpClient.HTTPClient.Timeout = 10 * time.Second return c } @@ -299,6 +281,8 @@ func (c *client) LoginWithPassword(ctx context.Context, username, password strin if err != nil { return nil, fmt.Errorf("error preparing login with password request: %w", err) } + httpReq.Header.Add("Auth-Email", base64.StdEncoding.EncodeToString([]byte(username))) + httpReq.Header.Add("Device-Type", c.deviceType) tokenResp, err := doRequest[TokenResponse](ctx, c.httpClient, httpReq) if err != nil { @@ -386,9 +370,9 @@ func (c *client) Sync(ctx context.Context) (*SyncResponse, error) { func (c *client) prepareRequest(ctx context.Context, reqMethod, reqUrl string, reqBody interface{}) (*retryablehttp.Request, error) { var httpReq *retryablehttp.Request var err error - contentType := "" - debugInfo := map[string]interface{}{"url": reqUrl, "method": reqMethod} + if reqBody != nil { + contentType := "" var bodyBytes []byte if v, ok := reqBody.(url.Values); ok { bodyBytes = []byte(v.Encode()) @@ -403,7 +387,9 @@ func (c *client) prepareRequest(ctx context.Context, reqMethod, reqUrl string, r } } httpReq, err = retryablehttp.NewRequestWithContext(ctx, reqMethod, reqUrl, bytes.NewBuffer(bodyBytes)) - debugInfo["body"] = string(bodyBytes) + if len(contentType) > 0 { + httpReq.Header.Add("Content-Type", contentType) + } } else { httpReq, err = retryablehttp.NewRequestWithContext(ctx, reqMethod, reqUrl, nil) } @@ -414,17 +400,14 @@ func (c *client) prepareRequest(ctx context.Context, reqMethod, reqUrl string, r if len(c.sessionAccessToken) > 0 { httpReq.Header.Add("authorization", fmt.Sprintf("Bearer %s", c.sessionAccessToken)) } - if len(contentType) > 0 { - httpReq.Header.Add("Content-Type", contentType) - } - - debugInfo["headers"] = httpReq.Header - tflog.Trace(ctx, "Request to Bitwarden server", debugInfo) + httpReq.Header.Add("Accept", "application/json") return httpReq, nil } func doRequest[T any](ctx context.Context, httpClient *retryablehttp.Client, httpReq *retryablehttp.Request) (*T, error) { + logRequest(ctx, httpReq) + resp, err := httpClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("error doing request to '%s': %w", httpReq.URL, err) @@ -453,7 +436,22 @@ func doRequest[T any](ctx context.Context, httpClient *retryablehttp.Client, htt fmt.Printf("Body to unmarshall: %s\n", string(body)) return nil, fmt.Errorf("error unmarshalling response from '%s': %w", httpReq.URL, err) } - tflog.Trace(ctx, "Response from Bitwarden server", map[string]interface{}{"url": httpReq.URL, "body": string(body)}) + tflog.Trace(ctx, "Response from Bitwarden server", map[string]interface{}{"url": httpReq.URL.RequestURI(), "body": string(body)}) return &res, nil } + +func logRequest(ctx context.Context, httpReq *retryablehttp.Request) { + bodyCopy, err := httpReq.BodyBytes() + if err != nil { + tflog.Trace(ctx, "Unable to re-read request body", map[string]interface{}{"error": err}) + } + debugInfo := map[string]interface{}{ + "url": httpReq.URL.RequestURI(), + "method": httpReq.Method, + "headers": httpReq.Header, + "body": string(bodyCopy), + } + + tflog.Trace(ctx, "Request to Bitwarden server ", debugInfo) +} diff --git a/internal/bitwarden/webapi/custom_retry_policy.go b/internal/bitwarden/webapi/custom_retry_policy.go new file mode 100644 index 0000000..9761e1e --- /dev/null +++ b/internal/bitwarden/webapi/custom_retry_policy.go @@ -0,0 +1,39 @@ +package webapi + +import ( + "context" + "net" + "net/http" + + "github.com/hashicorp/go-retryablehttp" + "github.com/hashicorp/terraform-plugin-log/tflog" +) + +func CustomRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { + debugInfo := map[string]interface{}{ + "error": err, + } + if resp != nil { + debugInfo["status_code"] = resp.StatusCode + debugInfo["status_message"] = resp.Status + } + + var willRetry bool + var handlerErr error + + if err != nil { + if err == context.DeadlineExceeded { + willRetry = false + } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + willRetry = false + } + } else { + willRetry, handlerErr = retryablehttp.DefaultRetryPolicy(ctx, resp, err) + } + + debugInfo["will_retry"] = willRetry + debugInfo["handler_error"] = handlerErr + tflog.Trace(ctx, "retry_handler", debugInfo) + + return willRetry, handlerErr +} diff --git a/internal/bitwarden/webapi/options.go b/internal/bitwarden/webapi/options.go new file mode 100644 index 0000000..15521ad --- /dev/null +++ b/internal/bitwarden/webapi/options.go @@ -0,0 +1,23 @@ +package webapi + +import "net/http" + +type Options func(c Client) + +func DisableRetries() Options { + return func(c Client) { + c.(*client).httpClient.RetryMax = 0 + } +} + +func WithCustomClient(httpClient http.Client) Options { + return func(c Client) { + c.(*client).httpClient.HTTPClient = &httpClient + } +} + +func WithDeviceIdentifier(deviceIdentifier string) Options { + return func(c Client) { + c.(*client).deviceIdentifier = deviceIdentifier + } +}