From ae2dd172e8f71ab193bcca989d543d8d7ca025e3 Mon Sep 17 00:00:00 2001 From: James Sanders Date: Tue, 4 Jun 2024 14:42:26 -0700 Subject: [PATCH] refactor: allow variants to implement their own k8s providerID parsing logic == Motivation == Allow further variant specific customization == Details == This change adds the ability for variants to implement their own logic to parse a k8s providerID into an identifier specific to that variant. This is done by adding a new method 'NodeId' on each variant that takes the providerID (in pre-parsed url format) and returns a respective NodeId. NodeId is a refactor of the previous InstanceId which was inadequately named for variants other than EC2 instances (e.g. fargate). I have also taken this opportunity to add a bit more strong typing to the variant methods to make these methods a bit easier to grok and safer to implement. I have also renamed the KubernetesInstanceId to KubernetesProviderId which felt a more adequate name as well. Finally, in order to prevent circular depedencies and to maintain (what I felt) was a more logical concept of a NodeId, we have created a new submodule in the v1 package named 'awsnode' which contains the NodeId type. This allows both the base v1 module and the variant submodule to use this NodeId type withou circular dependencies. == Testing == make --- pkg/controllers/tagging/tagging_controller.go | 10 +-- pkg/providers/v1/aws.go | 53 +++++++-------- pkg/providers/v1/aws_instance.go | 3 +- pkg/providers/v1/aws_loadbalancer.go | 7 +- pkg/providers/v1/aws_loadbalancer_test.go | 7 +- pkg/providers/v1/aws_routes.go | 4 +- pkg/providers/v1/aws_test.go | 3 +- pkg/providers/v1/awsnode/identifier.go | 11 ++++ pkg/providers/v1/instances.go | 66 ++++++++----------- pkg/providers/v1/instances_test.go | 31 ++++----- pkg/providers/v1/variant/fargate/fargate.go | 41 ++++++++---- pkg/providers/v1/variant/variant.go | 40 +++++++---- 12 files changed, 156 insertions(+), 120 deletions(-) create mode 100644 pkg/providers/v1/awsnode/identifier.go diff --git a/pkg/controllers/tagging/tagging_controller.go b/pkg/controllers/tagging/tagging_controller.go index 909c8237eb..84d5f07cbc 100644 --- a/pkg/controllers/tagging/tagging_controller.go +++ b/pkg/controllers/tagging/tagging_controller.go @@ -225,7 +225,7 @@ func (tc *Controller) process() bool { recordWorkItemLatencyMetrics(workItemDequeuingTimeWorkItemMetric, timeTaken) klog.Infof("Dequeuing latency %f seconds", timeTaken) - instanceID, err := awsv1.KubernetesInstanceID(workItem.node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, err := awsv1.ParseProviderID(workItem.node.Spec.ProviderID) if err != nil { err = fmt.Errorf("Error in getting instanceID for node %s, error: %v", workItem.node.GetName(), err) utilruntime.HandleError(err) @@ -233,9 +233,9 @@ func (tc *Controller) process() bool { } klog.Infof("Instance ID of work item %s is %s", workItem, instanceID) - if variant.IsVariantNode(string(instanceID)) { + if variant.IsVariantNode(instanceID) { klog.Infof("Skip processing the node %s since it is a %s node", - instanceID, variant.NodeType(string(instanceID))) + instanceID, variant.NodeType(instanceID)) tc.workqueue.Forget(obj) return nil } @@ -297,7 +297,7 @@ func (tc *Controller) tagEc2Instance(node *v1.Node) error { return nil } - instanceID, _ := awsv1.KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, _ := awsv1.ParseProviderID(node.Spec.ProviderID) err := tc.cloud.TagResource(string(instanceID), tc.tags) @@ -349,7 +349,7 @@ func (tc *Controller) untagNodeResources(node *v1.Node) error { // untagEc2Instances deletes the provided tags to each EC2 instances in // the cluster. func (tc *Controller) untagEc2Instance(node *v1.Node) error { - instanceID, _ := awsv1.KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, _ := awsv1.ParseProviderID(node.Spec.ProviderID) err := tc.cloud.UntagResource(string(instanceID), tc.tags) diff --git a/pkg/providers/v1/aws.go b/pkg/providers/v1/aws.go index 5122ecd432..0613c93632 100644 --- a/pkg/providers/v1/aws.go +++ b/pkg/providers/v1/aws.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "net" "regexp" "sort" @@ -418,7 +419,7 @@ func InstanceIDIndexFunc(obj interface{}) ([]string, error) { // provider ID hasn't been populated yet return []string{""}, nil } - instanceID, err := KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(node.Spec.ProviderID) if err != nil { //logging the error as warning as Informer.AddIndexers would panic if there is an error klog.Warningf("error mapping node %q's provider ID %q to instance ID: %v", node.Name, node.Spec.ProviderID, err) @@ -832,16 +833,16 @@ func extractIPv6NodeAddresses(instance *ec2.Instance) ([]v1.NodeAddress, error) // This method will not be called from the node that is requesting this ID. i.e. metadata service // and other local methods cannot be used here func (c *Cloud) NodeAddressesByProviderID(ctx context.Context, providerID string) ([]v1.NodeAddress, error) { - instanceID, err := KubernetesInstanceID(providerID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(providerID) if err != nil { return nil, err } - if v := variant.GetVariant(string(instanceID)); v != nil { - return v.NodeAddresses(string(instanceID), c.vpcID) + if v := variant.GetVariant(instanceID); v != nil { + return v.NodeAddresses(instanceID, c.vpcID) } - instance, err := describeInstance(c.ec2, instanceID) + instance, err := describeInstance(c.ec2, string(instanceID)) if err != nil { return nil, err } @@ -871,17 +872,17 @@ func (c *Cloud) NodeAddressesByProviderID(ctx context.Context, providerID string // InstanceExistsByProviderID returns true if the instance with the given provider id still exists. // If false is returned with no error, the instance will be immediately deleted by the cloud controller manager. func (c *Cloud) InstanceExistsByProviderID(ctx context.Context, providerID string) (bool, error) { - instanceID, err := KubernetesInstanceID(providerID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(providerID) if err != nil { return false, err } - if v := variant.GetVariant(string(instanceID)); v != nil { - return v.InstanceExists(string(instanceID), c.vpcID) + if v := variant.GetVariant(instanceID); v != nil { + return v.InstanceExists(instanceID, c.vpcID) } request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []*string{instanceID.AwsString()}, } instances, err := c.ec2.DescribeInstances(request) @@ -910,17 +911,17 @@ func (c *Cloud) InstanceExistsByProviderID(ctx context.Context, providerID strin // InstanceShutdownByProviderID returns true if the instance is terminated func (c *Cloud) InstanceShutdownByProviderID(ctx context.Context, providerID string) (bool, error) { - instanceID, err := KubernetesInstanceID(providerID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(providerID) if err != nil { return false, err } - if v := variant.GetVariant(string(instanceID)); v != nil { - return v.InstanceShutdown(string(instanceID), c.vpcID) + if v := variant.GetVariant(instanceID); v != nil { + return v.InstanceShutdown(instanceID, c.vpcID) } request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []*string{instanceID.AwsString()}, } instances, err := c.ec2.DescribeInstances(request) @@ -969,16 +970,16 @@ func (c *Cloud) InstanceID(ctx context.Context, nodeName types.NodeName) (string // This method will not be called from the node that is requesting this ID. i.e. metadata service // and other local methods cannot be used here func (c *Cloud) InstanceTypeByProviderID(ctx context.Context, providerID string) (string, error) { - instanceID, err := KubernetesInstanceID(providerID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(providerID) if err != nil { return "", err } - if v := variant.GetVariant(string(instanceID)); v != nil { - return v.InstanceTypeByProviderID(string(instanceID)) + if v := variant.GetVariant(instanceID); v != nil { + return v.InstanceTypeByProviderID(instanceID) } - instance, err := describeInstance(c.ec2, instanceID) + instance, err := describeInstance(c.ec2, string(instanceID)) if err != nil { return "", err } @@ -1010,13 +1011,13 @@ func (c *Cloud) GetZone(ctx context.Context) (cloudprovider.Zone, error) { // This is particularly useful in external cloud providers where the kubelet // does not initialize node data. func (c *Cloud) GetZoneByProviderID(ctx context.Context, providerID string) (cloudprovider.Zone, error) { - instanceID, err := KubernetesInstanceID(providerID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(providerID) if err != nil { return cloudprovider.Zone{}, err } - if v := variant.GetVariant(string(instanceID)); v != nil { - return v.GetZone(string(instanceID), c.vpcID, c.region) + if v := variant.GetVariant(instanceID); v != nil { + return v.GetZone(instanceID, c.vpcID, c.region) } instance, err := c.getInstanceByID(string(instanceID)) @@ -2651,7 +2652,7 @@ func (c *Cloud) getTaggedSecurityGroups() (map[string]*ec2.SecurityGroup, error) // Open security group ingress rules on the instances so that the load balancer can talk to them // Will also remove any security groups ingress rules for the load balancer that are _not_ needed for allInstances -func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancerDescription, instances map[InstanceID]*ec2.Instance, annotations map[string]string) error { +func (c *Cloud) updateInstanceSecurityGroupsForLoadBalancer(lb *elb.LoadBalancerDescription, instances map[awsnode.NodeID]*ec2.Instance, annotations map[string]string) error { if c.cfg.Global.DisableSecurityGroupIngress { return nil } @@ -3228,15 +3229,15 @@ func nodeNameToIPAddress(nodeName string) string { return strings.ReplaceAll(nodeName, "-", ".") } -func (c *Cloud) nodeNameToInstanceID(nodeName types.NodeName) (InstanceID, error) { +func (c *Cloud) nodeNameToInstanceID(nodeName types.NodeName) (awsnode.NodeID, error) { if strings.HasPrefix(string(nodeName), rbnNamePrefix) { // depending on if you use a RHEL (e.g. AL2) or Debian (e.g. standard Ubuntu) based distribution, the // hostname on the machine may be either i-00000000000000001 or i-00000000000000001.region.compute.internal. // This handles both scenarios by returning anything before the first '.' in the node name if it has an RBN prefix. if idx := strings.IndexByte(string(nodeName), '.'); idx != -1 { - return InstanceID(nodeName[0:idx]), nil + return awsnode.NodeID(nodeName[0:idx]), nil } - return InstanceID(nodeName), nil + return awsnode.NodeID(nodeName), nil } if len(nodeName) == 0 { return "", fmt.Errorf("no nodeName provided") @@ -3254,10 +3255,10 @@ func (c *Cloud) nodeNameToInstanceID(nodeName types.NodeName) (InstanceID, error return "", fmt.Errorf("node has no providerID") } - return KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + return ParseProviderID(node.Spec.ProviderID) } -func (c *Cloud) instanceIDToNodeName(instanceID InstanceID) (types.NodeName, error) { +func (c *Cloud) instanceIDToNodeName(instanceID awsnode.NodeID) (types.NodeName, error) { if len(instanceID) == 0 { return "", fmt.Errorf("no instanceID provided") } diff --git a/pkg/providers/v1/aws_instance.go b/pkg/providers/v1/aws_instance.go index e7e8b152a1..01423a6e92 100644 --- a/pkg/providers/v1/aws_instance.go +++ b/pkg/providers/v1/aws_instance.go @@ -20,7 +20,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "k8s.io/apimachinery/pkg/types" - "k8s.io/cloud-provider-aws/pkg/providers/v1/iface" ) @@ -67,5 +66,5 @@ func newAWSInstance(ec2Service iface.EC2, instance *ec2.Instance) *awsInstance { // Gets the full information about this instance from the EC2 API func (i *awsInstance) describeInstance() (*ec2.Instance, error) { - return describeInstance(i.ec2, InstanceID(i.awsID)) + return describeInstance(i.ec2, i.awsID) } diff --git a/pkg/providers/v1/aws_loadbalancer.go b/pkg/providers/v1/aws_loadbalancer.go index c39ea3de37..f149d92ad0 100644 --- a/pkg/providers/v1/aws_loadbalancer.go +++ b/pkg/providers/v1/aws_loadbalancer.go @@ -20,6 +20,7 @@ import ( "crypto/sha1" "encoding/hex" "fmt" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "reflect" "regexp" "strconv" @@ -781,7 +782,7 @@ func (c *Cloud) chunkTargetDescriptions(targets []*elbv2.TargetDescription, chun // updateInstanceSecurityGroupsForNLB will adjust securityGroup's settings to allow inbound traffic into instances from clientCIDRs and portMappings. // TIP: if either instances or clientCIDRs or portMappings are nil, then the securityGroup rules for lbName are cleared. -func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[InstanceID]*ec2.Instance, subnetCIDRs []string, clientCIDRs []string, portMappings []nlbPortMapping) error { +func (c *Cloud) updateInstanceSecurityGroupsForNLB(lbName string, instances map[awsnode.NodeID]*ec2.Instance, subnetCIDRs []string, clientCIDRs []string, portMappings []nlbPortMapping) error { if c.cfg.Global.DisableSecurityGroupIngress { return nil } @@ -1430,7 +1431,7 @@ func (c *Cloud) ensureLoadBalancerHealthCheck(loadBalancer *elb.LoadBalancerDesc } // Makes sure that exactly the specified hosts are registered as instances with the load balancer -func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances []*elb.Instance, instanceIDs map[InstanceID]*ec2.Instance) error { +func (c *Cloud) ensureLoadBalancerInstances(loadBalancerName string, lbInstances []*elb.Instance, instanceIDs map[awsnode.NodeID]*ec2.Instance) error { expected := sets.NewString() for id := range instanceIDs { expected.Insert(string(id)) @@ -1607,7 +1608,7 @@ func proxyProtocolEnabled(backend *elb.BackendServerDescription) bool { // findInstancesForELB gets the EC2 instances corresponding to the Nodes, for setting up an ELB // We ignore Nodes (with a log message) where the instanceid cannot be determined from the provider, // and we ignore instances which are not found -func (c *Cloud) findInstancesForELB(nodes []*v1.Node, annotations map[string]string) (map[InstanceID]*ec2.Instance, error) { +func (c *Cloud) findInstancesForELB(nodes []*v1.Node, annotations map[string]string) (map[awsnode.NodeID]*ec2.Instance, error) { targetNodes := filterTargetNodes(nodes, annotations) diff --git a/pkg/providers/v1/aws_loadbalancer_test.go b/pkg/providers/v1/aws_loadbalancer_test.go index 309d9eb209..a24c077fcb 100644 --- a/pkg/providers/v1/aws_loadbalancer_test.go +++ b/pkg/providers/v1/aws_loadbalancer_test.go @@ -18,6 +18,7 @@ package aws import ( "fmt" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "reflect" "testing" "time" @@ -592,7 +593,7 @@ func TestCloud_findInstancesForELB(t *testing.T) { return } - want := map[InstanceID]*ec2.Instance{ + want := map[awsnode.NodeID]*ec2.Instance{ "i-self": awsServices.selfInstance, } got, err := c.findInstancesForELB([]*v1.Node{defaultNode}, nil) @@ -601,9 +602,9 @@ func TestCloud_findInstancesForELB(t *testing.T) { // Add a new EC2 instance awsServices.instances = append(awsServices.instances, newInstance) - want = map[InstanceID]*ec2.Instance{ + want = map[awsnode.NodeID]*ec2.Instance{ "i-self": awsServices.selfInstance, - InstanceID(aws.StringValue(newInstance.InstanceId)): newInstance, + awsnode.NodeID(aws.StringValue(newInstance.InstanceId)): newInstance, } got, err = c.findInstancesForELB([]*v1.Node{defaultNode, newNode}, nil) assert.NoError(t, err) diff --git a/pkg/providers/v1/aws_routes.go b/pkg/providers/v1/aws_routes.go index e3e7c5b7a4..16f6e0e29e 100644 --- a/pkg/providers/v1/aws_routes.go +++ b/pkg/providers/v1/aws_routes.go @@ -19,9 +19,9 @@ package aws import ( "context" "fmt" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "k8s.io/klog/v2" cloudprovider "k8s.io/cloud-provider" @@ -114,7 +114,7 @@ func (c *Cloud) ListRoutes(ctx context.Context, clusterName string) ([]*cloudpro if instanceID != "" { _, found := instances[instanceID] if found { - node, err := c.instanceIDToNodeName(InstanceID(instanceID)) + node, err := c.instanceIDToNodeName(awsnode.NodeID(instanceID)) if err != nil { return nil, err } diff --git a/pkg/providers/v1/aws_test.go b/pkg/providers/v1/aws_test.go index 577f5d72cf..65b462574e 100644 --- a/pkg/providers/v1/aws_test.go +++ b/pkg/providers/v1/aws_test.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "io" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "math/rand" "reflect" "sort" @@ -2399,7 +2400,7 @@ func TestNodeNameToInstanceID(t *testing.T) { func TestInstanceIDToNodeName(t *testing.T) { testCases := []struct { name string - instanceID InstanceID + instanceID awsnode.NodeID node *v1.Node expectedNodeName types.NodeName expectedErr error diff --git a/pkg/providers/v1/awsnode/identifier.go b/pkg/providers/v1/awsnode/identifier.go new file mode 100644 index 0000000000..a534308f31 --- /dev/null +++ b/pkg/providers/v1/awsnode/identifier.go @@ -0,0 +1,11 @@ +package awsnode + +import "github.com/aws/aws-sdk-go/aws" + +// NodeID is the ID used to uniquely identify a node within an AWS service +type NodeID string + +// AwsString returns a pointer to the string value of the NodeID. Useful for AWS APIs +func (i NodeID) AwsString() *string { + return aws.String(string(i)) +} diff --git a/pkg/providers/v1/instances.go b/pkg/providers/v1/instances.go index 08ae3aff21..28c21e0b58 100644 --- a/pkg/providers/v1/instances.go +++ b/pkg/providers/v1/instances.go @@ -18,6 +18,7 @@ package aws import ( "fmt" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "net/url" "regexp" "strings" @@ -36,29 +37,14 @@ import ( // awsInstanceRegMatch represents Regex Match for AWS instance. var awsInstanceRegMatch = regexp.MustCompile("^i-[^/]*$") -// InstanceID represents the ID of the instance in the AWS API, e.g. i-12345678 -// The "traditional" format is "i-12345678" -// A new longer format is also being introduced: "i-12345678abcdef01" -// We should not assume anything about the length or format, though it seems -// reasonable to assume that instances will continue to start with "i-". -type InstanceID string - -func (i InstanceID) awsString() *string { - return aws.String(string(i)) -} - -// KubernetesInstanceID represents the id for an instance in the kubernetes API; -// the following form +// ParseProviderID turns a Kubernetes ProviderID into an AWS node id +// the following are forms of ProviderIDs that are supported: // - aws://// // - aws://// // - aws:////fargate- // - -type KubernetesInstanceID string - -// MapToAWSInstanceID extracts the InstanceID from the KubernetesInstanceID -func (name KubernetesInstanceID) MapToAWSInstanceID() (InstanceID, error) { - s := string(name) - +func ParseProviderID(providerID string) (awsnode.NodeID, error) { + s := providerID if !strings.HasPrefix(s, "aws://") { // Assume a bare aws instance id (i-1234...) // Build a URL with an empty host (AZ) @@ -66,10 +52,14 @@ func (name KubernetesInstanceID) MapToAWSInstanceID() (InstanceID, error) { } url, err := url.Parse(s) if err != nil { - return "", fmt.Errorf("Invalid instance name (%s): %v", name, err) + return "", fmt.Errorf("Invalid instance name (%s): %v", providerID, err) } if url.Scheme != "aws" { - return "", fmt.Errorf("Invalid scheme for AWS instance (%s)", name) + return "", fmt.Errorf("Invalid scheme for AWS instance (%s)", providerID) + } + + if nodeID := variant.GetNodeID(*url); nodeID != "" { + return nodeID, nil } awsID := "" @@ -81,21 +71,21 @@ func (name KubernetesInstanceID) MapToAWSInstanceID() (InstanceID, error) { // We sanity check the resulting instance ID; the two known formats are // i-12345678 and i-12345678abcdef01 - if awsID == "" || !(awsInstanceRegMatch.MatchString(awsID) || variant.IsVariantNode(awsID)) { - return "", fmt.Errorf("Invalid format for AWS instance (%s)", name) + if awsID == "" || !awsInstanceRegMatch.MatchString(awsID) { + return "", fmt.Errorf("Invalid format for AWS instance (%s)", providerID) } - return InstanceID(awsID), nil + return awsnode.NodeID(awsID), nil } // mapToAWSInstanceID extracts the InstanceIDs from the Nodes, returning an error if a Node cannot be mapped -func mapToAWSInstanceIDs(nodes []*v1.Node) ([]InstanceID, error) { - var instanceIDs []InstanceID +func mapToAWSInstanceIDs(nodes []*v1.Node) ([]awsnode.NodeID, error) { + var instanceIDs []awsnode.NodeID for _, node := range nodes { if node.Spec.ProviderID == "" { return nil, fmt.Errorf("node %q did not have ProviderID set", node.Name) } - instanceID, err := KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(node.Spec.ProviderID) if err != nil { return nil, fmt.Errorf("unable to parse ProviderID %q for node %q", node.Spec.ProviderID, node.Name) } @@ -106,14 +96,14 @@ func mapToAWSInstanceIDs(nodes []*v1.Node) ([]InstanceID, error) { } // mapToAWSInstanceIDsTolerant extracts the InstanceIDs from the Nodes, skipping Nodes that cannot be mapped -func mapToAWSInstanceIDsTolerant(nodes []*v1.Node) []InstanceID { - var instanceIDs []InstanceID +func mapToAWSInstanceIDsTolerant(nodes []*v1.Node) []awsnode.NodeID { + var instanceIDs []awsnode.NodeID for _, node := range nodes { if node.Spec.ProviderID == "" { klog.Warningf("node %q did not have ProviderID set", node.Name) continue } - instanceID, err := KubernetesInstanceID(node.Spec.ProviderID).MapToAWSInstanceID() + instanceID, err := ParseProviderID(node.Spec.ProviderID) if err != nil { klog.Warningf("unable to parse ProviderID %q for node %q", node.Spec.ProviderID, node.Name) continue @@ -125,9 +115,9 @@ func mapToAWSInstanceIDsTolerant(nodes []*v1.Node) []InstanceID { } // Gets the full information about this instance from the EC2 API -func describeInstance(ec2Client iface.EC2, instanceID InstanceID) (*ec2.Instance, error) { +func describeInstance(ec2Client iface.EC2, instanceID string) (*ec2.Instance, error) { request := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{instanceID.awsString()}, + InstanceIds: []*string{&instanceID}, } instances, err := ec2Client.DescribeInstances(request) @@ -165,9 +155,9 @@ func (c *instanceCache) describeAllInstancesUncached() (*allInstancesSnapshot, e return nil, err } - m := make(map[InstanceID]*ec2.Instance) + m := make(map[awsnode.NodeID]*ec2.Instance) for _, i := range instances { - id := InstanceID(aws.StringValue(i.InstanceId)) + id := awsnode.NodeID(aws.StringValue(i.InstanceId)) m[id] = i } @@ -190,7 +180,7 @@ type cacheCriteria struct { // HasInstances is a list of InstanceIDs that must be in a cached snapshot for it to be considered valid. // If an instance is not found in the cached snapshot, the snapshot be ignored and we will re-fetch. - HasInstances []InstanceID + HasInstances []awsnode.NodeID } // describeAllInstancesCached returns all instances, using cached results if applicable @@ -238,12 +228,12 @@ func (s *allInstancesSnapshot) MeetsCriteria(criteria cacheCriteria) bool { // along with the timestamp for cache-invalidation purposes type allInstancesSnapshot struct { timestamp time.Time - instances map[InstanceID]*ec2.Instance + instances map[awsnode.NodeID]*ec2.Instance } // FindInstances returns the instances corresponding to the specified ids. If an id is not found, it is ignored. -func (s *allInstancesSnapshot) FindInstances(ids []InstanceID) map[InstanceID]*ec2.Instance { - m := make(map[InstanceID]*ec2.Instance) +func (s *allInstancesSnapshot) FindInstances(ids []awsnode.NodeID) map[awsnode.NodeID]*ec2.Instance { + m := make(map[awsnode.NodeID]*ec2.Instance) for _, id := range ids { instance := s.instances[id] if instance != nil { diff --git a/pkg/providers/v1/instances_test.go b/pkg/providers/v1/instances_test.go index ac431c6cf6..5ed6318261 100644 --- a/pkg/providers/v1/instances_test.go +++ b/pkg/providers/v1/instances_test.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" "testing" "time" @@ -28,8 +29,8 @@ import ( func TestMapToAWSInstanceIDs(t *testing.T) { tests := []struct { - Kubernetes KubernetesInstanceID - Aws InstanceID + Kubernetes string + Aws awsnode.NodeID ExpectError bool }{ { @@ -87,7 +88,7 @@ func TestMapToAWSInstanceIDs(t *testing.T) { } for _, test := range tests { - awsID, err := test.Kubernetes.MapToAWSInstanceID() + awsID, err := ParseProviderID(test.Kubernetes) if err != nil { if !test.ExpectError { t.Errorf("unexpected error parsing %s: %v", test.Kubernetes, err) @@ -146,18 +147,18 @@ func TestSnapshotMeetsCriteria(t *testing.T) { t.Errorf("Snapshot did not honor MaxAge") } - if snapshot.MeetsCriteria(cacheCriteria{HasInstances: []InstanceID{InstanceID("i-12345678")}}) { + if snapshot.MeetsCriteria(cacheCriteria{HasInstances: []awsnode.NodeID{awsnode.NodeID("i-12345678")}}) { t.Errorf("Snapshot did not honor HasInstances with missing instances") } - snapshot.instances = make(map[InstanceID]*ec2.Instance) - snapshot.instances[InstanceID("i-12345678")] = &ec2.Instance{} + snapshot.instances = make(map[awsnode.NodeID]*ec2.Instance) + snapshot.instances[awsnode.NodeID("i-12345678")] = &ec2.Instance{} - if !snapshot.MeetsCriteria(cacheCriteria{HasInstances: []InstanceID{InstanceID("i-12345678")}}) { + if !snapshot.MeetsCriteria(cacheCriteria{HasInstances: []awsnode.NodeID{awsnode.NodeID("i-12345678")}}) { t.Errorf("Snapshot did not honor HasInstances with matching instances") } - if snapshot.MeetsCriteria(cacheCriteria{HasInstances: []InstanceID{InstanceID("i-12345678"), InstanceID("i-00000000")}}) { + if snapshot.MeetsCriteria(cacheCriteria{HasInstances: []awsnode.NodeID{awsnode.NodeID("i-12345678"), awsnode.NodeID("i-00000000")}}) { t.Errorf("Snapshot did not honor HasInstances with partially matching instances") } } @@ -177,22 +178,22 @@ func TestOlderThan(t *testing.T) { func TestSnapshotFindInstances(t *testing.T) { snapshot := &allInstancesSnapshot{} - snapshot.instances = make(map[InstanceID]*ec2.Instance) + snapshot.instances = make(map[awsnode.NodeID]*ec2.Instance) { - id := InstanceID("i-12345678") - snapshot.instances[id] = &ec2.Instance{InstanceId: id.awsString()} + id := awsnode.NodeID("i-12345678") + snapshot.instances[id] = &ec2.Instance{InstanceId: id.AwsString()} } { - id := InstanceID("i-23456789") - snapshot.instances[id] = &ec2.Instance{InstanceId: id.awsString()} + id := awsnode.NodeID("i-23456789") + snapshot.instances[id] = &ec2.Instance{InstanceId: id.AwsString()} } - instances := snapshot.FindInstances([]InstanceID{InstanceID("i-12345678"), InstanceID("i-23456789"), InstanceID("i-00000000")}) + instances := snapshot.FindInstances([]awsnode.NodeID{"i-12345678", "i-23456789", "i-00000000"}) if len(instances) != 2 { t.Errorf("findInstances returned %d results, expected 2", len(instances)) } - for _, id := range []InstanceID{InstanceID("i-12345678"), InstanceID("i-23456789")} { + for _, id := range []awsnode.NodeID{awsnode.NodeID("i-12345678"), awsnode.NodeID("i-23456789")} { i := instances[id] if i == nil { t.Errorf("findInstances did not return %s", id) diff --git a/pkg/providers/v1/variant/fargate/fargate.go b/pkg/providers/v1/variant/fargate/fargate.go index f4d7174603..fd296c0ca8 100644 --- a/pkg/providers/v1/variant/fargate/fargate.go +++ b/pkg/providers/v1/variant/fargate/fargate.go @@ -2,6 +2,8 @@ package fargate import ( "fmt" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" + "net/url" "strings" awssdk "github.com/aws/aws-sdk-go/aws" @@ -37,12 +39,12 @@ func (f *fargateVariant) Initialize(cloudConfig *config.CloudConfig, credentials return nil } -func (f *fargateVariant) InstanceTypeByProviderID(instanceID string) (string, error) { +func (f *fargateVariant) InstanceTypeByProviderID(nodeID awsnode.NodeID) (string, error) { return "", nil } -func (f *fargateVariant) GetZone(instanceID, vpcID, region string) (cloudprovider.Zone, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) GetZone(nodeID awsnode.NodeID, vpcID, region string) (cloudprovider.Zone, error) { + eni, err := f.DescribeNetworkInterfaces(f.ec2API, nodeID, vpcID) if eni == nil || err != nil { return cloudprovider.Zone{}, err } @@ -52,12 +54,12 @@ func (f *fargateVariant) GetZone(instanceID, vpcID, region string) (cloudprovide }, nil } -func (f *fargateVariant) IsSupportedNode(nodeName string) bool { - return strings.HasPrefix(nodeName, fargateNodeNamePrefix) +func (f *fargateVariant) IsSupportedNode(nodeID awsnode.NodeID) bool { + return strings.HasPrefix(string(nodeID), fargateNodeNamePrefix) } -func (f *fargateVariant) NodeAddresses(instanceID, vpcID string) ([]v1.NodeAddress, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) NodeAddresses(nodeID awsnode.NodeID, vpcID string) ([]v1.NodeAddress, error) { + eni, err := f.DescribeNetworkInterfaces(f.ec2API, nodeID, vpcID) if eni == nil || err != nil { return nil, err } @@ -83,16 +85,29 @@ func (f *fargateVariant) NodeAddresses(instanceID, vpcID string) ([]v1.NodeAddre return addresses, nil } -func (f *fargateVariant) InstanceExists(instanceID, vpcID string) (bool, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) InstanceExists(nodeID awsnode.NodeID, vpcID string) (bool, error) { + eni, err := f.DescribeNetworkInterfaces(f.ec2API, nodeID, vpcID) return eni != nil, err } -func (f *fargateVariant) InstanceShutdown(instanceID, vpcID string) (bool, error) { - eni, err := f.DescribeNetworkInterfaces(f.ec2API, instanceID, vpcID) +func (f *fargateVariant) InstanceShutdown(nodeID awsnode.NodeID, vpcID string) (bool, error) { + eni, err := f.DescribeNetworkInterfaces(f.ec2API, nodeID, vpcID) return eni != nil, err } +func (f *fargateVariant) NodeID(providerID url.URL) awsnode.NodeID { + tokens := strings.Split(strings.Trim(providerID.Path, "/"), "/") + // last token in the providerID is the aws resource ID for Fargate nodes + if len(tokens) == 0 { + return "" + } + nodeName := awsnode.NodeID(tokens[len(tokens)-1]) + if f.IsSupportedNode(nodeName) { + return nodeName + } + return "" +} + func newEc2Filter(name string, values ...string) *ec2.Filter { filter := &ec2.Filter{ Name: awssdk.String(name), @@ -116,8 +131,8 @@ func nodeNameToIPAddress(nodeName string) string { } // DescribeNetworkInterfaces returns network interface information for the given DNS name. -func (f *fargateVariant) DescribeNetworkInterfaces(ec2API iface.EC2, instanceID, vpcID string) (*ec2.NetworkInterface, error) { - eniEndpoint := strings.TrimPrefix(instanceID, fargateNodeNamePrefix) +func (f *fargateVariant) DescribeNetworkInterfaces(ec2API iface.EC2, nodeID awsnode.NodeID, vpcID string) (*ec2.NetworkInterface, error) { + eniEndpoint := strings.TrimPrefix(string(nodeID), fargateNodeNamePrefix) filters := []*ec2.Filter{ newEc2Filter("attachment.status", "attached"), diff --git a/pkg/providers/v1/variant/variant.go b/pkg/providers/v1/variant/variant.go index 39df86b795..e74c62a7cb 100644 --- a/pkg/providers/v1/variant/variant.go +++ b/pkg/providers/v1/variant/variant.go @@ -2,6 +2,8 @@ package variant import ( "fmt" + "k8s.io/cloud-provider-aws/pkg/providers/v1/awsnode" + "net/url" "sync" v1 "k8s.io/api/core/v1" @@ -20,12 +22,13 @@ var variants = make(map[string]Variant) type Variant interface { Initialize(cloudConfig *config.CloudConfig, credentials *credentials.Credentials, provider config.SDKProvider, ec2API iface.EC2, region string) error - IsSupportedNode(nodeName string) bool - NodeAddresses(instanceID, vpcID string) ([]v1.NodeAddress, error) - GetZone(instanceID, vpcID, region string) (cloudprovider.Zone, error) - InstanceExists(instanceID, vpcID string) (bool, error) - InstanceShutdown(instanceID, vpcID string) (bool, error) - InstanceTypeByProviderID(id string) (string, error) + IsSupportedNode(nodeID awsnode.NodeID) bool + NodeAddresses(nodeID awsnode.NodeID, vpcID string) ([]v1.NodeAddress, error) + GetZone(nodeID awsnode.NodeID, vpcID, region string) (cloudprovider.Zone, error) + InstanceExists(nodeID awsnode.NodeID, vpcID string) (bool, error) + InstanceShutdown(nodeID awsnode.NodeID, vpcID string) (bool, error) + InstanceTypeByProviderID(nodeID awsnode.NodeID) (string, error) + NodeID(providerID url.URL) awsnode.NodeID } // RegisterVariant is used to register code that needs to be called for a specific variant @@ -39,11 +42,11 @@ func RegisterVariant(name string, variant Variant) { } // IsVariantNode helps evaluate if a specific variant handles a given instance -func IsVariantNode(instanceID string) bool { +func IsVariantNode(nodeID awsnode.NodeID) bool { variantsLock.Lock() defer variantsLock.Unlock() for _, v := range variants { - if v.IsSupportedNode(instanceID) { + if v.IsSupportedNode(nodeID) { return true } } @@ -51,11 +54,11 @@ func IsVariantNode(instanceID string) bool { } // NodeType returns the type name example: "fargate" -func NodeType(instanceID string) string { +func NodeType(nodeID awsnode.NodeID) string { variantsLock.Lock() defer variantsLock.Unlock() for key, v := range variants { - if v.IsSupportedNode(instanceID) { + if v.IsSupportedNode(nodeID) { return key } } @@ -63,17 +66,30 @@ func NodeType(instanceID string) string { } // GetVariant returns the interface that can then be used to handle a specific instance -func GetVariant(instanceID string) Variant { +func GetVariant(nodeID awsnode.NodeID) Variant { variantsLock.Lock() defer variantsLock.Unlock() for _, v := range variants { - if v.IsSupportedNode(instanceID) { + if v.IsSupportedNode(nodeID) { return v } } return nil } +// GetNodeID returns the node id of the variant if a variant supports this particular provider id +// A return value of an empty string denotes no variant supported the node with this providerId. +func GetNodeID(providerID url.URL) awsnode.NodeID { + variantsLock.Lock() + defer variantsLock.Unlock() + for _, v := range variants { + if varID := v.NodeID(providerID); varID != "" { + return varID + } + } + return "" +} + // GetVariants returns the names of all the variants registered func GetVariants() []Variant { variantsLock.Lock()