diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index bcd92c2f24..460a12ebb3 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -9,7 +9,7 @@ import ( ) type Notifier struct { - l sync.RWMutex + l sync.Mutex nodes map[string]chan<- types.StateUpdate } @@ -54,8 +54,8 @@ func (n *Notifier) NotifyAll(update types.StateUpdate) { } func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) { - n.l.RLock() - defer n.l.RUnlock() + n.l.Lock() + defer n.l.Unlock() for key, c := range n.nodes { if util.IsStringInSlice(ignore, key) { diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 7e86d591aa..c6b18bdcc5 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -62,6 +62,22 @@ func (h *Headscale) handlePoll( mapRequest tailcfg.MapRequest, isNoise bool, ) { + // Immediate open the channel and register it if the client wants + // a stream of MapResponses to prevent initial map response and + // following updates missing + var updateChan chan types.StateUpdate + if mapRequest.Stream { + h.pollNetMapStreamWG.Add(1) + defer h.pollNetMapStreamWG.Done() + + updateChan = make(chan types.StateUpdate) + defer closeChanWithLog(updateChan, machine.Hostname, "updateChan") + + // Register the node's update channel + h.nodeNotifier.AddNode(machine.MachineKey, updateChan) + defer h.nodeNotifier.RemoveNode(machine.MachineKey) + } + logInfo, logErr := logPollFunc(mapRequest, machine, isNoise) mapp := mapper.NewMapper( @@ -116,6 +132,21 @@ func (h *Headscale) handlePoll( return } + if !mapRequest.ReadOnly { + // It sounds like we should update the nodes when we have received a endpoint update + // even tho the comments in the tailscale code dont explicitly say so. + updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "endpoint-update"). + Inc() + + // Tell all the other nodes about the new endpoint, but dont update ourselves. + h.nodeNotifier.NotifyWithIgnore( + types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: []uint64{machine.ID}, + }, + machine.MachineKey) + } + // We update our peers if the client is not sending ReadOnly in the MapRequest // so we don't distribute its initial request (it comes with // empty endpoints to peers) @@ -165,18 +196,6 @@ func (h *Headscale) handlePoll( if err != nil { logErr(err, "Failed to write response") } - // It sounds like we should update the nodes when we have received a endpoint update - // even tho the comments in the tailscale code dont explicitly say so. - updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "endpoint-update"). - Inc() - - // Tell all the other nodes about the new endpoint, but dont update ourselves. - h.nodeNotifier.NotifyWithIgnore( - types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: []uint64{machine.ID}, - }, - machine.MachineKey) return } else if mapRequest.OmitPeers && mapRequest.Stream { @@ -213,43 +232,9 @@ func (h *Headscale) handlePoll( return } - h.pollNetMapStream( - writer, - ctx, - machine, - mapp, - mapRequest, - isNoise, - ) - - logInfo("Finished stream, closing PollNetMap session") -} - -// pollNetMapStream stream logic for /machine/map, -// ensuring we communicate updates and data to the connected clients. -func (h *Headscale) pollNetMapStream( - writer http.ResponseWriter, - ctxReq context.Context, - machine *types.Machine, - mapp *mapper.Mapper, - mapRequest tailcfg.MapRequest, - isNoise bool, -) { - logInfo, logErr := logPollFunc(mapRequest, machine, isNoise) - keepAliveTicker := time.NewTicker(keepAliveInterval) - h.pollNetMapStreamWG.Add(1) - defer h.pollNetMapStreamWG.Done() - - updateChan := make(chan types.StateUpdate) - defer closeChanWithLog(updateChan, machine.Hostname, "updateChan") - - // Register the node's update channel - h.nodeNotifier.AddNode(machine.MachineKey, updateChan) - defer h.nodeNotifier.RemoveNode(machine.MachineKey) - - ctx := context.WithValue(ctxReq, machineNameContextKey, machine.Hostname) + ctx = context.WithValue(ctx, machineNameContextKey, machine.Hostname) ctx, cancel := context.WithCancel(ctx) defer cancel()