Skip to content

Commit

Permalink
Add tests for PR #4726 (#5200)
Browse files Browse the repository at this point in the history
* Add tests to ensure the phase version is bumped in kubeflow plugin if reason changes within the same phase

Signed-off-by: Fabio Graetz <[email protected]>

* Test that ray and dask plugins bump phase version in GetTaskPhase

Signed-off-by: Fabio Graetz <[email protected]>

* Test phase version increase when reason changes for spark plugin

Signed-off-by: Fabio Graetz <[email protected]>

* Fix ray tests after rebase

Signed-off-by: Fabio Graetz <[email protected]>

* Make lint pass

Signed-off-by: Fabio Graetz <[email protected]>

---------

Signed-off-by: Fabio Graetz <[email protected]>
  • Loading branch information
fg91 authored Jun 5, 2024
1 parent d409b3b commit ecd65a0
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 83 deletions.
43 changes: 30 additions & 13 deletions flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func dummyDaskTaskTemplate(customImage string, resources *core.Resources, podTem
}
}

func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, isInterruptible bool) pluginsCore.TaskExecutionContext {
func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, isInterruptible bool, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext {
taskCtx := &mocks.TaskExecutionContext{}

inputReader := &pluginIOMocks.InputReader{}
Expand Down Expand Up @@ -199,11 +199,10 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc
taskExecutionMetadata.OnGetOverrides().Return(overrides)
taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata)

inputState := k8s.PluginState{}
pluginStateReaderMock := mocks.PluginStateReader{}
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return(
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return(
func(v interface{}) uint8 {
*(v.(*k8s.PluginState)) = inputState
*(v.(*k8s.PluginState)) = pluginState
return 0
},
func(v interface{}) error {
Expand All @@ -218,7 +217,7 @@ func TestBuildResourceDaskHappyPath(t *testing.T) {
daskResourceHandler := daskResourceHandler{}

taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -329,7 +328,7 @@ func TestBuildResourceDaskCustomImages(t *testing.T) {

daskResourceHandler := daskResourceHandler{}
taskTemplate := dummyDaskTaskTemplate(customImage, nil, "")
taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -362,7 +361,7 @@ func TestBuildResourceDaskDefaultResoureRequirements(t *testing.T) {

daskResourceHandler := daskResourceHandler{}
taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -419,7 +418,7 @@ func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) {

daskResourceHandler := daskResourceHandler{}
taskTemplate := dummyDaskTaskTemplate("", &protobufResources, "")
taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -474,7 +473,7 @@ func TestBuildResourceDaskInterruptible(t *testing.T) {
daskResourceHandler := daskResourceHandler{}

taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, true)
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, true, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -508,7 +507,7 @@ func TestBuildResouceDaskUsePodTemplate(t *testing.T) {
flytek8s.DefaultPodTemplateStore.Store(podTemplate)
daskResourceHandler := daskResourceHandler{}
taskTemplate := dummyDaskTaskTemplate("", nil, podTemplateName)
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -628,7 +627,7 @@ func TestBuildResourceDaskExtendedResources(t *testing.T) {
t.Run(f.name, func(t *testing.T) {
taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskTemplate.ExtendedResources = f.extendedResourcesBase
taskContext := dummyDaskTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, false)
taskContext := dummyDaskTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, false, k8s.PluginState{})
daskResourceHandler := daskResourceHandler{}
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
Expand Down Expand Up @@ -694,7 +693,7 @@ func TestBuildIdentityResourceDask(t *testing.T) {
}

taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{})
identityResources, err := daskResourceHandler.BuildIdentityResource(context.TODO(), taskContext.TaskExecutionMetadata())
if err != nil {
panic(err)
Expand All @@ -707,7 +706,7 @@ func TestGetTaskPhaseDask(t *testing.T) {
ctx := context.TODO()

taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false)
taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{})

taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(""))
assert.NoError(t, err)
Expand Down Expand Up @@ -751,3 +750,21 @@ func TestGetTaskPhaseDask(t *testing.T) {
assert.NotNil(t, taskPhase.Info().Logs)
assert.Nil(t, err)
}

func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) {
daskResourceHandler := daskResourceHandler{}
ctx := context.TODO()

pluginState := k8s.PluginState{
Phase: pluginsCore.PhaseInitializing,
PhaseVersion: pluginsCore.DefaultPhaseVersion,
Reason: "task submitted to K8s",
}
taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, pluginState)

taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobCreated))

assert.NoError(t, err)
assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1)
}
48 changes: 32 additions & 16 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate {
}
}

func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources) pluginsCore.TaskExecutionContext {
func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.ResourceRequirements, extendedResources *core.ExtendedResources, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext {
taskCtx := &mocks.TaskExecutionContext{}
inputReader := &pluginIOMocks.InputReader{}
inputReader.OnGetInputPrefixPath().Return("/input/prefix")
Expand Down Expand Up @@ -172,11 +172,10 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate, resources *corev1.Reso
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)

inputState := k8s.PluginState{}
pluginStateReaderMock := mocks.PluginStateReader{}
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return(
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return(
func(v interface{}) uint8 {
*(v.(*k8s.PluginState)) = inputState
*(v.(*k8s.PluginState)) = pluginState
return 0
},
func(v interface{}) error {
Expand Down Expand Up @@ -289,7 +288,7 @@ func dummyMPIJobResource(mpiResourceHandler mpiOperatorResourceHandler,

mpiObj := dummyMPICustomObj(workers, launcher, slots)
taskTemplate := dummyMPITaskTemplate(mpiID, mpiObj)
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
if err != nil {
panic(err)
}
Expand All @@ -316,7 +315,7 @@ func TestBuildResourceMPI(t *testing.T) {
mpiObj := dummyMPICustomObj(100, 50, 1)
taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj)

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)

Expand Down Expand Up @@ -352,13 +351,13 @@ func TestBuildResourceMPIForWrongInput(t *testing.T) {
mpiObj := dummyMPICustomObj(0, 0, 1)
taskTemplate := dummyMPITaskTemplate(mpiID, mpiObj)

_, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
_, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.Error(t, err)

mpiObj = dummyMPICustomObj(1, 1, 1)
taskTemplate = dummyMPITaskTemplate(mpiID2, mpiObj)

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
app, ok := resource.(*kubeflowv1.MPIJob)
assert.Nil(t, err)
assert.Equal(t, true, ok)
Expand Down Expand Up @@ -472,7 +471,7 @@ func TestBuildResourceMPIExtendedResources(t *testing.T) {
mpiObj := dummyMPICustomObj(100, 50, 1)
taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj)
taskTemplate.ExtendedResources = f.extendedResourcesBase
taskContext := dummyMPITaskContext(taskTemplate, f.resources, f.extendedResourcesOverride)
taskContext := dummyMPITaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, k8s.PluginState{})
mpiResourceHandler := mpiOperatorResourceHandler{}
r, err := mpiResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
Expand Down Expand Up @@ -504,7 +503,7 @@ func TestGetTaskPhase(t *testing.T) {
return dummyMPIJobResource(mpiResourceHandler, 2, 1, 1, conditionType)
}

taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil)
taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil, k8s.PluginState{})
taskPhase, err := mpiResourceHandler.GetTaskPhase(ctx, taskCtx, dummyMPIJobResourceCreator(mpiOp.JobCreated))
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase())
Expand Down Expand Up @@ -536,6 +535,23 @@ func TestGetTaskPhase(t *testing.T) {
assert.Nil(t, err)
}

func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) {
mpiResourceHandler := mpiOperatorResourceHandler{}
ctx := context.TODO()

pluginState := k8s.PluginState{
Phase: pluginsCore.PhaseQueued,
PhaseVersion: pluginsCore.DefaultPhaseVersion,
Reason: "task submitted to K8s",
}
taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(2, 1, 1)), resourceRequirements, nil, pluginState)

taskPhase, err := mpiResourceHandler.GetTaskPhase(ctx, taskCtx, dummyMPIJobResource(mpiResourceHandler, 2, 1, 1, mpiOp.JobCreated))

assert.NoError(t, err)
assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1)
}

func TestGetLogs(t *testing.T) {
assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{
IsKubernetesEnabled: true,
Expand All @@ -548,7 +564,7 @@ func TestGetLogs(t *testing.T) {

mpiResourceHandler := mpiOperatorResourceHandler{}
mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning)
taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil)
taskCtx := dummyMPITaskContext(dummyMPITaskTemplate("", dummyMPICustomObj(workers, launcher, slots)), resourceRequirements, nil, k8s.PluginState{})
jobLogs, err := common.GetLogs(taskCtx, common.MPITaskType, mpiJob.ObjectMeta, false, workers, launcher, 0, 0)
assert.NoError(t, err)
assert.Equal(t, 2, len(jobLogs))
Expand Down Expand Up @@ -581,7 +597,7 @@ func TestReplicaCounts(t *testing.T) {
mpiObj := dummyMPICustomObj(test.workerReplicaCount, test.launcherReplicaCount, 1)
taskTemplate := dummyMPITaskTemplate(mpiID2, mpiObj)

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
if test.expectError {
assert.Error(t, err)
assert.Nil(t, resource)
Expand Down Expand Up @@ -667,7 +683,7 @@ func TestBuildResourceMPIV1(t *testing.T) {
taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig)
taskTemplate.TaskTypeVersion = 1

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)

Expand Down Expand Up @@ -719,7 +735,7 @@ func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) {
taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig)
taskTemplate.TaskTypeVersion = 1

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)

Expand Down Expand Up @@ -782,7 +798,7 @@ func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) {
taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig)
taskTemplate.TaskTypeVersion = 1

resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)

Expand All @@ -797,7 +813,7 @@ func TestGetReplicaCount(t *testing.T) {
mpiResourceHandler := mpiOperatorResourceHandler{}
tfObj := dummyMPICustomObj(1, 1, 0)
taskTemplate := dummyMPITaskTemplate("the job", tfObj)
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil, k8s.PluginState{}))
assert.NoError(t, err)
assert.NotNil(t, resource)
MPIJob, ok := resource.(*kubeflowv1.MPIJob)
Expand Down
Loading

0 comments on commit ecd65a0

Please sign in to comment.