diff --git a/client.go b/client.go index 76fc615..f7d77cd 100644 --- a/client.go +++ b/client.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "log" - "strconv" + "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -49,12 +49,105 @@ func (p *Provider) init(ctx context.Context) { cfg, err := config.LoadDefaultConfig(ctx, opts...) if err != nil { - log.Fatal(err) + log.Fatalf("route53: unable to load AWS SDK config, %v", err) } p.client = r53.NewFromConfig(cfg) } +func chunkString(s string, chunkSize int) []string { + var chunks []string + for i := 0; i < len(s); i += chunkSize { + end := i + chunkSize + if end > len(s) { + end = len(s) + } + chunks = append(chunks, s[i:end]) + } + return chunks +} + +func parseRecordSet(set types.ResourceRecordSet) []libdns.Record { + records := make([]libdns.Record, 0) + + // Route53 returns TXT & SPF records with quotes around them. + // https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/ResourceRecordTypes.html#TXTFormat + var ttl int64 + if set.TTL != nil { + ttl = *set.TTL + } + + rtype := string(set.Type) + for _, record := range set.ResourceRecords { + value := *record.Value + switch rtype { + case "TXT", "SPF": + rows := strings.Split(value, "\n") + for i, row := range rows { + parts := strings.Split(row, `" "`) + if len(parts) > 0 { + parts[0] = strings.TrimPrefix(parts[0], `"`) + parts[len(parts)-1] = strings.TrimSuffix(parts[len(parts)-1], `"`) + } + + // Join parts + row = strings.Join(parts, "") + row = unquote(row) + rows[i] = row + + records = append(records, libdns.Record{ + Name: *set.Name, + Value: row, + Type: rtype, + TTL: time.Duration(ttl) * time.Second, + }) + } + default: + records = append(records, libdns.Record{ + Name: *set.Name, + Value: value, + Type: rtype, + TTL: time.Duration(ttl) * time.Second, + }) + } + + } + + return records +} + +func marshalRecord(record libdns.Record) []types.ResourceRecord { + resourceRecords := make([]types.ResourceRecord, 0) + + // Route53 requires TXT & SPF records to be quoted. + // https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/ResourceRecordTypes.html#TXTFormat + switch record.Type { + case "TXT", "SPF": + strs := make([]string, 0) + if len(record.Value) > 255 { + strs = append(strs, chunkString(record.Value, 255)...) + } else { + strs = append(strs, record.Value) + } + + // Quote strings + for i, str := range strs { + strs[i] = quote(str) + } + + // Finally join chunks with spaces + resourceRecords = append(resourceRecords, types.ResourceRecord{ + Value: aws.String(strings.Join(strs, " ")), + }) + default: + resourceRecords = append(resourceRecords, types.ResourceRecord{ + Value: aws.String(record.Value), + }) + } + + return resourceRecords +} + func (p *Provider) getRecords(ctx context.Context, zoneID string, zone string) ([]libdns.Record, error) { getRecordsInput := &r53.ListResourceRecordSetsInput{ HostedZoneId: aws.String(zoneID), @@ -79,6 +172,10 @@ func (p *Provider) getRecords(ctx context.Context, zoneID string, zone string) ( } recordSets = append(recordSets, getRecordResult.ResourceRecordSets...) + for _, s := range recordSets { + records = append(records, parseRecordSet(s)...) + } + if getRecordResult.IsTruncated { getRecordsInput.StartRecordName = getRecordResult.NextRecordName getRecordsInput.StartRecordType = getRecordResult.NextRecordType @@ -88,31 +185,6 @@ func (p *Provider) getRecords(ctx context.Context, zoneID string, zone string) ( } } - for _, rrset := range recordSets { - for _, rrsetRecord := range rrset.ResourceRecords { - rtype := rrset.Type - value := *rrsetRecord.Value - // Route53 returns TXT & SPF records with quotes around them. - // https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/ResourceRecordTypes.html#TXTFormat - switch rtype { - case types.RRTypeTxt, types.RRTypeSpf: - var err error - value, err = strconv.Unquote(value) - if err != nil { - return records, fmt.Errorf("Error unquoting TXT/SPF record: %s", err) - } - } - record := libdns.Record{ - Name: *rrset.Name, - Value: value, - Type: string(rtype), - TTL: time.Duration(*rrset.TTL) * time.Second, - } - - records = append(records, record) - } - } - return records, nil } @@ -170,24 +242,19 @@ func (p *Provider) createRecord(ctx context.Context, zoneID string, record libdn switch record.Type { case "TXT": return p.updateRecord(ctx, zoneID, record, zone) - case "SPF": - record.Value = strconv.Quote(record.Value) } + resourceRecords := marshalRecord(record) createInput := &r53.ChangeResourceRecordSetsInput{ ChangeBatch: &types.ChangeBatch{ Changes: []types.Change{ { Action: types.ChangeActionCreate, ResourceRecordSet: &types.ResourceRecordSet{ - Name: aws.String(libdns.AbsoluteName(record.Name, zone)), - ResourceRecords: []types.ResourceRecord{ - { - Value: aws.String(record.Value), - }, - }, - TTL: aws.Int64(int64(record.TTL.Seconds())), - Type: types.RRType(record.Type), + Name: aws.String(libdns.AbsoluteName(record.Name, zone)), + ResourceRecords: resourceRecords, + TTL: aws.Int64(int64(record.TTL.Seconds())), + Type: types.RRType(record.Type), }, }, }, @@ -206,12 +273,6 @@ func (p *Provider) createRecord(ctx context.Context, zoneID string, record libdn func (p *Provider) updateRecord(ctx context.Context, zoneID string, record libdns.Record, zone string) (libdns.Record, error) { resourceRecords := make([]types.ResourceRecord, 0) // AWS Route53 TXT record value must be enclosed in quotation marks on update - switch record.Type { - case "SPF", "TXT": - resourceRecords = append(resourceRecords, types.ResourceRecord{ - Value: aws.String(strconv.Quote(record.Value)), - }) - } if record.Type == "TXT" { txtRecords, err := p.getTxtRecordsFor(ctx, zoneID, zone, record.Name) if err != nil { @@ -219,13 +280,12 @@ func (p *Provider) updateRecord(ctx context.Context, zoneID string, record libdn } for _, r := range txtRecords { if record.Value != r.Value { - resourceRecords = append(resourceRecords, types.ResourceRecord{ - Value: aws.String(strconv.Quote(r.Value)), - }) + resourceRecords = append(resourceRecords, marshalRecord(r)...) } } } + resourceRecords = append(resourceRecords, marshalRecord(record)...) updateInput := &r53.ChangeResourceRecordSetsInput{ ChangeBatch: &types.ChangeBatch{ Changes: []types.Change{ @@ -255,28 +315,24 @@ func (p *Provider) deleteRecord(ctx context.Context, zoneID string, record libdn action := types.ChangeActionDelete resourceRecords := make([]types.ResourceRecord, 0) // AWS Route53 TXT record value must be enclosed in quotation marks on update - switch record.Type { - case "SPF", "TXT": - resourceRecords = append(resourceRecords, types.ResourceRecord{ - Value: aws.String(strconv.Quote(record.Value)), - }) - } if record.Type == "TXT" { txtRecords, err := p.getTxtRecordsFor(ctx, zoneID, zone, record.Name) if err != nil { return record, err } + switch { - case len(txtRecords) > 0 && txtRecords[0].Value != record.Value, - len(txtRecords) > 1: + // If there is only one record, we can delete the entire record set. + case len(txtRecords) == 1: + resourceRecords = append(resourceRecords, marshalRecord(record)...) + // If there are multiple records, we need to upsert the remaining records. + case len(txtRecords) > 1: action = types.ChangeActionUpsert resourceRecords = make([]types.ResourceRecord, 0) - } - for _, r := range txtRecords { - if record.Value != r.Value { - resourceRecords = append(resourceRecords, types.ResourceRecord{ - Value: aws.String(strconv.Quote(r.Value)), - }) + for _, r := range txtRecords { + if record.Value != r.Value { + resourceRecords = append(resourceRecords, marshalRecord(r)...) + } } } } diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..22896e8 --- /dev/null +++ b/client_test.go @@ -0,0 +1,341 @@ +package route53 + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/route53/types" + "github.com/libdns/libdns" +) + +func TestTXTMarshalling(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + { + name: "string with quotes", + input: `This string includes "quotation marks".`, + expected: `"This string includes \"quotation marks\"."`, + }, + { + name: "string with backslashes", + input: `This string includes \backslashes\`, + expected: `"This string includes \\backslashes\\"`, + }, + { + name: "string with special characters", + input: `The last character in this string is an accented e specified in octal format: é`, + expected: `"The last character in this string is an accented e specified in octal format: \351"`, + }, + { + name: "simple", + input: "v=spf1 ip4:192.168.0.1/16 -all", + expected: `"v=spf1 ip4:192.168.0.1/16 -all"`, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + actual := quote(c.input) + if actual != c.expected { + t.Errorf("expected %s, got %s", c.expected, actual) + } + }) + } +} + +func TestTXTUnmarhalling(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + { + name: "string with quotes", + input: `"This string includes \"quotation marks\"."`, + expected: `This string includes "quotation marks".`, + }, + { + name: "string with backslashes", + input: `"This string includes \\backslashes\\"`, + expected: `This string includes \backslashes\`, + }, + { + name: "string with special characters", + input: `"The last character in this string is an accented e specified in octal format: \351"`, + expected: `The last character in this string is an accented e specified in octal format: é`, + }, + { + name: "simple", + input: `"v=spf1 ip4:192.168.0.1/16 -all"`, + expected: "v=spf1 ip4:192.168.0.1/16 -all", + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + actual := unquote(c.input) + if actual != c.expected { + t.Errorf("expected %s, got %s", c.expected, actual) + } + }) + } +} + +func TestParseRecordSet(t *testing.T) { + cases := []struct { + name string + input types.ResourceRecordSet + expected []libdns.Record + }{ + { + name: "A record", + input: types.ResourceRecordSet{ + Name: aws.String(""), + Type: types.RRTypeA, + ResourceRecords: []types.ResourceRecord{ + { + Value: aws.String("127.0.0.1"), + }, + }, + }, + expected: []libdns.Record{ + { + Type: "A", + Name: "", + Value: "127.0.0.1", + }, + }, + }, + { + name: "CNAME record", + input: types.ResourceRecordSet{ + Name: aws.String("*"), + Type: types.RRTypeCname, + ResourceRecords: []types.ResourceRecord{ + { + Value: aws.String("example.com"), + }, + }, + }, + expected: []libdns.Record{ + { + Type: "CNAME", + Name: "*", + Value: "example.com", + }, + }, + }, + { + name: "TXT record", + input: types.ResourceRecordSet{ + Name: aws.String("test"), + Type: types.RRTypeTxt, + ResourceRecords: []types.ResourceRecord{ + { + Value: aws.String(`"This string includes \"quotation marks\"."`), + }, + { + Value: aws.String(`"This string includes \\backslashes\\"`), + }, + { + Value: aws.String(`"The last character in this string is an accented e specified in octal format: \351"`), + }, + { + Value: aws.String(`"String 1" "String 2" "String 3"`), + }, + }, + }, + expected: []libdns.Record{ + { + Type: "TXT", + Name: "test", + Value: `This string includes "quotation marks".`, + }, + { + Type: "TXT", + Name: "test", + Value: `This string includes \backslashes\`, + }, + { + Type: "TXT", + Name: "test", + Value: `The last character in this string is an accented e specified in octal format: é`, + }, + { + Type: "TXT", + Name: "test", + Value: `String 1String 2String 3`, + }, + }, + }, + { + name: "TXT long record", + input: types.ResourceRecordSet{ + Name: aws.String("_testlong"), + Type: types.RRTypeTxt, + ResourceRecords: []types.ResourceRecord{ + { + Value: aws.String(`"3gImdrsMGi6MzHi2rMviVqvwJbv7tXDPk6JvUEI2Fnl7sRF1bUSjNIe4qnatzomDu368bV6Q45qItkF wwnYoGBXNu1uclGvlPIIcGQd6wqBPzTtv0P83brCXJ59RJNLnAif8a3EQuLy88GmblPq 42uJpHTeNYnDRLQt8WvhRCYySX6bx" "vJtK8TZJtVRFbCgUrziRgQVzLwV4fn2hitpnItt U3Ke9IE5 gcs1Obx9kG8wkQ9h4qIxKDLVsmYdhuw4kdLmM2Qm6jJ3ZlSIaQWFP2eNLq5NwZfgATZiGRhr"`), + }, + }, + }, + expected: []libdns.Record{ + { + Type: "TXT", + Name: "_testlong", + Value: "3gImdrsMGi6MzHi2rMviVqvwJbv7tXDPk6JvUEI2Fnl7sRF1bUSjNIe4qnatzomDu368bV6Q45qItkF wwnYoGBXNu1uclGvlPIIcGQd6wqBPzTtv0P83brCXJ59RJNLnAif8a3EQuLy88GmblPq 42uJpHTeNYnDRLQt8WvhRCYySX6bxvJtK8TZJtVRFbCgUrziRgQVzLwV4fn2hitpnItt U3Ke9IE5 gcs1Obx9kG8wkQ9h4qIxKDLVsmYdhuw4kdLmM2Qm6jJ3ZlSIaQWFP2eNLq5NwZfgATZiGRhr", + }, + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + actual := parseRecordSet(c.input) + if len(actual) != len(c.expected) { + t.Errorf("expected %d records, got %d", len(c.expected), len(actual)) + } + for i, record := range actual { + if record.Type != c.expected[i].Type { + t.Errorf("expected type %s, got %s", c.expected[i].Type, record.Type) + } + if record.Name != c.expected[i].Name { + t.Errorf("expected name %s, got %s", c.expected[i].Name, record.Name) + } + if record.Value != c.expected[i].Value { + t.Errorf("expected value %s, got %s", c.expected[i].Value, record.Value) + } + } + }) + } +} + +func TestMarshalRecord(t *testing.T) { + cases := []struct { + name string + input libdns.Record + expected []types.ResourceRecord + }{ + { + name: "A record", + input: libdns.Record{ + Type: "A", + Name: "", + Value: "127.0.0.1", + }, + expected: []types.ResourceRecord{ + { + Value: aws.String("127.0.0.1"), + }, + }, + }, + { + name: "A record with name", + input: libdns.Record{ + Type: "A", + Name: "test", + Value: "127.0.0.1", + }, + expected: []types.ResourceRecord{ + { + Value: aws.String("127.0.0.1"), + }, + }, + }, + { + name: "TXT record", + input: libdns.Record{ + Type: "TXT", + Name: "", + Value: "test", + }, + expected: []types.ResourceRecord{ + { + Value: aws.String(`"test"`), + }, + }, + }, + { + name: "TXT record with name", + input: libdns.Record{ + Type: "TXT", + Name: "test", + Value: "test", + }, + expected: []types.ResourceRecord{ + { + Value: aws.String(`"test"`), + }, + }, + }, + { + name: "TXT record with long value", + input: libdns.Record{ + Type: "TXT", + Name: "test", + Value: `3gImdrsMGi6MzHi2rMviVqvwJbv7tXDPk6JvUEI2Fnl7sRF1bUSjNIe4qnatzomDu368bV6Q45qItkF wwnYoGBXNu1uclGvlPIIcGQd6wqBPzTtv0P83brCXJ59RJNLnAif8a3EQuLy88GmblPq 42uJpHTeNYnDRLQt8WvhRCYySX6bxvJtK8TZJtVRFbCgUrziRgQVzLwV4fn2hitpnItt U3Ke9IE5 gcs1Obx9kG8wkQ9h4qIxKDLVsmYdhuw4kdLmM2Qm6jJ3ZlSIaQWFP2eNLq5NwZfgATZiGRhr`, + }, + expected: []types.ResourceRecord{ + { + Value: aws.String(`"3gImdrsMGi6MzHi2rMviVqvwJbv7tXDPk6JvUEI2Fnl7sRF1bUSjNIe4qnatzomDu368bV6Q45qItkF wwnYoGBXNu1uclGvlPIIcGQd6wqBPzTtv0P83brCXJ59RJNLnAif8a3EQuLy88GmblPq 42uJpHTeNYnDRLQt8WvhRCYySX6bxvJtK8TZJtVRFbCgUrziRgQVzLwV4fn2hitpnItt U3Ke9IE5 gcs1Obx9kG8wkQ9h4qIxKDLVsmYd" "huw4kdLmM2Qm6jJ3ZlSIaQWFP2eNLq5NwZfgATZiGRhr"`), + }, + }, + }, + { + name: "TXT record with a special character", + input: libdns.Record{ + Type: "TXT", + Name: "test", + Value: `test é`, + }, + expected: []types.ResourceRecord{ + { + Value: aws.String(`"test \351"`), + }, + }, + }, + { + name: "TXT record with quotes", + input: libdns.Record{ + Type: "TXT", + Name: "test", + Value: `"test"`, + }, + expected: []types.ResourceRecord{ + { + Value: aws.String(`"\"test\""`), + }, + }, + }, + { + name: "TXT record with backslashes", + input: libdns.Record{ + Type: "TXT", + Name: "test", + Value: `\test\`, + }, + expected: []types.ResourceRecord{ + { + Value: aws.String(`"\\test\\"`), + }, + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + actual := marshalRecord(c.input) + if len(actual) != len(c.expected) { + t.Errorf("expected %d records, got %d", len(c.expected), len(actual)) + } + for i, record := range actual { + if *record.Value != *c.expected[i].Value { + t.Errorf("expected value %s, got %s", *c.expected[i].Value, *record.Value) + } + } + }) + } +} diff --git a/quote.go b/quote.go new file mode 100644 index 0000000..5073343 --- /dev/null +++ b/quote.go @@ -0,0 +1,62 @@ +package route53 + +import ( + "fmt" + "strconv" + "strings" +) + +func quote(s string) string { + // Special characters in a TXT record value + // + // If your TXT record contains any of the following characters, you must specify the characters by using escape codes in the format \three-digit octal code: + // Characters 000 to 040 octal (0 to 32 decimal, 0x00 to 0x20 hexadecimal) + // Characters 177 to 377 octal (127 to 255 decimal, 0x7F to 0xFF hexadecimal) + sb := strings.Builder{} + for _, c := range s { + if (c >= 0 && c < 32) || (c >= 127 && c <= 255) { + sb.WriteString(fmt.Sprintf("\\%03o", c)) + } else if c == '"' { + sb.WriteString(`\"`) + } else if c == '\\' { + sb.WriteString(`\\`) + } else { + sb.WriteRune(c) + } + } + s = sb.String() + + // Quote strings + s = `"` + s + `"` + + return s +} + +func unquote(s string) string { + // Unescape special characters + var sb strings.Builder + for i := 0; i < len(s); i++ { + c := rune(s[i]) + if c == '\\' && len(s) > i+1 { + if s[i+1] == '"' { + sb.WriteRune('"') + i++ + continue + } else if s[i+1] == '\\' { + sb.WriteRune('\\') + i++ + continue + } else if s[i+1] >= '0' && s[i+1] <= '7' && len(s) > i+3 { + octal, err := strconv.ParseInt(s[i+1:i+4], 8, 32) + if err == nil { + sb.WriteRune(rune(octal)) + i += 3 + continue + } + } + } + sb.WriteRune(c) + } + + return strings.Trim(sb.String(), `"`) +}