From 6b8ddff1fa36e6bd034e170f2576e9da872678e3 Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Wed, 16 Oct 2024 09:20:26 -0400 Subject: [PATCH] windows: set job object for executor and children (#24214) On Windows, if the `raw_exec` driver's executor exits, the child processes are not also killed. Create a Windows "job object" (not to be confused with a Nomad job) and add the executor to it. Child processes of the executor will inherit the job automatically. When the handle to the job object is freed (on executor exit), the job itself is destroyed and this causes all processes in that job to exit. Fixes: https://github.com/hashicorp/nomad/issues/23668 Ref: https://learn.microsoft.com/en-us/windows/win32/procthread/job-objects --- .changelog/24214.txt | 3 + .github/workflows/test-windows.yml | 2 + drivers/rawexec/driver_test.go | 98 ------------------ drivers/rawexec/driver_unix_test.go | 99 +++++++++++++++++++ drivers/rawexec/driver_windows_test.go | 96 ++++++++++++++++++ drivers/shared/executor/executor_test.go | 45 +-------- drivers/shared/executor/executor_windows.go | 33 ++++++- .../shared/executor/executor_windows_test.go | 88 +++++++++++++++++ drivers/shared/executor/utils_test.go | 47 +++++++++ 9 files changed, 369 insertions(+), 142 deletions(-) create mode 100644 .changelog/24214.txt create mode 100644 drivers/rawexec/driver_windows_test.go create mode 100644 drivers/shared/executor/executor_windows_test.go diff --git a/.changelog/24214.txt b/.changelog/24214.txt new file mode 100644 index 00000000000..d0e59532db9 --- /dev/null +++ b/.changelog/24214.txt @@ -0,0 +1,3 @@ +```release-note:bug +windows: Fixed a bug where a crashed executor would orphan task processes +``` diff --git a/.github/workflows/test-windows.yml b/.github/workflows/test-windows.yml index dc5d961a5c0..6a8b3536c1d 100644 --- a/.github/workflows/test-windows.yml +++ b/.github/workflows/test-windows.yml @@ -87,6 +87,8 @@ jobs: gotestsum --format=short-verbose \ --junitfile results.xml \ github.com/hashicorp/nomad/drivers/docker \ + github.com/hashicorp/nomad/drivers/rawexec \ + github.com/hashicorp/nomad/drivers/shared/executor \ github.com/hashicorp/nomad/client/lib/fifo \ github.com/hashicorp/nomad/client/logmon \ github.com/hashicorp/nomad/client/allocrunner/taskrunner/template \ diff --git a/drivers/rawexec/driver_test.go b/drivers/rawexec/driver_test.go index 35a60fc2b2a..df360f5eb92 100644 --- a/drivers/rawexec/driver_test.go +++ b/drivers/rawexec/driver_test.go @@ -12,7 +12,6 @@ import ( "path/filepath" "runtime" "strconv" - "sync" "syscall" "testing" "time" @@ -237,103 +236,6 @@ func TestRawExecDriver_StartWait(t *testing.T) { require.NoError(harness.DestroyTask(task.ID, true)) } -func TestRawExecDriver_StartWaitRecoverWaitStop(t *testing.T) { - ci.Parallel(t) - require := require.New(t) - - d := newEnabledRawExecDriver(t) - harness := dtestutil.NewDriverHarness(t, d) - defer harness.Kill() - - config := &Config{Enabled: true} - var data []byte - require.NoError(basePlug.MsgPackEncode(&data, config)) - bconfig := &basePlug.Config{ - PluginConfig: data, - AgentConfig: &base.AgentConfig{ - Driver: &base.ClientDriverConfig{ - Topology: d.nomadConfig.Topology, - }, - }, - } - require.NoError(harness.SetConfig(bconfig)) - - allocID := uuid.Generate() - taskName := "sleep" - task := &drivers.TaskConfig{ - AllocID: allocID, - ID: uuid.Generate(), - Name: taskName, - Env: defaultEnv(), - Resources: testResources(allocID, taskName), - } - tc := &TaskConfig{ - Command: testtask.Path(), - Args: []string{"sleep", "100s"}, - } - require.NoError(task.EncodeConcreteDriverConfig(&tc)) - - testtask.SetTaskConfigEnv(task) - - cleanup := harness.MkAllocDir(task, false) - defer cleanup() - - harness.MakeTaskCgroup(allocID, taskName) - - handle, _, err := harness.StartTask(task) - require.NoError(err) - - ch, err := harness.WaitTask(context.Background(), task.ID) - require.NoError(err) - - var waitDone bool - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - result := <-ch - require.Error(result.Err) - waitDone = true - }() - - originalStatus, err := d.InspectTask(task.ID) - require.NoError(err) - - d.tasks.Delete(task.ID) - - wg.Wait() - require.True(waitDone) - _, err = d.InspectTask(task.ID) - require.Equal(drivers.ErrTaskNotFound, err) - - err = d.RecoverTask(handle) - require.NoError(err) - - status, err := d.InspectTask(task.ID) - require.NoError(err) - require.Exactly(originalStatus, status) - - ch, err = harness.WaitTask(context.Background(), task.ID) - require.NoError(err) - - wg.Add(1) - waitDone = false - go func() { - defer wg.Done() - result := <-ch - require.NoError(result.Err) - require.NotZero(result.ExitCode) - require.Equal(9, result.Signal) - waitDone = true - }() - - time.Sleep(300 * time.Millisecond) - require.NoError(d.StopTask(task.ID, 0, "SIGKILL")) - wg.Wait() - require.NoError(d.DestroyTask(task.ID, false)) - require.True(waitDone) -} - func TestRawExecDriver_Start_Wait_AllocDir(t *testing.T) { ci.Parallel(t) require := require.New(t) diff --git a/drivers/rawexec/driver_unix_test.go b/drivers/rawexec/driver_unix_test.go index 4a620856a90..c09e3e0eba3 100644 --- a/drivers/rawexec/driver_unix_test.go +++ b/drivers/rawexec/driver_unix_test.go @@ -14,6 +14,7 @@ import ( "runtime" "strconv" "strings" + "sync" "syscall" "testing" "time" @@ -23,6 +24,7 @@ import ( "github.com/hashicorp/nomad/helper/testtask" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/plugins/base" + basePlug "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/drivers" dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" "github.com/hashicorp/nomad/testutil" @@ -443,3 +445,100 @@ func TestRawExec_ExecTaskStreaming_User(t *testing.T) { require.Empty(t, stderr) require.Contains(t, stdout, "nobody") } + +func TestRawExecDriver_StartWaitRecoverWaitStop(t *testing.T) { + ci.Parallel(t) + require := require.New(t) + + d := newEnabledRawExecDriver(t) + harness := dtestutil.NewDriverHarness(t, d) + defer harness.Kill() + + config := &Config{Enabled: true} + var data []byte + require.NoError(basePlug.MsgPackEncode(&data, config)) + bconfig := &basePlug.Config{ + PluginConfig: data, + AgentConfig: &base.AgentConfig{ + Driver: &base.ClientDriverConfig{ + Topology: d.nomadConfig.Topology, + }, + }, + } + require.NoError(harness.SetConfig(bconfig)) + + allocID := uuid.Generate() + taskName := "sleep" + task := &drivers.TaskConfig{ + AllocID: allocID, + ID: uuid.Generate(), + Name: taskName, + Env: defaultEnv(), + Resources: testResources(allocID, taskName), + } + tc := &TaskConfig{ + Command: testtask.Path(), + Args: []string{"sleep", "100s"}, + } + require.NoError(task.EncodeConcreteDriverConfig(&tc)) + + testtask.SetTaskConfigEnv(task) + + cleanup := harness.MkAllocDir(task, false) + defer cleanup() + + harness.MakeTaskCgroup(allocID, taskName) + + handle, _, err := harness.StartTask(task) + require.NoError(err) + + ch, err := harness.WaitTask(context.Background(), task.ID) + require.NoError(err) + + var waitDone bool + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + result := <-ch + require.Error(result.Err) + waitDone = true + }() + + originalStatus, err := d.InspectTask(task.ID) + require.NoError(err) + + d.tasks.Delete(task.ID) + + wg.Wait() + require.True(waitDone) + _, err = d.InspectTask(task.ID) + require.Equal(drivers.ErrTaskNotFound, err) + + err = d.RecoverTask(handle) + require.NoError(err) + + status, err := d.InspectTask(task.ID) + require.NoError(err) + require.Exactly(originalStatus, status) + + ch, err = harness.WaitTask(context.Background(), task.ID) + require.NoError(err) + + wg.Add(1) + waitDone = false + go func() { + defer wg.Done() + result := <-ch + require.NoError(result.Err) + require.NotZero(result.ExitCode) + require.Equal(9, result.Signal) + waitDone = true + }() + + time.Sleep(300 * time.Millisecond) + require.NoError(d.StopTask(task.ID, 0, "SIGKILL")) + wg.Wait() + require.NoError(d.DestroyTask(task.ID, false)) + require.True(waitDone) +} diff --git a/drivers/rawexec/driver_windows_test.go b/drivers/rawexec/driver_windows_test.go new file mode 100644 index 00000000000..68876b037ed --- /dev/null +++ b/drivers/rawexec/driver_windows_test.go @@ -0,0 +1,96 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build windows + +package rawexec + +import ( + "os" + "testing" + "time" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/plugins/base" + "github.com/hashicorp/nomad/plugins/drivers" + dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" + "github.com/shoenig/test/must" +) + +// TestRawExecDriver_ExecutorKill verifies that killing the executor will stop +// its child processes +func TestRawExecDriver_ExecutorKill(t *testing.T) { + ci.Parallel(t) + + d := newEnabledRawExecDriver(t) + harness := dtestutil.NewDriverHarness(t, d) + t.Cleanup(harness.Kill) + + config := &Config{Enabled: true} + var data []byte + must.NoError(t, base.MsgPackEncode(&data, config)) + bconfig := &base.Config{ + PluginConfig: data, + AgentConfig: &base.AgentConfig{ + Driver: &base.ClientDriverConfig{ + Topology: d.nomadConfig.Topology, + }, + }, + } + must.NoError(t, harness.SetConfig(bconfig)) + + allocID := uuid.Generate() + taskName := "test" + task := &drivers.TaskConfig{ + AllocID: allocID, + ID: uuid.Generate(), + Name: taskName, + Resources: testResources(allocID, taskName), + } + + taskConfig := map[string]interface{}{} + taskConfig["command"] = "Powershell.exe" + taskConfig["args"] = []string{"sleep", "100s"} + + must.NoError(t, task.EncodeConcreteDriverConfig(&taskConfig)) + + cleanup := harness.MkAllocDir(task, false) + t.Cleanup(cleanup) + + handle, _, err := harness.StartTask(task) + must.NoError(t, err) + + var taskState TaskState + must.NoError(t, handle.GetDriverState(&taskState)) + must.NoError(t, harness.WaitUntilStarted(task.ID, 1*time.Second)) + + // forcibly kill the executor, not the workload + must.NotEq(t, taskState.ReattachConfig.Pid, taskState.Pid) + proc, err := os.FindProcess(taskState.ReattachConfig.Pid) + must.NoError(t, err) + + taskProc, err := os.FindProcess(taskState.Pid) + must.NoError(t, err) + + must.NoError(t, proc.Kill()) + t.Logf("killed %d, waiting on %d to stop", taskState.ReattachConfig.Pid, taskState.Pid) + + t.Cleanup(func() { + if taskProc != nil { + taskProc.Kill() + } + }) + + done := make(chan struct{}) + go func() { + taskProc.Wait() + close(done) + }() + + select { + case <-time.After(5 * time.Second): + t.Fatal("expected child process to exit") + case <-done: + } +} diff --git a/drivers/shared/executor/executor_test.go b/drivers/shared/executor/executor_test.go index 50e415d6641..a0e17e66696 100644 --- a/drivers/shared/executor/executor_test.go +++ b/drivers/shared/executor/executor_test.go @@ -1,10 +1,11 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 +//go:build !windows + package executor import ( - "bytes" "context" "fmt" "io" @@ -12,7 +13,6 @@ import ( "path/filepath" "runtime" "strings" - "sync" "syscall" "testing" "time" @@ -59,15 +59,6 @@ var ( compute = topology.Compute() ) -type testExecCmd struct { - command *ExecCommand - allocDir *allocdir.AllocDir - - stdout *bytes.Buffer - stderr *bytes.Buffer - outputCopyDone *sync.WaitGroup -} - // testExecutorContext returns an ExecutorContext and AllocDir. // // The caller is responsible for calling AllocDir.Destroy() to cleanup. @@ -123,38 +114,6 @@ func testExecutorCommand(t *testing.T) *testExecCmd { return testCmd } -// configureTLogging configures a test command executor with buffer as Std{out|err} -// but using os.Pipe so it mimics non-test case where cmd is set with files as Std{out|err} -// the buffers can be used to read command output -func configureTLogging(t *testing.T, testcmd *testExecCmd) { - var stdout, stderr bytes.Buffer - var copyDone sync.WaitGroup - - stdoutPr, stdoutPw, err := os.Pipe() - require.NoError(t, err) - - stderrPr, stderrPw, err := os.Pipe() - require.NoError(t, err) - - copyDone.Add(2) - go func() { - defer copyDone.Done() - io.Copy(&stdout, stdoutPr) - }() - go func() { - defer copyDone.Done() - io.Copy(&stderr, stderrPr) - }() - - testcmd.stdout = &stdout - testcmd.stderr = &stderr - testcmd.outputCopyDone = ©Done - - testcmd.command.stdout = stdoutPw - testcmd.command.stderr = stderrPw - return -} - func TestExecutor_Start_Invalid(t *testing.T) { ci.Parallel(t) invalid := "/bin/foobar" diff --git a/drivers/shared/executor/executor_windows.go b/drivers/shared/executor/executor_windows.go index 457f29a6e02..25134ece5d1 100644 --- a/drivers/shared/executor/executor_windows.go +++ b/drivers/shared/executor/executor_windows.go @@ -9,17 +9,48 @@ import ( "fmt" "os" "syscall" + "unsafe" "golang.org/x/sys/windows" ) -// configure new process group for child process +// configure new process group for child process and creates a JobObject for the +// executor. Children of the executor will be created in the same JobObject +// Ref: https://learn.microsoft.com/en-us/windows/win32/procthread/job-objects func (e *UniversalExecutor) setNewProcessGroup() error { // We need to check that as build flags includes windows for this file if e.childCmd.SysProcAttr == nil { e.childCmd.SysProcAttr = &syscall.SysProcAttr{} } e.childCmd.SysProcAttr.CreationFlags = syscall.CREATE_NEW_PROCESS_GROUP + + // note: we don't call CloseHandle on this job handle because we need to + // hold onto it until the executor exits + job, err := windows.CreateJobObject(nil, nil) + if err != nil { + return fmt.Errorf("could not create Windows job object for executor: %w", err) + } + + info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{ + BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{ + LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, + }, + } + _, err = windows.SetInformationJobObject( + job, + windows.JobObjectExtendedLimitInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info))) + if err != nil { + return fmt.Errorf("could not configure Windows job object for executor: %w", err) + } + + handle := windows.CurrentProcess() + err = windows.AssignProcessToJobObject(job, handle) + if err != nil { + return fmt.Errorf("could not assign executor to Windows job object: %w", err) + } + return nil } diff --git a/drivers/shared/executor/executor_windows_test.go b/drivers/shared/executor/executor_windows_test.go new file mode 100644 index 00000000000..c54cd497254 --- /dev/null +++ b/drivers/shared/executor/executor_windows_test.go @@ -0,0 +1,88 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build windows + +package executor + +import ( + "context" + "os" + "testing" + "time" + + "github.com/hashicorp/nomad/ci" + "github.com/hashicorp/nomad/client/allocdir" + "github.com/hashicorp/nomad/client/lib/numalib" + "github.com/hashicorp/nomad/client/taskenv" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/drivers" + "github.com/hashicorp/nomad/plugins/drivers/fsisolation" + "github.com/shoenig/test/must" +) + +// testExecutorCommand sets up a test task environment. +func testExecutorCommand(t *testing.T) *testExecCmd { + alloc := mock.Alloc() + task := alloc.Job.TaskGroups[0].Tasks[0] + taskEnv := taskenv.NewBuilder(mock.Node(), alloc, task, "global").Build() + + allocDir := allocdir.NewAllocDir(testlog.HCLogger(t), t.TempDir(), t.TempDir(), alloc.ID) + must.NoError(t, allocDir.Build()) + t.Cleanup(func() { allocDir.Destroy() }) + + must.NoError(t, allocDir.NewTaskDir(task).Build(fsisolation.None, nil, task.User)) + td := allocDir.TaskDirs[task.Name] + cmd := &ExecCommand{ + Env: taskEnv.List(), + TaskDir: td.Dir, + Resources: &drivers.Resources{ + NomadResources: &structs.AllocatedTaskResources{ + Cpu: structs.AllocatedCpuResources{ + CpuShares: 500, + }, + Memory: structs.AllocatedMemoryResources{ + MemoryMB: 256, + }, + }, + }, + } + + testCmd := &testExecCmd{ + command: cmd, + allocDir: allocDir, + } + configureTLogging(t, testCmd) + return testCmd +} + +func TestExecutor_ProcessExit(t *testing.T) { + ci.Parallel(t) + + topology := numalib.Scan(numalib.PlatformScanners()) + compute := topology.Compute() + + cmd := testExecutorCommand(t) + cmd.command.Cmd = "Powershell.exe" + cmd.command.Args = []string{"sleep", "30"} + executor := NewExecutor(testlog.HCLogger(t), compute) + + t.Cleanup(func() { executor.Shutdown("SIGKILL", 0) }) + + childPs, err := executor.Launch(cmd.command) + must.NoError(t, err) + must.NonZero(t, childPs.Pid) + + proc, err := os.FindProcess(childPs.Pid) + must.NoError(t, err) + must.NoError(t, proc.Kill()) + + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Second) + t.Cleanup(cancel) + waitPs, err := executor.Wait(ctx) + must.NoError(t, err) + must.Eq(t, 1, waitPs.ExitCode) + must.Eq(t, childPs.Pid, waitPs.Pid) +} diff --git a/drivers/shared/executor/utils_test.go b/drivers/shared/executor/utils_test.go index 24a0598d0d8..b58a6854e09 100644 --- a/drivers/shared/executor/utils_test.go +++ b/drivers/shared/executor/utils_test.go @@ -4,8 +4,13 @@ package executor import ( + "bytes" + "io" + "os" + "sync" "testing" + "github.com/hashicorp/nomad/client/allocdir" "github.com/stretchr/testify/require" ) @@ -29,3 +34,45 @@ func TestUtils_IsolationMode(t *testing.T) { require.Equal(t, tc.exp, result) } } + +type testExecCmd struct { + command *ExecCommand + allocDir *allocdir.AllocDir + + stdout *bytes.Buffer + stderr *bytes.Buffer + outputCopyDone *sync.WaitGroup +} + +// configureTLogging configures a test command executor with buffer as +// Std{out|err} but using os.Pipe so it mimics non-test case where cmd is set +// with files as Std{out|err} the buffers can be used to read command output +func configureTLogging(t *testing.T, testcmd *testExecCmd) { + t.Helper() + var stdout, stderr bytes.Buffer + var copyDone sync.WaitGroup + + stdoutPr, stdoutPw, err := os.Pipe() + require.NoError(t, err) + + stderrPr, stderrPw, err := os.Pipe() + require.NoError(t, err) + + copyDone.Add(2) + go func() { + defer copyDone.Done() + io.Copy(&stdout, stdoutPr) + }() + go func() { + defer copyDone.Done() + io.Copy(&stderr, stderrPr) + }() + + testcmd.stdout = &stdout + testcmd.stderr = &stderr + testcmd.outputCopyDone = ©Done + + testcmd.command.stdout = stdoutPw + testcmd.command.stderr = stderrPw + return +}