Skip to content

Commit

Permalink
Implement CSRF protection, update styles
Browse files Browse the repository at this point in the history
- Implement CSRF protection with middleware and token validation
- Add CSRF tokens to forms and improve error logging
- Update CSS with a different, blue-y look
- Add nosemgrep to some more open redirects

Change-Id: Ie69a0b7764113a0b05566a972a4ac047aa63d647
Signed-off-by: Ian Meyer <[email protected]>
  • Loading branch information
imeyer committed Sep 24, 2024
1 parent d177ce7 commit 9f91ad3
Show file tree
Hide file tree
Showing 7 changed files with 416 additions and 361 deletions.
34 changes: 33 additions & 1 deletion csrf.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package main

import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"log/slog"
"net/http"
"sync"
"time"
Expand All @@ -19,8 +21,13 @@ var (
tokenStore = make(map[string]time.Time)
tokenStoreMu sync.Mutex
timeNow = time.Now
csrfLogger *slog.Logger
)

func initCSRFLogger(logger *slog.Logger) {
csrfLogger = logger
}

func generateCSRFToken() (string, error) {
b := make([]byte, csrfTokenLength)
_, err := rand.Read(b)
Expand All @@ -46,6 +53,10 @@ func setCSRFToken(r *http.Request, w http.ResponseWriter) (string, error) {
SameSite: http.SameSiteStrictMode,
})

if csrfLogger != nil {
csrfLogger.DebugContext(r.Context(), "set csrf cookie")
}

tokenStoreMu.Lock()
tokenStore[token] = timeNow().Add(24 * time.Hour) // Token expires in 24 hours
tokenStoreMu.Unlock()
Expand All @@ -56,15 +67,27 @@ func setCSRFToken(r *http.Request, w http.ResponseWriter) (string, error) {
func validateCSRFToken(r *http.Request) error {
cookie, err := r.Cookie(csrfCookieName)
if err != nil {
if csrfLogger != nil {
csrfLogger.DebugContext(r.Context(), "csrf error", slog.String("message", err.Error()))
}
return errors.New("CSRF cookie not found")
}

token := r.Header.Get(csrfHeaderName)
if token == "" {
return errors.New("CSRF token not found in header")
token = r.FormValue("csrf_token")
if token == "" {
if csrfLogger != nil {
csrfLogger.DebugContext(r.Context(), "csrf token not found in header or form")
}
return errors.New("CSRF token not found in header or form")
}
}

if cookie.Value != token {
if csrfLogger != nil {
csrfLogger.DebugContext(r.Context(), "csrf token mismatch")
}
return errors.New("CSRF token mismatch")
}

Expand All @@ -73,11 +96,17 @@ func validateCSRFToken(r *http.Request) error {

expiry, exists := tokenStore[token]
if !exists {
if csrfLogger != nil {
csrfLogger.DebugContext(r.Context(), "csrf token not found in store")
}
return errors.New("CSRF token not found in store")
}

if timeNow().After(expiry) {
delete(tokenStore, token)
if csrfLogger != nil {
csrfLogger.DebugContext(r.Context(), "csrf token expired")
}
return errors.New("CSRF token expired")
}

Expand All @@ -93,6 +122,9 @@ func CSRFMiddleware(next http.HandlerFunc) http.HandlerFunc {
return
}
w.Header().Set(csrfHeaderName, token)

ctx := context.WithValue(r.Context(), "CSRFToken", token)
next.ServeHTTP(w, r.WithContext(ctx))
} else {
if err := validateCSRFToken(r); err != nil {
http.Error(w, "CSRF validation failed", http.StatusForbidden)
Expand Down
5 changes: 5 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ func main() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

// Enabling logging within csrf.go
if *debug {
initCSRFLogger(logger.With("component", "csrf"))
}

dbconn, err := setupDatabase(ctx, logger)
if err != nil {
logger.Error("failed to connect to database", slog.String("error", err.Error()))
Expand Down
36 changes: 25 additions & 11 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,16 @@ func setupMux(dsvc *DiscussService) http.Handler {
}

tailnetMux := http.NewServeMux()

tailnetMux.HandleFunc("POST /member/edit", CSRFMiddleware(dsvc.EditMemberProfile))
tailnetMux.HandleFunc("POST /thread/new", CSRFMiddleware(dsvc.CreateThread))
tailnetMux.HandleFunc("POST /thread/{id}", CSRFMiddleware(dsvc.CreateThreadPost))

tailnetMux.HandleFunc("GET /", dsvc.ListThreads)
tailnetMux.HandleFunc("GET /member/{id}", dsvc.ListMember)
tailnetMux.HandleFunc("GET /member/edit", dsvc.EditMemberProfile)
tailnetMux.HandleFunc("POST /member/edit", dsvc.EditMemberProfile)
tailnetMux.HandleFunc("GET /thread/new", dsvc.NewThread)
tailnetMux.HandleFunc("POST /thread/new", dsvc.CreateThread)
tailnetMux.HandleFunc("GET /thread/{id}", dsvc.ListThreadPosts)
tailnetMux.HandleFunc("POST /thread/{id}", dsvc.CreateThreadPost)
tailnetMux.Handle("GET /metrics", promhttp.Handler())
tailnetMux.Handle("GET /static/", http.StripPrefix("/static/", http.FileServer(http.FS(fs))))

Expand Down Expand Up @@ -401,6 +403,7 @@ func (s *DiscussService) CreateThreadPost(w http.ResponseWriter, r *http.Request
return
}

// nosemgrep
http.Redirect(w, r, fmt.Sprintf("/thread/%d", threadID), http.StatusSeeOther)
}

Expand All @@ -410,6 +413,8 @@ func (s *DiscussService) EditMemberProfile(w http.ResponseWriter, r *http.Reques
return
}

csrfToken, err := setCSRFToken(r, w)

// Get the current user's email
currentUserEmail, err := s.GetTailscaleUserEmail(r)
if err != nil {
Expand All @@ -426,6 +431,7 @@ func (s *DiscussService) EditMemberProfile(w http.ResponseWriter, r *http.Reques
"CurrentUserEmail": currentUserEmail,
"Version": s.version,
"GitSha": s.gitSha,
"CSRFToken": csrfToken,
})
return
}
Expand All @@ -451,6 +457,7 @@ func (s *DiscussService) EditMemberProfile(w http.ResponseWriter, r *http.Reques
"CurrentUserEmail": currentUserEmail,
"Version": s.version,
"GitSha": s.gitSha,
"CSRFToken": csrfToken,
})
return
}
Expand Down Expand Up @@ -483,6 +490,7 @@ func (s *DiscussService) EditMemberProfile(w http.ResponseWriter, r *http.Reques
}

// Redirect to the member's profile page
// nosemgrep
http.Redirect(w, r, fmt.Sprintf("/member/%d", memberID), http.StatusSeeOther)
}

Expand Down Expand Up @@ -567,6 +575,8 @@ func (s *DiscussService) ListThreadPosts(w http.ResponseWriter, r *http.Request)
return
}

csrfToken, err := setCSRFToken(r, w)

email, err := s.GetTailscaleUserEmail(r)
if err != nil {
s.renderError(w, r, err, http.StatusInternalServerError)
Expand Down Expand Up @@ -615,10 +625,11 @@ func (s *DiscussService) ListThreadPosts(w http.ResponseWriter, r *http.Request)
"Title": BOARD_TITLE,
"ThreadPosts": threadPosts,
// nosemgrep
"Subject": template.HTML(subject),
"ID": threadID,
"GitSha": s.gitSha,
"Version": s.version,
"Subject": template.HTML(subject),
"ID": threadID,
"GitSha": s.gitSha,
"Version": s.version,
"CSRFToken": csrfToken,
})
}

Expand Down Expand Up @@ -703,6 +714,8 @@ func (s *DiscussService) NewThread(w http.ResponseWriter, r *http.Request) {
return
}

csrfToken, err := setCSRFToken(r, w)

email, err := s.GetTailscaleUserEmail(r)
if err != nil {
s.renderError(w, r, err, http.StatusInternalServerError)
Expand All @@ -717,9 +730,10 @@ func (s *DiscussService) NewThread(w http.ResponseWriter, r *http.Request) {
}

s.renderTemplate(w, r, "newthread.html", map[string]interface{}{
"User": email,
"Title": BOARD_TITLE,
"Version": s.version,
"GitSha": s.gitSha,
"User": email,
"Title": BOARD_TITLE,
"CSRFToken": csrfToken,
"Version": s.version,
"GitSha": s.gitSha,
})
}
Loading

0 comments on commit 9f91ad3

Please sign in to comment.