-
Notifications
You must be signed in to change notification settings - Fork 14
/
czds.go
307 lines (271 loc) · 7.86 KB
/
czds.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
// Package czds implementing a client to the CZDS REST API using both the documented and undocumented API endpoints
package czds
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/lanrat/czds/jwt"
)
const (
// AuthURL production url endpoint
AuthURL = "https://account-api.icann.org/api/authenticate"
// BaseURL production url endpoint
BaseURL = "https://czds-api.icann.org"
// TestAuthURL testing url endpoint
TestAuthURL = "https://account-api-test.icann.org/api/authenticate"
// TestBaseURL testing url endpoint
TestBaseURL = "https://czds-api-test.icann.org"
)
var (
defaultHTTPClient = &http.Client{}
)
// Client stores all session information for czds authentication
// and manages token renewal
type Client struct {
HTTPClient *http.Client
AuthURL string
BaseURL string
auth authResponse
authExp time.Time
Creds Credentials
authMutex sync.Mutex
log Logger
}
// Credentials used by the czds.Client
type Credentials struct {
Username string `json:"username"`
Password string `json:"password"`
}
type authResponse struct {
AccessToken string `json:"accessToken"`
Message string `json:"message"`
}
type errorResponse struct {
Message string `json:"message"`
HTTPStatus int `json:"httpStatus"`
}
// NewClient returns a new instance of the CZDS Client with the default production URLs
func NewClient(username, password string) *Client {
client := &Client{
AuthURL: AuthURL,
BaseURL: BaseURL,
Creds: Credentials{
Username: username,
Password: password,
},
}
return client
}
// this function does NOT make network requests if the auth is valid
func (c *Client) checkAuth() error {
// used a mutex to prevent multiple threads from authenticating at the same time
c.authMutex.Lock()
defer c.authMutex.Unlock()
if c.auth.AccessToken == "" {
// no token yet
c.v("no auth token")
return c.Authenticate()
}
if time.Now().After(c.authExp) {
// token expired, renew
c.v("auth token expired")
return c.Authenticate()
}
return nil
}
func (c *Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return defaultHTTPClient
}
// apiRequest makes a request to the client's API endpoint
// TODO add optional context to requests
func (c *Client) apiRequest(auth bool, method, url string, request io.Reader) (*http.Response, error) {
c.v("HTTP API Request: %s %q", method, url)
if auth {
err := c.checkAuth()
if err != nil {
return nil, err
}
}
totalTrys := 3
var err error
var req *http.Request
var resp *http.Response
for try := 1; try <= totalTrys; try++ {
req, err = http.NewRequest(method, url, request)
if err != nil {
return nil, err
}
if request != nil {
req.Header.Set("Content-Type", "application/json")
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.auth.AccessToken))
resp, err = c.httpClient().Do(req)
if err != nil {
err = fmt.Errorf("error on request [%d/%d] %s, got error %w: %+v", try, totalTrys, url, err, resp)
c.v("HTTP API Request error: %s", err)
} else {
return resp, nil
}
// sleep only if we will try again
if try < totalTrys {
time.Sleep(time.Second * 10)
}
}
return resp, err
}
// jsonAPI performs an authenticated json API request
func (c *Client) jsonAPI(method, path string, request, response interface{}) error {
return c.jsonRequest(true, method, c.BaseURL+path, request, response)
}
// jsonRequest performs a request to the API endpoint sending and receiving JSON objects
func (c *Client) jsonRequest(auth bool, method, url string, request, response interface{}) error {
var payloadReader io.Reader
if request != nil {
jsonPayload, err := json.Marshal(request)
if err != nil {
return err
}
payloadReader = bytes.NewReader(jsonPayload)
}
resp, err := c.apiRequest(auth, method, url, payloadReader)
if err != nil {
return err
}
defer resp.Body.Close()
// got an error, decode it
if resp.StatusCode != http.StatusOK {
var errorResp errorResponse
err := fmt.Errorf("error on request %q: got Status %s %s", url, resp.Status, http.StatusText(resp.StatusCode))
if resp.ContentLength != 0 {
jsonError := json.NewDecoder(resp.Body).Decode(&errorResp)
if jsonError != nil {
return fmt.Errorf("error decoding json %w on errored request: %s", jsonError, err.Error())
}
err = fmt.Errorf("%w HTTP Status: %d Message: %q", err, errorResp.HTTPStatus, errorResp.Message)
}
return err
}
if response != nil {
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err
}
}
return nil
}
// Authenticate tests the client's credentials and gets an authentication token from the server
// calling this is optional. All other functions will check the auth state on their own first and authenticate if necessary.
func (c *Client) Authenticate() error {
c.v("authenticating")
authResp := authResponse{}
err := c.jsonRequest(false, "POST", c.AuthURL, c.Creds, &authResp)
if err != nil {
return err
}
c.auth = authResp
c.authExp, err = authResp.getExpiration()
if err != nil {
return err
}
if !c.authExp.After(time.Now()) {
return fmt.Errorf("unable to authenticate")
}
return nil
}
// getExpiration returns the expiration of the authentication token
func (ar *authResponse) getExpiration() (time.Time, error) {
token, err := jwt.DecodeJWT(ar.AccessToken)
exp := time.Unix(token.Data.Exp, 0)
return exp, err
}
// GetZoneRequestID returns the most request RequestID for the given zone
func (c *Client) GetZoneRequestID(zone string) (string, error) {
c.v("GetZoneRequestID: %q", zone)
zone = strings.ToLower(zone)
// given a RequestsResponse, return the request for the provided zone if found, otherwise nil
findFirstZoneInRequests := func(zone string, r *RequestsResponse) *Request {
for _, request := range r.Requests {
if strings.ToLower(request.TLD) == zone {
return &request
}
}
return nil
}
filter := RequestsFilter{
Status: RequestAll,
Filter: zone,
Pagination: RequestsPagination{
Size: 100,
Page: 0,
},
Sort: RequestsSort{
Field: SortByLastUpdated,
Direction: SortDesc,
},
}
// get all requests matching filter
requests, err := c.GetRequests(&filter)
if err != nil {
return "", err
}
// check if zone in returned requests
request := findFirstZoneInRequests(zone, requests)
// if zone is not found in requests, and there are more requests to get, iterate through them
for request == nil && len(requests.Requests) != 0 {
filter.Pagination.Page++
c.v("GetZoneRequestID: zone %q not found yet, requesting page %d", zone, filter.Pagination.Page)
requests, err = c.GetRequests(&filter)
if err != nil {
return "", err
}
request = findFirstZoneInRequests(zone, requests)
}
if requests.TotalRequests == 0 || request == nil {
return "", fmt.Errorf("no request found for zone %s", zone)
}
return request.RequestID, nil
}
// GetAllRequests returns the request information for all requests with the given status
// status should be one of the constant czds.Status* strings
// warning: for large number of results, may be slow
func (c *Client) GetAllRequests(status string) ([]Request, error) {
c.v("GetAllRequests status: %q", status)
const pageSize = 100
filter := RequestsFilter{
Status: status,
Filter: "",
Pagination: RequestsPagination{
Size: pageSize,
Page: 0,
},
Sort: RequestsSort{
Field: SortByCreated,
Direction: SortDesc,
},
}
out := make([]Request, 0, 100)
c.v("GetAllRequests status: %q, page %d", status, filter.Pagination.Page)
requests, err := c.GetRequests(&filter)
if err != nil {
return out, err
}
for len(requests.Requests) != 0 {
c.v("GetAllRequests status: %q, page %d", status, filter.Pagination.Page)
out = append(out, requests.Requests...)
filter.Pagination.Page++
requests, err = c.GetRequests(&filter)
if err != nil {
return out, err
}
}
return out, nil
}