Skip to content

Commit

Permalink
Fix data race.
Browse files Browse the repository at this point in the history
  • Loading branch information
blakerouse committed Sep 19, 2023
1 parent ca4e405 commit b324949
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions internal/pkg/agent/install/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type progressTrackerStep struct {

rootstep bool
substeps bool
mu sync.Mutex
step *progressTrackerStep
}

Expand Down Expand Up @@ -82,15 +83,28 @@ func (pts *progressTrackerStep) StepStart(msg string) ProgressTrackerStep {
}
pts.tracker.printf("%s%s...", prefix, strings.TrimSpace(msg))
s := newProgressTrackerStep(pts.tracker, prefix, func() {
pts.step = nil
pts.setStep(nil)
})
pts.step = s
pts.setStep(s)
return s
}

func (pts *progressTrackerStep) getStep() *progressTrackerStep {
pts.mu.Lock()
defer pts.mu.Unlock()
return pts.step
}

func (pts *progressTrackerStep) setStep(step *progressTrackerStep) {
pts.mu.Lock()
defer pts.mu.Unlock()
pts.step = step
}

func (pts *progressTrackerStep) tick() {
if pts.step != nil {
pts.step.tick()
step := pts.getStep()
if step != nil {
step.tick()
return
}
if !pts.rootstep {
Expand Down Expand Up @@ -135,20 +149,21 @@ func (pt *ProgressTracker) Start() ProgressTrackerStep {
case <-pt.stop:
return
case <-timer.C:
if pt.step != nil {
pt.step.tick()
step := pt.getStep()
if step != nil {
step.tick()
}
timer = time.NewTimer(pt.calculateTickInterval())
}
}
}()

s := newProgressTrackerStep(pt, "", func() {
pt.step = nil
pt.setStep(nil)
pt.stop <- struct{}{}
})
s.rootstep = true // is the root step
pt.step = s
pt.setStep(s)
return s
}

Expand All @@ -158,6 +173,18 @@ func (pt *ProgressTracker) printf(format string, a ...any) {
_, _ = fmt.Fprintf(pt.writer, format, a...)
}

func (pt *ProgressTracker) getStep() *progressTrackerStep {
pt.mu.Lock()
defer pt.mu.Unlock()
return pt.step
}

func (pt *ProgressTracker) setStep(step *progressTrackerStep) {
pt.mu.Lock()
defer pt.mu.Unlock()
pt.step = step
}

func (pt *ProgressTracker) calculateTickInterval() time.Duration {
if !pt.randomizeTickInterval {
return pt.tickInterval
Expand Down

0 comments on commit b324949

Please sign in to comment.