Skip to content

Commit

Permalink
feat: 优化 TimestampChecker (#9)
Browse files Browse the repository at this point in the history
* feat: 优化 TimestampChecker

* feat: 优化 TimestampChecker
  • Loading branch information
mmyj authored Oct 8, 2020
1 parent 4eb0aa4 commit 2a0042e
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 20 deletions.
12 changes: 8 additions & 4 deletions apiaccessor/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ type Accessor interface {
}

var (
errArgLack = errors.New("arg lack")
errSignatureUnmatched = errors.New("signature is unmatched")
errTimestampTimeout = errors.New("timestamp time out")
errNonceUsed = errors.New("nonce is used")
// ErrArgLack represent the request's arguments are lack.
ErrArgLack = errors.New("arg lack")
// ErrSignatureUnmatched represent the signature of the request's arguments is wrong.
ErrSignatureUnmatched = errors.New("signature is unmatched")
// ErrTimestampTimeout represent the timestamp argument timeout.
ErrTimestampTimeout = errors.New("timestamp time out")
// ErrNonceUsed represent the nonce argument had been used.
ErrNonceUsed = errors.New("nonce is used")
)

const (
Expand Down
16 changes: 11 additions & 5 deletions apiaccessor/base_accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,21 @@ func defEvalSignatureFunc(move uint) EvalSignature {
}

func defTimestampChecker(timestamp int64) error {
const sec = 5
dt := time.Now().Unix() - timestamp
const (
sec = 5
timeFormat = "2006/01/02 15:04:05"
)
now := time.Now()
dt := now.Unix() - timestamp
if dt > sec || dt < -sec {
return errTimestampTimeout
nowTimeStr := now.Format(timeFormat)
timestampStr := time.Unix(timestamp, 0).In(now.Location()).Format(timeFormat)
return fmt.Errorf("%w: now %s, get %s", ErrTimestampTimeout, nowTimeStr, timestampStr)
}
return nil
}

func defNonceChecker(nonce string) error {
func defNonceChecker(_ string) error {
return nil
}

Expand Down Expand Up @@ -76,7 +82,7 @@ func (a *baseAccessor) CheckSignature() error {
signature := a.evalSignatureFunc(argText)
argSignature := a.args.kv[signatureTag]
if signature != argSignature {
return fmt.Errorf("%w: want %s, get %s", errSignatureUnmatched, signature, argSignature)
return fmt.Errorf("%w: want %s, get %s", ErrSignatureUnmatched, signature, argSignature)
}
return nil
}
Expand Down
17 changes: 17 additions & 0 deletions apiaccessor/base_accessor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package apiaccessor

import (
"errors"
"testing"
"time"

"github.com/go-playground/assert/v2"
)

func TestDefTimestampChecker(t *testing.T) {
err := defTimestampChecker(1602146001) // 2020-10-08 16:33:21
assert.Equal(t, errors.Is(err, ErrTimestampTimeout), true)
t.Log(err)
err = defTimestampChecker(time.Now().Unix())
assert.Equal(t, err, nil)
}
8 changes: 4 additions & 4 deletions apiaccessor/query_accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ func NewQueryAccessor(query url.Values, secretKey string, setters ...Setter) (*Q
for key, vs := range query {
v := vs[0]
if len(v) == 0 {
return nil, fmt.Errorf("%w: %s", errArgLack, key)
return nil, fmt.Errorf("%w: %s", ErrArgLack, key)
}
qa.args.append(key, v)
}
qa.args.append(secretKeyTag, secretKey)
if len(qa.args.kv[nonceTag]) == 0 {
return nil, fmt.Errorf("%w: %s", errArgLack, nonceTag)
return nil, fmt.Errorf("%w: %s", ErrArgLack, nonceTag)
}
if len(qa.args.kv[timestampTag]) == 0 {
return nil, fmt.Errorf("%w: %s", errArgLack, timestampTag)
return nil, fmt.Errorf("%w: %s", ErrArgLack, timestampTag)
}
if len(qa.args.kv[signatureTag]) == 0 {
return nil, fmt.Errorf("%w: %s", errArgLack, signatureTag)
return nil, fmt.Errorf("%w: %s", ErrArgLack, signatureTag)
}

for _, setter := range setters {
Expand Down
12 changes: 6 additions & 6 deletions apiaccessor/query_accessor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ import (
func TestNewQueryAccessor(t *testing.T) {
query := url.Values{}
_, err := NewQueryAccessor(query, "123")
assert.Equal(t, errors.Is(err, errArgLack), true)
assert.Equal(t, errors.Is(err, ErrArgLack), true)

query = url.Values{
nonceTag: []string{"12345"},
}
_, err = NewQueryAccessor(query, "123")
assert.Equal(t, errors.Is(err, errArgLack), true)
assert.Equal(t, errors.Is(err, ErrArgLack), true)

query = url.Values{
nonceTag: []string{"12345"},
Expand Down Expand Up @@ -52,7 +52,7 @@ func TestCheckSignature(t *testing.T) {
qa, err := NewQueryAccessor(query, "123")
assert.Equal(t, err, nil)
err = qa.CheckSignature()
assert.Equal(t, errors.Is(err, errSignatureUnmatched), true)
assert.Equal(t, errors.Is(err, ErrSignatureUnmatched), true)

query = url.Values{
nonceTag: []string{"12345"},
Expand All @@ -78,7 +78,7 @@ func TestCheckTimestamp(t *testing.T) {
qa, err := NewQueryAccessor(query, "123")
assert.Equal(t, err, nil)
err = qa.CheckTimestamp()
assert.Equal(t, errors.Is(err, errTimestampTimeout), true)
assert.Equal(t, errors.Is(err, ErrTimestampTimeout), true)

query = url.Values{
nonceTag: []string{"12345"},
Expand All @@ -97,7 +97,7 @@ func TestCheckNonce(t *testing.T) {
nonceMap := make(map[string]bool)
mockNonceChecker := func(nonce string) error {
if _, ok := nonceMap[nonce]; ok {
return errNonceUsed
return ErrNonceUsed
}
nonceMap[nonce] = true
return nil
Expand All @@ -115,5 +115,5 @@ func TestCheckNonce(t *testing.T) {
err = qa.CheckNonce()
assert.Equal(t, err, nil)
err = qa.CheckNonce()
assert.Equal(t, errors.Is(err, errNonceUsed), true)
assert.Equal(t, errors.Is(err, ErrNonceUsed), true)
}
10 changes: 9 additions & 1 deletion apiaccessor/setter.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func WithGeneralRedisNonceChecker(client redis.Cmdable, sec int64, keyGenFunc Ke
return err
}
if re == 1 {
return errNonceUsed
return ErrNonceUsed
}
return nil
}
Expand All @@ -61,3 +61,11 @@ func WithNonceChecker(nc NonceChecker) Setter {
return nil
}
}

// WithTimestampChecker set a custom TimestampChecker for the Accessor
func WithTimestampChecker(tc TimestampChecker) Setter {
return func(b *baseAccessor) error {
b.timestampChecker = tc
return nil
}
}

0 comments on commit 2a0042e

Please sign in to comment.