diff --git a/integration/autoupdate/tools/helper_test.go b/integration/autoupdate/tools/helper_test.go new file mode 100644 index 0000000000000..a3c37a9e94b55 --- /dev/null +++ b/integration/autoupdate/tools/helper_test.go @@ -0,0 +1,89 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools_test + +import ( + "net/http" + "sync" +) + +type limitRequest struct { + limit int64 + lock chan struct{} +} + +// limitedResponseWriter wraps http.ResponseWriter and enforces a write limit +// then block the response until signal is received. +type limitedResponseWriter struct { + requests chan limitRequest +} + +// newLimitedResponseWriter creates a new limitedResponseWriter with the lock. +func newLimitedResponseWriter() *limitedResponseWriter { + lw := &limitedResponseWriter{ + requests: make(chan limitRequest, 10), + } + return lw +} + +// Wrap wraps response writer if limit was previously requested, if not, return original one. +func (lw *limitedResponseWriter) Wrap(w http.ResponseWriter) http.ResponseWriter { + select { + case request := <-lw.requests: + return &wrapper{ + ResponseWriter: w, + request: request, + } + default: + return w + } +} + +// SetLimitRequest sends limit request to the pool to wrap next response writer with defined limits. +func (lw *limitedResponseWriter) SetLimitRequest(limit limitRequest) { + lw.requests <- limit +} + +// wrapper wraps the http response writer to control writing operation by blocking it. +type wrapper struct { + http.ResponseWriter + + written int64 + request limitRequest + released bool + + mutex sync.Mutex +} + +// Write writes data to the underlying ResponseWriter but respects the byte limit. +func (lw *wrapper) Write(p []byte) (int, error) { + lw.mutex.Lock() + defer lw.mutex.Unlock() + + if lw.written >= lw.request.limit && !lw.released { + // Send signal that lock is acquired and wait till it was released by response. + lw.request.lock <- struct{}{} + <-lw.request.lock + lw.released = true + } + + n, err := lw.ResponseWriter.Write(p) + lw.written += int64(n) + return n, err +} diff --git a/integration/autoupdate/tools/helper_unix_test.go b/integration/autoupdate/tools/helper_unix_test.go new file mode 100644 index 0000000000000..61ba0766b90d4 --- /dev/null +++ b/integration/autoupdate/tools/helper_unix_test.go @@ -0,0 +1,37 @@ +//go:build !windows + +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools_test + +import ( + "errors" + "syscall" + + "github.com/gravitational/trace" +) + +// sendInterrupt sends a SIGINT to the process. +func sendInterrupt(pid int) error { + err := syscall.Kill(pid, syscall.SIGINT) + if errors.Is(err, syscall.ESRCH) { + return trace.BadParameter("can't find the process: %v", pid) + } + return trace.Wrap(err) +} diff --git a/integration/autoupdate/tools/helper_windows_test.go b/integration/autoupdate/tools/helper_windows_test.go new file mode 100644 index 0000000000000..b2ede9ade8c19 --- /dev/null +++ b/integration/autoupdate/tools/helper_windows_test.go @@ -0,0 +1,42 @@ +//go:build windows + +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools_test + +import ( + "syscall" + + "github.com/gravitational/trace" + "golang.org/x/sys/windows" +) + +var ( + kernel = windows.NewLazyDLL("kernel32.dll") + ctrlEvent = kernel.NewProc("GenerateConsoleCtrlEvent") +) + +// sendInterrupt sends a Ctrl-Break event to the process. +func sendInterrupt(pid int) error { + r, _, err := ctrlEvent.Call(uintptr(syscall.CTRL_BREAK_EVENT), uintptr(pid)) + if r == 0 { + return trace.Wrap(err) + } + return nil +} diff --git a/integration/autoupdate/tools/main_test.go b/integration/autoupdate/tools/main_test.go new file mode 100644 index 0000000000000..a14a6dc9fc683 --- /dev/null +++ b/integration/autoupdate/tools/main_test.go @@ -0,0 +1,173 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools_test + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/integration/helpers" +) + +const ( + testBinaryName = "updater" + teleportToolsVersion = "TELEPORT_TOOLS_VERSION" +) + +var ( + // testVersions list of the pre-compiled binaries with encoded versions to check. + testVersions = []string{ + "1.2.3", + "3.2.1", + } + limitedWriter = newLimitedResponseWriter() + + toolsDir string + baseURL string +) + +func TestMain(m *testing.M) { + ctx := context.Background() + tmp, err := os.MkdirTemp(os.TempDir(), testBinaryName) + if err != nil { + log.Fatalf("failed to create temporary directory: %v", err) + } + + toolsDir, err = os.MkdirTemp(os.TempDir(), "tools") + if err != nil { + log.Fatalf("failed to create temporary directory: %v", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + filePath := filepath.Join(tmp, r.URL.Path) + switch { + case strings.HasSuffix(r.URL.Path, ".sha256"): + serve256File(w, r, strings.TrimSuffix(filePath, ".sha256")) + default: + http.ServeFile(limitedWriter.Wrap(w), r, filePath) + } + })) + baseURL = server.URL + for _, version := range testVersions { + if err := buildAndArchiveApps(ctx, tmp, toolsDir, version, server.URL); err != nil { + log.Fatalf("failed to build testing app binary archive: %v", err) + } + } + + // Run tests after binary is built. + code := m.Run() + + server.Close() + if err := os.RemoveAll(tmp); err != nil { + log.Fatalf("failed to remove temporary directory: %v", err) + } + if err := os.RemoveAll(toolsDir); err != nil { + log.Fatalf("failed to remove tools directory: %v", err) + } + + os.Exit(code) +} + +// serve256File calculates sha256 checksum for requested file. +func serve256File(w http.ResponseWriter, _ *http.Request, filePath string) { + log.Printf("Calculating and serving file checksum: %s\n", filePath) + + w.Header().Set("Content-Disposition", "attachment; filename=\""+filepath.Base(filePath)+".sha256\"") + w.Header().Set("Content-Type", "plain/text") + + file, err := os.Open(filePath) + if errors.Is(err, os.ErrNotExist) { + http.Error(w, "file not found", http.StatusNotFound) + return + } + if err != nil { + http.Error(w, "failed to open file", http.StatusInternalServerError) + return + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + http.Error(w, "failed to write to hash", http.StatusInternalServerError) + return + } + if _, err := hex.NewEncoder(w).Write(hash.Sum(nil)); err != nil { + http.Error(w, "failed to write checksum", http.StatusInternalServerError) + } +} + +// buildAndArchiveApps compiles the updater integration and pack it depends on platform is used. +func buildAndArchiveApps(ctx context.Context, path string, toolsDir string, version string, baseURL string) error { + versionPath := filepath.Join(path, version) + for _, app := range []string{"tsh", "tctl"} { + output := filepath.Join(versionPath, app) + switch runtime.GOOS { + case "windows": + output = filepath.Join(versionPath, app+".exe") + case "darwin": + output = filepath.Join(versionPath, app+".app", "Contents", "MacOS", app) + } + if err := buildBinary(output, toolsDir, version, baseURL); err != nil { + return trace.Wrap(err) + } + } + switch runtime.GOOS { + case "darwin": + archivePath := filepath.Join(path, fmt.Sprintf("teleport-%s.pkg", version)) + return trace.Wrap(helpers.CompressDirToPkgFile(ctx, versionPath, archivePath, "com.example.pkgtest")) + case "windows": + archivePath := filepath.Join(path, fmt.Sprintf("teleport-v%s-windows-amd64-bin.zip", version)) + return trace.Wrap(helpers.CompressDirToZipFile(ctx, versionPath, archivePath)) + default: + archivePath := filepath.Join(path, fmt.Sprintf("teleport-v%s-linux-%s-bin.tar.gz", version, runtime.GOARCH)) + return trace.Wrap(helpers.CompressDirToTarGzFile(ctx, versionPath, archivePath)) + } +} + +// buildBinary executes command to build binary with updater logic only for testing. +func buildBinary(output string, toolsDir string, version string, baseURL string) error { + cmd := exec.Command( + "go", "build", "-o", output, + "-ldflags", strings.Join([]string{ + fmt.Sprintf("-X 'main.toolsDir=%s'", toolsDir), + fmt.Sprintf("-X 'main.version=%s'", version), + fmt.Sprintf("-X 'main.baseURL=%s'", baseURL), + }, " "), + "./updater", + ) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + return trace.Wrap(cmd.Run()) +} diff --git a/integration/autoupdate/tools/updater/main.go b/integration/autoupdate/tools/updater/main.go new file mode 100644 index 0000000000000..e14c76e5d5aa8 --- /dev/null +++ b/integration/autoupdate/tools/updater/main.go @@ -0,0 +1,88 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package main + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "os/signal" + "runtime" + "syscall" + "time" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/lib/autoupdate/tools" +) + +var ( + version = "development" + baseURL = "http://localhost" + toolsDir = "" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + ctx, _ = signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) + + updater := tools.NewUpdater( + clientTools(), + toolsDir, + version, + tools.WithBaseURL(baseURL), + ) + toolsVersion, reExec := updater.CheckLocal() + if reExec { + // Download and update the version of client tools required by the cluster. + // This is required if the user passed in the TELEPORT_TOOLS_VERSION explicitly. + err := updater.UpdateWithLock(ctx, toolsVersion) + if errors.Is(err, context.Canceled) { + os.Exit(0) + return + } + if err != nil { + log.Fatalf("failed to download version (%v): %v\n", toolsVersion, err) + return + } + + // Re-execute client tools with the correct version of client tools. + code, err := updater.Exec() + if err != nil { + log.Fatalf("Failed to re-exec client tool: %v\n", err) + } else { + os.Exit(code) + } + } + if len(os.Args) > 1 && os.Args[1] == "version" { + fmt.Printf("Teleport v%v git\n", version) + } +} + +// clientTools list of the client tools needs to be updated. +func clientTools() []string { + switch runtime.GOOS { + case constants.WindowsOS: + return []string{"tsh.exe", "tctl.exe"} + default: + return []string{"tsh", "tctl"} + } +} diff --git a/integration/autoupdate/tools/updater_test.go b/integration/autoupdate/tools/updater_test.go new file mode 100644 index 0000000000000..96d5486462067 --- /dev/null +++ b/integration/autoupdate/tools/updater_test.go @@ -0,0 +1,231 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools_test + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/lib/autoupdate/tools" +) + +var ( + // pattern is template for response on version command for client tools {tsh, tctl}. + pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`) +) + +// TestUpdate verifies the basic update logic. We first download a lower version, then request +// an update to a newer version, expecting it to re-execute with the updated version. +func TestUpdate(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Fetch compiled test binary with updater logic and install to $TELEPORT_HOME. + updater := tools.NewUpdater( + clientTools(), + toolsDir, + testVersions[0], + tools.WithBaseURL(baseURL), + ) + err := updater.Update(ctx, testVersions[0]) + require.NoError(t, err) + + // Verify that the installed version is equal to requested one. + cmd := exec.CommandContext(ctx, filepath.Join(toolsDir, "tsh"), "version") + out, err := cmd.Output() + require.NoError(t, err) + + matches := pattern.FindStringSubmatch(string(out)) + require.Len(t, matches, 2) + require.Equal(t, testVersions[0], matches[1]) + + // Execute version command again with setting the new version which must + // trigger re-execution of the same command after downloading requested version. + cmd = exec.CommandContext(ctx, filepath.Join(toolsDir, "tsh"), "version") + cmd.Env = append( + os.Environ(), + fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]), + ) + out, err = cmd.Output() + require.NoError(t, err) + + matches = pattern.FindStringSubmatch(string(out)) + require.Len(t, matches, 2) + require.Equal(t, testVersions[1], matches[1]) +} + +// TestParallelUpdate launches multiple updater commands in parallel while defining a new version. +// The first process should acquire a lock and block execution for the other processes. After the +// first update is complete, other processes should acquire the lock one by one and re-execute +// the command with the updated version without any new downloads. +func TestParallelUpdate(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Initial fetch the updater binary un-archive and replace. + updater := tools.NewUpdater( + clientTools(), + toolsDir, + testVersions[0], + tools.WithBaseURL(baseURL), + ) + err := updater.Update(ctx, testVersions[0]) + require.NoError(t, err) + + // By setting the limit request next test http serving file going blocked until unlock is sent. + lock := make(chan struct{}) + limitedWriter.SetLimitRequest(limitRequest{ + limit: 1024, + lock: lock, + }) + + outputs := make([]bytes.Buffer, 3) + errChan := make(chan error, 3) + for i := 0; i < len(outputs); i++ { + cmd := exec.Command(filepath.Join(toolsDir, "tsh"), "version") + cmd.Stdout = &outputs[i] + cmd.Stderr = &outputs[i] + cmd.Env = append( + os.Environ(), + fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]), + ) + err = cmd.Start() + require.NoError(t, err, "failed to start updater") + + go func(cmd *exec.Cmd) { + errChan <- cmd.Wait() + }(cmd) + } + + select { + case err := <-errChan: + require.Fail(t, "we shouldn't receive any error", err) + case <-time.After(5 * time.Second): + require.Fail(t, "failed to wait till the download is started") + case <-lock: + // Wait for a short period to allow other processes to launch and attempt to acquire the lock. + time.Sleep(100 * time.Millisecond) + lock <- struct{}{} + } + + // Wait till process finished with exit code 0, but we still should get progress + // bar in output content. + for i := 0; i < cap(outputs); i++ { + select { + case <-time.After(5 * time.Second): + require.Fail(t, "failed to wait till the process is finished") + case err := <-errChan: + require.NoError(t, err) + } + } + + var progressCount int + for i := 0; i < cap(outputs); i++ { + matches := pattern.FindStringSubmatch(outputs[i].String()) + require.Len(t, matches, 2) + assert.Equal(t, testVersions[1], matches[1]) + if strings.Contains(outputs[i].String(), "Update progress:") { + progressCount++ + } + } + assert.Equal(t, 1, progressCount, "we should have only one progress bar downloading new version") +} + +// TestUpdateInterruptSignal verifies the interrupt signal send to the process must stop downloading. +func TestUpdateInterruptSignal(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Initial fetch the updater binary un-archive and replace. + updater := tools.NewUpdater( + clientTools(), + toolsDir, + testVersions[0], + tools.WithBaseURL(baseURL), + ) + err := updater.Update(ctx, testVersions[0]) + require.NoError(t, err) + + var output bytes.Buffer + cmd := exec.Command(filepath.Join(toolsDir, "tsh"), "version") + cmd.Stdout = &output + cmd.Stderr = &output + cmd.Env = append( + os.Environ(), + fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]), + ) + err = cmd.Start() + require.NoError(t, err, "failed to start updater") + pid := cmd.Process.Pid + + errChan := make(chan error) + go func() { + errChan <- cmd.Wait() + }() + + // By setting the limit request next test http serving file going blocked until unlock is sent. + lock := make(chan struct{}) + limitedWriter.SetLimitRequest(limitRequest{ + limit: 1024, + lock: lock, + }) + + select { + case err := <-errChan: + require.Fail(t, "we shouldn't receive any error", err) + case <-time.After(5 * time.Second): + require.Fail(t, "failed to wait till the download is started") + case <-lock: + time.Sleep(100 * time.Millisecond) + require.NoError(t, sendInterrupt(pid)) + lock <- struct{}{} + } + + // Wait till process finished with exit code 0, but we still should get progress + // bar in output content. + select { + case <-time.After(5 * time.Second): + require.Fail(t, "failed to wait till the process interrupted") + case err := <-errChan: + require.NoError(t, err) + } + assert.Contains(t, output.String(), "Update progress:") +} + +func clientTools() []string { + switch runtime.GOOS { + case constants.WindowsOS: + return []string{"tsh.exe", "tctl.exe"} + default: + return []string{"tsh", "tctl"} + } +} diff --git a/lib/autoupdate/tools/progress.go b/lib/autoupdate/tools/progress.go new file mode 100644 index 0000000000000..95395003730ec --- /dev/null +++ b/lib/autoupdate/tools/progress.go @@ -0,0 +1,43 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools + +import ( + "fmt" + "strings" +) + +type progressWriter struct { + n int64 + limit int64 +} + +func (w *progressWriter) Write(p []byte) (int, error) { + w.n = w.n + int64(len(p)) + + n := int((w.n*100)/w.limit) / 10 + bricks := strings.Repeat("▒", n) + strings.Repeat(" ", 10-n) + fmt.Print("\rUpdate progress: [" + bricks + "] (Ctrl-C to cancel update)") + + if w.n == w.limit { + fmt.Print("\n") + } + + return len(p), nil +} diff --git a/lib/autoupdate/tools/updater.go b/lib/autoupdate/tools/updater.go new file mode 100644 index 0000000000000..96991044ccc31 --- /dev/null +++ b/lib/autoupdate/tools/updater.go @@ -0,0 +1,400 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools + +import ( + "bytes" + "context" + "crypto/sha256" + "crypto/x509" + "encoding/hex" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "syscall" + "time" + + "github.com/google/uuid" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types/autoupdate" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/packaging" +) + +const ( + // teleportToolsVersionEnv is environment name for requesting specific version for update. + teleportToolsVersionEnv = "TELEPORT_TOOLS_VERSION" + // baseURL is CDN URL for downloading official Teleport packages. + baseURL = "https://cdn.teleport.dev" + // reservedFreeDisk is the predefined amount of free disk space (in bytes) required + // to remain available after downloading archives. + reservedFreeDisk = 10 * 1024 * 1024 // 10 Mb + // lockFileName is file used for locking update process in parallel. + lockFileName = ".lock" + // updatePackageSuffix is directory suffix used for package extraction in tools directory. + updatePackageSuffix = "-update-pkg" +) + +var ( + // // pattern is template for response on version command for client tools {tsh, tctl}. + pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`) +) + +// Option applies an option value for the Updater. +type Option func(u *Updater) + +// WithBaseURL defines custom base url for the updater. +func WithBaseURL(baseURL string) Option { + return func(u *Updater) { + u.baseURL = baseURL + } +} + +// WithClient defines custom http client for the Updater. +func WithClient(client *http.Client) Option { + return func(u *Updater) { + u.client = client + } +} + +// Updater is updater implementation for the client tools auto updates. +type Updater struct { + toolsDir string + localVersion string + tools []string + + baseURL string + client *http.Client +} + +// NewUpdater initializes the updater for client tools auto updates. We need to specify the list +// of tools (e.g., `tsh`, `tctl`) that should be updated, the tools directory path where we +// download, extract package archives with the new version, and replace symlinks (e.g., `$TELEPORT_HOME/bin`). +// The base URL of the CDN with Teleport packages and the `http.Client` can be customized via options. +func NewUpdater(tools []string, toolsDir string, localVersion string, options ...Option) *Updater { + updater := &Updater{ + tools: tools, + toolsDir: toolsDir, + localVersion: localVersion, + baseURL: baseURL, + client: http.DefaultClient, + } + for _, option := range options { + option(updater) + } + + return updater +} + +// CheckLocal is run at client tool startup and will only perform local checks. +// Returns the version needs to be updated and re-executed, by re-execution flag we +// understand that update and re-execute is required. +func (u *Updater) CheckLocal() (version string, reExec bool) { + // Check if the user has requested a specific version of client tools. + requestedVersion := os.Getenv(teleportToolsVersionEnv) + switch requestedVersion { + // The user has turned off any form of automatic updates. + case "off": + return "", false + // Requested version already the same as client version. + case u.localVersion: + return u.localVersion, false + } + + // If a version of client tools has already been downloaded to + // tools directory, return that. + toolsVersion, err := checkToolVersion(u.toolsDir) + if err != nil { + return "", false + } + // The user has requested a specific version of client tools. + if requestedVersion != "" && requestedVersion != toolsVersion { + return requestedVersion, true + } + + return toolsVersion, false +} + +// CheckRemote first checks the version set by the environment variable. If not set or disabled, +// it checks against the Proxy Service to determine if client tools need updating by requesting +// the `webapi/find` handler, which stores information about the required client tools version to +// operate with this cluster. It returns the semantic version that needs updating and whether +// re-execution is necessary, by re-execution flag we understand that update and re-execute is required. +func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (version string, reExec bool, err error) { + // Check if the user has requested a specific version of client tools. + requestedVersion := os.Getenv(teleportToolsVersionEnv) + switch requestedVersion { + // The user has turned off any form of automatic updates. + case "off": + return "", false, nil + // Requested version already the same as client version. + case u.localVersion: + return u.localVersion, false, nil + } + + certPool, err := x509.SystemCertPool() + if err != nil { + return "", false, trace.Wrap(err) + } + resp, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: proxyAddr, + Pool: certPool, + Timeout: 30 * time.Second, + }) + if err != nil { + return "", false, trace.Wrap(err) + } + + // If a version of client tools has already been downloaded to + // tools directory, return that. + toolsVersion, err := checkToolVersion(u.toolsDir) + if err != nil { + return "", false, trace.Wrap(err) + } + + switch { + case requestedVersion != "" && requestedVersion != toolsVersion: + return requestedVersion, true, nil + case resp.AutoUpdate.ToolsMode != autoupdate.ToolsUpdateModeEnabled || resp.AutoUpdate.ToolsVersion == "": + return "", false, nil + case u.localVersion == resp.AutoUpdate.ToolsVersion: + return resp.AutoUpdate.ToolsVersion, false, nil + case resp.AutoUpdate.ToolsVersion != toolsVersion: + return resp.AutoUpdate.ToolsVersion, true, nil + } + + return toolsVersion, false, nil +} + +// UpdateWithLock acquires filesystem lock, downloads requested version package, +// unarchive and replace existing one. +func (u *Updater) UpdateWithLock(ctx context.Context, toolsVersion string) (err error) { + // Create tools directory if it does not exist. + if err := os.MkdirAll(u.toolsDir, 0o755); err != nil { + return trace.Wrap(err) + } + // Lock concurrent client tools execution util requested version is updated. + unlock, err := utils.FSWriteLock(filepath.Join(u.toolsDir, lockFileName)) + if err != nil { + return trace.Wrap(err) + } + defer func() { + err = trace.NewAggregate(err, unlock()) + }() + + // If the version of the running binary or the version downloaded to + // tools directory is the same as the requested version of client tools, + // nothing to be done, exit early. + teleportVersion, err := checkToolVersion(u.toolsDir) + if err != nil && !trace.IsNotFound(err) { + return trace.Wrap(err) + + } + if toolsVersion == u.localVersion || toolsVersion == teleportVersion { + return nil + } + + // Download and update client tools in tools directory. + if err := u.Update(ctx, toolsVersion); err != nil { + return trace.Wrap(err) + } + + return +} + +// Update downloads requested version and replace it with existing one and cleanups the previous downloads +// with defined updater directory suffix. +func (u *Updater) Update(ctx context.Context, toolsVersion string) error { + // Get platform specific download URLs. + packages, err := teleportPackageURLs(u.baseURL, toolsVersion) + if err != nil { + return trace.Wrap(err) + } + + for _, pkg := range packages { + if err := u.update(ctx, pkg); err != nil { + return trace.Wrap(err) + } + } + + return nil +} + +// update downloads the archive and validate against the hash. Download to a +// temporary path within tools directory. +func (u *Updater) update(ctx context.Context, pkg packageURL) error { + hash, err := u.downloadHash(ctx, pkg.Hash) + if pkg.Optional && trace.IsNotFound(err) { + return nil + } + if err != nil { + return trace.Wrap(err) + } + + f, err := os.CreateTemp(u.toolsDir, "tmp-") + if err != nil { + return trace.Wrap(err) + } + defer func() { + _ = f.Close() + if err := os.Remove(f.Name()); err != nil { + slog.WarnContext(ctx, "failed to remove temporary archive file", "error", err) + } + }() + + archiveHash, err := u.downloadArchive(ctx, pkg.Archive, f) + if pkg.Optional && trace.IsNotFound(err) { + return nil + } + if err != nil { + return trace.Wrap(err) + } + if !bytes.Equal(archiveHash, hash) { + return trace.BadParameter("hash of archive does not match downloaded archive") + } + + pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix) + extractDir := filepath.Join(u.toolsDir, pkgName) + if runtime.GOOS != constants.DarwinOS { + if err := os.Mkdir(extractDir, 0o755); err != nil { + return trace.Wrap(err) + } + } + + // Perform atomic replace so concurrent exec do not fail. + if err := packaging.ReplaceToolsBinaries(u.toolsDir, f.Name(), extractDir, u.tools); err != nil { + return trace.Wrap(err) + } + // Cleanup the tools directory with previously downloaded and un-archived versions. + if err := packaging.RemoveWithSuffix(u.toolsDir, updatePackageSuffix, pkgName); err != nil { + return trace.Wrap(err) + } + + return nil +} + +// Exec re-executes tool command with same arguments and environ variables. +func (u *Updater) Exec() (int, error) { + path, err := toolName(u.toolsDir) + if err != nil { + return 0, trace.Wrap(err) + } + // To prevent re-execution loop we have to disable update logic for re-execution. + env := append(os.Environ(), teleportToolsVersionEnv+"=off") + + if runtime.GOOS == constants.WindowsOS { + cmd := exec.Command(path, os.Args[1:]...) + cmd.Env = env + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return 0, trace.Wrap(err) + } + + return cmd.ProcessState.ExitCode(), nil + } + + if err := syscall.Exec(path, append([]string{path}, os.Args[1:]...), env); err != nil { + return 0, trace.Wrap(err) + } + + return 0, nil +} + +// downloadHash downloads the hash file `.sha256` for package checksum validation and return the hash sum. +func (u *Updater) downloadHash(ctx context.Context, url string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, trace.Wrap(err) + } + resp, err := u.client.Do(req) + if err != nil { + return nil, trace.Wrap(err) + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return nil, trace.NotFound("hash file is not found: %v", resp.StatusCode) + } + if resp.StatusCode != http.StatusOK { + return nil, trace.BadParameter("bad status when downloading archive hash: %v", resp.StatusCode) + } + + var buf bytes.Buffer + _, err = io.CopyN(&buf, resp.Body, sha256.Size*2) // SHA bytes to hex + if err != nil { + return nil, trace.Wrap(err) + } + hexBytes, err := hex.DecodeString(buf.String()) + if err != nil { + return nil, trace.Wrap(err) + } + + return hexBytes, nil +} + +// downloadArchive downloads the archive package by `url` and writes content to the writer interface, +// return calculated sha256 hash sum of the content. +func (u *Updater) downloadArchive(ctx context.Context, url string, f io.Writer) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, trace.Wrap(err) + } + resp, err := u.client.Do(req) + if err != nil { + return nil, trace.Wrap(err) + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return nil, trace.NotFound("archive file is not found: %v", resp.StatusCode) + } + if resp.StatusCode != http.StatusOK { + return nil, trace.BadParameter("bad status when downloading archive: %v", resp.StatusCode) + } + + if resp.ContentLength != -1 { + if err := checkFreeSpace(u.toolsDir, uint64(resp.ContentLength)); err != nil { + return nil, trace.Wrap(err) + } + } + + h := sha256.New() + pw := &progressWriter{n: 0, limit: resp.ContentLength} + body := io.TeeReader(io.TeeReader(resp.Body, h), pw) + + // It is a little inefficient to download the file to disk and then re-load + // it into memory to unarchive later, but this is safer as it allows client + // tools to validate the hash before trying to operate on the archive. + _, err = io.CopyN(f, body, resp.ContentLength) + if err != nil { + return nil, trace.Wrap(err) + } + + return h.Sum(nil), nil +} diff --git a/lib/autoupdate/tools/utils.go b/lib/autoupdate/tools/utils.go new file mode 100644 index 0000000000000..d552b31abefe4 --- /dev/null +++ b/lib/autoupdate/tools/utils.go @@ -0,0 +1,175 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tools + +import ( + "bufio" + "bytes" + "context" + "errors" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/coreos/go-semver/semver" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/utils" +) + +// Dir returns the path to client tools in $TELEPORT_HOME/bin. +func Dir() (string, error) { + home := os.Getenv(types.HomeEnvVar) + if home == "" { + var err error + home, err = os.UserHomeDir() + if err != nil { + return "", trace.Wrap(err) + } + } + + return filepath.Join(home, ".tsh", "bin"), nil +} + +func checkToolVersion(toolsDir string) (string, error) { + // Find the path to the current executable. + path, err := toolName(toolsDir) + if err != nil { + return "", trace.Wrap(err) + } + if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) { + return "", nil + } else if err != nil { + return "", trace.Wrap(err) + } + + // Set a timeout to not let "{tsh, tctl} version" block forever. Allow up + // to 10 seconds because sometimes MDM tools like Jamf cause a lot of + // latency in launching binaries. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Execute "{tsh, tctl} version" and pass in TELEPORT_TOOLS_VERSION=off to + // turn off all automatic updates code paths to prevent any recursion. + command := exec.CommandContext(ctx, path, "version") + command.Env = []string{teleportToolsVersionEnv + "=off"} + output, err := command.Output() + if err != nil { + return "", trace.Wrap(err) + } + + // The output for "{tsh, tctl} version" can be multiple lines. Find the + // actual version line and extract the version. + scanner := bufio.NewScanner(bytes.NewReader(output)) + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "Teleport") { + continue + } + + matches := pattern.FindStringSubmatch(line) + if len(matches) != 2 { + return "", trace.BadParameter("invalid version line: %v", line) + } + version, err := semver.NewVersion(matches[1]) + if err != nil { + return "", trace.Wrap(err) + } + return version.String(), nil + } + + return "", trace.BadParameter("unable to determine version") +} + +// packageURL defines URLs to the archive and their archive sha256 hash file, and marks +// if this package is optional, for such case download needs to be ignored if package +// not found in CDN. +type packageURL struct { + Archive string + Hash string + Optional bool +} + +// teleportPackageURLs returns the URL for the Teleport archive to download. The format is: +// https://cdn.teleport.dev/teleport-{, ent-}v15.3.0-{linux, darwin, windows}-{amd64,arm64,arm,386}-{fips-}bin.tar.gz +func teleportPackageURLs(baseURL, toolsVersion string) ([]packageURL, error) { + switch runtime.GOOS { + case "darwin": + tsh := baseURL + "/tsh-" + toolsVersion + ".pkg" + teleport := baseURL + "/teleport-" + toolsVersion + ".pkg" + return []packageURL{ + {Archive: teleport, Hash: teleport + ".sha256"}, + {Archive: tsh, Hash: tsh + ".sha256", Optional: true}, + }, nil + case "windows": + archive := baseURL + "/teleport-v" + toolsVersion + "-windows-amd64-bin.zip" + return []packageURL{ + {Archive: archive, Hash: archive + ".sha256"}, + }, nil + case "linux": + m := modules.GetModules() + var b strings.Builder + b.WriteString(baseURL + "/teleport-") + if m.IsEnterpriseBuild() || m.IsBoringBinary() { + b.WriteString("ent-") + } + b.WriteString("v" + toolsVersion + "-" + runtime.GOOS + "-" + runtime.GOARCH + "-") + if m.IsBoringBinary() { + b.WriteString("fips-") + } + b.WriteString("bin.tar.gz") + archive := b.String() + return []packageURL{ + {Archive: archive, Hash: archive + ".sha256"}, + }, nil + default: + return nil, trace.BadParameter("unsupported runtime: %v", runtime.GOOS) + } +} + +// toolName returns the path to {tsh, tctl} for the executable that started +// the current process. +func toolName(toolsDir string) (string, error) { + executablePath, err := os.Executable() + if err != nil { + return "", trace.Wrap(err) + } + + return filepath.Join(toolsDir, filepath.Base(executablePath)), nil +} + +// checkFreeSpace verifies that we have enough requested space at specific directory. +func checkFreeSpace(path string, requested uint64) error { + free, err := utils.FreeDiskWithReserve(path, reservedFreeDisk) + if err != nil { + return trace.Errorf("failed to calculate free disk in %q: %v", path, err) + } + // Bail if there's not enough free disk space at the target. + if requested > free { + return trace.Errorf("%q needs %d additional bytes of disk space", path, requested-free) + } + + return nil +} diff --git a/lib/utils/disk.go b/lib/utils/disk.go index b3553c873aa3e..782c598d26229 100644 --- a/lib/utils/disk.go +++ b/lib/utils/disk.go @@ -53,11 +53,8 @@ func FreeDiskWithReserve(dir string, reservedFreeDisk uint64) (uint64, error) { if err != nil { return 0, trace.Wrap(err) } - //nolint:staticcheck // SA4003. False positive on macOS. - if stat.Bsize < 0 { - return 0, trace.Errorf("invalid size") - } - avail := stat.Bavail * uint64(stat.Bsize) + //nolint:unconvert // The cast is only necessary for linux platform. + avail := uint64(stat.Bavail) * uint64(stat.Bsize) if reservedFreeDisk > avail { return 0, trace.Errorf("no free space left") } diff --git a/lib/utils/packaging/unarchive_unix.go b/lib/utils/packaging/unarchive_unix.go index 3be7d0c473ef9..ea51afdbbc7f0 100644 --- a/lib/utils/packaging/unarchive_unix.go +++ b/lib/utils/packaging/unarchive_unix.go @@ -33,6 +33,8 @@ import ( "github.com/google/renameio/v2" "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/constants" ) // ReplaceToolsBinaries extracts executables specified by execNames from archivePath into @@ -43,7 +45,7 @@ import ( // For other POSIX, archivePath must be a gzipped tarball. func ReplaceToolsBinaries(toolsDir string, archivePath string, extractDir string, execNames []string) error { switch runtime.GOOS { - case "darwin": + case constants.DarwinOS: return replacePkg(toolsDir, archivePath, extractDir, execNames) default: return replaceTarGz(toolsDir, archivePath, extractDir, execNames)