diff --git a/http.go b/http.go index c2d312c..786a3cc 100644 --- a/http.go +++ b/http.go @@ -8,6 +8,7 @@ import ( "net/http" "strconv" "strings" + "sync" "github.com/nautilus/graphql" ) @@ -27,6 +28,8 @@ type HTTPOperation struct { } `json:"extensions"` } +type setResultFunc func(r map[string]interface{}) + func formatErrors(err error) map[string]interface{} { return formatErrorsWithCode(nil, err, "UNKNOWN_ERROR") } @@ -70,12 +73,14 @@ func (g *Gateway) GraphQLHandler(w http.ResponseWriter, r *http.Request) { /// Handle the operations regardless of the request method // we have to respond to each operation in the right order - results := []map[string]interface{}{} + results := make([]map[string]interface{}, len(operations)) + opWg := new(sync.WaitGroup) + opMutex := new(sync.Mutex) // the status code to report statusCode := http.StatusOK - for _, operation := range operations { + for opNum, operation := range operations { // there might be a query plan cache key embedded in the operation cacheKey := "" if operation.Extensions.QueryPlanCache != nil { @@ -85,10 +90,8 @@ func (g *Gateway) GraphQLHandler(w http.ResponseWriter, r *http.Request) { // if there is no query or cache key if operation.Query == "" && cacheKey == "" { statusCode = http.StatusUnprocessableEntity - results = append( - results, - formatErrorsWithCode(nil, errors.New("could not find query body"), "BAD_USER_INPUT"), - ) + results[opNum] = formatErrorsWithCode(nil, errors.New("could not find query body"), "BAD_USER_INPUT") + continue } @@ -116,32 +119,12 @@ func (g *Gateway) GraphQLHandler(w http.ResponseWriter, r *http.Request) { return } - // fire the query with the request context passed through to execution - result, err := g.Execute(requestContext, plan) - if err != nil { - results = append(results, formatErrorsWithCode(result, err, "INTERNAL_SERVER_ERROR")) - - continue - } - - // the result for this operation - payload := map[string]interface{}{"data": result} - - // if there was a cache key associated with this query - if requestContext.CacheKey != "" { - // embed the cache key in the response - payload["extensions"] = map[string]interface{}{ - "persistedQuery": map[string]interface{}{ - "sha265Hash": requestContext.CacheKey, - "version": "1", - }, - } - } - - // add this result to the list - results = append(results, payload) + opWg.Add(1) + go g.executeRequest(requestContext, plan, opWg, g.setResultFunc(opNum, results, opMutex)) } + opWg.Wait() + // the final result depends on whether we are executing in batch mode or not var finalResponse interface{} if batchMode { @@ -165,6 +148,43 @@ func (g *Gateway) GraphQLHandler(w http.ResponseWriter, r *http.Request) { emitResponse(w, statusCode, string(response)) } +func (g *Gateway) setResultFunc(opNum int, results []map[string]interface{}, opMutex *sync.Mutex) setResultFunc { + return func(r map[string]interface{}) { + opMutex.Lock() + defer opMutex.Unlock() + results[opNum] = r + } +} + +func (g *Gateway) executeRequest(requestContext *RequestContext, plan QueryPlanList, opWg *sync.WaitGroup, setResult setResultFunc) { + defer opWg.Done() + + // fire the query with the request context passed through to execution + result, err := g.Execute(requestContext, plan) + if err != nil { + setResult(formatErrorsWithCode(result, err, "INTERNAL_SERVER_ERROR")) + + return + } + + // the result for this operation + payload := map[string]interface{}{"data": result} + + // if there was a cache key associated with this query + if requestContext.CacheKey != "" { + // embed the cache key in the response + payload["extensions"] = map[string]interface{}{ + "persistedQuery": map[string]interface{}{ + "sha265Hash": requestContext.CacheKey, + "version": "1", + }, + } + } + + // add this result to the list + setResult(payload) +} + // Parses request to operations (single or batch mode). // Returns an error and an error status code if the request is invalid. func parseRequest(r *http.Request) (operations []*HTTPOperation, batchMode bool, errStatusCode int, payloadErr error) { diff --git a/http_test.go b/http_test.go index 3b905c1..794cd31 100644 --- a/http_test.go +++ b/http_test.go @@ -14,6 +14,7 @@ import ( "strconv" "strings" "testing" + "time" "golang.org/x/net/html" @@ -971,6 +972,98 @@ func TestGraphQLHandler_postBatchWithMultipleFiles(t *testing.T) { assert.Equal(t, http.StatusOK, result.StatusCode) } +func TestGraphQLHandler_postBatchParallel(t *testing.T) { + t.Parallel() + schema, err := graphql.LoadSchema(` + type Query { + queryA: String! + queryB: String! + } + `) + assert.NoError(t, err) + + // create gateway schema we can test against + gateway, err := New([]*graphql.RemoteSchema{ + {Schema: schema, URL: "url-file-upload"}, + }, WithExecutor(ExecutorFunc( + func(ec *ExecutionContext) (map[string]interface{}, error) { + if ec.Plan.Operation.Name == "queryAOperation" { + time.Sleep(50 * time.Millisecond) + return map[string]interface{}{ + "queryA": "resultA", + }, nil + } + if ec.Plan.Operation.Name == "queryBOperation" { + return map[string]interface{}{ + "queryB": "resultB", + }, nil + } + + assert.Fail(t, "unexpected operation name", ec.Plan.Operation.Name) + return nil, nil + }, + ))) + + if err != nil { + t.Error(err.Error()) + return + } + + request := httptest.NewRequest("POST", "/graphql", strings.NewReader(`[ + { + "query": "query queryAOperation { queryA }", + "variables": null + }, + { + "query": "query queryBOperation { queryB }", + "variables": null + } + ]`)) + + // a recorder so we can check what the handler responded with + responseRecorder := httptest.NewRecorder() + + // call the http hander + gateway.GraphQLHandler(responseRecorder, request) + + // make sure we got correct order in response + response := responseRecorder.Result() + assert.Equal(t, http.StatusOK, response.StatusCode) + + // read the body + body, err := io.ReadAll(response.Body) + assert.NoError(t, response.Body.Close()) + + if err != nil { + t.Error(err.Error()) + return + } + + result := []map[string]interface{}{} + err = json.Unmarshal(body, &result) + if err != nil { + t.Error(err.Error()) + return + } + + // we should have gotten 2 responses + if !assert.Len(t, result, 2) { + return + } + + // make sure there were no errors in the first query + if firstQuery := result[0]; assert.Nil(t, firstQuery["errors"]) { + // make sure it has the right id + assert.Equal(t, map[string]interface{}{"queryA": "resultA"}, firstQuery["data"]) + } + + // make sure there were no errors in the second query + if secondQuery := result[1]; assert.Nil(t, secondQuery["errors"]) { + // make sure it has the right id + assert.Equal(t, map[string]interface{}{"queryB": "resultB"}, secondQuery["data"]) + } +} + func TestGraphQLHandler_postFilesWithError(t *testing.T) { t.Parallel() schema, err := graphql.LoadSchema(`