Skip to content

Commit

Permalink
windows: wait for pending service control actions
Browse files Browse the repository at this point in the history
  • Loading branch information
djdv committed May 15, 2021
1 parent 30b888c commit 7b85551
Showing 1 changed file with 157 additions and 65 deletions.
222 changes: 157 additions & 65 deletions service_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package service

import (
"errors"
"fmt"
"os"
"os/signal"
Expand All @@ -21,6 +22,11 @@ import (

const version = "windows-service"

var (
errAlreadyRunning = errors.New("service already running")
errAlreadyStopped = errors.New("service already stopped")
)

type windowsService struct {
i Interface
*Config
Expand Down Expand Up @@ -160,48 +166,63 @@ func (ws *windowsService) getError() error {
return ws.stopStartErr
}

func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) {
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, exitCode uint32) {
var err error
defer func() {
if err != nil {
ssec = true
ws.setError(err)
}
}()

// Signal that we're starting.
changes <- svc.Status{State: svc.StartPending}

if err := ws.i.Start(ws); err != nil {
ws.setError(err)
return true, 1
// Perform the actual start.
if initErr := ws.i.Start(ws); initErr != nil {
err = initErr
exitCode = 1
return
}

changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
// Signal that we're ready.
changes <- svc.Status{
State: svc.Running,
Accepts: svc.AcceptStop | svc.AcceptShutdown,
}

// Expect service change requests.
var stopMethod func(s Service) error
loop:
for {
c := <-r
for c := range r {
switch c.Cmd {
case svc.Interrogate:
changes <- c.CurrentStatus
case svc.Stop:
changes <- svc.Status{State: svc.StopPending}
if err := ws.i.Stop(ws); err != nil {
ws.setError(err)
return true, 2
}
break loop
case svc.Shutdown:
changes <- svc.Status{State: svc.StopPending}
var err error
if wsShutdown, ok := ws.i.(Shutdowner); ok {
err = wsShutdown.Shutdown(ws)
} else {
err = ws.i.Stop(ws)
}
if err != nil {
ws.setError(err)
return true, 2
if shutdowner, ok := ws.i.(Shutdowner); ok {
stopMethod = shutdowner.Shutdown
break loop
}
fallthrough
case svc.Stop:
stopMethod = ws.i.Stop
break loop
default:
continue loop
continue
}
}

return false, 0
// We were requested to stop,
// change state and proceed to do so.
changes <- svc.Status{State: svc.StopPending}
if stopErr := stopMethod(ws); stopErr != nil {
err = stopErr
exitCode = 2
return
}

// Calling function will set our state to Stopped.
return
}

func (ws *windowsService) Install() error {
Expand Down Expand Up @@ -249,19 +270,21 @@ func (ws *windowsService) Uninstall() error {
return err
}
defer m.Disconnect()

s, err := m.OpenService(ws.Name)
if err != nil {
return fmt.Errorf("service %s is not installed", ws.Name)
}
defer s.Close()
err = s.Delete()
if err != nil {

if err := s.Delete(); err != nil {
return err
}
err = eventlog.Remove(ws.Name)
if err != nil {

if err := eventlog.Remove(ws.Name); err != nil {
return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
}

return nil
}

Expand All @@ -287,7 +310,7 @@ func (ws *windowsService) Run() error {
return err
}

sigChan := make(chan os.Signal)
sigChan := make(chan os.Signal, 1)

signal.Notify(sigChan, os.Interrupt)

Expand Down Expand Up @@ -349,26 +372,28 @@ func (ws *windowsService) Start() error {
return err
}
defer s.Close()
return s.Start()
}

func (ws *windowsService) Stop() error {
m, err := mgr.Connect()
status, err := s.Query()
if err != nil {
return err
}
defer m.Disconnect()

s, err := m.OpenService(ws.Name)
if err != nil {
return err
switch status.State {
default:
err = errAlreadyRunning
case svc.StopPending:
err = waitForStateChange(s, status, svc.Stopped)
case svc.Stopped:
if startErr := s.Start(); startErr != nil {
return startErr
}
err = waitForStateChange(s, status, svc.Running)
}
defer s.Close()

return ws.stopWait(s)
return err
}

func (ws *windowsService) Restart() error {
func (ws *windowsService) Stop() error {
m, err := mgr.Connect()
if err != nil {
return err
Expand All @@ -381,37 +406,104 @@ func (ws *windowsService) Restart() error {
}
defer s.Close()

err = ws.stopWait(s)
status, err := s.Query()
if err != nil {
return err
}

return s.Start()
}

func (ws *windowsService) stopWait(s *mgr.Service) error {
// First stop the service. Then wait for the service to
// actually stop before starting it.
status, err := s.Control(svc.Stop)
if err != nil {
return err
switch status.State {
case svc.Stopped:
err = errAlreadyStopped
case svc.StopPending:
err = waitForStateChange(s, status, svc.Stopped)
default:
if _, stopErr := s.Control(svc.Stop); stopErr != nil {
return stopErr
}
err = waitForStateChange(s, status, svc.Stopped)
}

timeDuration := time.Millisecond * 50

timeout := time.After(getStopTimeout() + (timeDuration * 2))
tick := time.NewTicker(timeDuration)
defer tick.Stop()
return err
}

for status.State != svc.Stopped {
func (ws *windowsService) Restart() error {
if stopErr := ws.Stop(); stopErr != nil {
return stopErr
}
return ws.Start()
}

// statusInterval retreives a (bounded) duration from the status,
// or provides a default.
func statusInterval(status svc.Status) time.Duration {
// MSDN:
// "Do not wait longer than the wait hint. A good interval is
// one-tenth of the wait hint but not less than 1 second
// and not more than 10 seconds."
const (
lower = time.Second
upper = time.Second * 10
)

waitDuration := (time.Duration(status.WaitHint) * time.Millisecond) / 10
if waitDuration < lower {
waitDuration = lower
} else if waitDuration > upper {
waitDuration = upper
}
return waitDuration
}

// waitForStateChange polls the service until its state matches the desiredState,
// and error is encountered, or we timeout.
func waitForStateChange(s *mgr.Service, currentStatus svc.Status, desiredState svc.State) error {
const defaultAttempts = 10
var (
initialInterval = statusInterval(currentStatus)
queryTicker = time.NewTicker(initialInterval)
queryTimer *time.Timer
)
// If the service is providing hints,
// use them, otherwise use a default timeout.
if currentStatus.CheckPoint != 0 {
queryTimer = time.NewTimer(initialInterval)
} else {
queryTimer = time.NewTimer(initialInterval * defaultAttempts)
}
defer func() {
queryTicker.Stop()
queryTimer.Stop()
}()

var (
currentState = currentStatus.State
lastCheckpoint uint32
)
for currentState != desiredState {
select {
case <-tick.C:
status, err = s.Query()
if err != nil {
return err
case <-queryTicker.C:
currentStatus, queryErr := s.Query()
if queryErr != nil {
return queryErr
}

currentState = currentStatus.State
if currentState == desiredState {
return nil
}

if currentStatus.CheckPoint > lastCheckpoint {
// Service progressed,
// give it more time to complete.
if !queryTimer.Stop() {
<-queryTimer.C
}
queryTimer.Reset(statusInterval(currentStatus))
}
case <-timeout:
break
lastCheckpoint = currentStatus.CheckPoint
case <-queryTimer.C:
return fmt.Errorf("service did not enter desired state (%v) before we timed out",
desiredState)
}
}
return nil
Expand Down

0 comments on commit 7b85551

Please sign in to comment.