diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 7f9aeb3..4b79179 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/base64" + "encoding/gob" "fmt" "html/template" "log" @@ -25,8 +26,14 @@ type Handler struct { store *sessions.CookieStore } +type FlashMessage struct { + Type string + Message string +} + func NewHandler(db *database.DB) *Handler { - // TODO: use environment variable + gob.Register(FlashMessage{}) + secretKey := os.Getenv("SESSION_SECRET_KEY") if secretKey == "" { log.Fatalf("SESSION_SECRET_KEY environment variable is not set") @@ -36,6 +43,43 @@ func NewHandler(db *database.DB) *Handler { return &Handler{db: db, store: store} } +func (h *Handler) addFlashMessage(w http.ResponseWriter, r *http.Request, messageType, message string) { + session, _ := h.store.Get(r, "session") + session.AddFlash(FlashMessage{Type: messageType, Message: message}, "flashMessages") + if err := session.Save(r, w); err != nil { + log.Println("Error saving session:", err) + } +} + +func (h *Handler) getFlashMessages(w http.ResponseWriter, r *http.Request) []FlashMessage { + session, _ := h.store.Get(r, "session") + flashes := session.Flashes("flashMessages") + + var messages []FlashMessage + for _, f := range flashes { + if msg, ok := f.(FlashMessage); ok { + messages = append(messages, msg) + } + } + + if len(flashes) > 0 { + if err := session.Save(r, w); err != nil { + log.Println("Error saving session:", err) + } + } + + return messages +} + +func (h *Handler) prepareTemplateData(w http.ResponseWriter, r *http.Request, data map[string]interface{}) map[string]interface{} { + if data == nil { + data = make(map[string]interface{}) + } + flashMessages := h.getFlashMessages(w, r) + data["FlashMessages"] = flashMessages + return data +} + func (h *Handler) Routes() http.Handler { mux := http.NewServeMux() mux.HandleFunc("/", h.indexHandler) @@ -58,12 +102,13 @@ func (h *Handler) Routes() http.Handler { func (h *Handler) notFoundHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) + data := h.prepareTemplateData(w, r, nil) tmpl, err := template.ParseFiles("internal/templates/base.html", "internal/templates/404.html") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - err = tmpl.ExecuteTemplate(w, "base.html", nil) + err = tmpl.ExecuteTemplate(w, "base.html", data) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -72,12 +117,13 @@ func (h *Handler) notFoundHandler(w http.ResponseWriter, r *http.Request) { func (h *Handler) forbiddenHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusForbidden) + data := h.prepareTemplateData(w, r, nil) tmpl, err := template.ParseFiles("internal/templates/base.html", "internal/templates/403.html") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - err = tmpl.ExecuteTemplate(w, "base.html", nil) + err = tmpl.ExecuteTemplate(w, "base.html", data) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -93,20 +139,9 @@ func (h *Handler) indexHandler(w http.ResponseWriter, r *http.Request) { session, _ := h.store.Get(r, "session") user, _ := session.Values["user"].(*database.User) - flashes := session.Flashes("error") - var errorMsg string - if len(flashes) > 0 { - errorMsg, _ = flashes[0].(string) - } - - data := struct { - User *database.User - Error string - }{ - User: user, - Error: errorMsg, - } - session.Save(r, w) + data := h.prepareTemplateData(w, r, map[string]interface{}{ + "User": user, + }) tmpl, err := template.ParseFiles("internal/templates/base.html", "internal/templates/index.html") if err != nil { @@ -130,19 +165,7 @@ func (h *Handler) newURLHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - flashes := session.Flashes("error") - var errorMsg string - if len(flashes) > 0 { - errorMsg, _ = flashes[0].(string) - } - session.Save(r, w) - - data := struct { - Error string - }{ - Error: errorMsg, - } - + data := h.prepareTemplateData(w, r, nil) tmpl, err := template.ParseFiles("internal/templates/base.html", "internal/templates/new.html") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -156,8 +179,7 @@ func (h *Handler) newURLHandler(w http.ResponseWriter, r *http.Request) { case http.MethodPost: err := r.ParseForm() if err != nil { - session.AddFlash("Error parsing form", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Error parsing form") http.Redirect(w, r, "/new", http.StatusSeeOther) return } @@ -166,31 +188,27 @@ func (h *Handler) newURLHandler(w http.ResponseWriter, r *http.Request) { password := r.Form.Get("password") if url == "" { - session.AddFlash("URL is required", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "URL is required") http.Redirect(w, r, "/new", http.StatusSeeOther) return } url, isValid := utils.IsValidURL(url) if !isValid { - session.AddFlash("Invalid URL", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Invalid URL") http.Redirect(w, r, "/new", http.StatusSeeOther) return } isSafe, err := safebrowsing.IsSafeURL(url) if err != nil { - session.AddFlash("The provided URL is not safe", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Error checking URL safety") http.Redirect(w, r, "/new", http.StatusSeeOther) return } if !isSafe { - session.AddFlash("The provided URL is not safe", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "The provided URL is not safe") http.Redirect(w, r, "/new", http.StatusSeeOther) return } @@ -201,9 +219,8 @@ func (h *Handler) newURLHandler(w http.ResponseWriter, r *http.Request) { if password != "" { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - session.AddFlash("Error hashing password", "error") - session.Save(r, w) - http.Redirect(w, r, r.URL.Path, http.StatusSeeOther) + h.addFlashMessage(w, r, "error", "Error hashing password") + http.Redirect(w, r, "/new", http.StatusSeeOther) return } hashedPassword = string(hash) @@ -212,8 +229,7 @@ func (h *Handler) newURLHandler(w http.ResponseWriter, r *http.Request) { shortURL := fmt.Sprintf("http://%s/r/%s", r.Host, key) qrCode, err := qrcode.Encode(shortURL, qrcode.Medium, 256) if err != nil { - session.AddFlash("Error generating QR code", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Error generating QR code") http.Redirect(w, r, "/new", http.StatusSeeOther) return } @@ -221,14 +237,12 @@ func (h *Handler) newURLHandler(w http.ResponseWriter, r *http.Request) { qrCodeBase64 := base64.StdEncoding.EncodeToString(qrCode) if err := h.db.InsertURL(url, key, user.ID, hashedPassword, qrCodeBase64); err != nil { - session.AddFlash("Error inserting URL into database", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Error inserting URL into database") http.Redirect(w, r, "/new", http.StatusSeeOther) return } - session.AddFlash("URL successfully added", "success") - session.Save(r, w) + h.addFlashMessage(w, r, "success", "URL successfully added") http.Redirect(w, r, "/dashboard", http.StatusSeeOther) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -251,7 +265,9 @@ func (h *Handler) redirectHandler(w http.ResponseWriter, r *http.Request) { if url.Password != "" { switch r.Method { case http.MethodGet: - data := struct{ Key string }{Key: key} + data := h.prepareTemplateData(w, r, map[string]interface{}{ + "Key": key, + }) tmpl, err := template.ParseFiles("internal/templates/base.html", "internal/templates/password.html") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -266,7 +282,8 @@ func (h *Handler) redirectHandler(w http.ResponseWriter, r *http.Request) { case http.MethodPost: password := r.FormValue("password") if err := bcrypt.CompareHashAndPassword([]byte(url.Password), []byte(password)); err != nil { - http.Error(w, "Invalid password", http.StatusUnauthorized) + h.addFlashMessage(w, r, "error", "Invalid password") + http.Redirect(w, r, r.URL.Path, http.StatusSeeOther) return } default: @@ -277,7 +294,8 @@ func (h *Handler) redirectHandler(w http.ResponseWriter, r *http.Request) { isSafe, err := safebrowsing.IsSafeURL(url.URL) if err != nil { - http.Error(w, "Error checking URL safety", http.StatusInternalServerError) + h.addFlashMessage(w, r, "error", "Error checking URL safety") + http.Redirect(w, r, "/", http.StatusSeeOther) return } @@ -287,8 +305,7 @@ func (h *Handler) redirectHandler(w http.ResponseWriter, r *http.Request) { } if err := h.db.IncrementClicks(url.ID); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + log.Printf("Error incrementing clicks: %v", err) } http.Redirect(w, r, url.URL, http.StatusFound) @@ -297,12 +314,13 @@ func (h *Handler) redirectHandler(w http.ResponseWriter, r *http.Request) { func (h *Handler) registerHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: + data := h.prepareTemplateData(w, r, nil) tmpl, err := template.ParseFiles("internal/templates/base.html", "internal/templates/register.html") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - err = tmpl.ExecuteTemplate(w, "base.html", nil) + err = tmpl.ExecuteTemplate(w, "base.html", data) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -313,19 +331,22 @@ func (h *Handler) registerHandler(w http.ResponseWriter, r *http.Request) { password := r.FormValue("password") if username == "" || email == "" || password == "" { - http.Error(w, "All fields are required", http.StatusBadRequest) + h.addFlashMessage(w, r, "error", "All fields are required") + http.Redirect(w, r, "/register", http.StatusSeeOther) return } hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - http.Error(w, "Error hashing password", http.StatusInternalServerError) + h.addFlashMessage(w, r, "error", "Error hashing password") + http.Redirect(w, r, "/register", http.StatusSeeOther) return } user, err := h.db.CreateUser(username, email, string(hashedPassword)) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + h.addFlashMessage(w, r, "error", "Error creating user") + http.Redirect(w, r, "/register", http.StatusSeeOther) return } @@ -334,10 +355,12 @@ func (h *Handler) registerHandler(w http.ResponseWriter, r *http.Request) { err = session.Save(r, w) if err != nil { log.Printf("Error saving session: %v", err) - http.Error(w, "Error saving session", http.StatusInternalServerError) + h.addFlashMessage(w, r, "error", "Error saving session") + http.Redirect(w, r, "/register", http.StatusSeeOther) return } + h.addFlashMessage(w, r, "success", "Registration successful") http.Redirect(w, r, "/dashboard", http.StatusSeeOther) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -345,23 +368,9 @@ func (h *Handler) registerHandler(w http.ResponseWriter, r *http.Request) { } func (h *Handler) loginHandler(w http.ResponseWriter, r *http.Request) { - session, _ := h.store.Get(r, "session") - switch r.Method { case http.MethodGet: - flashes := session.Flashes("error") - var errorMsg string - if len(flashes) > 0 { - errorMsg, _ = flashes[0].(string) - } - session.Save(r, w) - - data := struct { - Error string - }{ - Error: errorMsg, - } - + data := h.prepareTemplateData(w, r, nil) tmpl, err := template.ParseFiles("internal/templates/base.html", "internal/templates/login.html") if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -378,28 +387,29 @@ func (h *Handler) loginHandler(w http.ResponseWriter, r *http.Request) { user, err := h.db.GetUserByUsername(username) if err != nil || user == nil { - session.AddFlash("Invalid username or password", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Invalid username or password") http.Redirect(w, r, "/login", http.StatusSeeOther) return } err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) if err != nil { - session.AddFlash("Invalid username or password", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Invalid username or password") http.Redirect(w, r, "/login", http.StatusSeeOther) return } + session, _ := h.store.Get(r, "session") session.Values["user"] = user err = session.Save(r, w) if err != nil { log.Printf("Error saving session: %v", err) - http.Error(w, "Error saving session", http.StatusInternalServerError) + h.addFlashMessage(w, r, "error", "Error saving session") + http.Redirect(w, r, "/login", http.StatusSeeOther) return } + h.addFlashMessage(w, r, "success", "Login successful") http.Redirect(w, r, "/dashboard", http.StatusSeeOther) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -410,6 +420,7 @@ func (h *Handler) logoutHandler(w http.ResponseWriter, r *http.Request) { session, _ := h.store.Get(r, "session") session.Values["user"] = nil session.Save(r, w) + h.addFlashMessage(w, r, "info", "You have been logged out") http.Redirect(w, r, "/", http.StatusSeeOther) } @@ -434,41 +445,17 @@ func (h *Handler) dashboardHandler(w http.ResponseWriter, r *http.Request) { return } - errorFlashes := session.Flashes("error") - var errorMsg string - if len(errorFlashes) > 0 { - errorMsg, _ = errorFlashes[0].(string) - } - - var successMsg string - if errorMsg == "" { - successFlashes := session.Flashes("success") - if len(successFlashes) > 0 { - successMsg, _ = successFlashes[0].(string) - } - } - - session.Save(r, w) - urls, err := h.db.GetURLsByUserID(user.ID) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - data := struct { - User *database.User - URLs []database.URL - Host string - Success string - Error string - }{ - User: user, - URLs: urls, - Host: r.Host, - Success: successMsg, - Error: errorMsg, - } + data := h.prepareTemplateData(w, r, map[string]interface{}{ + "User": user, + "URLs": urls, + "Host": r.Host, + }) tmpl, err := template.ParseFiles("internal/templates/base.html", "internal/templates/dashboard.html") if err != nil { @@ -492,16 +479,14 @@ func (h *Handler) editURLHandler(w http.ResponseWriter, r *http.Request) { urlID, err := strconv.ParseInt(strings.TrimPrefix(r.URL.Path, "/edit/"), 10, 64) if err != nil { - session.AddFlash("Invalid URL ID", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Invalid URL ID") http.Redirect(w, r, "/dashboard", http.StatusSeeOther) return } url, err := h.db.GetURLByID(urlID) if err != nil { - session.AddFlash("URL not found", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "URL not found") http.Redirect(w, r, "/dashboard", http.StatusSeeOther) return } @@ -513,22 +498,10 @@ func (h *Handler) editURLHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - flashes := session.Flashes("error") - var errorMsg string - if len(flashes) > 0 { - errorMsg, _ = flashes[0].(string) - } - session.Save(r, w) - - data := struct { - URL *database.URL - Host string - Error string - }{ - URL: url, - Host: r.Host, - Error: errorMsg, - } + data := h.prepareTemplateData(w, r, map[string]interface{}{ + "URL": url, + "Host": r.Host, + }) tmpl, err := template.ParseFiles("internal/templates/base.html", "internal/templates/edit.html") if err != nil { @@ -545,31 +518,27 @@ func (h *Handler) editURLHandler(w http.ResponseWriter, r *http.Request) { newPassword := r.FormValue("password") if newURL == "" { - session.AddFlash("URL is required", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "URL is required") http.Redirect(w, r, r.URL.Path, http.StatusSeeOther) return } newURL, isValid := utils.IsValidURL(newURL) if !isValid { - session.AddFlash("Invalid URL provided", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Invalid URL provided") http.Redirect(w, r, r.URL.Path, http.StatusSeeOther) return } isSafe, err := safebrowsing.IsSafeURL(newURL) if err != nil { - session.AddFlash("Error checking URL safety", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Error checking URL safety") http.Redirect(w, r, r.URL.Path, http.StatusSeeOther) return } if !isSafe { - session.AddFlash("The provided URL is not safe", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "The provided URL is not safe") http.Redirect(w, r, r.URL.Path, http.StatusSeeOther) return } @@ -578,8 +547,7 @@ func (h *Handler) editURLHandler(w http.ResponseWriter, r *http.Request) { if newPassword != "" { hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { - session.AddFlash("Error hashing password", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Error hashing password") http.Redirect(w, r, r.URL.Path, http.StatusSeeOther) return } @@ -588,14 +556,12 @@ func (h *Handler) editURLHandler(w http.ResponseWriter, r *http.Request) { err = h.db.UpdateURL(urlID, newURL, hashedPassword) if err != nil { - session.AddFlash("Error updating the URL", "error") - session.Save(r, w) + h.addFlashMessage(w, r, "error", "Error updating the URL") http.Redirect(w, r, r.URL.Path, http.StatusSeeOther) return } - session.AddFlash("URL updated successfully", "success") - session.Save(r, w) + h.addFlashMessage(w, r, "success", "URL updated successfully") http.Redirect(w, r, "/dashboard", http.StatusSeeOther) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -612,7 +578,8 @@ func (h *Handler) deleteURLHandler(w http.ResponseWriter, r *http.Request) { urlID, err := strconv.ParseInt(strings.TrimPrefix(r.URL.Path, "/delete/"), 10, 64) if err != nil { - http.Error(w, "Invalid URL ID", http.StatusBadRequest) + h.addFlashMessage(w, r, "error", "Invalid URL ID") + http.Redirect(w, r, "/dashboard", http.StatusSeeOther) return } @@ -629,10 +596,12 @@ func (h *Handler) deleteURLHandler(w http.ResponseWriter, r *http.Request) { err = h.db.DeleteURL(urlID) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + h.addFlashMessage(w, r, "error", "Error deleting URL") + http.Redirect(w, r, "/dashboard", http.StatusSeeOther) return } + h.addFlashMessage(w, r, "success", "URL deleted successfully") http.Redirect(w, r, "/dashboard", http.StatusSeeOther) } @@ -646,7 +615,8 @@ func (h *Handler) urlDetailsHandler(w http.ResponseWriter, r *http.Request) { urlID, err := strconv.ParseInt(strings.TrimPrefix(r.URL.Path, "/details/"), 10, 64) if err != nil { - http.Error(w, "Invalid URL ID", http.StatusBadRequest) + h.addFlashMessage(w, r, "error", "Invalid URL ID") + http.Redirect(w, r, "/dashboard", http.StatusSeeOther) return } @@ -664,21 +634,17 @@ func (h *Handler) urlDetailsHandler(w http.ResponseWriter, r *http.Request) { shortURL := fmt.Sprintf("http://%s/r/%s", r.Host, url.Key) qrCode, err := qrcode.Encode(shortURL, qrcode.Medium, 256) if err != nil { - http.Error(w, "Error generating QR code", http.StatusInternalServerError) + h.addFlashMessage(w, r, "error", "Error generating QR code") + http.Redirect(w, r, "/dashboard", http.StatusSeeOther) return } - data := struct { - URL *database.URL - QRCode string - Host string - ShortURL string - }{ - URL: url, - QRCode: base64.StdEncoding.EncodeToString(qrCode), - Host: r.Host, - ShortURL: shortURL, - } + data := h.prepareTemplateData(w, r, map[string]interface{}{ + "URL": url, + "QRCode": base64.StdEncoding.EncodeToString(qrCode), + "Host": r.Host, + "ShortURL": shortURL, + }) tmpl, err := template.ParseFiles("internal/templates/base.html", "internal/templates/details.html") if err != nil { diff --git a/internal/templates/base.html b/internal/templates/base.html index 854f3fe..a935211 100644 --- a/internal/templates/base.html +++ b/internal/templates/base.html @@ -15,6 +15,47 @@ {{ block "head" . }}{{ end }}
+