diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go
index 6b7d7153a7136..78d4c80c9aebc 100644
--- a/api/client/webclient/webclient.go
+++ b/api/client/webclient/webclient.go
@@ -47,6 +47,15 @@ import (
"github.com/gravitational/teleport/api/utils/keys"
)
+const (
+ // AgentUpdateGroupParameter is the parameter used to specify the updater
+ // group when doing a Ping() or Find() query.
+ // The proxy server will modulate the auto_update part of the PingResponse
+ // based on the specified group. e.g. some groups might need to update
+ // before others.
+ AgentUpdateGroupParameter = "group"
+)
+
// Config specifies information when building requests with the
// webclient.
type Config struct {
@@ -183,7 +192,7 @@ func findWithClient(cfg *Config, clt *http.Client) (*PingResponse, error) {
}
if cfg.UpdateGroup != "" {
endpoint.RawQuery = url.Values{
- "group": []string{cfg.UpdateGroup},
+ AgentUpdateGroupParameter: []string{cfg.UpdateGroup},
}.Encode()
}
@@ -232,7 +241,7 @@ func pingWithClient(cfg *Config, clt *http.Client) (*PingResponse, error) {
}
if cfg.UpdateGroup != "" {
endpoint.RawQuery = url.Values{
- "group": []string{cfg.UpdateGroup},
+ AgentUpdateGroupParameter: []string{cfg.UpdateGroup},
}.Encode()
}
if cfg.ConnectorName != "" {
diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go
index e63035dfcd759..9bbfadd29877f 100644
--- a/lib/web/apiserver.go
+++ b/lib/web/apiserver.go
@@ -178,6 +178,9 @@ type Handler struct {
// rate-limits, each call must cause minimal work. The cached answer can be modulated after, for example if the
// caller specified its Automatic Updates UUID or group.
findEndpointCache *utils.FnCache
+
+ // clusterMaintenanceConfig is used to cache the cluster maintenance config from the AUth Service.
+ clusterMaintenanceConfigCache *utils.FnCache
}
// HandlerOption is a functional argument - an option that can be passed
@@ -480,6 +483,18 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
}
h.findEndpointCache = findCache
+ // We create the cache after applying the options to make sure we use the fake clock if it was passed.
+ cmcCache, err := utils.NewFnCache(utils.FnCacheConfig{
+ TTL: findEndpointCacheTTL,
+ Clock: h.clock,
+ Context: cfg.Context,
+ ReloadOnErr: false,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err, "creating /find cache")
+ }
+ h.clusterMaintenanceConfigCache = cmcCache
+
sessionLingeringThreshold := cachedSessionLingeringThreshold
if cfg.CachedSessionLingeringThreshold != nil {
sessionLingeringThreshold = *cfg.CachedSessionLingeringThreshold
@@ -1527,6 +1542,8 @@ func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Para
return nil, trace.Wrap(err)
}
+ group := r.URL.Query().Get(webclient.AgentUpdateGroupParameter)
+
return webclient.PingResponse{
Auth: authSettings,
Proxy: *proxyConfig,
@@ -1534,15 +1551,21 @@ func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Para
MinClientVersion: teleport.MinClientVersion,
ClusterName: h.auth.clusterName,
AutomaticUpgrades: pr.ServerFeatures.GetAutomaticUpgrades(),
- AutoUpdate: h.automaticUpdateSettings184(r.Context()),
+ AutoUpdate: h.automaticUpdateSettings184(r.Context(), group, "" /* updater UUID */),
Edition: modules.GetModules().BuildType(),
FIPS: modules.IsBoringBinary(),
}, nil
}
func (h *Handler) find(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
+ group := r.URL.Query().Get(webclient.AgentUpdateGroupParameter)
+ cacheKey := "find"
+ if group != "" {
+ cacheKey += "-" + group
+ }
+
// cache the generic answer to avoid doing work for each request
- resp, err := utils.FnCacheGet[*webclient.PingResponse](r.Context(), h.findEndpointCache, "find", func(ctx context.Context) (*webclient.PingResponse, error) {
+ resp, err := utils.FnCacheGet[*webclient.PingResponse](r.Context(), h.findEndpointCache, cacheKey, func(ctx context.Context) (*webclient.PingResponse, error) {
proxyConfig, err := h.cfg.ProxySettings.GetProxySettings(ctx)
if err != nil {
return nil, trace.Wrap(err)
@@ -1561,7 +1584,7 @@ func (h *Handler) find(w http.ResponseWriter, r *http.Request, p httprouter.Para
ClusterName: h.auth.clusterName,
Edition: modules.GetModules().BuildType(),
FIPS: modules.IsBoringBinary(),
- AutoUpdate: h.automaticUpdateSettings184(ctx),
+ AutoUpdate: h.automaticUpdateSettings184(ctx, group, "" /* updater UUID */),
}, nil
})
if err != nil {
diff --git a/lib/web/apiserver_ping_test.go b/lib/web/apiserver_ping_test.go
index 2bf325d4f7902..84e073ca7ae87 100644
--- a/lib/web/apiserver_ping_test.go
+++ b/lib/web/apiserver_ping_test.go
@@ -299,6 +299,7 @@ func TestPing_autoUpdateResources(t *testing.T) {
name string
config *autoupdatev1pb.AutoUpdateConfigSpec
version *autoupdatev1pb.AutoUpdateVersionSpec
+ rollout *autoupdatev1pb.AutoUpdateAgentRolloutSpec
cleanup bool
expected webclient.AutoUpdateSettings
}{
@@ -330,19 +331,12 @@ func TestPing_autoUpdateResources(t *testing.T) {
},
{
name: "enable agent auto update, immediate schedule",
- config: &autoupdatev1pb.AutoUpdateConfigSpec{
- Agents: &autoupdatev1pb.AutoUpdateConfigSpecAgents{
- Mode: autoupdate.AgentsUpdateModeEnabled,
- Strategy: autoupdate.AgentsStrategyHaltOnError,
- },
- },
- version: &autoupdatev1pb.AutoUpdateVersionSpec{
- Agents: &autoupdatev1pb.AutoUpdateVersionSpecAgents{
- Mode: autoupdate.AgentsUpdateModeEnabled,
- StartVersion: "1.2.3",
- TargetVersion: "1.2.4",
- Schedule: autoupdate.AgentsScheduleImmediate,
- },
+ rollout: &autoupdatev1pb.AutoUpdateAgentRolloutSpec{
+ AutoupdateMode: autoupdate.AgentsUpdateModeEnabled,
+ Strategy: autoupdate.AgentsStrategyHaltOnError,
+ Schedule: autoupdate.AgentsScheduleImmediate,
+ StartVersion: "1.2.3",
+ TargetVersion: "1.2.4",
},
expected: webclient.AutoUpdateSettings{
ToolsVersion: api.Version,
@@ -354,20 +348,13 @@ func TestPing_autoUpdateResources(t *testing.T) {
cleanup: true,
},
{
- name: "version enable agent auto update, but config disables them",
- config: &autoupdatev1pb.AutoUpdateConfigSpec{
- Agents: &autoupdatev1pb.AutoUpdateConfigSpecAgents{
- Mode: autoupdate.AgentsUpdateModeDisabled,
- Strategy: autoupdate.AgentsStrategyHaltOnError,
- },
- },
- version: &autoupdatev1pb.AutoUpdateVersionSpec{
- Agents: &autoupdatev1pb.AutoUpdateVersionSpecAgents{
- Mode: autoupdate.AgentsUpdateModeEnabled,
- StartVersion: "1.2.3",
- TargetVersion: "1.2.4",
- Schedule: autoupdate.AgentsScheduleImmediate,
- },
+ name: "agent rollout present but AU mode is disabled",
+ rollout: &autoupdatev1pb.AutoUpdateAgentRolloutSpec{
+ AutoupdateMode: autoupdate.AgentsUpdateModeDisabled,
+ Strategy: autoupdate.AgentsStrategyHaltOnError,
+ Schedule: autoupdate.AgentsScheduleImmediate,
+ StartVersion: "1.2.3",
+ TargetVersion: "1.2.4",
},
expected: webclient.AutoUpdateSettings{
ToolsVersion: api.Version,
@@ -462,6 +449,12 @@ func TestPing_autoUpdateResources(t *testing.T) {
_, err = env.server.Auth().UpsertAutoUpdateVersion(ctx, version)
require.NoError(t, err)
}
+ if tc.rollout != nil {
+ rollout, err := autoupdate.NewAutoUpdateAgentRollout(tc.rollout)
+ require.NoError(t, err)
+ _, err = env.server.Auth().UpsertAutoUpdateAgentRollout(ctx, rollout)
+ require.NoError(t, err)
+ }
// expire the fn cache to force the next answer to be fresh
for _, proxy := range env.proxies {
@@ -480,6 +473,7 @@ func TestPing_autoUpdateResources(t *testing.T) {
if tc.cleanup {
require.NotErrorIs(t, env.server.Auth().DeleteAutoUpdateConfig(ctx), &trace.NotFoundError{})
require.NotErrorIs(t, env.server.Auth().DeleteAutoUpdateVersion(ctx), &trace.NotFoundError{})
+ require.NotErrorIs(t, env.server.Auth().DeleteAutoUpdateAgentRollout(ctx), &trace.NotFoundError{})
}
})
}
diff --git a/lib/web/autoupdate_common.go b/lib/web/autoupdate_common.go
new file mode 100644
index 0000000000000..1756172f4c6e4
--- /dev/null
+++ b/lib/web/autoupdate_common.go
@@ -0,0 +1,228 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package web
+
+import (
+ "context"
+ "strings"
+
+ "github.com/gravitational/trace"
+
+ autoupdatepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/autoupdate/v1"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/api/types/autoupdate"
+ "github.com/gravitational/teleport/lib/automaticupgrades"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+// autoUpdateAgentVersion returns the version the agent should install/update to based on
+// its group and updater UUID.
+// If the cluster contains an autoupdate_agent_rollout resource from RFD184 it should take precedence.
+// If the resource is not there, we fall back to RFD109-style updates with channels
+// and maintenance window derived from the cluster_maintenance_config resource.
+// Version returned follows semver without the leading "v".
+func (h *Handler) autoUpdateAgentVersion(ctx context.Context, group, updaterUUID string) (string, error) {
+ rollout, err := h.cfg.AccessPoint.GetAutoUpdateAgentRollout(ctx)
+ if err != nil {
+ // Fallback to channels if there is no autoupdate_agent_rollout.
+ if trace.IsNotFound(err) {
+ return getVersionFromChannel(ctx, h.cfg.AutomaticUpgradesChannels, group)
+ }
+ // Something is broken, we don't want to fallback to channels, this would be harmful.
+ return "", trace.Wrap(err, "getting autoupdate_agent_rollout")
+ }
+
+ return getVersionFromRollout(rollout, group, updaterUUID)
+}
+
+// autoUpdateAgentShouldUpdate returns if the agent should update now to based on its group
+// and updater UUID.
+// If the cluster contains an autoupdate_agent_rollout resource from RFD184 it should take precedence.
+// If the resource is not there, we fall back to RFD109-style updates with channels
+// and maintenance window derived from the cluster_maintenance_config resource.
+func (h *Handler) autoUpdateAgentShouldUpdate(ctx context.Context, group, updaterUUID string, windowLookup bool) (bool, error) {
+ rollout, err := h.cfg.AccessPoint.GetAutoUpdateAgentRollout(ctx)
+ if err != nil {
+ // Fallback to channels if there is no autoupdate_agent_rollout.
+ if trace.IsNotFound(err) {
+ // Updaters using the RFD184 API are not aware of maintenance windows
+ // like RFD109 updaters are. To have both updaters adopt the same behavior
+ // we must do the CMC window lookup for them.
+ if windowLookup {
+ return h.getTriggerFromWindowThenChannel(ctx, group)
+ }
+ return getTriggerFromChannel(ctx, h.cfg.AutomaticUpgradesChannels, group)
+ }
+ // Something is broken, we don't want to fallback to channels, this would be harmful.
+ return false, trace.Wrap(err, "failed to get auto-update rollout")
+ }
+
+ return getTriggerFromRollout(rollout, group, updaterUUID)
+}
+
+// getVersionFromRollout returns the version we should serve to the agent based
+// on the RFD184 agent rollout, the agent group name, and its UUID.
+// This logic is pretty complex and described in RFD 184.
+// The spec is summed up in the following table:
+// https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md#rollout-status-disabled
+// Version returned follows semver without the leading "v".
+func getVersionFromRollout(
+ rollout *autoupdatepb.AutoUpdateAgentRollout,
+ groupName, updaterUUID string,
+) (string, error) {
+ switch rollout.GetSpec().GetAutoupdateMode() {
+ case autoupdate.AgentsUpdateModeDisabled:
+ // If AUs are disabled, we always answer the target version
+ return rollout.GetSpec().GetTargetVersion(), nil
+ case autoupdate.AgentsUpdateModeSuspended, autoupdate.AgentsUpdateModeEnabled:
+ // If AUs are enabled or suspended, we modulate the response based on the schedule and agent group state
+ default:
+ return "", trace.BadParameter("unsupported agent update mode %q", rollout.GetSpec().GetAutoupdateMode())
+ }
+
+ // If the schedule is immediate, agents always update to the latest version
+ if rollout.GetSpec().GetSchedule() == autoupdate.AgentsScheduleImmediate {
+ return rollout.GetSpec().GetTargetVersion(), nil
+ }
+
+ // Else we follow the regular schedule and answer based on the agent group state
+ group, err := getGroup(rollout, groupName)
+ if err != nil {
+ return "", trace.Wrap(err, "getting group %q", groupName)
+ }
+
+ switch group.GetState() {
+ case autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK:
+ return rollout.GetSpec().GetStartVersion(), nil
+ case autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE:
+ return rollout.GetSpec().GetTargetVersion(), nil
+ default:
+ return "", trace.NotImplemented("unsupported group state %q", group.GetState())
+ }
+}
+
+// getTriggerFromRollout returns the version we should serve to the agent based
+// on the RFD184 agent rollout, the agent group name, and its UUID.
+// This logic is pretty complex and described in RFD 184.
+// The spec is summed up in the following table:
+// https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md#rollout-status-disabled
+func getTriggerFromRollout(rollout *autoupdatepb.AutoUpdateAgentRollout, groupName, updaterUUID string) (bool, error) {
+ // If the mode is "paused" or "disabled", we never tell to update
+ switch rollout.GetSpec().GetAutoupdateMode() {
+ case autoupdate.AgentsUpdateModeDisabled, autoupdate.AgentsUpdateModeSuspended:
+ // If AUs are disabled or suspended, never tell to update
+ return false, nil
+ case autoupdate.AgentsUpdateModeEnabled:
+ // If AUs are enabled, we modulate the response based on the schedule and agent group state
+ default:
+ return false, trace.BadParameter("unsupported agent update mode %q", rollout.GetSpec().GetAutoupdateMode())
+ }
+
+ // If the schedule is immediate, agents always update to the latest version
+ if rollout.GetSpec().GetSchedule() == autoupdate.AgentsScheduleImmediate {
+ return true, nil
+ }
+
+ // Else we follow the regular schedule and answer based on the agent group state
+ group, err := getGroup(rollout, groupName)
+ if err != nil {
+ return false, trace.Wrap(err, "getting group %q", groupName)
+ }
+
+ switch group.GetState() {
+ case autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED:
+ return false, nil
+ case autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK:
+ return true, nil
+ case autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE:
+ return rollout.GetSpec().GetStrategy() == autoupdate.AgentsStrategyHaltOnError, nil
+ default:
+ return false, trace.NotImplemented("Unsupported group state %q", group.GetState())
+ }
+}
+
+// getGroup returns the agent rollout group the requesting agent belongs to.
+// If a group matches the agent-provided group name, this group is returned.
+// Else the default group is returned. The default group currently is the last
+// one. This might change in the future.
+func getGroup(
+ rollout *autoupdatepb.AutoUpdateAgentRollout,
+ groupName string,
+) (*autoupdatepb.AutoUpdateAgentRolloutStatusGroup, error) {
+ groups := rollout.GetStatus().GetGroups()
+ if len(groups) == 0 {
+ return nil, trace.BadParameter("no groups found")
+ }
+
+ // Try to find a group with our name
+ for _, group := range groups {
+ if group.Name == groupName {
+ return group, nil
+ }
+ }
+
+ // Fallback to the default group (currently the last one but this might change).
+ return groups[len(groups)-1], nil
+}
+
+// getVersionFromChannel gets the target version from the RFD109 channels.
+// Version returned follows semver without the leading "v".
+func getVersionFromChannel(ctx context.Context, channels automaticupgrades.Channels, groupName string) (version string, err error) {
+ // RFD109 channels return the version with the 'v' prefix.
+ // We can't change the internals for backward compatibility, so we must trim the prefix if it's here.
+ defer func() {
+ version = strings.TrimPrefix(version, "v")
+ }()
+
+ if channel, ok := channels[groupName]; ok {
+ return channel.GetVersion(ctx)
+ }
+ return channels.DefaultVersion(ctx)
+}
+
+// getTriggerFromWindowThenChannel gets the target version from the RFD109 maintenance window and channels.
+func (h *Handler) getTriggerFromWindowThenChannel(ctx context.Context, groupName string) (bool, error) {
+ // Caching the CMC for 10 seconds because this resource is cached neither by the auth nor the proxy.
+ // And this function can be accessed via unauthenticated endpoints.
+ cmc, err := utils.FnCacheGet[types.ClusterMaintenanceConfig](ctx, h.clusterMaintenanceConfigCache, "cmc", func(ctx context.Context) (types.ClusterMaintenanceConfig, error) {
+ return h.cfg.ProxyClient.GetClusterMaintenanceConfig(ctx)
+ })
+
+ // If we have a CMC, we check if the window is active, else we just check if the update is critical.
+ if err == nil && cmc.WithinUpgradeWindow(h.clock.Now()) {
+ return true, nil
+ }
+
+ return getTriggerFromChannel(ctx, h.cfg.AutomaticUpgradesChannels, groupName)
+}
+
+// getTriggerFromWindowThenChannel gets the target version from the RFD109 channels.
+func getTriggerFromChannel(ctx context.Context, channels automaticupgrades.Channels, groupName string) (bool, error) {
+ if channel, ok := channels[groupName]; ok {
+ return channel.GetCritical(ctx)
+ }
+ defaultChannel, err := channels.DefaultChannel()
+ if err != nil {
+ return false, trace.Wrap(err, "creating new default channel")
+ }
+ return defaultChannel.GetCritical(ctx)
+}
diff --git a/lib/web/autoupdate_common_test.go b/lib/web/autoupdate_common_test.go
new file mode 100644
index 0000000000000..a365ac121b078
--- /dev/null
+++ b/lib/web/autoupdate_common_test.go
@@ -0,0 +1,796 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package web
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/gravitational/trace"
+ "github.com/jonboulle/clockwork"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ autoupdatepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/autoupdate/v1"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/api/types/autoupdate"
+ "github.com/gravitational/teleport/lib/auth/authclient"
+ "github.com/gravitational/teleport/lib/automaticupgrades"
+ "github.com/gravitational/teleport/lib/automaticupgrades/constants"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+const (
+ testVersionHigh = "2.3.4"
+ testVersionLow = "2.0.4"
+)
+
+// fakeRolloutAccessPoint allows us to mock the ProxyAccessPoint in autoupdate
+// tests.
+type fakeRolloutAccessPoint struct {
+ authclient.ProxyAccessPoint
+
+ rollout *autoupdatepb.AutoUpdateAgentRollout
+ err error
+}
+
+func (ap *fakeRolloutAccessPoint) GetAutoUpdateAgentRollout(_ context.Context) (*autoupdatepb.AutoUpdateAgentRollout, error) {
+ return ap.rollout, ap.err
+}
+
+// fakeRolloutAccessPoint allows us to mock the proxy's auth client in autoupdate
+// tests.
+type fakeCMCAuthClient struct {
+ authclient.ClientI
+
+ cmc types.ClusterMaintenanceConfig
+ err error
+}
+
+func (c *fakeCMCAuthClient) GetClusterMaintenanceConfig(_ context.Context) (types.ClusterMaintenanceConfig, error) {
+ return c.cmc, c.err
+}
+
+func TestAutoUpdateAgentVersion(t *testing.T) {
+ t.Parallel()
+ groupName := "test-group"
+ ctx := context.Background()
+
+ // brokenChannelUpstream is a buggy upstream version server.
+ // This allows us to craft version channels returning errors.
+ brokenChannelUpstream := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusBadRequest)
+ }))
+ t.Cleanup(brokenChannelUpstream.Close)
+
+ tests := []struct {
+ name string
+ rollout *autoupdatepb.AutoUpdateAgentRollout
+ rolloutErr error
+ channel *automaticupgrades.Channel
+ expectedVersion string
+ expectError require.ErrorAssertionFunc
+ }{
+ {
+ name: "version is looked up from rollout if it is here",
+ rollout: &autoupdatepb.AutoUpdateAgentRollout{
+ Spec: &autoupdatepb.AutoUpdateAgentRolloutSpec{
+ AutoupdateMode: autoupdate.AgentsUpdateModeEnabled,
+ TargetVersion: testVersionHigh,
+ Schedule: autoupdate.AgentsScheduleImmediate,
+ },
+ },
+ channel: &automaticupgrades.Channel{StaticVersion: testVersionLow},
+ expectError: require.NoError,
+ expectedVersion: testVersionHigh,
+ },
+ {
+ name: "version is looked up from channel if rollout is not here",
+ rolloutErr: trace.NotFound("rollout is not here"),
+ channel: &automaticupgrades.Channel{StaticVersion: testVersionLow},
+ expectError: require.NoError,
+ expectedVersion: testVersionLow,
+ },
+ {
+ name: "hard error getting rollout should not fallback to version channels",
+ rolloutErr: trace.AccessDenied("something is very broken"),
+ channel: &automaticupgrades.Channel{
+ StaticVersion: testVersionLow,
+ },
+ expectError: require.Error,
+ },
+ {
+ name: "no rollout, error checking channel",
+ rolloutErr: trace.NotFound("rollout is not here"),
+ channel: &automaticupgrades.Channel{ForwardURL: brokenChannelUpstream.URL},
+ expectError: require.Error,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Test setup: building the channel, mock client, and handler with test config.
+ require.NoError(t, tt.channel.CheckAndSetDefaults())
+ h := &Handler{
+ cfg: Config{
+ AccessPoint: &fakeRolloutAccessPoint{
+ rollout: tt.rollout,
+ err: tt.rolloutErr,
+ },
+ AutomaticUpgradesChannels: map[string]*automaticupgrades.Channel{
+ groupName: tt.channel,
+ },
+ },
+ }
+
+ // Test execution
+ result, err := h.autoUpdateAgentVersion(ctx, groupName, "")
+ tt.expectError(t, err)
+ require.Equal(t, tt.expectedVersion, result)
+ })
+ }
+}
+
+// TestAutoUpdateAgentShouldUpdate also accidentally tests getTriggerFromWindowThenChannel.
+func TestAutoUpdateAgentShouldUpdate(t *testing.T) {
+ t.Parallel()
+
+ groupName := "test-group"
+ ctx := context.Background()
+
+ // brokenChannelUpstream is a buggy upstream version server.
+ // This allows us to craft version channels returning errors.
+ brokenChannelUpstream := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusBadRequest)
+ }))
+ t.Cleanup(brokenChannelUpstream.Close)
+
+ clock := clockwork.NewFakeClock()
+ cmcCache, err := utils.NewFnCache(utils.FnCacheConfig{
+ TTL: findEndpointCacheTTL,
+ Clock: clock,
+ Context: ctx,
+ ReloadOnErr: false,
+ })
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ cmcCache.Shutdown(ctx)
+ })
+
+ activeUpgradeWindow := types.AgentUpgradeWindow{UTCStartHour: uint32(clock.Now().Hour())}
+ inactiveUpgradeWindow := types.AgentUpgradeWindow{UTCStartHour: uint32(clock.Now().Add(2 * time.Hour).Hour())}
+ tests := []struct {
+ name string
+ rollout *autoupdatepb.AutoUpdateAgentRollout
+ rolloutErr error
+ channel *automaticupgrades.Channel
+ upgradeWindow types.AgentUpgradeWindow
+ cmcErr error
+ windowLookup bool
+ expectedTrigger bool
+ expectError require.ErrorAssertionFunc
+ }{
+ {
+ name: "trigger is looked up from rollout if it is here, trigger firing",
+ rollout: &autoupdatepb.AutoUpdateAgentRollout{
+ Spec: &autoupdatepb.AutoUpdateAgentRolloutSpec{
+ AutoupdateMode: autoupdate.AgentsUpdateModeEnabled,
+ TargetVersion: testVersionHigh,
+ Schedule: autoupdate.AgentsScheduleImmediate,
+ },
+ },
+ channel: &automaticupgrades.Channel{StaticVersion: testVersionLow},
+ expectError: require.NoError,
+ expectedTrigger: true,
+ },
+ {
+ name: "trigger is looked up from rollout if it is here, trigger not firing",
+ rollout: &autoupdatepb.AutoUpdateAgentRollout{
+ Spec: &autoupdatepb.AutoUpdateAgentRolloutSpec{
+ AutoupdateMode: autoupdate.AgentsUpdateModeDisabled,
+ TargetVersion: testVersionHigh,
+ Schedule: autoupdate.AgentsScheduleImmediate,
+ },
+ },
+ channel: &automaticupgrades.Channel{StaticVersion: testVersionLow},
+ expectError: require.NoError,
+ expectedTrigger: false,
+ },
+ {
+ name: "trigger is looked up from channel if rollout is not here and window lookup is disabled, trigger not firing",
+ rolloutErr: trace.NotFound("rollout is not here"),
+ channel: &automaticupgrades.Channel{
+ StaticVersion: testVersionLow,
+ Critical: false,
+ },
+ expectError: require.NoError,
+ expectedTrigger: false,
+ },
+ {
+ name: "trigger is looked up from channel if rollout is not here and window lookup is disabled, trigger firing",
+ rolloutErr: trace.NotFound("rollout is not here"),
+ channel: &automaticupgrades.Channel{
+ StaticVersion: testVersionLow,
+ Critical: true,
+ },
+ expectError: require.NoError,
+ expectedTrigger: true,
+ },
+ {
+ name: "trigger is looked up from cmc, then channel if rollout is not here and window lookup is enabled, cmc firing",
+ rolloutErr: trace.NotFound("rollout is not here"),
+ channel: &automaticupgrades.Channel{
+ StaticVersion: testVersionLow,
+ Critical: false,
+ },
+ upgradeWindow: activeUpgradeWindow,
+ windowLookup: true,
+ expectError: require.NoError,
+ expectedTrigger: true,
+ },
+ {
+ name: "trigger is looked up from cmc, then channel if rollout is not here and window lookup is enabled, cmc not firing",
+ rolloutErr: trace.NotFound("rollout is not here"),
+ channel: &automaticupgrades.Channel{
+ StaticVersion: testVersionLow,
+ Critical: false,
+ },
+ upgradeWindow: inactiveUpgradeWindow,
+ windowLookup: true,
+ expectError: require.NoError,
+ expectedTrigger: false,
+ },
+ {
+ name: "trigger is looked up from cmc, then channel if rollout is not here and window lookup is enabled, cmc not firing but channel firing",
+ rolloutErr: trace.NotFound("rollout is not here"),
+ channel: &automaticupgrades.Channel{
+ StaticVersion: testVersionLow,
+ Critical: true,
+ },
+ upgradeWindow: inactiveUpgradeWindow,
+ windowLookup: true,
+ expectError: require.NoError,
+ expectedTrigger: true,
+ },
+ {
+ name: "trigger is looked up from cmc, then channel if rollout is not here and window lookup is enabled, no cmc and channel not firing",
+ rolloutErr: trace.NotFound("rollout is not here"),
+ channel: &automaticupgrades.Channel{
+ StaticVersion: testVersionLow,
+ Critical: false,
+ },
+ cmcErr: trace.NotFound("no cmc for this cluster"),
+ windowLookup: true,
+ expectError: require.NoError,
+ expectedTrigger: false,
+ },
+ {
+ name: "trigger is looked up from cmc, then channel if rollout is not here and window lookup is enabled, no cmc and channel firing",
+ rolloutErr: trace.NotFound("rollout is not here"),
+ channel: &automaticupgrades.Channel{
+ StaticVersion: testVersionLow,
+ Critical: true,
+ },
+ cmcErr: trace.NotFound("no cmc for this cluster"),
+ windowLookup: true,
+ expectError: require.NoError,
+ expectedTrigger: true,
+ },
+ {
+ name: "hard error getting rollout should not fallback to RFD109 trigger",
+ rolloutErr: trace.AccessDenied("something is very broken"),
+ channel: &automaticupgrades.Channel{
+ StaticVersion: testVersionLow,
+ },
+ expectError: require.Error,
+ },
+ {
+ name: "no rollout, error checking channel",
+ rolloutErr: trace.NotFound("rollout is not here"),
+ channel: &automaticupgrades.Channel{
+ ForwardURL: brokenChannelUpstream.URL,
+ },
+ expectError: require.Error,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Test setup: building the channel, mock clients, and handler with test config.
+ cmc := types.NewClusterMaintenanceConfig()
+ cmc.SetAgentUpgradeWindow(tt.upgradeWindow)
+ require.NoError(t, tt.channel.CheckAndSetDefaults())
+ // Advance clock to invalidate cache
+ clock.Advance(2 * findEndpointCacheTTL)
+ h := &Handler{
+ cfg: Config{
+ AccessPoint: &fakeRolloutAccessPoint{
+ rollout: tt.rollout,
+ err: tt.rolloutErr,
+ },
+ ProxyClient: &fakeCMCAuthClient{
+ cmc: cmc,
+ err: tt.cmcErr,
+ },
+ AutomaticUpgradesChannels: map[string]*automaticupgrades.Channel{
+ groupName: tt.channel,
+ },
+ },
+ clock: clock,
+ clusterMaintenanceConfigCache: cmcCache,
+ }
+
+ // Test execution
+ result, err := h.autoUpdateAgentShouldUpdate(ctx, groupName, "", tt.windowLookup)
+ tt.expectError(t, err)
+ require.Equal(t, tt.expectedTrigger, result)
+ })
+ }
+}
+
+func TestGetVersionFromRollout(t *testing.T) {
+ t.Parallel()
+ groupName := "test-group"
+
+ // This test matrix is written based on:
+ // https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md#rollout-status-disabled
+ latestAllTheTime := map[autoupdatepb.AutoUpdateAgentGroupState]string{
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: testVersionHigh,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: testVersionHigh,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: testVersionHigh,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: testVersionHigh,
+ }
+
+ activeDoneOnly := map[autoupdatepb.AutoUpdateAgentGroupState]string{
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: testVersionLow,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: testVersionHigh,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: testVersionHigh,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: testVersionLow,
+ }
+
+ tests := map[string]map[string]map[autoupdatepb.AutoUpdateAgentGroupState]string{
+ autoupdate.AgentsUpdateModeDisabled: {
+ autoupdate.AgentsScheduleImmediate: latestAllTheTime,
+ autoupdate.AgentsScheduleRegular: latestAllTheTime,
+ },
+ autoupdate.AgentsUpdateModeSuspended: {
+ autoupdate.AgentsScheduleImmediate: latestAllTheTime,
+ autoupdate.AgentsScheduleRegular: activeDoneOnly,
+ },
+ autoupdate.AgentsUpdateModeEnabled: {
+ autoupdate.AgentsScheduleImmediate: latestAllTheTime,
+ autoupdate.AgentsScheduleRegular: activeDoneOnly,
+ },
+ }
+ for mode, scheduleCases := range tests {
+ for schedule, stateCases := range scheduleCases {
+ for state, expectedVersion := range stateCases {
+ t.Run(fmt.Sprintf("%s/%s/%s", mode, schedule, state), func(t *testing.T) {
+ rollout := &autoupdatepb.AutoUpdateAgentRollout{
+ Spec: &autoupdatepb.AutoUpdateAgentRolloutSpec{
+ StartVersion: testVersionLow,
+ TargetVersion: testVersionHigh,
+ Schedule: schedule,
+ AutoupdateMode: mode,
+ // Strategy does not affect which version are served
+ Strategy: autoupdate.AgentsStrategyTimeBased,
+ },
+ Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{
+ Groups: []*autoupdatepb.AutoUpdateAgentRolloutStatusGroup{
+ {
+ Name: groupName,
+ State: state,
+ },
+ },
+ },
+ }
+ version, err := getVersionFromRollout(rollout, groupName, "")
+ require.NoError(t, err)
+ require.Equal(t, expectedVersion, version)
+ })
+ }
+ }
+ }
+}
+
+func TestGetTriggerFromRollout(t *testing.T) {
+ t.Parallel()
+ groupName := "test-group"
+
+ // This test matrix is written based on:
+ // https://github.com/gravitational/teleport/blob/master/rfd/0184-agent-auto-updates.md#rollout-status-disabled
+ neverUpdate := map[autoupdatepb.AutoUpdateAgentGroupState]bool{
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: false,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: false,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: false,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: false,
+ }
+ alwaysUpdate := map[autoupdatepb.AutoUpdateAgentGroupState]bool{
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: true,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: true,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: true,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: true,
+ }
+
+ tests := map[string]map[string]map[string]map[autoupdatepb.AutoUpdateAgentGroupState]bool{
+ autoupdate.AgentsUpdateModeDisabled: {
+ autoupdate.AgentsStrategyTimeBased: {
+ autoupdate.AgentsScheduleImmediate: neverUpdate,
+ autoupdate.AgentsScheduleRegular: neverUpdate,
+ },
+ autoupdate.AgentsStrategyHaltOnError: {
+ autoupdate.AgentsScheduleImmediate: neverUpdate,
+ autoupdate.AgentsScheduleRegular: neverUpdate,
+ },
+ },
+ autoupdate.AgentsUpdateModeSuspended: {
+ autoupdate.AgentsStrategyTimeBased: {
+ autoupdate.AgentsScheduleImmediate: neverUpdate,
+ autoupdate.AgentsScheduleRegular: neverUpdate,
+ },
+ autoupdate.AgentsStrategyHaltOnError: {
+ autoupdate.AgentsScheduleImmediate: neverUpdate,
+ autoupdate.AgentsScheduleRegular: neverUpdate,
+ },
+ },
+ autoupdate.AgentsUpdateModeEnabled: {
+ autoupdate.AgentsStrategyTimeBased: {
+ autoupdate.AgentsScheduleImmediate: alwaysUpdate,
+ autoupdate.AgentsScheduleRegular: {
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: false,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: false,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: true,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: true,
+ },
+ },
+ autoupdate.AgentsStrategyHaltOnError: {
+ autoupdate.AgentsScheduleImmediate: alwaysUpdate,
+ autoupdate.AgentsScheduleRegular: {
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_UNSTARTED: false,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_DONE: true,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ACTIVE: true,
+ autoupdatepb.AutoUpdateAgentGroupState_AUTO_UPDATE_AGENT_GROUP_STATE_ROLLEDBACK: true,
+ },
+ },
+ },
+ }
+ for mode, strategyCases := range tests {
+ for strategy, scheduleCases := range strategyCases {
+ for schedule, stateCases := range scheduleCases {
+ for state, expectedTrigger := range stateCases {
+ t.Run(fmt.Sprintf("%s/%s/%s/%s", mode, strategy, schedule, state), func(t *testing.T) {
+ rollout := &autoupdatepb.AutoUpdateAgentRollout{
+ Spec: &autoupdatepb.AutoUpdateAgentRolloutSpec{
+ StartVersion: testVersionLow,
+ TargetVersion: testVersionHigh,
+ Schedule: schedule,
+ AutoupdateMode: mode,
+ Strategy: strategy,
+ },
+ Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{
+ Groups: []*autoupdatepb.AutoUpdateAgentRolloutStatusGroup{
+ {
+ Name: groupName,
+ State: state,
+ },
+ },
+ },
+ }
+ shouldUpdate, err := getTriggerFromRollout(rollout, groupName, "")
+ require.NoError(t, err)
+ require.Equal(t, expectedTrigger, shouldUpdate)
+ })
+ }
+ }
+ }
+ }
+}
+
+func TestGetGroup(t *testing.T) {
+ groupName := "test-group"
+ t.Parallel()
+ tests := []struct {
+ name string
+ rollout *autoupdatepb.AutoUpdateAgentRollout
+ expectedResult *autoupdatepb.AutoUpdateAgentRolloutStatusGroup
+ expectError require.ErrorAssertionFunc
+ }{
+ {
+ name: "nil",
+ expectError: require.Error,
+ },
+ {
+ name: "nil status",
+ rollout: &autoupdatepb.AutoUpdateAgentRollout{},
+ expectError: require.Error,
+ },
+ {
+ name: "nil status groups",
+ rollout: &autoupdatepb.AutoUpdateAgentRollout{Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{}},
+ expectError: require.Error,
+ },
+ {
+ name: "empty status groups",
+ rollout: &autoupdatepb.AutoUpdateAgentRollout{
+ Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{
+ Groups: []*autoupdatepb.AutoUpdateAgentRolloutStatusGroup{},
+ },
+ },
+ expectError: require.Error,
+ },
+ {
+ name: "group matching name",
+ rollout: &autoupdatepb.AutoUpdateAgentRollout{
+ Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{
+ Groups: []*autoupdatepb.AutoUpdateAgentRolloutStatusGroup{
+ {Name: "foo", State: 1},
+ {Name: "bar", State: 1},
+ {Name: groupName, State: 2},
+ {Name: "baz", State: 1},
+ },
+ },
+ },
+ expectedResult: &autoupdatepb.AutoUpdateAgentRolloutStatusGroup{
+ Name: groupName,
+ State: 2,
+ },
+ expectError: require.NoError,
+ },
+ {
+ name: "no group matching name, should fallback to default",
+ rollout: &autoupdatepb.AutoUpdateAgentRollout{
+ Status: &autoupdatepb.AutoUpdateAgentRolloutStatus{
+ Groups: []*autoupdatepb.AutoUpdateAgentRolloutStatusGroup{
+ {Name: "foo", State: 1},
+ {Name: "bar", State: 1},
+ {Name: "baz", State: 1},
+ },
+ },
+ },
+ expectedResult: &autoupdatepb.AutoUpdateAgentRolloutStatusGroup{
+ Name: "baz",
+ State: 1,
+ },
+ expectError: require.NoError,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := getGroup(tt.rollout, groupName)
+ tt.expectError(t, err)
+ require.Equal(t, tt.expectedResult, result)
+ })
+ }
+}
+
+type mockRFD109VersionServer struct {
+ t *testing.T
+ channels map[string]channelStub
+}
+
+type channelStub struct {
+ // with our without the leading "v"
+ version string
+ critical bool
+ fail bool
+}
+
+func (m *mockRFD109VersionServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ var path string
+ var writeResp func(w http.ResponseWriter, stub channelStub) error
+
+ switch {
+ case strings.HasSuffix(r.URL.Path, constants.VersionPath):
+ path = strings.Trim(strings.TrimSuffix(r.URL.Path, constants.VersionPath), "/")
+ writeResp = func(w http.ResponseWriter, stub channelStub) error {
+ _, err := w.Write([]byte(stub.version))
+ return err
+ }
+ case strings.HasSuffix(r.URL.Path, constants.MaintenancePath):
+ path = strings.Trim(strings.TrimSuffix(r.URL.Path, constants.MaintenancePath), "/")
+ writeResp = func(w http.ResponseWriter, stub channelStub) error {
+ response := "no"
+ if stub.critical {
+ response = "yes"
+ }
+ _, err := w.Write([]byte(response))
+ return err
+ }
+ default:
+ assert.Fail(m.t, "unsupported path %q", r.URL.Path)
+ w.WriteHeader(http.StatusNotFound)
+ return
+ }
+
+ channel, ok := m.channels[path]
+ if !ok {
+ w.WriteHeader(http.StatusNotFound)
+ assert.Fail(m.t, "channel %q not found", path)
+ return
+ }
+ if channel.fail {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ assert.NoError(m.t, writeResp(w, channel), "failed to write response")
+}
+
+func TestGetVersionFromChannel(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ channelName := "test-channel"
+
+ mock := mockRFD109VersionServer{
+ t: t,
+ channels: map[string]channelStub{
+ "broken": {fail: true},
+ "with-leading-v": {version: "v" + testVersionHigh},
+ "without-leading-v": {version: testVersionHigh},
+ "low": {version: testVersionLow},
+ },
+ }
+ srv := httptest.NewServer(http.HandlerFunc(mock.ServeHTTP))
+ t.Cleanup(srv.Close)
+
+ tests := []struct {
+ name string
+ channels automaticupgrades.Channels
+ expectedResult string
+ expectError require.ErrorAssertionFunc
+ }{
+ {
+ name: "channel with leading v",
+ channels: automaticupgrades.Channels{
+ channelName: {ForwardURL: srv.URL + "/with-leading-v"},
+ "default": {ForwardURL: srv.URL + "/low"},
+ },
+ expectedResult: testVersionHigh,
+ expectError: require.NoError,
+ },
+ {
+ name: "channel without leading v",
+ channels: automaticupgrades.Channels{
+ channelName: {ForwardURL: srv.URL + "/without-leading-v"},
+ "default": {ForwardURL: srv.URL + "/low"},
+ },
+ expectedResult: testVersionHigh,
+ expectError: require.NoError,
+ },
+ {
+ name: "fallback to default with leading v",
+ channels: automaticupgrades.Channels{
+ "default": {ForwardURL: srv.URL + "/with-leading-v"},
+ },
+ expectedResult: testVersionHigh,
+ expectError: require.NoError,
+ },
+ {
+ name: "fallback to default without leading v",
+ channels: automaticupgrades.Channels{
+ "default": {ForwardURL: srv.URL + "/without-leading-v"},
+ },
+ expectedResult: testVersionHigh,
+ expectError: require.NoError,
+ },
+ {
+ name: "broken channel",
+ channels: automaticupgrades.Channels{
+ channelName: {ForwardURL: srv.URL + "/broken"},
+ "default": {ForwardURL: srv.URL + "/without-leading-v"},
+ },
+ expectError: require.Error,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Test setup
+ require.NoError(t, tt.channels.CheckAndSetDefaults())
+
+ // Test execution
+ result, err := getVersionFromChannel(ctx, tt.channels, channelName)
+ tt.expectError(t, err)
+ require.Equal(t, tt.expectedResult, result)
+ })
+ }
+}
+
+func TestGetTriggerFromChannel(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+
+ channelName := "test-channel"
+
+ mock := mockRFD109VersionServer{
+ t: t,
+ channels: map[string]channelStub{
+ "broken": {fail: true},
+ "critical": {critical: true},
+ "non-critical": {critical: false},
+ },
+ }
+ srv := httptest.NewServer(http.HandlerFunc(mock.ServeHTTP))
+ t.Cleanup(srv.Close)
+
+ tests := []struct {
+ name string
+ channels automaticupgrades.Channels
+ expectedResult bool
+ expectError require.ErrorAssertionFunc
+ }{
+ {
+ name: "critical channel",
+ channels: automaticupgrades.Channels{
+ channelName: {ForwardURL: srv.URL + "/critical"},
+ "default": {ForwardURL: srv.URL + "/non-critical"},
+ },
+ expectedResult: true,
+ expectError: require.NoError,
+ },
+ {
+ name: "non-critical channel",
+ channels: automaticupgrades.Channels{
+ channelName: {ForwardURL: srv.URL + "/non-critical"},
+ "default": {ForwardURL: srv.URL + "/critical"},
+ },
+ expectedResult: false,
+ expectError: require.NoError,
+ },
+ {
+ name: "fallback to default which is critical",
+ channels: automaticupgrades.Channels{
+ "default": {ForwardURL: srv.URL + "/critical"},
+ },
+ expectedResult: true,
+ expectError: require.NoError,
+ },
+ {
+ name: "fallback to default which is non-critical",
+ channels: automaticupgrades.Channels{
+ "default": {ForwardURL: srv.URL + "/non-critical"},
+ },
+ expectedResult: false,
+ expectError: require.NoError,
+ },
+ {
+ name: "broken channel",
+ channels: automaticupgrades.Channels{
+ channelName: {ForwardURL: srv.URL + "/broken"},
+ "default": {ForwardURL: srv.URL + "/critical"},
+ },
+ expectError: require.Error,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Test setup
+ require.NoError(t, tt.channels.CheckAndSetDefaults())
+
+ // Test execution
+ result, err := getTriggerFromChannel(ctx, tt.channels, channelName)
+ tt.expectError(t, err)
+ require.Equal(t, tt.expectedResult, result)
+ })
+ }
+}
diff --git a/lib/web/autoupdate_rfd109.go b/lib/web/autoupdate_rfd109.go
index b09b151220754..3bbdd0175b106 100644
--- a/lib/web/autoupdate_rfd109.go
+++ b/lib/web/autoupdate_rfd109.go
@@ -21,6 +21,7 @@ package web
import (
"context"
"errors"
+ "fmt"
"net/http"
"strings"
"time"
@@ -28,7 +29,6 @@ import (
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
- "github.com/gravitational/teleport/lib/automaticupgrades"
"github.com/gravitational/teleport/lib/automaticupgrades/constants"
"github.com/gravitational/teleport/lib/automaticupgrades/version"
)
@@ -59,31 +59,25 @@ func (h *Handler) automaticUpgrades109(w http.ResponseWriter, r *http.Request, p
return nil, trace.BadParameter("a channel name is required")
}
- // We check if the channel is configured
- channel, ok := h.cfg.AutomaticUpgradesChannels[channelName]
- if !ok {
- return nil, trace.NotFound("channel %s not found", channelName)
- }
-
// Finally, we treat the request based on its type
switch requestType {
case "version":
h.log.Debugf("Agent requesting version for channel %s", channelName)
- return h.automaticUpgradesVersion109(w, r, channel)
+ return h.automaticUpgradesVersion109(w, r, channelName)
case "critical":
h.log.Debugf("Agent requesting criticality for channel %s", channelName)
- return h.automaticUpgradesCritical109(w, r, channel)
+ return h.automaticUpgradesCritical109(w, r, channelName)
default:
return nil, trace.BadParameter("requestType path must end with 'version' or 'critical'")
}
}
// automaticUpgradesVersion109 handles version requests from upgraders
-func (h *Handler) automaticUpgradesVersion109(w http.ResponseWriter, r *http.Request, channel *automaticupgrades.Channel) (interface{}, error) {
+func (h *Handler) automaticUpgradesVersion109(w http.ResponseWriter, r *http.Request, channelName string) (interface{}, error) {
ctx, cancel := context.WithTimeout(r.Context(), defaultChannelTimeout)
defer cancel()
- targetVersion, err := channel.GetVersion(ctx)
+ targetVersion, err := h.autoUpdateAgentVersion(ctx, channelName, "" /* updater UUID */)
if err != nil {
// If the error is that the upstream channel has no version
// We gracefully handle by serving "none"
@@ -96,16 +90,20 @@ func (h *Handler) automaticUpgradesVersion109(w http.ResponseWriter, r *http.Req
return nil, trace.Wrap(err)
}
- _, err = w.Write([]byte(targetVersion))
+ // RFD 109 specifies that version from channels must have the leading "v".
+ // As h.autoUpdateAgentVersion doesn't, we must add it.
+ _, err = fmt.Fprintf(w, "v%s", targetVersion)
return nil, trace.Wrap(err)
}
// automaticUpgradesCritical109 handles criticality requests from upgraders
-func (h *Handler) automaticUpgradesCritical109(w http.ResponseWriter, r *http.Request, channel *automaticupgrades.Channel) (interface{}, error) {
+func (h *Handler) automaticUpgradesCritical109(w http.ResponseWriter, r *http.Request, channelName string) (interface{}, error) {
ctx, cancel := context.WithTimeout(r.Context(), defaultChannelTimeout)
defer cancel()
- critical, err := channel.GetCritical(ctx)
+ // RFD109 agents already retrieve maintenance windows from the CMC, no need to
+ // do a maintenance window lookup for them.
+ critical, err := h.autoUpdateAgentShouldUpdate(ctx, channelName, "" /* updater UUID */, false /* window lookup */)
if err != nil {
return nil, trace.Wrap(err)
}
diff --git a/lib/web/autoupdate_rfd184.go b/lib/web/autoupdate_rfd184.go
index 4c3ccdeaef907..6ac532650cb64 100644
--- a/lib/web/autoupdate_rfd184.go
+++ b/lib/web/autoupdate_rfd184.go
@@ -23,6 +23,7 @@ import (
"github.com/gravitational/trace"
+ "github.com/gravitational/teleport"
"github.com/gravitational/teleport/api"
"github.com/gravitational/teleport/api/client/webclient"
autoupdatepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/autoupdate/v1"
@@ -31,8 +32,8 @@ import (
// automaticUpdateSettings184 crafts the automatic updates part of the ping/find response
// as described in RFD-184 (agents) and RFD-144 (tools).
-// TODO: add the request as a parameter when we'll need to modulate the content based on the UUID and group
-func (h *Handler) automaticUpdateSettings184(ctx context.Context) webclient.AutoUpdateSettings {
+func (h *Handler) automaticUpdateSettings184(ctx context.Context, group, updaterUUID string) webclient.AutoUpdateSettings {
+ // Tools auto updates section.
autoUpdateConfig, err := h.cfg.AccessPoint.GetAutoUpdateConfig(ctx)
// TODO(vapopov) DELETE IN v18.0.0 check of IsNotImplemented, must be backported to all latest supported versions.
if err != nil && !trace.IsNotFound(err) && !trace.IsNotImplemented(err) {
@@ -45,12 +46,29 @@ func (h *Handler) automaticUpdateSettings184(ctx context.Context) webclient.Auto
h.logger.ErrorContext(ctx, "failed to receive AutoUpdateVersion", "error", err)
}
+ // Agent auto updates section.
+ agentVersion, err := h.autoUpdateAgentVersion(ctx, group, updaterUUID)
+ if err != nil {
+ h.logger.ErrorContext(ctx, "failed to resolve AgentVersion", "error", err)
+ // Defaulting to current version
+ agentVersion = teleport.Version
+ }
+ // If the source of truth is RFD 109 configuration (channels + CMC) we must emulate the
+ // RFD109 agent maintenance window behavior by looking up the CMC and checking if
+ // we are in a maintenance window.
+ shouldUpdate, err := h.autoUpdateAgentShouldUpdate(ctx, group, updaterUUID, true /* window lookup */)
+ if err != nil {
+ h.logger.ErrorContext(ctx, "failed to resolve AgentAutoUpdate", "error", err)
+ // Failing open
+ shouldUpdate = false
+ }
+
return webclient.AutoUpdateSettings{
ToolsAutoUpdate: getToolsAutoUpdate(autoUpdateConfig),
ToolsVersion: getToolsVersion(autoUpdateVersion),
AgentUpdateJitterSeconds: DefaultAgentUpdateJitterSeconds,
- AgentVersion: getAgentVersion184(autoUpdateVersion),
- AgentAutoUpdate: agentShouldUpdate184(autoUpdateConfig, autoUpdateVersion),
+ AgentVersion: agentVersion,
+ AgentAutoUpdate: shouldUpdate,
}
}
@@ -73,39 +91,3 @@ func getToolsVersion(version *autoupdatepb.AutoUpdateVersion) string {
}
return version.GetSpec().GetTools().GetTargetVersion()
}
-
-func getAgentVersion184(version *autoupdatepb.AutoUpdateVersion) string {
- // If we can't get the AU version or tools AU version is not specified, we default to the current proxy version.
- // This ensures we always advertise a version compatible with the cluster.
- // TODO: read the version from the autoupdate_agent_rollout when the resource is implemented
- if version.GetSpec().GetAgents() == nil {
- return api.Version
- }
-
- return version.GetSpec().GetAgents().GetTargetVersion()
-}
-
-func agentShouldUpdate184(config *autoupdatepb.AutoUpdateConfig, version *autoupdatepb.AutoUpdateVersion) bool {
- // TODO: read the data from the autoupdate_agent_rollout when the resource is implemented
-
- // If we can't get the AU config or if AUs are not configured, we default to "disabled".
- // This ensures we fail open and don't accidentally update agents if something is going wrong.
- // If we want to enable AUs by default, it would be better to create a default "autoupdate_config" resource
- // than changing this logic.
- if config.GetSpec().GetAgents() == nil {
- return false
- }
- if version.GetSpec().GetAgents() == nil {
- return false
- }
- configMode := config.GetSpec().GetAgents().GetMode()
- versionMode := version.GetSpec().GetAgents().GetMode()
-
- // We update only if both version and config agent modes are "enabled"
- if configMode != autoupdate.AgentsUpdateModeEnabled || versionMode != autoupdate.AgentsUpdateModeEnabled {
- return false
- }
-
- scheduleName := version.GetSpec().GetAgents().GetSchedule()
- return scheduleName == autoupdate.AgentsScheduleImmediate
-}