Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Refactored listing of node and task executions to shared util
Browse files Browse the repository at this point in the history
Allows for re-use by cache manager

Signed-off-by: Nick Müller <[email protected]>
  • Loading branch information
Nick Müller committed Dec 15, 2022
1 parent 3d9cd15 commit 5cd94b6
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 125 deletions.
108 changes: 20 additions & 88 deletions pkg/manager/impl/node_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package impl

import (
"context"
"strconv"

cloudeventInterfaces "github.com/flyteorg/flyteadmin/pkg/async/cloudevent/interfaces"

Expand All @@ -17,7 +16,6 @@ import (

"github.com/flyteorg/flytestdlib/contextutils"

"github.com/flyteorg/flyteadmin/pkg/manager/impl/shared"
"github.com/flyteorg/flytestdlib/promutils"
"github.com/prometheus/client_golang/prometheus"

Expand Down Expand Up @@ -74,11 +72,6 @@ const (
alreadyInTerminalStatus
)

var isParent = common.NewMapFilter(map[string]interface{}{
shared.ParentTaskExecutionID: nil,
shared.ParentID: nil,
})

func getNodeExecutionContext(ctx context.Context, identifier *core.NodeExecutionIdentifier) context.Context {
ctx = contextutils.WithProjectDomain(ctx, identifier.ExecutionId.Project, identifier.ExecutionId.Domain)
ctx = contextutils.WithExecutionID(ctx, identifier.ExecutionId.Name)
Expand Down Expand Up @@ -369,48 +362,23 @@ func (m *NodeExecutionManager) GetNodeExecution(
return nodeExecution, nil
}

func (m *NodeExecutionManager) listNodeExecutions(
ctx context.Context, identifierFilters []common.InlineFilter,
requestFilters string, limit uint32, requestToken string, sortBy *admin.Sort, mapFilters []common.MapFilter) (
*admin.NodeExecutionList, error) {

filters, err := util.AddRequestFilters(requestFilters, common.NodeExecution, identifierFilters)
if err != nil {
func (m *NodeExecutionManager) ListNodeExecutions(
ctx context.Context, request admin.NodeExecutionListRequest) (*admin.NodeExecutionList, error) {
// Check required fields
if err := validation.ValidateNodeExecutionListRequest(request); err != nil {
return nil, err
}
var sortParameter common.SortParameter
if sortBy != nil {
sortParameter, err = common.NewSortParameter(*sortBy)
if err != nil {
return nil, err
}
}
offset, err := validation.ValidateToken(requestToken)
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid pagination token %s for ListNodeExecutions", requestToken)
}
listInput := repoInterfaces.ListResourceInput{
Limit: int(limit),
Offset: offset,
InlineFilters: filters,
SortParameter: sortParameter,
}
ctx = getExecutionContext(ctx, request.WorkflowExecutionId)

listInput.MapFilters = mapFilters
output, err := m.db.NodeExecutionRepo().List(ctx, listInput)
nodeExecutions, token, err := util.ListNodeExecutionsForWorkflow(ctx, m.db, request.WorkflowExecutionId,
request.UniqueParentId, request.Filters, request.Limit, request.Token, request.SortBy)
if err != nil {
logger.Debugf(ctx, "Failed to list node executions for request with err %v", err)
return nil, err
}

var token string
if len(output.NodeExecutions) == int(limit) {
token = strconv.Itoa(offset + len(output.NodeExecutions))
}
nodeExecutionList, err := m.transformNodeExecutionModelList(ctx, output.NodeExecutions)
nodeExecutionList, err := m.transformNodeExecutionModelList(ctx, nodeExecutions)
if err != nil {
logger.Debugf(ctx, "failed to transform node execution models for request with err: %v", err)
logger.Debugf(ctx, "failed to transform node execution models for request [%+v] with err: %v", request, err)
return nil, err
}

Expand All @@ -420,42 +388,6 @@ func (m *NodeExecutionManager) listNodeExecutions(
}, nil
}

func (m *NodeExecutionManager) ListNodeExecutions(
ctx context.Context, request admin.NodeExecutionListRequest) (*admin.NodeExecutionList, error) {
// Check required fields
if err := validation.ValidateNodeExecutionListRequest(request); err != nil {
return nil, err
}
ctx = getExecutionContext(ctx, request.WorkflowExecutionId)

identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, *request.WorkflowExecutionId)
if err != nil {
return nil, err
}
var mapFilters []common.MapFilter
if request.UniqueParentId != "" {
parentNodeExecution, err := util.GetNodeExecutionModel(ctx, m.db, &core.NodeExecutionIdentifier{
ExecutionId: request.WorkflowExecutionId,
NodeId: request.UniqueParentId,
})
if err != nil {
return nil, err
}
parentIDFilter, err := common.NewSingleValueFilter(
common.NodeExecution, common.Equal, shared.ParentID, parentNodeExecution.ID)
if err != nil {
return nil, err
}
identifierFilters = append(identifierFilters, parentIDFilter)
} else {
mapFilters = []common.MapFilter{
isParent,
}
}
return m.listNodeExecutions(
ctx, identifierFilters, request.Filters, request.Limit, request.Token, request.SortBy, mapFilters)
}

// Filters on node executions matching the execution parameters (execution project, domain, and name) as well as the
// parent task execution id corresponding to the task execution identified in the request params.
func (m *NodeExecutionManager) ListNodeExecutionsForTask(
Expand All @@ -465,23 +397,23 @@ func (m *NodeExecutionManager) ListNodeExecutionsForTask(
return nil, err
}
ctx = getTaskExecutionContext(ctx, request.TaskExecutionId)
identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(
ctx, *request.TaskExecutionId.NodeExecutionId.ExecutionId)
if err != nil {
return nil, err
}
parentTaskExecutionModel, err := util.GetTaskExecutionModel(ctx, m.db, request.TaskExecutionId)

nodeExecutions, token, err := util.ListNodeExecutionsForTask(ctx, m.db, request.TaskExecutionId,
request.TaskExecutionId.NodeExecutionId.ExecutionId, request.Filters, request.Limit, request.Token, request.SortBy)
if err != nil {
return nil, err
}
nodeIDFilter, err := common.NewSingleValueFilter(
common.NodeExecution, common.Equal, shared.ParentTaskExecutionID, parentTaskExecutionModel.ID)

nodeExecutionList, err := m.transformNodeExecutionModelList(ctx, nodeExecutions)
if err != nil {
logger.Debugf(ctx, "failed to transform node execution models for request [%+v] with err: %v", request, err)
return nil, err
}
identifierFilters = append(identifierFilters, nodeIDFilter)
return m.listNodeExecutions(
ctx, identifierFilters, request.Filters, request.Limit, request.Token, request.SortBy, nil)

return &admin.NodeExecutionList{
NodeExecutions: nodeExecutionList,
Token: token,
}, nil
}

func (m *NodeExecutionManager) GetNodeExecutionData(
Expand Down
41 changes: 4 additions & 37 deletions pkg/manager/impl/task_execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package impl
import (
"context"
"fmt"
"strconv"

cloudeventInterfaces "github.com/flyteorg/flyteadmin/pkg/async/cloudevent/interfaces"

Expand Down Expand Up @@ -249,50 +248,18 @@ func (m *TaskExecutionManager) ListTaskExecutions(
}
ctx = getNodeExecutionContext(ctx, request.NodeExecutionId)

identifierFilters, err := util.GetNodeExecutionIdentifierFilters(ctx, *request.NodeExecutionId)
taskExecutions, token, err := util.ListTaskExecutions(ctx, m.db, request.NodeExecutionId, request.Filters,
request.Limit, request.Token, request.SortBy)
if err != nil {
return nil, err
}

filters, err := util.AddRequestFilters(request.Filters, common.TaskExecution, identifierFilters)
if err != nil {
return nil, err
}
var sortParameter common.SortParameter
if request.SortBy != nil {
sortParameter, err = common.NewSortParameter(*request.SortBy)
if err != nil {
return nil, err
}
}

offset, err := validation.ValidateToken(request.Token)
if err != nil {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid pagination token %s for ListTaskExecutions", request.Token)
}

output, err := m.db.TaskExecutionRepo().List(ctx, repoInterfaces.ListResourceInput{
InlineFilters: filters,
Offset: offset,
Limit: int(request.Limit),
SortParameter: sortParameter,
})
if err != nil {
logger.Debugf(ctx, "Failed to list task executions with request [%+v] with err %v",
request, err)
return nil, err
}

taskExecutionList, err := transformers.FromTaskExecutionModels(output.TaskExecutions)
taskExecutionList, err := transformers.FromTaskExecutionModels(taskExecutions)
if err != nil {
logger.Debugf(ctx, "failed to transform task execution models for request [%+v] with err: %v", request, err)
return nil, err
}
var token string
if len(taskExecutionList) == int(request.Limit) {
token = strconv.Itoa(offset + len(taskExecutionList))
}

return &admin.TaskExecutionList{
TaskExecutions: taskExecutionList,
Token: token,
Expand Down
145 changes: 145 additions & 0 deletions pkg/manager/impl/util/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package util

import (
"context"
"strconv"
"time"

"github.com/flyteorg/flyteadmin/pkg/common"
Expand Down Expand Up @@ -298,3 +299,147 @@ func MergeIntoExecConfig(workflowExecConfig admin.WorkflowExecutionConfig, spec

return workflowExecConfig
}

func ListNodeExecutions(ctx context.Context, repo repoInterfaces.Repository, identifierFilters []common.InlineFilter,
requestFilters string, limit uint32, requestToken string, sortBy *admin.Sort,
mapFilters []common.MapFilter) ([]models.NodeExecution, string, error) {
filters, err := AddRequestFilters(requestFilters, common.NodeExecution, identifierFilters)
if err != nil {
return nil, "", err
}
var sortParameter common.SortParameter
if sortBy != nil {
sortParameter, err = common.NewSortParameter(*sortBy)
if err != nil {
return nil, "", err
}
}
offset, err := validation.ValidateToken(requestToken)
if err != nil {
return nil, "", errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid pagination token %s for ListNodeExecutions", requestToken)
}
listInput := repoInterfaces.ListResourceInput{
Limit: int(limit),
Offset: offset,
InlineFilters: filters,
SortParameter: sortParameter,
MapFilters: mapFilters,
}

output, err := repo.NodeExecutionRepo().List(ctx, listInput)
if err != nil {
logger.Debugf(ctx, "Failed to list node executions: %v", err)
return nil, "", err
}

var token string
if len(output.NodeExecutions) == int(limit) {
token = strconv.Itoa(offset + len(output.NodeExecutions))
}

return output.NodeExecutions, token, nil
}

func ListNodeExecutionsForWorkflow(ctx context.Context, repo repoInterfaces.Repository,
workflowExecutionID *core.WorkflowExecutionIdentifier, uniqueParentID string, requestFilters string,
limit uint32, requestToken string, sortBy *admin.Sort) ([]models.NodeExecution, string, error) {
identifierFilters, err := GetWorkflowExecutionIdentifierFilters(ctx, *workflowExecutionID)
if err != nil {
return nil, "", err
}

var mapFilters []common.MapFilter
if len(uniqueParentID) > 0 {
parentNodeExecution, err := GetNodeExecutionModel(ctx, repo, &core.NodeExecutionIdentifier{
ExecutionId: workflowExecutionID,
NodeId: uniqueParentID,
})
if err != nil {
return nil, "", err
}
parentIDFilter, err := common.NewSingleValueFilter(
common.NodeExecution, common.Equal, shared.ParentID, parentNodeExecution.ID)
if err != nil {
return nil, "", err
}
identifierFilters = append(identifierFilters, parentIDFilter)
} else {
mapFilters = []common.MapFilter{
common.NewMapFilter(map[string]interface{}{
shared.ParentTaskExecutionID: nil,
shared.ParentID: nil,
}),
}
}

return ListNodeExecutions(ctx, repo, identifierFilters, requestFilters, limit, requestToken, sortBy, mapFilters)
}

func ListNodeExecutionsForTask(ctx context.Context, repo repoInterfaces.Repository,
taskExecutionID *core.TaskExecutionIdentifier, workflowExecutionID *core.WorkflowExecutionIdentifier,
requestFilters string, limit uint32, requestToken string, sortBy *admin.Sort) ([]models.NodeExecution, string, error) {
identifierFilters, err := GetWorkflowExecutionIdentifierFilters(ctx, *workflowExecutionID)
if err != nil {
return nil, "", err
}

parentTaskExecutionModel, err := GetTaskExecutionModel(ctx, repo, taskExecutionID)
if err != nil {
return nil, "", err
}

nodeIDFilter, err := common.NewSingleValueFilter(
common.NodeExecution, common.Equal, shared.ParentTaskExecutionID, parentTaskExecutionModel.ID)
if err != nil {
return nil, "", err
}
identifierFilters = append(identifierFilters, nodeIDFilter)

return ListNodeExecutions(ctx, repo, identifierFilters, requestFilters, limit, requestToken, sortBy, nil)
}

func ListTaskExecutions(ctx context.Context, repo repoInterfaces.Repository,
nodeExecutionID *core.NodeExecutionIdentifier, requestFilters string, limit uint32, requestToken string,
sortBy *admin.Sort) ([]models.TaskExecution, string, error) {
identifierFilters, err := GetNodeExecutionIdentifierFilters(ctx, *nodeExecutionID)
if err != nil {
return nil, "", err
}

filters, err := AddRequestFilters(requestFilters, common.TaskExecution, identifierFilters)
if err != nil {
return nil, "", err
}
var sortParameter common.SortParameter
if sortBy != nil {
sortParameter, err = common.NewSortParameter(*sortBy)
if err != nil {
return nil, "", err
}
}

offset, err := validation.ValidateToken(requestToken)
if err != nil {
return nil, "", errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid pagination token %s for ListTaskExecutions", requestToken)
}

output, err := repo.TaskExecutionRepo().List(ctx, repoInterfaces.ListResourceInput{
InlineFilters: filters,
Offset: offset,
Limit: int(limit),
SortParameter: sortParameter,
})
if err != nil {
logger.Debugf(ctx, "Failed to list task executions: %v", err)
return nil, "", err
}

var token string
if len(output.TaskExecutions) == int(limit) {
token = strconv.Itoa(offset + len(output.TaskExecutions))
}

return output.TaskExecutions, token, nil
}

0 comments on commit 5cd94b6

Please sign in to comment.