Skip to content

Commit

Permalink
Allow specifying a target container for pilot VMs (#4218)
Browse files Browse the repository at this point in the history
* Allow specifying a target container for pilot VMs

* Error handling

* format
  • Loading branch information
jsierles authored Feb 13, 2025
1 parent ef6c170 commit 49f10cd
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 8 deletions.
7 changes: 6 additions & 1 deletion internal/command/console/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ func New() *cobra.Command {
Description: "Select the machine on which to execute the console from a list.",
Default: false,
},
flag.String{
Name: "container",
Description: "Container to connect to",
},
flag.String{
Name: "user",
Shorthand: "u",
Expand Down Expand Up @@ -221,6 +225,7 @@ func runConsole(ctx context.Context) error {
Dialer: dialer,
Username: flag.GetString(ctx, "user"),
DisableSpinner: false,
Container: flag.GetString(ctx, "container"),
AppNames: []string{app.Name},
}
sshClient, err := ssh.Connect(params, machine.PrivateIP)
Expand All @@ -234,7 +239,7 @@ func runConsole(ctx context.Context) error {
consoleCommand = flag.GetString(ctx, "command")
}

return ssh.Console(ctx, sshClient, consoleCommand, true)
return ssh.Console(ctx, sshClient, consoleCommand, true, params.Container)
}

func selectMachine(ctx context.Context, app *fly.AppCompact, appConfig *appconfig.Config) (*fly.Machine, func(), error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/command/machine/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ func runMachineRun(ctx context.Context) error {
return err
}

err = ssh.Console(ctx, sshClient, flag.GetString(ctx, "command"), true)
err = ssh.Console(ctx, sshClient, flag.GetString(ctx, "command"), true, "")
if destroy {
err = soManyErrors("console", err, "destroy machine", Destroy(ctx, app, machine, true))
}
Expand Down
2 changes: 1 addition & 1 deletion internal/command/postgres/barman.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func runConsole(ctx context.Context, cmd string) error {
return err
}

if err := ssh.Console(ctx, sshc, cmd, false); err != nil {
if err := ssh.Console(ctx, sshc, cmd, false, ""); err != nil {
captureError(ctx, err, app)
return err
}
Expand Down
1 change: 1 addition & 0 deletions internal/command/ssh/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type ConnectParams struct {
Username string
Dialer agent.Dialer
DisableSpinner bool
Container string
AppNames []string
}

Expand Down
11 changes: 8 additions & 3 deletions internal/command/ssh/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ func stdArgsSSH(cmd *cobra.Command) {
Shorthand: "A",
Description: "Address of VM to connect to",
},
flag.String{
Name: "container",
Description: "Container to connect to",
},
flag.Bool{
Name: "pty",
Description: "Allocate a pseudo-terminal (default: on when no command is provided)",
Expand Down Expand Up @@ -177,6 +181,7 @@ func runConsole(ctx context.Context) error {
Dialer: dialer,
Username: flag.GetString(ctx, "user"),
DisableSpinner: quiet(ctx),
Container: flag.GetString(ctx, "container"),
AppNames: []string{app.Name},
}
sshc, err := Connect(params, addr)
Expand All @@ -185,15 +190,15 @@ func runConsole(ctx context.Context) error {
return err
}

if err := Console(ctx, sshc, cmd, allocPTY); err != nil {
if err := Console(ctx, sshc, cmd, allocPTY, params.Container); err != nil {
captureError(ctx, err, app)
return err
}

return nil
}

func Console(ctx context.Context, sshClient *ssh.Client, cmd string, allocPTY bool) error {
func Console(ctx context.Context, sshClient *ssh.Client, cmd string, allocPTY bool, container string) error {
currentStdin, currentStdout, currentStderr, err := setupConsole()
defer func() error {
if err := cleanupConsole(currentStdin, currentStdout, currentStderr); err != nil {
Expand All @@ -214,7 +219,7 @@ func Console(ctx context.Context, sshClient *ssh.Client, cmd string, allocPTY bo
TermEnv: determineTermEnv(),
}

if err := sshClient.Shell(ctx, sessIO, cmd); err != nil {
if err := sshClient.Shell(ctx, sessIO, cmd, container); err != nil {
return errors.Wrap(err, "ssh shell")
}

Expand Down
2 changes: 1 addition & 1 deletion internal/command/ssh/ssh_terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func SSHConnect(p *SSHParams, addr string) error {
TermEnv: "xterm",
}

if err := sshClient.Shell(context.Background(), sessIO, p.Cmd); err != nil {
if err := sshClient.Shell(context.Background(), sessIO, p.Cmd, ""); err != nil {
return errors.Wrap(err, "ssh shell")
}

Expand Down
11 changes: 10 additions & 1 deletion ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,26 @@ func (c *Client) Connect(ctx context.Context) error {
}
}

func (c *Client) Shell(ctx context.Context, sessIO *SessionIO, cmd string) error {
func (c *Client) Shell(ctx context.Context, sessIO *SessionIO, cmd string, container string) error {
if c.Client == nil {
if err := c.Connect(ctx); err != nil {
return err
}
}

sess, err := c.Client.NewSession()

if err != nil {
return err
}

if container != "" {
err = sess.Setenv("FLY_SSH_CONTAINER", container)
if err != nil {
return err
}
}

defer sess.Close()

return sessIO.attach(ctx, sess, cmd)
Expand Down

0 comments on commit 49f10cd

Please sign in to comment.