From 389a0851d504a5cf1b4babb996e674214ce13111 Mon Sep 17 00:00:00 2001 From: David Evans Date: Wed, 25 Oct 2023 11:20:11 +0100 Subject: [PATCH] URL encode user ID in picture URL --- provider/custom_server.go | 2 +- provider/custom_server_test.go | 33 ++++++++++++++++++++++++++------- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/provider/custom_server.go b/provider/custom_server.go index 64b90bc..0ca145e 100644 --- a/provider/custom_server.go +++ b/provider/custom_server.go @@ -163,7 +163,7 @@ func (c *CustomServer) handleUserInfo(w http.ResponseWriter, r *http.Request) { user := token.User{ ID: userID, Name: userID, - Picture: fmt.Sprintf(c.URL+"/avatar?user=%s", userID), + Picture: fmt.Sprintf(c.URL+"/avatar?user=%s", url.QueryEscape(userID)), } res, err := json.Marshal(user) if err != nil { diff --git a/provider/custom_server_test.go b/provider/custom_server_test.go index 527e190..3083920 100644 --- a/provider/custom_server_test.go +++ b/provider/custom_server_test.go @@ -43,6 +43,9 @@ func TestCustomProvider(t *testing.T) { L: logger.Std, } + var loginUsername string + var capturedUser token.User + sopts := CustomServerOpt{ URL: "http://127.0.0.1:9096", L: logger.Std, @@ -61,7 +64,7 @@ func TestCustomProvider(t *testing.T) { jar.SetCookies(u, r.Cookies()) form := url.Values{} - form.Add("username", "admin") + form.Add("username", loginUsername) form.Add("password", "pwd1234") req, err := http.NewRequest("POST", "", strings.NewReader(form.Encode())) @@ -87,9 +90,7 @@ func TestCustomProvider(t *testing.T) { claims, err := params.JwtService.Parse(resp.Cookies()[0].Value) assert.NoError(t, err) - assert.Equal(t, token.User{Name: "admin", ID: "admin", - Picture: "http://127.0.0.1:9096/avatar?user=admin", IP: ""}, *claims.User) - + capturedUser = *claims.User }, } @@ -120,6 +121,7 @@ func TestCustomProvider(t *testing.T) { client := &http.Client{Jar: jar, Timeout: time.Second * 10} // check non-admin, permanent + loginUsername = "admin" resp, err := client.Get("http://127.0.0.1:8080/auth/customprov/login?site=my-test-site") require.Nil(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -127,6 +129,8 @@ func TestCustomProvider(t *testing.T) { assert.NoError(t, err) t.Logf("resp %s", string(body)) t.Logf("headers: %+v", resp.Header) + assert.Equal(t, token.User{Name: "admin", ID: "admin", + Picture: "http://127.0.0.1:9096/avatar?user=admin", IP: ""}, capturedUser) // check avatar resp, err = client.Get("http://127.0.0.1:9096/avatar?user=dev_user") @@ -137,6 +141,18 @@ func TestCustomProvider(t *testing.T) { assert.Equal(t, 960, len(body)) t.Logf("headers: %+v", resp.Header) + // check malicious user ID + loginUsername = "attack" + resp, err = client.Get("http://127.0.0.1:8080/auth/customprov/login?site=my-test-site") + require.Nil(t, err) + assert.Equal(t, 200, resp.StatusCode) + body, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + t.Logf("resp %s", string(body)) + t.Logf("headers: %+v", resp.Header) + // user ID in picture URL is encoded + assert.Equal(t, "http://127.0.0.1:9096/avatar?user=none%26attack%3Dvalue%22%3E%3Cscript%3Enasty+stuff%3C%2Fscript%3E", capturedUser.Picture) + // check default login page prov.LoginPageHandler = nil resp, err = client.Get("http://127.0.0.1:8080/auth/customprov/login?site=my-test-site") @@ -196,10 +212,13 @@ func initGoauth2Srv(t *testing.T) *goauth2.Server { if r.ParseForm() != nil { return "", fmt.Errorf("no username and password in request") } - if r.Form.Get("username") != "admin" || r.Form.Get("password") != "pwd1234" { - return "", fmt.Errorf("wrong creds") + if r.Form.Get("username") == "admin" && r.Form.Get("password") == "pwd1234" { + return "admin", nil + } + if r.Form.Get("username") == "attack" && r.Form.Get("password") == "pwd1234" { + return "none&attack=value\">", nil } - return "admin", nil + return "", fmt.Errorf("wrong creds") }) srv.SetInternalErrorHandler(func(err error) (re *errors.Response) {