diff --git a/rpc/wrtc_call_queue_memory.go b/rpc/wrtc_call_queue_memory.go index 8145dfa6..dbc03b54 100644 --- a/rpc/wrtc_call_queue_memory.go +++ b/rpc/wrtc_call_queue_memory.go @@ -18,12 +18,9 @@ import ( // testing and single node/host deployments. type memoryWebRTCCallQueue struct { mu sync.Mutex - activeBackgroundWorkers sync.WaitGroup + activeBackgroundWorkers *utils.StoppableWorkers hostQueues map[string]*singleWebRTCHostQueue - cancelCtx context.Context - cancelFunc func() - uuidDeterministic bool uuidDeterministicCounter int64 logger utils.ZapCompatibleLogger @@ -41,42 +38,24 @@ func newMemoryWebRTCCallQueueTest(logger utils.ZapCompatibleLogger) *memoryWebRT } func newMemoryWebRTCCallQueue(uuidDeterministic bool, logger utils.ZapCompatibleLogger) *memoryWebRTCCallQueue { - cancelCtx, cancelFunc := context.WithCancel(context.Background()) queue := &memoryWebRTCCallQueue{ hostQueues: map[string]*singleWebRTCHostQueue{}, - cancelCtx: cancelCtx, - cancelFunc: cancelFunc, uuidDeterministic: uuidDeterministic, logger: logger, } - queue.activeBackgroundWorkers.Add(1) - ticker := time.NewTicker(5 * time.Second) - utils.ManagedGo(func() { - for { - if cancelCtx.Err() != nil { - return - } - select { - case <-cancelCtx.Done(): - return - case <-ticker.C: - } - now := time.Now() - queue.mu.Lock() - for _, hostQueue := range queue.hostQueues { - hostQueue.mu.Lock() - for offerID, offer := range hostQueue.activeOffers { - if d, ok := offer.offer.answererDoneCtx.Deadline(); ok && d.Before(now) { - delete(hostQueue.activeOffers, offerID) - } + queue.activeBackgroundWorkers = utils.NewStoppableWorkerWithTicker(5*time.Second, func(ctx context.Context) { + now := time.Now() + queue.mu.Lock() + for _, hostQueue := range queue.hostQueues { + hostQueue.mu.Lock() + for offerID, offer := range hostQueue.activeOffers { + if d, ok := offer.offer.answererDoneCtx.Deadline(); ok && d.Before(now) { + delete(hostQueue.activeOffers, offerID) } - hostQueue.mu.Unlock() } - queue.mu.Unlock() + hostQueue.mu.Unlock() } - }, func() { - defer queue.activeBackgroundWorkers.Done() - defer ticker.Stop() + queue.mu.Unlock() }) return queue } @@ -113,7 +92,7 @@ func (queue *memoryWebRTCCallQueue) SendOfferInit( } answererResponses := make(chan WebRTCCallAnswer) offerDeadline := time.Now().Add(getDefaultOfferDeadline()) - sendCtx, sendCtxCancel := context.WithDeadline(queue.cancelCtx, offerDeadline) + sendCtx, sendCtxCancel := context.WithDeadline(queue.activeBackgroundWorkers.Context(), offerDeadline) offer := memoryWebRTCCallOfferInit{ uuid: newUUID, sdp: sdp, @@ -135,9 +114,7 @@ func (queue *memoryWebRTCCallQueue) SendOfferInit( hostQueueForSend.activeOffers[offer.uuid] = exchange hostQueueForSend.mu.Unlock() - queue.activeBackgroundWorkers.Add(1) - utils.PanicCapturingGo(func() { - queue.activeBackgroundWorkers.Done() + queue.activeBackgroundWorkers.Add(func(_ context.Context) { select { case <-sendCtx.Done(): case <-ctx.Done(): @@ -213,7 +190,7 @@ func (queue *memoryWebRTCCallQueue) SendOfferError(ctx context.Context, host, uu func (queue *memoryWebRTCCallQueue) RecvOffer(ctx context.Context, hosts []string) (WebRTCCallOfferExchange, error) { hostQueue := queue.getOrMakeHostsQueue(hosts) - recvCtx, recvCtxCancel := context.WithCancel(queue.cancelCtx) + recvCtx, recvCtxCancel := context.WithCancel(queue.activeBackgroundWorkers.Context()) defer recvCtxCancel() select { @@ -228,8 +205,7 @@ func (queue *memoryWebRTCCallQueue) RecvOffer(ctx context.Context, hosts []strin // Close cancels all active offers and waits to cleanly close all background workers. func (queue *memoryWebRTCCallQueue) Close() error { - queue.cancelFunc() - queue.activeBackgroundWorkers.Wait() + queue.activeBackgroundWorkers.Stop() return nil } diff --git a/rpc/wrtc_server.go b/rpc/wrtc_server.go index cf124575..e03f06d1 100644 --- a/rpc/wrtc_server.go +++ b/rpc/wrtc_server.go @@ -20,8 +20,6 @@ var DefaultWebRTCMaxGRPCCalls = 256 // A webrtcServer translates gRPC frames over WebRTC data channels into gRPC calls. type webrtcServer struct { - ctx context.Context - cancel context.CancelFunc handlers map[string]handlerFunc services map[string]*serviceInfo logger utils.ZapCompatibleLogger @@ -29,12 +27,7 @@ type webrtcServer struct { peerConnsMu sync.Mutex peerConns map[*webrtc.PeerConnection]struct{} - // processHeadersMu should be `Lock`ed in `Stop` to `Wait` on ongoing - // processHeaders calls (incoming method invocations). processHeaderMu should - // be `RLock`ed in processHeaders (allow concurrent processHeaders) to `Add` - // to processHeadersWorkers. - processHeadersMu sync.RWMutex - processHeadersWorkers sync.WaitGroup + processHeadersWorkers *utils.StoppableWorkers callTickets chan struct{} @@ -129,18 +122,15 @@ func newWebRTCServerWithInterceptorsAndUnknownStreamHandler( streamInt: streamInt, unknownStreamDesc: unknownStreamDesc, } - srv.ctx, srv.cancel = context.WithCancel(context.Background()) + srv.processHeadersWorkers = utils.NewBackgroundStoppableWorkers() return srv } // Stop instructs the server and all handlers to stop. It returns when all handlers // are done executing. func (srv *webrtcServer) Stop() { - srv.cancel() - srv.processHeadersMu.Lock() srv.logger.Info("waiting for handlers to complete") - srv.processHeadersWorkers.Wait() - srv.processHeadersMu.Unlock() + srv.processHeadersWorkers.Stop() srv.logger.Info("handlers complete") srv.logger.Info("closing lingering peer connections") diff --git a/rpc/wrtc_server_channel.go b/rpc/wrtc_server_channel.go index d9761c97..086f7c25 100644 --- a/rpc/wrtc_server_channel.go +++ b/rpc/wrtc_server_channel.go @@ -37,7 +37,7 @@ func newWebRTCServerChannel( logger utils.ZapCompatibleLogger, ) *webrtcServerChannel { base := newBaseChannel( - server.ctx, + server.processHeadersWorkers.Context(), peerConn, dataChannel, func() { server.removePeer(peerConn) }, diff --git a/rpc/wrtc_server_stream.go b/rpc/wrtc_server_stream.go index 76d8b0a8..1445a487 100644 --- a/rpc/wrtc_server_stream.go +++ b/rpc/wrtc_server_stream.go @@ -269,30 +269,18 @@ func (s *webrtcServerStream) processHeaders(headers *webrtcpb.RequestHeaders) { } s.ch.server.counters.HeadersProcessed.Add(1) - s.ch.server.processHeadersMu.RLock() - s.ch.server.processHeadersWorkers.Add(1) - s.ch.server.processHeadersMu.RUnlock() - - // Check if context has errored: underlying server may have been `Stop`ped, - // in which case we mark this processHeaders worker as `Done` and return. - if err := s.ch.server.ctx.Err(); err != nil { - s.ch.server.processHeadersWorkers.Done() - return - } // take a ticket select { case s.ch.server.callTickets <- struct{}{}: default: - s.ch.server.processHeadersWorkers.Done() s.closeWithSendError(status.Error(codes.ResourceExhausted, "too many in-flight requests")) return } s.headersReceived = true - utils.PanicCapturingGo(func() { + s.ch.server.processHeadersWorkers.Add(func(ctx context.Context) { defer func() { - s.ch.server.processHeadersWorkers.Done() <-s.ch.server.callTickets // return a ticket }() if err := handlerFunc(s); err != nil {