diff --git a/internal/aggregator/ingress_proxy.go b/internal/aggregator/ingress_proxy.go index e882d714d..6aa9733ba 100644 --- a/internal/aggregator/ingress_proxy.go +++ b/internal/aggregator/ingress_proxy.go @@ -39,7 +39,8 @@ type clientPool struct { } type longpollClient struct { - queryID int64 + client *rpc.Client + cc rpc.CallbackContext requestLen int } @@ -158,7 +159,7 @@ func (ls *longpollShard) callback(client *rpc.Client, queryID int64, resp *rpc.R ls.mu.Lock() defer ls.mu.Unlock() lpc, ok := ls.clientList[hctx] - if !ok || lpc.queryID != queryID { + if !ok || lpc.client != client || lpc.cc.QueryID() != queryID { // server already cancelled longpoll call // or hctx was cancelled and reused by server before client response arrived // since we have no client cancellation, we rely on fact that client queryId does not repeat often @@ -182,10 +183,13 @@ func (ls *longpollShard) CancelHijack(hctx *rpc.HandlerContext) { ls.mu.Lock() defer ls.mu.Unlock() if lpc, ok := ls.clientList[hctx]; ok { - key := keyFromHctx(hctx, format.TagValueIDRPCRequestsStatusErrCancel) - ls.proxy.sh2.AddValueCounter(key, float64(lpc.requestLen), 1, nil) + delete(ls.clientList, hctx) + // same order of locks between client and ls.mu as below + if lpc.client.CancelDoCallback(lpc.cc) { + key := keyFromHctx(hctx, format.TagValueIDRPCRequestsStatusErrCancel) + ls.proxy.sh2.AddValueCounter(key, float64(lpc.requestLen), 1, nil) + } // otherwise callback was/will be called } - delete(ls.clientList, hctx) } func (proxy *IngressProxy) syncHandler(ctx context.Context, hctx *rpc.HandlerContext) error { @@ -221,10 +225,11 @@ func (proxy *IngressProxy) syncHandlerImpl(ctx context.Context, hctx *rpc.Handle ls := proxy.longpollShards[lockShardID] ls.mu.Lock() // to avoid race with longpoll cancellation, all code below must run under lock defer ls.mu.Unlock() - if _, err := client.DoCallback(ctx, proxy.config.Network, address, req, ls.callback, hctx); err != nil { + cc, err := client.DoCallback(ctx, proxy.config.Network, address, req, ls.callback, hctx) + if err != nil { return format.TagValueIDRPCRequestsStatusErrLocal, err } - ls.clientList[hctx] = longpollClient{queryID: req.QueryID(), requestLen: requestLen} + ls.clientList[hctx] = longpollClient{client: client, cc: cc, requestLen: requestLen} return 0, hctx.HijackResponse(ls) }