diff --git a/cmd/aws-application-networking-k8s/main.go b/cmd/aws-application-networking-k8s/main.go index 744da27a..3a8acb73 100644 --- a/cmd/aws-application-networking-k8s/main.go +++ b/cmd/aws-application-networking-k8s/main.go @@ -18,12 +18,13 @@ package main import ( "flag" + "os" + "strings" + "github.com/aws/aws-application-networking-k8s/pkg/webhook" "github.com/go-logr/zapr" "go.uber.org/zap/zapcore" - "os" k8swebhook "sigs.k8s.io/controller-runtime/pkg/webhook" - "strings" "github.com/aws/aws-application-networking-k8s/pkg/aws" "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" @@ -119,6 +120,7 @@ func main() { "DefaultServiceNetwork", config.DefaultServiceNetwork, "ClusterName", config.ClusterName, "LogLevel", logLevel, + "EnablePrivateVPC", config.EnablePrivateVPC, ) cloud, err := aws.NewCloud(log.Named("cloud"), aws.CloudConfig{ @@ -126,6 +128,7 @@ func main() { AccountId: config.AccountID, Region: config.Region, ClusterName: config.ClusterName, + PrivateVPC: config.EnablePrivateVPC, }) if err != nil { setupLog.Fatal("cloud client setup failed: %s", err) diff --git a/pkg/aws/cloud.go b/pkg/aws/cloud.go index aae7de93..afeb725d 100644 --- a/pkg/aws/cloud.go +++ b/pkg/aws/cloud.go @@ -1,14 +1,15 @@ package aws import ( + "context" "fmt" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/vpclattice" "golang.org/x/exp/maps" - "context" "github.com/aws/aws-application-networking-k8s/pkg/aws/services" "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" ) @@ -25,6 +26,7 @@ type CloudConfig struct { AccountId string Region string ClusterName string + PrivateVPC bool } type Cloud interface { @@ -38,6 +40,12 @@ type Cloud interface { // creates lattice tags with default values populated and merges them with provided tags DefaultTagsMergedWith(services.Tags) services.Tags + // find tags on lattice resources + FindTagsForARNs(ctx context.Context, arns []string) (map[string]services.Tags, error) + + // find lattice target group ARNs using tags + FindTargetGroupARNs(context.Context, services.Tags) ([]string, error) + // check if managedBy tag set for lattice resource IsArnManaged(ctx context.Context, arn string) (bool, error) @@ -125,6 +133,55 @@ func (c *defaultCloud) DefaultTagsMergedWith(tags services.Tags) services.Tags { return newTags } +func (c *defaultCloud) FindTagsForARNs(ctx context.Context, arns []string) (map[string]services.Tags, error) { + if !c.cfg.PrivateVPC { + return c.tagging.GetTagsForArns(ctx, arns) + } + + tagsForARNs := map[string]services.Tags{} + + for _, arn := range arns { + tags, err := c.lattice.ListTagsForResourceWithContext(ctx, + &vpclattice.ListTagsForResourceInput{ResourceArn: aws.String(arn)}, + ) + if err != nil { + return nil, err + } + tagsForARNs[arn] = tags.Tags + } + return tagsForARNs, nil +} + +func (c *defaultCloud) FindTargetGroupARNs(ctx context.Context, tags services.Tags) ([]string, error) { + if !c.cfg.PrivateVPC { + return c.tagging.FindResourcesByTags(ctx, services.ResourceTypeTargetGroup, tags) + } + + tgs, err := c.lattice.ListTargetGroupsAsList(ctx, &vpclattice.ListTargetGroupsInput{ + VpcIdentifier: aws.String(c.cfg.VpcId), + }) + if err != nil { + return nil, err + } + + arns := make([]string, 0, len(tgs)) + + for _, tg := range tgs { + resp, err := c.lattice.ListTagsForResourceWithContext(ctx, + &vpclattice.ListTagsForResourceInput{ResourceArn: tg.Arn}, + ) + if err != nil { + return nil, err + } + + if containsTags(tags, resp.Tags) { + arns = append(arns, aws.StringValue(tg.Arn)) + } + } + + return arns, nil +} + func (c *defaultCloud) getTags(ctx context.Context, arn string) (services.Tags, error) { tagsReq := &vpclattice.ListTagsForResourceInput{ResourceArn: &arn} resp, err := c.lattice.ListTagsForResourceWithContext(ctx, tagsReq) @@ -172,6 +229,15 @@ func (c *defaultCloud) TryOwnFromTags(ctx context.Context, arn string, tags serv return c.isOwner(managedBy), nil } +func containsTags(source, check services.Tags) bool { + for k, v := range source { + if aws.StringValue(check[k]) != aws.StringValue(v) { + return false + } + } + return true +} + func (c *defaultCloud) ownResource(ctx context.Context, arn string) error { _, err := c.Lattice().TagResourceWithContext(ctx, &vpclattice.TagResourceInput{ ResourceArn: &arn, diff --git a/pkg/aws/cloud_mocks.go b/pkg/aws/cloud_mocks.go index bd8fe67b..6249ea62 100644 --- a/pkg/aws/cloud_mocks.go +++ b/pkg/aws/cloud_mocks.go @@ -77,6 +77,36 @@ func (mr *MockCloudMockRecorder) DefaultTagsMergedWith(arg0 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DefaultTagsMergedWith", reflect.TypeOf((*MockCloud)(nil).DefaultTagsMergedWith), arg0) } +// FindTagsForARNs mocks base method. +func (m *MockCloud) FindTagsForARNs(arg0 context.Context, arg1 []string) (map[string]map[string]*string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindTagsForARNs", arg0, arg1) + ret0, _ := ret[0].(map[string]map[string]*string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindTagsForARNs indicates an expected call of FindTagsForARNs. +func (mr *MockCloudMockRecorder) FindTagsForARNs(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindTagsForARNs", reflect.TypeOf((*MockCloud)(nil).FindTagsForARNs), arg0, arg1) +} + +// FindTargetGroupARNs mocks base method. +func (m *MockCloud) FindTargetGroupARNs(arg0 context.Context, arg1 map[string]*string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindTargetGroupARNs", arg0, arg1) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindTargetGroupARNs indicates an expected call of FindTargetGroupARNs. +func (mr *MockCloudMockRecorder) FindTargetGroupARNs(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindTargetGroupARNs", reflect.TypeOf((*MockCloud)(nil).FindTargetGroupARNs), arg0, arg1) +} + // IsArnManaged mocks base method. func (m *MockCloud) IsArnManaged(arg0 context.Context, arg1 string) (bool, error) { m.ctrl.T.Helper() diff --git a/pkg/aws/cloud_test.go b/pkg/aws/cloud_test.go index 73ed4b01..4619e753 100644 --- a/pkg/aws/cloud_test.go +++ b/pkg/aws/cloud_test.go @@ -30,7 +30,7 @@ func TestGetManagedByTag(t *testing.T) { } func TestDefaultTags(t *testing.T) { - cfg := CloudConfig{"acc", "vpc", "region", "cluster"} + cfg := CloudConfig{"acc", "vpc", "region", "cluster", false} c := NewDefaultCloud(nil, cfg) tags := c.DefaultTags() tagWant := getManagedByTag(cfg) diff --git a/pkg/config/controller_config.go b/pkg/config/controller_config.go index 43af0679..f5ccdac0 100644 --- a/pkg/config/controller_config.go +++ b/pkg/config/controller_config.go @@ -23,6 +23,7 @@ const ( CLUSTER_VPC_ID = "CLUSTER_VPC_ID" CLUSTER_NAME = "CLUSTER_NAME" DEFAULT_SERVICE_NETWORK = "DEFAULT_SERVICE_NETWORK" + ENABLE_PRIVATE_VPC = "ENABLE_PRIVATE_VPC" ENABLE_SERVICE_NETWORK_OVERRIDE = "ENABLE_SERVICE_NETWORK_OVERRIDE" AWS_ACCOUNT_ID = "AWS_ACCOUNT_ID" DEV_MODE = "DEV_MODE" @@ -37,6 +38,7 @@ var ClusterName = "" var DevMode = "" var WebhookEnabled = "" +var EnablePrivateVPC = false var ServiceNetworkOverrideMode = false func ConfigInit() error { @@ -82,6 +84,12 @@ func configInit(sess *session.Session, metadata EC2Metadata) error { ServiceNetworkOverrideMode = true } + privateVPC := os.Getenv(ENABLE_PRIVATE_VPC) + + if strings.ToLower(privateVPC) == "true" { + EnablePrivateVPC = true + } + ClusterName, err = getClusterName(sess) if err != nil { return fmt.Errorf("cannot get cluster name: %s", err) diff --git a/pkg/deploy/lattice/target_group_manager.go b/pkg/deploy/lattice/target_group_manager.go index 0838f594..46319894 100644 --- a/pkg/deploy/lattice/target_group_manager.go +++ b/pkg/deploy/lattice/target_group_manager.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "reflect" "github.com/aws/aws-application-networking-k8s/pkg/aws/services" "github.com/aws/aws-application-networking-k8s/pkg/utils" @@ -13,7 +14,6 @@ import ( pkg_aws "github.com/aws/aws-application-networking-k8s/pkg/aws" model "github.com/aws/aws-application-networking-k8s/pkg/model/lattice" "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" - "reflect" ) //go:generate mockgen -destination target_group_manager_mock.go -package lattice github.com/aws/aws-application-networking-k8s/pkg/deploy/lattice TargetGroupManager @@ -258,7 +258,7 @@ func (s *defaultTargetGroupManager) List(ctx context.Context) ([]tgListOutput, e tgArns := utils.SliceMap(resp, func(tg *vpclattice.TargetGroupSummary) string { return aws.StringValue(tg.Arn) }) - tgArnToTagsMap, err := s.cloud.Tagging().GetTagsForArns(ctx, tgArns) + tgArnToTagsMap, err := s.cloud.FindTagsForARNs(ctx, tgArns) if err != nil { return nil, err @@ -276,8 +276,7 @@ func (s *defaultTargetGroupManager) findTargetGroup( ctx context.Context, modelTargetGroup *model.TargetGroup, ) (*vpclattice.GetTargetGroupOutput, error) { - arns, err := s.cloud.Tagging().FindResourcesByTags(ctx, services.ResourceTypeTargetGroup, - model.TagsFromTGTagFields(modelTargetGroup.Spec.TargetGroupTagFields)) + arns, err := s.cloud.FindTargetGroupARNs(ctx, model.TagsFromTGTagFields(modelTargetGroup.Spec.TargetGroupTagFields)) if err != nil { return nil, err }