diff --git a/go.mod b/go.mod index 1af2cb1..efc2f48 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,11 @@ module github.com/go-pay/wechat-sdk go 1.21 require ( - github.com/go-pay/bm v0.0.4 + github.com/go-pay/bm v0.0.5 github.com/go-pay/crypto v0.0.1 - github.com/go-pay/util v0.0.3 - github.com/go-pay/xhttp v0.0.2 + github.com/go-pay/smap v0.0.2 + github.com/go-pay/util v0.0.4 + github.com/go-pay/xhttp v0.0.3 github.com/go-pay/xlog v0.0.3 github.com/go-pay/xtime v0.0.2 ) diff --git a/go.sum b/go.sum index 9f617b4..faf3898 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,13 @@ -github.com/go-pay/bm v0.0.4 h1:MUECRx1t0MkbQ7Yzk2qseKBOKoUi4N8WZ+alZBrZesg= -github.com/go-pay/bm v0.0.4/go.mod h1:S7ZAxWtyjm7PX54cna4N/RzJ1JAZG8EDukZ2fZaZ+qk= +github.com/go-pay/bm v0.0.5 h1:ZAg6j1Wagc8JZ88ja7VgEF/g+kKOFzLRNc43bm0ivZc= +github.com/go-pay/bm v0.0.5/go.mod h1:S7ZAxWtyjm7PX54cna4N/RzJ1JAZG8EDukZ2fZaZ+qk= github.com/go-pay/crypto v0.0.1 h1:B6InT8CLfSLc6nGRVx9VMJRBBazFMjr293+jl0lLXUY= github.com/go-pay/crypto v0.0.1/go.mod h1:41oEIvHMKbNcYlWUlRWtsnC6+ASgh7u29z0gJXe5bes= -github.com/go-pay/util v0.0.3 h1:0OjERb7MAVpM2gLPnBESLdMsosYyJ4i31V2/YZBiPjw= -github.com/go-pay/util v0.0.3/go.mod h1:qM8VbyF1n7YAPZBSJONSPMPsPedhUTktewUAdf1AjPg= -github.com/go-pay/xhttp v0.0.2 h1:O8rnd/d03WsboFtUthwFMg61ikHRfYHyD1m0JiUx60g= -github.com/go-pay/xhttp v0.0.2/go.mod h1:BnuvXpLKkXTFMOBc5MTb0hxdrstwunbzQPJUZOsNbt4= +github.com/go-pay/smap v0.0.2 h1:kKflYor5T5FgZltPFBMTFfjJvqYMHr5VnIFSEyhVTcA= +github.com/go-pay/smap v0.0.2/go.mod h1:HW9oAo0okuyDYsbpbj5fJFxnNj/BZorRGFw26SxrNWw= +github.com/go-pay/util v0.0.4 h1:TuwSU9o3Qd7m9v1PbzFuIA/8uO9FJnA6P7neG/NwPyk= +github.com/go-pay/util v0.0.4/go.mod h1:Tsdhs8Ib9J9b4+NKNO1PHh5hWHhlg98PthsX0ckq6PM= +github.com/go-pay/xhttp v0.0.3 h1:9Vke2QeY0xs8E9oyb3bi94v47N25ZdGgZOIG1hgCgKA= +github.com/go-pay/xhttp v0.0.3/go.mod h1:LDNKLp+C6UJRZSAcxI4z4BYtRs3ksbgxPQl1W9HQGXs= github.com/go-pay/xlog v0.0.3 h1:avyMhCL/JgBHreoGx/am/kHxfs1udDOAeVqbmzP/Yes= github.com/go-pay/xlog v0.0.3/go.mod h1:mH47xbobrdsSHWsmFtSF5agWbMHFP+tK0ZbVCk5OAEw= github.com/go-pay/xtime v0.0.2 h1:7YR4/iuELsEHpJ6LUO0SVK80hQxDO9MLCfuVYIiTCRM= diff --git a/mini/access_token.go b/mini/access_token.go index 68dc1da..7211b24 100644 --- a/mini/access_token.go +++ b/mini/access_token.go @@ -4,58 +4,10 @@ import ( "context" "fmt" "runtime" - "strconv" "time" -) - -// 获取小程序全局唯一后台接口调用凭据(access_token) -// 微信小程序文档:https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/mp-access-token/getAccessToken.html -//func (s *SDK) getAccessToken() (err error) { -// defer func() { -// if err != nil { -// // reset default refresh internal -// s.RefreshInternal = time.Second * 20 -// if s.callback != nil { -// go s.callback("", "", 0, err) -// } -// } -// }() -// -// path := "/cgi-bin/token?grant_type=client_credential&appid=" + s.Appid + "&secret=" + s.Secret -// at := &AccessToken{} -// if _, err = s.DoRequestGet(s.ctx, path, at); err != nil { -// return -// } -// if at.Errcode != Success { -// err = fmt.Errorf("errcode(%d), errmsg(%s)", at.Errcode, at.Errmsg) -// return -// } -// s.accessToken = at.AccessToken -// s.RefreshInternal = time.Second * time.Duration(at.ExpiresIn) -// if s.callback != nil { -// go s.callback(s.Appid, at.AccessToken, at.ExpiresIn, nil) -// } -// return nil -//} -//func (s *SDK) goAutoRefreshAccessToken() { -// defer func() { -// if r := recover(); r != nil { -// buf := make([]byte, 64<<10) -// buf = buf[:runtime.Stack(buf, false)] -// s.logger.Errorf("mini_goAutoRefreshAccessToken: panic recovered: %s\n%s", r, buf) -// } -// }() -// for { -// // every one hour, request new access token, default 10s -// time.Sleep(s.RefreshInternal / 2) -// err := s.getAccessToken() -// if err != nil { -// s.logger.Errorf("get access token error, after 10s retry: %+v", err) -// continue -// } -// } -//} + "github.com/go-pay/bm" +) // 获取稳定版接口调用凭据 // 微信小程序文档:https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/mp-access-token/getStableAccessToken.html @@ -70,9 +22,14 @@ func (s *SDK) getStableAccessToken() (err error) { } }() - path := "/cgi-bin/stable_token?grant_type=client_credential&appid=" + s.Appid + "&secret=" + s.Secret + "&force_refresh=false" + path := "/cgi-bin/stable_token" + body := make(bm.BodyMap) + body.Set("grant_type", "client_credential"). + Set("appid", s.Appid). + Set("secret", s.Secret). + Set("force_refresh", false) at := &AccessToken{} - if _, err = s.DoRequestGet(s.ctx, path, at); err != nil { + if _, err = s.doRequestPost(s.ctx, path, body, at); err != nil { return } if at.Errcode != Success { @@ -106,7 +63,7 @@ func (s *SDK) goAutoRefreshStableAccessToken() { time.Sleep(s.RefreshInternal / 2) err := s.getStableAccessToken() if err != nil { - s.logger.Errorf("get access token error, after 10s retry: %+v", err) + s.logger.Errorf("get stable access token error, after 10s retry: %+v", err) continue } } @@ -129,7 +86,7 @@ func (s *SDK) SetMiniAccessToken(accessToken string) { // ===================================================================================================================== -// 获取接口调用凭据 +// 获取 Access Token // 微信小程序文档:https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/mp-access-token/getAccessToken.html func GetAccessToken(c context.Context, appid, secret string) (at *AccessToken, err error) { uri := HostDefault + "/cgi-bin/token?grant_type=client_credential&appid=" + appid + "&secret=" + secret @@ -143,12 +100,17 @@ func GetAccessToken(c context.Context, appid, secret string) (at *AccessToken, e return at, nil } -// 获取稳定版接口调用凭据 +// 获取 Stable Access Token // 微信小程序文档:https://developers.weixin.qq.com/miniprogram/dev/OpenApiDoc/mp-access-token/getStableAccessToken.html func GetStableAccessToken(c context.Context, appid, secret string, forceRefresh bool) (at *AccessToken, err error) { - uri := HostDefault + "/cgi-bin/stable_token?grant_type=client_credential&appid=" + appid + "&secret=" + secret + "&force_refresh=" + strconv.FormatBool(forceRefresh) + url := HostDefault + "/cgi-bin/stable_token" + body := make(bm.BodyMap) + body.Set("grant_type", "client_credential"). + Set("appid", appid). + Set("secret", secret). + Set("force_refresh", forceRefresh) at = &AccessToken{} - if err = doRequestGet(c, uri, at); err != nil { + if err = doRequestPost(c, url, body, at); err != nil { return nil, err } if at.Errcode != Success { diff --git a/mini/mini.go b/mini/mini.go index e0297e7..5b9897f 100644 --- a/mini/mini.go +++ b/mini/mini.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/go-pay/bm" "github.com/go-pay/util" "github.com/go-pay/util/js" "github.com/go-pay/wechat-sdk" @@ -97,3 +98,16 @@ func doRequestGet(c context.Context, uri string, ptr any) (err error) { } return } + +func doRequestPost(c context.Context, url string, body bm.BodyMap, ptr any) (err error) { + req := xhttp.NewClient().Req() + req.Header.Add(wechat.HeaderRequestID, fmt.Sprintf("%s-%d", util.RandomString(21), time.Now().Unix())) + _, bs, err := req.Post(url).SendBodyMap(body).EndBytes(c) + if err != nil { + return fmt.Errorf("http.request(POST, %s), err:%w", url, err) + } + if err = js.UnmarshalBytes(bs, ptr); err != nil { + return fmt.Errorf("js.UnmarshalBytes(%s, %+v):%w", string(bs), ptr, err) + } + return +} diff --git a/open/access_token.go b/open/access_token.go index 9c1b66e..f9a888f 100644 --- a/open/access_token.go +++ b/open/access_token.go @@ -25,9 +25,8 @@ func (s *SDK) refreshAccessToken(openid, refreshToken string) { } return } - s.mu.Lock() - s.openidAccessTokenMap[at.Openid] = at - s.mu.Unlock() + + s.openidAccessTokenMap.Store(at.Openid, at) if s.callback != nil { go s.callback(&AT{ AccessToken: at.AccessToken, @@ -53,12 +52,13 @@ func (s *SDK) goAutoRefreshAccessTokenJob() { for { // request new access token, default internal 10min time.Sleep(s.autoRefreshTokenInternal) - for k, v := range s.openidAccessTokenMap { + s.openidAccessTokenMap.Range(func(k string, v *AccessToken) bool { // 有效期小于1.5倍轮询时间时,自动刷新) if time.Duration(v.ExpiresIn)*time.Second < (s.autoRefreshTokenInternal*3)/2 { s.refreshAccessToken(k, v.RefreshToken) } - } + return true + }) } } @@ -74,30 +74,24 @@ func (s *SDK) SetAccessTokenRefreshInternal(internal time.Duration) { // GetAccessTokenMap 获取 access_token map,key 为 openid func (s *SDK) GetAccessTokenMap() (openidATMap map[string]*AT) { - openidATMap = make(map[string]*AT, len(s.openidAccessTokenMap)) - if s.openidAccessTokenMap != nil && len(s.openidAccessTokenMap) > 0 { - s.mu.RLock() - defer s.mu.RUnlock() - for k, v := range s.openidAccessTokenMap { - openidATMap[k] = &AT{ - AccessToken: v.AccessToken, - ExpiresIn: v.ExpiresIn, - RefreshToken: v.RefreshToken, - Openid: v.Openid, - Scope: v.Scope, - Unionid: v.Unionid, - } + openidATMap = make(map[string]*AT) + s.openidAccessTokenMap.Range(func(k string, v *AccessToken) bool { + openidATMap[k] = &AT{ + AccessToken: v.AccessToken, + ExpiresIn: v.ExpiresIn, + RefreshToken: v.RefreshToken, + Openid: v.Openid, + Scope: v.Scope, + Unionid: v.Unionid, } - return - } + return true + }) return } // DelAccessToken 根据 openid 删除 map 中维护的 access_token func (s *SDK) DelAccessToken(openid string) { - if s.openidAccessTokenMap != nil { - delete(s.openidAccessTokenMap, openid) - } + s.openidAccessTokenMap.Delete(openid) } // Code2AccessToken 通过 code 获取用户 access_token @@ -123,9 +117,7 @@ func (s *SDK) Code2AccessToken(c context.Context, code string) (at *AccessToken, }, nil) } if s.autoManageToken { - s.mu.Lock() - s.openidAccessTokenMap[at.Openid] = at - s.mu.Unlock() + s.openidAccessTokenMap.Store(at.Openid, at) } return at, nil } @@ -153,9 +145,7 @@ func (s *SDK) RefreshAccessToken(c context.Context, refreshToken string) (at *Ac }, nil) } if s.autoManageToken { - s.mu.Lock() - s.openidAccessTokenMap[at.Openid] = at - s.mu.Unlock() + s.openidAccessTokenMap.Store(at.Openid, at) } return at, nil } @@ -171,9 +161,7 @@ func (s *SDK) CheckAccessToken(c context.Context, accessToken, openid string) (e } if ec.Errcode != Success { if s.autoManageToken { - s.mu.Lock() - delete(s.openidAccessTokenMap, openid) - s.mu.Unlock() + s.DelAccessToken(openid) } err = fmt.Errorf("errcode(%d), errmsg(%s)", ec.Errcode, ec.Errmsg) return err diff --git a/open/open.go b/open/open.go index 253572f..74baacd 100644 --- a/open/open.go +++ b/open/open.go @@ -4,9 +4,9 @@ import ( "context" "fmt" "net/http" - "sync" "time" + "github.com/go-pay/smap" "github.com/go-pay/util" "github.com/go-pay/util/js" "github.com/go-pay/wechat-sdk" @@ -17,13 +17,12 @@ import ( type SDK struct { ctx context.Context DebugSwitch wechat.DebugSwitch - mu sync.RWMutex Appid string Secret string Host string - autoManageToken bool // 是否自动维护刷新 AccessToken - autoRefreshTokenInternal time.Duration // 自动刷新 token 的间隔时间 - openidAccessTokenMap map[string]*AccessToken // key: openid + autoManageToken bool // 是否自动维护刷新 AccessToken + autoRefreshTokenInternal time.Duration // 自动刷新 token 的间隔时间 + openidAccessTokenMap smap.Map[string, *AccessToken] // key: openid hc *xhttp.Client logger xlog.XLogger @@ -49,7 +48,6 @@ func New(appid, secret string, autoManageToken bool) (o *SDK) { } if autoManageToken { o.autoRefreshTokenInternal = time.Minute * 10 - o.openidAccessTokenMap = make(map[string]*AccessToken) go o.goAutoRefreshAccessTokenJob() } return diff --git a/public/access_token.go b/public/access_token.go index 0711a92..d3e8561 100644 --- a/public/access_token.go +++ b/public/access_token.go @@ -1,14 +1,17 @@ package public import ( + "context" "fmt" "runtime" "time" + + "github.com/go-pay/bm" ) -// 获取公众号全局唯一后台接口调用凭据(access_token) -// 公众号文档:https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Get_access_token.html -func (s *SDK) getAccessToken() (err error) { +// 获取公众号全局唯一后台稳定版接口调用凭据(access_token) +// 公众号文档:https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/getStableAccessToken.html +func (s *SDK) getStableAccessToken() (err error) { defer func() { if err != nil { // reset default refresh internal @@ -19,9 +22,14 @@ func (s *SDK) getAccessToken() (err error) { } }() - path := "/cgi-bin/token?grant_type=client_credential&appid=" + s.Appid + "&secret=" + s.Secret + path := "/cgi-bin/stable_token" + body := make(bm.BodyMap) + body.Set("grant_type", "client_credential"). + Set("appid", s.Appid). + Set("secret", s.Secret). + Set("force_refresh", false) at := &AccessToken{} - if _, err = s.DoRequestGet(s.ctx, path, at); err != nil { + if _, err = s.doRequestPost(s.ctx, path, body, at); err != nil { return } if at.Errcode != Success { @@ -36,26 +44,26 @@ func (s *SDK) getAccessToken() (err error) { return nil } -func (s *SDK) goAutoRefreshAccessTokenJob() { +func (s *SDK) goAutoRefreshStableAccessToken() { defer func() { if r := recover(); r != nil { buf := make([]byte, 64<<10) buf = buf[:runtime.Stack(buf, false)] - s.logger.Errorf("public_goAutoRefreshAccessTokenJob: panic recovered: %s\n%s", r, buf) + s.logger.Errorf("public_goAutoRefreshStableAccessTokenJob: panic recovered: %s\n%s", r, buf) time.Sleep(time.Second * 3) - if err := s.getAccessToken(); err != nil { + if err := s.getStableAccessToken(); err != nil { // 失败就不再自动刷新了 return } - s.goAutoRefreshAccessTokenJob() + s.goAutoRefreshStableAccessToken() } }() for { // every one hour, request new access token, default 10s time.Sleep(s.RefreshInternal / 2) - err := s.getAccessToken() + err := s.getStableAccessToken() if err != nil { - s.logger.Errorf("get access token error, after 10s retry: %+v", err) + s.logger.Errorf("get stable access token error, after 10s retry: %+v", err) continue } } @@ -75,3 +83,36 @@ func (s *SDK) GetPublicAccessToken() (at string) { func (s *SDK) SetPublicAccessToken(accessToken string) { s.accessToken = accessToken } + +// 获取 Access Token +// 微信公众号文档:https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Get_access_token.html +func GetAccessToken(c context.Context, appid, secret string) (at *AccessToken, err error) { + uri := HostDefault + "/cgi-bin/token?grant_type=client_credential&appid=" + appid + "&secret=" + secret + at = &AccessToken{} + if err = doRequestGet(c, uri, at); err != nil { + return nil, err + } + if at.Errcode != Success { + return nil, fmt.Errorf("errcode(%d), errmsg(%s)", at.Errcode, at.Errmsg) + } + return at, nil +} + +// 获取 Stable Access Token +// 微信公众号文档:https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/getStableAccessToken.html +func GetStableAccessToken(c context.Context, appid, secret string, forceRefresh bool) (at *AccessToken, err error) { + url := HostDefault + "/cgi-bin/stable_token" + body := make(bm.BodyMap) + body.Set("grant_type", "client_credential"). + Set("appid", appid). + Set("secret", secret). + Set("force_refresh", forceRefresh) + at = &AccessToken{} + if err = doRequestPost(c, url, body, at); err != nil { + return nil, err + } + if at.Errcode != Success { + return nil, fmt.Errorf("errcode(%d), errmsg(%s)", at.Errcode, at.Errmsg) + } + return at, nil +} diff --git a/public/public.go b/public/public.go index d7f7720..d5d8422 100644 --- a/public/public.go +++ b/public/public.go @@ -3,6 +3,7 @@ package public import ( "context" "fmt" + "github.com/go-pay/bm" "net/http" "time" @@ -44,10 +45,10 @@ func New(appid, secret string, autoManageToken bool) (p *SDK, err error) { logger: logger, } if autoManageToken { - if err = p.getAccessToken(); err != nil { + if err = p.getStableAccessToken(); err != nil { return nil, err } - go p.goAutoRefreshAccessTokenJob() + go p.goAutoRefreshStableAccessToken() } return } @@ -84,3 +85,29 @@ func (s *SDK) DoRequestGet(c context.Context, path string, ptr any) (res *http.R } return res, nil } + +func doRequestGet(c context.Context, uri string, ptr any) (err error) { + req := xhttp.NewClient().Req() + req.Header.Add(wechat.HeaderRequestID, fmt.Sprintf("%s-%d", util.RandomString(21), time.Now().Unix())) + _, bs, err := req.Get(uri).EndBytes(c) + if err != nil { + return fmt.Errorf("http.request(GET, %s), err:%w", uri, err) + } + if err = js.UnmarshalBytes(bs, ptr); err != nil { + return fmt.Errorf("js.UnmarshalBytes(%s, %+v):%w", string(bs), ptr, err) + } + return +} + +func doRequestPost(c context.Context, url string, body bm.BodyMap, ptr any) (err error) { + req := xhttp.NewClient().Req() + req.Header.Add(wechat.HeaderRequestID, fmt.Sprintf("%s-%d", util.RandomString(21), time.Now().Unix())) + _, bs, err := req.Post(url).SendBodyMap(body).EndBytes(c) + if err != nil { + return fmt.Errorf("http.request(POST, %s), err:%w", url, err) + } + if err = js.UnmarshalBytes(bs, ptr); err != nil { + return fmt.Errorf("js.UnmarshalBytes(%s, %+v):%w", string(bs), ptr, err) + } + return +}