diff --git a/cmd/web/handlers.go b/cmd/web/handlers.go index a94cdb0..7e4f50f 100644 --- a/cmd/web/handlers.go +++ b/cmd/web/handlers.go @@ -33,6 +33,13 @@ type userLoginForm struct { validator.Validator `form:"-"` } +type updatePasswordForm struct { + CurrentPassword string `form:"current"` + NewPassword string `form:"new"` + NewPasswordConfirmation string `form:"confirmation"` + validator.Validator `form:"-"` +} + func (app *application) home(w http.ResponseWriter, r *http.Request) { // Because httprouter matches the "/" path exactly, we can now remove the // manual check of r.URL.Path != "/" from this handler. @@ -290,6 +297,57 @@ func (app *application) accountView(w http.ResponseWriter, r *http.Request) { app.render(w, http.StatusOK, "account.tmpl", data) } +func (app *application) updatePassword(w http.ResponseWriter, r *http.Request) { + data := app.newTemplateData(r) + data.Form = updatePasswordForm{} + + app.render(w, http.StatusOK, "password.tmpl", data) +} + +func (app *application) updatePasswordPost(w http.ResponseWriter, r *http.Request) { + var form updatePasswordForm + + err := app.decodePostForm(r, &form) + if err != nil { + app.clientError(w, http.StatusBadRequest) + return + } + + form.CheckField(validator.NotBlank(form.CurrentPassword), "current", "This field cannot be blank") + form.CheckField(validator.NotBlank(form.NewPassword), "new", "This field cannot be blank") + form.CheckField(validator.MinChars(form.NewPassword, 8), "new", "This field must be at least 8 characters long") + form.CheckField(validator.NotBlank(form.NewPasswordConfirmation), "confirmation", "This field cannot be blank") + form.CheckField(validator.MinChars(form.NewPasswordConfirmation, 8), "confirmation", "This field must be at least 8 characters long") + form.CheckField(form.NewPassword == form.NewPasswordConfirmation, "newPasswordConfirmation", "Passwords do not match") + form.CheckField(form.NewPassword != form.CurrentPassword, "newPasswordEquality", "New password must be different from the current password") + + if !form.Valid() { + data := app.newTemplateData(r) + data.Form = form + app.render(w, http.StatusUnprocessableEntity, "password.tmpl", data) + return + } + + id := app.sessionManager.GetInt(r.Context(), "authenticatedUserID") + err = app.users.UpdatePassword(id, form.CurrentPassword, form.NewPassword) + if err != nil { + if errors.Is(err, models.ErrInvalidCredentials) { + form.AddFieldError("currentPassword", "Current password is incorrect") + + data := app.newTemplateData(r) + data.Form = form + + app.render(w, http.StatusUnprocessableEntity, "password.tmpl", data) + } else if err != nil { + app.serverError(w, err) + } + return + } + + app.sessionManager.Put(r.Context(), "flash", "Your password has been updated!") + http.Redirect(w, r, "/account/view", http.StatusSeeOther) +} + func health(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) } diff --git a/cmd/web/routes.go b/cmd/web/routes.go index bfca024..4a4164d 100644 --- a/cmd/web/routes.go +++ b/cmd/web/routes.go @@ -34,6 +34,8 @@ func (app *application) routes() http.Handler { // middleware chain which includes the requireAuthentication middleware. protected := dynamic.Append(app.requireAuthentication) + router.Handler(http.MethodGet, "/account/password/update", protected.ThenFunc(app.updatePassword)) + router.Handler(http.MethodPost, "/account/password/update", protected.ThenFunc(app.updatePasswordPost)) router.Handler(http.MethodGet, "/account/view", protected.ThenFunc(app.accountView)) router.Handler(http.MethodGet, "/snippet/create", protected.ThenFunc(app.snippetCreate)) router.Handler(http.MethodPost, "/snippet/create", protected.ThenFunc(app.snippetCreatePost)) diff --git a/internal/models/mocks/users.go b/internal/models/mocks/users.go index 09e6033..dc4de0e 100644 --- a/internal/models/mocks/users.go +++ b/internal/models/mocks/users.go @@ -42,6 +42,17 @@ func (m *UserModel) Authenticate(email, password string) (int, error) { return 0, models.ErrInvalidCredentials } +func (m *UserModel) UpdatePassword(id int, currentPassword, newPassword string) error { + switch id { + case 1: + return models.ErrNoRecord + case 2: + return models.ErrInvalidCredentials + default: + return nil + } +} + func (m *UserModel) Exists(id int) (bool, error) { switch id { case 1: diff --git a/internal/models/users.go b/internal/models/users.go index fa2a7cb..059647b 100644 --- a/internal/models/users.go +++ b/internal/models/users.go @@ -16,6 +16,7 @@ type UserModelInterface interface { Authenticate(email, password string) (int, error) Exists(id int) (bool, error) Get(id int) (*User, error) + UpdatePassword(id int, currentPassword, newPassword string) error } // Define a new User type. Notice how the field names and types align @@ -67,11 +68,11 @@ func (m *UserModel) Insert(name, email, password string) error { } func (m *UserModel) Get(id int) (*User, error) { - stmt := `SELECT name, email, created FROM users WHERE id = ?` + stmt := `SELECT name, email, created, hashed_password FROM users WHERE id = ?` var row *sql.Row = m.DB.QueryRow(stmt, id) u := &User{} - err := row.Scan(&u.Name, &u.Email, &u.Created) + err := row.Scan(&u.Name, &u.Email, &u.Created, &u.HashedPassword) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNoRecord @@ -114,6 +115,32 @@ func (m *UserModel) Authenticate(email, password string) (int, error) { return id, nil } +func (m *UserModel) UpdatePassword(id int, currentPassword, newPassword string) error { + user, err := m.Get(id) + if err != nil { + return err + } + + err = bcrypt.CompareHashAndPassword(user.HashedPassword, []byte(currentPassword)) + if err != nil { + if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { + return ErrInvalidCredentials + } else { + return err + } + } + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12) + if err != nil { + return err + } + + stmt := `UPDATE users SET hashed_password = ? WHERE id = ?` + _, err = m.DB.Exec(stmt, hashedPassword, id) + + return err +} + // We'll use the Exists method to check if a user exists with a specific ID. func (m *UserModel) Exists(id int) (bool, error) { log.Println("using implementation") diff --git a/ui/html/pages/account.tmpl b/ui/html/pages/account.tmpl index ccf2cd9..9c99471 100644 --- a/ui/html/pages/account.tmpl +++ b/ui/html/pages/account.tmpl @@ -16,6 +16,10 @@