diff --git a/cmd/proplet/main.go b/cmd/proplet/main.go index 60388ec..c6f5f92 100644 --- a/cmd/proplet/main.go +++ b/cmd/proplet/main.go @@ -115,7 +115,7 @@ func checkRegistryConnectivity(ctx context.Context, registryURL string, registry defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("fegistry returned unexpected status: %d", resp.StatusCode) + return fmt.Errorf("registry returned unexpected status: %d", resp.StatusCode) } return nil diff --git a/manager/api/transport.go b/manager/api/transport.go index 73db15b..554ab9b 100644 --- a/manager/api/transport.go +++ b/manager/api/transport.go @@ -74,6 +74,12 @@ func MakeHandler(svc manager.Service, logger *slog.Logger, instanceID string) ht api.EncodeResponse, opts..., ), "update-task").ServeHTTP) + r.Put("/upload", otelhttp.NewHandler(kithttp.NewServer( + updateTaskEndpoint(svc), + decodeUploadTaskFileReq, + api.EncodeResponse, + opts..., + ), "upload-task-file").ServeHTTP) r.Delete("/", otelhttp.NewHandler(kithttp.NewServer( deleteTaskEndpoint(svc), decodeEntityReq("taskID"), @@ -110,6 +116,19 @@ func decodeEntityReq(key string) kithttp.DecodeRequestFunc { } func decodeTaskReq(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, errors.Join(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) + } + + var req taskReq + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Join(err, apiutil.ErrValidation) + } + + return req, nil +} + +func decodeUploadTaskFileReq(_ context.Context, r *http.Request) (interface{}, error) { var req taskReq if err := r.ParseMultipartForm(maxFileSize); err != nil { return nil, err @@ -128,7 +147,7 @@ func decodeTaskReq(_ context.Context, r *http.Request) (interface{}, error) { return nil, err } req.File = data - req.Name = header.Filename + req.Task.ID = chi.URLParam(r, "taskID") return req, nil } diff --git a/manager/service.go b/manager/service.go index 63f57aa..658b0db 100644 --- a/manager/service.go +++ b/manager/service.go @@ -154,8 +154,15 @@ func (svc *service) UpdateTask(ctx context.Context, t task.Task) (task.Task, err return task.Task{}, err } dbT.UpdatedAt = time.Now() - dbT.Name = t.Name - dbT.Inputs = t.Inputs + if t.Name != "" { + dbT.Name = t.Name + } + if t.Inputs != nil { + dbT.Inputs = t.Inputs + } + if t.File != nil { + dbT.File = t.File + } if err := svc.tasksDB.Update(ctx, dbT.ID, dbT); err != nil { return task.Task{}, err