diff --git a/api/client/webclient/webclient.go b/api/client/webclient/webclient.go index f3b6ba5586768..d14ef610472f7 100644 --- a/api/client/webclient/webclient.go +++ b/api/client/webclient/webclient.go @@ -68,6 +68,9 @@ type Config struct { Timeout time.Duration // TraceProvider is used to retrieve a Tracer for creating spans TraceProvider oteltrace.TracerProvider + // UpdateGroup is used to modulate the webapi response based on the + // client's auto-update group. + UpdateGroup string } // CheckAndSetDefaults checks and sets defaults @@ -169,9 +172,18 @@ func Find(cfg *Config) (*PingResponse, error) { ctx, span := cfg.TraceProvider.Tracer("webclient").Start(cfg.Context, "webclient/Find") defer span.End() - endpoint := fmt.Sprintf("https://%s/webapi/find", cfg.ProxyAddr) + endpoint := url.URL{ + Scheme: "https", + Host: cfg.ProxyAddr, + Path: "/webapi/find", + } + if cfg.UpdateGroup != "" { + endpoint.RawQuery = url.Values{ + "group": []string{cfg.UpdateGroup}, + }.Encode() + } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) if err != nil { return nil, trace.Wrap(err) } @@ -205,12 +217,21 @@ func Ping(cfg *Config) (*PingResponse, error) { ctx, span := cfg.TraceProvider.Tracer("webclient").Start(cfg.Context, "webclient/Ping") defer span.End() - endpoint := fmt.Sprintf("https://%s/webapi/ping", cfg.ProxyAddr) + endpoint := url.URL{ + Scheme: "https", + Host: cfg.ProxyAddr, + Path: "/webapi/ping", + } + if cfg.UpdateGroup != "" { + endpoint.RawQuery = url.Values{ + "group": []string{cfg.UpdateGroup}, + }.Encode() + } if cfg.ConnectorName != "" { - endpoint = fmt.Sprintf("%s/%s", endpoint, cfg.ConnectorName) + endpoint.Path += "/" + cfg.ConnectorName } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/FIPS_and_Enterprise_flags.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/FIPS_and_Enterprise_flags.golden new file mode 100644 index 0000000000000..d9e09a2c95d71 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Enable/FIPS_and_Enterprise_flags.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + group: "" + url_template: "" + enabled: true +status: + active_version: 16.3.0 + backup_version: "" diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index ade0704607cb9..7071f16e42d15 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -240,20 +240,26 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { } desiredVersion := override.ForceVersion + var flags InstallFlags if desiredVersion == "" { resp, err := webclient.Find(&webclient.Config{ - Context: ctx, - ProxyAddr: addr.Addr, - Insecure: u.InsecureSkipVerify, - Timeout: 30 * time.Second, - //Group: cfg.Spec.Group, // TODO(sclevine): add web API for verssion - Pool: u.Pool, + Context: ctx, + ProxyAddr: addr.Addr, + Insecure: u.InsecureSkipVerify, + Timeout: 30 * time.Second, + UpdateGroup: cfg.Spec.Group, + Pool: u.Pool, }) if err != nil { return trace.Errorf("failed to request version from proxy: %w", err) } - desiredVersion, _ = "16.3.0", resp // TODO(sclevine): add web API for version - //desiredVersion := resp.AutoUpdate.AgentVersion + desiredVersion = resp.AutoUpdate.AgentVersion + if resp.Edition == "ent" { + flags |= FlagEnterprise + } + if resp.FIPS { + flags |= FlagFIPS + } } if desiredVersion == "" { @@ -277,7 +283,7 @@ func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { if template == "" { template = cdnURITemplate } - err = u.Installer.Install(ctx, desiredVersion, template, 0) // TODO(sclevine): add web API for flags + err = u.Installer.Install(ctx, desiredVersion, template, flags) if err != nil { return trace.Errorf("failed to install: %w", err) } diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go index d6d0128316c20..e817851fed1f7 100644 --- a/lib/autoupdate/agent/updater_test.go +++ b/lib/autoupdate/agent/updater_test.go @@ -20,6 +20,7 @@ package agent import ( "context" + "encoding/json" "errors" "net/http" "net/http/httptest" @@ -33,6 +34,7 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" + "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/lib/utils/golden" ) @@ -129,10 +131,12 @@ func TestUpdater_Enable(t *testing.T) { cfg *UpdateConfig // nil -> file not present userCfg OverrideConfig installErr error + flags InstallFlags removedVersion string installedVersion string installedTemplate string + requestGroup string errMatch string }{ { @@ -150,6 +154,7 @@ func TestUpdater_Enable(t *testing.T) { }, installedVersion: "16.3.0", installedTemplate: "https://example.com", + requestGroup: "group", }, { name: "config from user", @@ -255,6 +260,12 @@ func TestUpdater_Enable(t *testing.T) { installedVersion: "16.3.0", installedTemplate: cdnURITemplate, }, + { + name: "FIPS and Enterprise flags", + flags: FlagEnterprise | FlagFIPS, + installedVersion: "16.3.0", + installedTemplate: cdnURITemplate, + }, { name: "invalid metadata", cfg: &UpdateConfig{}, @@ -276,9 +287,20 @@ func TestUpdater_Enable(t *testing.T) { require.NoError(t, err) } + var requestedGroup string server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // TODO(sclevine): add web API test including group verification - w.Write([]byte(`{}`)) + requestedGroup = r.URL.Query().Get("group") + config := webclient.PingResponse{ + AutoUpdate: webclient.AutoUpdateSettings{ + AgentVersion: "16.3.0", + }, + } + if tt.flags&FlagEnterprise != 0 { + config.Edition = "ent" + } + config.FIPS = tt.flags&FlagFIPS != 0 + err := json.NewEncoder(w).Encode(config) + require.NoError(t, err) })) t.Cleanup(server.Close) @@ -297,11 +319,13 @@ func TestUpdater_Enable(t *testing.T) { installedTemplate string linkedVersion string removedVersion string + installedFlags InstallFlags ) updater.Installer = &testInstaller{ - FuncInstall: func(_ context.Context, version, template string, _ InstallFlags) error { + FuncInstall: func(_ context.Context, version, template string, flags InstallFlags) error { installedVersion = version installedTemplate = template + installedFlags = flags return tt.installErr }, FuncLink: func(_ context.Context, version string) error { @@ -329,6 +353,8 @@ func TestUpdater_Enable(t *testing.T) { require.Equal(t, tt.installedTemplate, installedTemplate) require.Equal(t, tt.installedVersion, linkedVersion) require.Equal(t, tt.removedVersion, removedVersion) + require.Equal(t, tt.flags, installedFlags) + require.Equal(t, tt.requestGroup, requestedGroup) data, err := os.ReadFile(cfgPath) require.NoError(t, err)