Skip to content

Commit

Permalink
Add support for clusters running in private VPC
Browse files Browse the repository at this point in the history
  • Loading branch information
aaroniscode committed Apr 29, 2024
1 parent 70c9054 commit 5e8d935
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 8 deletions.
7 changes: 5 additions & 2 deletions cmd/aws-application-networking-k8s/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ package main

import (
"flag"
"os"
"strings"

"github.com/aws/aws-application-networking-k8s/pkg/webhook"
"github.com/go-logr/zapr"
"go.uber.org/zap/zapcore"
"os"
k8swebhook "sigs.k8s.io/controller-runtime/pkg/webhook"
"strings"

"github.com/aws/aws-application-networking-k8s/pkg/aws"
"github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog"
Expand Down Expand Up @@ -119,13 +120,15 @@ func main() {
"DefaultServiceNetwork", config.DefaultServiceNetwork,
"ClusterName", config.ClusterName,
"LogLevel", logLevel,
"EnablePrivateVPC", config.EnablePrivateVPC,
)

cloud, err := aws.NewCloud(log.Named("cloud"), aws.CloudConfig{
VpcId: config.VpcID,
AccountId: config.AccountID,
Region: config.Region,
ClusterName: config.ClusterName,
PrivateVPC: config.EnablePrivateVPC,
})
if err != nil {
setupLog.Fatal("cloud client setup failed: %s", err)
Expand Down
68 changes: 67 additions & 1 deletion pkg/aws/cloud.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package aws

import (
"context"
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/vpclattice"
"golang.org/x/exp/maps"

"context"
"github.com/aws/aws-application-networking-k8s/pkg/aws/services"
"github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog"
)
Expand All @@ -25,6 +26,7 @@ type CloudConfig struct {
AccountId string
Region string
ClusterName string
PrivateVPC bool
}

type Cloud interface {
Expand All @@ -38,6 +40,12 @@ type Cloud interface {
// creates lattice tags with default values populated and merges them with provided tags
DefaultTagsMergedWith(services.Tags) services.Tags

// find tags on lattice resources
FindTagsForARNs(ctx context.Context, arns []string) (map[string]services.Tags, error)

// find lattice target group ARNs using tags
FindTargetGroupARNs(context.Context, services.Tags) ([]string, error)

// check if managedBy tag set for lattice resource
IsArnManaged(ctx context.Context, arn string) (bool, error)

Expand Down Expand Up @@ -125,6 +133,55 @@ func (c *defaultCloud) DefaultTagsMergedWith(tags services.Tags) services.Tags {
return newTags
}

func (c *defaultCloud) FindTagsForARNs(ctx context.Context, arns []string) (map[string]services.Tags, error) {
if !c.cfg.PrivateVPC {
return c.tagging.GetTagsForArns(ctx, arns)
}

tagsForARNs := map[string]services.Tags{}

for _, arn := range arns {
tags, err := c.lattice.ListTagsForResourceWithContext(ctx,
&vpclattice.ListTagsForResourceInput{ResourceArn: aws.String(arn)},
)
if err != nil {
return nil, err
}
tagsForARNs[arn] = tags.Tags
}
return tagsForARNs, nil
}

func (c *defaultCloud) FindTargetGroupARNs(ctx context.Context, tags services.Tags) ([]string, error) {
if !c.cfg.PrivateVPC {
return c.tagging.FindResourcesByTags(ctx, services.ResourceTypeTargetGroup, tags)
}

tgs, err := c.lattice.ListTargetGroupsAsList(ctx, &vpclattice.ListTargetGroupsInput{
VpcIdentifier: aws.String(c.cfg.VpcId),
})
if err != nil {
return nil, err
}

arns := make([]string, 0, len(tgs))

for _, tg := range tgs {
resp, err := c.lattice.ListTagsForResourceWithContext(ctx,
&vpclattice.ListTagsForResourceInput{ResourceArn: tg.Arn},
)
if err != nil {
return nil, err
}

if containsTags(tags, resp.Tags) {
arns = append(arns, aws.StringValue(tg.Arn))
}
}

return arns, nil
}

func (c *defaultCloud) getTags(ctx context.Context, arn string) (services.Tags, error) {
tagsReq := &vpclattice.ListTagsForResourceInput{ResourceArn: &arn}
resp, err := c.lattice.ListTagsForResourceWithContext(ctx, tagsReq)
Expand Down Expand Up @@ -172,6 +229,15 @@ func (c *defaultCloud) TryOwnFromTags(ctx context.Context, arn string, tags serv
return c.isOwner(managedBy), nil
}

func containsTags(source, check services.Tags) bool {
for k, v := range source {
if aws.StringValue(check[k]) != aws.StringValue(v) {
return false
}
}
return true
}

func (c *defaultCloud) ownResource(ctx context.Context, arn string) error {
_, err := c.Lattice().TagResourceWithContext(ctx, &vpclattice.TagResourceInput{
ResourceArn: &arn,
Expand Down
30 changes: 30 additions & 0 deletions pkg/aws/cloud_mocks.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/aws/cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestGetManagedByTag(t *testing.T) {
}

func TestDefaultTags(t *testing.T) {
cfg := CloudConfig{"acc", "vpc", "region", "cluster"}
cfg := CloudConfig{"acc", "vpc", "region", "cluster", false}
c := NewDefaultCloud(nil, cfg)
tags := c.DefaultTags()
tagWant := getManagedByTag(cfg)
Expand Down
8 changes: 8 additions & 0 deletions pkg/config/controller_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
CLUSTER_VPC_ID = "CLUSTER_VPC_ID"
CLUSTER_NAME = "CLUSTER_NAME"
DEFAULT_SERVICE_NETWORK = "DEFAULT_SERVICE_NETWORK"
ENABLE_PRIVATE_VPC = "ENABLE_PRIVATE_VPC"
ENABLE_SERVICE_NETWORK_OVERRIDE = "ENABLE_SERVICE_NETWORK_OVERRIDE"
AWS_ACCOUNT_ID = "AWS_ACCOUNT_ID"
DEV_MODE = "DEV_MODE"
Expand All @@ -37,6 +38,7 @@ var ClusterName = ""
var DevMode = ""
var WebhookEnabled = ""

var EnablePrivateVPC = false
var ServiceNetworkOverrideMode = false

func ConfigInit() error {
Expand Down Expand Up @@ -82,6 +84,12 @@ func configInit(sess *session.Session, metadata EC2Metadata) error {
ServiceNetworkOverrideMode = true
}

privateVPC := os.Getenv(ENABLE_PRIVATE_VPC)

if strings.ToLower(privateVPC) == "true" {
EnablePrivateVPC = true
}

ClusterName, err = getClusterName(sess)
if err != nil {
return fmt.Errorf("cannot get cluster name: %s", err)
Expand Down
7 changes: 3 additions & 4 deletions pkg/deploy/lattice/target_group_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"reflect"

"github.com/aws/aws-application-networking-k8s/pkg/aws/services"
"github.com/aws/aws-application-networking-k8s/pkg/utils"
Expand All @@ -13,7 +14,6 @@ import (
pkg_aws "github.com/aws/aws-application-networking-k8s/pkg/aws"
model "github.com/aws/aws-application-networking-k8s/pkg/model/lattice"
"github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog"
"reflect"
)

//go:generate mockgen -destination target_group_manager_mock.go -package lattice github.com/aws/aws-application-networking-k8s/pkg/deploy/lattice TargetGroupManager
Expand Down Expand Up @@ -258,7 +258,7 @@ func (s *defaultTargetGroupManager) List(ctx context.Context) ([]tgListOutput, e
tgArns := utils.SliceMap(resp, func(tg *vpclattice.TargetGroupSummary) string {
return aws.StringValue(tg.Arn)
})
tgArnToTagsMap, err := s.cloud.Tagging().GetTagsForArns(ctx, tgArns)
tgArnToTagsMap, err := s.cloud.FindTagsForARNs(ctx, tgArns)

if err != nil {
return nil, err
Expand All @@ -276,8 +276,7 @@ func (s *defaultTargetGroupManager) findTargetGroup(
ctx context.Context,
modelTargetGroup *model.TargetGroup,
) (*vpclattice.GetTargetGroupOutput, error) {
arns, err := s.cloud.Tagging().FindResourcesByTags(ctx, services.ResourceTypeTargetGroup,
model.TagsFromTGTagFields(modelTargetGroup.Spec.TargetGroupTagFields))
arns, err := s.cloud.FindTargetGroupARNs(ctx, model.TagsFromTGTagFields(modelTargetGroup.Spec.TargetGroupTagFields))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 5e8d935

Please sign in to comment.