Skip to content

Commit

Permalink
remove GetRootDomain. Improve validation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maleck13 committed Mar 12, 2024
1 parent 810ac73 commit 5bfb02e
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 155 deletions.
58 changes: 26 additions & 32 deletions api/v1alpha1/dnsrecord_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,48 +106,42 @@ const (
DefaultGeo string = "default"
)

// GetRootDomain returns the shortest domain that is shared across all spec.Endpoints dns names.
// Validates that all endpoints share an equal root domain and returns an error if they don't.
func (s *DNSRecord) GetRootDomain() (string, error) {
if err := s.Validate(); err != nil {
return "", err
}
if s.Spec.RootHost != nil {
return *s.Spec.RootHost, nil
}
domain := ""
dnsNames := []string{}
for idx := range s.Spec.Endpoints {
dnsNames = append(dnsNames, s.Spec.Endpoints[idx].DNSName)
}
for idx := range dnsNames {
if domain == "" || len(domain) > len(dnsNames[idx]) {
domain = dnsNames[idx]
}
}

if domain == "" {
return "", fmt.Errorf("unable to determine root domain from %v", dnsNames)
}
const wildcardPrefix = "*."

for idx := range dnsNames {
if !strings.HasSuffix(dnsNames[idx], domain) {
return "", fmt.Errorf("inconsitent domains, got %s, expected suffix %s", dnsNames[idx], domain)
}
func (s *DNSRecord) isWildCardRoot() bool {
if s.Spec.RootHost != nil {
return strings.HasPrefix(*s.Spec.RootHost, wildcardPrefix)
}

return domain, nil
return false
}

func (s *DNSRecord) Validate() error {
if s.Spec.RootHost != nil {
if len(strings.Split(*s.Spec.RootHost, ".")) <= 1 {
root := *s.Spec.RootHost
if len(strings.Split(root, ".")) <= 1 {
return fmt.Errorf("invalid domain format no tld discovered")
}
if len(s.Spec.Endpoints) == 0 {
return fmt.Errorf("no endpoints defined for DNSRecord. Nothing to do.")
}

if s.isWildCardRoot() {
root = strings.Replace(root, wildcardPrefix, "", 1)
}

rootEndpointFound := false
for _, ep := range s.Spec.Endpoints {
if !strings.HasSuffix(ep.DNSName, *s.Spec.RootHost) {
return fmt.Errorf("invalid endpoint discovered %s all endpoints should be equal to or end with the rootHost %s", ep.DNSName, *s.Spec.RootHost)
if !strings.HasSuffix(ep.DNSName, root) {
return fmt.Errorf("invalid endpoint discovered %s all endpoints should be equal to or end with the rootHost %s", ep.DNSName, root)
}
if !rootEndpointFound {
if ep.DNSName == root {
rootEndpointFound = true
}
}
}
if !rootEndpointFound && !s.isWildCardRoot() {
return fmt.Errorf("invalid endpoint set. rootHost is set but found no endpoint defining a record for the rootHost %s", root)
}
}
return nil
Expand Down
129 changes: 32 additions & 97 deletions api/v1alpha1/dnsrecord_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,131 +3,63 @@ package v1alpha1
import (
"testing"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/external-dns/endpoint"
)

func TestDNSRecord_GetRootDomain(t *testing.T) {
var (
rootTestExample = "test.example.com"
example = "example.com"
)
func TestValidate(t *testing.T) {
tests := []struct {
name string
rootHost *string
rootHost string
dnsNames []string
want string
wantErr bool
}{
{
name: "single endpoint",
rootHost: &rootTestExample,
dnsNames: []string{
"test.example.com",
},
want: "test.example.com",
wantErr: false,
},
{
name: "multiple endpoints matching",
dnsNames: []string{
"bar.baz.test.example.com",
"bar.test.example.com",
"test.example.com",
"foo.bar.baz.test.example.com",
},
want: "test.example.com",
wantErr: false,
name: "invalid domain",
rootHost: "example",
wantErr: true,
},
{
name: "no endpoints",
dnsNames: []string{},
want: "",
rootHost: "example.com",
wantErr: true,
},
{
rootHost: &example,
name: "multiple endpoints",
name: "invalid domain",
rootHost: "example.com",
dnsNames: []string{
"foo.bar.test.example.com",
"bar.test.example.com",
"baz.example.com",
"example.com",
"a.exmple.com",
},
want: "example.com",
wantErr: false,
wantErr: true,
},
{
rootHost: &example,
name: "multiple endpoints mismatching",
name: "valid domain",
rootHost: "example.com",
dnsNames: []string{
"foo.bar.test.other.com",
"bar.test.example.com",
"baz.example.com",
"example.com",
"a.b.example.com",
"b.a.example.com",
"a.example.com",
"b.example.com",
},
want: "",
wantErr: true,
wantErr: false,
},
{
name: "multiple endpoints no rootHost",
name: "valid wildcard domain",
rootHost: "*.example.com",
dnsNames: []string{
"foo.bar.test.other.com",
"bar.test.example.com",
"baz.example.com",
"*.example.com",
"a.b.example.com",
"b.a.example.com",
"a.example.com",
"b.example.com",
},
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &DNSRecord{
TypeMeta: metav1.TypeMeta{
Kind: "DNSRecord",
APIVersion: GroupVersion.String(),
},
ObjectMeta: metav1.ObjectMeta{
Name: "testRecord",
Namespace: "testNS",
},
Spec: DNSRecordSpec{
Endpoints: []*endpoint.Endpoint{},
},
}
if tt.rootHost != nil {
s.Spec.RootHost = tt.rootHost
}
for idx := range tt.dnsNames {
s.Spec.Endpoints = append(s.Spec.Endpoints, &endpoint.Endpoint{DNSName: tt.dnsNames[idx]})
}
got, err := s.GetRootDomain()
if (err != nil) != tt.wantErr {
t.Errorf("GetRootDomain() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("GetRootDomain() got = %v, want %v", got, tt.want)
}
})
}
}

func TestValidate(t *testing.T) {
tests := []struct {
name string
rootHost string
dnsNames []string
wantErr bool
}{
{
name: "invalid domain",
rootHost: "example",
wantErr: true,
wantErr: false,
},
{
name: "valid domain",
rootHost: "example.com",
name: "valid wildcard domain no endpoint",
rootHost: "*.example.com",
dnsNames: []string{
"example.com",
"a.b.example.com",
"b.a.example.com",
"a.example.com",
Expand All @@ -144,6 +76,9 @@ func TestValidate(t *testing.T) {
RootHost: &tt.rootHost,
},
}
for idx := range tt.dnsNames {
record.Spec.Endpoints = append(record.Spec.Endpoints, &endpoint.Endpoint{DNSName: tt.dnsNames[idx]})
}
err := record.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
Expand Down
51 changes: 25 additions & 26 deletions internal/controller/dnsrecord_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
externaldnsendpoint "sigs.k8s.io/external-dns/endpoint"
externaldnsplan "sigs.k8s.io/external-dns/plan"
externaldnsprovider "sigs.k8s.io/external-dns/provider"
Expand Down Expand Up @@ -99,31 +100,39 @@ func (r *DNSRecordReconciler) Reconcile(ctx context.Context, req ctrl.Request) (
status := metav1.ConditionTrue
reason = "ProviderSuccess"
message = "Provider ensured the dns record"

err = dnsRecord.Validate()
if err != nil {
status = metav1.ConditionFalse
reason = "ValidationError"
message = fmt.Sprintf("validation of DNSRecord failed: %v", err)
setDNSRecordCondition(dnsRecord, string(conditions.ConditionTypeReady), status, reason, message)
return r.updateStatus(ctx, previous, dnsRecord)
}
// Publish the record
err = r.publishRecord(ctx, dnsRecord)
if err != nil {
status = metav1.ConditionFalse
reason = "ProviderError"
message = fmt.Sprintf("The DNS provider failed to ensure the record: %v", provider.SanitizeError(err))
} else {
dnsRecord.Status.ObservedGeneration = dnsRecord.Generation
dnsRecord.Status.Endpoints = dnsRecord.Spec.Endpoints
setDNSRecordCondition(dnsRecord, string(conditions.ConditionTypeReady), status, reason, message)
return r.updateStatus(ctx, previous, dnsRecord)
}
// success
setDNSRecordCondition(dnsRecord, string(conditions.ConditionTypeReady), status, reason, message)
dnsRecord.Status.ObservedGeneration = dnsRecord.Generation
dnsRecord.Status.Endpoints = dnsRecord.Spec.Endpoints
return r.updateStatus(ctx, previous, dnsRecord)
}

if !equality.Semantic.DeepEqual(previous.Status, dnsRecord.Status) {
updateErr := r.Status().Update(ctx, dnsRecord)
if updateErr != nil {
// Ignore conflicts, resource might just be outdated.
if apierrors.IsConflict(updateErr) {
return ctrl.Result{Requeue: true}, nil
}
return ctrl.Result{}, updateErr
func (r *DNSRecordReconciler) updateStatus(ctx context.Context, previous, current *v1alpha1.DNSRecord) (reconcile.Result, error) {
if !equality.Semantic.DeepEqual(previous.Status, current.Status) {
updateError := r.Status().Update(ctx, current)
if apierrors.IsConflict(updateError) {
return ctrl.Result{Requeue: true}, nil
}
return ctrl.Result{}, updateError
}

return ctrl.Result{}, err
return ctrl.Result{}, nil
}

// SetupWithManager sets up the controller with the Manager.
Expand Down Expand Up @@ -175,9 +184,6 @@ func (r *DNSRecordReconciler) deleteRecord(ctx context.Context, dnsRecord *v1alp
// DNSRecord (dnsRecord.Status.ParentManagedZone).
func (r *DNSRecordReconciler) publishRecord(ctx context.Context, dnsRecord *v1alpha1.DNSRecord) error {
logger := log.FromContext(ctx)
if err := dnsRecord.Validate(); err != nil {
return fmt.Errorf("failed validation pre publish : %s", err)
}
managedZone := &v1alpha1.ManagedZone{
ObjectMeta: metav1.ObjectMeta{
Name: dnsRecord.Spec.ManagedZoneRef.Name,
Expand Down Expand Up @@ -223,21 +229,14 @@ func setDNSRecordCondition(dnsRecord *v1alpha1.DNSRecord, conditionType string,
func (r *DNSRecordReconciler) applyChanges(ctx context.Context, dnsRecord *v1alpha1.DNSRecord, managedZone *v1alpha1.ManagedZone, isDelete bool) error {
logger := log.FromContext(ctx)

rootDomain, err := dnsRecord.GetRootDomain()
if err != nil {
return err
}
if !strings.HasSuffix(rootDomain, managedZone.Spec.DomainName) {
return fmt.Errorf("inconsitent domains, does not match managedzone, got %s, expected suffix %s", rootDomain, managedZone.Spec.DomainName)
}
rootDomainFilter := externaldnsendpoint.NewDomainFilter([]string{rootDomain})
rootDomainFilter := externaldnsendpoint.NewDomainFilter([]string{managedZone.Spec.DomainName})

providerConfig := provider.Config{
DomainFilter: externaldnsendpoint.NewDomainFilter([]string{managedZone.Spec.DomainName}),
ZoneTypeFilter: externaldnsprovider.NewZoneTypeFilter(""),
ZoneIDFilter: externaldnsprovider.NewZoneIDFilter([]string{managedZone.Status.ID}),
}
logger.V(3).Info("applyChanges", "rootDomain", rootDomain, "rootDomainFilter", rootDomainFilter, "providerConfig", providerConfig)
logger.V(3).Info("applyChanges", "zone", managedZone.Spec.DomainName, "rootDomainFilter", rootDomainFilter, "providerConfig", providerConfig)
dnsProvider, err := r.ProviderFactory.ProviderFor(ctx, managedZone, providerConfig)
if err != nil {
return err
Expand Down

0 comments on commit 5bfb02e

Please sign in to comment.