Skip to content

Commit

Permalink
Moved refreshSession to the Session struct, added ability for rest tr…
Browse files Browse the repository at this point in the history
…ansport to automatically refresh token if needed
  • Loading branch information
allmightyspiff committed Sep 20, 2024
1 parent 6eff079 commit 303b98d
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 85 deletions.
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"files": "go.sum|^.secrets.baseline$",
"lines": null
},
"generated_at": "2024-06-06T22:18:14Z",
"generated_at": "2024-09-20T19:18:35Z",
"plugins_used": [
{
"name": "AWSKeyDetector"
Expand Down Expand Up @@ -242,7 +242,7 @@
"hashed_secret": "6f667d3e9627f5549ffeb1055ff294c34430b837",
"is_secret": false,
"is_verified": false,
"line_number": 171,
"line_number": 194,
"type": "Secret Keyword",
"verified_result": null
}
Expand Down
56 changes: 56 additions & 0 deletions examples/cmd/iam_demo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package cmd

import (
"fmt"
"time"
"github.com/spf13/cobra"

"github.com/softlayer/softlayer-go/services"
"github.com/softlayer/softlayer-go/session"
)

func init() {
rootCmd.AddCommand(iamDemoCmd)
}

var iamDemoCmd = &cobra.Command{
Use: "iam-demo",
Short: "Will make 1 API call per minute and refresh API key when needed.",
RunE: func(cmd *cobra.Command, args []string) error {
return RunIamCmd(cmd, args)
},
}

func RunIamCmd(cmd *cobra.Command, args []string) error {
objectMask := "mask[id,companyName]"

// Sets up the session with authentication headers.
sess := &session.Session{
Endpoint: session.DefaultEndpoint,
IAMToken: "Bearer TOKEN",
IAMRefreshToken: "REFRESH TOKEN",
Debug: true,
}

// creates a reference to the service object (SoftLayer_Account)
service := services.GetAccountService(sess)

// Sets the mask, filter, result limit, and then makes the API call SoftLayer_Account::getHardware()


for {
account, err := service.Mask(objectMask).GetObject()
if err != nil {
fmt.Printf("======= ERROR ======")
return err
}
fmt.Printf("AccountId: %v, CompanyName: %v\n", *account.Id, *account.CompanyName)
fmt.Printf("Refreshing Token for no reason...\n")
sess.RefreshToken()
fmt.Printf("%s\n", sess.IAMToken)
fmt.Printf("Sleeping for 60s.......\n")
time.Sleep(60 * time.Second)
}

return nil
}
90 changes: 7 additions & 83 deletions session/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package session

import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
Expand All @@ -36,20 +35,6 @@ import (

type RestTransport struct{}

const IBMCLOUDIAMENDPOINT = "https://iam.cloud.ibm.com/identity/token"

// IAMTokenResponse ...
type IAMTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
}

// IAMErrorMessage -
type IAMErrorMessage struct {
ErrorMessage string `json:"errormessage"`
ErrorCode string `json:"errorcode"`
}

// DoRequest - Implementation of the TransportHandler interface for handling
// calls to the REST endpoint.
Expand All @@ -68,12 +53,13 @@ func (r *RestTransport) DoRequest(sess *Session, service string, method string,

path := buildPath(service, method, options)

resp, code, err := sendHTTPRequest(
sess,
path,
restMethod,
parameters,
options)
resp, code, err := sendHTTPRequest(sess, path, restMethod, parameters, options)

//Check if this is a refreshable exception
if err != nil && sess.IAMRefreshToken != "" && NeedsRefresh(err) {
sess.RefreshToken()
resp, code, err = sendHTTPRequest(sess, path, restMethod, parameters, options)
}

if err != nil {
//Preserve the original sl error
Expand Down Expand Up @@ -205,22 +191,6 @@ func tryHTTPRequest(

resp, code, err := makeHTTPRequest(sess, path, requestType, requestBody, options)
if err != nil {
if code == 500 && (sess.IAMToken != "" && sess.IAMRefreshToken != "") {
authErr := refreshToken(sess)
if authErr == nil {
if retries--; retries > 0 {
jitter := time.Duration(rand.Int63n(int64(wait)))
wait = wait + jitter/2
time.Sleep(wait)
return tryHTTPRequest(
retries, wait, sess, path, requestType, requestBody, options)
}
}
if authErr != nil {
return resp, code, fmt.Errorf("Unable to refresh auth token: {{%v}}", authErr)
}

}
if !isRetryable(err) {
return resp, code, err
}
Expand Down Expand Up @@ -379,50 +349,4 @@ func findResponseError(code int, resp []byte) error {
return nil
}

func refreshToken(sess *Session) error {

client := http.DefaultClient
reqPayload := url.Values{}
reqPayload.Add("grant_type", "refresh_token")
reqPayload.Add("refresh_token", sess.IAMRefreshToken)

req, err := http.NewRequest("POST", IBMCLOUDIAMENDPOINT, strings.NewReader(reqPayload.Encode()))
if err != nil {
return err
}
req.Header.Add("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("bx:bx")))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Accept", "application/json")
var token IAMTokenResponse
var eresp IAMErrorMessage

resp, err := client.Do(req)
if err != nil {
return err
}

defer resp.Body.Close()

responseBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}

if resp != nil && resp.StatusCode != 200 {
err = json.Unmarshal(responseBody, &eresp)
if err != nil {
return err
}
if eresp.ErrorCode != "" {
return sl.Error{Exception: eresp.ErrorCode, Message: eresp.ErrorMessage}
}
}

err = json.Unmarshal(responseBody, &token)
if err != nil {
return err
}
sess.IAMToken = fmt.Sprintf("%s %s", token.TokenType, token.AccessToken)
sess.IAMRefreshToken = token.RefreshToken
return nil
}
79 changes: 79 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ import (
"math/rand"
"net"
"net/http"
"net/url"
"os"
"os/user"
"strings"
"time"
"encoding/base64"
"encoding/json"
"io/ioutil"


"github.com/softlayer/softlayer-go/config"
"github.com/softlayer/softlayer-go/sl"
Expand All @@ -44,6 +49,21 @@ func init() {
// DefaultEndpoint is the default endpoint for API calls, when no override is provided.
const DefaultEndpoint = "https://api.softlayer.com/rest/v3.1"

const IBMCLOUDIAMENDPOINT = "https://iam.cloud.ibm.com/identity/token"

// IAMTokenResponse ...
type IAMTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
}

// IAMErrorMessage -
type IAMErrorMessage struct {
ErrorMessage string `json:"errormessage"`
ErrorCode string `json:"errorcode"`
}

var retryableErrorCodes = []string{"SoftLayer_Exception_WebService_RateLimitExceeded"}

// TransportHandler interface for the protocol-specific handling of API requests.
Expand Down Expand Up @@ -319,6 +339,56 @@ func (r *Session) ResetUserAgent() {
r.userAgent = getDefaultUserAgent()
}

// Refreshes an IAM authenticated session
func (r *Session) RefreshToken() error {

Logger.Println("[DEBUG] Refreshing IAM Token")
client := http.DefaultClient
reqPayload := url.Values{}
reqPayload.Add("grant_type", "refresh_token")
reqPayload.Add("refresh_token", r.IAMRefreshToken)

req, err := http.NewRequest("POST", IBMCLOUDIAMENDPOINT, strings.NewReader(reqPayload.Encode()))
if err != nil {
return err
}
req.Header.Add("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("bx:bx")))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Accept", "application/json")
var token IAMTokenResponse
var eresp IAMErrorMessage

resp, err := client.Do(req)
if err != nil {
return err
}

defer resp.Body.Close()

responseBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}

if resp != nil && resp.StatusCode != 200 {
err = json.Unmarshal(responseBody, &eresp)
if err != nil {
return err
}
if eresp.ErrorCode != "" {
return sl.Error{Exception: eresp.ErrorCode, Message: eresp.ErrorMessage}
}
}

err = json.Unmarshal(responseBody, &token)
if err != nil {
return err
}
r.IAMToken = fmt.Sprintf("%s %s", token.TokenType, token.AccessToken)
r.IAMRefreshToken = token.RefreshToken
return nil
}

func envFallback(keyName string, value *string) {
if *value == "" {
*value = os.Getenv(keyName)
Expand Down Expand Up @@ -375,6 +445,15 @@ func isRetryable(err error) bool {
return isTimeout(err) || hasRetryableCode(err)
}

func NeedsRefresh(err error) bool {
if slError, ok := err.(sl.Error); ok {
if slError.StatusCode == 500 && slError.Exception == "SoftLayer_Exception_Account_Authentication_AccessTokenValidation" {
return true
}
}
return false
}

// Set ENV Variable SL_USERAGENT to append that to the useragent string
func getDefaultUserAgent() string {
envAgent := os.Getenv("SL_USERAGENT")
Expand Down

0 comments on commit 303b98d

Please sign in to comment.