Skip to content

Commit

Permalink
Add host resolution support to tsh scp
Browse files Browse the repository at this point in the history
Extends tsh scp functionality to honor any defined proxy templates
when resolving the remote host.

Closes #45465
  • Loading branch information
rosstimothy committed Nov 5, 2024
1 parent c9a7fce commit a9fbc43
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 115 deletions.
117 changes: 18 additions & 99 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2396,112 +2396,40 @@ func PlayFile(ctx context.Context, filename, sid string, speed float64, skipIdle
}

// SFTP securely copies files between Nodes or SSH servers using SFTP
func (tc *TeleportClient) SFTP(ctx context.Context, args []string, port int, opts sftp.Options, quiet bool) (err error) {
func (tc *TeleportClient) SFTP(ctx context.Context, cfg *sftp.Config) (err error) {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/SFTP",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

if len(args) < 2 {
return trace.Errorf("local and remote destinations are required")
}
first := args[0]
last := args[len(args)-1]

// local copy?
if !isRemoteDest(first) && !isRemoteDest(last) {
return trace.BadParameter("no remote destination specified")
}

var config *sftpConfig
if isRemoteDest(last) {
config, err = tc.uploadConfig(args, port, opts)
if err != nil {
return trace.Wrap(err)
}
} else {
config, err = tc.downloadConfig(args, port, opts)
if err != nil {
return trace.Wrap(err)
}
}
if config.hostLogin == "" {
config.hostLogin = tc.Config.HostLogin
}

if !quiet {
config.cfg.ProgressStream = func(fileInfo os.FileInfo) io.ReadWriter {
return sftp.NewProgressBar(fileInfo.Size(), fileInfo.Name(), tc.Stdout)
}
}

return trace.Wrap(tc.TransferFiles(ctx, config.hostLogin, config.addr, config.cfg))
}

type sftpConfig struct {
cfg *sftp.Config
addr string
hostLogin string
}

func (tc *TeleportClient) uploadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) {
// args are guaranteed to have len(args) > 1
srcPaths := args[:len(args)-1]
// copy everything except the last arg (the destination)
dstPath := args[len(args)-1]

dst, addr, err := getSFTPDestination(dstPath, port)
if err != nil {
return nil, trace.Wrap(err)
}
cfg, err := sftp.CreateUploadConfig(srcPaths, dst.Path, opts)
clt, err := tc.ConnectToCluster(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

return &sftpConfig{
cfg: cfg,
addr: addr,
hostLogin: dst.Login,
}, nil
}

func (tc *TeleportClient) downloadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) {
if len(args) > 2 {
return nil, trace.BadParameter("only one source file is supported when downloading files")
return trace.Wrap(err)
}
defer clt.Close()

// args are guaranteed to have len(args) > 1
src, addr, err := getSFTPDestination(args[0], port)
// Expand any proxy templates and attempt host resolution.
resolvedNodes, err := tc.GetTargetNodes(ctx, clt.AuthClient, SSHOptions{})
if err != nil {
return nil, trace.Wrap(err)
}
cfg, err := sftp.CreateDownloadConfig(src.Path, args[1], opts)
if err != nil {
return nil, trace.Wrap(err)
return trace.Wrap(err)
}

return &sftpConfig{
cfg: cfg,
addr: addr,
hostLogin: src.Login,
}, nil
}

func getSFTPDestination(target string, port int) (dest *sftp.Destination, addr string, err error) {
dest, err = sftp.ParseDestination(target)
if err != nil {
return nil, "", trace.Wrap(err)
switch len(resolvedNodes) {
case 0:
return trace.BadParameter("no matching target host found")
case 1:
default:
return trace.BadParameter("multiple matching target hosts found")
}
addr = net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port))
return dest, addr, nil

return trace.Wrap(tc.TransferFiles(ctx, clt, tc.HostLogin, resolvedNodes[0].Addr, cfg))
}

// TransferFiles copies files between the current machine and the
// specified Node using the supplied config
func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr string, cfg *sftp.Config) error {
func (tc *TeleportClient) TransferFiles(ctx context.Context, clt *ClusterClient, hostLogin, nodeAddr string, cfg *sftp.Config) error {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/TransferFiles",
Expand All @@ -2519,13 +2447,8 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr
if !tc.Config.ProxySpecified() {
return trace.BadParameter("proxy server is not specified")
}
clt, err := tc.ConnectToCluster(ctx)
if err != nil {
return trace.Wrap(err)
}
defer clt.Close()

client, err := tc.ConnectToNode(
nodeClient, err := tc.ConnectToNode(
ctx,
clt,
NodeDetails{
Expand All @@ -2539,11 +2462,7 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr
return trace.Wrap(err)
}

return trace.Wrap(client.TransferFiles(ctx, cfg))
}

func isRemoteDest(name string) bool {
return strings.ContainsRune(name, ':')
return trace.Wrap(nodeClient.TransferFiles(ctx, cfg))
}

// ListNodesWithFilters returns all nodes that match the filters in the current cluster
Expand Down
5 changes: 3 additions & 2 deletions lib/teleterm/clusters/cluster_file_transfer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ import (
"github.com/gravitational/trace"

api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/teleterm/api/uri"
)

type FileTransferProgressSender = func(progress *api.FileTransferProgress) error

func (c *Cluster) TransferFile(ctx context.Context, request *api.FileTransferRequest, sendProgress FileTransferProgressSender) error {
func (c *Cluster) TransferFile(ctx context.Context, clt *client.ClusterClient, request *api.FileTransferRequest, sendProgress FileTransferProgressSender) error {
config, err := getSftpConfig(request)
if err != nil {
return trace.Wrap(err)
Expand All @@ -54,7 +55,7 @@ func (c *Cluster) TransferFile(ctx context.Context, request *api.FileTransferReq
}

err = AddMetadataToRetryableError(ctx, func() error {
err := c.clusterClient.TransferFiles(ctx, request.GetLogin(), serverUUID+":0", config)
err := c.clusterClient.TransferFiles(ctx, clt, request.GetLogin(), serverUUID+":0", config)
if errors.As(err, new(*sftp.NonRecursiveDirectoryTransferError)) {
return trace.Errorf("transferring directories through Teleport Connect is not supported at the moment, please use tsh scp -r")
}
Expand Down
7 changes: 6 additions & 1 deletion lib/teleterm/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,12 @@ func (s *Service) TransferFile(ctx context.Context, request *api.FileTransferReq
return trace.Wrap(err)
}

return cluster.TransferFile(ctx, request, sendProgress)
clt, err := s.GetCachedClient(ctx, cluster.URI)
if err != nil {
return trace.Wrap(err)
}

return cluster.TransferFile(ctx, clt, request, sendProgress)
}

// CreateConnectMyComputerRole creates a role which allows access to nodes with the label
Expand Down
8 changes: 7 additions & 1 deletion lib/web/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,13 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou
ctx = context.WithValue(ctx, sftp.ModeratedSessionID, req.moderatedSessionID)
}

err = tc.TransferFiles(ctx, req.login, req.serverID+":0", cfg)
cl, err := tc.ConnectToCluster(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
defer cl.Close()

err = tc.TransferFiles(ctx, cl, req.login, req.serverID+":0", cfg)
if err != nil {
if errors.As(err, new(*sftp.NonRecursiveDirectoryTransferError)) {
return nil, trace.Errorf("transferring directories through the Web UI is not supported at the moment, please use tsh scp -r")
Expand Down
126 changes: 114 additions & 12 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -3953,8 +3953,108 @@ func onJoin(cf *CLIConf) error {
return nil
}

type sftpConfig struct {
cfg *sftp.Config
addr string
hostLogin string
}

func getSFTPDestination(target string, port int) (dest *sftp.Destination, addr string, err error) {
dest, err = sftp.ParseDestination(target)
if err != nil {
return nil, "", trace.Wrap(err)
}
addr = net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port))
return dest, addr, nil
}

func uploadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) {
// args are guaranteed to have len(args) > 1
srcPaths := args[:len(args)-1]
// copy everything except the last arg (the destination)
dstPath := args[len(args)-1]

dst, addr, err := getSFTPDestination(dstPath, port)
if err != nil {
return nil, trace.Wrap(err)
}
cfg, err := sftp.CreateUploadConfig(srcPaths, dst.Path, opts)
if err != nil {
return nil, trace.Wrap(err)
}

return &sftpConfig{
cfg: cfg,
addr: addr,
hostLogin: dst.Login,
}, nil
}

func downloadConfig(args []string, port int, opts sftp.Options) (*sftpConfig, error) {
if len(args) > 2 {
return nil, trace.BadParameter("only one source file is supported when downloading files")
}

// args are guaranteed to have len(args) > 1
src, addr, err := getSFTPDestination(args[0], port)
if err != nil {
return nil, trace.Wrap(err)
}
cfg, err := sftp.CreateDownloadConfig(src.Path, args[1], opts)
if err != nil {
return nil, trace.Wrap(err)
}

return &sftpConfig{
cfg: cfg,
addr: addr,
hostLogin: src.Login,
}, nil
}

// onSCP executes 'tsh scp' command
func onSCP(cf *CLIConf) error {
if len(cf.CopySpec) != 2 {
return trace.Errorf("local and remote destinations are required")
}

opts := sftp.Options{
Recursive: cf.RecursiveCopy,
PreserveAttrs: cf.PreserveAttrs,
}

first := cf.CopySpec[0]
last := cf.CopySpec[1]

localCopy := strings.ContainsRune(first, ':')
remoteCopy := strings.ContainsRune(last, ':')

if !localCopy && !remoteCopy {
return trace.BadParameter("no remote destination specified")
}

var (
config *sftpConfig
err error
)
if remoteCopy {
config, err = uploadConfig(cf.CopySpec, int(cf.NodePort), opts)
if err != nil {
return trace.Wrap(err)
}
} else {
config, err = downloadConfig(cf.CopySpec, int(cf.NodePort), opts)
if err != nil {
return trace.Wrap(err)
}
}

if !cf.Quiet {
config.cfg.ProgressStream = func(fileInfo os.FileInfo) io.ReadWriter {
return sftp.NewProgressBar(fileInfo.Size(), fileInfo.Name(), cf.Stdout())
}
}

tc, err := makeClient(cf)
if err != nil {
return trace.Wrap(err)
Expand All @@ -3967,12 +4067,8 @@ func onSCP(cf *CLIConf) error {
cf.Context = ctx
defer cancel()

opts := sftp.Options{
Recursive: cf.RecursiveCopy,
PreserveAttrs: cf.PreserveAttrs,
}
err = client.RetryWithRelogin(cf.Context, tc, func() error {
return tc.SFTP(cf.Context, cf.CopySpec, int(cf.NodePort), opts, cf.Quiet)
return tc.SFTP(cf.Context, config.cfg)
})
// don't print context canceled errors to the user
if err == nil || errors.Is(err, context.Canceled) {
Expand Down Expand Up @@ -4075,14 +4171,20 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err
} else if cf.CopySpec != nil {
for _, location := range cf.CopySpec {
// Extract username and host from "username@host:file/path"
parts := strings.Split(location, ":")
parts = strings.Split(parts[0], "@")
partsLength := len(parts)
if partsLength > 1 {
hostLogin = strings.Join(parts[:partsLength-1], "@")
hostUser = parts[partsLength-1]
break
userHost, _, found := strings.Cut(location, ":")
if !found {
continue
}

login, hostname, found := strings.Cut(userHost, "@")
if found {
hostLogin = login
hostUser = hostname
} else {
hostUser = userHost
}
break

}
}

Expand Down

0 comments on commit a9fbc43

Please sign in to comment.