diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go index 6b7d7153a7136..78d4c80c9aebc 100644 --- a/api/client/webclient/webclient.go +++ b/api/client/webclient/webclient.go @@ -47,6 +47,15 @@ import ( "github.com/gravitational/teleport/api/utils/keys" ) +const ( + // AgentUpdateGroupParameter is the parameter used to specify the updater + // group when doing a Ping() or Find() query. + // The proxy server will modulate the auto_update part of the PingResponse + // based on the specified group. e.g. some groups might need to update + // before others. + AgentUpdateGroupParameter = "group" +) + // Config specifies information when building requests with the // webclient. type Config struct { @@ -183,7 +192,7 @@ func findWithClient(cfg *Config, clt *http.Client) (*PingResponse, error) { } if cfg.UpdateGroup != "" { endpoint.RawQuery = url.Values{ - "group": []string{cfg.UpdateGroup}, + AgentUpdateGroupParameter: []string{cfg.UpdateGroup}, }.Encode() } @@ -232,7 +241,7 @@ func pingWithClient(cfg *Config, clt *http.Client) (*PingResponse, error) { } if cfg.UpdateGroup != "" { endpoint.RawQuery = url.Values{ - "group": []string{cfg.UpdateGroup}, + AgentUpdateGroupParameter: []string{cfg.UpdateGroup}, }.Encode() } if cfg.ConnectorName != "" { diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index e63035dfcd759..9bbfadd29877f 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -178,6 +178,9 @@ type Handler struct { // rate-limits, each call must cause minimal work. The cached answer can be modulated after, for example if the // caller specified its Automatic Updates UUID or group. findEndpointCache *utils.FnCache + + // clusterMaintenanceConfig is used to cache the cluster maintenance config from the AUth Service. + clusterMaintenanceConfigCache *utils.FnCache } // HandlerOption is a functional argument - an option that can be passed @@ -480,6 +483,18 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { } h.findEndpointCache = findCache + // We create the cache after applying the options to make sure we use the fake clock if it was passed. + cmcCache, err := utils.NewFnCache(utils.FnCacheConfig{ + TTL: findEndpointCacheTTL, + Clock: h.clock, + Context: cfg.Context, + ReloadOnErr: false, + }) + if err != nil { + return nil, trace.Wrap(err, "creating /find cache") + } + h.clusterMaintenanceConfigCache = cmcCache + sessionLingeringThreshold := cachedSessionLingeringThreshold if cfg.CachedSessionLingeringThreshold != nil { sessionLingeringThreshold = *cfg.CachedSessionLingeringThreshold @@ -1527,6 +1542,8 @@ func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Para return nil, trace.Wrap(err) } + group := r.URL.Query().Get(webclient.AgentUpdateGroupParameter) + return webclient.PingResponse{ Auth: authSettings, Proxy: *proxyConfig, @@ -1534,15 +1551,21 @@ func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Para MinClientVersion: teleport.MinClientVersion, ClusterName: h.auth.clusterName, AutomaticUpgrades: pr.ServerFeatures.GetAutomaticUpgrades(), - AutoUpdate: h.automaticUpdateSettings184(r.Context()), + AutoUpdate: h.automaticUpdateSettings184(r.Context(), group, "" /* updater UUID */), Edition: modules.GetModules().BuildType(), FIPS: modules.IsBoringBinary(), }, nil } func (h *Handler) find(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { + group := r.URL.Query().Get(webclient.AgentUpdateGroupParameter) + cacheKey := "find" + if group != "" { + cacheKey += "-" + group + } + // cache the generic answer to avoid doing work for each request - resp, err := utils.FnCacheGet[*webclient.PingResponse](r.Context(), h.findEndpointCache, "find", func(ctx context.Context) (*webclient.PingResponse, error) { + resp, err := utils.FnCacheGet[*webclient.PingResponse](r.Context(), h.findEndpointCache, cacheKey, func(ctx context.Context) (*webclient.PingResponse, error) { proxyConfig, err := h.cfg.ProxySettings.GetProxySettings(ctx) if err != nil { return nil, trace.Wrap(err) @@ -1561,7 +1584,7 @@ func (h *Handler) find(w http.ResponseWriter, r *http.Request, p httprouter.Para ClusterName: h.auth.clusterName, Edition: modules.GetModules().BuildType(), FIPS: modules.IsBoringBinary(), - AutoUpdate: h.automaticUpdateSettings184(ctx), + AutoUpdate: h.automaticUpdateSettings184(ctx, group, "" /* updater UUID */), }, nil }) if err != nil { diff --git a/lib/web/apiserver_ping_test.go b/lib/web/apiserver_ping_test.go index 2bf325d4f7902..84e073ca7ae87 100644 --- a/lib/web/apiserver_ping_test.go +++ b/lib/web/apiserver_ping_test.go @@ -299,6 +299,7 @@ func TestPing_autoUpdateResources(t *testing.T) { name string config *autoupdatev1pb.AutoUpdateConfigSpec version *autoupdatev1pb.AutoUpdateVersionSpec + rollout *autoupdatev1pb.AutoUpdateAgentRolloutSpec cleanup bool expected webclient.AutoUpdateSettings }{ @@ -330,19 +331,12 @@ func TestPing_autoUpdateResources(t *testing.T) { }, { name: "enable agent auto update, immediate schedule", - config: &autoupdatev1pb.AutoUpdateConfigSpec{ - Agents: &autoupdatev1pb.AutoUpdateConfigSpecAgents{ - Mode: autoupdate.AgentsUpdateModeEnabled, - Strategy: autoupdate.AgentsStrategyHaltOnError, - }, - }, - version: &autoupdatev1pb.AutoUpdateVersionSpec{ - Agents: &autoupdatev1pb.AutoUpdateVersionSpecAgents{ - Mode: autoupdate.AgentsUpdateModeEnabled, - StartVersion: "1.2.3", - TargetVersion: "1.2.4", - Schedule: autoupdate.AgentsScheduleImmediate, - }, + rollout: &autoupdatev1pb.AutoUpdateAgentRolloutSpec{ + AutoupdateMode: autoupdate.AgentsUpdateModeEnabled, + Strategy: autoupdate.AgentsStrategyHaltOnError, + Schedule: autoupdate.AgentsScheduleImmediate, + StartVersion: "1.2.3", + TargetVersion: "1.2.4", }, expected: webclient.AutoUpdateSettings{ ToolsVersion: api.Version, @@ -354,20 +348,13 @@ func TestPing_autoUpdateResources(t *testing.T) { cleanup: true, }, { - name: "version enable agent auto update, but config disables them", - config: &autoupdatev1pb.AutoUpdateConfigSpec{ - Agents: &autoupdatev1pb.AutoUpdateConfigSpecAgents{ - Mode: autoupdate.AgentsUpdateModeDisabled, - Strategy: autoupdate.AgentsStrategyHaltOnError, - }, - }, - version: &autoupdatev1pb.AutoUpdateVersionSpec{ - Agents: &autoupdatev1pb.AutoUpdateVersionSpecAgents{ - Mode: autoupdate.AgentsUpdateModeEnabled, - StartVersion: "1.2.3", - TargetVersion: "1.2.4", - Schedule: autoupdate.AgentsScheduleImmediate, - }, + name: "agent rollout present but AU mode is disabled", + rollout: &autoupdatev1pb.AutoUpdateAgentRolloutSpec{ + AutoupdateMode: autoupdate.AgentsUpdateModeDisabled, + Strategy: autoupdate.AgentsStrategyHaltOnError, + Schedule: autoupdate.AgentsScheduleImmediate, + StartVersion: "1.2.3", + TargetVersion: "1.2.4", }, expected: webclient.AutoUpdateSettings{ ToolsVersion: api.Version, @@ -462,6 +449,12 @@ func TestPing_autoUpdateResources(t *testing.T) { _, err = env.server.Auth().UpsertAutoUpdateVersion(ctx, version) require.NoError(t, err) } + if tc.rollout != nil { + rollout, err := autoupdate.NewAutoUpdateAgentRollout(tc.rollout) + require.NoError(t, err) + _, err = env.server.Auth().UpsertAutoUpdateAgentRollout(ctx, rollout) + require.NoError(t, err) + } // expire the fn cache to force the next answer to be fresh for _, proxy := range env.proxies { @@ -480,6 +473,7 @@ func TestPing_autoUpdateResources(t *testing.T) { if tc.cleanup { require.NotErrorIs(t, env.server.Auth().DeleteAutoUpdateConfig(ctx), &trace.NotFoundError{}) require.NotErrorIs(t, env.server.Auth().DeleteAutoUpdateVersion(ctx), &trace.NotFoundError{}) + require.NotErrorIs(t, env.server.Auth().DeleteAutoUpdateAgentRollout(ctx), &trace.NotFoundError{}) } }) } diff --git a/lib/web/autoupdate_common.go b/lib/web/autoupdate_common.go new file mode 100644 index 0000000000000..1756172f4c6e4 --- /dev/null +++ b/lib/web/autoupdate_common.go @@ -0,0 +1,228 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package web + +import ( + "context" + "strings" + + "github.com/gravitational/trace" + + autoupdatepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/autoupdate/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/autoupdate" + "github.com/gravitational/teleport/lib/automaticupgrades" + "github.com/gravitational/teleport/lib/utils" +) + +// autoUpdateAgentVersion returns the version the agent should install/update to based on +// its group and updater UUID. +// If the cluster contains an autoupdate_agent_rollout resource from RFD184 it should take precedence. +// If the resource is not there, we fall back to RFD109-style updates with channels +// and maintenance window derived from the cluster_maintenance_config resource. +// Version returned follows semver without the leading "v". +func (h *Handler) autoUpdateAgentVersion(ctx context.Context, group, updaterUUID string) (string, error) { + rollout, err := h.cfg.AccessPoint.GetAutoUpdateAgentRollout(ctx) + if err != nil { + // Fallback to channels if there is no autoupdate_agent_rollout. + if trace.IsNotFound(err) { + return getVersionFromChannel(ctx, h.cfg.AutomaticUpgradesChannels, group) + } + // Something is broken, we don't want to fallback to channels, this would be harmful. + return "", trace.Wrap(err, "getting autoupdate_agent_rollout") + } + + return getVersionFromRollout(rollout, group, updaterUUID) +} + +// autoUpdateAgentShouldUpdate returns if the agent should update now to based on its group +// and updater UUID. +// If the cluster contains an autoupdate_agent_rollout resource from RFD184 it should take precedence. +// If the resource is not there, we fall back to RFD109-style updates with channels +// and maintenance window derived from the cluster_maintenance_config resource. +func (h *Handler) autoUpdateAgentShouldUpdate(ctx context.Context, group, updaterUUID string, windowLookup bool) (bool, error) { + rollout, err := h.cfg.AccessPoint.GetAutoUpdateAgentRollout(ctx) + if err != nil { + // Fallback to channels if there is no autoupdate_agent_rollout. + if trace.IsNotFound(err) { + // Updaters using the RFD184 API are not aware of maintenance windows + // like RFD109 updaters are. To have both updaters adopt the same behavior + // we must do the CMC window lookup for them. + if windowLookup { + return h.getTriggerFromWindowThenChannel(ctx, group) + } + return getTriggerFromChannel(ctx, h.cfg.AutomaticUpgradesChannels, group) + } + // Something is broken, we don't want to fallback to channels, this would be harmful. + return false, trace.Wrap(err, "failed to get auto-update rollout") + } + + return getTriggerFromRollout(rollout, group, updaterUUID) +} + +// getVersionFromRollout returns the version we should serve to the agent based +// on the RFD184 agent rollout, the agent group name, and its UUID. +// This logic is pretty complex and described in RFD 184. +// The spec is summed up in the following table: +// https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md#rollout-status-disabled +// Version returned follows semver without the leading "v". +func getVersionFromRollout( + rollout *autoupdatepb.AutoUpdateAgentRollout, + groupName, updaterUUID string, +) (string, error) { + switch rollout.GetSpec().GetAutoupdateMode() { + case autoupdate.AgentsUpdateModeDisabled: + // If AUs are disabled, we always answer the target version + return rollout.GetSpec().GetTargetVersion(), nil + case autoupdate.AgentsUpdateModeSuspended, autoupdate.AgentsUpdateModeEnabled: + // If AUs are enabled or suspended, we modulate the response based on the schedule and agent group state + default: + return "", trace.BadParameter("unsupported agent update mode %q", rollout.GetSpec().GetAutoupdateMode()) + } + + // If the schedule is immediate, agents always update to the latest version + if rollout.GetSpec().GetSchedule() == autoupdate.AgentsScheduleImmediate { + return rollout.GetSpec().GetTargetVersion(), nil + } + + // Else we follow the regular schedule and answer based on the agent group state + group, err := getGroup(rollout, groupName) + if err != nil { + return "", trace.Wrap(err, "getting group %q", groupName) + } + + switch group.GetState() { + case autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: + return rollout.GetSpec().GetStartVersion(), nil + case autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: + return rollout.GetSpec().GetTargetVersion(), nil + default: + return "", trace.NotImplemented("unsupported group state %q", group.GetState()) + } +} + +// getTriggerFromRollout returns the version we should serve to the agent based +// on the RFD184 agent rollout, the agent group name, and its UUID. +// This logic is pretty complex and described in RFD 184. +// The spec is summed up in the following table: +// https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md#rollout-status-disabled +func getTriggerFromRollout(rollout *autoupdatepb.AutoUpdateAgentRollout, groupName, updaterUUID string) (bool, error) { + // If the mode is "paused" or "disabled", we never tell to update + switch rollout.GetSpec().GetAutoupdateMode() { + case autoupdate.AgentsUpdateModeDisabled, autoupdate.AgentsUpdateModeSuspended: + // If AUs are disabled or suspended, never tell to update + return false, nil + case autoupdate.AgentsUpdateModeEnabled: + // If AUs are enabled, we modulate the response based on the schedule and agent group state + default: + return false, trace.BadParameter("unsupported agent update mode %q", rollout.GetSpec().GetAutoupdateMode()) + } + + // If the schedule is immediate, agents always update to the latest version + if rollout.GetSpec().GetSchedule() == autoupdate.AgentsScheduleImmediate { + return true, nil + } + + // Else we follow the regular schedule and answer based on the agent group state + group, err := getGroup(rollout, groupName) + if err != nil { + return false, trace.Wrap(err, "getting group %q", groupName) + } + + switch group.GetState() { + case autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: + return false, nil + case autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: + return true, nil + case autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: + return rollout.GetSpec().GetStrategy() == autoupdate.AgentsStrategyHaltOnError, nil + default: + return false, trace.NotImplemented("Unsupported group state %q", group.GetState()) + } +} + +// getGroup returns the agent rollout group the requesting agent belongs to. +// If a group matches the agent-provided group name, this group is returned. +// Else the default group is returned. The default group currently is the last +// one. This might change in the future. +func getGroup( + rollout *autoupdatepb.AutoUpdateAgentRollout, + groupName string, +) (*autoupdatepb.AutoUpdateAgentRolloutStatusGroup, error) { + groups := rollout.GetStatus().GetGroups() + if len(groups) == 0 { + return nil, trace.BadParameter("no groups found") + } + + // Try to find a group with our name + for _, group := range groups { + if group.Name == groupName { + return group, nil + } + } + + // Fallback to the default group (currently the last one but this might change). + return groups[len(groups)-1], nil +} + +// getVersionFromChannel gets the target version from the RFD109 channels. +// Version returned follows semver without the leading "v". +func getVersionFromChannel(ctx context.Context, channels automaticupgrades.Channels, groupName string) (version string, err error) { + // RFD109 channels return the version with the 'v' prefix. + // We can't change the internals for backward compatibility, so we must trim the prefix if it's here. + defer func() { + version = strings.TrimPrefix(version, "v") + }() + + if channel, ok := channels[groupName]; ok { + return channel.GetVersion(ctx) + } + return channels.DefaultVersion(ctx) +} + +// getTriggerFromWindowThenChannel gets the target version from the RFD109 maintenance window and channels. +func (h *Handler) getTriggerFromWindowThenChannel(ctx context.Context, groupName string) (bool, error) { + // Caching the CMC for 10 seconds because this resource is cached neither by the auth nor the proxy. + // And this function can be accessed via unauthenticated endpoints. + cmc, err := utils.FnCacheGet[types.ClusterMaintenanceConfig](ctx, h.clusterMaintenanceConfigCache, "cmc", func(ctx context.Context) (types.ClusterMaintenanceConfig, error) { + return h.cfg.ProxyClient.GetClusterMaintenanceConfig(ctx) + }) + + // If we have a CMC, we check if the window is active, else we just check if the update is critical. + if err == nil && cmc.WithinUpgradeWindow(h.clock.Now()) { + return true, nil + } + + return getTriggerFromChannel(ctx, h.cfg.AutomaticUpgradesChannels, groupName) +} + +// getTriggerFromWindowThenChannel gets the target version from the RFD109 channels. +func getTriggerFromChannel(ctx context.Context, channels automaticupgrades.Channels, groupName string) (bool, error) { + if channel, ok := channels[groupName]; ok { + return channel.GetCritical(ctx) + } + defaultChannel, err := channels.DefaultChannel() + if err != nil { + return false, trace.Wrap(err, "creating new default channel") + } + return defaultChannel.GetCritical(ctx) +} diff --git a/lib/web/autoupdate_common_test.go b/lib/web/autoupdate_common_test.go new file mode 100644 index 0000000000000..a365ac121b078 --- /dev/null +++ b/lib/web/autoupdate_common_test.go @@ -0,0 +1,796 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package web + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + autoupdatepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/autoupdate/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/autoupdate" + "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/automaticupgrades" + "github.com/gravitational/teleport/lib/automaticupgrades/constants" + "github.com/gravitational/teleport/lib/utils" +) + +const ( + testVersionHigh = "2.3.4" + testVersionLow = "2.0.4" +) + +// fakeRolloutAccessPoint allows us to mock the ProxyAccessPoint in autoupdate +// tests. +type fakeRolloutAccessPoint struct { + authclient.ProxyAccessPoint + + rollout *autoupdatepb.AutoUpdateAgentRollout + err error +} + +func (ap *fakeRolloutAccessPoint) GetAutoUpdateAgentRollout(_ context.Context) (*autoupdatepb.AutoUpdateAgentRollout, error) { + return ap.rollout, ap.err +} + +// fakeRolloutAccessPoint allows us to mock the proxy's auth client in autoupdate +// tests. +type fakeCMCAuthClient struct { + authclient.ClientI + + cmc types.ClusterMaintenanceConfig + err error +} + +func (c *fakeCMCAuthClient) GetClusterMaintenanceConfig(_ context.Context) (types.ClusterMaintenanceConfig, error) { + return c.cmc, c.err +} + +func TestAutoUpdateAgentVersion(t *testing.T) { + t.Parallel() + groupName := "test-group" + ctx := context.Background() + + // brokenChannelUpstream is a buggy upstream version server. + // This allows us to craft version channels returning errors. + brokenChannelUpstream := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + })) + t.Cleanup(brokenChannelUpstream.Close) + + tests := []struct { + name string + rollout *autoupdatepb.AutoUpdateAgentRollout + rolloutErr error + channel *automaticupgrades.Channel + expectedVersion string + expectError require.ErrorAssertionFunc + }{ + { + name: "version is looked up from rollout if it is here", + rollout: &autoupdatepb.AutoUpdateAgentRollout{ + Spec: &autoupdatepb.AutoUpdateAgentRolloutSpec{ + AutoupdateMode: autoupdate.AgentsUpdateModeEnabled, + TargetVersion: testVersionHigh, + Schedule: autoupdate.AgentsScheduleImmediate, + }, + }, + channel: &automaticupgrades.Channel{StaticVersion: testVersionLow}, + expectError: require.NoError, + expectedVersion: testVersionHigh, + }, + { + name: "version is looked up from channel if rollout is not here", + rolloutErr: trace.NotFound("rollout is not here"), + channel: &automaticupgrades.Channel{StaticVersion: testVersionLow}, + expectError: require.NoError, + expectedVersion: testVersionLow, + }, + { + name: "hard error getting rollout should not fallback to version channels", + rolloutErr: trace.AccessDenied("something is very broken"), + channel: &automaticupgrades.Channel{ + StaticVersion: testVersionLow, + }, + expectError: require.Error, + }, + { + name: "no rollout, error checking channel", + rolloutErr: trace.NotFound("rollout is not here"), + channel: &automaticupgrades.Channel{ForwardURL: brokenChannelUpstream.URL}, + expectError: require.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test setup: building the channel, mock client, and handler with test config. + require.NoError(t, tt.channel.CheckAndSetDefaults()) + h := &Handler{ + cfg: Config{ + AccessPoint: &fakeRolloutAccessPoint{ + rollout: tt.rollout, + err: tt.rolloutErr, + }, + AutomaticUpgradesChannels: map[string]*automaticupgrades.Channel{ + groupName: tt.channel, + }, + }, + } + + // Test execution + result, err := h.autoUpdateAgentVersion(ctx, groupName, "") + tt.expectError(t, err) + require.Equal(t, tt.expectedVersion, result) + }) + } +} + +// TestAutoUpdateAgentShouldUpdate also accidentally tests getTriggerFromWindowThenChannel. +func TestAutoUpdateAgentShouldUpdate(t *testing.T) { + t.Parallel() + + groupName := "test-group" + ctx := context.Background() + + // brokenChannelUpstream is a buggy upstream version server. + // This allows us to craft version channels returning errors. + brokenChannelUpstream := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + })) + t.Cleanup(brokenChannelUpstream.Close) + + clock := clockwork.NewFakeClock() + cmcCache, err := utils.NewFnCache(utils.FnCacheConfig{ + TTL: findEndpointCacheTTL, + Clock: clock, + Context: ctx, + ReloadOnErr: false, + }) + require.NoError(t, err) + t.Cleanup(func() { + cmcCache.Shutdown(ctx) + }) + + activeUpgradeWindow := types.AgentUpgradeWindow{UTCStartHour: uint32(clock.Now().Hour())} + inactiveUpgradeWindow := types.AgentUpgradeWindow{UTCStartHour: uint32(clock.Now().Add(2 * time.Hour).Hour())} + tests := []struct { + name string + rollout *autoupdatepb.AutoUpdateAgentRollout + rolloutErr error + channel *automaticupgrades.Channel + upgradeWindow types.AgentUpgradeWindow + cmcErr error + windowLookup bool + expectedTrigger bool + expectError require.ErrorAssertionFunc + }{ + { + name: "trigger is looked up from rollout if it is here, trigger firing", + rollout: &autoupdatepb.AutoUpdateAgentRollout{ + Spec: &autoupdatepb.AutoUpdateAgentRolloutSpec{ + AutoupdateMode: autoupdate.AgentsUpdateModeEnabled, + TargetVersion: testVersionHigh, + Schedule: autoupdate.AgentsScheduleImmediate, + }, + }, + channel: &automaticupgrades.Channel{StaticVersion: testVersionLow}, + expectError: require.NoError, + expectedTrigger: true, + }, + { + name: "trigger is looked up from rollout if it is here, trigger not firing", + rollout: &autoupdatepb.AutoUpdateAgentRollout{ + Spec: &autoupdatepb.AutoUpdateAgentRolloutSpec{ + AutoupdateMode: autoupdate.AgentsUpdateModeDisabled, + TargetVersion: testVersionHigh, + Schedule: autoupdate.AgentsScheduleImmediate, + }, + }, + channel: &automaticupgrades.Channel{StaticVersion: testVersionLow}, + expectError: require.NoError, + expectedTrigger: false, + }, + { + name: "trigger is looked up from channel if rollout is not here and window lookup is disabled, trigger not firing", + rolloutErr: trace.NotFound("rollout is not here"), + channel: &automaticupgrades.Channel{ + StaticVersion: testVersionLow, + Critical: false, + }, + expectError: require.NoError, + expectedTrigger: false, + }, + { + name: "trigger is looked up from channel if rollout is not here and window lookup is disabled, trigger firing", + rolloutErr: trace.NotFound("rollout is not here"), + channel: &automaticupgrades.Channel{ + StaticVersion: testVersionLow, + Critical: true, + }, + expectError: require.NoError, + expectedTrigger: true, + }, + { + name: "trigger is looked up from cmc, then channel if rollout is not here and window lookup is enabled, cmc firing", + rolloutErr: trace.NotFound("rollout is not here"), + channel: &automaticupgrades.Channel{ + StaticVersion: testVersionLow, + Critical: false, + }, + upgradeWindow: activeUpgradeWindow, + windowLookup: true, + expectError: require.NoError, + expectedTrigger: true, + }, + { + name: "trigger is looked up from cmc, then channel if rollout is not here and window lookup is enabled, cmc not firing", + rolloutErr: trace.NotFound("rollout is not here"), + channel: &automaticupgrades.Channel{ + StaticVersion: testVersionLow, + Critical: false, + }, + upgradeWindow: inactiveUpgradeWindow, + windowLookup: true, + expectError: require.NoError, + expectedTrigger: false, + }, + { + name: "trigger is looked up from cmc, then channel if rollout is not here and window lookup is enabled, cmc not firing but channel firing", + rolloutErr: trace.NotFound("rollout is not here"), + channel: &automaticupgrades.Channel{ + StaticVersion: testVersionLow, + Critical: true, + }, + upgradeWindow: inactiveUpgradeWindow, + windowLookup: true, + expectError: require.NoError, + expectedTrigger: true, + }, + { + name: "trigger is looked up from cmc, then channel if rollout is not here and window lookup is enabled, no cmc and channel not firing", + rolloutErr: trace.NotFound("rollout is not here"), + channel: &automaticupgrades.Channel{ + StaticVersion: testVersionLow, + Critical: false, + }, + cmcErr: trace.NotFound("no cmc for this cluster"), + windowLookup: true, + expectError: require.NoError, + expectedTrigger: false, + }, + { + name: "trigger is looked up from cmc, then channel if rollout is not here and window lookup is enabled, no cmc and channel firing", + rolloutErr: trace.NotFound("rollout is not here"), + channel: &automaticupgrades.Channel{ + StaticVersion: testVersionLow, + Critical: true, + }, + cmcErr: trace.NotFound("no cmc for this cluster"), + windowLookup: true, + expectError: require.NoError, + expectedTrigger: true, + }, + { + name: "hard error getting rollout should not fallback to RFD109 trigger", + rolloutErr: trace.AccessDenied("something is very broken"), + channel: &automaticupgrades.Channel{ + StaticVersion: testVersionLow, + }, + expectError: require.Error, + }, + { + name: "no rollout, error checking channel", + rolloutErr: trace.NotFound("rollout is not here"), + channel: &automaticupgrades.Channel{ + ForwardURL: brokenChannelUpstream.URL, + }, + expectError: require.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test setup: building the channel, mock clients, and handler with test config. + cmc := types.NewClusterMaintenanceConfig() + cmc.SetAgentUpgradeWindow(tt.upgradeWindow) + require.NoError(t, tt.channel.CheckAndSetDefaults()) + // Advance clock to invalidate cache + clock.Advance(2 * findEndpointCacheTTL) + h := &Handler{ + cfg: Config{ + AccessPoint: &fakeRolloutAccessPoint{ + rollout: tt.rollout, + err: tt.rolloutErr, + }, + ProxyClient: &fakeCMCAuthClient{ + cmc: cmc, + err: tt.cmcErr, + }, + AutomaticUpgradesChannels: map[string]*automaticupgrades.Channel{ + groupName: tt.channel, + }, + }, + clock: clock, + clusterMaintenanceConfigCache: cmcCache, + } + + // Test execution + result, err := h.autoUpdateAgentShouldUpdate(ctx, groupName, "", tt.windowLookup) + tt.expectError(t, err) + require.Equal(t, tt.expectedTrigger, result) + }) + } +} + +func TestGetVersionFromRollout(t *testing.T) { + t.Parallel() + groupName := "test-group" + + // This test matrix is written based on: + // https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md#rollout-status-disabled + latestAllTheTime := map[autoupdatepb.AutoUpdateAgentGroupState]string{ + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: testVersionHigh, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: testVersionHigh, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: testVersionHigh, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: testVersionHigh, + } + + activeDoneOnly := map[autoupdatepb.AutoUpdateAgentGroupState]string{ + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: testVersionLow, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: testVersionHigh, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: testVersionHigh, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: testVersionLow, + } + + tests := map[string]map[string]map[autoupdatepb.AutoUpdateAgentGroupState]string{ + autoupdate.AgentsUpdateModeDisabled: { + autoupdate.AgentsScheduleImmediate: latestAllTheTime, + autoupdate.AgentsScheduleRegular: latestAllTheTime, + }, + autoupdate.AgentsUpdateModeSuspended: { + autoupdate.AgentsScheduleImmediate: latestAllTheTime, + autoupdate.AgentsScheduleRegular: activeDoneOnly, + }, + autoupdate.AgentsUpdateModeEnabled: { + autoupdate.AgentsScheduleImmediate: latestAllTheTime, + autoupdate.AgentsScheduleRegular: activeDoneOnly, + }, + } + for mode, scheduleCases := range tests { + for schedule, stateCases := range scheduleCases { + for state, expectedVersion := range stateCases { + t.Run(fmt.Sprintf("%s/%s/%s", mode, schedule, state), func(t *testing.T) { + rollout := &autoupdatepb.AutoUpdateAgentRollout{ + Spec: &autoupdatepb.AutoUpdateAgentRolloutSpec{ + StartVersion: testVersionLow, + TargetVersion: testVersionHigh, + Schedule: schedule, + AutoupdateMode: mode, + // Strategy does not affect which version are served + Strategy: autoupdate.AgentsStrategyTimeBased, + }, + Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{ + Groups: []*autoupdatepb.AutoUpdateAgentRolloutStatusGroup{ + { + Name: groupName, + State: state, + }, + }, + }, + } + version, err := getVersionFromRollout(rollout, groupName, "") + require.NoError(t, err) + require.Equal(t, expectedVersion, version) + }) + } + } + } +} + +func TestGetTriggerFromRollout(t *testing.T) { + t.Parallel() + groupName := "test-group" + + // This test matrix is written based on: + // https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md#rollout-status-disabled + neverUpdate := map[autoupdatepb.AutoUpdateAgentGroupState]bool{ + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: false, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: false, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: false, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: false, + } + alwaysUpdate := map[autoupdatepb.AutoUpdateAgentGroupState]bool{ + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: true, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: true, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: true, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: true, + } + + tests := map[string]map[string]map[string]map[autoupdatepb.AutoUpdateAgentGroupState]bool{ + autoupdate.AgentsUpdateModeDisabled: { + autoupdate.AgentsStrategyTimeBased: { + autoupdate.AgentsScheduleImmediate: neverUpdate, + autoupdate.AgentsScheduleRegular: neverUpdate, + }, + autoupdate.AgentsStrategyHaltOnError: { + autoupdate.AgentsScheduleImmediate: neverUpdate, + autoupdate.AgentsScheduleRegular: neverUpdate, + }, + }, + autoupdate.AgentsUpdateModeSuspended: { + autoupdate.AgentsStrategyTimeBased: { + autoupdate.AgentsScheduleImmediate: neverUpdate, + autoupdate.AgentsScheduleRegular: neverUpdate, + }, + autoupdate.AgentsStrategyHaltOnError: { + autoupdate.AgentsScheduleImmediate: neverUpdate, + autoupdate.AgentsScheduleRegular: neverUpdate, + }, + }, + autoupdate.AgentsUpdateModeEnabled: { + autoupdate.AgentsStrategyTimeBased: { + autoupdate.AgentsScheduleImmediate: alwaysUpdate, + autoupdate.AgentsScheduleRegular: { + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: false, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: false, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: true, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: true, + }, + }, + autoupdate.AgentsStrategyHaltOnError: { + autoupdate.AgentsScheduleImmediate: alwaysUpdate, + autoupdate.AgentsScheduleRegular: { + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: false, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: true, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: true, + autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: true, + }, + }, + }, + } + for mode, strategyCases := range tests { + for strategy, scheduleCases := range strategyCases { + for schedule, stateCases := range scheduleCases { + for state, expectedTrigger := range stateCases { + t.Run(fmt.Sprintf("%s/%s/%s/%s", mode, strategy, schedule, state), func(t *testing.T) { + rollout := &autoupdatepb.AutoUpdateAgentRollout{ + Spec: &autoupdatepb.AutoUpdateAgentRolloutSpec{ + StartVersion: testVersionLow, + TargetVersion: testVersionHigh, + Schedule: schedule, + AutoupdateMode: mode, + Strategy: strategy, + }, + Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{ + Groups: []*autoupdatepb.AutoUpdateAgentRolloutStatusGroup{ + { + Name: groupName, + State: state, + }, + }, + }, + } + shouldUpdate, err := getTriggerFromRollout(rollout, groupName, "") + require.NoError(t, err) + require.Equal(t, expectedTrigger, shouldUpdate) + }) + } + } + } + } +} + +func TestGetGroup(t *testing.T) { + groupName := "test-group" + t.Parallel() + tests := []struct { + name string + rollout *autoupdatepb.AutoUpdateAgentRollout + expectedResult *autoupdatepb.AutoUpdateAgentRolloutStatusGroup + expectError require.ErrorAssertionFunc + }{ + { + name: "nil", + expectError: require.Error, + }, + { + name: "nil status", + rollout: &autoupdatepb.AutoUpdateAgentRollout{}, + expectError: require.Error, + }, + { + name: "nil status groups", + rollout: &autoupdatepb.AutoUpdateAgentRollout{Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{}}, + expectError: require.Error, + }, + { + name: "empty status groups", + rollout: &autoupdatepb.AutoUpdateAgentRollout{ + Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{ + Groups: []*autoupdatepb.AutoUpdateAgentRolloutStatusGroup{}, + }, + }, + expectError: require.Error, + }, + { + name: "group matching name", + rollout: &autoupdatepb.AutoUpdateAgentRollout{ + Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{ + Groups: []*autoupdatepb.AutoUpdateAgentRolloutStatusGroup{ + {Name: "foo", State: 1}, + {Name: "bar", State: 1}, + {Name: groupName, State: 2}, + {Name: "baz", State: 1}, + }, + }, + }, + expectedResult: &autoupdatepb.AutoUpdateAgentRolloutStatusGroup{ + Name: groupName, + State: 2, + }, + expectError: require.NoError, + }, + { + name: "no group matching name, should fallback to default", + rollout: &autoupdatepb.AutoUpdateAgentRollout{ + Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{ + Groups: []*autoupdatepb.AutoUpdateAgentRolloutStatusGroup{ + {Name: "foo", State: 1}, + {Name: "bar", State: 1}, + {Name: "baz", State: 1}, + }, + }, + }, + expectedResult: &autoupdatepb.AutoUpdateAgentRolloutStatusGroup{ + Name: "baz", + State: 1, + }, + expectError: require.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := getGroup(tt.rollout, groupName) + tt.expectError(t, err) + require.Equal(t, tt.expectedResult, result) + }) + } +} + +type mockRFD109VersionServer struct { + t *testing.T + channels map[string]channelStub +} + +type channelStub struct { + // with our without the leading "v" + version string + critical bool + fail bool +} + +func (m *mockRFD109VersionServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var path string + var writeResp func(w http.ResponseWriter, stub channelStub) error + + switch { + case strings.HasSuffix(r.URL.Path, constants.VersionPath): + path = strings.Trim(strings.TrimSuffix(r.URL.Path, constants.VersionPath), "/") + writeResp = func(w http.ResponseWriter, stub channelStub) error { + _, err := w.Write([]byte(stub.version)) + return err + } + case strings.HasSuffix(r.URL.Path, constants.MaintenancePath): + path = strings.Trim(strings.TrimSuffix(r.URL.Path, constants.MaintenancePath), "/") + writeResp = func(w http.ResponseWriter, stub channelStub) error { + response := "no" + if stub.critical { + response = "yes" + } + _, err := w.Write([]byte(response)) + return err + } + default: + assert.Fail(m.t, "unsupported path %q", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + + channel, ok := m.channels[path] + if !ok { + w.WriteHeader(http.StatusNotFound) + assert.Fail(m.t, "channel %q not found", path) + return + } + if channel.fail { + w.WriteHeader(http.StatusInternalServerError) + return + } + assert.NoError(m.t, writeResp(w, channel), "failed to write response") +} + +func TestGetVersionFromChannel(t *testing.T) { + t.Parallel() + ctx := context.Background() + + channelName := "test-channel" + + mock := mockRFD109VersionServer{ + t: t, + channels: map[string]channelStub{ + "broken": {fail: true}, + "with-leading-v": {version: "v" + testVersionHigh}, + "without-leading-v": {version: testVersionHigh}, + "low": {version: testVersionLow}, + }, + } + srv := httptest.NewServer(http.HandlerFunc(mock.ServeHTTP)) + t.Cleanup(srv.Close) + + tests := []struct { + name string + channels automaticupgrades.Channels + expectedResult string + expectError require.ErrorAssertionFunc + }{ + { + name: "channel with leading v", + channels: automaticupgrades.Channels{ + channelName: {ForwardURL: srv.URL + "/with-leading-v"}, + "default": {ForwardURL: srv.URL + "/low"}, + }, + expectedResult: testVersionHigh, + expectError: require.NoError, + }, + { + name: "channel without leading v", + channels: automaticupgrades.Channels{ + channelName: {ForwardURL: srv.URL + "/without-leading-v"}, + "default": {ForwardURL: srv.URL + "/low"}, + }, + expectedResult: testVersionHigh, + expectError: require.NoError, + }, + { + name: "fallback to default with leading v", + channels: automaticupgrades.Channels{ + "default": {ForwardURL: srv.URL + "/with-leading-v"}, + }, + expectedResult: testVersionHigh, + expectError: require.NoError, + }, + { + name: "fallback to default without leading v", + channels: automaticupgrades.Channels{ + "default": {ForwardURL: srv.URL + "/without-leading-v"}, + }, + expectedResult: testVersionHigh, + expectError: require.NoError, + }, + { + name: "broken channel", + channels: automaticupgrades.Channels{ + channelName: {ForwardURL: srv.URL + "/broken"}, + "default": {ForwardURL: srv.URL + "/without-leading-v"}, + }, + expectError: require.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test setup + require.NoError(t, tt.channels.CheckAndSetDefaults()) + + // Test execution + result, err := getVersionFromChannel(ctx, tt.channels, channelName) + tt.expectError(t, err) + require.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestGetTriggerFromChannel(t *testing.T) { + t.Parallel() + ctx := context.Background() + + channelName := "test-channel" + + mock := mockRFD109VersionServer{ + t: t, + channels: map[string]channelStub{ + "broken": {fail: true}, + "critical": {critical: true}, + "non-critical": {critical: false}, + }, + } + srv := httptest.NewServer(http.HandlerFunc(mock.ServeHTTP)) + t.Cleanup(srv.Close) + + tests := []struct { + name string + channels automaticupgrades.Channels + expectedResult bool + expectError require.ErrorAssertionFunc + }{ + { + name: "critical channel", + channels: automaticupgrades.Channels{ + channelName: {ForwardURL: srv.URL + "/critical"}, + "default": {ForwardURL: srv.URL + "/non-critical"}, + }, + expectedResult: true, + expectError: require.NoError, + }, + { + name: "non-critical channel", + channels: automaticupgrades.Channels{ + channelName: {ForwardURL: srv.URL + "/non-critical"}, + "default": {ForwardURL: srv.URL + "/critical"}, + }, + expectedResult: false, + expectError: require.NoError, + }, + { + name: "fallback to default which is critical", + channels: automaticupgrades.Channels{ + "default": {ForwardURL: srv.URL + "/critical"}, + }, + expectedResult: true, + expectError: require.NoError, + }, + { + name: "fallback to default which is non-critical", + channels: automaticupgrades.Channels{ + "default": {ForwardURL: srv.URL + "/non-critical"}, + }, + expectedResult: false, + expectError: require.NoError, + }, + { + name: "broken channel", + channels: automaticupgrades.Channels{ + channelName: {ForwardURL: srv.URL + "/broken"}, + "default": {ForwardURL: srv.URL + "/critical"}, + }, + expectError: require.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test setup + require.NoError(t, tt.channels.CheckAndSetDefaults()) + + // Test execution + result, err := getTriggerFromChannel(ctx, tt.channels, channelName) + tt.expectError(t, err) + require.Equal(t, tt.expectedResult, result) + }) + } +} diff --git a/lib/web/autoupdate_rfd109.go b/lib/web/autoupdate_rfd109.go index b09b151220754..3bbdd0175b106 100644 --- a/lib/web/autoupdate_rfd109.go +++ b/lib/web/autoupdate_rfd109.go @@ -21,6 +21,7 @@ package web import ( "context" "errors" + "fmt" "net/http" "strings" "time" @@ -28,7 +29,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/automaticupgrades/constants" "github.com/gravitational/teleport/lib/automaticupgrades/version" ) @@ -59,31 +59,25 @@ func (h *Handler) automaticUpgrades109(w http.ResponseWriter, r *http.Request, p return nil, trace.BadParameter("a channel name is required") } - // We check if the channel is configured - channel, ok := h.cfg.AutomaticUpgradesChannels[channelName] - if !ok { - return nil, trace.NotFound("channel %s not found", channelName) - } - // Finally, we treat the request based on its type switch requestType { case "version": h.log.Debugf("Agent requesting version for channel %s", channelName) - return h.automaticUpgradesVersion109(w, r, channel) + return h.automaticUpgradesVersion109(w, r, channelName) case "critical": h.log.Debugf("Agent requesting criticality for channel %s", channelName) - return h.automaticUpgradesCritical109(w, r, channel) + return h.automaticUpgradesCritical109(w, r, channelName) default: return nil, trace.BadParameter("requestType path must end with 'version' or 'critical'") } } // automaticUpgradesVersion109 handles version requests from upgraders -func (h *Handler) automaticUpgradesVersion109(w http.ResponseWriter, r *http.Request, channel *automaticupgrades.Channel) (interface{}, error) { +func (h *Handler) automaticUpgradesVersion109(w http.ResponseWriter, r *http.Request, channelName string) (interface{}, error) { ctx, cancel := context.WithTimeout(r.Context(), defaultChannelTimeout) defer cancel() - targetVersion, err := channel.GetVersion(ctx) + targetVersion, err := h.autoUpdateAgentVersion(ctx, channelName, "" /* updater UUID */) if err != nil { // If the error is that the upstream channel has no version // We gracefully handle by serving "none" @@ -96,16 +90,20 @@ func (h *Handler) automaticUpgradesVersion109(w http.ResponseWriter, r *http.Req return nil, trace.Wrap(err) } - _, err = w.Write([]byte(targetVersion)) + // RFD 109 specifies that version from channels must have the leading "v". + // As h.autoUpdateAgentVersion doesn't, we must add it. + _, err = fmt.Fprintf(w, "v%s", targetVersion) return nil, trace.Wrap(err) } // automaticUpgradesCritical109 handles criticality requests from upgraders -func (h *Handler) automaticUpgradesCritical109(w http.ResponseWriter, r *http.Request, channel *automaticupgrades.Channel) (interface{}, error) { +func (h *Handler) automaticUpgradesCritical109(w http.ResponseWriter, r *http.Request, channelName string) (interface{}, error) { ctx, cancel := context.WithTimeout(r.Context(), defaultChannelTimeout) defer cancel() - critical, err := channel.GetCritical(ctx) + // RFD109 agents already retrieve maintenance windows from the CMC, no need to + // do a maintenance window lookup for them. + critical, err := h.autoUpdateAgentShouldUpdate(ctx, channelName, "" /* updater UUID */, false /* window lookup */) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/web/autoupdate_rfd184.go b/lib/web/autoupdate_rfd184.go index 4c3ccdeaef907..6ac532650cb64 100644 --- a/lib/web/autoupdate_rfd184.go +++ b/lib/web/autoupdate_rfd184.go @@ -23,6 +23,7 @@ import ( "github.com/gravitational/trace" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api" "github.com/gravitational/teleport/api/client/webclient" autoupdatepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/autoupdate/v1" @@ -31,8 +32,8 @@ import ( // automaticUpdateSettings184 crafts the automatic updates part of the ping/find response // as described in RFD-184 (agents) and RFD-144 (tools). -// TODO: add the request as a parameter when we'll need to modulate the content based on the UUID and group -func (h *Handler) automaticUpdateSettings184(ctx context.Context) webclient.AutoUpdateSettings { +func (h *Handler) automaticUpdateSettings184(ctx context.Context, group, updaterUUID string) webclient.AutoUpdateSettings { + // Tools auto updates section. autoUpdateConfig, err := h.cfg.AccessPoint.GetAutoUpdateConfig(ctx) // TODO(vapopov) DELETE IN v18.0.0 check of IsNotImplemented, must be backported to all latest supported versions. if err != nil && !trace.IsNotFound(err) && !trace.IsNotImplemented(err) { @@ -45,12 +46,29 @@ func (h *Handler) automaticUpdateSettings184(ctx context.Context) webclient.Auto h.logger.ErrorContext(ctx, "failed to receive AutoUpdateVersion", "error", err) } + // Agent auto updates section. + agentVersion, err := h.autoUpdateAgentVersion(ctx, group, updaterUUID) + if err != nil { + h.logger.ErrorContext(ctx, "failed to resolve AgentVersion", "error", err) + // Defaulting to current version + agentVersion = teleport.Version + } + // If the source of truth is RFD 109 configuration (channels + CMC) we must emulate the + // RFD109 agent maintenance window behavior by looking up the CMC and checking if + // we are in a maintenance window. + shouldUpdate, err := h.autoUpdateAgentShouldUpdate(ctx, group, updaterUUID, true /* window lookup */) + if err != nil { + h.logger.ErrorContext(ctx, "failed to resolve AgentAutoUpdate", "error", err) + // Failing open + shouldUpdate = false + } + return webclient.AutoUpdateSettings{ ToolsAutoUpdate: getToolsAutoUpdate(autoUpdateConfig), ToolsVersion: getToolsVersion(autoUpdateVersion), AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds, - AgentVersion: getAgentVersion184(autoUpdateVersion), - AgentAutoUpdate: agentShouldUpdate184(autoUpdateConfig, autoUpdateVersion), + AgentVersion: agentVersion, + AgentAutoUpdate: shouldUpdate, } } @@ -73,39 +91,3 @@ func getToolsVersion(version *autoupdatepb.AutoUpdateVersion) string { } return version.GetSpec().GetTools().GetTargetVersion() } - -func getAgentVersion184(version *autoupdatepb.AutoUpdateVersion) string { - // If we can't get the AU version or tools AU version is not specified, we default to the current proxy version. - // This ensures we always advertise a version compatible with the cluster. - // TODO: read the version from the autoupdate_agent_rollout when the resource is implemented - if version.GetSpec().GetAgents() == nil { - return api.Version - } - - return version.GetSpec().GetAgents().GetTargetVersion() -} - -func agentShouldUpdate184(config *autoupdatepb.AutoUpdateConfig, version *autoupdatepb.AutoUpdateVersion) bool { - // TODO: read the data from the autoupdate_agent_rollout when the resource is implemented - - // If we can't get the AU config or if AUs are not configured, we default to "disabled". - // This ensures we fail open and don't accidentally update agents if something is going wrong. - // If we want to enable AUs by default, it would be better to create a default "autoupdate_config" resource - // than changing this logic. - if config.GetSpec().GetAgents() == nil { - return false - } - if version.GetSpec().GetAgents() == nil { - return false - } - configMode := config.GetSpec().GetAgents().GetMode() - versionMode := version.GetSpec().GetAgents().GetMode() - - // We update only if both version and config agent modes are "enabled" - if configMode != autoupdate.AgentsUpdateModeEnabled || versionMode != autoupdate.AgentsUpdateModeEnabled { - return false - } - - scheduleName := version.GetSpec().GetAgents().GetSchedule() - return scheduleName == autoupdate.AgentsScheduleImmediate -}