diff --git a/service_windows.go b/service_windows.go index d52fdbf2..c08cda81 100644 --- a/service_windows.go +++ b/service_windows.go @@ -5,6 +5,7 @@ package service import ( + "errors" "fmt" "os" "os/signal" @@ -39,6 +40,11 @@ const ( errnoServiceDoesNotExist syscall.Errno = 1060 ) +var ( + errAlreadyRunning = errors.New("service already running") + errAlreadyStopped = errors.New("service already stopped") +) + type windowsService struct { i Interface *Config @@ -178,48 +184,63 @@ func (ws *windowsService) getError() error { return ws.stopStartErr } -func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) { - const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown +func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, exitCode uint32) { + var err error + defer func() { + if err != nil { + ssec = true + ws.setError(err) + } + }() + + // Signal that we're starting. changes <- svc.Status{State: svc.StartPending} - if err := ws.i.Start(ws); err != nil { - ws.setError(err) - return true, 1 + // Perform the actual start. + if initErr := ws.i.Start(ws); initErr != nil { + err = initErr + exitCode = 1 + return } - changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + // Signal that we're ready. + changes <- svc.Status{ + State: svc.Running, + Accepts: svc.AcceptStop | svc.AcceptShutdown, + } + + // Expect service change requests. + var stopMethod func(s Service) error loop: - for { - c := <-r + for c := range r { switch c.Cmd { case svc.Interrogate: changes <- c.CurrentStatus - case svc.Stop: - changes <- svc.Status{State: svc.StopPending} - if err := ws.i.Stop(ws); err != nil { - ws.setError(err) - return true, 2 - } - break loop case svc.Shutdown: - changes <- svc.Status{State: svc.StopPending} - var err error - if wsShutdown, ok := ws.i.(Shutdowner); ok { - err = wsShutdown.Shutdown(ws) - } else { - err = ws.i.Stop(ws) - } - if err != nil { - ws.setError(err) - return true, 2 + if shutdowner, ok := ws.i.(Shutdowner); ok { + stopMethod = shutdowner.Shutdown + break loop } + fallthrough + case svc.Stop: + stopMethod = ws.i.Stop break loop default: - continue loop + continue } } - return false, 0 + // We were requested to stop, + // change state and proceed to do so. + changes <- svc.Status{State: svc.StopPending} + if stopErr := stopMethod(ws); stopErr != nil { + err = stopErr + exitCode = 2 + return + } + + // Calling function will set our state to Stopped. + return } func (ws *windowsService) Install() error { @@ -308,19 +329,21 @@ func (ws *windowsService) Uninstall() error { return err } defer m.Disconnect() + s, err := m.OpenService(ws.Name) if err != nil { return fmt.Errorf("service %s is not installed", ws.Name) } defer s.Close() - err = s.Delete() - if err != nil { + + if err := s.Delete(); err != nil { return err } - err = eventlog.Remove(ws.Name) - if err != nil { + + if err := eventlog.Remove(ws.Name); err != nil { return fmt.Errorf("RemoveEventLogSource() failed: %s", err) } + return nil } @@ -346,7 +369,7 @@ func (ws *windowsService) Run() error { return err } - sigChan := make(chan os.Signal) + sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt) @@ -408,26 +431,28 @@ func (ws *windowsService) Start() error { return err } defer s.Close() - return s.Start() -} -func (ws *windowsService) Stop() error { - m, err := mgr.Connect() + status, err := s.Query() if err != nil { return err } - defer m.Disconnect() - s, err := m.OpenService(ws.Name) - if err != nil { - return err + switch status.State { + default: + err = errAlreadyRunning + case svc.StopPending: + err = waitForStateChange(s, status, svc.Stopped) + case svc.Stopped: + if startErr := s.Start(); startErr != nil { + return startErr + } + err = waitForStateChange(s, status, svc.Running) } - defer s.Close() - return ws.stopWait(s) + return err } -func (ws *windowsService) Restart() error { +func (ws *windowsService) Stop() error { m, err := mgr.Connect() if err != nil { return err @@ -440,37 +465,104 @@ func (ws *windowsService) Restart() error { } defer s.Close() - err = ws.stopWait(s) + status, err := s.Query() if err != nil { return err } - return s.Start() -} - -func (ws *windowsService) stopWait(s *mgr.Service) error { - // First stop the service. Then wait for the service to - // actually stop before starting it. - status, err := s.Control(svc.Stop) - if err != nil { - return err + switch status.State { + case svc.Stopped: + err = errAlreadyStopped + case svc.StopPending: + err = waitForStateChange(s, status, svc.Stopped) + default: + if _, stopErr := s.Control(svc.Stop); stopErr != nil { + return stopErr + } + err = waitForStateChange(s, status, svc.Stopped) } - timeDuration := time.Millisecond * 50 - - timeout := time.After(getStopTimeout() + (timeDuration * 2)) - tick := time.NewTicker(timeDuration) - defer tick.Stop() + return err +} - for status.State != svc.Stopped { +func (ws *windowsService) Restart() error { + if stopErr := ws.Stop(); stopErr != nil { + return stopErr + } + return ws.Start() +} + +// statusInterval retreives a (bounded) duration from the status, +// or provides a default. +func statusInterval(status svc.Status) time.Duration { + // MSDN: + // "Do not wait longer than the wait hint. A good interval is + // one-tenth of the wait hint but not less than 1 second + // and not more than 10 seconds." + const ( + lower = time.Second + upper = time.Second * 10 + ) + + waitDuration := (time.Duration(status.WaitHint) * time.Millisecond) / 10 + if waitDuration < lower { + waitDuration = lower + } else if waitDuration > upper { + waitDuration = upper + } + return waitDuration +} + +// waitForStateChange polls the service until its state matches the desiredState, +// and error is encountered, or we timeout. +func waitForStateChange(s *mgr.Service, currentStatus svc.Status, desiredState svc.State) error { + const defaultAttempts = 10 + var ( + initialInterval = statusInterval(currentStatus) + queryTicker = time.NewTicker(initialInterval) + queryTimer *time.Timer + ) + // If the service is providing hints, + // use them, otherwise use a default timeout. + if currentStatus.CheckPoint != 0 { + queryTimer = time.NewTimer(initialInterval) + } else { + queryTimer = time.NewTimer(initialInterval * defaultAttempts) + } + defer func() { + queryTicker.Stop() + queryTimer.Stop() + }() + + var ( + currentState = currentStatus.State + lastCheckpoint uint32 + ) + for currentState != desiredState { select { - case <-tick.C: - status, err = s.Query() - if err != nil { - return err + case <-queryTicker.C: + currentStatus, queryErr := s.Query() + if queryErr != nil { + return queryErr + } + + currentState = currentStatus.State + if currentState == desiredState { + return nil + } + + if currentStatus.CheckPoint > lastCheckpoint { + // Service progressed, + // give it more time to complete. + if !queryTimer.Stop() { + <-queryTimer.C + } + queryTimer.Reset(statusInterval(currentStatus)) } - case <-timeout: - break + lastCheckpoint = currentStatus.CheckPoint + case <-queryTimer.C: + return fmt.Errorf("service did not enter desired state (%v) before we timed out", + desiredState) } } return nil