Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add termination reason and message to the runner API #2204

Merged
merged 7 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions runner/consts/states/state.go

This file was deleted.

13 changes: 7 additions & 6 deletions runner/docs/shim.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ components:
TerminationReason:
type: string
enum:
- EXECUTOR_ERROR
- CREATING_CONTAINER_ERROR
- CONTAINER_EXITED_WITH_ERROR
- DONE_BY_RUNNER
- TERMINATED_BY_USER
- TERMINATED_BY_SERVER
- executor_error
- creating_container_error
- container_exited_with_error
- done_by_runner
- terminated_by_user
- terminated_by_server
- max_duration_exceeded

GpuID:
description: >
Expand Down
9 changes: 8 additions & 1 deletion runner/internal/executor/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"github.com/dstackai/dstack/runner/internal/schemas"
"github.com/dstackai/dstack/runner/internal/types"
)

type Executor interface {
Expand All @@ -13,7 +14,13 @@ type Executor interface {
Run(ctx context.Context) error
SetCodePath(codePath string)
SetJob(job schemas.SubmitBody)
SetJobState(ctx context.Context, state string)
SetJobState(ctx context.Context, state types.JobState)
SetJobStateWithTerminationReason(
ctx context.Context,
state types.JobState,
termination_reason types.TerminationReason,
termination_message string,
)
SetRunnerState(state string)
Lock()
RLock()
Expand Down
47 changes: 34 additions & 13 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import (

"github.com/creack/pty"
"github.com/dstackai/dstack/runner/consts"
"github.com/dstackai/dstack/runner/consts/states"
"github.com/dstackai/dstack/runner/internal/gerrors"
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/schemas"
"github.com/dstackai/dstack/runner/internal/types"
)

type RunExecutor struct {
Expand Down Expand Up @@ -79,14 +79,14 @@ func NewRunExecutor(tempDir string, homeDir string, workingDir string) (*RunExec
func (ex *RunExecutor) Run(ctx context.Context) (err error) {
runnerLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerLogFileName))
if err != nil {
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
return gerrors.Wrap(err)
}
defer func() { _ = runnerLogFile.Close() }()

jobLogFile, err := log.CreateAppendFile(filepath.Join(ex.tempDir, consts.RunnerJobLogFileName))
if err != nil {
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
return gerrors.Wrap(err)
}
defer func() { _ = jobLogFile.Close() }()
Expand All @@ -95,7 +95,7 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
// recover goes after runnerLogFile.Close() to keep the log
if r := recover(); r != nil {
log.Error(ctx, "Executor PANIC", "err", r)
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
err = gerrors.Newf("recovered: %v", r)
}
// no more logs will be written after this
Expand All @@ -115,17 +115,17 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
log.Info(ctx, "Run job", "log_level", log.GetLogger(ctx).Logger.Level.String())

if err := ex.setupRepo(ctx); err != nil {
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
return gerrors.Wrap(err)
}
cleanupCredentials, err := ex.setupCredentials(ctx)
if err != nil {
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
return gerrors.Wrap(err)
}
defer cleanupCredentials()

ex.SetJobState(ctx, states.Running)
ex.SetJobState(ctx, types.JobStateRunning)
timeoutCtx := ctx
var cancelTimeout context.CancelFunc
if ex.jobSpec.MaxDuration != 0 {
Expand All @@ -136,26 +136,33 @@ func (ex *RunExecutor) Run(ctx context.Context) (err error) {
select {
case <-ctx.Done():
log.Error(ctx, "Job canceled")
ex.SetJobState(ctx, states.Terminated)
ex.SetJobState(ctx, types.JobStateTerminated)
return gerrors.Wrap(err)
default:
}

select {
case <-timeoutCtx.Done():
log.Error(ctx, "Max duration exceeded", "max_duration", ex.jobSpec.MaxDuration)
ex.SetJobState(ctx, states.Terminated)
// We do not set "max_duration_exceeded" termination reason yet for backward compatibility
// TODO: Set it several releases after 0.18.36
ex.SetJobStateWithTerminationReason(
ctx,
types.JobStateTerminated,
types.TerminationReasonContainerExitedWithError,
"Max duration exceeded",
)
return gerrors.Wrap(err)
default:
}

// todo fail reason?
log.Error(ctx, "Exec failed", "err", err)
ex.SetJobState(ctx, states.Failed)
ex.SetJobState(ctx, types.JobStateFailed)
return gerrors.Wrap(err)
}

ex.SetJobState(ctx, states.Done)
ex.SetJobState(ctx, types.JobStateDone)
return nil
}

Expand All @@ -173,9 +180,23 @@ func (ex *RunExecutor) SetCodePath(codePath string) {
ex.state = WaitRun
}

func (ex *RunExecutor) SetJobState(ctx context.Context, state string) {
func (ex *RunExecutor) SetJobState(ctx context.Context, state types.JobState) {
ex.SetJobStateWithTerminationReason(ctx, state, "", "")
}

func (ex *RunExecutor) SetJobStateWithTerminationReason(
ctx context.Context, state types.JobState, termination_reason types.TerminationReason, termination_message string,
) {
ex.mu.Lock()
ex.jobStateHistory = append(ex.jobStateHistory, schemas.JobStateEvent{State: state, Timestamp: ex.timestamp.Next()})
ex.jobStateHistory = append(
ex.jobStateHistory,
schemas.JobStateEvent{
State: state,
Timestamp: ex.timestamp.Next(),
TerminationReason: termination_reason,
TerminationMessage: termination_message,
},
)
ex.mu.Unlock()
log.Info(ctx, "Job state changed", "new", state)
}
Expand Down
12 changes: 9 additions & 3 deletions runner/internal/schemas/schemas.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package schemas

import "strings"
import (
"strings"

"github.com/dstackai/dstack/runner/internal/types"
)

type JobStateEvent struct {
State string `json:"state"`
Timestamp int64 `json:"timestamp"`
State types.JobState `json:"state"`
Timestamp int64 `json:"timestamp"`
TerminationReason types.TerminationReason `json:"termination_reason"`
TerminationMessage string `json:"termination_message"`
}

type LogEvent struct {
Expand Down
17 changes: 9 additions & 8 deletions runner/internal/shim/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/dstackai/dstack/runner/internal/log"
"github.com/dstackai/dstack/runner/internal/shim/backends"
"github.com/dstackai/dstack/runner/internal/shim/host"
"github.com/dstackai/dstack/runner/internal/types"
bytesize "github.com/inhies/go-bytesize"
"github.com/ztrue/tracerr"
)
Expand Down Expand Up @@ -260,7 +261,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
gpuIDs, err := d.gpuLock.Acquire(ctx, cfg.GPU)
if err != nil {
log.Error(ctx, err.Error())
task.SetStatusTerminated("EXECUTOR_ERROR", err.Error())
task.SetStatusTerminated(string(types.TerminationReasonExecutorError), err.Error())
return tracerr.Wrap(err)
}
task.gpuIDs = gpuIDs
Expand All @@ -279,7 +280,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err := ak.AppendPublicKeys(cfg.HostSshKeys); err != nil {
errMessage := fmt.Sprintf("ak.AppendPublicKeys error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated("EXECUTOR_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonExecutorError), errMessage)
return tracerr.Wrap(err)
}
defer func(cfg TaskConfig) {
Expand All @@ -299,14 +300,14 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err != nil {
errMessage := fmt.Sprintf("prepareVolumes error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated("EXECUTOR_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonExecutorError), errMessage)
return tracerr.Wrap(err)
}
err = prepareInstanceMountPoints(cfg)
if err != nil {
errMessage := fmt.Sprintf("prepareInstanceMountPoints error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated("EXECUTOR_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonExecutorError), errMessage)
return tracerr.Wrap(err)
}

Expand All @@ -320,7 +321,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err = pullImage(pullCtx, d.client, cfg); err != nil {
errMessage := fmt.Sprintf("pullImage error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated("CREATING_CONTAINER_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonCreatingContainerError), errMessage)
return tracerr.Wrap(err)
}

Expand All @@ -332,7 +333,7 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
if err := d.createContainer(ctx, &task); err != nil {
errMessage := fmt.Sprintf("createContainer error: %s", err.Error())
log.Error(ctx, errMessage)
task.SetStatusTerminated("CREATING_CONTAINER_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonCreatingContainerError), errMessage)
return tracerr.Wrap(err)
}

Expand All @@ -358,12 +359,12 @@ func (d *DockerRunner) Run(ctx context.Context, taskID string) error {
log.Error(ctx, "getContainerLastLogs error", "err", err)
errMessage = ""
}
task.SetStatusTerminated("CONTAINER_EXITED_WITH_ERROR", errMessage)
task.SetStatusTerminated(string(types.TerminationReasonContainerExitedWithError), errMessage)
return tracerr.Wrap(err)
}

log.Debug(ctx, "Container finished successfully", "task", task.ID, "name", task.containerName)
task.SetStatusTerminated("DONE_BY_RUNNER", "")
task.SetStatusTerminated(string(types.TerminationReasonDoneByRunner), "")

return nil
}
Expand Down
2 changes: 1 addition & 1 deletion runner/internal/shim/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (
type Task struct {
ID string
Status TaskStatus
TerminationReason string // TODO: enum
TerminationReason string
TerminationMessage string

config TaskConfig
Expand Down
23 changes: 23 additions & 0 deletions runner/internal/types/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package types

type TerminationReason string

const (
TerminationReasonExecutorError TerminationReason = "executor_error"
TerminationReasonCreatingContainerError TerminationReason = "creating_container_error"
TerminationReasonContainerExitedWithError TerminationReason = "container_exited_with_error"
TerminationReasonDoneByRunner TerminationReason = "done_by_runner"
TerminationReasonTerminatedByUser TerminationReason = "terminated_by_user"
TerminationReasonTerminatedByServer TerminationReason = "terminated_by_server"
TerminationReasonMaxDurationExceeded TerminationReason = "max_duration_exceeded"
)

type JobState string

const (
JobStateDone JobState = "done"
JobStateFailed JobState = "failed"
JobStateRunning JobState = "running"
JobStateTerminated JobState = "terminated"
JobStateTerminating JobState = "terminating"
)
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class JobTerminationReason(str, Enum):
PORTS_BINDING_FAILED = "ports_binding_failed"
CREATING_CONTAINER_ERROR = "creating_container_error"
EXECUTOR_ERROR = "executor_error"
MAX_DURATION_EXCEEDED = "max_duration_exceeded"

def to_status(self) -> JobStatus:
mapping = {
Expand All @@ -135,6 +136,7 @@ def to_status(self) -> JobStatus:
self.PORTS_BINDING_FAILED: JobStatus.FAILED,
self.CREATING_CONTAINER_ERROR: JobStatus.FAILED,
self.EXECUTOR_ERROR: JobStatus.FAILED,
self.MAX_DURATION_EXCEEDED: JobStatus.TERMINATED,
}
return mapping[self]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def _process_pulling_with_shim(
task.termination_message,
)
logger.debug("task status: %s", task.dict())
job_model.termination_reason = JobTerminationReason[task.termination_reason.upper()]
job_model.termination_reason = JobTerminationReason(task.termination_reason.lower())
job_model.termination_reason_message = task.termination_message
return False

Expand Down Expand Up @@ -547,7 +547,7 @@ def _process_pulling_with_shim(
shim_status.result.reason_message,
)
logger.debug("shim status: %s", shim_status.dict())
job_model.termination_reason = JobTerminationReason[shim_status.result.reason.upper()]
job_model.termination_reason = JobTerminationReason(shim_status.result.reason.lower())
job_model.termination_reason_message = shim_status.result.reason_message
return False

Expand Down Expand Up @@ -598,18 +598,20 @@ def _process_running(
job_logs=resp.job_logs,
)
if len(resp.job_states) > 0:
latest_status = resp.job_states[-1].state
# TODO(egor-s): refactor dstack-runner to return compatible statuses and reasons
latest_state_event = resp.job_states[-1]
latest_status = latest_state_event.state
if latest_status == JobStatus.DONE:
job_model.status = JobStatus.TERMINATING
job_model.termination_reason = JobTerminationReason.DONE_BY_RUNNER
# let the CLI pull logs?
# delay_job_instance_termination(job_model)
elif latest_status in {JobStatus.FAILED, JobStatus.ABORTED, JobStatus.TERMINATED}:
elif latest_status in {JobStatus.FAILED, JobStatus.TERMINATED}:
job_model.status = JobStatus.TERMINATING
job_model.termination_reason = JobTerminationReason.CONTAINER_EXITED_WITH_ERROR
# let the CLI pull logs?
# delay_job_instance_termination(job_model)
if latest_state_event.termination_reason:
job_model.termination_reason = JobTerminationReason(
latest_state_event.termination_reason.lower()
)
if latest_state_event.termination_message:
job_model.termination_reason_message = latest_state_event.termination_message
logger.info("%s: now is %s", fmt(job_model), job_model.status.name)
return True

Expand Down
Loading
Loading