diff --git a/api/v1alpha1/dnsrecord_types.go b/api/v1alpha1/dnsrecord_types.go index 8de27098..8d7fb766 100644 --- a/api/v1alpha1/dnsrecord_types.go +++ b/api/v1alpha1/dnsrecord_types.go @@ -106,48 +106,42 @@ const ( DefaultGeo string = "default" ) -// GetRootDomain returns the shortest domain that is shared across all spec.Endpoints dns names. -// Validates that all endpoints share an equal root domain and returns an error if they don't. -func (s *DNSRecord) GetRootDomain() (string, error) { - if err := s.Validate(); err != nil { - return "", err - } - if s.Spec.RootHost != nil { - return *s.Spec.RootHost, nil - } - domain := "" - dnsNames := []string{} - for idx := range s.Spec.Endpoints { - dnsNames = append(dnsNames, s.Spec.Endpoints[idx].DNSName) - } - for idx := range dnsNames { - if domain == "" || len(domain) > len(dnsNames[idx]) { - domain = dnsNames[idx] - } - } - - if domain == "" { - return "", fmt.Errorf("unable to determine root domain from %v", dnsNames) - } +const wildcardPrefix = "*." - for idx := range dnsNames { - if !strings.HasSuffix(dnsNames[idx], domain) { - return "", fmt.Errorf("inconsitent domains, got %s, expected suffix %s", dnsNames[idx], domain) - } +func (s *DNSRecord) isWildCardRoot() bool { + if s.Spec.RootHost != nil { + return strings.HasPrefix(*s.Spec.RootHost, wildcardPrefix) } - - return domain, nil + return false } func (s *DNSRecord) Validate() error { if s.Spec.RootHost != nil { - if len(strings.Split(*s.Spec.RootHost, ".")) <= 1 { + root := *s.Spec.RootHost + if len(strings.Split(root, ".")) <= 1 { return fmt.Errorf("invalid domain format no tld discovered") } + if len(s.Spec.Endpoints) == 0 { + return fmt.Errorf("no endpoints defined for DNSRecord. Nothing to do.") + } + + if s.isWildCardRoot() { + root = strings.Replace(root, wildcardPrefix, "", 1) + } + + rootEndpointFound := false for _, ep := range s.Spec.Endpoints { - if !strings.HasSuffix(ep.DNSName, *s.Spec.RootHost) { - return fmt.Errorf("invalid endpoint discovered %s all endpoints should be equal to or end with the rootHost %s", ep.DNSName, *s.Spec.RootHost) + if !strings.HasSuffix(ep.DNSName, root) { + return fmt.Errorf("invalid endpoint discovered %s all endpoints should be equal to or end with the rootHost %s", ep.DNSName, root) } + if !rootEndpointFound { + if ep.DNSName == root { + rootEndpointFound = true + } + } + } + if !rootEndpointFound && !s.isWildCardRoot() { + return fmt.Errorf("invalid endpoint set. rootHost is set but found no endpoint defining a record for the rootHost %s", root) } } return nil diff --git a/api/v1alpha1/dnsrecord_types_test.go b/api/v1alpha1/dnsrecord_types_test.go index e9f242ce..ac64a9a7 100644 --- a/api/v1alpha1/dnsrecord_types_test.go +++ b/api/v1alpha1/dnsrecord_types_test.go @@ -3,131 +3,63 @@ package v1alpha1 import ( "testing" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/external-dns/endpoint" ) -func TestDNSRecord_GetRootDomain(t *testing.T) { - var ( - rootTestExample = "test.example.com" - example = "example.com" - ) +func TestValidate(t *testing.T) { tests := []struct { name string - rootHost *string + rootHost string dnsNames []string - want string wantErr bool }{ { - name: "single endpoint", - rootHost: &rootTestExample, - dnsNames: []string{ - "test.example.com", - }, - want: "test.example.com", - wantErr: false, - }, - { - name: "multiple endpoints matching", - dnsNames: []string{ - "bar.baz.test.example.com", - "bar.test.example.com", - "test.example.com", - "foo.bar.baz.test.example.com", - }, - want: "test.example.com", - wantErr: false, + name: "invalid domain", + rootHost: "example", + wantErr: true, }, { name: "no endpoints", - dnsNames: []string{}, - want: "", + rootHost: "example.com", wantErr: true, }, { - rootHost: &example, - name: "multiple endpoints", + name: "invalid domain", + rootHost: "example.com", dnsNames: []string{ - "foo.bar.test.example.com", - "bar.test.example.com", - "baz.example.com", + "example.com", + "a.exmple.com", }, - want: "example.com", - wantErr: false, + wantErr: true, }, { - rootHost: &example, - name: "multiple endpoints mismatching", + name: "valid domain", + rootHost: "example.com", dnsNames: []string{ - "foo.bar.test.other.com", - "bar.test.example.com", - "baz.example.com", + "example.com", + "a.b.example.com", + "b.a.example.com", + "a.example.com", + "b.example.com", }, - want: "", - wantErr: true, + wantErr: false, }, { - name: "multiple endpoints no rootHost", + name: "valid wildcard domain", + rootHost: "*.example.com", dnsNames: []string{ - "foo.bar.test.other.com", - "bar.test.example.com", - "baz.example.com", + "*.example.com", + "a.b.example.com", + "b.a.example.com", + "a.example.com", + "b.example.com", }, - want: "", - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := &DNSRecord{ - TypeMeta: metav1.TypeMeta{ - Kind: "DNSRecord", - APIVersion: GroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "testRecord", - Namespace: "testNS", - }, - Spec: DNSRecordSpec{ - Endpoints: []*endpoint.Endpoint{}, - }, - } - if tt.rootHost != nil { - s.Spec.RootHost = tt.rootHost - } - for idx := range tt.dnsNames { - s.Spec.Endpoints = append(s.Spec.Endpoints, &endpoint.Endpoint{DNSName: tt.dnsNames[idx]}) - } - got, err := s.GetRootDomain() - if (err != nil) != tt.wantErr { - t.Errorf("GetRootDomain() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("GetRootDomain() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestValidate(t *testing.T) { - tests := []struct { - name string - rootHost string - dnsNames []string - wantErr bool - }{ - { - name: "invalid domain", - rootHost: "example", - wantErr: true, + wantErr: false, }, { - name: "valid domain", - rootHost: "example.com", + name: "valid wildcard domain no endpoint", + rootHost: "*.example.com", dnsNames: []string{ - "example.com", "a.b.example.com", "b.a.example.com", "a.example.com", @@ -144,6 +76,9 @@ func TestValidate(t *testing.T) { RootHost: &tt.rootHost, }, } + for idx := range tt.dnsNames { + record.Spec.Endpoints = append(record.Spec.Endpoints, &endpoint.Endpoint{DNSName: tt.dnsNames[idx]}) + } err := record.Validate() if (err != nil) != tt.wantErr { t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) diff --git a/internal/controller/dnsrecord_controller.go b/internal/controller/dnsrecord_controller.go index 42dd9aa0..41842d1b 100644 --- a/internal/controller/dnsrecord_controller.go +++ b/internal/controller/dnsrecord_controller.go @@ -31,6 +31,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/reconcile" externaldnsendpoint "sigs.k8s.io/external-dns/endpoint" externaldnsplan "sigs.k8s.io/external-dns/plan" externaldnsprovider "sigs.k8s.io/external-dns/provider" @@ -99,31 +100,39 @@ func (r *DNSRecordReconciler) Reconcile(ctx context.Context, req ctrl.Request) ( status := metav1.ConditionTrue reason = "ProviderSuccess" message = "Provider ensured the dns record" - + err = dnsRecord.Validate() + if err != nil { + status = metav1.ConditionFalse + reason = "ValidationError" + message = fmt.Sprintf("validation of DNSRecord failed: %v", err) + setDNSRecordCondition(dnsRecord, string(conditions.ConditionTypeReady), status, reason, message) + return r.updateStatus(ctx, previous, dnsRecord) + } // Publish the record err = r.publishRecord(ctx, dnsRecord) if err != nil { status = metav1.ConditionFalse reason = "ProviderError" message = fmt.Sprintf("The DNS provider failed to ensure the record: %v", provider.SanitizeError(err)) - } else { - dnsRecord.Status.ObservedGeneration = dnsRecord.Generation - dnsRecord.Status.Endpoints = dnsRecord.Spec.Endpoints + setDNSRecordCondition(dnsRecord, string(conditions.ConditionTypeReady), status, reason, message) + return r.updateStatus(ctx, previous, dnsRecord) } + // success setDNSRecordCondition(dnsRecord, string(conditions.ConditionTypeReady), status, reason, message) + dnsRecord.Status.ObservedGeneration = dnsRecord.Generation + dnsRecord.Status.Endpoints = dnsRecord.Spec.Endpoints + return r.updateStatus(ctx, previous, dnsRecord) +} - if !equality.Semantic.DeepEqual(previous.Status, dnsRecord.Status) { - updateErr := r.Status().Update(ctx, dnsRecord) - if updateErr != nil { - // Ignore conflicts, resource might just be outdated. - if apierrors.IsConflict(updateErr) { - return ctrl.Result{Requeue: true}, nil - } - return ctrl.Result{}, updateErr +func (r *DNSRecordReconciler) updateStatus(ctx context.Context, previous, current *v1alpha1.DNSRecord) (reconcile.Result, error) { + if !equality.Semantic.DeepEqual(previous.Status, current.Status) { + updateError := r.Status().Update(ctx, current) + if apierrors.IsConflict(updateError) { + return ctrl.Result{Requeue: true}, nil } + return ctrl.Result{}, updateError } - - return ctrl.Result{}, err + return ctrl.Result{}, nil } // SetupWithManager sets up the controller with the Manager. @@ -175,9 +184,6 @@ func (r *DNSRecordReconciler) deleteRecord(ctx context.Context, dnsRecord *v1alp // DNSRecord (dnsRecord.Status.ParentManagedZone). func (r *DNSRecordReconciler) publishRecord(ctx context.Context, dnsRecord *v1alpha1.DNSRecord) error { logger := log.FromContext(ctx) - if err := dnsRecord.Validate(); err != nil { - return fmt.Errorf("failed validation pre publish : %s", err) - } managedZone := &v1alpha1.ManagedZone{ ObjectMeta: metav1.ObjectMeta{ Name: dnsRecord.Spec.ManagedZoneRef.Name, @@ -223,21 +229,14 @@ func setDNSRecordCondition(dnsRecord *v1alpha1.DNSRecord, conditionType string, func (r *DNSRecordReconciler) applyChanges(ctx context.Context, dnsRecord *v1alpha1.DNSRecord, managedZone *v1alpha1.ManagedZone, isDelete bool) error { logger := log.FromContext(ctx) - rootDomain, err := dnsRecord.GetRootDomain() - if err != nil { - return err - } - if !strings.HasSuffix(rootDomain, managedZone.Spec.DomainName) { - return fmt.Errorf("inconsitent domains, does not match managedzone, got %s, expected suffix %s", rootDomain, managedZone.Spec.DomainName) - } - rootDomainFilter := externaldnsendpoint.NewDomainFilter([]string{rootDomain}) + rootDomainFilter := externaldnsendpoint.NewDomainFilter([]string{managedZone.Spec.DomainName}) providerConfig := provider.Config{ DomainFilter: externaldnsendpoint.NewDomainFilter([]string{managedZone.Spec.DomainName}), ZoneTypeFilter: externaldnsprovider.NewZoneTypeFilter(""), ZoneIDFilter: externaldnsprovider.NewZoneIDFilter([]string{managedZone.Status.ID}), } - logger.V(3).Info("applyChanges", "rootDomain", rootDomain, "rootDomainFilter", rootDomainFilter, "providerConfig", providerConfig) + logger.V(3).Info("applyChanges", "zone", managedZone.Spec.DomainName, "rootDomainFilter", rootDomainFilter, "providerConfig", providerConfig) dnsProvider, err := r.ProviderFactory.ProviderFor(ctx, managedZone, providerConfig) if err != nil { return err