Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added a simple access token caching feature 2 #33

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 169 additions & 14 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ import (
"io"
"io/ioutil"
"log"
"math"
"mime/multipart"
"net/http"
"net/url"
"os"
"regexp"
"strings"
"time"
)

const (
Expand Down Expand Up @@ -40,6 +43,11 @@ type Server struct {
Configuration
}

type TokenCache struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}

// New returns an initialized Secrets object
func New(config Configuration) (*Server, error) {
if config.ServerURL == "" && config.Tenant == "" || config.ServerURL != "" && config.Tenant != "" {
Expand Down Expand Up @@ -252,12 +260,43 @@ func (s Server) uploadFile(secretId int, fileField SecretField) error {
return err
}

func (s *Server) setCacheAccessToken(value string, expiresIn int) error {
cache := TokenCache{}
cache.AccessToken = value
cache.ExpiresIn = (int(time.Now().Unix()) + expiresIn) - int(math.Floor(float64(expiresIn)*0.9))

data, _ := json.Marshal(cache)
os.Setenv("SS_AT", string(data))
return nil
}

func (s *Server) getCacheAccessToken() (string, bool) {
data, ok := os.LookupEnv("SS_AT")
if !ok {
os.Setenv("SS_AT", "")
return "", ok
}
cache := TokenCache{}
if err := json.Unmarshal([]byte(data), &cache); err != nil {
return "", false
}
if time.Now().Unix() < int64(cache.ExpiresIn) {
return cache.AccessToken, true
}
return "", false
}

// getAccessToken gets an OAuth2 Access Grant and returns the token
// endpoint and get an accessGrant.
func (s *Server) getAccessToken() (string, error) {
if s.Credentials.Token != "" {
return s.Credentials.Token, nil
}
accessToken, found := s.getCacheAccessToken()
if found {
return accessToken, nil
}

response, err := s.checkPlatformDetails()
if err != nil {
log.Print("Error while checking server details:", err)
Expand Down Expand Up @@ -292,6 +331,7 @@ func (s *Server) getAccessToken() (string, error) {
log.Print("[ERROR] parsing grant response:", err)
return "", err
}
s.setCacheAccessToken(grant.AccessToken, grant.ExpiresIn)
return grant.AccessToken, nil
} else {
return response, nil
Expand All @@ -316,29 +356,63 @@ func (s *Server) checkPlatformDetails() (string, error) {
} else {
isHealthy := checkJSONResponse(platformHelthCheckUrl)
if isHealthy {
requestData := url.Values{}
requestData.Set("grant_type", "client_credentials")
requestData.Set("client_id", s.Credentials.Username)
requestData.Set("client_secret", s.Credentials.Password)
requestData.Set("scope", "xpmheadless")
requestData := map[string]string{
"User": s.Credentials.Username,
"Version": "1.0",
}
jsonData, err := json.Marshal(requestData)
if err != nil {
log.Print("Error marshaling JSON:", err)
return "", err
}

req, err := http.NewRequest("POST", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "identity/api/oauth2/token/xpmplatform"), bytes.NewBufferString(requestData.Encode()))
req, err := http.NewRequest("POST", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "identity/Security/StartAuthentication"), bytes.NewBuffer(jsonData))
if err != nil {
log.Print("Error creating HTTP request:", err)
return "", err
}

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

data, _, err := handleResponse((&http.Client{}).Do(req))
if err != nil {
log.Print("[ERROR] get token response error:", err)
log.Print("[ERROR] start authetication response error:", err)
return "", err
}

var startAuthjsonResponse StartAuthResponse
if err = json.Unmarshal(data, &startAuthjsonResponse); err != nil {
log.Print("[ERROR] parsing start auth response:", err)
return "", err
}

requestData = map[string]string{
"Answer": s.Credentials.Password,
"MechanismId": findMechanismId(startAuthjsonResponse),
"Action": "Answer",
"SessionId": startAuthjsonResponse.Result.SessionId,
"TenantId": startAuthjsonResponse.Result.TenantId,
}

jsonData, err = json.Marshal(requestData)
if err != nil {
log.Print("Error marshaling JSON:", err)
return "", err
}

var tokenjsonResponse OAuthTokens
if err = json.Unmarshal(data, &tokenjsonResponse); err != nil {
log.Print("[ERROR] parsing get token response:", err)
req, err = http.NewRequest("POST", fmt.Sprintf("%s/%s", strings.Trim(baseURL, "/"), "identity/Security/AdvanceAuthentication"), bytes.NewBuffer(jsonData))
if err != nil {
log.Print("Error creating HTTP request:", err)
return "", err
}

data, _, err = handleResponse((&http.Client{}).Do(req))
if err != nil {
log.Print("[ERROR] advance authetication response error:", err)
return "", err
}

var advanceAuthJsonResponse AdvanceAuthResponse
if err = json.Unmarshal(data, &advanceAuthJsonResponse); err != nil {
log.Print("[ERROR] parsing advance auth response:", err)
return "", err
}

Expand All @@ -347,7 +421,7 @@ func (s *Server) checkPlatformDetails() (string, error) {
log.Print("Error creating HTTP request:", err)
return "", err
}
req.Header.Add("Authorization", "Bearer "+tokenjsonResponse.AccessToken)
req.Header.Add("Authorization", "Bearer "+advanceAuthJsonResponse.Result.OAuthTokens.AccessToken)

data, _, err = handleResponse((&http.Client{}).Do(req))
if err != nil {
Expand All @@ -374,7 +448,7 @@ func (s *Server) checkPlatformDetails() (string, error) {
return "", fmt.Errorf("no configured vault found")
}

return tokenjsonResponse.AccessToken, nil
return advanceAuthJsonResponse.Result.OAuthTokens.AccessToken, nil
}
}
return "", fmt.Errorf("invalid URL")
Expand Down Expand Up @@ -403,6 +477,17 @@ func checkJSONResponse(url string) bool {
}
}

func findMechanismId(saResponse StartAuthResponse) string {
for _, challenge := range saResponse.Result.Challenges {
for _, mechanism := range challenge.Mechanisms {
if mechanism.PromptSelectMech == "Password" {
return mechanism.MechanismId
}
}
}
return ""
}

type Response struct {
Healthy bool `json:"healthy"`
DatabaseHealthy bool `json:"databaseHealthy"`
Expand All @@ -411,6 +496,48 @@ type Response struct {
ScheduledForDeletion bool `json:"scheduledForDeletion"`
}

type ClientHints struct {
PersistDefault bool `json:"PersistDefault"`
AllowPersist bool `json:"AllowPersist"`
AllowForgotPassword bool `json:"AllowForgotPassword"`
StartingPoint string `json:"StartingPoint"`
RequestedUsername string `json:"RequestedUsername"`
}

type Mechanism struct {
AnswerType string `json:"AnswerType"`
Name string `json:"Name"`
PromptMechChosen string `json:"PromptMechChosen"`
PromptSelectMech string `json:"PromptSelectMech"`
MechanismId string `json:"MechanismId"`
}

type Challenge struct {
Mechanisms []Mechanism `json:"Mechanisms"`
}

type Result struct {
ClientHints ClientHints `json:"ClientHints"`
Version string `json:"Version"`
SessionId string `json:"SessionId"`
AllowLoginMfaCache bool `json:"AllowLoginMfaCache"`
Challenges []Challenge `json:"Challenges"`
Summary string `json:"Summary"`
TenantId string `json:"TenantId"`
}

type StartAuthResponse struct {
Success bool `json:"success"`
Result Result `json:"Result"`
Message interface{} `json:"Message"`
MessageID interface{} `json:"MessageID"`
Exception interface{} `json:"Exception"`
ErrorID interface{} `json:"ErrorID"`
ErrorCode interface{} `json:"ErrorCode"`
IsSoftError bool `json:"IsSoftError"`
InnerExceptions interface{} `json:"InnerExceptions"`
}

type OAuthTokens struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
Expand All @@ -421,6 +548,34 @@ type OAuthTokens struct {
Scope string `json:"scope"`
}

type AdvanceAuthResult struct {
AuthLevel string `json:"AuthLevel"`
DisplayName string `json:"DisplayName"`
OAuthTokens OAuthTokens `json:"OAuthTokens"`
UserId string `json:"UserId"`
EmailAddress string `json:"EmailAddress"`
UserDirectory string `json:"UserDirectory"`
StartingPoint string `json:"StartingPoint"`
PodFqdn string `json:"PodFqdn"`
User string `json:"User"`
CustomerID string `json:"CustomerID"`
SystemID string `json:"SystemID"`
SourceDsType string `json:"SourceDsType"`
Summary string `json:"Summary"`
}

type AdvanceAuthResponse struct {
Success bool `json:"success"`
Result AdvanceAuthResult `json:"Result"`
Message interface{} `json:"Message"`
MessageID interface{} `json:"MessageID"`
Exception interface{} `json:"Exception"`
ErrorID interface{} `json:"ErrorID"`
ErrorCode interface{} `json:"ErrorCode"`
IsSoftError bool `json:"IsSoftError"`
InnerExceptions interface{} `json:"InnerExceptions"`
}

type Connection struct {
Url string `json:"url"`
OAuthProfileId string `json:"oAuthProfileId"`
Expand Down
Loading