diff --git a/lib/client/api.go b/lib/client/api.go index 3bde83684ff4f..b6db979d53719 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2396,7 +2396,7 @@ func PlayFile(ctx context.Context, filename, sid string, speed float64, skipIdle } // SFTP securely copies files between Nodes or SSH servers using SFTP -func (tc *TeleportClient) SFTP(ctx context.Context, args []string, port int, opts sftp.Options, quiet bool) (err error) { +func (tc *TeleportClient) SFTP(ctx context.Context, cfg *sftp.Config) (err error) { ctx, span := tc.Tracer.Start( ctx, "teleportClient/SFTP", @@ -2404,104 +2404,32 @@ func (tc *TeleportClient) SFTP(ctx context.Context, args []string, port int, opt ) defer span.End() - if len(args) < 2 { - return trace.Errorf("local and remote destinations are required") - } - first := args[0] - last := args[len(args)-1] - - // local copy? - if !isRemoteDest(first) && !isRemoteDest(last) { - return trace.BadParameter("no remote destination specified") - } - - var config *sftpConfig - if isRemoteDest(last) { - config, err = tc.uploadConfig(args, port, opts) - if err != nil { - return trace.Wrap(err) - } - } else { - config, err = tc.downloadConfig(args, port, opts) - if err != nil { - return trace.Wrap(err) - } - } - if config.hostLogin == "" { - config.hostLogin = tc.Config.HostLogin - } - - if !quiet { - config.cfg.ProgressStream = func(fileInfo os.FileInfo) io.ReadWriter { - return sftp.NewProgressBar(fileInfo.Size(), fileInfo.Name(), tc.Stdout) - } - } - - return trace.Wrap(tc.TransferFiles(ctx, config.hostLogin, config.addr, config.cfg)) -} - -type sftpConfig struct { - cfg *sftp.Config - addr string - hostLogin string -} - -func (tc *TeleportClient) uploadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) { - // args are guaranteed to have len(args) > 1 - srcPaths := args[:len(args)-1] - // copy everything except the last arg (the destination) - dstPath := args[len(args)-1] - - dst, addr, err := getSFTPDestination(dstPath, port) - if err != nil { - return nil, trace.Wrap(err) - } - cfg, err := sftp.CreateUploadConfig(srcPaths, dst.Path, opts) + clt, err := tc.ConnectToCluster(ctx) if err != nil { - return nil, trace.Wrap(err) - } - - return &sftpConfig{ - cfg: cfg, - addr: addr, - hostLogin: dst.Login, - }, nil -} - -func (tc *TeleportClient) downloadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) { - if len(args) > 2 { - return nil, trace.BadParameter("only one source file is supported when downloading files") + return trace.Wrap(err) } + defer clt.Close() - // args are guaranteed to have len(args) > 1 - src, addr, err := getSFTPDestination(args[0], port) + // Expand any proxy templates and attempt host resolution. + resolvedNodes, err := tc.GetTargetNodes(ctx, clt.AuthClient, SSHOptions{}) if err != nil { - return nil, trace.Wrap(err) - } - cfg, err := sftp.CreateDownloadConfig(src.Path, args[1], opts) - if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } - return &sftpConfig{ - cfg: cfg, - addr: addr, - hostLogin: src.Login, - }, nil -} - -func getSFTPDestination(target string, port int) (dest *sftp.Destination, addr string, err error) { - dest, err = sftp.ParseDestination(target) - if err != nil { - return nil, "", trace.Wrap(err) + switch len(resolvedNodes) { + case 0: + return trace.BadParameter("no matching target host found") + case 1: + default: + return trace.BadParameter("multiple matching target hosts found") } - addr = net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port)) - return dest, addr, nil + + return trace.Wrap(tc.TransferFiles(ctx, clt, tc.HostLogin, resolvedNodes[0].Addr, cfg)) } // TransferFiles copies files between the current machine and the // specified Node using the supplied config -func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr string, cfg *sftp.Config) error { +func (tc *TeleportClient) TransferFiles(ctx context.Context, clt *ClusterClient, hostLogin, nodeAddr string, cfg *sftp.Config) error { ctx, span := tc.Tracer.Start( ctx, "teleportClient/TransferFiles", @@ -2519,13 +2447,8 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr if !tc.Config.ProxySpecified() { return trace.BadParameter("proxy server is not specified") } - clt, err := tc.ConnectToCluster(ctx) - if err != nil { - return trace.Wrap(err) - } - defer clt.Close() - client, err := tc.ConnectToNode( + nodeClient, err := tc.ConnectToNode( ctx, clt, NodeDetails{ @@ -2539,11 +2462,7 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr return trace.Wrap(err) } - return trace.Wrap(client.TransferFiles(ctx, cfg)) -} - -func isRemoteDest(name string) bool { - return strings.ContainsRune(name, ':') + return trace.Wrap(nodeClient.TransferFiles(ctx, cfg)) } // ListNodesWithFilters returns all nodes that match the filters in the current cluster diff --git a/lib/teleterm/clusters/cluster_file_transfer.go b/lib/teleterm/clusters/cluster_file_transfer.go index c55525123ddae..5047476c1f683 100644 --- a/lib/teleterm/clusters/cluster_file_transfer.go +++ b/lib/teleterm/clusters/cluster_file_transfer.go @@ -29,13 +29,14 @@ import ( "github.com/gravitational/trace" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" + "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/teleterm/api/uri" ) type FileTransferProgressSender = func(progress *api.FileTransferProgress) error -func (c *Cluster) TransferFile(ctx context.Context, request *api.FileTransferRequest, sendProgress FileTransferProgressSender) error { +func (c *Cluster) TransferFile(ctx context.Context, clt *client.ClusterClient, request *api.FileTransferRequest, sendProgress FileTransferProgressSender) error { config, err := getSftpConfig(request) if err != nil { return trace.Wrap(err) @@ -54,7 +55,7 @@ func (c *Cluster) TransferFile(ctx context.Context, request *api.FileTransferReq } err = AddMetadataToRetryableError(ctx, func() error { - err := c.clusterClient.TransferFiles(ctx, request.GetLogin(), serverUUID+":0", config) + err := c.clusterClient.TransferFiles(ctx, clt, request.GetLogin(), serverUUID+":0", config) if errors.As(err, new(*sftp.NonRecursiveDirectoryTransferError)) { return trace.Errorf("transferring directories through Teleport Connect is not supported at the moment, please use tsh scp -r") } diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index b324e061e07a7..310dbed2ff713 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -942,7 +942,12 @@ func (s *Service) TransferFile(ctx context.Context, request *api.FileTransferReq return trace.Wrap(err) } - return cluster.TransferFile(ctx, request, sendProgress) + clt, err := s.GetCachedClient(ctx, cluster.URI) + if err != nil { + return trace.Wrap(err) + } + + return cluster.TransferFile(ctx, clt, request, sendProgress) } // CreateConnectMyComputerRole creates a role which allows access to nodes with the label diff --git a/lib/web/files.go b/lib/web/files.go index e43c341d70022..53248258dd034 100644 --- a/lib/web/files.go +++ b/lib/web/files.go @@ -147,7 +147,13 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou ctx = context.WithValue(ctx, sftp.ModeratedSessionID, req.moderatedSessionID) } - err = tc.TransferFiles(ctx, req.login, req.serverID+":0", cfg) + cl, err := tc.ConnectToCluster(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + defer cl.Close() + + err = tc.TransferFiles(ctx, cl, req.login, req.serverID+":0", cfg) if err != nil { if errors.As(err, new(*sftp.NonRecursiveDirectoryTransferError)) { return nil, trace.Errorf("transferring directories through the Web UI is not supported at the moment, please use tsh scp -r") diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 4875f836b3a4f..f290f4a5bf601 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -3953,8 +3953,108 @@ func onJoin(cf *CLIConf) error { return nil } +type sftpConfig struct { + cfg *sftp.Config + addr string + hostLogin string +} + +func getSFTPDestination(target string, port int) (dest *sftp.Destination, addr string, err error) { + dest, err = sftp.ParseDestination(target) + if err != nil { + return nil, "", trace.Wrap(err) + } + addr = net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port)) + return dest, addr, nil +} + +func uploadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) { + // args are guaranteed to have len(args) > 1 + srcPaths := args[:len(args)-1] + // copy everything except the last arg (the destination) + dstPath := args[len(args)-1] + + dst, addr, err := getSFTPDestination(dstPath, port) + if err != nil { + return nil, trace.Wrap(err) + } + cfg, err := sftp.CreateUploadConfig(srcPaths, dst.Path, opts) + if err != nil { + return nil, trace.Wrap(err) + } + + return &sftpConfig{ + cfg: cfg, + addr: addr, + hostLogin: dst.Login, + }, nil +} + +func downloadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) { + if len(args) > 2 { + return nil, trace.BadParameter("only one source file is supported when downloading files") + } + + // args are guaranteed to have len(args) > 1 + src, addr, err := getSFTPDestination(args[0], port) + if err != nil { + return nil, trace.Wrap(err) + } + cfg, err := sftp.CreateDownloadConfig(src.Path, args[1], opts) + if err != nil { + return nil, trace.Wrap(err) + } + + return &sftpConfig{ + cfg: cfg, + addr: addr, + hostLogin: src.Login, + }, nil +} + // onSCP executes 'tsh scp' command func onSCP(cf *CLIConf) error { + if len(cf.CopySpec) != 2 { + return trace.Errorf("local and remote destinations are required") + } + + opts := sftp.Options{ + Recursive: cf.RecursiveCopy, + PreserveAttrs: cf.PreserveAttrs, + } + + first := cf.CopySpec[0] + last := cf.CopySpec[1] + + localCopy := strings.ContainsRune(first, ':') + remoteCopy := strings.ContainsRune(last, ':') + + if !localCopy && !remoteCopy { + return trace.BadParameter("no remote destination specified") + } + + var ( + config *sftpConfig + err error + ) + if remoteCopy { + config, err = uploadConfig(cf.CopySpec, int(cf.NodePort), opts) + if err != nil { + return trace.Wrap(err) + } + } else { + config, err = downloadConfig(cf.CopySpec, int(cf.NodePort), opts) + if err != nil { + return trace.Wrap(err) + } + } + + if !cf.Quiet { + config.cfg.ProgressStream = func(fileInfo os.FileInfo) io.ReadWriter { + return sftp.NewProgressBar(fileInfo.Size(), fileInfo.Name(), cf.Stdout()) + } + } + tc, err := makeClient(cf) if err != nil { return trace.Wrap(err) @@ -3967,12 +4067,8 @@ func onSCP(cf *CLIConf) error { cf.Context = ctx defer cancel() - opts := sftp.Options{ - Recursive: cf.RecursiveCopy, - PreserveAttrs: cf.PreserveAttrs, - } err = client.RetryWithRelogin(cf.Context, tc, func() error { - return tc.SFTP(cf.Context, cf.CopySpec, int(cf.NodePort), opts, cf.Quiet) + return tc.SFTP(cf.Context, config.cfg) }) // don't print context canceled errors to the user if err == nil || errors.Is(err, context.Canceled) { @@ -4075,14 +4171,20 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err } else if cf.CopySpec != nil { for _, location := range cf.CopySpec { // Extract username and host from "username@host:file/path" - parts := strings.Split(location, ":") - parts = strings.Split(parts[0], "@") - partsLength := len(parts) - if partsLength > 1 { - hostLogin = strings.Join(parts[:partsLength-1], "@") - hostUser = parts[partsLength-1] - break + userHost, _, found := strings.Cut(location, ":") + if !found { + continue } + + login, hostname, found := strings.Cut(userHost, "@") + if found { + hostLogin = login + hostUser = hostname + } else { + hostUser = userHost + } + break + } }