From d425a4c61f9e0f724506e7522d7ce3fafe47b150 Mon Sep 17 00:00:00 2001 From: LTLA Date: Tue, 21 Jan 2025 14:09:13 -0800 Subject: [PATCH] Improve handling of CORS requests in REST API. - Consolidate handling of preflights with a catch-all OPTIONS endpoint. - Decorate file retrieval responses with CORS headers. Also moved method requirements into the HandleFunc call for simplicity. --- handlers.go | 73 ++++++++++------------------------------------------- main.go | 16 +++++++----- 2 files changed, 22 insertions(+), 67 deletions(-) diff --git a/handlers.go b/handlers.go index 0b82e59..4a279b5 100644 --- a/handlers.go +++ b/handlers.go @@ -31,7 +31,7 @@ func dumpJsonResponse(w http.ResponseWriter, status int, v interface{}) { } w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Origin", "*") // setting this for CORS. w.WriteHeader(status) _, err = w.Write(contents) if err != nil { @@ -322,28 +322,8 @@ func newDeregisterFinishHandler(db *sql.DB, verifier *verificationRegistry, time /**********************************************************************/ -func configureCors(w http.ResponseWriter, r *http.Request) bool { - if r.Method == "OPTIONS" { - w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "*") - w.WriteHeader(http.StatusNoContent) - return true - } else { - return false - } -} - func newQueryHandler(db *sql.DB, tokenizer *unicodeTokenizer, wild_tokenizer *unicodeTokenizer, endpoint string) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - if configureCors(w, r) { - return - } - if r.Method != "POST" { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - params := r.URL.Query() var scroll *scrollPosition limit := 100 @@ -451,14 +431,6 @@ func getRetrievePath(params url.Values) (string, error) { func newRetrieveMetadataHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - if configureCors(w, r) { - return - } - if r.Method != "GET" { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - params := r.URL.Query() path, err := getRetrievePath(params) if err != nil { @@ -491,14 +463,6 @@ func newRetrieveMetadataHandler(db *sql.DB) func(http.ResponseWriter, *http.Requ func newRetrieveFileHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - if configureCors(w, r) { - return - } - if r.Method != "GET" && r.Method != "HEAD" { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - params := r.URL.Query() path, err := getRetrievePath(params) if err != nil { @@ -537,6 +501,9 @@ func newRetrieveFileHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) return } + // Setting this for CORS. + w.Header().Set("Access-Control-Allow-Origin", "*") + if (r.Method == "HEAD") { w.Header().Set("Content-Length", strconv.FormatInt(info.Size(), 10)) w.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat)) @@ -570,14 +537,6 @@ func newRetrieveFileHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) func newListFilesHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - if configureCors(w, r) { - return - } - if r.Method != "GET" { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - params := r.URL.Query() recursive := params.Get("recursive") == "true" path, err := getRetrievePath(params) @@ -614,14 +573,6 @@ func newListFilesHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) { func newListRegisteredDirectoriesHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - if configureCors(w, r) { - return - } - if r.Method != "GET" { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - query := listRegisteredDirectoriesQuery{} params := r.URL.Query() @@ -660,13 +611,15 @@ func newListRegisteredDirectoriesHandler(db *sql.DB) func(http.ResponseWriter, * func newDefaultHandler() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - if configureCors(w, r) { - return - } - if r.Method != "GET" { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } dumpJsonResponse(w, http.StatusOK, map[string]string{ "name": "SewerRat API", "url": "https://github.com/ArtifactDB/SewerRat" }) } } + +func newOptionsHandler() func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "*") + w.WriteHeader(http.StatusNoContent) + } +} diff --git a/main.go b/main.go index 9ac0095..1667d0d 100644 --- a/main.go +++ b/main.go @@ -65,13 +65,15 @@ func main() { http.HandleFunc("POST " + prefix + "/deregister/start", newDeregisterStartHandler(db, verifier)) http.HandleFunc("POST " + prefix + "/deregister/finish", newDeregisterFinishHandler(db, verifier, timeout)) - http.HandleFunc(prefix + "/registered", newListRegisteredDirectoriesHandler(ro_db)) - http.HandleFunc(prefix + "/query", newQueryHandler(ro_db, tokenizer, wild_tokenizer, "/query")) - http.HandleFunc(prefix + "/retrieve/metadata", newRetrieveMetadataHandler(ro_db)) - http.HandleFunc(prefix + "/retrieve/file", newRetrieveFileHandler(ro_db)) - http.HandleFunc(prefix + "/list", newListFilesHandler(ro_db)) - - http.HandleFunc(prefix + "/", newDefaultHandler()) + http.HandleFunc("GET " + prefix + "/registered", newListRegisteredDirectoriesHandler(ro_db)) + http.HandleFunc("POST " + prefix + "/query", newQueryHandler(ro_db, tokenizer, wild_tokenizer, "/query")) + http.HandleFunc("GET " + prefix + "/retrieve/metadata", newRetrieveMetadataHandler(ro_db)) + http.HandleFunc("GET " + prefix + "/retrieve/file", newRetrieveFileHandler(ro_db)) + http.HandleFunc("HEAD " + prefix + "/retrieve/file", newRetrieveFileHandler(ro_db)) + http.HandleFunc("GET " + prefix + "/list", newListFilesHandler(ro_db)) + + http.HandleFunc("GET " + prefix + "/", newDefaultHandler()) + http.HandleFunc("OPTIONS " + prefix + "/", newOptionsHandler()) // Adding an hourly job that does a full checkpoint. {