Skip to content

Commit

Permalink
Fix edge case in task queue, improve shutdown logic (#145)
Browse files Browse the repository at this point in the history
Co-authored-by: Luke Lombardi <[email protected]>
  • Loading branch information
luke-lombardi and Luke Lombardi authored Apr 17, 2024
1 parent 5ef19c2 commit 7c8d153
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 28 deletions.
7 changes: 7 additions & 0 deletions internal/abstractions/taskqueue/autoscaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ func taskQueueScaleFunc(i *taskQueueInstance, s *taskQueueAutoscalerSample) *abs
if s.QueueLength == 0 {
desiredContainers = 0
} else {
if s.QueueLength == -1 {
return &abstractions.AutoscalerResult{
DesiredContainers: 0,
ResultValid: false,
}
}

desiredContainers = int(s.QueueLength / int64(i.stubConfig.Concurrency))
if s.QueueLength%int64(i.stubConfig.Concurrency) > 0 {
desiredContainers += 1
Expand Down
12 changes: 8 additions & 4 deletions internal/abstractions/taskqueue/taskqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ type TaskQueueServiceOpts struct {
}

const (
taskQueueContainerPrefix string = "taskqueue"
taskQueueRoutePrefix string = "/taskqueue"
taskQueueContainerPrefix string = "taskqueue"
taskQueueRoutePrefix string = "/taskqueue"
taskQueueDefaultTaskExpiration int = 3600 * 12 // 12 hours
)

type RedisTaskQueue struct {
Expand Down Expand Up @@ -234,7 +235,10 @@ func (tq *RedisTaskQueue) put(ctx context.Context, workspaceName, stubId string,
return "", err
}

task, err := tq.taskDispatcher.SendAndExecute(ctx, string(types.ExecutorTaskQueue), workspaceName, stubId, payload, stubConfig.TaskPolicy)
policy := stubConfig.TaskPolicy
policy.Expires = time.Now().Add(time.Duration(taskQueueDefaultTaskExpiration) * time.Second)

task, err := tq.taskDispatcher.SendAndExecute(ctx, string(types.ExecutorTaskQueue), workspaceName, stubId, payload, policy)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -496,7 +500,7 @@ func (tq *RedisTaskQueue) TaskQueueMonitor(req *pb.TaskQueueMonitorRequest, stre
return err
}

err = tq.rdb.SetEx(ctx, Keys.taskQueueTaskRunningLock(authInfo.Workspace.Name, req.StubId, req.ContainerId, task.ExternalId), 1, time.Duration(defaultTaskRunningExpiration)*time.Second).Err()
err = tq.rdb.SetEx(ctx, Keys.taskQueueTaskRunningLock(authInfo.Workspace.Name, req.StubId, req.ContainerId, task.ExternalId), 1, time.Duration(1)*time.Second).Err()
if err != nil {
return err
}
Expand Down
7 changes: 0 additions & 7 deletions internal/repository/container_redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,6 @@ func (cr *ContainerRedisRepository) DeleteContainerState(request *types.Containe
return fmt.Errorf("failed to delete container addr <%v>: %w", addrKey, err)
}

// Remove container state key from index
indexKey := common.RedisKeys.SchedulerContainerIndex(request.StubId)
err = cr.rdb.SRem(context.TODO(), indexKey, stateKey).Err()
if err != nil {
return fmt.Errorf("failed to remove container state key from index <%v>: %w", indexKey, err)
}

return nil
}

Expand Down
26 changes: 20 additions & 6 deletions internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,15 @@ func (s *Worker) Run() error {
if err != nil {
log.Printf("Unable to run container <%s>: %v\n", containerId, err)

err := s.containerRepo.SetContainerExitCode(containerId, 1)
// Set a non-zero exit code for the container (both in memory, and in repo)
exitCode := 1
err := s.containerRepo.SetContainerExitCode(containerId, exitCode)
if err != nil {
log.Printf("<%s> - failed to set exit code: %v\n", containerId, err)
}

s.containerLock.Unlock()
s.clearContainer(containerId, request, time.Duration(0))
s.clearContainer(containerId, request, time.Duration(0), exitCode)
continue
}
}
Expand Down Expand Up @@ -394,7 +396,7 @@ func (s *Worker) stopContainer(event *common.Event) bool {
containerId := event.Args["container_id"].(string)

var err error = nil
if _, containerExists := s.containerInstances.Get(containerId); containerExists {
if _, exists := s.containerInstances.Get(containerId); exists {
log.Printf("<%s> - received stop container event.\n", containerId)
s.stopContainerChan <- containerId
}
Expand All @@ -406,7 +408,13 @@ func (s *Worker) processStopContainerEvents() {
for containerId := range s.stopContainerChan {
log.Printf("<%s> - stopping container.\n", containerId)

if _, containerExists := s.containerInstances.Get(containerId); !containerExists {
instance, exists := s.containerInstances.Get(containerId)
if !exists {
continue
}

// Container has already exited, just skip this event
if instance.ExitCode >= 0 {
continue
}

Expand Down Expand Up @@ -442,10 +450,10 @@ func (s *Worker) terminateContainer(containerId string, request *types.Container

defer s.containerWg.Done()

s.clearContainer(containerId, request, time.Duration(s.config.Worker.TerminationGracePeriod)*time.Second)
s.clearContainer(containerId, request, time.Duration(s.config.Worker.TerminationGracePeriod)*time.Second, *exitCode)
}

func (s *Worker) clearContainer(containerId string, request *types.ContainerRequest, delay time.Duration) {
func (s *Worker) clearContainer(containerId string, request *types.ContainerRequest, delay time.Duration, exitCode int) {
s.containerLock.Lock()

if request.Gpu != "" {
Expand All @@ -455,6 +463,12 @@ func (s *Worker) clearContainer(containerId string, request *types.ContainerRequ
s.completedRequests <- request
s.containerLock.Unlock()

instance, exists := s.containerInstances.Get(containerId)
if exists {
instance.ExitCode = exitCode
s.containerInstances.Set(containerId, instance)
}

go func() {
// Allow for some time to pass before clearing the container. This way we can handle some last
// minute logs or events or if the user wants to inspect the container before it's cleared.
Expand Down
30 changes: 19 additions & 11 deletions sdk/src/beta9/runner/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,18 @@

TASK_PROCESS_WATCHDOG_INTERVAL = 0.01
TASK_POLLING_INTERVAL = 0.01
TASK_MANAGER_INTERVAL = 0.1


class TaskQueueManager:
def __init__(self) -> None:
self._setup_signal_handlers()

set_start_method("spawn", force=True)

# Manager attributes
self.pid: int = os.getpid()
self.exit_code: int = 0
self.shutdown_event = Event()

# Register signal handlers
signal.signal(signal.SIGTERM, self.shutdown)
self._setup_signal_handlers()

This comment has been minimized.

Copy link
@jsun-m

jsun-m Apr 17, 2024

Contributor

were these just moved?

set_start_method("spawn", force=True)

# Task worker attributes
self.task_worker_count: int = config.concurrency
Expand All @@ -59,23 +57,33 @@ def __init__(self) -> None:
self.task_worker_watchdog_threads: List[threading.Thread] = []

def _setup_signal_handlers(self):
signal.signal(signal.SIGTERM, self.shutdown)
if os.getpid() == self.pid:
signal.signal(signal.SIGTERM, self._init_shutdown)

def _init_shutdown(self, signum=None, frame=None):
self.shutdown_event.set()

def run(self):
for worker_index in range(self.task_worker_count):
print(f"Starting task worker[{worker_index}]")
self._start_worker(worker_index)

for task_process in self.task_processes:
task_process.join()
while not self.shutdown_event.is_set():
time.sleep(TASK_MANAGER_INTERVAL)

self.shutdown()

def shutdown(self):
print("Spinning down taskqueue")

def shutdown(self, signum=None, frame=None):
# Terminate all worker processes
for task_process in self.task_processes:
task_process.terminate()
task_process.join()
task_process.join(timeout=5)

for task_process in self.task_processes:
if task_process.is_alive():
print("Task process did not join within the timeout. Terminating...")
task_process.terminate()
task_process.join(timeout=0)

Expand Down

0 comments on commit 7c8d153

Please sign in to comment.