Skip to content

Commit

Permalink
Add termination reason and message to the runner API (#2204)
Browse files Browse the repository at this point in the history
* Introduce TerminationReason and JobState types

* Handle runner API not avaiable when stopping

Maybe relevant for local runner when the runner container or shim was stopped

* Set max duration exceeded in termination message

* Add max_duration_exceeded termination reason

* Update shim OpenAPI spec

* Revert using TerminationReason enum in shim

The shim may expect any termination reason from the server

* Send termination_reason.value to shim
  • Loading branch information
r4victor authored Jan 21, 2025
1 parent e3c221f commit 6d93ecc
Show file tree
Hide file tree
Showing 13 changed files with 193 additions and 52 deletions.
9 changes: 0 additions & 9 deletions runner/consts/states/state.go

This file was deleted.

13 changes: 7 additions & 6 deletions runner/docs/shim.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ components:
TerminationReason:
type: string
enum:
- EXECUTOR_ERROR
- CREATING_CONTAINER_ERROR
- CONTAINER_EXITED_WITH_ERROR
- DONE_BY_RUNNER
- TERMINATED_BY_USER
- TERMINATED_BY_SERVER
- executor_error
- creating_container_error
- container_exited_with_error
- done_by_runner
- terminated_by_user
- terminated_by_server
- max_duration_exceeded

GpuID:
description: >
Expand Down
9 changes: 8 additions & 1 deletion runner/internal/executor/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"github.com/dstackai/dstack/runner/internal/schemas"
"github.com/dstackai/dstack/runner/internal/types"
)

type Executor interface {
Expand All @@ -13,7 +14,13 @@ type Executor interface {
Run(ctx context.Context) error
SetCodePath(codePath string)
SetJob(job schemas.SubmitBody)
SetJobState(ctx context.Context, state string)
SetJobState(ctx context.Context, state types.JobState)
SetJobStateWithTerminationReason(
ctx context.Context,
state types.JobState,
termination_reason types.TerminationReason,
termination_message string,
)
SetRunnerState(state string)
Lock()
RLock()
Expand Down
47 changes: 34 additions & 13 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import (

"github.com/creack/pty"
"github.com/dstackai/dstack/runner/consts"
"github.com/dstackai/dstack/runner/consts/states"
"github.com/dstackai/dstack/runner/internal/gerrors"
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/schemas"
"github.com/dstackai/dstack/runner/internal/types"
)

type RunExecutor struct {
Expand Down Expand Up @@ -79,14 +79,14 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string) (*RunExec
func (ex *RunExecutor) Run(ctx context.Context) (err error) {
runnerLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerLogFileName))
if err != nil {
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
return gerrors.Wrap(err)
}
defer func() { _ = runnerLogFile.Close() }()

jobLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerJobLogFileName))
if err != nil {
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
return gerrors.Wrap(err)
}
defer func() { _ = jobLogFile.Close() }()
Expand All @@ -95,7 +95,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
// recover goes after runnerLogFile.Close() to keep the log
if r := recover(); r != nil {
log.Error(ctx, "Executor PANIC", "err", r)
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
err = gerrors.Newf("recovered: %v", r)
}
// no more logs will be written after this
Expand All @@ -115,17 +115,17 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
log.Info(ctx, "Run job", "log_level", log.GetLogger(ctx).Logger.Level.String())

if err := ex.setupRepo(ctx); err != nil {
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
return gerrors.Wrap(err)
}
cleanupCredentials, err := ex.setupCredentials(ctx)
if err != nil {
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
return gerrors.Wrap(err)
}
defer cleanupCredentials()

ex.SetJobState(ctx, states.Running)
ex.SetJobState(ctx, types.JobStateRunning)
timeoutCtx := ctx
var cancelTimeout context.CancelFunc
if ex.jobSpec.MaxDuration != 0 {
Expand All @@ -136,26 +136,33 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
select {
case <-ctx.Done():
log.Error(ctx, "Job canceled")
ex.SetJobState(ctx, states.Terminated)
ex.SetJobState(ctx, types.JobStateTerminated)
return gerrors.Wrap(err)
default:
}

select {
case <-timeoutCtx.Done():
log.Error(ctx, "Max duration exceeded", "max_duration", ex.jobSpec.MaxDuration)
ex.SetJobState(ctx, states.Terminated)
// We do not set "max_duration_exceeded" termination reason yet for backward compatibility
// TODO: Set it several releases after 0.18.36
ex.SetJobStateWithTerminationReason(
ctx,
types.JobStateTerminated,
types.TerminationReasonContainerExitedWithError,
"Max duration exceeded",
)
return gerrors.Wrap(err)
default:
}

// todo fail reason?
log.Error(ctx, "Exec failed", "err", err)
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
return gerrors.Wrap(err)
}

ex.SetJobState(ctx, states.Done)
ex.SetJobState(ctx, types.JobStateDone)
return nil
}

Expand All @@ -173,9 +180,23 @@ func (ex *RunExecutor) SetCodePath(codePath string) {
ex.state = WaitRun
}

func (ex *RunExecutor) SetJobState(ctx context.Context, state string) {
func (ex *RunExecutor) SetJobState(ctx context.Context, state types.JobState) {
ex.SetJobStateWithTerminationReason(ctx, state, "", "")
}

func (ex *RunExecutor) SetJobStateWithTerminationReason(
ctx context.Context, state types.JobState, termination_reason types.TerminationReason, termination_message string,
) {
ex.mu.Lock()
ex.jobStateHistory = append(ex.jobStateHistory, schemas.JobStateEvent{State: state, Timestamp: ex.timestamp.Next()})
ex.jobStateHistory = append(
ex.jobStateHistory,
schemas.JobStateEvent{
State: state,
Timestamp: ex.timestamp.Next(),
TerminationReason: termination_reason,
TerminationMessage: termination_message,
},
)
ex.mu.Unlock()
log.Info(ctx, "Job state changed", "new", state)
}
Expand Down
12 changes: 9 additions & 3 deletions runner/internal/schemas/schemas.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package schemas

import "strings"
import (
"strings"

"github.com/dstackai/dstack/runner/internal/types"
)

type JobStateEvent struct {
State string `json:"state"`
Timestamp int64 `json:"timestamp"`
State types.JobState `json:"state"`
Timestamp int64 `json:"timestamp"`
TerminationReason types.TerminationReason `json:"termination_reason"`
TerminationMessage string `json:"termination_message"`
}

type LogEvent struct {
Expand Down
17 changes: 9 additions & 8 deletions runner/internal/shim/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/shim/backends"
"github.com/dstackai/dstack/runner/internal/shim/host"
"github.com/dstackai/dstack/runner/internal/types"
bytesize "github.com/inhies/go-bytesize"
"github.com/ztrue/tracerr"
)
Expand Down Expand Up @@ -260,7 +261,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
gpuIDs, err := d.gpuLock.Acquire(ctx, cfg.GPU)
if err != nil {
log.Error(ctx, err.Error())
task.SetStatusTerminated("EXECUTOR_ERROR", err.Error())
task.SetStatusTerminated(string(types.TerminationReasonExecutorError), err.Error())
return tracerr.Wrap(err)
}
task.gpuIDs = gpuIDs
Expand All @@ -279,7 +280,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err := ak.AppendPublicKeys(cfg.HostSshKeys); err != nil {
errMessage := fmt.Sprintf("ak.AppendPublicKeys error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated("EXECUTOR_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonExecutorError), errMessage)
return tracerr.Wrap(err)
}
defer func(cfg TaskConfig) {
Expand All @@ -299,14 +300,14 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err != nil {
errMessage := fmt.Sprintf("prepareVolumes error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated("EXECUTOR_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonExecutorError), errMessage)
return tracerr.Wrap(err)
}
err = prepareInstanceMountPoints(cfg)
if err != nil {
errMessage := fmt.Sprintf("prepareInstanceMountPoints error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated("EXECUTOR_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonExecutorError), errMessage)
return tracerr.Wrap(err)
}

Expand All @@ -320,7 +321,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err = pullImage(pullCtx, d.client, cfg); err != nil {
errMessage := fmt.Sprintf("pullImage error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated("CREATING_CONTAINER_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonCreatingContainerError), errMessage)
return tracerr.Wrap(err)
}

Expand All @@ -332,7 +333,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err := d.createContainer(ctx, &task); err != nil {
errMessage := fmt.Sprintf("createContainer error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated("CREATING_CONTAINER_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonCreatingContainerError), errMessage)
return tracerr.Wrap(err)
}

Expand All @@ -358,12 +359,12 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
log.Error(ctx, "getContainerLastLogs error", "err", err)
errMessage = ""
}
task.SetStatusTerminated("CONTAINER_EXITED_WITH_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonContainerExitedWithError), errMessage)
return tracerr.Wrap(err)
}

log.Debug(ctx, "Container finished successfully", "task", task.ID, "name", task.containerName)
task.SetStatusTerminated("DONE_BY_RUNNER", "")
task.SetStatusTerminated(string(types.TerminationReasonDoneByRunner), "")

return nil
}
Expand Down
2 changes: 1 addition & 1 deletion runner/internal/shim/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (
type Task struct {
ID string
Status TaskStatus
TerminationReason string // TODO: enum
TerminationReason string
TerminationMessage string

config TaskConfig
Expand Down
23 changes: 23 additions & 0 deletions runner/internal/types/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package types

type TerminationReason string

const (
TerminationReasonExecutorError TerminationReason = "executor_error"
TerminationReasonCreatingContainerError TerminationReason = "creating_container_error"
TerminationReasonContainerExitedWithError TerminationReason = "container_exited_with_error"
TerminationReasonDoneByRunner TerminationReason = "done_by_runner"
TerminationReasonTerminatedByUser TerminationReason = "terminated_by_user"
TerminationReasonTerminatedByServer TerminationReason = "terminated_by_server"
TerminationReasonMaxDurationExceeded TerminationReason = "max_duration_exceeded"
)

type JobState string

const (
JobStateDone JobState = "done"
JobStateFailed JobState = "failed"
JobStateRunning JobState = "running"
JobStateTerminated JobState = "terminated"
JobStateTerminating JobState = "terminating"
)
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class JobTerminationReason(str, Enum):
PORTS_BINDING_FAILED = "ports_binding_failed"
CREATING_CONTAINER_ERROR = "creating_container_error"
EXECUTOR_ERROR = "executor_error"
MAX_DURATION_EXCEEDED = "max_duration_exceeded"

def to_status(self) -> JobStatus:
mapping = {
Expand All @@ -135,6 +136,7 @@ def to_status(self) -> JobStatus:
self.PORTS_BINDING_FAILED: JobStatus.FAILED,
self.CREATING_CONTAINER_ERROR: JobStatus.FAILED,
self.EXECUTOR_ERROR: JobStatus.FAILED,
self.MAX_DURATION_EXCEEDED: JobStatus.TERMINATED,
}
return mapping[self]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def _process_pulling_with_shim(
task.termination_message,
)
logger.debug("task status: %s", task.dict())
job_model.termination_reason = JobTerminationReason[task.termination_reason.upper()]
job_model.termination_reason = JobTerminationReason(task.termination_reason.lower())
job_model.termination_reason_message = task.termination_message
return False

Expand Down Expand Up @@ -547,7 +547,7 @@ def _process_pulling_with_shim(
shim_status.result.reason_message,
)
logger.debug("shim status: %s", shim_status.dict())
job_model.termination_reason = JobTerminationReason[shim_status.result.reason.upper()]
job_model.termination_reason = JobTerminationReason(shim_status.result.reason.lower())
job_model.termination_reason_message = shim_status.result.reason_message
return False

Expand Down Expand Up @@ -598,18 +598,20 @@ def _process_running(
job_logs=resp.job_logs,
)
if len(resp.job_states) > 0:
latest_status = resp.job_states[-1].state
# TODO(egor-s): refactor dstack-runner to return compatible statuses and reasons
latest_state_event = resp.job_states[-1]
latest_status = latest_state_event.state
if latest_status == JobStatus.DONE:
job_model.status = JobStatus.TERMINATING
job_model.termination_reason = JobTerminationReason.DONE_BY_RUNNER
# let the CLI pull logs?
# delay_job_instance_termination(job_model)
elif latest_status in {JobStatus.FAILED, JobStatus.ABORTED, JobStatus.TERMINATED}:
elif latest_status in {JobStatus.FAILED, JobStatus.TERMINATED}:
job_model.status = JobStatus.TERMINATING
job_model.termination_reason = JobTerminationReason.CONTAINER_EXITED_WITH_ERROR
# let the CLI pull logs?
# delay_job_instance_termination(job_model)
if latest_state_event.termination_reason:
job_model.termination_reason = JobTerminationReason(
latest_state_event.termination_reason.lower()
)
if latest_state_event.termination_message:
job_model.termination_reason_message = latest_state_event.termination_message
logger.info("%s: now is %s", fmt(job_model), job_model.status.name)
return True

Expand Down
Loading

0 comments on commit 6d93ecc

Please sign in to comment.