diff --git a/controllers/gateway_controller.go b/controllers/gateway_controller.go index fc3610ca..5a86d03d 100644 --- a/controllers/gateway_controller.go +++ b/controllers/gateway_controller.go @@ -74,7 +74,7 @@ func RegisterGatewayController( scheme := mgr.GetScheme() evtRec := mgr.GetEventRecorderFor("gateway") - modelBuilder := gateway.NewServiceNetworkModelBuilder() + modelBuilder := gateway.NewServiceNetworkModelBuilder(mgrClient) stackDeployer := deploy.NewServiceNetworkStackDeployer(cloud, mgrClient) stackMarshaller := deploy.NewDefaultStackMarshaller() diff --git a/pkg/deploy/lattice/service_network_synthesizer_test.go b/pkg/deploy/lattice/service_network_synthesizer_test.go index b95ecef8..833a6393 100644 --- a/pkg/deploy/lattice/service_network_synthesizer_test.go +++ b/pkg/deploy/lattice/service_network_synthesizer_test.go @@ -8,7 +8,6 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" gateway_api "sigs.k8s.io/gateway-api/apis/v1beta1" @@ -87,6 +86,11 @@ func Test_SynthesizeTriggeredGateways(t *testing.T) { }, } + c := gomock.NewController(t) + defer c.Finish() + k8sClient := mock_client.NewMockClient(c) + k8sClient.EXPECT().List(context.Background(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + for _, tt := range tests { fmt.Printf("Testing >>>>> %v\n", tt.name) @@ -94,7 +98,7 @@ func Test_SynthesizeTriggeredGateways(t *testing.T) { defer c.Finish() ctx := context.TODO() - builder := gateway.NewServiceNetworkModelBuilder() + builder := gateway.NewServiceNetworkModelBuilder(k8sClient) stack, mesh, _ := builder.Build(context.Background(), tt.gw) diff --git a/pkg/gateway/model_build_service_network_test.go b/pkg/gateway/model_build_service_network_test.go index e066d767..92bd23cd 100644 --- a/pkg/gateway/model_build_service_network_test.go +++ b/pkg/gateway/model_build_service_network_test.go @@ -5,25 +5,46 @@ import ( "fmt" "testing" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "sigs.k8s.io/gateway-api/apis/v1alpha2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" gateway_api "sigs.k8s.io/gateway-api/apis/v1beta1" + + mock_client "github.com/aws/aws-application-networking-k8s/mocks/controller-runtime/client" + "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" ) func Test_MeshModelBuild(t *testing.T) { now := metav1.Now() + trueBool := true + falseBool := false + notRelatedVpcAssociationPolicy := v1alpha1.VpcAssociationPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "another-vpc-association-policy", + }, + Spec: v1alpha1.VpcAssociationPolicySpec{ + TargetRef: &v1alpha2.PolicyTargetReference{ + Group: gateway_api.GroupName, + Kind: "Gateway", + Name: "another-mesh", + }, + AssociateWithVpc: &falseBool, + }, + } tests := []struct { - name string - gw *gateway_api.Gateway - wantErr error - wantName string - wantNamespace string - wantIsDeleted bool - associateToVPC bool + name string + gw *gateway_api.Gateway + vpcAssociationPolicy *v1alpha1.VpcAssociationPolicy + wantErr error + wantName string + wantNamespace string + wantIsDeleted bool + associateToVPC bool }{ { - name: "Adding Mesh in default namespace, no annotation on VPC association", + name: "Adding Mesh in default namespace, no annotation on VPC association, associate to VPC by default", gw: &gateway_api.Gateway{ ObjectMeta: metav1.ObjectMeta{ Name: "mesh1", @@ -92,13 +113,106 @@ func Test_MeshModelBuild(t *testing.T) { wantIsDeleted: true, associateToVPC: true, }, + { + name: "Gateway has attached VpcAssociationPolicy found, VpcAssociationPolicy SecurityGroupIds are not empty", + gw: &gateway_api.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mesh1", + Finalizers: []string{"gateway.k8s.aws/resources"}, + }, + }, + vpcAssociationPolicy: &v1alpha1.VpcAssociationPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-vpc-association-policy", + }, + Spec: v1alpha1.VpcAssociationPolicySpec{ + TargetRef: &v1alpha2.PolicyTargetReference{ + Group: gateway_api.GroupName, + Kind: "Gateway", + Name: "mesh1", + }, + SecurityGroupIds: []v1alpha1.SecurityGroupId{"sg-123456", "sg-654321"}, + }, + }, + wantErr: nil, + wantName: "mesh1", + wantNamespace: "", + wantIsDeleted: false, + associateToVPC: true, + }, + { + name: "Gateway does not have LatticeVPCAssociationAnnotation, it has attached VpcAssociationPolicy found, which AssociateWithVpc field set to true", + gw: &gateway_api.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mesh1", + Finalizers: []string{"gateway.k8s.aws/resources"}, + }, + }, + vpcAssociationPolicy: &v1alpha1.VpcAssociationPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-vpc-association-policy", + }, + Spec: v1alpha1.VpcAssociationPolicySpec{ + TargetRef: &v1alpha2.PolicyTargetReference{ + Group: gateway_api.GroupName, + Kind: "Gateway", + Name: "mesh1", + }, + AssociateWithVpc: &trueBool, + }, + }, + wantErr: nil, + wantName: "mesh1", + wantNamespace: "", + wantIsDeleted: false, + associateToVPC: true, + }, + { + name: "Gateway does not have LatticeVPCAssociationAnnotation, it has attached VpcAssociationPolicy found, which AssociateWithVpc field set to false", + gw: &gateway_api.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mesh1", + Finalizers: []string{"gateway.k8s.aws/resources"}, + }, + }, + vpcAssociationPolicy: &v1alpha1.VpcAssociationPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-vpc-association-policy", + }, + Spec: v1alpha1.VpcAssociationPolicySpec{ + TargetRef: &v1alpha2.PolicyTargetReference{ + Group: gateway_api.GroupName, + Kind: "Gateway", + Name: "mesh1", + }, + AssociateWithVpc: &falseBool, + }, + }, + wantErr: nil, + wantName: "mesh1", + wantNamespace: "", + wantIsDeleted: false, + associateToVPC: false, + }, } + c := gomock.NewController(t) + defer c.Finish() + mock_client := mock_client.NewMockClient(c) + ctx := context.Background() for _, tt := range tests { fmt.Printf("Testing >>> %v\n", tt.name) t.Run(tt.name, func(t *testing.T) { - builder := NewServiceNetworkModelBuilder() - + mock_client.EXPECT().List(ctx, gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, policyList *v1alpha1.VpcAssociationPolicyList, arg3 ...interface{}) error { + policyList.Items = append(policyList.Items, notRelatedVpcAssociationPolicy) + if tt.vpcAssociationPolicy != nil { + policyList.Items = append(policyList.Items, *tt.vpcAssociationPolicy) + } + return nil + }, + ) + builder := NewServiceNetworkModelBuilder(mock_client) _, got, err := builder.Build(context.Background(), tt.gw) if tt.wantErr != nil { @@ -108,6 +222,9 @@ func Test_MeshModelBuild(t *testing.T) { assert.Equal(t, tt.wantNamespace, got.Spec.Namespace) assert.Equal(t, tt.wantIsDeleted, got.Spec.IsDeleted) assert.Equal(t, tt.associateToVPC, got.Spec.AssociateToVPC) + if tt.vpcAssociationPolicy != nil { + assert.Equal(t, securityGroupIdsToStringPointersSlice(tt.vpcAssociationPolicy.Spec.SecurityGroupIds), got.Spec.SecurityGroupIds) + } } }) diff --git a/pkg/gateway/model_build_servicenetwork.go b/pkg/gateway/model_build_servicenetwork.go index 5f613700..e2c60c48 100644 --- a/pkg/gateway/model_build_servicenetwork.go +++ b/pkg/gateway/model_build_servicenetwork.go @@ -4,8 +4,10 @@ import ( "context" corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/controller-runtime/pkg/client" gateway_api "sigs.k8s.io/gateway-api/apis/v1beta1" + "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" "github.com/aws/aws-application-networking-k8s/pkg/config" "github.com/aws/aws-application-networking-k8s/pkg/k8s" "github.com/aws/aws-application-networking-k8s/pkg/model/core" @@ -25,18 +27,23 @@ type ServiceNetworkModelBuilder interface { } type serviceNetworkModelBuilder struct { + client client.Client defaultTags map[string]string } -func NewServiceNetworkModelBuilder() *serviceNetworkModelBuilder { - return &serviceNetworkModelBuilder{} +func NewServiceNetworkModelBuilder(client client.Client) *serviceNetworkModelBuilder { + return &serviceNetworkModelBuilder{client: client} } func (b *serviceNetworkModelBuilder) Build(ctx context.Context, gw *gateway_api.Gateway) (core.Stack, *latticemodel.ServiceNetwork, error) { stack := core.NewDefaultStack(core.StackID(k8s.NamespacedName(gw))) - + vpcAssociationPolicy, err := GetAttachedPolicy(ctx, b.client, k8s.NamespacedName(gw), &v1alpha1.VpcAssociationPolicy{}) + if err != nil { + return nil, nil, err + } task := &serviceNetworkModelBuildTask{ - gateway: gw, - stack: stack, + gateway: gw, + vpcAssociationPolicy: vpcAssociationPolicy, + stack: stack, } if err := task.run(ctx); err != nil { @@ -63,25 +70,33 @@ func (t *serviceNetworkModelBuildTask) buildModel(ctx context.Context) error { func (t *serviceNetworkModelBuildTask) buildServiceNetwork(ctx context.Context) error { spec := latticemodel.ServiceNetworkSpec{ - Name: t.gateway.Name, - Namespace: t.gateway.Namespace, - Account: config.AccountID, - AssociateToVPC: false, + Name: t.gateway.Name, + Namespace: t.gateway.Namespace, + Account: config.AccountID, } - // by default it is true - spec.AssociateToVPC = true + // default associateToVPC is true + associateToVPC := true if len(t.gateway.ObjectMeta.Annotations) > 0 { if value, exist := t.gateway.Annotations[LatticeVPCAssociationAnnotation]; exist { if value == "true" { - spec.AssociateToVPC = true + associateToVPC = true } else { - spec.AssociateToVPC = false + associateToVPC = false } + } + } + if t.vpcAssociationPolicy != nil { + if t.vpcAssociationPolicy.Spec.AssociateWithVpc != nil { + associateToVPC = *t.vpcAssociationPolicy.Spec.AssociateWithVpc } + if t.vpcAssociationPolicy.Spec.SecurityGroupIds != nil { + spec.SecurityGroupIds = securityGroupIdsToStringPointersSlice(t.vpcAssociationPolicy.Spec.SecurityGroupIds) + } } + spec.AssociateToVPC = associateToVPC defaultSN, err := config.GetClusterLocalGateway() if err == nil && defaultSN != t.gateway.Name { @@ -101,9 +116,18 @@ func (t *serviceNetworkModelBuildTask) buildServiceNetwork(ctx context.Context) } type serviceNetworkModelBuildTask struct { - gateway *gateway_api.Gateway - - mesh *latticemodel.ServiceNetwork + gateway *gateway_api.Gateway + vpcAssociationPolicy *v1alpha1.VpcAssociationPolicy + mesh *latticemodel.ServiceNetwork stack core.Stack } + +func securityGroupIdsToStringPointersSlice(sgIds []v1alpha1.SecurityGroupId) []*string { + var ret []*string + for _, sgId := range sgIds { + sgIdStr := string(sgId) + ret = append(ret, &sgIdStr) + } + return ret +} diff --git a/pkg/model/lattice/servicenetwork.go b/pkg/model/lattice/servicenetwork.go index 6f653b55..35ef0bfe 100644 --- a/pkg/model/lattice/servicenetwork.go +++ b/pkg/model/lattice/servicenetwork.go @@ -21,11 +21,12 @@ type ServiceNetwork struct { type ServiceNetworkSpec struct { // The name of the ServiceNetwork - Name string `json:"name"` - Namespace string `json:"namespace"` - Account string `json:"account"` - AssociateToVPC bool - IsDeleted bool + Name string `json:"name"` + Namespace string `json:"namespace"` + Account string `json:"account"` + SecurityGroupIds []*string `json:"securityGroupIds"` + AssociateToVPC bool + IsDeleted bool } type ServiceNetworkStatus struct {