From 21861342405cb50cd310833ed0291f6071fe3fdf Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Wed, 25 Dec 2024 11:14:10 +0000 Subject: [PATCH 1/2] [shim] Implement Future API Part-of: https://github.com/dstackai/dstack/issues/1780 --- runner/cmd/shim/main.go | 1 - runner/consts/consts.go | 3 + runner/docs/shim.openapi.yaml | 521 +++++++++++++++++++--- runner/internal/api/common.go | 37 +- runner/internal/runner/api/http_test.go | 24 +- runner/internal/runner/api/server.go | 20 +- runner/internal/runner/api/submit_test.go | 4 +- runner/internal/shim/api/http.go | 120 ++++- runner/internal/shim/api/schemas.go | 53 ++- runner/internal/shim/api/server.go | 39 +- runner/internal/shim/docker.go | 359 +++++++++++---- runner/internal/shim/docker_test.go | 52 ++- runner/internal/shim/errs.go | 35 ++ runner/internal/shim/host/gpu.go | 4 + runner/internal/shim/models.go | 26 +- runner/internal/shim/resources.go | 26 +- runner/internal/shim/resources_test.go | 44 +- runner/internal/shim/runner.go | 3 +- runner/internal/shim/task.go | 97 +++- runner/internal/shim/task_test.go | 61 ++- 20 files changed, 1283 insertions(+), 246 deletions(-) create mode 100644 runner/internal/shim/errs.go diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index cb2274f49..a63849497 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -110,7 +110,6 @@ func main() { } } - args.Runner.TempDir = "/tmp/runner" args.Runner.HomeDir = "/root" args.Runner.WorkingDir = "/workflow" diff --git a/runner/consts/consts.go b/runner/consts/consts.go index ecb20ba21..e7c9f7aab 100644 --- a/runner/consts/consts.go +++ b/runner/consts/consts.go @@ -11,3 +11,6 @@ const ( // Error-containing messages will be identified by this signature const ExecutorFailedSignature = "Executor failed" + +// A directory inside the container where runner stores its files (logs, etc.) +const RunnerDir = "/tmp/runner" diff --git a/runner/docs/shim.openapi.yaml b/runner/docs/shim.openapi.yaml index 0617720ee..42eb8a551 100644 --- a/runner/docs/shim.openapi.yaml +++ b/runner/docs/shim.openapi.yaml @@ -2,20 +2,165 @@ openapi: 3.1.1 info: title: dstack-shim API - version: &shim-version 0.18.30 + version: &api-version 0.18.31 + x-logo: + url: https://avatars.githubusercontent.com/u/54146142?s=260 servers: - url: http://localhost:10998/api +tags: + - name: &stable-api Stable API + description: > + Stable API should always stay backward (or future, depending on which point of view we are + talking about) compatible, meaning that newer versions of dstack server should be able + to use Stable API of older versions of shim without additional API version negotiation + - name: &future-api Future API + description: > + As of the current version of API, Future API is an upcoming "task-oriented" API, able to + handle more than one task at a time managing machine resources. It is not yet supported + by dstack server + - name: &legacy-api Legacy API + description: > + As of the current version of API, Legacy API is the only API (apart from Stable one) used by + dstack server. It can only process one task at a time and cannot manage (that is, limit) + machine resources consumed by the task + paths: + /healthcheck: + get: + tags: + - *stable-api + summary: Ping and API version negotiation + description: > + Serves two roles: + + * as the path implies, it's a healthcheck, although there is no field in the response that + indicate if shim is healthy. Basically, it not is a proper healthcheck but + a basic "ping" method + * API version negotiation. Server inspects `version` field to figure out which API features + it should use + responses: + "200": + description: "" + content: + application/json: + schema: + $ref: "#/components/schemas/HealthcheckResponse" + + /tasks: + get: + tags: + - *future-api + summary: Get task list + description: Returns a list of all tasks known to shim, including terminated ones + responses: + "200": + description: "" + content: + application/json: + schema: + $ref: "#/components/schemas/TaskListResponse" + put: + tags: + - *future-api + summary: Submit and run new task + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/TaskSubmitRequest" + responses: + "200": + description: Pending task info + $ref: "#/components/responses/TaskInfo" + "409": + description: Task with the same ID already submitted + $ref: "#/components/responses/PlainTextConflict" + "500": + description: Internal error + $ref: "#/components/responses/PlainTextInternalError" + + /tasks/{id}: + get: + tags: + - *future-api + summary: Get task info + parameters: + - $ref: "#/parameters/taskId" + responses: + "200": + $ref: "#/components/responses/TaskInfo" + + delete: + tags: + - *future-api + summary: Remove task + description: > + Removes the task from in-memory storage and destroys its associated + resources: a container, logs, etc. + parameters: + - $ref: "#/parameters/taskId" + responses: + "200": + description: Task removed + $ref: "#/components/responses/PlainTextOk" + "404": + description: Task not found + $ref: "#/components/responses/PlainTextNotFound" + "409": + description: Task is not terminated, cannot remove + $ref: "#/components/responses/PlainTextConflict" + "500": + description: Internal error, e.g., failed to remove a container + $ref: "#/components/responses/PlainTextInternalError" + + /tasks/{id}/terminate: + post: + tags: + - *future-api + summary: Terminate task + description: > + Stops the task, that is, cancels image pulling if in progress, + stops the container if running, and sets the status to `terminated`. + No-op if the task is already terminated + parameters: + - in: path + name: id + schema: + $ref: "#/components/schemas/TaskID" + required: true + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/TaskTerminateRequest" + responses: + "200": + description: Updated task info + $ref: "#/components/responses/TaskInfo" + "404": + description: Task not found + $ref: "#/components/responses/PlainTextNotFound" + "409": + description: Task is not terminated, cannot remove + $ref: "#/components/responses/PlainTextConflict" + "500": + description: Internal error, e.g., failed to remove a container + $ref: "#/components/responses/PlainTextInternalError" + /submit: post: + tags: + - *legacy-api requestBody: required: true content: application/json: schema: - $ref: "#/components/schemas/TaskConfigBody" + $ref: "#/components/schemas/LegacySubmitBody" responses: "200": description: "" @@ -36,42 +181,71 @@ paths: /pull: get: + tags: + - *legacy-api responses: "200": description: "" content: application/json: schema: - $ref: "#/components/schemas/PullResponse" + $ref: "#/components/schemas/LegacyPullResponse" /stop: post: + tags: + - *legacy-api requestBody: required: false content: application/json: schema: - $ref: "#/components/schemas/StopBody" + $ref: "#/components/schemas/LegacyStopBody" responses: "200": description: "" content: application/json: schema: - $ref: "#/components/schemas/StopResponse" + $ref: "#/components/schemas/LegacyStopResponse" - /healthcheck: - get: - responses: - "200": - description: "" - content: - application/json: - schema: - $ref: "#/components/schemas/HealthcheckResponse" +parameters: + taskId: + name: id + in: path + schema: + $ref: "#/components/schemas/TaskID" + required: true components: schemas: + TaskID: + description: Unique task ID assigned by dstack server + type: string + examples: + - 23a2c7a0-6c88-48ee-8028-b9ad9f6f5c24 + + TaskStatus: + title: shim.TaskStatus + type: string + enum: + - pending + - preparing + - pulling + - creating + - running + - terminated + + TerminationReason: + type: string + enum: + - EXECUTOR_ERROR + - CREATING_CONTAINER_ERROR + - CONTAINER_EXITED_WITH_ERROR + - DONE_BY_RUNNER + - TERMINATED_BY_USER + - TERMINATED_BY_SERVER + RunnerStatus: title: shim.RunnerStatus type: string @@ -86,12 +260,7 @@ components: type: object properties: reason: - type: string - enum: - - EXECUTOR_ERROR - - CREATING_CONTAINER_ERROR - - CONTAINER_EXITED_WITH_ERROR - - DONE_BY_RUNNER + $ref: "#/components/schemas/TerminationReason" reason_message: type: string default: "" @@ -103,19 +272,16 @@ components: - reason_message additionalProperties: false - VolumeMountPoint: - title: shim.VolumeMountPoint - type: object - properties: - name: - type: string - default: "" - description: > - `dstack` volume [name](https://dstack.ai/docs/reference/dstack.yml/volume/#name) - path: - type: string - default: "" - description: Mount point inside container + GpuID: + description: > + A vendor-specific unique identifier of GPU: + * NVIDIA: "globally unique immutable alphanumeric identifier of the GPU", + in the form of `GPU-` + * AMD: `/dev/dri/renderD` path + type: string + examples: + - GPU-2b79666e-d81f-f3f8-fd47-9903f118c3f5 + - /dev/dri/renderD128 VolumeInfo: title: shim.VolumeInfo @@ -138,6 +304,20 @@ components: description: > Create a filesystem when it doesn't exist if `true`, fail with error if `false` + VolumeMountPoint: + title: shim.VolumeMountPoint + type: object + properties: + name: + type: string + default: "" + description: > + `dstack` volume [name](https://dstack.ai/docs/reference/dstack.yml/volume/#name) + path: + type: string + default: "" + description: Mount point inside container + InstanceMountPoint: title: shim.InstanceMountPoint type: object @@ -151,10 +331,212 @@ components: default: "" description: Mount point inside container - TaskConfigBody: - title: shim.api.TaskConfigBody + HealthcheckResponse: + title: shim.api.HealthcheckResponse + type: object + properties: + service: + const: dstack-shim + version: + type: string + examples: + - *api-version + required: + - service + - version + additionalProperties: false + + TaskListResponse: + title: shim.api.TaskListResponse + type: object + properties: + ids: + type: array + items: + $ref: "#/components/schemas/TaskID" + description: A list of all task IDs tracked by shim + required: + - ids + additionalProperties: false + + TaskInfoResponse: + title: shim.api.TaskInfoResponse + description: Same as `shim.TaskInfo` + type: object + properties: + id: + $ref: "#/components/schemas/TaskID" + status: + allOf: + - $ref: "#/components/schemas/TaskStatus" + - examples: + - terminated + termination_reason: + $ref: "#/components/schemas/TerminationReason" + termination_message: + type: string + description: A shim-generated message or N last lines from the container logs + container_name: + type: string + examples: + - horrible-mule-1-0-0-44f7cb95 + container_id: + type: string + examples: + - a6bb8d4bb8af8ec72482ecd194ff92fac9974521aa5ad8a46abfc4f0ba858775 + gpu_ids: + type: array + items: + $ref: "#/components/schemas/GpuID" + required: + - id + - status + - termination_reason + - termination_message + - container_name + - container_id + - gpu_ids + additionalProperties: false + + TaskSubmitRequest: + title: shim.api.TaskSubmitRequest description: Same as `shim.TaskConfig` type: object + properties: + id: + $ref: "#/components/schemas/TaskID" + name: + type: string + description: Task name. Used to construct unique container name + examples: + - horrible-mule-1-0-0 + registry_username: + type: string + default: "" + description: Private container registry username + examples: + - registry-user + registry_password: + type: string + default: "" + description: Private container registry password + examples: + - registry-token + image_name: + type: string + default: "" + examples: + - ubuntu:22.04 + container_user: + type: string + default: "" + description: > + If not set, the default image user is used. As of 0.18.24, `dstack` always uses `root` + examples: + - root + privileged: + type: boolean + default: false + description: Start container in privileged mode + gpu: + type: integer + minimum: -1 + default: 0 + description: > + Number of GPUs allocated for the container. A special value `-1` means "all available, + even if none", `0` means "zero GPUs" + cpu: + type: number + minimum: 0 + default: 0 + description: > + Amount of CPU resources available to the container. A special value `0` means "all". + Fractional values are allowed, e.g., `1.5` — one and a half CPUs + memory: + type: number + minimum: 0 + default: 0 + description: > + Amount of memory available to the container, in bytes. A special value `0` means "all" + shm_size: + type: integer + minimum: 0 + default: 0 + description: > + POSIX shared memory, bytes. A special value `0` means "use the default value (64MiB)". + If > 0, tmpfs is mounted with the `exec` option, unlike the default mount options + examples: + - 1073741824 + volumes: + type: array + items: + $ref: "#/components/schemas/VolumeInfo" + default: [] + volume_mounts: + type: array + items: + $ref: "#/components/schemas/VolumeMountPoint" + default: [] + instance_mounts: + type: array + items: + $ref: "#/components/schemas/InstanceMountPoint" + default: [] + host_ssh_user: + type: string + default: "" + description: > + Instance (host) user for SSH access, either directly (`ssh {run_name}-host`) + or for `ProxyJump`ing inside the container. Ignored if `host_ssh_keys` is not set + examples: + - root + host_ssh_keys: + type: array + items: + type: string + default: [] + description: > + SSH public keys for access to the instance (host). If set, the keys will be added + to the `host_ssh_user`'s `~/.ssh/authorized_keys` when the run starts and removed + when the run exits. + examples: + - "ssh-ed25519 me@laptop" + container_ssh_keys: + type: array + items: + type: string + default: [] + description: > + SSH public keys for `container_user`. As of 0.18.24, `dstack` submits two keys: + project key (generated by the server) and user key (either generated by + the CLI client or provided by the user) + examples: + - ["ssh-rsa project@dstack", "ssh-ed25519 me@laptop"] + + TaskTerminateRequest: + title: shim.api.TaskTerminateRequest + type: object + properties: + termination_reason: + allOf: + - $ref: "#/components/schemas/TerminationReason" + - examples: + - TERMINATED_BY_USER + - TERMINATED_BY_SERVER + default: "" + termination_message: + type: string + default: "" + timeout: + type: boolean + default: 0 + description: > + Seconds to wait before killing the container. If zero, kill + the container immediately (no graceful shutdown) + + LegacySubmitBody: + title: shim.api.LegacySubmitBody + type: object properties: username: type: string @@ -244,8 +626,8 @@ components: default: [] description: (since [0.18.21](https://github.com/dstackai/dstack/releases/tag/0.18.21)) - PullResponse: - title: shim.api.PullResponse + LegacyPullResponse: + title: shim.api.LegacyPullResponse type: object properties: state: @@ -272,16 +654,16 @@ components: - result additionalProperties: false - StopBody: - title: shim.api.StopBody + LegacyStopBody: + title: shim.api.LegacyStopBody type: object properties: force: type: boolean default: false - StopResponse: - title: shim.api.StopResponse + LegacyStopResponse: + title: shim.api.LegacyStopResponse type: object properties: state: @@ -290,17 +672,46 @@ components: - state additionalProperties: false - HealthcheckResponse: - title: shim.api.HealthcheckResponse - type: object - properties: - service: - const: dstack-shim - version: - type: string - examples: - - *shim-version - required: - - service - - version - additionalProperties: false + responses: + TaskInfo: + description: Task info + content: + application/json: + schema: + $ref: "#/components/schemas/TaskInfoResponse" + + PlainTextOk: + description: "" + content: + text/plain: + schema: + type: string + examples: + - OK + + PlainTextNotFound: + description: "" + content: + text/plain: + schema: + type: string + examples: + - not found + + PlainTextConflict: + description: "" + content: + text/plain: + schema: + type: string + examples: + - conflict + + PlainTextInternalError: + description: "" + content: + text/plain: + schema: + type: string + examples: + - internal error diff --git a/runner/internal/api/common.go b/runner/internal/api/common.go index 7c4ceba8a..fcb5f40e6 100644 --- a/runner/internal/api/common.go +++ b/runner/internal/api/common.go @@ -19,7 +19,25 @@ type Error struct { } func (e *Error) Error() string { - return e.Err.Error() + if e.Msg != "" { + return e.Msg + } + if e.Err != nil { + return e.Err.Error() + } + return http.StatusText(e.Status) +} + +type Router struct { + *http.ServeMux +} + +func (r *Router) AddHandler(method string, pattern string, handler func(http.ResponseWriter, *http.Request) (interface{}, error)) { + r.HandleFunc(fmt.Sprintf("%s %s", method, pattern), JSONResponseHandler(handler)) +} + +func NewRouter() Router { + return Router{http.NewServeMux()} } func DecodeJSONBody(w http.ResponseWriter, r *http.Request, dst interface{}, allowUnknown bool) error { @@ -84,25 +102,17 @@ func DecodeJSONBody(w http.ResponseWriter, r *http.Request, dst interface{}, all return nil } -func JSONResponseHandler(method string, handler func(http.ResponseWriter, *http.Request) (interface{}, error)) func(http.ResponseWriter, *http.Request) { +func JSONResponseHandler(handler func(http.ResponseWriter, *http.Request) (interface{}, error)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { status := 200 msg := "" - var body interface{} - var err error var apiErr *Error - if r.Method == method { - body, err = handler(w, r) - } else { - body = nil - err = &Error{Status: http.StatusMethodNotAllowed, Err: nil} - } - + body, err := handler(w, r) if err != nil { if errors.As(err, &apiErr) { status = apiErr.Status - msg = apiErr.Msg + msg = apiErr.Error() log.Warning(r.Context(), "API error", "err", apiErr.Err) } else { status = http.StatusInternalServerError @@ -115,9 +125,6 @@ func JSONResponseHandler(method string, handler func(http.ResponseWriter, *http. w.WriteHeader(status) _ = json.NewEncoder(w).Encode(body) } else { - if msg == "" { - msg = http.StatusText(status) - } http.Error(w, msg, status) } diff --git a/runner/internal/runner/api/http_test.go b/runner/internal/runner/api/http_test.go index f0ef73473..a0572f42d 100644 --- a/runner/internal/runner/api/http_test.go +++ b/runner/internal/runner/api/http_test.go @@ -20,11 +20,29 @@ func (ds DummyRunner) GetState() (shim.RunnerStatus, shim.JobResult) { return ds.State, ds.JobResult } -func (ds DummyRunner) Run(context.Context, shim.TaskConfig) error { +func (ds DummyRunner) Submit(context.Context, shim.TaskConfig) error { return nil } -func (ds DummyRunner) Stop(force bool) {} +func (ds DummyRunner) Run(context.Context, string) error { + return nil +} + +func (ds DummyRunner) Terminate(context.Context, string, uint, string, string) error { + return nil +} + +func (ds DummyRunner) Remove(context.Context, string) error { + return nil +} + +func (ds DummyRunner) TaskIDs() []string { + return []string{} +} + +func (ds DummyRunner) TaskInfo(taskID string) shim.TaskInfo { + return shim.TaskInfo{} +} func (ds DummyRunner) Resources() shim.Resources { return shim.Resources{} @@ -36,7 +54,7 @@ func TestHealthcheck(t *testing.T) { server := api.NewShimServer(":12345", DummyRunner{}, "0.0.1.dev2") - f := common.JSONResponseHandler("GET", server.HealthcheckGetHandler) + f := common.JSONResponseHandler(server.HealthcheckHandler) f(responseRecorder, request) if responseRecorder.Code != 200 { diff --git a/runner/internal/runner/api/server.go b/runner/internal/runner/api/server.go index bf194bad2..7395f2213 100644 --- a/runner/internal/runner/api/server.go +++ b/runner/internal/runner/api/server.go @@ -35,7 +35,7 @@ type Server struct { } func NewServer(tempDir string, homeDir string, workingDir string, address string, version string) (*Server, error) { - mux := http.NewServeMux() + r := api.NewRouter() ex, err := executor.NewRunExecutor(tempDir, homeDir, workingDir) if err != nil { return nil, err @@ -43,7 +43,7 @@ func NewServer(tempDir string, homeDir string, workingDir string, address string s := &Server{ srv: &http.Server{ Addr: address, - Handler: mux, + Handler: r, }, tempDir: tempDir, workingDir: workingDir, @@ -60,14 +60,14 @@ func NewServer(tempDir string, homeDir string, workingDir string, address string version: version, } - mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.healthcheckGetHandler)) - mux.HandleFunc("/api/metrics", api.JSONResponseHandler("GET", s.metricsGetHandler)) - mux.HandleFunc("/api/submit", api.JSONResponseHandler("POST", s.submitPostHandler)) - mux.HandleFunc("/api/upload_code", api.JSONResponseHandler("POST", s.uploadCodePostHandler)) - mux.HandleFunc("/api/run", api.JSONResponseHandler("POST", s.runPostHandler)) - mux.HandleFunc("/api/pull", api.JSONResponseHandler("GET", s.pullGetHandler)) - mux.HandleFunc("/api/stop", api.JSONResponseHandler("POST", s.stopPostHandler)) - mux.HandleFunc("/logs_ws", api.JSONResponseHandler("GET", s.logsWsGetHandler)) + r.AddHandler("GET", "/api/healthcheck", s.healthcheckGetHandler) + r.AddHandler("GET", "/api/metrics", s.metricsGetHandler) + r.AddHandler("POST", "/api/submit", s.submitPostHandler) + r.AddHandler("POST", "/api/upload_code", s.uploadCodePostHandler) + r.AddHandler("POST", "/api/run", s.runPostHandler) + r.AddHandler("GET", "/api/pull", s.pullGetHandler) + r.AddHandler("POST", "/api/stop", s.stopPostHandler) + r.AddHandler("GET", "/logs_ws", s.logsWsGetHandler) return s, nil } diff --git a/runner/internal/runner/api/submit_test.go b/runner/internal/runner/api/submit_test.go index b170c81be..279d91edc 100644 --- a/runner/internal/runner/api/submit_test.go +++ b/runner/internal/runner/api/submit_test.go @@ -21,7 +21,7 @@ func TestSubmit(t *testing.T) { server := api.NewShimServer(":12340", &dummyRunner, "0.0.1.dev2") - firstSubmitPost := common.JSONResponseHandler("POST", server.SubmitPostHandler) + firstSubmitPost := common.JSONResponseHandler(server.LegacySubmitPostHandler) firstSubmitPost(responseRecorder, request) if responseRecorder.Code != 200 { @@ -35,7 +35,7 @@ func TestSubmit(t *testing.T) { request = httptest.NewRequest("POST", "/api/submit", strings.NewReader("{\"image_name\":\"ubuntu\"}")) responseRecorder = httptest.NewRecorder() - secondSubmitPost := common.JSONResponseHandler("POST", server.SubmitPostHandler) + secondSubmitPost := common.JSONResponseHandler(server.LegacySubmitPostHandler) secondSubmitPost(responseRecorder, request) t.Logf("%v", responseRecorder.Result()) diff --git a/runner/internal/shim/api/http.go b/runner/internal/shim/api/http.go index 277b4686f..6a7fee775 100644 --- a/runner/internal/shim/api/http.go +++ b/runner/internal/shim/api/http.go @@ -2,6 +2,7 @@ package api import ( "context" + "errors" "fmt" "log" "net/http" @@ -10,7 +11,9 @@ import ( "github.com/dstackai/dstack/runner/internal/shim" ) -func (s *ShimServer) HealthcheckGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { +// Stable API + +func (s *ShimServer) HealthcheckHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -20,7 +23,79 @@ func (s *ShimServer) HealthcheckGetHandler(w http.ResponseWriter, r *http.Reques }, nil } -func (s *ShimServer) SubmitPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { +// Future API + +func (s *ShimServer) TaskListHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + return &TaskListResponse{IDs: s.runner.TaskIDs()}, nil +} + +func (s *ShimServer) TaskInfoHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + taskInfo := s.runner.TaskInfo(r.PathValue("id")) + if taskInfo.ID == "" { + return nil, &api.Error{Status: http.StatusNotFound} + } + return TaskInfoResponse(taskInfo), nil +} + +// TaskSubmitHandler submits AND runs a task +func (s *ShimServer) TaskSubmitHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + var req TaskSubmitRequest + if err := api.DecodeJSONBody(w, r, &req, true); err != nil { + return nil, err + } + taskConfig := shim.TaskConfig(req) + if err := s.runner.Submit(r.Context(), taskConfig); err != nil { + if errors.Is(err, shim.ErrRequest) { + return nil, &api.Error{Status: http.StatusConflict, Err: err} + } + return nil, &api.Error{Status: http.StatusInternalServerError, Err: err} + } + go func(taskID string) { + if err := s.runner.Run(context.Background(), taskID); err != nil { + fmt.Printf("failed task %v", err) + } + }(taskConfig.ID) + return s.runner.TaskInfo(taskConfig.ID), nil +} + +func (s *ShimServer) TaskTerminateHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + taskID := r.PathValue("id") + var req TaskTerminateRequest + if err := api.DecodeJSONBody(w, r, &req, true); err != nil { + return nil, err + } + if err := s.runner.Terminate(r.Context(), taskID, req.Timeout, req.TerminationReason, req.TerminationMessage); err != nil { + if errors.Is(err, shim.ErrNotFound) { + return nil, &api.Error{Status: http.StatusNotFound, Err: err} + } + if errors.Is(err, shim.ErrRequest) { + return nil, &api.Error{Status: http.StatusConflict, Err: err} + } + return nil, &api.Error{Status: http.StatusInternalServerError, Err: err} + } + taskInfo := s.runner.TaskInfo(taskID) + if taskInfo.ID == "" { + return nil, &api.Error{Status: http.StatusNotFound} + } + return TaskInfoResponse(taskInfo), nil +} + +func (s *ShimServer) TaskRemoveHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { + if err := s.runner.Remove(r.Context(), r.PathValue("id")); err != nil { + if errors.Is(err, shim.ErrNotFound) { + return nil, &api.Error{Status: http.StatusNotFound, Err: err} + } + if errors.Is(err, shim.ErrRequest) { + return nil, &api.Error{Status: http.StatusConflict, Err: err} + } + return nil, &api.Error{Status: http.StatusInternalServerError, Err: err} + } + return nil, nil +} + +// Legacy API + +func (s *ShimServer) LegacySubmitPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.mu.RLock() defer s.mu.RUnlock() state, _ := s.runner.GetState() @@ -28,7 +103,7 @@ func (s *ShimServer) SubmitPostHandler(w http.ResponseWriter, r *http.Request) ( return nil, &api.Error{Status: http.StatusConflict} } - var body SubmitBody + var body LegacySubmitBody if err := api.DecodeJSONBody(w, r, &body, true); err != nil { log.Println("Failed to decode submit body", "err", err) return nil, err @@ -42,57 +117,68 @@ func (s *ShimServer) SubmitPostHandler(w http.ResponseWriter, r *http.Request) ( ImageName: body.ImageName, ContainerUser: body.ContainerUser, Privileged: body.Privileged, - GpuCount: -1, + GPU: -1, ShmSize: body.ShmSize, - PublicKeys: body.PublicKeys, - SshUser: body.SshUser, - SshKey: body.SshKey, Volumes: body.Volumes, VolumeMounts: body.VolumeMounts, InstanceMounts: body.InstanceMounts, + HostSshUser: body.SshUser, + HostSshKeys: []string{body.SshKey}, + ContainerSshKeys: body.PublicKeys, } go func(taskConfig shim.TaskConfig) { - err := s.runner.Run(context.Background(), taskConfig) - if err != nil { - fmt.Printf("failed Run %v\n", err) + if err := s.runner.Submit(context.Background(), taskConfig); err != nil { + fmt.Printf("failed Submit %v", err) + } + if err := s.runner.Run(context.Background(), taskConfig.ID); err != nil { + fmt.Printf("failed Run %v", err) } }(taskConfig) return nil, nil } -func (s *ShimServer) PullGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (s *ShimServer) LegacyPullGetHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.mu.RLock() defer s.mu.RUnlock() state, jobResult := s.runner.GetState() - return &PullResponse{ + return &LegacyPullResponse{ State: string(state), Result: jobResult, }, nil } -func (s *ShimServer) StopPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (s *ShimServer) LegacyStopPostHandler(w http.ResponseWriter, r *http.Request) (interface{}, error) { s.mu.RLock() defer s.mu.RUnlock() state, _ := s.runner.GetState() if state == shim.Pending { - return &StopResponse{ + return &LegacyStopResponse{ State: string(state), }, nil } - var body StopBody + var body LegacyStopBody if err := api.DecodeJSONBody(w, r, &body, true); err != nil { log.Println("Failed to decode submit stop body", "err", err) return nil, err } - s.runner.Stop(body.Force) + var timeout uint + if body.Force { + timeout = 0 + } else { + timeout = 10 // Docker default value + } + if err := s.runner.Terminate(r.Context(), shim.LegacyTaskID, timeout, "", ""); err != nil { + log.Println("Failed to terminate", "err", err) + } - return &StopResponse{ + state, _ = s.runner.GetState() + return &LegacyStopResponse{ State: string(state), }, nil } diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 45eecadf4..8a402ee23 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -2,7 +2,41 @@ package api import "github.com/dstackai/dstack/runner/internal/shim" -type SubmitBody struct { +// Stable API + +type HealthcheckResponse struct { + Service string `json:"service"` + Version string `json:"version"` +} + +// Future API + +type TaskListResponse struct { + IDs []string `json:"ids"` +} + +type TaskInfoResponse struct { + ID string `json:"id"` + Status shim.TaskStatus `json:"status"` + TerminationReason string `json:"termination_reason"` + TerminationMessage string `json:"termination_message"` + // The following fields are for debugging only, server doesn't need them + ContainerName string `json:"container_name"` + ContainerID string `json:"container_id"` + GpuIDs []string `json:"gpus_ids"` +} + +type TaskSubmitRequest = shim.TaskConfig + +type TaskTerminateRequest struct { + TerminationReason string `json:"termination_reason"` + TerminationMessage string `json:"termination_message"` + Timeout uint `json:"timeout"` +} + +// Legacy API + +type LegacySubmitBody struct { Username string `json:"username"` Password string `json:"password"` ImageName string `json:"image_name"` @@ -18,20 +52,15 @@ type SubmitBody struct { InstanceMounts []shim.InstanceMountPoint `json:"instance_mounts"` } -type StopBody struct { - Force bool `json:"force"` -} - -type HealthcheckResponse struct { - Service string `json:"service"` - Version string `json:"version"` -} - -type PullResponse struct { +type LegacyPullResponse struct { State string `json:"state"` Result shim.JobResult `json:"result"` } -type StopResponse struct { +type LegacyStopBody struct { + Force bool `json:"force"` +} + +type LegacyStopResponse struct { State string `json:"state"` } diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 748361402..33ebf2cba 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -10,11 +10,16 @@ import ( ) type TaskRunner interface { - Run(context.Context, shim.TaskConfig) error - GetState() (shim.RunnerStatus, shim.JobResult) - Stop(bool) + Submit(context.Context, shim.TaskConfig) error + Run(ctx context.Context, taskID string) error + Terminate(ctx context.Context, taskID string, timeout uint, reason string, message string) error + Remove(ctx context.Context, taskID string) error Resources() shim.Resources + TaskIDs() []string + TaskInfo(taskID string) shim.TaskInfo + + GetState() (shim.RunnerStatus, shim.JobResult) } type ShimServer struct { @@ -27,25 +32,33 @@ type ShimServer struct { } func NewShimServer(address string, runner TaskRunner, version string) *ShimServer { - mux := http.NewServeMux() + r := api.NewRouter() s := &ShimServer{ HttpServer: &http.Server{ Addr: address, - Handler: mux, + Handler: r, }, runner: runner, version: version, } + + // Stable API // The healthcheck endpoint should stay backward compatible, as it is used for negotiation - mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.HealthcheckGetHandler)) - // The following endpoints constitute a so-called legacy API, where shim has one global state - // and is able to process only one task at a time - // NOTE: as of 2024-12-10, there is _only_ legacy API, but the "legacy" label is used to - // distinguish the "old" API from the upcoming new one - mux.HandleFunc("/api/submit", api.JSONResponseHandler("POST", s.SubmitPostHandler)) - mux.HandleFunc("/api/pull", api.JSONResponseHandler("GET", s.PullGetHandler)) - mux.HandleFunc("/api/stop", api.JSONResponseHandler("POST", s.StopPostHandler)) + r.AddHandler("GET", "/api/healthcheck", s.HealthcheckHandler) + + // Future API + r.AddHandler("GET", "/api/tasks", s.TaskListHandler) + r.AddHandler("GET", "/api/tasks/{id}", s.TaskInfoHandler) + r.AddHandler("PUT", "/api/tasks", s.TaskSubmitHandler) + r.AddHandler("POST", "/api/tasks/{id}/terminate", s.TaskTerminateHandler) + r.AddHandler("DELETE", "/api/tasks/{id}", s.TaskRemoveHandler) + + // Legacy API + r.AddHandler("POST", "/api/submit", s.LegacySubmitPostHandler) + r.AddHandler("GET", "/api/pull", s.LegacyPullGetHandler) + r.AddHandler("POST", "/api/stop", s.LegacyStopPostHandler) + return s } diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 160193a21..4ba13dbbc 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -26,9 +26,11 @@ import ( "github.com/docker/docker/api/types/registry" dockersystem "github.com/docker/docker/api/types/system" docker "github.com/docker/docker/client" + "github.com/docker/docker/errdefs" "github.com/docker/docker/pkg/stdcopy" "github.com/docker/go-connections/nat" "github.com/docker/go-units" + "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/shim/backends" "github.com/dstackai/dstack/runner/internal/shim/host" bytesize "github.com/inhies/go-bytesize" @@ -62,11 +64,12 @@ type DockerRunner struct { } func NewDockerRunner(dockerParams DockerParameters) (*DockerRunner, error) { + ctx := context.TODO() client, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation()) if err != nil { return nil, tracerr.Wrap(err) } - dockerInfo, err := client.Info(context.TODO()) + dockerInfo, err := client.Info(ctx) if err != nil { return nil, tracerr.Wrap(err) } @@ -92,9 +95,87 @@ func NewDockerRunner(dockerParams DockerParameters) (*DockerRunner, error) { gpuLock: gpuLock, tasks: NewTaskStorage(), } + + if err := runner.restoreStateFromContainers(ctx); err != nil { + return nil, tracerr.Errorf("failed to restore state from containers: %w", err) + } + return runner, nil } +// restoreStateFromContainers regenerates TaskStorage and GpuLock inspecting containers +// Used to restore shim state on restarts +func (d *DockerRunner) restoreStateFromContainers(ctx context.Context) error { + listOptions := container.ListOptions{ + All: true, + Filters: filters.NewArgs(filters.Arg("label", fmt.Sprintf("%s=%s", LabelKeyIsTask, LabelValueTrue))), + } + containers, err := d.client.ContainerList(ctx, listOptions) + if err != nil { + return fmt.Errorf("failed to get container list: %w", err) + } + for _, containerShort := range containers { + containerID := containerShort.ID + taskID := containerShort.Labels[LabelKeyTaskID] + if taskID == "" { + log.Printf("container %s has no %s label", containerID, LabelKeyTaskID) + continue + } + var status TaskStatus + if containerShort.State == "exited" { + status = TaskStatusTerminated + } else { + status = TaskStatusRunning + } + var containerName string + if len(containerShort.Names) > 0 { + // "Names are prefixed with their parent and / == the docker daemon" + // https://github.com/moby/moby/issues/6705 + containerName = strings.TrimLeft(containerShort.Names[0], "/") + } + var gpuIDs []string + if d.gpuVendor != host.GpuVendorNone { + if containerFull, err := d.client.ContainerInspect(ctx, containerID); err != nil { + log.Printf("failed to inspect container=%s task=%s", containerID, taskID) + } else if d.gpuVendor == host.GpuVendorNvidia { + deviceRequests := containerFull.HostConfig.Resources.DeviceRequests + if len(deviceRequests) == 1 { + gpuIDs = deviceRequests[0].DeviceIDs + } else if len(deviceRequests) != 0 { + log.Printf( + "cannot extract GPU IDs container=%s task=%s: more than one DeviceRequest", + containerID, taskID, + ) + } + } else { + for _, device := range containerFull.HostConfig.Resources.Devices { + if host.IsRenderNodePath(device.PathOnHost) { + gpuIDs = append(gpuIDs, device.PathOnHost) + } + } + } + } + var runnerDir string + for _, mount := range containerShort.Mounts { + if mount.Destination == consts.RunnerDir { + runnerDir = mount.Source + break + } + } + task := NewTask(taskID, status, containerName, containerID, gpuIDs, runnerDir) + if !d.tasks.Add(task) { + log.Printf("duplicate restored task %s", taskID) + } else { + log.Printf("restored task ID=%s, status=%s, gpuIDs=%v", taskID, status, gpuIDs) + } + if status == TaskStatusRunning && len(gpuIDs) > 0 { + lockedGpuIDs := d.gpuLock.Lock(gpuIDs) + log.Printf("locked GPU(s) due to running task gpuIDs=%v task=%s", lockedGpuIDs, taskID) + } + } + return nil +} + func (d *DockerRunner) Resources() Resources { cpuCount := host.GetCpuCount() totalMemory, err := host.GetTotalMemory() @@ -118,50 +199,111 @@ func (d *DockerRunner) Resources() Resources { } } -func (d *DockerRunner) Run(ctx context.Context, cfg TaskConfig) error { - task := NewTask(cfg) +func (d *DockerRunner) TaskIDs() []string { + return d.tasks.IDs() +} +func (d *DockerRunner) TaskInfo(taskID string) TaskInfo { + task, ok := d.tasks.Get(taskID) + if !ok { + return TaskInfo{} + } + taskInfo := TaskInfo{ + ID: task.ID, + Status: task.Status, + TerminationReason: task.TerminationReason, + TerminationMessage: task.TerminationMessage, + ContainerName: task.containerName, + ContainerID: task.containerID, + GpuIDs: task.gpuIDs, + } + if taskInfo.GpuIDs == nil { + taskInfo.GpuIDs = []string{} + } + return taskInfo +} + +func (d *DockerRunner) Submit(ctx context.Context, cfg TaskConfig) error { + if cfg.ID == "" { + return tracerr.Errorf("%w: empty task ID", ErrRequest) + } + if cfg.Name == "" { + return tracerr.Errorf("%w: empty task Name", ErrRequest) + } + task := NewTaskFromConfig(cfg) // For legacy API compatibility, since LegacyTaskID is the same for all tasks if task.ID == LegacyTaskID { - d.tasks.Delete(task.ID) + if currentTask, ok := d.tasks.Get(LegacyTaskID); ok { + if currentTask.Status != TaskStatusTerminated { + if err := d.Terminate(ctx, LegacyTaskID, 0, "", ""); err != nil { + log.Printf("failed to terminate task: %v", err) + } + } + if err := d.Remove(ctx, LegacyTaskID); err != nil { + log.Printf("failed to remove task: %v", err) + } + } } - if ok := d.tasks.Add(task); !ok { - return tracerr.Errorf("task %s is already submitted", task.ID) + return tracerr.Errorf("%w: task %s is already submitted", ErrRequest, task.ID) + } + return nil +} + +func (d *DockerRunner) Run(ctx context.Context, taskID string) error { + task, ok := d.tasks.Get(taskID) + if !ok { + log.Printf("cannot run task %s: not found", taskID) + return fmt.Errorf("task %s: %w", taskID, ErrNotFound) + } + + if task.Status != TaskStatusPending { + return fmt.Errorf("%w: cannot run task %s with %s status", ErrRequest, task.ID, task.Status) } defer func() { - if ok := d.tasks.Update(task); !ok { - log.Printf("failed to update task %s", task.ID) + if err := d.tasks.Update(task); err != nil { + if currentTask, ok := d.tasks.Get(LegacyTaskID); ok && currentTask.Status != task.Status { + // ignore error if task is gone or status has not changed, e.g., terminated -> terminated + log.Printf("failed to update task %s: %v", task.ID, err) + } } }() + task.SetStatusPreparing() + if err := d.tasks.Update(task); err != nil { + return tracerr.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err) + } + + cfg := task.config var err error - if cfg.GpuCount != 0 { - gpuIDs, err := d.gpuLock.Acquire(cfg.GpuCount) + if cfg.GPU != 0 { + gpuIDs, err := d.gpuLock.Acquire(cfg.GPU) if err != nil { log.Println(err) task.SetStatusTerminated("EXECUTOR_ERROR", err.Error()) return tracerr.Wrap(err) } task.gpuIDs = gpuIDs + log.Printf("acquired GPU(s) gpuIDs=%v task=%s", gpuIDs, task.ID) defer func() { - d.gpuLock.Release(task.gpuIDs) + releasedGpuIDs := d.gpuLock.Release(task.gpuIDs) + log.Printf("released GPU(s) gpuIDs=%v task=%s", releasedGpuIDs, task.ID) }() } - if cfg.SshKey != "" { - ak := AuthorizedKeys{user: cfg.SshUser, lookup: user.Lookup} - if err := ak.AppendPublicKeys([]string{cfg.SshKey}); err != nil { + if len(cfg.HostSshKeys) > 0 { + ak := AuthorizedKeys{user: cfg.HostSshUser, lookup: user.Lookup} + if err := ak.AppendPublicKeys(cfg.HostSshKeys); err != nil { errMessage := fmt.Sprintf("ak.AppendPublicKeys error: %s", err.Error()) log.Println(errMessage) task.SetStatusTerminated("EXECUTOR_ERROR", errMessage) return tracerr.Wrap(err) } defer func(cfg TaskConfig) { - err := ak.RemovePublicKeys([]string{cfg.SshKey}) + err := ak.RemovePublicKeys(cfg.HostSshKeys) if err != nil { log.Printf("Error RemovePublicKeys: %s\n", err.Error()) } @@ -192,8 +334,8 @@ func (d *DockerRunner) Run(ctx context.Context, cfg TaskConfig) error { pullCtx, cancelPull := context.WithTimeout(ctx, ImagePullTimeout) defer cancelPull() task.SetStatusPulling(cancelPull) - if !d.tasks.Update(task) { - return tracerr.Errorf("failed to update task %s", task.ID) + if err := d.tasks.Update(task); err != nil { + return tracerr.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err) } if err = pullImage(pullCtx, d.client, cfg); err != nil { errMessage := fmt.Sprintf("pullImage error: %s", err.Error()) @@ -202,48 +344,25 @@ func (d *DockerRunner) Run(ctx context.Context, cfg TaskConfig) error { return tracerr.Wrap(err) } - log.Println("Creating container") + log.Printf("Creating container name=%s task=%s", task.containerName, task.ID) task.SetStatusCreating() - if !d.tasks.Update(task) { - return tracerr.Errorf("failed to update task %s", task.ID) + if err := d.tasks.Update(task); err != nil { + return tracerr.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err) } - containerID, err := d.createContainer(ctx, task) + containerID, err := d.createContainer(ctx, &task) if err != nil { errMessage := fmt.Sprintf("createContainer error: %s", err.Error()) - log.Print(errMessage + "\n") + log.Println(errMessage) task.SetStatusTerminated("CREATING_CONTAINER_ERROR", errMessage) return tracerr.Wrap(err) } - defer func() { - log.Println("Deleting old container(s)") - listFilters := filters.NewArgs( - filters.Arg("label", fmt.Sprintf("%s=%s", LabelKeyIsTask, LabelValueTrue)), - filters.Arg("status", "exited"), - ) - containers, err := d.client.ContainerList(ctx, container.ListOptions{Filters: listFilters}) - if err != nil { - log.Printf("ContainerList error: %s\n", err.Error()) - return - } - for _, container_ := range containers { - if container_.ID == containerID { - continue - } - err := d.client.ContainerRemove(ctx, container_.ID, container.RemoveOptions{Force: true, RemoveVolumes: true}) - if err != nil { - log.Printf("ContainerRemove error: %s\n", err.Error()) - } - } - }() - - log.Printf("Running container, name=%s, id=%s\n", task.containerName, containerID) + log.Printf("Running container name=%s task=%s", task.containerName, task.ID) task.SetStatusRunning(containerID) - if !d.tasks.Update(task) { - return tracerr.Errorf("failed to update task %s", task.ID) + if err := d.tasks.Update(task); err != nil { + return tracerr.Errorf("%w: failed to update task %s: %w", ErrInternal, task.ID, err) } - - if err = runContainer(ctx, d.client, containerID); err != nil { + if err = d.runContainer(ctx, &task); err != nil { log.Printf("runContainer error: %s\n", err.Error()) var errMessage string if lastLogs, err := getContainerLastLogs(d.client, containerID, 5); err == nil { @@ -262,27 +381,110 @@ func (d *DockerRunner) Run(ctx context.Context, cfg TaskConfig) error { return nil } -func (d *DockerRunner) Stop(force bool) { - task, ok := d.tasks.Get(LegacyTaskID) +// Terminate aborts running operations (pulling an image, running a container) and sets task status to terminated +// Associated resources (container, logs, etc.) are not destroyed, use Remove() for cleanup +func (d *DockerRunner) Terminate(ctx context.Context, taskID string, timeout uint, reason string, message string) (err error) { + task, ok := d.tasks.Get(taskID) if !ok { - return + log.Printf("cannot terminate task %s: not found", taskID) + return fmt.Errorf("task %s: %w", taskID, ErrNotFound) + } + task.Lock() + defer task.Release() + defer func() { + if err := d.tasks.Update(task); err != nil { + log.Printf("failed to update task %s: %v", task.ID, err) + } + }() + return d.terminate(ctx, &task, timeout, reason, message) +} + +func (d *DockerRunner) terminate(ctx context.Context, task *Task, timeout uint, reason string, message string) (err error) { + log.Printf("terminating task %s", task.ID) + defer func() { + if err != nil { + log.Printf("cannot terminate task %s: %v", task.ID, err) + } + }() + if !task.IsTransitionAllowed(TaskStatusTerminated) { + return fmt.Errorf("%w: cannot terminate task %s with %s status", ErrRequest, task.ID, task.Status) } switch task.Status { - case TaskStatusPending, TaskStatusCreating, TaskStatusTerminated: + case TaskStatusPending, TaskStatusPreparing, TaskStatusCreating, TaskStatusTerminated: // nothing to do case TaskStatusPulling: task.cancelPull() case TaskStatusRunning: stopOptions := container.StopOptions{} - if force { - timeout := int(0) - stopOptions.Timeout = &timeout + timeout := int(timeout) + stopOptions.Timeout = &timeout + if err := d.client.ContainerStop(ctx, task.containerID, stopOptions); err != nil { + return fmt.Errorf("%w: failed to stop container: %w", ErrInternal, err) + } + default: + return fmt.Errorf("%w: should not reach here", ErrInternal) + } + if len(task.gpuIDs) > 0 { + releasedGpuIDs := d.gpuLock.Release(task.gpuIDs) + log.Printf("released GPU(s) gpuIDs=%v task=%s", releasedGpuIDs, task.ID) + } + task.SetStatusTerminated(reason, message) + log.Printf("task %s terminated", task.ID) + return nil +} + +// Remove destroys resources associated with task (container, logs, etc.), if any +// On success, it also removes the task from TaskStorage +func (d *DockerRunner) Remove(ctx context.Context, taskID string) error { + task, ok := d.tasks.Get(taskID) + if !ok { + log.Printf("cannot remove task %s: not found", taskID) + return fmt.Errorf("task %s: %w", taskID, ErrNotFound) + } + task.Lock() + defer task.Release() + err := d.remove(ctx, &task) + if err == nil { + d.tasks.Delete(taskID) + } + return err +} + +func (d *DockerRunner) remove(ctx context.Context, task *Task) (err error) { + log.Printf("removing task %s", task.ID) + defer func() { + if err != nil { + log.Printf("cannot remove task %s: %v", task.ID, err) } - err := d.client.ContainerStop(context.Background(), task.containerID, stopOptions) + }() + if task.Status != TaskStatusTerminated { + return fmt.Errorf("%w: cannot remove task %s with %s status", ErrRequest, task.ID, task.Status) + } + removeOptions := container.RemoveOptions{Force: true, RemoveVolumes: true} + // Normally, it should not be empty + if task.containerID != "" { + err := d.client.ContainerRemove(ctx, task.containerID, removeOptions) if err != nil { - log.Printf("Failed to stop container: %s", err) + if errdefs.IsNotFound(err) { + log.Printf("cannot remove container task=%s: not found", task.ID) + } else { + return fmt.Errorf("%w: failed to remove container task=%s: %w", ErrInternal, task.ID, err) + } } } + // Normally, it should not be empty + if task.runnerDir != "" { + // Failed attempts to remove or rename runner dir are considered non-fatal + if err := os.RemoveAll(task.runnerDir); err != nil { + log.Printf("failed to remove runner directory %s: %v", task.runnerDir, err) + trashName := fmt.Sprintf(".trash-%s-%d", task.runnerDir, time.Now().UnixMicro()) + if err := os.Rename(task.runnerDir, trashName); err != nil { + log.Printf("failed to rename runner directory %s: %v", task.runnerDir, err) + } + } + } + log.Printf("task %s removed", task.ID) + return nil } func (d *DockerRunner) GetState() (RunnerStatus, JobResult) { @@ -299,6 +501,8 @@ func getLegacyStatus(task Task) RunnerStatus { switch task.Status { case TaskStatusPending: return Pulling + case TaskStatusPreparing: + return Pulling case TaskStatusPulling: return Pulling case TaskStatusCreating: @@ -567,28 +771,12 @@ func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConf return nil } -func (d *DockerRunner) createContainer(ctx context.Context, task Task) (string, error) { - // For legacy API compatibility, since LegacyTaskID is the same for all tasks, containerName is not unique - // With new API where task.ID is unique (and, in turn, containerName is unique too), container name clash - // is not expected - if task.ID == LegacyTaskID { - timeout := int(0) - stopOptions := container.StopOptions{Timeout: &timeout} - err := d.client.ContainerStop(ctx, task.containerName, stopOptions) - if err != nil { - log.Printf("Cleanup routine: Cannot stop container: %s", err) - } - removeOptions := container.RemoveOptions{Force: true, RemoveVolumes: true} - err = d.client.ContainerRemove(ctx, task.containerName, removeOptions) - if err != nil { - log.Printf("Cleanup routine: Cannot remove container: %s", err) - } - } - - runnerDir, err := d.dockerParams.MakeRunnerDir() +func (d *DockerRunner) createContainer(ctx context.Context, task *Task) (string, error) { + runnerDir, err := d.dockerParams.MakeRunnerDir(task.containerName) if err != nil { return "", tracerr.Wrap(err) } + task.runnerDir = runnerDir mounts, err := d.dockerParams.DockerMounts(runnerDir) if err != nil { return "", tracerr.Wrap(err) @@ -625,7 +813,7 @@ func (d *DockerRunner) createContainer(ctx context.Context, task Task) (string, containerConfig := &container.Config{ Image: task.config.ImageName, - Cmd: []string{strings.Join(d.dockerParams.DockerShellCommands(task.config.PublicKeys), " && ")}, + Cmd: []string{strings.Join(d.dockerParams.DockerShellCommands(task.config.ContainerSshKeys), " && ")}, Entrypoint: []string{"/bin/sh", "-c"}, ExposedPorts: exposePorts(d.dockerParams.DockerPorts()...), Env: envVars, @@ -647,12 +835,13 @@ func (d *DockerRunner) createContainer(ctx context.Context, task Task) (string, ShmSize: task.config.ShmSize, Tmpfs: tmpfs, } + hostConfig.Resources.NanoCPUs = int64(task.config.CPU * 1000000000) + hostConfig.Resources.Memory = task.config.Memory if len(task.gpuIDs) > 0 { configureGpus(hostConfig, d.gpuVendor, task.gpuIDs) } configureHpcNetworkingIfAvailable(hostConfig) - log.Printf("Creating container %s:\nconfig: %v\nhostConfig:%v", task.containerName, containerConfig, hostConfig) resp, err := d.client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, task.containerName) if err != nil { return "", tracerr.Wrap(err) @@ -660,12 +849,12 @@ func (d *DockerRunner) createContainer(ctx context.Context, task Task) (string, return resp.ID, nil } -func runContainer(ctx context.Context, client docker.APIClient, containerID string) error { - if err := client.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { +func (d *DockerRunner) runContainer(ctx context.Context, task *Task) error { + if err := d.client.ContainerStart(ctx, task.containerID, container.StartOptions{}); err != nil { return tracerr.Wrap(err) } - waitCh, errorCh := client.ContainerWait(ctx, containerID, "") + waitCh, errorCh := d.client.ContainerWait(ctx, task.containerID, "") select { case waitResp := <-waitCh: { @@ -918,7 +1107,7 @@ func (c *CLIArgs) DockerMounts(hostRunnerDir string) ([]mount.Mount, error) { { Type: mount.TypeBind, Source: hostRunnerDir, - Target: c.Runner.TempDir, + Target: consts.RunnerDir, }, { Type: mount.TypeBind, @@ -932,8 +1121,8 @@ func (c *CLIArgs) DockerPorts() []int { return []int{c.Runner.HTTPPort, c.Docker.SSHPort} } -func (c *CLIArgs) MakeRunnerDir() (string, error) { - runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", time.Now().Format("20060102-150405")) +func (c *CLIArgs) MakeRunnerDir(name string) (string, error) { + runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", name) if err := os.MkdirAll(runnerTemp, 0o755); err != nil { return "", tracerr.Wrap(err) } diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 61fc63b09..5a0b2efa6 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -2,6 +2,8 @@ package shim import ( "context" + "encoding/hex" + "math/rand" "os" "os/exec" "strconv" @@ -33,7 +35,11 @@ func TestDocker_SSHServer(t *testing.T) { defer cancel() dockerRunner, _ := NewDockerRunner(params) - assert.NoError(t, dockerRunner.Run(ctx, TaskConfig{ImageName: "ubuntu", Name: t.Name(), ID: time.Now().String()})) + taskConfig := createTaskConfig(t) + defer dockerRunner.Remove(context.Background(), taskConfig.ID) + + assert.NoError(t, dockerRunner.Submit(ctx, taskConfig)) + assert.NoError(t, dockerRunner.Run(ctx, taskConfig.ID)) } // TestDocker_SSHServerConnect pulls ubuntu image (without sshd), installs openssh-server and tries to connect via SSH @@ -64,7 +70,11 @@ func TestDocker_SSHServerConnect(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - assert.NoError(t, dockerRunner.Run(ctx, TaskConfig{ImageName: "ubuntu", Name: t.Name(), ID: time.Now().String()})) + taskConfig := createTaskConfig(t) + defer dockerRunner.Remove(context.Background(), taskConfig.ID) + + assert.NoError(t, dockerRunner.Submit(ctx, taskConfig)) + assert.NoError(t, dockerRunner.Run(ctx, taskConfig.ID)) }() for i := 0; i < timeout; i++ { @@ -101,7 +111,11 @@ func TestDocker_ShmNoexecByDefault(t *testing.T) { defer cancel() dockerRunner, _ := NewDockerRunner(params) - assert.NoError(t, dockerRunner.Run(ctx, TaskConfig{ImageName: "ubuntu", Name: t.Name(), ID: time.Now().String()})) + taskConfig := createTaskConfig(t) + defer dockerRunner.Remove(context.Background(), taskConfig.ID) + + assert.NoError(t, dockerRunner.Submit(ctx, taskConfig)) + assert.NoError(t, dockerRunner.Run(ctx, taskConfig.ID)) } func TestDocker_ShmExecIfSizeSpecified(t *testing.T) { @@ -119,7 +133,12 @@ func TestDocker_ShmExecIfSizeSpecified(t *testing.T) { defer cancel() dockerRunner, _ := NewDockerRunner(params) - assert.NoError(t, dockerRunner.Run(ctx, TaskConfig{ImageName: "ubuntu", ShmSize: 1024 * 1024, Name: t.Name(), ID: time.Now().String()})) + taskConfig := createTaskConfig(t) + taskConfig.ShmSize = 1024 * 1024 + defer dockerRunner.Remove(context.Background(), taskConfig.ID) + + assert.NoError(t, dockerRunner.Submit(ctx, taskConfig)) + assert.NoError(t, dockerRunner.Run(ctx, taskConfig.ID)) } /* Mocks */ @@ -164,7 +183,7 @@ func (c *dockerParametersMock) DockerMounts(string) ([]mount.Mount, error) { return nil, nil } -func (c *dockerParametersMock) MakeRunnerDir() (string, error) { +func (c *dockerParametersMock) MakeRunnerDir(string) (string, error) { return "", nil } @@ -175,3 +194,26 @@ var portNumber int32 = 10000 func nextPort() int { return int(atomic.AddInt32(&portNumber, 1)) } + +var ( + randSrc = rand.New(rand.NewSource(time.Now().UnixNano())) + randMu = sync.Mutex{} +) + +func generateID(t *testing.T) string { + const idLen = 16 + b := make([]byte, idLen/2) + randMu.Lock() + defer randMu.Unlock() + _, err := randSrc.Read(b) + require.Nil(t, err) + return hex.EncodeToString(b)[:idLen] +} + +func createTaskConfig(t *testing.T) TaskConfig { + return TaskConfig{ + ID: generateID(t), + Name: t.Name(), + ImageName: "ubuntu", + } +} diff --git a/runner/internal/shim/errs.go b/runner/internal/shim/errs.go new file mode 100644 index 000000000..237d2616d --- /dev/null +++ b/runner/internal/shim/errs.go @@ -0,0 +1,35 @@ +package shim + +import "errors" + +/* +Definitions of common error types used throughout shim. +Errors should wrap these errors to simplify error classifications, e.g.: + + func cleanup(containerID string) { + ... + return fmt.Errorf("%w: failed to remove container") + } + + if err := cleanup(containerID); errors.Is(err, ErrInternal) { + return ErrorResponse { + Status: 500, + Message: err.Error(), + } + } else if errors.Is(err, ErrNotFound) { + return ErrorResponse { + Status: 404, + Message: err.Error(), + } + } +*/ +var ( + // shim failed to process request due to internal error + ErrInternal = errors.New("internal error") + // shim rejected to process request, e.g., bad params, state conflict, etc. + ErrRequest = errors.New("request error") + // referenced object does not exist + ErrNotFound = errors.New("not found") + // object already exists (conflict) + ErrAlreadyExists = errors.New("already exists") +) diff --git a/runner/internal/shim/host/gpu.go b/runner/internal/shim/host/gpu.go index fdbed4b1b..e62a376c1 100644 --- a/runner/internal/shim/host/gpu.go +++ b/runner/internal/shim/host/gpu.go @@ -192,3 +192,7 @@ func getAmdRenderNodePath(bdf string) (string, error) { } return path, nil } + +func IsRenderNodePath(path string) bool { + return strings.HasPrefix(path, "/dev/dri/renderD") +} diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 4dd0eed2c..8733a53e7 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -9,7 +9,7 @@ type DockerParameters interface { DockerShellCommands([]string) []string DockerMounts(string) ([]mount.Mount, error) DockerPorts() []int - MakeRunnerDir() (string, error) + MakeRunnerDir(name string) (string, error) DockerPJRTDevice() string } @@ -24,7 +24,6 @@ type CLIArgs struct { LogLevel int DownloadURL string BinaryPath string - TempDir string HomeDir string WorkingDir string } @@ -63,14 +62,27 @@ type TaskConfig struct { ImageName string `json:"image_name"` ContainerUser string `json:"container_user"` Privileged bool `json:"privileged"` - GpuCount int `json:"gpu_count"` - ShmSize int64 `json:"shm_size"` - PublicKeys []string `json:"public_keys"` - SshUser string `json:"ssh_user"` - SshKey string `json:"ssh_key"` + GPU int `json:"gpu"` // -1 = all available, even if zero; 0 = zero, ... + CPU float64 `json:"cpu"` // 0.0 = all available; 0.5 = a half of CPU, ... + Memory int64 `json:"memory"` // bytes; 0 = all avaliable + ShmSize int64 `json:"shm_size"` // bytes; 0 = default (64MiB) Volumes []VolumeInfo `json:"volumes"` VolumeMounts []VolumeMountPoint `json:"volume_mounts"` InstanceMounts []InstanceMountPoint `json:"instance_mounts"` + HostSshUser string `json:"host_ssh_user"` + HostSshKeys []string `json:"host_ssh_keys"` + // TODO: submit keys to runner, not to shim + ContainerSshKeys []string `json:"container_ssh_keys"` +} + +type TaskInfo struct { + ID string + Status TaskStatus + TerminationReason string + TerminationMessage string + ContainerName string + ContainerID string + GpuIDs []string } // a surrogate ID used for tasks submitted via legacy API diff --git a/runner/internal/shim/resources.go b/runner/internal/shim/resources.go index 8243e9da5..7145b898d 100644 --- a/runner/internal/shim/resources.go +++ b/runner/internal/shim/resources.go @@ -87,11 +87,33 @@ func (gl *GpuLock) Acquire(count int) ([]string, error) { return ids, nil } +// Lock marks passed Resource IDs as locked (busy) +// This method never fails, it's safe to lock already locked resource or try to lock unknown resource +// The returned slice contains only actually locked resource IDs +func (gl *GpuLock) Lock(ids []string) []string { + gl.mu.Lock() + defer gl.mu.Unlock() + lockedIDs := make([]string, 0, len(ids)) + for _, id := range ids { + if locked, ok := gl.lock[id]; !ok { + log.Printf("skipping %s: unknown GPU resource", id) + } else if locked { + log.Printf("skipping %s: already locked", id) + } else { + gl.lock[id] = true + lockedIDs = append(lockedIDs, id) + } + } + return lockedIDs +} + // Release marks passed Resource IDs as idle // This method never fails, it's safe to release already idle resource or try to release unknown resource -func (gl *GpuLock) Release(ids []string) { +// The returned slice contains only actually released resource IDs +func (gl *GpuLock) Release(ids []string) []string { gl.mu.Lock() defer gl.mu.Unlock() + releasedIDs := make([]string, 0, len(ids)) for _, id := range ids { if locked, ok := gl.lock[id]; !ok { log.Printf("skipping %s: unknown GPU resource", id) @@ -99,6 +121,8 @@ func (gl *GpuLock) Release(ids []string) { log.Printf("skipping %s: not locked", id) } else { gl.lock[id] = false + releasedIDs = append(releasedIDs, id) } } + return releasedIDs } diff --git a/runner/internal/shim/resources_test.go b/runner/internal/shim/resources_test.go index 230338d1f..013a7cb9d 100644 --- a/runner/internal/shim/resources_test.go +++ b/runner/internal/shim/resources_test.go @@ -138,7 +138,7 @@ func TestGpuLock_Acquire_Count_ErrNoCapacity(t *testing.T) { assert.True(t, gl.lock["GPU-f00d"], "GPU-f00d") } -func TestGpuLock_Acquire_Release(t *testing.T) { +func TestGpuLock_Lock(t *testing.T) { gpus := []host.GpuInfo{ {Vendor: host.GpuVendorNvidia, ID: "GPU-beef"}, {Vendor: host.GpuVendorNvidia, ID: "GPU-f00d"}, @@ -147,17 +147,52 @@ func TestGpuLock_Acquire_Release(t *testing.T) { gl, _ := NewGpuLock(gpus) gl.lock["GPU-beef"] = true gl.lock["GPU-f00d"] = true - gl.Release([]string{ + locked := gl.Lock([]string{ + "GPU-beef", // already locked + "GPU-dead", // unknown + "GPU-c0de", // not locked + }) + assert.Equal(t, []string{"GPU-c0de"}, locked) + assert.True(t, gl.lock["GPU-beef"], "GPU-beef") // was already locked + assert.True(t, gl.lock["GPU-f00d"], "GPU-f00d") // was already locked + assert.True(t, gl.lock["GPU-c0de"], "GPU-c0de") // has been locked +} + +func TestGpuLock_Lock_Nil(t *testing.T) { + gpus := []host.GpuInfo{ + {Vendor: host.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: host.GpuVendorNvidia, ID: "GPU-f00d"}, + } + gl, _ := NewGpuLock(gpus) + gl.lock["GPU-beef"] = true + var ids []string + locked := gl.Lock(ids) + assert.Equal(t, []string{}, locked) + assert.True(t, gl.lock["GPU-beef"], "GPU-beef") + assert.False(t, gl.lock["GPU-f00d"], "GPU-f00d") +} + +func TestGpuLock_Release(t *testing.T) { + gpus := []host.GpuInfo{ + {Vendor: host.GpuVendorNvidia, ID: "GPU-beef"}, + {Vendor: host.GpuVendorNvidia, ID: "GPU-f00d"}, + {Vendor: host.GpuVendorNvidia, ID: "GPU-c0de"}, + } + gl, _ := NewGpuLock(gpus) + gl.lock["GPU-beef"] = true + gl.lock["GPU-f00d"] = true + released := gl.Release([]string{ "GPU-beef", // locked "GPU-dead", // unknown "GPU-c0de", // not locked }) + assert.Equal(t, []string{"GPU-beef"}, released) assert.False(t, gl.lock["GPU-beef"], "GPU-beef") // has been unlocked assert.True(t, gl.lock["GPU-f00d"], "GPU-f00d") // still locked assert.False(t, gl.lock["GPU-c0de"], "GPU-c0de") // was already unlocked } -func TestGpuLock_Acquire_Release_Nil(t *testing.T) { +func TestGpuLock_Release_Nil(t *testing.T) { gpus := []host.GpuInfo{ {Vendor: host.GpuVendorNvidia, ID: "GPU-beef"}, {Vendor: host.GpuVendorNvidia, ID: "GPU-f00d"}, @@ -165,7 +200,8 @@ func TestGpuLock_Acquire_Release_Nil(t *testing.T) { gl, _ := NewGpuLock(gpus) gl.lock["GPU-beef"] = true var ids []string - gl.Release(ids) + released := gl.Release(ids) + assert.Equal(t, []string{}, released) assert.True(t, gl.lock["GPU-beef"], "GPU-beef") assert.False(t, gl.lock["GPU-f00d"], "GPU-f00d") } diff --git a/runner/internal/shim/runner.go b/runner/internal/shim/runner.go index d9c4b62f5..41c0fc915 100644 --- a/runner/internal/shim/runner.go +++ b/runner/internal/shim/runner.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/gerrors" ) @@ -40,7 +41,7 @@ func (c *CLIArgs) getRunnerArgs() []string { "--log-level", strconv.Itoa(c.Runner.LogLevel), "start", "--http-port", strconv.Itoa(c.Runner.HTTPPort), - "--temp-dir", c.Runner.TempDir, + "--temp-dir", consts.RunnerDir, "--home-dir", c.Runner.HomeDir, "--working-dir", c.Runner.WorkingDir, } diff --git a/runner/internal/shim/task.go b/runner/internal/shim/task.go index fb56bf9f3..65f11d264 100644 --- a/runner/internal/shim/task.go +++ b/runner/internal/shim/task.go @@ -4,17 +4,19 @@ import ( "context" "crypto/sha256" "fmt" + "log" "sync" ) type TaskStatus string const ( - // pending -> pulling -> creating -> running -> terminated - // | | | - // v v v - // terminated terminated terminated + // pending -> preparing -> pulling -> creating -> running -> terminated + // | | | | + // v v v v + // terminated terminated terminated terminated TaskStatusPending TaskStatus = "pending" + TaskStatusPreparing TaskStatus = "preparing" TaskStatusPulling TaskStatus = "pulling" TaskStatusCreating TaskStatus = "creating" TaskStatusRunning TaskStatus = "running" @@ -36,6 +38,55 @@ type Task struct { containerID string cancelPull context.CancelFunc gpuIDs []string + runnerDir string // path on host mapped to consts.RunnerDir in container + + mu *sync.Mutex +} + +// Lock is used for exclusive operations, e.g, stopping a container, +// removing task data, etc. +func (t *Task) Lock() { + if !t.mu.TryLock() { + log.Fatalf("task %s already locked!", t.ID) + } + log.Printf("task %s locked", t.ID) +} + +// Release should be called Unlock, but this name triggers govet copylocks check, +// since "thanks" to Go implicit interfaces, a struct with Lock/Unlock method pair +// looks like lock: https://github.com/golang/go/issues/18451 +func (t *Task) Release() { + t.mu.Unlock() + log.Printf("task %s unlocked", t.ID) +} + +func (t *Task) IsTransitionAllowed(toStatus TaskStatus) bool { + if t.Status == TaskStatusTerminated { + // terminal status, cannot transition further + return false + } + switch toStatus { + case TaskStatusPending: + // initial status, task should be Add()ed with it, not Update()d + return false + case TaskStatusPreparing: + return t.Status == TaskStatusPending + case TaskStatusPulling: + return t.Status == TaskStatusPreparing + case TaskStatusCreating: + return t.Status == TaskStatusPulling + case TaskStatusRunning: + return t.Status == TaskStatusCreating + case TaskStatusTerminated: + // we already checked terminated -> terminated (not allowed), + // all other transitions are allowed + return true + } + return false +} + +func (t *Task) SetStatusPreparing() { + t.Status = TaskStatusPreparing } func (t *Task) SetStatusPulling(cancelPull context.CancelFunc) { @@ -60,12 +111,25 @@ func (t *Task) SetStatusTerminated(reason string, message string) { t.cancelPull = nil } -func NewTask(cfg TaskConfig) Task { +func NewTask(id string, status TaskStatus, containerName string, containerID string, gpuIDs []string, runnerDir string) Task { + return Task{ + ID: id, + Status: status, + containerName: containerName, + containerID: containerID, + runnerDir: runnerDir, + gpuIDs: gpuIDs, + mu: &sync.Mutex{}, + } +} + +func NewTaskFromConfig(cfg TaskConfig) Task { return Task{ ID: cfg.ID, Status: TaskStatusPending, config: cfg, containerName: generateUniqueName(cfg.Name, cfg.ID), + mu: &sync.Mutex{}, } } @@ -75,6 +139,16 @@ type TaskStorage struct { mu sync.RWMutex } +func (ts *TaskStorage) IDs() []string { + ts.mu.RLock() + defer ts.mu.RUnlock() + ids := make([]string, 0, len(ts.tasks)) + for id := range ts.tasks { + ids = append(ids, id) + } + return ids +} + // Get a _copy_ of the task. To "commit" changes, use Update() func (ts *TaskStorage) Get(id string) (Task, bool) { ts.mu.RLock() @@ -95,14 +169,19 @@ func (ts *TaskStorage) Add(task Task) bool { } // Update the _existing_ task. If the task is not in the storage, do nothing and return false -func (ts *TaskStorage) Update(task Task) bool { +// If the current status is terminated, do nothing and return false +func (ts *TaskStorage) Update(task Task) error { ts.mu.Lock() defer ts.mu.Unlock() - if _, ok := ts.tasks[task.ID]; !ok { - return false + currentTask, ok := ts.tasks[task.ID] + if !ok { + return ErrNotFound + } + if !currentTask.IsTransitionAllowed(task.Status) { + return fmt.Errorf("%w: %s -> %s transition not allowed", ErrRequest, currentTask.Status, task.Status) } ts.tasks[task.ID] = task - return true + return nil } func (ts *TaskStorage) Delete(id string) { diff --git a/runner/internal/shim/task_test.go b/runner/internal/shim/task_test.go index 088a9f065..0cf13d406 100644 --- a/runner/internal/shim/task_test.go +++ b/runner/internal/shim/task_test.go @@ -1,6 +1,7 @@ package shim import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -48,19 +49,31 @@ func TestTaskStorage_Update_OK(t *testing.T) { storage.tasks["1"] = storedTask updatedTask := Task{ID: "1", Status: TaskStatusTerminated} - ok := storage.Update(updatedTask) - assert.True(t, ok) + err := storage.Update(updatedTask) + assert.Nil(t, err) assert.Equal(t, updatedTask, storage.tasks["1"]) } func TestTaskStorage_Update_DoesNotExist(t *testing.T) { storage := NewTaskStorage() - ok := storage.Update(Task{ID: "1", Status: TaskStatusPending}) - assert.False(t, ok) + err := storage.Update(Task{ID: "1", Status: TaskStatusPending}) + assert.ErrorIs(t, err, ErrNotFound) assert.Equal(t, 0, len(storage.tasks)) } +func TestTaskStorage_Update_TransitionNotAllowed(t *testing.T) { + storage := NewTaskStorage() + storedTask := Task{ID: "1", Status: TaskStatusPending} + storage.tasks["1"] = storedTask + updatedTask := Task{ID: "1", Status: TaskStatusRunning} + + err := storage.Update(updatedTask) + assert.ErrorIs(t, err, ErrRequest) + assert.ErrorContains(t, err, fmt.Sprintf("%s -> %s", storedTask.Status, updatedTask.Status)) + assert.Equal(t, storedTask, storage.tasks["1"]) +} + func TestTaskStorage_Delete(t *testing.T) { storage := NewTaskStorage() storage.tasks["1"] = Task{ID: "1", Status: TaskStatusRunning} @@ -72,12 +85,48 @@ func TestTaskStorage_Delete(t *testing.T) { assert.Equal(t, 0, len(storage.tasks)) } -func TestNewTask(t *testing.T) { +func TestTask_IsTransitionAllowed_true(t *testing.T) { + testCases := []struct { + oldStatus, newStatus TaskStatus + }{ + {TaskStatusPending, TaskStatusPreparing}, + {TaskStatusPending, TaskStatusTerminated}, + {TaskStatusPreparing, TaskStatusPulling}, + {TaskStatusPreparing, TaskStatusTerminated}, + {TaskStatusPulling, TaskStatusCreating}, + {TaskStatusPulling, TaskStatusTerminated}, + {TaskStatusCreating, TaskStatusRunning}, + {TaskStatusCreating, TaskStatusTerminated}, + {TaskStatusRunning, TaskStatusTerminated}, + } + for _, tc := range testCases { + task := Task{ID: "1", Status: tc.oldStatus} + assert.True(t, task.IsTransitionAllowed(tc.newStatus), "%s -> %s", tc.oldStatus, tc.newStatus) + } +} + +func TestTask_IsTransitionAllowed_false(t *testing.T) { + testCases := []struct { + oldStatus, newStatus TaskStatus + }{ + // non-exhaustive list of impossible transitions + {TaskStatusPending, TaskStatusPending}, + {TaskStatusPending, TaskStatusRunning}, + {TaskStatusPulling, TaskStatusPending}, + {TaskStatusTerminated, TaskStatusTerminated}, + } + for _, tc := range testCases { + task := Task{ID: "1", Status: tc.oldStatus} + assert.False(t, task.IsTransitionAllowed(tc.newStatus), "%s -> %s", tc.oldStatus, tc.newStatus) + } +} + +func TestNewTaskFromConfig(t *testing.T) { cfg := TaskConfig{ ID: "66a886db-86db-4cf9-8c06-8984ad15dde2", Name: "vllm-0-0", } - task := NewTask(cfg) + task := NewTaskFromConfig(cfg) assert.Equal(t, "66a886db-86db-4cf9-8c06-8984ad15dde2", task.ID) assert.Equal(t, "vllm-0-0-cff1b8da", task.containerName) From 3cb216641e641c3e65256cc6b47f6f511f61acf6 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Thu, 26 Dec 2024 10:46:43 +0000 Subject: [PATCH 2/2] Use only GET and POST It's not a REST API anyway --- runner/docs/shim.openapi.yaml | 49 +++++++++++++++--------------- runner/internal/shim/api/server.go | 4 +-- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/runner/docs/shim.openapi.yaml b/runner/docs/shim.openapi.yaml index 42eb8a551..cd00f8f0d 100644 --- a/runner/docs/shim.openapi.yaml +++ b/runner/docs/shim.openapi.yaml @@ -61,7 +61,7 @@ paths: application/json: schema: $ref: "#/components/schemas/TaskListResponse" - put: + post: tags: - *future-api summary: Submit and run new task @@ -93,29 +93,6 @@ paths: "200": $ref: "#/components/responses/TaskInfo" - delete: - tags: - - *future-api - summary: Remove task - description: > - Removes the task from in-memory storage and destroys its associated - resources: a container, logs, etc. - parameters: - - $ref: "#/parameters/taskId" - responses: - "200": - description: Task removed - $ref: "#/components/responses/PlainTextOk" - "404": - description: Task not found - $ref: "#/components/responses/PlainTextNotFound" - "409": - description: Task is not terminated, cannot remove - $ref: "#/components/responses/PlainTextConflict" - "500": - description: Internal error, e.g., failed to remove a container - $ref: "#/components/responses/PlainTextInternalError" - /tasks/{id}/terminate: post: tags: @@ -151,6 +128,30 @@ paths: description: Internal error, e.g., failed to remove a container $ref: "#/components/responses/PlainTextInternalError" + /tasks/{id}/remove: + post: + tags: + - *future-api + summary: Remove task + description: > + Removes the task from in-memory storage and destroys its associated + resources: a container, logs, etc. + parameters: + - $ref: "#/parameters/taskId" + responses: + "200": + description: Task removed + $ref: "#/components/responses/PlainTextOk" + "404": + description: Task not found + $ref: "#/components/responses/PlainTextNotFound" + "409": + description: Task is not terminated, cannot remove + $ref: "#/components/responses/PlainTextConflict" + "500": + description: Internal error, e.g., failed to remove a container + $ref: "#/components/responses/PlainTextInternalError" + /submit: post: tags: diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index 33ebf2cba..89bd74fd0 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -51,9 +51,9 @@ func NewShimServer(address string, runner TaskRunner, version string) *ShimServe // Future API r.AddHandler("GET", "/api/tasks", s.TaskListHandler) r.AddHandler("GET", "/api/tasks/{id}", s.TaskInfoHandler) - r.AddHandler("PUT", "/api/tasks", s.TaskSubmitHandler) + r.AddHandler("POST", "/api/tasks", s.TaskSubmitHandler) r.AddHandler("POST", "/api/tasks/{id}/terminate", s.TaskTerminateHandler) - r.AddHandler("DELETE", "/api/tasks/{id}", s.TaskRemoveHandler) + r.AddHandler("POST", "/api/tasks/{id}/remove", s.TaskRemoveHandler) // Legacy API r.AddHandler("POST", "/api/submit", s.LegacySubmitPostHandler)