Skip to content

Commit

Permalink
Honour -u flag which overrides username.
Browse files Browse the repository at this point in the history
In non-cache mode try to get EC2 username from tags, as we are describing
it anyway.

Fixes #16
Fixes #17
  • Loading branch information
Eugene Dementyev authored and ekini committed Jul 29, 2021
1 parent cdc822f commit f8f499c
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 38 deletions.
53 changes: 27 additions & 26 deletions cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ These placeholders are useful when you need to override the ssh command.`,
var sshEntries lib.SSHEntries
var profile string
var instanceID = viper.GetString("instanceid")
var defaultUser = viper.GetString("user")
var instanceUser = viper.GetString("user")

profiles := viper.GetStringSlice("profiles")
if len(profiles) > 0 {
Expand All @@ -50,40 +50,41 @@ These placeholders are useful when you need to override the ssh command.`,
&lib.SSHEntry{
ProfileConfig: lib.ProfileConfig{Name: profile},
InstanceID: instanceID,
User: defaultUser,
User: instanceUser,
Names: []string{instanceID},
},
},
viper.GetString("ssh-config-path"),
args,
)
}
// ok, profile is not set, switch to mode 2
log.Info("No profile has been provided, switching to the cache mode")
cache := cache.NewYAMLCache(viper.GetString("cache-dir"))

sshEntry, err := cache.Lookup(instanceID)
if err != nil {
log.WithError(err).Fatalf("can't lookup %s in cache", instanceID)
}
if sshEntry.User == "" {
sshEntry.User = defaultUser
}
} else {
// ok, profile is not set, switch to mode 2
log.Info("no profile has been provided, switching to the cache mode")
cache := cache.NewYAMLCache(viper.GetString("cache-dir"))

sshEntries = append(sshEntries, &sshEntry)
// ProxyJump is set, which means we need to lookup the bastion host too
if sshEntry.ProxyJump != "" {
bastionEntry, err := cache.Lookup(sshEntry.ProxyJump)
sshEntry, err := cache.Lookup(instanceID)
if err != nil {
log.WithError(err).Fatalf("can't lookup bastion %s in cache", sshEntry.ProxyJump)
log.WithError(err).Fatalf("can't lookup %s in cache", instanceID)
}
if instanceUser != "" {
sshEntry.User = instanceUser
}
if bastionEntry.User == "" {
bastionEntry.User = defaultUser

sshEntries = append(sshEntries, &sshEntry)
// ProxyJump is set, which means we need to lookup the bastion host too
if sshEntry.ProxyJump != "" {
bastionEntry, err := cache.Lookup(sshEntry.ProxyJump)
if err != nil {
log.WithError(err).Fatalf("can't lookup bastion %s in cache", sshEntry.ProxyJump)
}
if instanceUser == "" {
bastionEntry.User = instanceUser
}
log.WithField("instance_id", bastionEntry.InstanceID).Infof("Got bastion %s", bastionEntry.Names[0])
sshEntries = append(sshEntries, &bastionEntry)
}
log.WithField("instance_id", bastionEntry.InstanceID).Infof("Got bastion %s", bastionEntry.Names[0])
sshEntries = append(sshEntries, &bastionEntry)
ec2connect.ConnectEC2(sshEntries, viper.GetString("ssh-config-path"), args)
}
ec2connect.ConnectEC2(sshEntries, viper.GetString("ssh-config-path"), args)
},
}

Expand All @@ -95,8 +96,8 @@ func init() {
defaultSSHConfigFile := path.Join(homeDir, ".ssh", "ec2_connect_config")

connectCmd.Flags().StringP("instanceid", "i", "", "Instance ID to connect to")
connectCmd.Flags().StringP("user", "u", "ec2-user", "Existing user on the instance")
connectCmd.Flags().StringP("ssh-config-path", "c", defaultSSHConfigFile, "Existing user on the instance")
connectCmd.Flags().StringP("user", "u", "", "Existing user on the instance")
connectCmd.Flags().StringP("ssh-config-path", "c", defaultSSHConfigFile, "Path to the ssh config to generate")
viper.BindPFlag("instanceid", connectCmd.Flags().Lookup("instanceid"))
viper.BindPFlag("user", connectCmd.Flags().Lookup("user"))
viper.BindPFlag("ssh-config-path", connectCmd.Flags().Lookup("ssh-config-path"))
Expand Down
4 changes: 2 additions & 2 deletions lib/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]Process
Domain: summary.Domain,
},
}
entry.User = getTagValue("x-aws-ssh-user", instance.Tags)
entry.Port = getTagValue("x-aws-ssh-port", instance.Tags)
entry.User = GetUserFromTags(instance.Tags)
entry.Port = getPortFromTags(instance.Tags)

// first try to find a bastion from this vpc
bastion := findBestBastion(instanceName, vpcBastions)
Expand Down
40 changes: 30 additions & 10 deletions lib/ec2connect/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"golang.org/x/crypto/ssh/agent"
)

const defaultUser = "ec2-user"

// ConnectEC2 connects to an EC2 instance by pushing your public key onto it first
// using EC2 connect feature and then runs ssh.
func ConnectEC2(sshEntries lib.SSHEntries, sshConfigPath string, args []string) {
Expand All @@ -37,15 +39,17 @@ func ConnectEC2(sshEntries lib.SSHEntries, sshConfigPath string, args []string)
if len(sshEntry.Names) > 0 {
instanceName = sshEntry.Names[0]
}

log.WithField("instance", instanceName).WithField("user", sshEntry.User).Info("Pushing SSH key...")
instanceIpAddress, err := pushEC2Connect(sshEntry.ProfileConfig.Name, sshEntry.InstanceID, sshEntry.User, pubkey)
log.WithField("instance", instanceName).Info("trying to do ec2 connect...")
instanceIPAddress, instanceUser, err := pushEC2Connect(sshEntry.ProfileConfig.Name, sshEntry.InstanceID, sshEntry.User, pubkey)
if err != nil {
log.WithError(err).Fatal("can't push ssh key to the instance")
}
// if the address is empty we set to the value we got from ec2 connect push
if sshEntry.Address == "" {
sshEntry.Address = instanceIpAddress
sshEntry.Address = instanceIPAddress
}
if sshEntry.User == "" {
sshEntry.User = instanceUser
}
}

Expand Down Expand Up @@ -91,42 +95,58 @@ func ConnectEC2(sshEntries lib.SSHEntries, sshConfigPath string, args []string)

// pushEC2Connect pushes the ssh key to a given profile and instance ID
// and returns the public (or private if public doesn't exist) address of the EC2 instance
func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) (string, error) {
func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) (string, string, error) {
ctx := log.WithField("instance_id", instanceID)
localSession, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{},

SharedConfigState: session.SharedConfigEnable,
Profile: profile,
})
if err != nil {
return "", fmt.Errorf("can't get aws session: %s", err)
return "", "", fmt.Errorf("can't get aws session: %s", err)
}
ec2Svc := ec2.New(localSession)
ec2Result, err := ec2Svc.DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: aws.StringSlice([]string{instanceID}),
})
if err != nil {
return "", fmt.Errorf("can't get ec2 instance: %s", err)
return "", "", fmt.Errorf("can't get ec2 instance: %s", err)
}

if len(ec2Result.Reservations) == 0 || len(ec2Result.Reservations[0].Instances) == 0 {
return "", fmt.Errorf("Couldn't find the instance %s", instanceID)
return "", "", fmt.Errorf("Couldn't find the instance %s", instanceID)
}

ec2Instance := ec2Result.Reservations[0].Instances[0]
ec2ICSvc := ec2instanceconnect.New(localSession)

// no username has been provided, so we try to get it fom the instance tag first
if instanceUser == "" {
ctx.Debug("no user has been set provided, trying to get it from the tags")
// next try to get username from the instance tags
if instanceUser = lib.GetUserFromTags(ec2Instance.Tags); instanceUser == "" {
// otherwise fallback to default
ctx.WithField("user", defaultUser).Debugf("got no user from the instance tags, setting to default")
instanceUser = defaultUser
} else {
ctx.WithField("user", instanceUser).Debugf("got username from tags")
}
}

ctx.WithField("user", instanceUser).Info("pushing SSH key...")

if _, err := ec2ICSvc.SendSSHPublicKey(&ec2instanceconnect.SendSSHPublicKeyInput{
InstanceId: ec2Instance.InstanceId,
InstanceOSUser: aws.String(instanceUser),
AvailabilityZone: ec2Instance.Placement.AvailabilityZone,
SSHPublicKey: aws.String(pubKey),
}); err != nil {
return "", fmt.Errorf("can't push ssh key: %s", err)
return "", "", fmt.Errorf("can't push ssh key: %s", err)
}
var address = aws.StringValue(ec2Instance.PrivateIpAddress)
if aws.StringValue(ec2Instance.PublicIpAddress) != "" {
address = aws.StringValue(ec2Instance.PublicIpAddress)
}
return address, nil
return address, instanceUser, nil
}
10 changes: 10 additions & 0 deletions lib/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,20 @@ func getTagValue(tag string, tags []*ec2.Tag, caseInsensitive ...bool) string {
return ""

}

func getNameFromTags(tags []*ec2.Tag) string {
return strings.ToLower(getTagValue("Name", tags))
}

func getPortFromTags(tags []*ec2.Tag) string {
return strings.ToLower(getTagValue("x-aws-ssh-port", tags))
}

// GetUserFromTags gets the ec2 username from tags
func GetUserFromTags(tags []*ec2.Tag) string {
return strings.ToLower(getTagValue("x-aws-ssh-user", tags))
}

func isBastionFromTags(tags []*ec2.Tag, checkGlobal bool) bool {
if len(tags) > 0 {
var name string
Expand Down

0 comments on commit f8f499c

Please sign in to comment.