Skip to content

Commit

Permalink
Merge pull request #201 from cyb3r4nt/fix-dev-provider-reg
Browse files Browse the repository at this point in the history
Fix registration of dev provider in Service.authMiddleware.Providers
  • Loading branch information
umputun authored Jun 18, 2024
2 parents 733697d + 79b88b6 commit 350854c
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 120 deletions.
54 changes: 26 additions & 28 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,39 +234,44 @@ func (s *Service) AddProviderWithUserAttributes(name, cid, csecret string, userA
L: s.logger,
UserAttributes: userAttributes,
}
s.addProvider(name, p)
s.addProviderByName(name, p)
}

func (s *Service) addProvider(name string, p provider.Params) {
func (s *Service) addProviderByName(name string, p provider.Params) {
var prov provider.Provider
switch strings.ToLower(name) {
case "github":
s.providers = append(s.providers, provider.NewService(provider.NewGithub(p)))
prov = provider.NewGithub(p)
case "google":
s.providers = append(s.providers, provider.NewService(provider.NewGoogle(p)))
prov = provider.NewGoogle(p)
case "facebook":
s.providers = append(s.providers, provider.NewService(provider.NewFacebook(p)))
prov = provider.NewFacebook(p)
case "yandex":
s.providers = append(s.providers, provider.NewService(provider.NewYandex(p)))
prov = provider.NewYandex(p)
case "battlenet":
s.providers = append(s.providers, provider.NewService(provider.NewBattlenet(p)))
prov = provider.NewBattlenet(p)
case "microsoft":
s.providers = append(s.providers, provider.NewService(provider.NewMicrosoft(p)))
prov = provider.NewMicrosoft(p)
case "twitter":
s.providers = append(s.providers, provider.NewService(provider.NewTwitter(p)))
prov = provider.NewTwitter(p)
case "patreon":
s.providers = append(s.providers, provider.NewService(provider.NewPatreon(p)))
prov = provider.NewPatreon(p)
case "dev":
s.providers = append(s.providers, provider.NewService(provider.NewDev(p)))
prov = provider.NewDev(p)
default:
return
}

s.addProvider(prov)
}

func (s *Service) addProvider(prov provider.Provider) {
s.providers = append(s.providers, provider.NewService(prov))
s.authMiddleware.Providers = s.providers
}

// AddProvider adds provider for given name
func (s *Service) AddProvider(name, cid, csecret string) {

p := provider.Params{
URL: s.opts.URL,
JwtService: s.jwtService,
Expand All @@ -277,8 +282,7 @@ func (s *Service) AddProvider(name, cid, csecret string) {
L: s.logger,
UserAttributes: map[string]string{},
}

s.addProvider(name, p)
s.addProviderByName(name, p)
}

// AddDevProvider with a custom host and port
Expand All @@ -292,7 +296,7 @@ func (s *Service) AddDevProvider(host string, port int) {
Port: port,
Host: host,
}
s.providers = append(s.providers, provider.NewService(provider.NewDev(p)))
s.addProvider(provider.NewDev(p))
}

// AddAppleProvider allow SignIn with Apple ID
Expand All @@ -311,7 +315,7 @@ func (s *Service) AddAppleProvider(appleConfig provider.AppleConfig, privKeyLoad
return fmt.Errorf("an AppleProvider creating failed: %w", err)
}

s.providers = append(s.providers, provider.NewService(appleProvider))
s.addProvider(appleProvider)
return nil
}

Expand All @@ -326,9 +330,7 @@ func (s *Service) AddCustomProvider(name string, client Client, copts provider.C
Csecret: client.Csecret,
L: s.logger,
}

s.providers = append(s.providers, provider.NewService(provider.NewCustom(name, p, copts)))
s.authMiddleware.Providers = s.providers
s.addProvider(provider.NewCustom(name, p, copts))
}

// AddDirectProvider adds provider with direct check against data store
Expand All @@ -342,8 +344,7 @@ func (s *Service) AddDirectProvider(name string, credChecker provider.CredChecke
CredChecker: credChecker,
AvatarSaver: s.avatarProxy,
}
s.providers = append(s.providers, provider.NewService(dh))
s.authMiddleware.Providers = s.providers
s.addProvider(dh)
}

// AddDirectProviderWithUserIDFunc adds provider with direct check against data store and sets custom UserIDFunc allows
Expand All @@ -359,8 +360,7 @@ func (s *Service) AddDirectProviderWithUserIDFunc(name string, credChecker provi
AvatarSaver: s.avatarProxy,
UserIDFunc: ufn,
}
s.providers = append(s.providers, provider.NewService(dh))
s.authMiddleware.Providers = s.providers
s.addProvider(dh)
}

// AddVerifProvider adds provider user's verification sent by sender
Expand All @@ -375,14 +375,12 @@ func (s *Service) AddVerifProvider(name, msgTmpl string, sender provider.Sender)
Template: msgTmpl,
UseGravatar: s.useGravatar,
}
s.providers = append(s.providers, provider.NewService(dh))
s.authMiddleware.Providers = s.providers
s.addProvider(dh)
}

// AddCustomHandler adds user-defined self-implemented handler of auth provider
func (s *Service) AddCustomHandler(handler provider.Provider) {
s.providers = append(s.providers, provider.NewService(handler))
s.authMiddleware.Providers = s.providers
func (s *Service) AddCustomHandler(p provider.Provider) {
s.addProvider(p)
}

// DevAuth makes dev oauth2 server, for testing and development only!
Expand Down
78 changes: 46 additions & 32 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ func TestProvider(t *testing.T) {
_, err := svc.Provider("some provider")
assert.EqualError(t, err, "provider some provider not found")

svc.AddProvider("dev", "cid", "csecret")
svc.AddProviderWithUserAttributes("dev", "cid", "csecret", provider.UserAttributes{"attrName": "attrValue"})
svc.AddProvider("github", "cid", "csecret")
svc.AddProvider("google", "cid", "csecret")
svc.AddProvider("facebook", "cid", "csecret")
svc.AddProvider("yandex", "cid", "csecret")
svc.AddProvider("microsoft", "cid", "csecret")
svc.AddProvider("twitter", "cid", "csecret")
svc.AddProvider("battlenet", "cid", "csecret")
svc.AddProvider("patreon", "cid", "csecret")
svc.AddProvider("bad", "cid", "csecret")
Expand All @@ -72,14 +73,15 @@ func TestProvider(t *testing.T) {
assert.Equal(t, "cid", op.Cid)
assert.Equal(t, "csecret", op.Csecret)
assert.Equal(t, "go-pkgz/auth", op.Issuer)
assert.Equal(t, provider.UserAttributes{"attrName": "attrValue"}, op.Params.UserAttributes)

p, err = svc.Provider("github")
assert.NoError(t, err)
op = p.Provider.(provider.Oauth2Handler)
assert.Equal(t, "github", op.Name())

pp := svc.Providers()
assert.Equal(t, 9, len(pp))
assert.Equal(t, 10, len(pp))

ch, err := svc.Provider("telegramBotMySiteCom")
assert.NoError(t, err)
Expand Down Expand Up @@ -227,7 +229,11 @@ func TestIntegrationAvatar(t *testing.T) {
}

func TestIntegrationList(t *testing.T) {
_, teardown := prepService(t)
_, teardown := prepService(t, func(svc *Service) {
svc.AddProvider("github", "cid", "csec")
// add go-oauth2/oauth2 provider
svc.AddCustomProvider("custom123", Client{"cid", "csecret"}, provider.CustomHandlerOpt{})
})
defer teardown()

resp, err := http.Get("http://127.0.0.1:8089/auth/list")
Expand All @@ -237,7 +243,7 @@ func TestIntegrationList(t *testing.T) {

b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, `["dev","github","custom123","direct","direct_custom","email"]`+"\n", string(b))
assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b))
}

func TestIntegrationUserInfo(t *testing.T) {
Expand Down Expand Up @@ -337,7 +343,11 @@ func TestBadRequests(t *testing.T) {
}

func TestDirectProvider(t *testing.T) {
_, teardown := prepService(t)
_, teardown := prepService(t, func(svc *Service) {
svc.AddDirectProvider("direct", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}))
})
defer teardown()

// login
Expand Down Expand Up @@ -375,19 +385,28 @@ func TestDirectProvider(t *testing.T) {
}

func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
_, teardown := prepService(t)
_, teardown := prepService(t, func(svc *Service) {
svc.AddDirectProviderWithUserIDFunc("directCustom",
provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}),
func(user string, r *http.Request) string {
return "blah"
},
)
})
defer teardown()

// login
jar, err := cookiejar.New(nil)
require.Nil(t, err)
client := &http.Client{Jar: jar, Timeout: 5 * time.Second}
resp, err := client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=bad")
resp, err := client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=bad")
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, 403, resp.StatusCode)

resp, err = client.Get("http://127.0.0.1:8089/auth/direct_custom/login?user=dev_direct&passwd=password")
resp, err = client.Get("http://127.0.0.1:8089/auth/directCustom/login?user=dev_direct&passwd=password")
require.Nil(t, err)
defer resp.Body.Close()
assert.Equal(t, 200, resp.StatusCode)
Expand All @@ -397,7 +416,7 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
t.Logf("resp %s", string(body))
t.Logf("headers: %+v", resp.Header)

assert.Contains(t, string(body), `"name":"dev_direct","id":"direct_custom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`)
assert.Contains(t, string(body), `"name":"dev_direct","id":"directCustom_5bf1fd927dfb8679496a2e6cf00cbe50c1c87145"`)

require.Equal(t, 2, len(resp.Cookies()))
assert.Equal(t, "JWT", resp.Cookies()[0].Name)
Expand All @@ -413,7 +432,9 @@ func TestDirectProvider_WithCustomUserIDFunc(t *testing.T) {
}

func TestVerifProvider(t *testing.T) {
_, teardown := prepService(t)
_, teardown := prepService(t, func(svc *Service) {
svc.AddVerifProvider("email", "{{.Token}}", &sender)
})
defer teardown()

// login
Expand Down Expand Up @@ -489,7 +510,16 @@ func TestStatus(t *testing.T) {

}

func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unparam
func TestDevAuthServerWithoutDevProvider(t *testing.T) {
svc := NewService(Opts{})
assert.NotNil(t, svc)

_, err := svc.DevAuth()
require.NotNil(t, err)
assert.EqualError(t, err, "dev provider not registered: provider dev not found")
}

func prepService(t *testing.T, providerConfigFunctions ...func(svc *Service)) (svc *Service, teardown func()) { //nolint unparam

options := Opts{
SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }),
Expand All @@ -510,28 +540,12 @@ func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unpara
}

svc = NewService(options)
svc.AddDevProvider("localhost", 18084) // add dev provider on 18084
svc.AddProvider("github", "cid", "csec") // add github provider

// add go-oauth2/oauth2 provider
svc.AddCustomProvider("custom123", Client{"cid", "csecret"}, provider.CustomHandlerOpt{})

// add direct provider
svc.AddDirectProvider("direct", provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}))

// add direct provider with custom user id func
svc.AddDirectProviderWithUserIDFunc("direct_custom",
provider.CredCheckerFunc(func(user, password string) (ok bool, err error) {
return user == "dev_direct" && password == "password", nil
}),
func(user string, r *http.Request) string {
return "blah"
},
)
svc.AddDevProvider("localhost", 18084) // add dev provider on 18084

svc.AddVerifProvider("email", "{{.Token}}", &sender)
for _, f := range providerConfigFunctions {
f(svc)
}

// run dev/test oauth2 server on :18084
devAuth, err := svc.DevAuth()
Expand All @@ -547,7 +561,7 @@ func prepService(t *testing.T) (svc *Service, teardown func()) { //nolint unpara
_, _ = w.Write([]byte("open route, no token needed\n"))
}))
mux.Handle("/private", m.Auth(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("open route, no token needed\n"))
_, _ = w.Write([]byte("protected route, authenticated with token\n"))
})))

// setup auth routes
Expand Down
Loading

0 comments on commit 350854c

Please sign in to comment.