Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the multi-user selection feature and modify the code for access, refresh, and grant tokens #67

Merged
merged 4 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 64 additions & 204 deletions cmd/other/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,7 @@ func executeUserLogin(currentEnv string) {
exitWithError()
}

homeDir, err := os.UserHomeDir()
if err != nil {
pterm.Error.Println("Failed to get user home directory:", err)
exitWithError()
}

homeDir, _ := os.UserHomeDir()
// Get user_id from current environment
mainViper := viper.New()
settingPath := filepath.Join(homeDir, ".cfctl", "setting.toml")
Expand All @@ -430,39 +425,57 @@ func executeUserLogin(currentEnv string) {

userID := mainViper.GetString(fmt.Sprintf("environments.%s.user_id", currentEnv))
if userID == "" {
pterm.Error.Println("No user ID found in current environment configuration.")
exitWithError()
}
userIDInput := pterm.DefaultInteractiveTextInput
userID, _ = userIDInput.Show("Enter your User ID")

// Display the current user ID
pterm.Info.Printf("Logged in as: %s\n", userID)
mainViper.Set(fmt.Sprintf("environments.%s.user_id", currentEnv), userID)
if err := mainViper.WriteConfig(); err != nil {
pterm.Error.Printf("Failed to save user ID to config: %v\n", err)
exitWithError()
}
} else {
pterm.Info.Printf("Logging in as: %s\n", userID)
}

// Prompt for password
password := promptPassword()
// Check for valid tokens first
accessToken, refreshToken, newAccessToken, err := getValidTokens(currentEnv)
var password string

// Extract the middle part of the environment name for `name`
// Extract domain name from environment
nameParts := strings.Split(currentEnv, "-")
if len(nameParts) < 3 {
pterm.Error.Println("Environment name format is invalid.")
exitWithError()
}
name := nameParts[1]

// Fetch Domain ID using the base URL and domain name
// Fetch Domain ID
domainID, err := fetchDomainID(baseUrl, name)
if err != nil {
pterm.Error.Println("Failed to fetch Domain ID:", err)
exitWithError()
}

// Issue new tokens
accessToken, refreshToken, err := issueToken(baseUrl, userID, password, domainID)
if err != nil {
pterm.Error.Printf("Failed to issue token: %v\n", err)
// If refresh token is not valid, get new tokens with password
if refreshToken == "" || isTokenExpired(refreshToken) {
password = promptPassword()
accessToken, refreshToken, err = issueToken(baseUrl, userID, password, domainID)
if err != nil {
pterm.Error.Printf("Failed to issue token: %v\n", err)
exitWithError()
}
}

// Create cache directory and save tokens
envCacheDir := filepath.Join(homeDir, ".cfctl", "cache", currentEnv)
if err := os.MkdirAll(envCacheDir, 0700); err != nil {
pterm.Error.Printf("Failed to create cache directory: %v\n", err)
exitWithError()
}

// Use the tokens
pterm.Info.Printf("Logged in as %s\n", userID)

// Use the tokens to fetch workspaces and role
workspaces, err := fetchWorkspaces(baseUrl, accessToken)
if err != nil {
pterm.Error.Println("Failed to fetch workspaces:", err)
Expand All @@ -475,6 +488,7 @@ func executeUserLogin(currentEnv string) {
exitWithError()
}

// Determine scope and select workspace
scope := determineScope(roleType, len(workspaces))
var workspaceID string
if roleType == "DOMAIN_ADMIN" {
Expand All @@ -491,19 +505,13 @@ func executeUserLogin(currentEnv string) {
}

// Grant new token using the refresh token
grantToken, err := grantToken(baseUrl, refreshToken, scope, domainID, workspaceID)
newAccessToken, err = grantToken(baseUrl, refreshToken, scope, domainID, workspaceID)
if err != nil {
pterm.Error.Println("Failed to retrieve new access token:", err)
exitWithError()
}

// Save tokens to cache
envCacheDir := filepath.Join(homeDir, ".cfctl", "cache", currentEnv)
if err := os.MkdirAll(envCacheDir, 0700); err != nil {
pterm.Error.Printf("Failed to create cache directory: %v\n", err)
exitWithError()
}

// Save all tokens
if err := os.WriteFile(filepath.Join(envCacheDir, "access_token"), []byte(accessToken), 0600); err != nil {
pterm.Error.Printf("Failed to save access token: %v\n", err)
exitWithError()
Expand All @@ -514,7 +522,7 @@ func executeUserLogin(currentEnv string) {
exitWithError()
}

if err := os.WriteFile(filepath.Join(envCacheDir, "grant_token"), []byte(grantToken), 0600); err != nil {
if err := os.WriteFile(filepath.Join(envCacheDir, "grant_token"), []byte(newAccessToken), 0600); err != nil {
pterm.Error.Printf("Failed to save grant token: %v\n", err)
exitWithError()
}
Expand All @@ -529,171 +537,6 @@ func promptPassword() string {
return password
}

// Prompt for user selection, now receiving 'users' slice as an argument
func promptUserSelection(max int, users []interface{}) int {
if err := keyboard.Open(); err != nil {
pterm.Error.Println("Failed to initialize keyboard:", err)
exitWithError()
}
defer keyboard.Close()

selectedIndex := 0
currentPage := 0
pageSize := 10
searchMode := false
searchTerm := ""
filteredUsers := users

for {
fmt.Print("\033[H\033[2J") // Clear the screen

// Apply search filter
if searchTerm != "" {
filteredUsers = filterUsers(users, searchTerm)
if len(filteredUsers) == 0 {
filteredUsers = users // Show all users if no search results
}
} else {
filteredUsers = users
}

// Calculate pagination
totalUsers := len(filteredUsers)
totalPages := (totalUsers + pageSize - 1) / pageSize
startIndex := currentPage * pageSize
endIndex := startIndex + pageSize
if endIndex > totalUsers {
endIndex = totalUsers
}

// Display header with page information
pterm.DefaultHeader.WithFullWidth().
WithBackgroundStyle(pterm.NewStyle(pterm.BgDarkGray)).
WithTextStyle(pterm.NewStyle(pterm.FgLightWhite)).
Printf("Select a user account (Page %d of %d)", currentPage+1, totalPages)

// Display option to add new user first
if selectedIndex == 0 {
pterm.Printf("→ %d: Add new user\n", 1)
} else {
pterm.Printf(" %d: Add new user\n", 1)
}

// Display existing users
for i := startIndex; i < endIndex; i++ {
userMap := filteredUsers[i].(map[string]interface{})
if i+1 == selectedIndex {
pterm.Printf("→ %d: %s\n", i+2, userMap["userid"].(string))
} else {
pterm.Printf(" %d: %s\n", i+2, userMap["userid"].(string))
}
}

// Show navigation help
pterm.DefaultBasicText.WithStyle(pterm.NewStyle(pterm.FgGray)).
Println("\nNavigation: [h]prev-page [j]down [k]up [l]next-page [/]search [Enter]select [q]quit")

// Show search prompt if in search mode
if searchMode {
fmt.Println()
pterm.Info.Printf("Search (ESC to cancel, Enter to confirm): %s", searchTerm)
}

// Get keyboard input
char, key, err := keyboard.GetKey()
if err != nil {
pterm.Error.Println("Error reading keyboard input:", err)
exitWithError()
}

// Handle search mode input
if searchMode {
switch key {
case keyboard.KeyEsc:
searchMode = false
searchTerm = ""
filteredUsers = users // Return to full user list when search term is cleared
case keyboard.KeyBackspace, keyboard.KeyBackspace2:
if len(searchTerm) > 0 {
searchTerm = searchTerm[:len(searchTerm)-1]
}
case keyboard.KeyEnter:
searchMode = false
default:
if char != 0 {
searchTerm += string(char)
}
}
currentPage = 0
selectedIndex = 0
continue
}

// Handle normal mode input
switch key {
case keyboard.KeyEnter:
if selectedIndex == 0 {
return len(users) + 1 // Add new user
} else if selectedIndex <= len(filteredUsers) {
// Find the original index of the selected user
selectedUserMap := filteredUsers[selectedIndex-1].(map[string]interface{})
selectedUserID := selectedUserMap["userid"].(string)

for i, user := range users {
userMap := user.(map[string]interface{})
if userMap["userid"].(string) == selectedUserID {
return i + 1
}
}
}
}

switch char {
case 'j': // Down
if selectedIndex < min(endIndex-startIndex, totalUsers) {
selectedIndex++
}
case 'k': // Up
if selectedIndex > 0 {
selectedIndex--
}
case 'l': // Next page
if currentPage < totalPages-1 {
currentPage++
selectedIndex = 0
}
case 'h': // Previous page
if currentPage > 0 {
currentPage--
selectedIndex = 0
}
case '/': // Enter search mode
searchMode = true
searchTerm = ""
selectedIndex = 0
case 'q', 'Q':
fmt.Println()
pterm.Error.Println("User selection cancelled.")
os.Exit(1)
}
}
}

// filterUsers filters the users list based on the search term
func filterUsers(users []interface{}, searchTerm string) []interface{} {
var filtered []interface{}
searchTerm = strings.ToLower(searchTerm)

for _, user := range users {
userMap := user.(map[string]interface{})
userid := strings.ToLower(userMap["userid"].(string))
if strings.Contains(userid, searchTerm) {
filtered = append(filtered, user)
}
}
return filtered
}

// min returns the minimum of two integers
func min(a, b int) int {
if a < b {
Expand Down Expand Up @@ -835,7 +678,7 @@ func saveCredentials(currentEnv, userID, encryptedPassword, accessToken, refresh
if grantToken != "" {
if err := os.WriteFile(filepath.Join(envCacheDir, "grant_token"), []byte(grantToken), 0600); err != nil {
pterm.Error.Printf("Failed to save grant token: %v\n", err)
exitWithError()
exitWithError()
}
}
}
Expand Down Expand Up @@ -1810,25 +1653,42 @@ func readTokenFromFile(envDir, tokenType string) (string, error) {
}

// getValidTokens checks for existing valid tokens in the environment cache directory
func getValidTokens(currentEnv string) (accessToken, refreshToken string, err error) {
func getValidTokens(currentEnv string) (accessToken, refreshToken, newAccessToken string, err error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", "", err
return "", "", "", err
}

envCacheDir := filepath.Join(homeDir, ".cfctl", "cache", currentEnv)

// Try to read and validate access token
// Try to read and validate grant token first
if newAccessToken, err = readTokenFromFile(envCacheDir, "grant_token"); err == nil {
claims, err := validateAndDecodeToken(newAccessToken)
if err == nil {
// Check if token has expired
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() < int64(exp) {
accessToken, _ = readTokenFromFile(envCacheDir, "access_token")
refreshToken, _ = readTokenFromFile(envCacheDir, "refresh_token")
return accessToken, refreshToken, newAccessToken, nil
}
}
}
}

// If grant token is invalid or expired, check refresh token
if accessToken, err = readTokenFromFile(envCacheDir, "access_token"); err == nil {
if !isTokenExpired(accessToken) {
// Try to read refresh token only if access token is valid
if refreshToken, err = readTokenFromFile(envCacheDir, "refresh_token"); err == nil {
if !isTokenExpired(refreshToken) {
return accessToken, refreshToken, nil
if refreshToken, err = readTokenFromFile(envCacheDir, "refresh_token"); err == nil {
claims, err := validateAndDecodeToken(refreshToken)
if err == nil {
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() < int64(exp) {
return accessToken, refreshToken, "", nil
}
}
}
}
}

return "", "", fmt.Errorf("no valid tokens found")
return "", "", "", fmt.Errorf("no valid tokens found")
}
7 changes: 2 additions & 5 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,6 @@ func loadCachedEndpoints() (map[string]string, error) {

// Create environment-specific cache directory
envCacheDir := filepath.Join(home, ".cfctl", "cache", currentEnv)
if err := os.MkdirAll(envCacheDir, 0755); err != nil {
return nil, err
}

// Read from environment-specific cache file
cacheFile := filepath.Join(envCacheDir, "endpoints.toml")
Expand Down Expand Up @@ -348,7 +345,7 @@ func loadConfig() (*Config, error) {
// Try to read main setting first
mainV := viper.New()
mainV.SetConfigFile(settingFile)
mainV.SetConfigType("toml") // Explicitly set config type to TOML
mainV.SetConfigType("toml") // Explicitly set config type to TOML
mainConfigErr := mainV.ReadInConfig()

if mainConfigErr != nil {
Expand All @@ -373,7 +370,7 @@ func loadConfig() (*Config, error) {
if endpoint == "" || token == "" {
cacheV := viper.New()
cacheV.SetConfigFile(cacheConfigFile)
cacheV.SetConfigType("toml") // Explicitly set config type to TOML
cacheV.SetConfigType("toml") // Explicitly set config type to TOML

if err := cacheV.ReadInConfig(); err == nil {
// If no current environment set, try to get it from cache setting
Expand Down
Loading