diff --git a/common.go b/common.go index 85ca3b6..8750704 100644 --- a/common.go +++ b/common.go @@ -10,14 +10,23 @@ import ( ) func pae(pieces ...[]byte) []byte { - var buf bytes.Buffer - binary.Write(&buf, binary.LittleEndian, int64(len(pieces))) + size := 8 + for i := range pieces { + size += 8 + len(pieces[i]) + } + + buf := make([]byte, size) + binary.LittleEndian.PutUint64(buf, uint64(len(pieces))) + + idx := 8 + for i := range pieces { + binary.LittleEndian.PutUint64(buf[idx:], uint64(len(pieces[i]))) + idx += 8 - for _, p := range pieces { - binary.Write(&buf, binary.LittleEndian, int64(len(p))) - buf.Write(p) + copy(buf[idx:], pieces[i]) + idx += len(pieces[i]) } - return buf.Bytes() + return buf } func toBytes(x any) ([]byte, error) { diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..9cd37b3 --- /dev/null +++ b/common_test.go @@ -0,0 +1,60 @@ +package paseto + +import ( + "testing" +) + +func TestPAE(t *testing.T) { + testCases := []struct { + pieces [][]byte + want string + }{ + { + pieces: nil, + want: "\x00\x00\x00\x00\x00\x00\x00\x00", + }, + { + pieces: [][]byte{}, + want: "\x00\x00\x00\x00\x00\x00\x00\x00", + }, + { + pieces: [][]byte{nil}, + want: "\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + }, + { + pieces: [][]byte{[]byte("test")}, + want: "\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00test", + }, + } + + for _, tc := range testCases { + res := pae(tc.pieces...) + have := string(res) + + if have != tc.want { + t.Errorf("\nhave: %v\nwant: %v", have, tc.want) + } + } +} + +func BenchmarkPAE(b *testing.B) { + var nonce [32]byte + var encryptedPayload [256]byte + var footerBytes []byte + + pieces := [][]byte{ + []byte(v1LocHeader), + nonce[:], + encryptedPayload[:], + footerBytes, + } + + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + res := pae(pieces...) + if len(res) == 0 { + b.Fatal() + } + } +}