diff --git a/internal/abstractions/taskqueue/autoscaler.go b/internal/abstractions/taskqueue/autoscaler.go index 7e6fa147c..8c35b7b24 100644 --- a/internal/abstractions/taskqueue/autoscaler.go +++ b/internal/abstractions/taskqueue/autoscaler.go @@ -59,6 +59,13 @@ func taskQueueScaleFunc(i *taskQueueInstance, s *taskQueueAutoscalerSample) *abs if s.QueueLength == 0 { desiredContainers = 0 } else { + if s.QueueLength == -1 { + return &abstractions.AutoscalerResult{ + DesiredContainers: 0, + ResultValid: false, + } + } + desiredContainers = int(s.QueueLength / int64(i.stubConfig.Concurrency)) if s.QueueLength%int64(i.stubConfig.Concurrency) > 0 { desiredContainers += 1 diff --git a/internal/abstractions/taskqueue/taskqueue.go b/internal/abstractions/taskqueue/taskqueue.go index ad2f05b3b..a9974bd17 100644 --- a/internal/abstractions/taskqueue/taskqueue.go +++ b/internal/abstractions/taskqueue/taskqueue.go @@ -44,8 +44,9 @@ type TaskQueueServiceOpts struct { } const ( - taskQueueContainerPrefix string = "taskqueue" - taskQueueRoutePrefix string = "/taskqueue" + taskQueueContainerPrefix string = "taskqueue" + taskQueueRoutePrefix string = "/taskqueue" + taskQueueDefaultTaskExpiration int = 3600 * 12 // 12 hours ) type RedisTaskQueue struct { @@ -234,7 +235,10 @@ func (tq *RedisTaskQueue) put(ctx context.Context, workspaceName, stubId string, return "", err } - task, err := tq.taskDispatcher.SendAndExecute(ctx, string(types.ExecutorTaskQueue), workspaceName, stubId, payload, stubConfig.TaskPolicy) + policy := stubConfig.TaskPolicy + policy.Expires = time.Now().Add(time.Duration(taskQueueDefaultTaskExpiration) * time.Second) + + task, err := tq.taskDispatcher.SendAndExecute(ctx, string(types.ExecutorTaskQueue), workspaceName, stubId, payload, policy) if err != nil { return "", err } @@ -496,7 +500,7 @@ func (tq *RedisTaskQueue) TaskQueueMonitor(req *pb.TaskQueueMonitorRequest, stre return err } - err = tq.rdb.SetEx(ctx, Keys.taskQueueTaskRunningLock(authInfo.Workspace.Name, req.StubId, req.ContainerId, task.ExternalId), 1, time.Duration(defaultTaskRunningExpiration)*time.Second).Err() + err = tq.rdb.SetEx(ctx, Keys.taskQueueTaskRunningLock(authInfo.Workspace.Name, req.StubId, req.ContainerId, task.ExternalId), 1, time.Duration(1)*time.Second).Err() if err != nil { return err } diff --git a/internal/repository/container_redis.go b/internal/repository/container_redis.go index af5599d94..622a94eca 100644 --- a/internal/repository/container_redis.go +++ b/internal/repository/container_redis.go @@ -163,13 +163,6 @@ func (cr *ContainerRedisRepository) DeleteContainerState(request *types.Containe return fmt.Errorf("failed to delete container addr <%v>: %w", addrKey, err) } - // Remove container state key from index - indexKey := common.RedisKeys.SchedulerContainerIndex(request.StubId) - err = cr.rdb.SRem(context.TODO(), indexKey, stateKey).Err() - if err != nil { - return fmt.Errorf("failed to remove container state key from index <%v>: %w", indexKey, err) - } - return nil } diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 5a92bdc42..c8be548c6 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -247,13 +247,15 @@ func (s *Worker) Run() error { if err != nil { log.Printf("Unable to run container <%s>: %v\n", containerId, err) - err := s.containerRepo.SetContainerExitCode(containerId, 1) + // Set a non-zero exit code for the container (both in memory, and in repo) + exitCode := 1 + err := s.containerRepo.SetContainerExitCode(containerId, exitCode) if err != nil { log.Printf("<%s> - failed to set exit code: %v\n", containerId, err) } s.containerLock.Unlock() - s.clearContainer(containerId, request, time.Duration(0)) + s.clearContainer(containerId, request, time.Duration(0), exitCode) continue } } @@ -394,7 +396,7 @@ func (s *Worker) stopContainer(event *common.Event) bool { containerId := event.Args["container_id"].(string) var err error = nil - if _, containerExists := s.containerInstances.Get(containerId); containerExists { + if _, exists := s.containerInstances.Get(containerId); exists { log.Printf("<%s> - received stop container event.\n", containerId) s.stopContainerChan <- containerId } @@ -406,7 +408,13 @@ func (s *Worker) processStopContainerEvents() { for containerId := range s.stopContainerChan { log.Printf("<%s> - stopping container.\n", containerId) - if _, containerExists := s.containerInstances.Get(containerId); !containerExists { + instance, exists := s.containerInstances.Get(containerId) + if !exists { + continue + } + + // Container has already exited, just skip this event + if instance.ExitCode >= 0 { continue } @@ -442,10 +450,10 @@ func (s *Worker) terminateContainer(containerId string, request *types.Container defer s.containerWg.Done() - s.clearContainer(containerId, request, time.Duration(s.config.Worker.TerminationGracePeriod)*time.Second) + s.clearContainer(containerId, request, time.Duration(s.config.Worker.TerminationGracePeriod)*time.Second, *exitCode) } -func (s *Worker) clearContainer(containerId string, request *types.ContainerRequest, delay time.Duration) { +func (s *Worker) clearContainer(containerId string, request *types.ContainerRequest, delay time.Duration, exitCode int) { s.containerLock.Lock() if request.Gpu != "" { @@ -455,6 +463,12 @@ func (s *Worker) clearContainer(containerId string, request *types.ContainerRequ s.completedRequests <- request s.containerLock.Unlock() + instance, exists := s.containerInstances.Get(containerId) + if exists { + instance.ExitCode = exitCode + s.containerInstances.Set(containerId, instance) + } + go func() { // Allow for some time to pass before clearing the container. This way we can handle some last // minute logs or events or if the user wants to inspect the container before it's cleared. diff --git a/sdk/src/beta9/runner/taskqueue.py b/sdk/src/beta9/runner/taskqueue.py index 926946085..b2d4dfccf 100644 --- a/sdk/src/beta9/runner/taskqueue.py +++ b/sdk/src/beta9/runner/taskqueue.py @@ -34,20 +34,18 @@ TASK_PROCESS_WATCHDOG_INTERVAL = 0.01 TASK_POLLING_INTERVAL = 0.01 +TASK_MANAGER_INTERVAL = 0.1 class TaskQueueManager: def __init__(self) -> None: - self._setup_signal_handlers() - - set_start_method("spawn", force=True) - # Manager attributes self.pid: int = os.getpid() self.exit_code: int = 0 + self.shutdown_event = Event() - # Register signal handlers - signal.signal(signal.SIGTERM, self.shutdown) + self._setup_signal_handlers() + set_start_method("spawn", force=True) # Task worker attributes self.task_worker_count: int = config.concurrency @@ -59,23 +57,33 @@ def __init__(self) -> None: self.task_worker_watchdog_threads: List[threading.Thread] = [] def _setup_signal_handlers(self): - signal.signal(signal.SIGTERM, self.shutdown) + if os.getpid() == self.pid: + signal.signal(signal.SIGTERM, self._init_shutdown) + + def _init_shutdown(self, signum=None, frame=None): + self.shutdown_event.set() def run(self): for worker_index in range(self.task_worker_count): print(f"Starting task worker[{worker_index}]") self._start_worker(worker_index) - for task_process in self.task_processes: - task_process.join() + while not self.shutdown_event.is_set(): + time.sleep(TASK_MANAGER_INTERVAL) + + self.shutdown() + + def shutdown(self): + print("Spinning down taskqueue") - def shutdown(self, signum=None, frame=None): + # Terminate all worker processes for task_process in self.task_processes: task_process.terminate() - task_process.join() + task_process.join(timeout=5) for task_process in self.task_processes: if task_process.is_alive(): + print("Task process did not join within the timeout. Terminating...") task_process.terminate() task_process.join(timeout=0)