Skip to content

Commit

Permalink
fix: first 401 triggers Login to refresh token
Browse files Browse the repository at this point in the history
  • Loading branch information
masonkatz committed Feb 16, 2024
1 parent f4358de commit e2f161a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
38 changes: 21 additions & 17 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ type Options struct {

// Client is a connection to a REST service.
type Client struct {
BaseURL string
Password string
Username string
Logger *slog.Logger
Debug bool
Token string
UserAgent string
ErrorParser func(code int, r io.Reader) error
BaseURL string
Password string
Username string
Logger *slog.Logger
Debug bool
Token string
UserAgent string
NewError func(code int, r io.Reader) error
http.Client
}

Expand Down Expand Up @@ -82,10 +82,10 @@ func New(o *Options) *Client {
},
Timeout: o.Timeout,
},
BaseURL: fmt.Sprintf("https://%s", o.Address),
Username: o.Username,
Password: o.Password,
ErrorParser: defaultErrorParser,
BaseURL: fmt.Sprintf("https://%s", o.Address),
Username: o.Username,
Password: o.Password,
NewError: newDefaultError,
}
}

Expand Down Expand Up @@ -185,9 +185,7 @@ func (c *Client) requestWrapper(ctx context.Context, method, path string, in, ou
}

if err := c.request(ctx, method, path, in, out); err != nil {
var e defaultError // 401 means the token expired so automatically re-login and try again.

if errors.As(err, &e) && e.Code == http.StatusUnauthorized {
if errors.Is(err, errUnauthorized) {
c.Logger.Info("token expired")

if err := c.Login(ctx); err != nil {
Expand Down Expand Up @@ -238,8 +236,12 @@ func (c *Client) Request(ctx context.Context, method, path string, r io.Reader)
fmt.Printf("RESPONSE:\n%s", string(b))
}

if resp.StatusCode == http.StatusUnauthorized {
return nil, errUnauthorized // special error to trigger re-Login on stale token
}

if resp.StatusCode < http.StatusOK || resp.StatusCode > 299 {
return nil, c.ErrorParser(resp.StatusCode, resp.Body)
return nil, c.NewError(resp.StatusCode, resp.Body)
}

var buf bytes.Buffer
Expand Down Expand Up @@ -277,6 +279,8 @@ func (c *Client) request(ctx context.Context, method, path string, in, out inter
return nil
}

var errUnauthorized = errors.New("unauthorized")

type defaultError struct {
Code int
Body string
Expand All @@ -286,7 +290,7 @@ func (e defaultError) Error() string {
return fmt.Sprintf("%s (%d) - %s", strings.ToLower(http.StatusText(e.Code)), e.Code, e.Body)
}

func defaultErrorParser(code int, r io.Reader) error {
func newDefaultError(code int, r io.Reader) error {
var b bytes.Buffer

if _, err := b.ReadFrom(r); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions sifi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Client struct {
// NewClient returns a new client.
func NewClient(o *client.Options) *Client {
c := client.New(o)
c.ErrorParser = errorParser
c.NewError = newError

prefix := APIPrefix + "/" + APIVersion

Expand All @@ -37,7 +37,7 @@ func NewClient(o *client.Options) *Client {
}
}

func errorParser(code int, r io.Reader) error {
func newError(code int, r io.Reader) error {
var resp ResponseError

decodeErr := json.NewDecoder(r).Decode(&resp)
Expand Down

0 comments on commit e2f161a

Please sign in to comment.