diff --git a/lib/coordinator.go b/lib/coordinator.go index be8e053..3b36745 100644 --- a/lib/coordinator.go +++ b/lib/coordinator.go @@ -925,7 +925,7 @@ func (crd *Coordinator) doRequest(ctx context.Context, worker *WorkerClient, req // The worker is in ACPT state. // It will not finish recovery because of ACPT. The worker will never get back into the pool. // Just marking the state as FNSH and dispatchRequest will return the worker back to the pool. - worker.setState(wsFnsh, false) + worker.setState(wsFnsh) return false, ErrReqParseFail } cnt := 1 diff --git a/lib/util.go b/lib/util.go index 1852810..1ba6930 100644 --- a/lib/util.go +++ b/lib/util.go @@ -79,8 +79,8 @@ func IsPidRunning(pid int) (isRunning bool) { } /* - 1st return value: the number - 2nd return value: the number of digits +1st return value: the number +2nd return value: the number of digits */ func atoi(bf []byte) (int, int) { sz := len(bf) @@ -96,8 +96,8 @@ func atoi(bf []byte) (int, int) { } /* - 1st return value: the number - 2nd return value: the number of digits +1st return value: the number +2nd return value: the number of digits */ func atoui(str string) (uint64, int) { sz := len(str) @@ -164,3 +164,13 @@ func ExtractSQLHash(request *netstring.Netstring) (uint32, bool) { } return 0, false } + +// Contains This is utility method to check whether value present in list or not +func Contains[T comparable](slice []T, value T) bool { + for _, val := range slice { + if val == value { + return true + } + } + return false +} diff --git a/lib/workerbroker.go b/lib/workerbroker.go index f360114..95cfe0a 100644 --- a/lib/workerbroker.go +++ b/lib/workerbroker.go @@ -318,7 +318,7 @@ func (broker *WorkerBroker) startWorkerMonitor() (err error) { if logger.GetLogger().V(logger.Debug) { logger.GetLogger().Log(logger.Debug, "worker (pid=", workerclient.pid, ") received signal. transits from state ", workerclient.Status, " to terminated.") } - workerclient.setState(wsUnset, true) // Set the state to UNSET to make sure worker does not stay in FNSH state so long + workerclient.setState(wsUnset) // Set the state to UNSET to make sure worker does not stay in FNSH state so long pool.RestartWorker(workerclient) } } else { diff --git a/lib/workerclient.go b/lib/workerclient.go index d5ebe0a..9c920c1 100644 --- a/lib/workerclient.go +++ b/lib/workerclient.go @@ -21,6 +21,10 @@ import ( "bytes" "errors" "fmt" + "github.com/paypal/hera/cal" + "github.com/paypal/hera/common" + "github.com/paypal/hera/utility/encoding/netstring" + "github.com/paypal/hera/utility/logger" "math/rand" "net" "os" @@ -28,14 +32,10 @@ import ( "runtime" "strconv" "strings" + "sync" "sync/atomic" "syscall" "time" - - "github.com/paypal/hera/cal" - "github.com/paypal/hera/common" - "github.com/paypal/hera/utility/encoding/netstring" - "github.com/paypal/hera/utility/logger" ) // HeraWorkerStatus defines the posible states the worker can be in @@ -54,6 +54,17 @@ const ( MaxWorkerState = 7 ) +var validStateTransitionMap map[HeraWorkerStatus][]HeraWorkerStatus = map[HeraWorkerStatus][]HeraWorkerStatus{ + wsUnset: {wsSchd, wsInit}, + wsSchd: {wsInit, wsUnset}, + wsInit: {wsSchd, wsAcpt, wsUnset}, + wsAcpt: {wsBusy}, + wsBusy: {wsWait, wsQuce, wsFnsh}, + wsWait: {wsQuce, wsFnsh}, + wsFnsh: {wsAcpt}, + wsQuce: {wsInit, wsFnsh}, //Forceful termination target state "wsInit", Graceful termination "wsFnsh" +} + const bfChannelSize = 30 // workerMsg is used to communicate with the coordinator, it contains the control message metadata plus the actual payload @@ -154,6 +165,9 @@ type WorkerClient struct { // Throtle workers lifecycle thr Throttler + + //mutex lock to update state from single go-routine + stateLock sync.Mutex } type strandedCalInfo struct { @@ -472,7 +486,7 @@ func (worker *WorkerClient) StartWorker() (err error) { logger.GetLogger().Log(logger.Info, "Started ", workerPath, ", pid=", pid) } worker.pid = pid - worker.setState(wsInit, false) + worker.setState(wsInit) return nil } @@ -542,7 +556,7 @@ func (worker *WorkerClient) attachToWorker() (err error) { logger.GetLogger().Log(logger.Info, "Got control message from worker (", worker.ID, ",", worker.pid, ",", worker.racID, ",", worker.dbUname, ")") } - worker.setState(wsAcpt, false) + worker.setState(wsAcpt) pool, err := GetWorkerBrokerInstance().GetWorkerPool(worker.Type, worker.instID, worker.shardID) if err != nil { @@ -669,7 +683,7 @@ func (worker *WorkerClient) Recover(p *WorkerPool, ticket string, recovParam Wor if logger.GetLogger().V(logger.Debug) { logger.GetLogger().Log(logger.Debug, fmt.Sprintf("about to recover worker Id: %d, worker process Id: %d as part of reconvery process, setting worker state to Quece", worker.ID, worker.pid)) } - worker.setState(wsQuce, true) + worker.setState(wsQuce) killparam := common.StrandedClientClose if len(param) > 0 { killparam = param[0] @@ -680,7 +694,7 @@ func (worker *WorkerClient) Recover(p *WorkerPool, ticket string, recovParam Wor select { case <-workerRecoverTimeout: worker.thr.CanRun() - worker.setState(wsInit, true) // Set the worker state to INIT when we decide to Terminate the worker + worker.setState(wsInit) // Set the worker state to INIT when we decide to Terminate the worker GetStateLog().PublishStateEvent(StateEvent{eType: WorkerStateEvt, shardID: worker.shardID, wType: worker.Type, instID: worker.instID, workerID: worker.ID, newWState: worker.Status}) worker.Terminate() worker.callogStranded("RECYCLED", info) @@ -727,7 +741,7 @@ func (worker *WorkerClient) Recover(p *WorkerPool, ticket string, recovParam Wor } worker.callogStranded("RECOVERED", info) - worker.setState(wsFnsh, true) + worker.setState(wsFnsh) if logger.GetLogger().V(logger.Debug) { logger.GetLogger().Log(logger.Debug, fmt.Sprintf("worker Id: %d, worker process: %d recovered as part of message from channel set status to FINSH", worker.ID, worker.pid)) } @@ -910,13 +924,13 @@ func (worker *WorkerClient) doRead() { logger.GetLogger().Log(logger.Verbose, "workerclient (<<< pid =", worker.pid, ",wrqId:", worker.rqId, "): EOR code:", eor, ", rqId: ", rqId, ", data:", DebugString(payload)) } if eor == common.EORFree { - worker.setState(wsFnsh, false) + worker.setState(wsFnsh) /*worker.sqlStartTimeMs = 0 if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, "workerclient sqltime=", worker.sqlStartTimeMs) }*/ } else { - worker.setState(wsWait, false) + worker.setState(wsWait) } if eor != common.EORMoreIncomingRequests { worker.outCh <- &workerMsg{data: payload, eor: true, free: (eor == common.EORFree), inTransaction: ((eor == common.EORInTransaction) || (eor == common.EORInCursorInTransaction)), rqId: rqId} @@ -940,7 +954,7 @@ func (worker *WorkerClient) doRead() { return default: if ns.Cmd != common.RcStillExecuting { - worker.setState(wsWait, false) + worker.setState(wsWait) } if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, "workerclient (<<< pid =", worker.pid, "): data:", DebugString(ns.Serialized), len(ns.Serialized)) @@ -956,7 +970,7 @@ func (worker *WorkerClient) doRead() { // Write sends a message to the worker func (worker *WorkerClient) Write(ns *netstring.Netstring, nsCount uint16) error { - worker.setState(wsBusy, false) + worker.setState(wsBusy) worker.rqId += uint32(nsCount) @@ -980,30 +994,24 @@ func (worker *WorkerClient) Write(ns *netstring.Netstring, nsCount uint16) error } // setState updates the worker state -func (worker *WorkerClient) setState(status HeraWorkerStatus, callFromRecovery bool) { - if worker.Status == status { +func (worker *WorkerClient) setState(status HeraWorkerStatus) { + currentStatus := worker.Status + if currentStatus == status { return } - if worker.isUnderRecovery == 1 && !callFromRecovery { - if logger.GetLogger().V(logger.Info) { - //If worker under recovery drinup of channel happens as part of DrainResponseChannel - logger.GetLogger().Log(logger.Info, "worker : ", worker.ID, " is under recovery. "+ - "workerclient pid=", worker.pid, "not allowed changing status from", worker.Status, "to", status) - } + //This checks whether state transition is valid or not + if Contains(validStateTransitionMap[currentStatus], status) { + worker.stateLock.Lock() + worker.Status = status + worker.stateLock.Unlock() + GetStateLog().PublishStateEvent(StateEvent{eType: WorkerStateEvt, shardID: worker.shardID, wType: worker.Type, instID: worker.instID, workerID: worker.ID, newWState: status}) + } else { + logger.GetLogger().Log(logger.Warning, "worker : ", worker.ID, "processId: ", worker.pid, " seeing invalid state transition from ", currentStatus, " to ", status) if logger.GetLogger().V(logger.Debug) { worker.printCallStack() } - return - } - if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, "worker Id=", worker.ID, " worker pid=", worker.pid, " changing status from", worker.Status, "to", status) - worker.printCallStack() } - // TODO: sync atomic set - worker.Status = status - - GetStateLog().PublishStateEvent(StateEvent{eType: WorkerStateEvt, shardID: worker.shardID, wType: worker.Type, instID: worker.instID, workerID: worker.ID, newWState: status}) } // Channel returns the worker out channel diff --git a/lib/workerpool.go b/lib/workerpool.go index 7c01693..50aab16 100644 --- a/lib/workerpool.go +++ b/lib/workerpool.go @@ -116,7 +116,7 @@ func (pool *WorkerPool) Init(wType HeraWorkerType, size int, instID int, shardID func (pool *WorkerPool) spawnWorker(wid int) error { worker := NewWorker(wid, pool.Type, pool.InstID, pool.ShardID, pool.moduleName, pool.thr) - worker.setState(wsSchd, false) + worker.setState(wsSchd) millis := rand.Intn(GetConfig().RandomStartMs) if logger.GetLogger().V(logger.Alert) { logger.GetLogger().Log(logger.Alert, wid, "randomized start ms", millis) @@ -495,7 +495,7 @@ func (pool *WorkerPool) ReturnWorker(worker *WorkerClient, ticket string) (err e worker.DrainResponseChannel(time.Microsecond * 10) } - worker.setState(wsAcpt, true) + worker.setState(wsAcpt) if (pool.desiredSize < pool.currentSize) && (worker.ID >= pool.desiredSize) { go func(w *WorkerClient) { if logger.GetLogger().V(logger.Info) { diff --git a/lib/workerpool_test.go b/lib/workerpool_test.go index a2afe0e..4e3db96 100644 --- a/lib/workerpool_test.go +++ b/lib/workerpool_test.go @@ -76,12 +76,12 @@ func TestPoolDempotency(t *testing.T) { wd := NewWorker(3, wtypeRW, 0, 0, "cloc", nil) we := NewWorker(4, wtypeRW, 0, 0, "cloc", nil) wf := NewWorker(5, wtypeRW, 0, 0, "cloc", nil) - wa.setState(wsAcpt, false) - wb.setState(wsAcpt, false) - wc.setState(wsAcpt, false) - wd.setState(wsAcpt, false) - we.setState(wsAcpt, false) - wf.setState(wsAcpt, false) + wa.setState(wsAcpt) + wb.setState(wsAcpt) + wc.setState(wsAcpt) + wd.setState(wsAcpt) + we.setState(wsAcpt) + wf.setState(wsAcpt) pool.WorkerReady(wa) pool.WorkerReady(wb) pool.WorkerReady(wc)