diff --git a/flyteplugins/go/tasks/pluginmachinery/core/plugin.go b/flyteplugins/go/tasks/pluginmachinery/core/plugin.go index 18df83f83c..634d8b6591 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/plugin.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/plugin.go @@ -3,6 +3,9 @@ package core import ( "context" "fmt" + "sync" + + "k8s.io/utils/strings/slices" ) //go:generate mockery -all -case=underscore @@ -55,7 +58,27 @@ type Plugin interface { Finalize(ctx context.Context, tCtx TaskExecutionContext) error } -// Loads and validates a plugin. +type AgentService struct { + mu sync.RWMutex + supportedTaskTypes []TaskType + CorePlugin Plugin +} + +// ContainTaskType check if agent supports this task type. +func (p *AgentService) ContainTaskType(taskType TaskType) bool { + p.mu.RLock() + defer p.mu.RUnlock() + return slices.Contains(p.supportedTaskTypes, taskType) +} + +// SetSupportedTaskType set supportTaskType in the agent service. +func (p *AgentService) SetSupportedTaskType(taskTypes []TaskType) { + p.mu.Lock() + defer p.mu.Unlock() + p.supportedTaskTypes = taskTypes +} + +// LoadPlugin Loads and validates a plugin. func LoadPlugin(ctx context.Context, iCtx SetupContext, entry PluginEntry) (Plugin, error) { plugin, err := entry.LoadPlugin(ctx, iCtx) if err != nil { diff --git a/flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go b/flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go index ba6edc4f37..b5cd07c3fc 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go @@ -93,3 +93,17 @@ func TestLoadPlugin(t *testing.T) { }) } + +func TestAgentService(t *testing.T) { + agentService := core.AgentService{} + taskTypes := []core.TaskType{"sensor", "chatgpt"} + + for _, taskType := range taskTypes { + assert.Equal(t, false, agentService.ContainTaskType(taskType)) + } + + agentService.SetSupportedTaskType(taskTypes) + for _, taskType := range taskTypes { + assert.Equal(t, true, agentService.ContainTaskType(taskType)) + } +} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index e698d32121..35e6662107 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -90,17 +90,11 @@ func getFinalContext(ctx context.Context, operation string, agent *Deployment) ( return context.WithTimeout(ctx, timeout) } -func updateAgentRegistry(ctx context.Context, cs *ClientSet) { - agentRegistry := make(Registry) +func getAgentRegistry(ctx context.Context, cs *ClientSet) Registry { + newAgentRegistry := make(Registry) cfg := GetConfig() var agentDeployments []*Deployment - // Ensure that the old configuration is backward compatible - for taskType, agentDeploymentID := range cfg.AgentForTaskTypes { - agent := Agent{AgentDeployment: cfg.AgentDeployments[agentDeploymentID], IsSync: false} - agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: &agent} - } - if len(cfg.DefaultAgent.Endpoint) != 0 { agentDeployments = append(agentDeployments, &cfg.DefaultAgent) } @@ -137,27 +131,36 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) { deprecatedSupportedTaskTypes := agent.SupportedTaskTypes for _, supportedTaskType := range deprecatedSupportedTaskTypes { agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync} - agentRegistry[supportedTaskType] = map[int32]*Agent{defaultTaskTypeVersion: agent} + newAgentRegistry[supportedTaskType] = map[int32]*Agent{defaultTaskTypeVersion: agent} } supportedTaskCategories := agent.SupportedTaskCategories for _, supportedCategory := range supportedTaskCategories { agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync} - agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent} + newAgentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent} } } - // If the agent doesn't implement the metadata service, we construct the registry based on the configuration - for taskType, agentDeploymentID := range cfg.AgentForTaskTypes { - if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok { - if _, ok := agentRegistry[taskType]; !ok { - agent := &Agent{AgentDeployment: agentDeployment, IsSync: false} - agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent} - } + } + + // If the agent doesn't implement the metadata service, we construct the registry based on the configuration + for taskType, agentDeploymentID := range cfg.AgentForTaskTypes { + if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok { + if _, ok := newAgentRegistry[taskType]; !ok { + agent := &Agent{AgentDeployment: agentDeployment, IsSync: false} + newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent} } } } - logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry)) - setAgentRegistry(agentRegistry) + + // Ensure that the old configuration is backward compatible + for _, taskType := range cfg.SupportedTaskTypes { + if _, ok := newAgentRegistry[taskType]; !ok { + agent := &Agent{AgentDeployment: &cfg.DefaultAgent, IsSync: false} + newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent} + } + } + + return newAgentRegistry } func getAgentClientSets(ctx context.Context) *ClientSet { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index f3e626524c..ba74fbf5d2 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -35,10 +35,6 @@ import ( ) func TestEndToEnd(t *testing.T) { - agentRegistry = Registry{ - "openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}}, - "spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}}, - } iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { return nil } @@ -117,7 +113,7 @@ func TestEndToEnd(t *testing.T) { t.Run("failed to create a job", func(t *testing.T) { agentPlugin := newMockAsyncAgentPlugin() agentPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return Plugin{ + return &Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), cs: &ClientSet{ @@ -259,6 +255,9 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext { func newMockAsyncAgentPlugin() webapi.PluginEntry { asyncAgentClient := new(agentMocks.AsyncAgentServiceClient) + agentRegistry := Registry{ + "spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}}, + } mockCreateRequestMatcher := mock.MatchedBy(func(request *admin.CreateTaskRequest) bool { expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"} @@ -283,7 +282,7 @@ func newMockAsyncAgentPlugin() webapi.PluginEntry { ID: "agent-service", SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark"}, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return Plugin{ + return &Plugin{ metricScope: iCtx.MetricsScope(), cfg: &cfg, cs: &ClientSet{ @@ -291,12 +290,17 @@ func newMockAsyncAgentPlugin() webapi.PluginEntry { defaultAgentEndpoint: asyncAgentClient, }, }, + registry: agentRegistry, }, nil }, } } func newMockSyncAgentPlugin() webapi.PluginEntry { + agentRegistry := Registry{ + "openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}}, + } + syncAgentClient := new(agentMocks.SyncAgentServiceClient) output, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) resource := &admin.Resource{Phase: flyteIdlCore.TaskExecution_SUCCEEDED, Outputs: output} @@ -323,7 +327,7 @@ func newMockSyncAgentPlugin() webapi.PluginEntry { ID: "agent-service", SupportedTaskTypes: []core.TaskType{"openai"}, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return Plugin{ + return &Plugin{ metricScope: iCtx.MetricsScope(), cfg: &cfg, cs: &ClientSet{ @@ -331,6 +335,7 @@ func newMockSyncAgentPlugin() webapi.PluginEntry { defaultAgentEndpoint: syncAgentClient, }, }, + registry: agentRegistry, }, nil }, } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 5470247ab7..20a65ccba1 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -24,29 +24,16 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils" ) -type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent - -var ( - agentRegistry Registry - mu sync.RWMutex -) +const ID = "agent-service" -func getAgentRegistry() Registry { - mu.Lock() - defer mu.Unlock() - return agentRegistry -} - -func setAgentRegistry(r Registry) { - mu.Lock() - defer mu.Unlock() - agentRegistry = r -} +type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent type Plugin struct { metricScope promutils.Scope cfg *Config cs *ClientSet + registry Registry + mu sync.RWMutex } type ResourceWrapper struct { @@ -69,18 +56,24 @@ type ResourceMetaWrapper struct { TaskCategory admin.TaskCategory } -func (p Plugin) GetConfig() webapi.PluginConfig { +func (p *Plugin) setRegistry(r Registry) { + p.mu.Lock() + defer p.mu.Unlock() + p.registry = r +} + +func (p *Plugin) GetConfig() webapi.PluginConfig { return GetConfig().WebAPI } -func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( +func (p *Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) { // Resource requirements are assumed to be the same. return "default", p.cfg.ResourceConstraints, nil } -func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, +func (p *Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, webapi.Resource, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) if err != nil { @@ -113,7 +106,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() taskCategory := admin.TaskCategory{Name: taskTemplate.Type, Version: taskTemplate.TaskTypeVersion} - agent, isSync := getFinalAgent(&taskCategory, p.cfg) + agent, isSync := p.getFinalAgent(&taskCategory, p.cfg) taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) @@ -149,7 +142,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR }, nil, nil } -func (p Plugin) ExecuteTaskSync( +func (p *Plugin) ExecuteTaskSync( ctx context.Context, client service.SyncAgentServiceClient, header *admin.CreateRequestHeader, @@ -206,9 +199,9 @@ func (p Plugin) ExecuteTaskSync( }, err } -func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { +func (p *Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg) + agent, _ := p.getFinalAgent(&metadata.TaskCategory, p.cfg) client, err := p.getAsyncAgentClient(ctx, agent) if err != nil { @@ -236,12 +229,12 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba }, nil } -func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error { +func (p *Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error { if taskCtx.ResourceMeta() == nil { return nil } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg) + agent, _ := p.getFinalAgent(&metadata.TaskCategory, p.cfg) client, err := p.getAsyncAgentClient(ctx, agent) if err != nil { @@ -259,7 +252,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return err } -func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { +func (p *Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { resource := taskCtx.Resource().(ResourceWrapper) taskInfo := &core.TaskInfo{Logs: resource.LogLinks} @@ -311,7 +304,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase return core.PhaseInfoUndefined, pluginErrors.Errorf(core.SystemErrorCode, "unknown execution state [%v].", resource.State) } -func (p Plugin) getSyncAgentClient(ctx context.Context, agent *Deployment) (service.SyncAgentServiceClient, error) { +func (p *Plugin) getSyncAgentClient(ctx context.Context, agent *Deployment) (service.SyncAgentServiceClient, error) { client, ok := p.cs.syncAgentClients[agent.Endpoint] if !ok { conn, err := getGrpcConnection(ctx, agent) @@ -324,7 +317,7 @@ func (p Plugin) getSyncAgentClient(ctx context.Context, agent *Deployment) (serv return client, nil } -func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (service.AsyncAgentServiceClient, error) { +func (p *Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (service.AsyncAgentServiceClient, error) { client, ok := p.cs.asyncAgentClients[agent.Endpoint] if !ok { conn, err := getGrpcConnection(ctx, agent) @@ -337,13 +330,25 @@ func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (ser return client, nil } -func (p Plugin) watchAgents(ctx context.Context) { +func (p *Plugin) watchAgents(ctx context.Context, agentService *core.AgentService) { go wait.Until(func() { clientSet := getAgentClientSets(ctx) - updateAgentRegistry(ctx, clientSet) + agentRegistry := getAgentRegistry(ctx, clientSet) + p.setRegistry(agentRegistry) + agentService.SetSupportedTaskType(maps.Keys(agentRegistry)) }, p.cfg.PollInterval.Duration, ctx.Done()) } +func (p *Plugin) getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + + if agent, exists := p.registry[taskCategory.Name][taskCategory.Version]; exists { + return agent.AgentDeployment, agent.IsSync + } + return &cfg.DefaultAgent, false +} + func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *flyteIdl.LiteralMap) error { taskTemplate, err := taskCtx.TaskReader().Read(ctx) if err != nil { @@ -366,14 +371,6 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *fly return taskCtx.OutputWriter().Put(ctx, opReader) } -func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { - r := getAgentRegistry() - if agent, exists := r[taskCategory.Name][taskCategory.Version]; exists { - return agent.AgentDeployment, agent.IsSync - } - return &cfg.DefaultAgent, false -} - func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata) admin.TaskExecutionMetadata { taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID() @@ -388,13 +385,12 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata } } -func newAgentPlugin() webapi.PluginEntry { +func newAgentPlugin(agentService *core.AgentService) webapi.PluginEntry { ctx := context.Background() cfg := GetConfig() - clientSet := getAgentClientSets(ctx) - updateAgentRegistry(ctx, clientSet) - supportedTaskTypes := append(maps.Keys(getAgentRegistry()), cfg.SupportedTaskTypes...) + agentRegistry := getAgentRegistry(ctx, clientSet) + supportedTaskTypes := maps.Keys(agentRegistry) return webapi.PluginEntry{ ID: "agent-service", @@ -404,15 +400,16 @@ func newAgentPlugin() webapi.PluginEntry { metricScope: iCtx.MetricsScope(), cfg: cfg, cs: clientSet, + registry: agentRegistry, } - plugin.watchAgents(ctx) + plugin.watchAgents(ctx, agentService) return plugin, nil }, } } -func RegisterAgentPlugin() { +func RegisterAgentPlugin(agentService *core.AgentService) { gob.Register(ResourceMetaWrapper{}) gob.Register(ResourceWrapper{}) - pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin()) + pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin(agentService)) } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 19af85eed3..3db1c464b6 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -2,7 +2,6 @@ package agent import ( "context" - "sort" "testing" "time" @@ -35,9 +34,12 @@ func TestPlugin(t *testing.T) { cfg.AgentDeployments = map[string]*Deployment{"spark_agent": {Endpoint: "localhost:80"}} cfg.AgentForTaskTypes = map[string]string{"spark": "spark_agent", "bar": "bar_agent"} + agent := &Agent{AgentDeployment: &Deployment{Endpoint: "localhost:80"}} + agentRegistry := Registry{"spark": {defaultTaskTypeVersion: agent}} plugin := Plugin{ metricScope: fakeSetupContext.MetricsScope(), cfg: GetConfig(), + registry: agentRegistry, } t.Run("get config", func(t *testing.T) { err := SetConfig(&cfg) @@ -59,16 +61,14 @@ func TestPlugin(t *testing.T) { }) t.Run("test getFinalAgent", func(t *testing.T) { - agent := &Agent{AgentDeployment: &Deployment{Endpoint: "localhost:80"}} - agentRegistry = Registry{"spark": {defaultTaskTypeVersion: agent}} spark := &admin.TaskCategory{Name: "spark", Version: defaultTaskTypeVersion} foo := &admin.TaskCategory{Name: "foo", Version: defaultTaskTypeVersion} bar := &admin.TaskCategory{Name: "bar", Version: defaultTaskTypeVersion} - agentDeployment, _ := getFinalAgent(spark, &cfg) + agentDeployment, _ := plugin.getFinalAgent(spark, &cfg) assert.Equal(t, agentDeployment.Endpoint, "localhost:80") - agentDeployment, _ = getFinalAgent(foo, &cfg) + agentDeployment, _ = plugin.getFinalAgent(foo, &cfg) assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint) - agentDeployment, _ = getFinalAgent(bar, &cfg) + agentDeployment, _ = plugin.getFinalAgent(bar, &cfg) assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint) }) @@ -318,11 +318,12 @@ func TestInitializeAgentRegistry(t *testing.T) { cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} err := SetConfig(&cfg) assert.NoError(t, err) - updateAgentRegistry(context.Background(), cs) - // In golang, the order of keys in a map is random. So, we sort the keys before asserting. - agentRegistryKeys := maps.Keys(getAgentRegistry()) - sort.Strings(agentRegistryKeys) + agentRegistry := getAgentRegistry(context.Background(), cs) + agentRegistryKeys := maps.Keys(agentRegistry) + expectedKeys := []string{"task1", "task2", "task3", "task_type_1", "task_type_2"} - assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"}) + for _, key := range expectedKeys { + assert.Contains(t, agentRegistryKeys, key) + } } diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 5e4139296f..214540ac07 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -19,6 +19,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" pluginK8s "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent" eventsErr "github.com/flyteorg/flyte/flytepropeller/events/errors" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" controllerConfig "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" @@ -200,6 +201,7 @@ type Handler struct { pluginScope promutils.Scope eventConfig *controllerConfig.EventConfig clusterID string + agentService *pluginCore.AgentService } func (t *Handler) FinalizeRequired() bool { @@ -226,6 +228,7 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error return err } + once.Do(func() { agent.RegisterAgentPlugin(t.agentService) }) // Create the resource negotiator here // and then convert it to proxies later and pass them to plugins enabledPlugins, defaultForTaskTypes, err := WranglePluginsAndGenerateFinalList(ctx, &t.cfg.TaskPlugins, t.pluginRegistry, t.kubeClientset) @@ -245,6 +248,11 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix), p.ID) logger.Infof(ctx, "Loading Plugin [%s] ENABLED", p.ID) cp, err := pluginCore.LoadPlugin(ctx, sCtxFinal, p) + + if cp.GetID() == agent.ID { + t.agentService.CorePlugin = cp + } + if err != nil { return regErrors.Wrapf(err, "failed to load plugin - %s", p.ID) } @@ -306,7 +314,6 @@ func (t *Handler) Setup(ctx context.Context, sCtx interfaces.SetupContext) error } t.resourceManager = rm - return nil } @@ -337,6 +344,11 @@ func (t Handler) ResolvePlugin(ctx context.Context, ttype string, executionConfi logger.Debugf(ctx, "Plugin [%s] resolved for Handler type [%s]", p.GetID(), ttype) return p, nil } + + if t.agentService.ContainTaskType(ttype) { + return t.agentService.CorePlugin, nil + } + if t.defaultPlugin != nil { logger.Warnf(ctx, "No plugin found for Handler-type [%s], defaulting to [%s]", ttype, t.defaultPlugin.GetID()) return t.defaultPlugin, nil @@ -913,5 +925,6 @@ func New(ctx context.Context, kubeClient executors.Client, kubeClientset kuberne cfg: cfg, eventConfig: eventConfig, clusterID: clusterID, + agentService: &pluginCore.AgentService{}, }, nil } diff --git a/flytepropeller/pkg/controller/nodes/task/handler_test.go b/flytepropeller/pkg/controller/nodes/task/handler_test.go index 27e377236f..22e1d7451f 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/task/handler_test.go @@ -80,6 +80,7 @@ func Test_task_setDefault(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tk := &Handler{ defaultPlugin: tt.fields.defaultPlugin, + agentService: &pluginCore.AgentService{}, } if err := tk.setDefault(context.TODO(), tt.args.p); (err != nil) != tt.wantErr { t.Errorf("Handler.setDefault() error = %v, wantErr %v", err, tt.wantErr) @@ -330,6 +331,7 @@ func Test_task_ResolvePlugin(t *testing.T) { defaultPlugins: tt.fields.plugins, defaultPlugin: tt.fields.defaultPlugin, pluginsForType: tt.fields.pluginsForType, + agentService: &pluginCore.AgentService{}, } got, err := tk.ResolvePlugin(context.TODO(), tt.args.ttype, tt.args.executionConfig) if (err != nil) != tt.wantErr { @@ -702,6 +704,7 @@ func Test_task_Handle_NoCatalog(t *testing.T) { resourceManager: noopRm, taskMetricsMap: make(map[MetricKey]*taskMetrics), eventConfig: eventConfig, + agentService: &pluginCore.AgentService{}, } got, err := tk.Handle(context.TODO(), nCtx) if (err != nil) != tt.want.wantErr { @@ -887,6 +890,7 @@ func Test_task_Abort(t *testing.T) { tk := Handler{ defaultPlugin: m, resourceManager: noopRm, + agentService: &pluginCore.AgentService{}, } nCtx := createNodeCtx(tt.args.ev) if err := tk.Abort(context.TODO(), nCtx, "reason"); (err != nil) != tt.wantErr { @@ -1048,6 +1052,7 @@ func Test_task_Abort_v1(t *testing.T) { tk := Handler{ defaultPlugin: m, resourceManager: noopRm, + agentService: &pluginCore.AgentService{}, } nCtx := createNodeCtx(tt.args.ev) if err := tk.Abort(context.TODO(), nCtx, "reason"); (err != nil) != tt.wantErr { diff --git a/flytepropeller/pkg/controller/nodes/task/plugin_config.go b/flytepropeller/pkg/controller/nodes/task/plugin_config.go index 11b4bc6790..71d84af5a5 100644 --- a/flytepropeller/pkg/controller/nodes/task/plugin_config.go +++ b/flytepropeller/pkg/controller/nodes/task/plugin_config.go @@ -9,7 +9,6 @@ import ( "k8s.io/client-go/kubernetes" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/backoff" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/config" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/k8s" @@ -25,7 +24,6 @@ func WranglePluginsAndGenerateFinalList(ctx context.Context, cfg *config.TaskPlu } // Register the GRPC plugin after the config is loaded - once.Do(func() { agent.RegisterAgentPlugin() }) pluginsConfigMeta, err := cfg.GetEnabledPlugins() if err != nil {