Skip to content

Commit

Permalink
Addressing HO comments
Browse files Browse the repository at this point in the history
  • Loading branch information
denisonbarbosa committed Jun 3, 2024
1 parent 3f692ce commit 073a91e
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 82 deletions.
2 changes: 1 addition & 1 deletion config/oidc-broker.broker
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ offline_expiration = 180

# The directory where the user's home directory will be created.
# The user home directory will be created in the format of $homedir_path/$username
homedir_path = /home/
home_base_dir = /home
79 changes: 34 additions & 45 deletions internal/broker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type Config struct {
ClientID string
CachePath string
OfflineExpiration string
HomeDirPath string
HomeBaseDir string
}

// Broker is the real implementation of the broker to track sessions and process oidc calls.
Expand Down Expand Up @@ -88,7 +88,7 @@ type isAuthenticatedCtx struct {
}

type option struct {
// skipJWTSignatureCheck is used to skip the JWT validation done by the oidc dependency.
// skipJWTSignatureCheck is used to skip the JWT validation done by the oidc web server.
skipJWTSignatureCheck bool
providerInfo providers.ProviderInfoer
}
Expand Down Expand Up @@ -129,9 +129,9 @@ func New(cfg Config, args ...Option) (b *Broker, err error) {
offlineExpiration = time.Duration(intValue*24) * time.Hour
}

homeDirPath := "/home/"
if cfg.HomeDirPath != "" {
homeDirPath = cfg.HomeDirPath
homeDirPath := "/home"
if cfg.HomeBaseDir != "" {
homeDirPath = cfg.HomeBaseDir
}

// Generate a new private key for the broker.
Expand Down Expand Up @@ -212,23 +212,23 @@ func (b *Broker) GetAuthenticationModes(sessionID string, supportedUILayouts []m
return nil, err
}

supportedModes := b.getSupportedModes(supportedUILayouts)
supportedAuthModes := b.supportedAuthModesFromLayout(supportedUILayouts)

// Checks if the token exists in the cache.
_, err = os.Stat(session.cachePath)
tokenExists := err == nil

availableModes, err := b.providerInfo.AuthenticationModes(
availableModes, err := b.providerInfo.CurrentAuthenticationModesOffered(
session.mode,
b.loadLastUsedMode(&session),
supportedAuthModes,
tokenExists,
session.currentAuthStep)
if err != nil {
return nil, err
}

for _, id := range availableModes {
label, ok := supportedModes[id]
label, ok := supportedAuthModes[id]
if !ok {
return nil, fmt.Errorf("required mode %q is not supported", id)
}
Expand All @@ -238,15 +238,15 @@ func (b *Broker) GetAuthenticationModes(sessionID string, supportedUILayouts []m
})
}

session.supportedModes = supportedModes
session.supportedModes = supportedAuthModes
if err := b.updateSession(sessionID, session); err != nil {
return nil, err
}

return authModes, nil
}

func (b *Broker) getSupportedModes(supportedUILayouts []map[string]string) (supportedModes map[string]string) {
func (b *Broker) supportedAuthModesFromLayout(supportedUILayouts []map[string]string) (supportedModes map[string]string) {
supportedModes = make(map[string]string)
for _, layout := range supportedUILayouts {
supportedEntries := strings.Split(strings.TrimPrefix(layout["entry"], "optional:"), ",")
Expand Down Expand Up @@ -358,7 +358,7 @@ func (b *Broker) IsAuthenticated(sessionID, authenticationData string) (string,
}
}

ctx, err := b.startAuthentication(sessionID)
ctx, err := b.startAuthenticate(sessionID)
if err != nil {
return AuthDenied, "", err
}
Expand Down Expand Up @@ -386,10 +386,6 @@ func (b *Broker) IsAuthenticated(sessionID, authenticationData string) (string,
access = AuthDenied
data = `{"message": "maximum number of attempts reached"}`
}
case AuthGranted:
if err := b.cacheLastUsedMode(&session); err != nil {
slog.Warn(fmt.Sprintf("Could not cache authentication mode used: %v", err))
}

case AuthNext:
session.currentAuthStep++
Expand All @@ -411,6 +407,7 @@ func (b *Broker) handleIsAuthenticated(ctx context.Context, session *sessionInfo
var authInfo authCachedInfo
var userClaims claims
var groups []group.Info
var tokenRefreshed bool

switch session.selectedMode {
case "qrcode":
Expand All @@ -433,7 +430,7 @@ func (b *Broker) handleIsAuthenticated(ctx context.Context, session *sessionInfo
return AuthNext, ""

case "password":
authInfo, err = b.loadAuthInfo(session, challenge)
authInfo, tokenRefreshed, err = b.loadAuthInfo(session, challenge)
if err != nil {
return AuthRetry, fmt.Sprintf(`{"message": "could not authenticate user: %v"}`, err)
}
Expand All @@ -450,6 +447,7 @@ func (b *Broker) handleIsAuthenticated(ctx context.Context, session *sessionInfo
if !ok {
return AuthDenied, `{"message": "could not get required information"}`
}
tokenRefreshed = true
}

if authInfo.UserInfo == "" {
Expand All @@ -464,14 +462,18 @@ func (b *Broker) handleIsAuthenticated(ctx context.Context, session *sessionInfo
}
}

if !tokenRefreshed {
return AuthGranted, fmt.Sprintf(`{"userinfo": %s}`, authInfo.UserInfo)
}

if err := b.cacheAuthInfo(session, authInfo, challenge); err != nil {
return AuthRetry, fmt.Sprintf(`{"message": "could not update token: %v"}`, err)
}

return AuthGranted, fmt.Sprintf(`{"userinfo": %s}`, authInfo.UserInfo)
}

func (b *Broker) startAuthentication(sessionID string) (context.Context, error) {
func (b *Broker) startAuthenticate(sessionID string) (context.Context, error) {
session, err := b.getSession(sessionID)
if err != nil {
return nil, err
Expand Down Expand Up @@ -592,26 +594,8 @@ func (b *Broker) cacheAuthInfo(session *sessionInfo, authInfo authCachedInfo, pa
return nil
}

// cacheLastUsedMode caches the last selected mode for the user.
func (b *Broker) cacheLastUsedMode(session *sessionInfo) error {
err := os.WriteFile(session.cachePath+"-mode", []byte(session.firstSelectedMode), 0600)
if err != nil {
return fmt.Errorf("could not save last used mode: %v", err)
}
return nil
}

// loadLastUsedMode loads the last selected mode for the user.
func (b *Broker) loadLastUsedMode(session *sessionInfo) string {
mode, err := os.ReadFile(session.cachePath + "-mode")
if err != nil {
slog.Warn(fmt.Sprintf("Could not load last used mode: %v", err))
}
return string(mode)
}

// loadAuthInfo deserializes the token from the cache and refreshes it if needed.
func (b *Broker) loadAuthInfo(session *sessionInfo, password string) (loadedInfo authCachedInfo, err error) {
func (b *Broker) loadAuthInfo(session *sessionInfo, password string) (loadedInfo authCachedInfo, refreshed bool, err error) {
defer func() {
// Override the error so that we don't leak information. Also, abstract it for the user.
// We still log as error for the admin to get access.
Expand All @@ -623,43 +607,48 @@ func (b *Broker) loadAuthInfo(session *sessionInfo, password string) (loadedInfo

s, err := os.ReadFile(session.cachePath)
if err != nil {
return authCachedInfo{}, fmt.Errorf("could not read token: %v", err)
return authCachedInfo{}, false, fmt.Errorf("could not read token: %v", err)
}

deserialized, err := decrypt(s, []byte(password))
if err != nil {
return authCachedInfo{}, fmt.Errorf("could not deserializing token: %v", err)
return authCachedInfo{}, false, fmt.Errorf("could not deserializing token: %v", err)
}

var cachedInfo authCachedInfo
if err := json.Unmarshal(deserialized, &cachedInfo); err != nil {
return authCachedInfo{}, fmt.Errorf("could not unmarshaling token: %v", err)
return authCachedInfo{}, false, fmt.Errorf("could not unmarshaling token: %v", err)
}

// We always try to rewrite the token file to update the last access time.
refreshed = true
// Refresh token automatically if needed. If the service is unavailable and token is still valid from the broker
// perspective (i.e. last modification was between now and now-expiration), we can use it.
tok, err := b.auth.oauthCfg.TokenSource(context.Background(), cachedInfo.Token).Token()
if err != nil {
castErr := &oauth2.RetrieveError{}
if errors.As(err, &castErr) && castErr.Response.StatusCode != http.StatusServiceUnavailable {
return authCachedInfo{}, fmt.Errorf("could not refresh token: %v", err)
return authCachedInfo{}, false, fmt.Errorf("could not refresh token: %v", err)
}

st, err := os.Stat(session.cachePath)
if err != nil {
return authCachedInfo{}, fmt.Errorf("could not get token file info: %v", err)
return authCachedInfo{}, false, fmt.Errorf("could not get token file info: %v", err)
}

if st.ModTime().Add(b.offlineExpiration).Before(time.Now()) {
return authCachedInfo{}, errors.New("token exceeded offline expiration")
return authCachedInfo{}, false, errors.New("token exceeded offline expiration")
}

// This means we are using cached user information, so we don't need to query the provider for them.
tok = cachedInfo.Token
loadedInfo.UserInfo = cachedInfo.UserInfo

// If we are using the cached token offline, we don't want to update its modification time in order to control
// offline expiration.
refreshed = false
}
loadedInfo.Token = tok

loadedInfo.RawIDToken = cachedInfo.RawIDToken

// If the ID token was refreshed, we overwrite the cached one.
Expand All @@ -668,7 +657,7 @@ func (b *Broker) loadAuthInfo(session *sessionInfo, password string) (loadedInfo
loadedInfo.RawIDToken = refreshedIDToken
}

return loadedInfo, nil
return loadedInfo, refreshed, nil
}

func (b *Broker) fetchUserInfo(ctx context.Context, session *sessionInfo, t *authCachedInfo) (userClaims claims, userGroups []group.Info, err error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/broker/broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ func TestFetchUserInfo(t *testing.T) {
IssuerURL: defaultProvider.URL,
ClientID: "test-client-id",
CachePath: t.TempDir(),
HomeDirPath: homeDirPath,
HomeBaseDir: homeDirPath,
}

mockInfoer := &testutils.MockProviderInfoer{
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

2 changes: 1 addition & 1 deletion internal/dbusservice/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ const (
// offlineExpirationKey is the key in the config file for the offline expiration.
offlineExpirationKey = "offline_expiration"
// homeDirKey is the key in the config file for the home directory prefix.
homeDirKey = "homedir_path"
homeDirKey = "home_base_dir"
)
2 changes: 1 addition & 1 deletion internal/dbusservice/dbusservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func New(_ context.Context, cfgPath, cachePath string) (s *Service, err error) {
IssuerURL: cfg[oidcSection][issuerKey],
ClientID: cfg[oidcSection][clientIDKey],
OfflineExpiration: cfg[oidcSection][offlineExpirationKey],
HomeDirPath: cfg[oidcSection][homeDirKey],
HomeBaseDir: cfg[oidcSection][homeDirKey],
CachePath: cachePath,
}
b, err := broker.New(bCfg)
Expand Down
27 changes: 18 additions & 9 deletions internal/providers/microsoft_entra_id/microsoft-entra-id.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package microsoft_entra_id
import (
"context"
"errors"
"fmt"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
Expand Down Expand Up @@ -89,30 +90,38 @@ func (p MSEntraIDProvider) GetGroups(token *oauth2.Token) ([]group.Info, error)
return groups, nil
}

// AuthenticationModes returns the generic authentication modes supported by the provider.
func (p MSEntraIDProvider) AuthenticationModes(sessionMode, _ string, tokenExists bool, currentAuthStep int) ([]string, error) {
var modes []string
// CurrentAuthenticationModesOffered returns the generic authentication modes supported by the provider.
//
// Token validity is not considered, only the presence of a token.
func (p MSEntraIDProvider) CurrentAuthenticationModesOffered(sessionMode string, supportedAuthModes map[string]string, tokenExists bool, currentAuthStep int) ([]string, error) {
var offeredModes []string
switch sessionMode {
case "passwd":
if !tokenExists {
return nil, errors.New("user has no cached token")
}
modes = []string{"password"}
offeredModes = []string{"password"}
if currentAuthStep > 0 {
modes = []string{"newpassword"}
offeredModes = []string{"newpassword"}
}

default: // auth mode
modes = []string{"qrcode"}
offeredModes = []string{"qrcode"}
if tokenExists {
modes = []string{"password", "qrcode"}
offeredModes = []string{"password", "qrcode"}
}
if currentAuthStep > 0 {
modes = []string{"newpassword"}
offeredModes = []string{"newpassword"}
}
}

return modes, nil
for _, mode := range offeredModes {
if _, ok := supportedAuthModes[mode]; !ok {
return nil, fmt.Errorf("auth mode %q required by the provider, but is not supported locally", mode)
}
}

return offeredModes, nil
}

type azureTokenCredential struct {
Expand Down
25 changes: 16 additions & 9 deletions internal/providers/noprovider/noprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package noprovider

import (
"errors"
"fmt"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/ubuntu/oidc-broker/internal/providers/group"
Expand All @@ -27,28 +28,34 @@ func (p NoProvider) GetGroups(_ *oauth2.Token) ([]group.Info, error) {
return nil, nil
}

// AuthenticationModes returns the generic authentication modes supported by the provider.
func (p NoProvider) AuthenticationModes(sessionMode, _ string, tokenExists bool, currentAuthStep int) ([]string, error) {
var modes []string
// CurrentAuthenticationModesOffered returns the generic authentication modes supported by the provider.
func (p NoProvider) CurrentAuthenticationModesOffered(sessionMode string, supportedAuthModes map[string]string, tokenExists bool, currentAuthStep int) ([]string, error) {
var offeredModes []string
switch sessionMode {
case "passwd":
if !tokenExists {
return nil, errors.New("user has no cached token")
}
modes = []string{"password"}
offeredModes = []string{"password"}
if currentAuthStep > 0 {
modes = []string{"newpassword"}
offeredModes = []string{"newpassword"}
}

default: // auth mode
modes = []string{"qrcode"}
offeredModes = []string{"qrcode"}
if tokenExists {
modes = []string{"password", "qrcode"}
offeredModes = []string{"password", "qrcode"}
}
if currentAuthStep > 0 {
modes = []string{"newpassword"}
offeredModes = []string{"newpassword"}
}
}

return modes, nil
for _, mode := range offeredModes {
if _, ok := supportedAuthModes[mode]; !ok {
return nil, fmt.Errorf("auth mode %q required by the provider, but is not supported locally", mode)
}
}

return offeredModes, nil
}
2 changes: 1 addition & 1 deletion internal/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ import (
type ProviderInfoer interface {
AdditionalScopes() []string
AuthOptions() []oauth2.AuthCodeOption
AuthenticationModes(sessionMode, lastUsedMode string, tokenExists bool, currentAuthStep int) ([]string, error)
CurrentAuthenticationModesOffered(sessionMode string, supportedAuthModes map[string]string, tokenExists bool, currentAuthStep int) ([]string, error)
GetGroups(*oauth2.Token) ([]group.Info, error)
}
Loading

0 comments on commit 073a91e

Please sign in to comment.