diff --git a/integration/autoupdate/tools/helper_test.go b/integration/autoupdate/tools/helper_test.go
new file mode 100644
index 0000000000000..a3c37a9e94b55
--- /dev/null
+++ b/integration/autoupdate/tools/helper_test.go
@@ -0,0 +1,89 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tools_test
+
+import (
+ "net/http"
+ "sync"
+)
+
+type limitRequest struct {
+ limit int64
+ lock chan struct{}
+}
+
+// limitedResponseWriter wraps http.ResponseWriter and enforces a write limit
+// then block the response until signal is received.
+type limitedResponseWriter struct {
+ requests chan limitRequest
+}
+
+// newLimitedResponseWriter creates a new limitedResponseWriter with the lock.
+func newLimitedResponseWriter() *limitedResponseWriter {
+ lw := &limitedResponseWriter{
+ requests: make(chan limitRequest, 10),
+ }
+ return lw
+}
+
+// Wrap wraps response writer if limit was previously requested, if not, return original one.
+func (lw *limitedResponseWriter) Wrap(w http.ResponseWriter) http.ResponseWriter {
+ select {
+ case request := <-lw.requests:
+ return &wrapper{
+ ResponseWriter: w,
+ request: request,
+ }
+ default:
+ return w
+ }
+}
+
+// SetLimitRequest sends limit request to the pool to wrap next response writer with defined limits.
+func (lw *limitedResponseWriter) SetLimitRequest(limit limitRequest) {
+ lw.requests <- limit
+}
+
+// wrapper wraps the http response writer to control writing operation by blocking it.
+type wrapper struct {
+ http.ResponseWriter
+
+ written int64
+ request limitRequest
+ released bool
+
+ mutex sync.Mutex
+}
+
+// Write writes data to the underlying ResponseWriter but respects the byte limit.
+func (lw *wrapper) Write(p []byte) (int, error) {
+ lw.mutex.Lock()
+ defer lw.mutex.Unlock()
+
+ if lw.written >= lw.request.limit && !lw.released {
+ // Send signal that lock is acquired and wait till it was released by response.
+ lw.request.lock <- struct{}{}
+ <-lw.request.lock
+ lw.released = true
+ }
+
+ n, err := lw.ResponseWriter.Write(p)
+ lw.written += int64(n)
+ return n, err
+}
diff --git a/integration/autoupdate/tools/helper_unix_test.go b/integration/autoupdate/tools/helper_unix_test.go
new file mode 100644
index 0000000000000..61ba0766b90d4
--- /dev/null
+++ b/integration/autoupdate/tools/helper_unix_test.go
@@ -0,0 +1,37 @@
+//go:build !windows
+
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tools_test
+
+import (
+ "errors"
+ "syscall"
+
+ "github.com/gravitational/trace"
+)
+
+// sendInterrupt sends a SIGINT to the process.
+func sendInterrupt(pid int) error {
+ err := syscall.Kill(pid, syscall.SIGINT)
+ if errors.Is(err, syscall.ESRCH) {
+ return trace.BadParameter("can't find the process: %v", pid)
+ }
+ return trace.Wrap(err)
+}
diff --git a/integration/autoupdate/tools/helper_windows_test.go b/integration/autoupdate/tools/helper_windows_test.go
new file mode 100644
index 0000000000000..b2ede9ade8c19
--- /dev/null
+++ b/integration/autoupdate/tools/helper_windows_test.go
@@ -0,0 +1,42 @@
+//go:build windows
+
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tools_test
+
+import (
+ "syscall"
+
+ "github.com/gravitational/trace"
+ "golang.org/x/sys/windows"
+)
+
+var (
+ kernel = windows.NewLazyDLL("kernel32.dll")
+ ctrlEvent = kernel.NewProc("GenerateConsoleCtrlEvent")
+)
+
+// sendInterrupt sends a Ctrl-Break event to the process.
+func sendInterrupt(pid int) error {
+ r, _, err := ctrlEvent.Call(uintptr(syscall.CTRL_BREAK_EVENT), uintptr(pid))
+ if r == 0 {
+ return trace.Wrap(err)
+ }
+ return nil
+}
diff --git a/integration/autoupdate/tools/main_test.go b/integration/autoupdate/tools/main_test.go
new file mode 100644
index 0000000000000..a14a6dc9fc683
--- /dev/null
+++ b/integration/autoupdate/tools/main_test.go
@@ -0,0 +1,173 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tools_test
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/integration/helpers"
+)
+
+const (
+ testBinaryName = "updater"
+ teleportToolsVersion = "TELEPORT_TOOLS_VERSION"
+)
+
+var (
+ // testVersions list of the pre-compiled binaries with encoded versions to check.
+ testVersions = []string{
+ "1.2.3",
+ "3.2.1",
+ }
+ limitedWriter = newLimitedResponseWriter()
+
+ toolsDir string
+ baseURL string
+)
+
+func TestMain(m *testing.M) {
+ ctx := context.Background()
+ tmp, err := os.MkdirTemp(os.TempDir(), testBinaryName)
+ if err != nil {
+ log.Fatalf("failed to create temporary directory: %v", err)
+ }
+
+ toolsDir, err = os.MkdirTemp(os.TempDir(), "tools")
+ if err != nil {
+ log.Fatalf("failed to create temporary directory: %v", err)
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ filePath := filepath.Join(tmp, r.URL.Path)
+ switch {
+ case strings.HasSuffix(r.URL.Path, ".sha256"):
+ serve256File(w, r, strings.TrimSuffix(filePath, ".sha256"))
+ default:
+ http.ServeFile(limitedWriter.Wrap(w), r, filePath)
+ }
+ }))
+ baseURL = server.URL
+ for _, version := range testVersions {
+ if err := buildAndArchiveApps(ctx, tmp, toolsDir, version, server.URL); err != nil {
+ log.Fatalf("failed to build testing app binary archive: %v", err)
+ }
+ }
+
+ // Run tests after binary is built.
+ code := m.Run()
+
+ server.Close()
+ if err := os.RemoveAll(tmp); err != nil {
+ log.Fatalf("failed to remove temporary directory: %v", err)
+ }
+ if err := os.RemoveAll(toolsDir); err != nil {
+ log.Fatalf("failed to remove tools directory: %v", err)
+ }
+
+ os.Exit(code)
+}
+
+// serve256File calculates sha256 checksum for requested file.
+func serve256File(w http.ResponseWriter, _ *http.Request, filePath string) {
+ log.Printf("Calculating and serving file checksum: %s\n", filePath)
+
+ w.Header().Set("Content-Disposition", "attachment; filename=\""+filepath.Base(filePath)+".sha256\"")
+ w.Header().Set("Content-Type", "plain/text")
+
+ file, err := os.Open(filePath)
+ if errors.Is(err, os.ErrNotExist) {
+ http.Error(w, "file not found", http.StatusNotFound)
+ return
+ }
+ if err != nil {
+ http.Error(w, "failed to open file", http.StatusInternalServerError)
+ return
+ }
+ defer file.Close()
+
+ hash := sha256.New()
+ if _, err := io.Copy(hash, file); err != nil {
+ http.Error(w, "failed to write to hash", http.StatusInternalServerError)
+ return
+ }
+ if _, err := hex.NewEncoder(w).Write(hash.Sum(nil)); err != nil {
+ http.Error(w, "failed to write checksum", http.StatusInternalServerError)
+ }
+}
+
+// buildAndArchiveApps compiles the updater integration and pack it depends on platform is used.
+func buildAndArchiveApps(ctx context.Context, path string, toolsDir string, version string, baseURL string) error {
+ versionPath := filepath.Join(path, version)
+ for _, app := range []string{"tsh", "tctl"} {
+ output := filepath.Join(versionPath, app)
+ switch runtime.GOOS {
+ case "windows":
+ output = filepath.Join(versionPath, app+".exe")
+ case "darwin":
+ output = filepath.Join(versionPath, app+".app", "Contents", "MacOS", app)
+ }
+ if err := buildBinary(output, toolsDir, version, baseURL); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+ switch runtime.GOOS {
+ case "darwin":
+ archivePath := filepath.Join(path, fmt.Sprintf("teleport-%s.pkg", version))
+ return trace.Wrap(helpers.CompressDirToPkgFile(ctx, versionPath, archivePath, "com.example.pkgtest"))
+ case "windows":
+ archivePath := filepath.Join(path, fmt.Sprintf("teleport-v%s-windows-amd64-bin.zip", version))
+ return trace.Wrap(helpers.CompressDirToZipFile(ctx, versionPath, archivePath))
+ default:
+ archivePath := filepath.Join(path, fmt.Sprintf("teleport-v%s-linux-%s-bin.tar.gz", version, runtime.GOARCH))
+ return trace.Wrap(helpers.CompressDirToTarGzFile(ctx, versionPath, archivePath))
+ }
+}
+
+// buildBinary executes command to build binary with updater logic only for testing.
+func buildBinary(output string, toolsDir string, version string, baseURL string) error {
+ cmd := exec.Command(
+ "go", "build", "-o", output,
+ "-ldflags", strings.Join([]string{
+ fmt.Sprintf("-X 'main.toolsDir=%s'", toolsDir),
+ fmt.Sprintf("-X 'main.version=%s'", version),
+ fmt.Sprintf("-X 'main.baseURL=%s'", baseURL),
+ }, " "),
+ "./updater",
+ )
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ return trace.Wrap(cmd.Run())
+}
diff --git a/integration/autoupdate/tools/updater/main.go b/integration/autoupdate/tools/updater/main.go
new file mode 100644
index 0000000000000..e14c76e5d5aa8
--- /dev/null
+++ b/integration/autoupdate/tools/updater/main.go
@@ -0,0 +1,88 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package main
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log"
+ "os"
+ "os/signal"
+ "runtime"
+ "syscall"
+ "time"
+
+ "github.com/gravitational/teleport/api/constants"
+ "github.com/gravitational/teleport/lib/autoupdate/tools"
+)
+
+var (
+ version = "development"
+ baseURL = "http://localhost"
+ toolsDir = ""
+)
+
+func main() {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ ctx, _ = signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
+
+ updater := tools.NewUpdater(
+ clientTools(),
+ toolsDir,
+ version,
+ tools.WithBaseURL(baseURL),
+ )
+ toolsVersion, reExec := updater.CheckLocal()
+ if reExec {
+ // Download and update the version of client tools required by the cluster.
+ // This is required if the user passed in the TELEPORT_TOOLS_VERSION explicitly.
+ err := updater.UpdateWithLock(ctx, toolsVersion)
+ if errors.Is(err, context.Canceled) {
+ os.Exit(0)
+ return
+ }
+ if err != nil {
+ log.Fatalf("failed to download version (%v): %v\n", toolsVersion, err)
+ return
+ }
+
+ // Re-execute client tools with the correct version of client tools.
+ code, err := updater.Exec()
+ if err != nil {
+ log.Fatalf("Failed to re-exec client tool: %v\n", err)
+ } else {
+ os.Exit(code)
+ }
+ }
+ if len(os.Args) > 1 && os.Args[1] == "version" {
+ fmt.Printf("Teleport v%v git\n", version)
+ }
+}
+
+// clientTools list of the client tools needs to be updated.
+func clientTools() []string {
+ switch runtime.GOOS {
+ case constants.WindowsOS:
+ return []string{"tsh.exe", "tctl.exe"}
+ default:
+ return []string{"tsh", "tctl"}
+ }
+}
diff --git a/integration/autoupdate/tools/updater_test.go b/integration/autoupdate/tools/updater_test.go
new file mode 100644
index 0000000000000..96d5486462067
--- /dev/null
+++ b/integration/autoupdate/tools/updater_test.go
@@ -0,0 +1,231 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tools_test
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/api/constants"
+ "github.com/gravitational/teleport/lib/autoupdate/tools"
+)
+
+var (
+ // pattern is template for response on version command for client tools {tsh, tctl}.
+ pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`)
+)
+
+// TestUpdate verifies the basic update logic. We first download a lower version, then request
+// an update to a newer version, expecting it to re-execute with the updated version.
+func TestUpdate(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ // Fetch compiled test binary with updater logic and install to $TELEPORT_HOME.
+ updater := tools.NewUpdater(
+ clientTools(),
+ toolsDir,
+ testVersions[0],
+ tools.WithBaseURL(baseURL),
+ )
+ err := updater.Update(ctx, testVersions[0])
+ require.NoError(t, err)
+
+ // Verify that the installed version is equal to requested one.
+ cmd := exec.CommandContext(ctx, filepath.Join(toolsDir, "tsh"), "version")
+ out, err := cmd.Output()
+ require.NoError(t, err)
+
+ matches := pattern.FindStringSubmatch(string(out))
+ require.Len(t, matches, 2)
+ require.Equal(t, testVersions[0], matches[1])
+
+ // Execute version command again with setting the new version which must
+ // trigger re-execution of the same command after downloading requested version.
+ cmd = exec.CommandContext(ctx, filepath.Join(toolsDir, "tsh"), "version")
+ cmd.Env = append(
+ os.Environ(),
+ fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]),
+ )
+ out, err = cmd.Output()
+ require.NoError(t, err)
+
+ matches = pattern.FindStringSubmatch(string(out))
+ require.Len(t, matches, 2)
+ require.Equal(t, testVersions[1], matches[1])
+}
+
+// TestParallelUpdate launches multiple updater commands in parallel while defining a new version.
+// The first process should acquire a lock and block execution for the other processes. After the
+// first update is complete, other processes should acquire the lock one by one and re-execute
+// the command with the updated version without any new downloads.
+func TestParallelUpdate(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ // Initial fetch the updater binary un-archive and replace.
+ updater := tools.NewUpdater(
+ clientTools(),
+ toolsDir,
+ testVersions[0],
+ tools.WithBaseURL(baseURL),
+ )
+ err := updater.Update(ctx, testVersions[0])
+ require.NoError(t, err)
+
+ // By setting the limit request next test http serving file going blocked until unlock is sent.
+ lock := make(chan struct{})
+ limitedWriter.SetLimitRequest(limitRequest{
+ limit: 1024,
+ lock: lock,
+ })
+
+ outputs := make([]bytes.Buffer, 3)
+ errChan := make(chan error, 3)
+ for i := 0; i < len(outputs); i++ {
+ cmd := exec.Command(filepath.Join(toolsDir, "tsh"), "version")
+ cmd.Stdout = &outputs[i]
+ cmd.Stderr = &outputs[i]
+ cmd.Env = append(
+ os.Environ(),
+ fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]),
+ )
+ err = cmd.Start()
+ require.NoError(t, err, "failed to start updater")
+
+ go func(cmd *exec.Cmd) {
+ errChan <- cmd.Wait()
+ }(cmd)
+ }
+
+ select {
+ case err := <-errChan:
+ require.Fail(t, "we shouldn't receive any error", err)
+ case <-time.After(5 * time.Second):
+ require.Fail(t, "failed to wait till the download is started")
+ case <-lock:
+ // Wait for a short period to allow other processes to launch and attempt to acquire the lock.
+ time.Sleep(100 * time.Millisecond)
+ lock <- struct{}{}
+ }
+
+ // Wait till process finished with exit code 0, but we still should get progress
+ // bar in output content.
+ for i := 0; i < cap(outputs); i++ {
+ select {
+ case <-time.After(5 * time.Second):
+ require.Fail(t, "failed to wait till the process is finished")
+ case err := <-errChan:
+ require.NoError(t, err)
+ }
+ }
+
+ var progressCount int
+ for i := 0; i < cap(outputs); i++ {
+ matches := pattern.FindStringSubmatch(outputs[i].String())
+ require.Len(t, matches, 2)
+ assert.Equal(t, testVersions[1], matches[1])
+ if strings.Contains(outputs[i].String(), "Update progress:") {
+ progressCount++
+ }
+ }
+ assert.Equal(t, 1, progressCount, "we should have only one progress bar downloading new version")
+}
+
+// TestUpdateInterruptSignal verifies the interrupt signal send to the process must stop downloading.
+func TestUpdateInterruptSignal(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ // Initial fetch the updater binary un-archive and replace.
+ updater := tools.NewUpdater(
+ clientTools(),
+ toolsDir,
+ testVersions[0],
+ tools.WithBaseURL(baseURL),
+ )
+ err := updater.Update(ctx, testVersions[0])
+ require.NoError(t, err)
+
+ var output bytes.Buffer
+ cmd := exec.Command(filepath.Join(toolsDir, "tsh"), "version")
+ cmd.Stdout = &output
+ cmd.Stderr = &output
+ cmd.Env = append(
+ os.Environ(),
+ fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]),
+ )
+ err = cmd.Start()
+ require.NoError(t, err, "failed to start updater")
+ pid := cmd.Process.Pid
+
+ errChan := make(chan error)
+ go func() {
+ errChan <- cmd.Wait()
+ }()
+
+ // By setting the limit request next test http serving file going blocked until unlock is sent.
+ lock := make(chan struct{})
+ limitedWriter.SetLimitRequest(limitRequest{
+ limit: 1024,
+ lock: lock,
+ })
+
+ select {
+ case err := <-errChan:
+ require.Fail(t, "we shouldn't receive any error", err)
+ case <-time.After(5 * time.Second):
+ require.Fail(t, "failed to wait till the download is started")
+ case <-lock:
+ time.Sleep(100 * time.Millisecond)
+ require.NoError(t, sendInterrupt(pid))
+ lock <- struct{}{}
+ }
+
+ // Wait till process finished with exit code 0, but we still should get progress
+ // bar in output content.
+ select {
+ case <-time.After(5 * time.Second):
+ require.Fail(t, "failed to wait till the process interrupted")
+ case err := <-errChan:
+ require.NoError(t, err)
+ }
+ assert.Contains(t, output.String(), "Update progress:")
+}
+
+func clientTools() []string {
+ switch runtime.GOOS {
+ case constants.WindowsOS:
+ return []string{"tsh.exe", "tctl.exe"}
+ default:
+ return []string{"tsh", "tctl"}
+ }
+}
diff --git a/lib/autoupdate/tools/progress.go b/lib/autoupdate/tools/progress.go
new file mode 100644
index 0000000000000..95395003730ec
--- /dev/null
+++ b/lib/autoupdate/tools/progress.go
@@ -0,0 +1,43 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tools
+
+import (
+ "fmt"
+ "strings"
+)
+
+type progressWriter struct {
+ n int64
+ limit int64
+}
+
+func (w *progressWriter) Write(p []byte) (int, error) {
+ w.n = w.n + int64(len(p))
+
+ n := int((w.n*100)/w.limit) / 10
+ bricks := strings.Repeat("▒", n) + strings.Repeat(" ", 10-n)
+ fmt.Print("\rUpdate progress: [" + bricks + "] (Ctrl-C to cancel update)")
+
+ if w.n == w.limit {
+ fmt.Print("\n")
+ }
+
+ return len(p), nil
+}
diff --git a/lib/autoupdate/tools/updater.go b/lib/autoupdate/tools/updater.go
new file mode 100644
index 0000000000000..96991044ccc31
--- /dev/null
+++ b/lib/autoupdate/tools/updater.go
@@ -0,0 +1,400 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tools
+
+import (
+ "bytes"
+ "context"
+ "crypto/sha256"
+ "crypto/x509"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "syscall"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/client/webclient"
+ "github.com/gravitational/teleport/api/constants"
+ "github.com/gravitational/teleport/api/types/autoupdate"
+ "github.com/gravitational/teleport/lib/utils"
+ "github.com/gravitational/teleport/lib/utils/packaging"
+)
+
+const (
+ // teleportToolsVersionEnv is environment name for requesting specific version for update.
+ teleportToolsVersionEnv = "TELEPORT_TOOLS_VERSION"
+ // baseURL is CDN URL for downloading official Teleport packages.
+ baseURL = "https://cdn.teleport.dev"
+ // reservedFreeDisk is the predefined amount of free disk space (in bytes) required
+ // to remain available after downloading archives.
+ reservedFreeDisk = 10 * 1024 * 1024 // 10 Mb
+ // lockFileName is file used for locking update process in parallel.
+ lockFileName = ".lock"
+ // updatePackageSuffix is directory suffix used for package extraction in tools directory.
+ updatePackageSuffix = "-update-pkg"
+)
+
+var (
+ // // pattern is template for response on version command for client tools {tsh, tctl}.
+ pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`)
+)
+
+// Option applies an option value for the Updater.
+type Option func(u *Updater)
+
+// WithBaseURL defines custom base url for the updater.
+func WithBaseURL(baseURL string) Option {
+ return func(u *Updater) {
+ u.baseURL = baseURL
+ }
+}
+
+// WithClient defines custom http client for the Updater.
+func WithClient(client *http.Client) Option {
+ return func(u *Updater) {
+ u.client = client
+ }
+}
+
+// Updater is updater implementation for the client tools auto updates.
+type Updater struct {
+ toolsDir string
+ localVersion string
+ tools []string
+
+ baseURL string
+ client *http.Client
+}
+
+// NewUpdater initializes the updater for client tools auto updates. We need to specify the list
+// of tools (e.g., `tsh`, `tctl`) that should be updated, the tools directory path where we
+// download, extract package archives with the new version, and replace symlinks (e.g., `$TELEPORT_HOME/bin`).
+// The base URL of the CDN with Teleport packages and the `http.Client` can be customized via options.
+func NewUpdater(tools []string, toolsDir string, localVersion string, options ...Option) *Updater {
+ updater := &Updater{
+ tools: tools,
+ toolsDir: toolsDir,
+ localVersion: localVersion,
+ baseURL: baseURL,
+ client: http.DefaultClient,
+ }
+ for _, option := range options {
+ option(updater)
+ }
+
+ return updater
+}
+
+// CheckLocal is run at client tool startup and will only perform local checks.
+// Returns the version needs to be updated and re-executed, by re-execution flag we
+// understand that update and re-execute is required.
+func (u *Updater) CheckLocal() (version string, reExec bool) {
+ // Check if the user has requested a specific version of client tools.
+ requestedVersion := os.Getenv(teleportToolsVersionEnv)
+ switch requestedVersion {
+ // The user has turned off any form of automatic updates.
+ case "off":
+ return "", false
+ // Requested version already the same as client version.
+ case u.localVersion:
+ return u.localVersion, false
+ }
+
+ // If a version of client tools has already been downloaded to
+ // tools directory, return that.
+ toolsVersion, err := checkToolVersion(u.toolsDir)
+ if err != nil {
+ return "", false
+ }
+ // The user has requested a specific version of client tools.
+ if requestedVersion != "" && requestedVersion != toolsVersion {
+ return requestedVersion, true
+ }
+
+ return toolsVersion, false
+}
+
+// CheckRemote first checks the version set by the environment variable. If not set or disabled,
+// it checks against the Proxy Service to determine if client tools need updating by requesting
+// the `webapi/find` handler, which stores information about the required client tools version to
+// operate with this cluster. It returns the semantic version that needs updating and whether
+// re-execution is necessary, by re-execution flag we understand that update and re-execute is required.
+func (u *Updater) CheckRemote(ctx context.Context, proxyAddr string) (version string, reExec bool, err error) {
+ // Check if the user has requested a specific version of client tools.
+ requestedVersion := os.Getenv(teleportToolsVersionEnv)
+ switch requestedVersion {
+ // The user has turned off any form of automatic updates.
+ case "off":
+ return "", false, nil
+ // Requested version already the same as client version.
+ case u.localVersion:
+ return u.localVersion, false, nil
+ }
+
+ certPool, err := x509.SystemCertPool()
+ if err != nil {
+ return "", false, trace.Wrap(err)
+ }
+ resp, err := webclient.Find(&webclient.Config{
+ Context: ctx,
+ ProxyAddr: proxyAddr,
+ Pool: certPool,
+ Timeout: 30 * time.Second,
+ })
+ if err != nil {
+ return "", false, trace.Wrap(err)
+ }
+
+ // If a version of client tools has already been downloaded to
+ // tools directory, return that.
+ toolsVersion, err := checkToolVersion(u.toolsDir)
+ if err != nil {
+ return "", false, trace.Wrap(err)
+ }
+
+ switch {
+ case requestedVersion != "" && requestedVersion != toolsVersion:
+ return requestedVersion, true, nil
+ case resp.AutoUpdate.ToolsMode != autoupdate.ToolsUpdateModeEnabled || resp.AutoUpdate.ToolsVersion == "":
+ return "", false, nil
+ case u.localVersion == resp.AutoUpdate.ToolsVersion:
+ return resp.AutoUpdate.ToolsVersion, false, nil
+ case resp.AutoUpdate.ToolsVersion != toolsVersion:
+ return resp.AutoUpdate.ToolsVersion, true, nil
+ }
+
+ return toolsVersion, false, nil
+}
+
+// UpdateWithLock acquires filesystem lock, downloads requested version package,
+// unarchive and replace existing one.
+func (u *Updater) UpdateWithLock(ctx context.Context, toolsVersion string) (err error) {
+ // Create tools directory if it does not exist.
+ if err := os.MkdirAll(u.toolsDir, 0o755); err != nil {
+ return trace.Wrap(err)
+ }
+ // Lock concurrent client tools execution util requested version is updated.
+ unlock, err := utils.FSWriteLock(filepath.Join(u.toolsDir, lockFileName))
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer func() {
+ err = trace.NewAggregate(err, unlock())
+ }()
+
+ // If the version of the running binary or the version downloaded to
+ // tools directory is the same as the requested version of client tools,
+ // nothing to be done, exit early.
+ teleportVersion, err := checkToolVersion(u.toolsDir)
+ if err != nil && !trace.IsNotFound(err) {
+ return trace.Wrap(err)
+
+ }
+ if toolsVersion == u.localVersion || toolsVersion == teleportVersion {
+ return nil
+ }
+
+ // Download and update client tools in tools directory.
+ if err := u.Update(ctx, toolsVersion); err != nil {
+ return trace.Wrap(err)
+ }
+
+ return
+}
+
+// Update downloads requested version and replace it with existing one and cleanups the previous downloads
+// with defined updater directory suffix.
+func (u *Updater) Update(ctx context.Context, toolsVersion string) error {
+ // Get platform specific download URLs.
+ packages, err := teleportPackageURLs(u.baseURL, toolsVersion)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ for _, pkg := range packages {
+ if err := u.update(ctx, pkg); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+
+ return nil
+}
+
+// update downloads the archive and validate against the hash. Download to a
+// temporary path within tools directory.
+func (u *Updater) update(ctx context.Context, pkg packageURL) error {
+ hash, err := u.downloadHash(ctx, pkg.Hash)
+ if pkg.Optional && trace.IsNotFound(err) {
+ return nil
+ }
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ f, err := os.CreateTemp(u.toolsDir, "tmp-")
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer func() {
+ _ = f.Close()
+ if err := os.Remove(f.Name()); err != nil {
+ slog.WarnContext(ctx, "failed to remove temporary archive file", "error", err)
+ }
+ }()
+
+ archiveHash, err := u.downloadArchive(ctx, pkg.Archive, f)
+ if pkg.Optional && trace.IsNotFound(err) {
+ return nil
+ }
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if !bytes.Equal(archiveHash, hash) {
+ return trace.BadParameter("hash of archive does not match downloaded archive")
+ }
+
+ pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix)
+ extractDir := filepath.Join(u.toolsDir, pkgName)
+ if runtime.GOOS != constants.DarwinOS {
+ if err := os.Mkdir(extractDir, 0o755); err != nil {
+ return trace.Wrap(err)
+ }
+ }
+
+ // Perform atomic replace so concurrent exec do not fail.
+ if err := packaging.ReplaceToolsBinaries(u.toolsDir, f.Name(), extractDir, u.tools); err != nil {
+ return trace.Wrap(err)
+ }
+ // Cleanup the tools directory with previously downloaded and un-archived versions.
+ if err := packaging.RemoveWithSuffix(u.toolsDir, updatePackageSuffix, pkgName); err != nil {
+ return trace.Wrap(err)
+ }
+
+ return nil
+}
+
+// Exec re-executes tool command with same arguments and environ variables.
+func (u *Updater) Exec() (int, error) {
+ path, err := toolName(u.toolsDir)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ // To prevent re-execution loop we have to disable update logic for re-execution.
+ env := append(os.Environ(), teleportToolsVersionEnv+"=off")
+
+ if runtime.GOOS == constants.WindowsOS {
+ cmd := exec.Command(path, os.Args[1:]...)
+ cmd.Env = env
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ if err := cmd.Run(); err != nil {
+ return 0, trace.Wrap(err)
+ }
+
+ return cmd.ProcessState.ExitCode(), nil
+ }
+
+ if err := syscall.Exec(path, append([]string{path}, os.Args[1:]...), env); err != nil {
+ return 0, trace.Wrap(err)
+ }
+
+ return 0, nil
+}
+
+// downloadHash downloads the hash file `.sha256` for package checksum validation and return the hash sum.
+func (u *Updater) downloadHash(ctx context.Context, url string) ([]byte, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ resp, err := u.client.Do(req)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode == http.StatusNotFound {
+ return nil, trace.NotFound("hash file is not found: %v", resp.StatusCode)
+ }
+ if resp.StatusCode != http.StatusOK {
+ return nil, trace.BadParameter("bad status when downloading archive hash: %v", resp.StatusCode)
+ }
+
+ var buf bytes.Buffer
+ _, err = io.CopyN(&buf, resp.Body, sha256.Size*2) // SHA bytes to hex
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ hexBytes, err := hex.DecodeString(buf.String())
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return hexBytes, nil
+}
+
+// downloadArchive downloads the archive package by `url` and writes content to the writer interface,
+// return calculated sha256 hash sum of the content.
+func (u *Updater) downloadArchive(ctx context.Context, url string, f io.Writer) ([]byte, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ resp, err := u.client.Do(req)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode == http.StatusNotFound {
+ return nil, trace.NotFound("archive file is not found: %v", resp.StatusCode)
+ }
+ if resp.StatusCode != http.StatusOK {
+ return nil, trace.BadParameter("bad status when downloading archive: %v", resp.StatusCode)
+ }
+
+ if resp.ContentLength != -1 {
+ if err := checkFreeSpace(u.toolsDir, uint64(resp.ContentLength)); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ }
+
+ h := sha256.New()
+ pw := &progressWriter{n: 0, limit: resp.ContentLength}
+ body := io.TeeReader(io.TeeReader(resp.Body, h), pw)
+
+ // It is a little inefficient to download the file to disk and then re-load
+ // it into memory to unarchive later, but this is safer as it allows client
+ // tools to validate the hash before trying to operate on the archive.
+ _, err = io.CopyN(f, body, resp.ContentLength)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return h.Sum(nil), nil
+}
diff --git a/lib/autoupdate/tools/utils.go b/lib/autoupdate/tools/utils.go
new file mode 100644
index 0000000000000..d552b31abefe4
--- /dev/null
+++ b/lib/autoupdate/tools/utils.go
@@ -0,0 +1,175 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tools
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/coreos/go-semver/semver"
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/modules"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+// Dir returns the path to client tools in $TELEPORT_HOME/bin.
+func Dir() (string, error) {
+ home := os.Getenv(types.HomeEnvVar)
+ if home == "" {
+ var err error
+ home, err = os.UserHomeDir()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ }
+
+ return filepath.Join(home, ".tsh", "bin"), nil
+}
+
+func checkToolVersion(toolsDir string) (string, error) {
+ // Find the path to the current executable.
+ path, err := toolName(toolsDir)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
+ return "", nil
+ } else if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ // Set a timeout to not let "{tsh, tctl} version" block forever. Allow up
+ // to 10 seconds because sometimes MDM tools like Jamf cause a lot of
+ // latency in launching binaries.
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ // Execute "{tsh, tctl} version" and pass in TELEPORT_TOOLS_VERSION=off to
+ // turn off all automatic updates code paths to prevent any recursion.
+ command := exec.CommandContext(ctx, path, "version")
+ command.Env = []string{teleportToolsVersionEnv + "=off"}
+ output, err := command.Output()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ // The output for "{tsh, tctl} version" can be multiple lines. Find the
+ // actual version line and extract the version.
+ scanner := bufio.NewScanner(bytes.NewReader(output))
+ for scanner.Scan() {
+ line := scanner.Text()
+
+ if !strings.HasPrefix(line, "Teleport") {
+ continue
+ }
+
+ matches := pattern.FindStringSubmatch(line)
+ if len(matches) != 2 {
+ return "", trace.BadParameter("invalid version line: %v", line)
+ }
+ version, err := semver.NewVersion(matches[1])
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ return version.String(), nil
+ }
+
+ return "", trace.BadParameter("unable to determine version")
+}
+
+// packageURL defines URLs to the archive and their archive sha256 hash file, and marks
+// if this package is optional, for such case download needs to be ignored if package
+// not found in CDN.
+type packageURL struct {
+ Archive string
+ Hash string
+ Optional bool
+}
+
+// teleportPackageURLs returns the URL for the Teleport archive to download. The format is:
+// https://cdn.teleport.dev/teleport-{, ent-}v15.3.0-{linux, darwin, windows}-{amd64,arm64,arm,386}-{fips-}bin.tar.gz
+func teleportPackageURLs(baseURL, toolsVersion string) ([]packageURL, error) {
+ switch runtime.GOOS {
+ case "darwin":
+ tsh := baseURL + "/tsh-" + toolsVersion + ".pkg"
+ teleport := baseURL + "/teleport-" + toolsVersion + ".pkg"
+ return []packageURL{
+ {Archive: teleport, Hash: teleport + ".sha256"},
+ {Archive: tsh, Hash: tsh + ".sha256", Optional: true},
+ }, nil
+ case "windows":
+ archive := baseURL + "/teleport-v" + toolsVersion + "-windows-amd64-bin.zip"
+ return []packageURL{
+ {Archive: archive, Hash: archive + ".sha256"},
+ }, nil
+ case "linux":
+ m := modules.GetModules()
+ var b strings.Builder
+ b.WriteString(baseURL + "/teleport-")
+ if m.IsEnterpriseBuild() || m.IsBoringBinary() {
+ b.WriteString("ent-")
+ }
+ b.WriteString("v" + toolsVersion + "-" + runtime.GOOS + "-" + runtime.GOARCH + "-")
+ if m.IsBoringBinary() {
+ b.WriteString("fips-")
+ }
+ b.WriteString("bin.tar.gz")
+ archive := b.String()
+ return []packageURL{
+ {Archive: archive, Hash: archive + ".sha256"},
+ }, nil
+ default:
+ return nil, trace.BadParameter("unsupported runtime: %v", runtime.GOOS)
+ }
+}
+
+// toolName returns the path to {tsh, tctl} for the executable that started
+// the current process.
+func toolName(toolsDir string) (string, error) {
+ executablePath, err := os.Executable()
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+
+ return filepath.Join(toolsDir, filepath.Base(executablePath)), nil
+}
+
+// checkFreeSpace verifies that we have enough requested space at specific directory.
+func checkFreeSpace(path string, requested uint64) error {
+ free, err := utils.FreeDiskWithReserve(path, reservedFreeDisk)
+ if err != nil {
+ return trace.Errorf("failed to calculate free disk in %q: %v", path, err)
+ }
+ // Bail if there's not enough free disk space at the target.
+ if requested > free {
+ return trace.Errorf("%q needs %d additional bytes of disk space", path, requested-free)
+ }
+
+ return nil
+}
diff --git a/lib/utils/disk.go b/lib/utils/disk.go
index b3553c873aa3e..782c598d26229 100644
--- a/lib/utils/disk.go
+++ b/lib/utils/disk.go
@@ -53,11 +53,8 @@ func FreeDiskWithReserve(dir string, reservedFreeDisk uint64) (uint64, error) {
if err != nil {
return 0, trace.Wrap(err)
}
- //nolint:staticcheck // SA4003. False positive on macOS.
- if stat.Bsize < 0 {
- return 0, trace.Errorf("invalid size")
- }
- avail := stat.Bavail * uint64(stat.Bsize)
+ //nolint:unconvert // The cast is only necessary for linux platform.
+ avail := uint64(stat.Bavail) * uint64(stat.Bsize)
if reservedFreeDisk > avail {
return 0, trace.Errorf("no free space left")
}
diff --git a/lib/utils/packaging/unarchive_unix.go b/lib/utils/packaging/unarchive_unix.go
index 3be7d0c473ef9..ea51afdbbc7f0 100644
--- a/lib/utils/packaging/unarchive_unix.go
+++ b/lib/utils/packaging/unarchive_unix.go
@@ -33,6 +33,8 @@ import (
"github.com/google/renameio/v2"
"github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/api/constants"
)
// ReplaceToolsBinaries extracts executables specified by execNames from archivePath into
@@ -43,7 +45,7 @@ import (
// For other POSIX, archivePath must be a gzipped tarball.
func ReplaceToolsBinaries(toolsDir string, archivePath string, extractDir string, execNames []string) error {
switch runtime.GOOS {
- case "darwin":
+ case constants.DarwinOS:
return replacePkg(toolsDir, archivePath, extractDir, execNames)
default:
return replaceTarGz(toolsDir, archivePath, extractDir, execNames)