Skip to content

Commit

Permalink
Improve handling of CORS requests in REST API.
Browse files Browse the repository at this point in the history
- Consolidate handling of preflights with a catch-all OPTIONS endpoint.
- Decorate file retrieval responses with CORS headers.

Also moved method requirements into the HandleFunc call for simplicity.
  • Loading branch information
LTLA committed Jan 21, 2025
1 parent 385134f commit d425a4c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 67 deletions.
73 changes: 13 additions & 60 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func dumpJsonResponse(w http.ResponseWriter, status int, v interface{}) {
}

w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Origin", "*") // setting this for CORS.
w.WriteHeader(status)
_, err = w.Write(contents)
if err != nil {
Expand Down Expand Up @@ -322,28 +322,8 @@ func newDeregisterFinishHandler(db *sql.DB, verifier *verificationRegistry, time

/**********************************************************************/

func configureCors(w http.ResponseWriter, r *http.Request) bool {
if r.Method == "OPTIONS" {
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
w.WriteHeader(http.StatusNoContent)
return true
} else {
return false
}
}

func newQueryHandler(db *sql.DB, tokenizer *unicodeTokenizer, wild_tokenizer *unicodeTokenizer, endpoint string) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if configureCors(w, r) {
return
}
if r.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

params := r.URL.Query()
var scroll *scrollPosition
limit := 100
Expand Down Expand Up @@ -451,14 +431,6 @@ func getRetrievePath(params url.Values) (string, error) {

func newRetrieveMetadataHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if configureCors(w, r) {
return
}
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

params := r.URL.Query()
path, err := getRetrievePath(params)
if err != nil {
Expand Down Expand Up @@ -491,14 +463,6 @@ func newRetrieveMetadataHandler(db *sql.DB) func(http.ResponseWriter, *http.Requ

func newRetrieveFileHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if configureCors(w, r) {
return
}
if r.Method != "GET" && r.Method != "HEAD" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

params := r.URL.Query()
path, err := getRetrievePath(params)
if err != nil {
Expand Down Expand Up @@ -537,6 +501,9 @@ func newRetrieveFileHandler(db *sql.DB) func(http.ResponseWriter, *http.Request)
return
}

// Setting this for CORS.
w.Header().Set("Access-Control-Allow-Origin", "*")

if (r.Method == "HEAD") {
w.Header().Set("Content-Length", strconv.FormatInt(info.Size(), 10))
w.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat))
Expand Down Expand Up @@ -570,14 +537,6 @@ func newRetrieveFileHandler(db *sql.DB) func(http.ResponseWriter, *http.Request)

func newListFilesHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if configureCors(w, r) {
return
}
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

params := r.URL.Query()
recursive := params.Get("recursive") == "true"
path, err := getRetrievePath(params)
Expand Down Expand Up @@ -614,14 +573,6 @@ func newListFilesHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) {

func newListRegisteredDirectoriesHandler(db *sql.DB) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if configureCors(w, r) {
return
}
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}

query := listRegisteredDirectoriesQuery{}

params := r.URL.Query()
Expand Down Expand Up @@ -660,13 +611,15 @@ func newListRegisteredDirectoriesHandler(db *sql.DB) func(http.ResponseWriter, *

func newDefaultHandler() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if configureCors(w, r) {
return
}
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
dumpJsonResponse(w, http.StatusOK, map[string]string{ "name": "SewerRat API", "url": "https://github.com/ArtifactDB/SewerRat" })
}
}

func newOptionsHandler() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "*")
w.WriteHeader(http.StatusNoContent)
}
}
16 changes: 9 additions & 7 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ func main() {
http.HandleFunc("POST " + prefix + "/deregister/start", newDeregisterStartHandler(db, verifier))
http.HandleFunc("POST " + prefix + "/deregister/finish", newDeregisterFinishHandler(db, verifier, timeout))

http.HandleFunc(prefix + "/registered", newListRegisteredDirectoriesHandler(ro_db))
http.HandleFunc(prefix + "/query", newQueryHandler(ro_db, tokenizer, wild_tokenizer, "/query"))
http.HandleFunc(prefix + "/retrieve/metadata", newRetrieveMetadataHandler(ro_db))
http.HandleFunc(prefix + "/retrieve/file", newRetrieveFileHandler(ro_db))
http.HandleFunc(prefix + "/list", newListFilesHandler(ro_db))

http.HandleFunc(prefix + "/", newDefaultHandler())
http.HandleFunc("GET " + prefix + "/registered", newListRegisteredDirectoriesHandler(ro_db))
http.HandleFunc("POST " + prefix + "/query", newQueryHandler(ro_db, tokenizer, wild_tokenizer, "/query"))
http.HandleFunc("GET " + prefix + "/retrieve/metadata", newRetrieveMetadataHandler(ro_db))
http.HandleFunc("GET " + prefix + "/retrieve/file", newRetrieveFileHandler(ro_db))
http.HandleFunc("HEAD " + prefix + "/retrieve/file", newRetrieveFileHandler(ro_db))
http.HandleFunc("GET " + prefix + "/list", newListFilesHandler(ro_db))

http.HandleFunc("GET " + prefix + "/", newDefaultHandler())
http.HandleFunc("OPTIONS " + prefix + "/", newOptionsHandler())

// Adding an hourly job that does a full checkpoint.
{
Expand Down

0 comments on commit d425a4c

Please sign in to comment.