Skip to content

Commit

Permalink
tlv: fuzz test encoding/decoding (#7889)
Browse files Browse the repository at this point in the history
* tlv: fuzz tests for primitives

* tlv: fuzz tests for BigSize

We use a new harness to compare decoded values instead of encoded
values, since there may be some unparsed bytes in the original data.

* tlv: fuzz tests for truncated integers

These fuzz tests are identical to non-truncated integers, except that we
allow the fuzzer to choose decode lengths shorter than the length of
normal integers.

* tlv: fuzz tests for streams

* fixup! tlv: fuzz tests for truncated integers

loop over decode length

* fixup! tlv: fuzz tests for streams

better documentation
  • Loading branch information
morehouse authored Oct 10, 2023
1 parent 4f34606 commit 2d98dcf
Showing 1 changed file with 289 additions and 0 deletions.
289 changes: 289 additions & 0 deletions tlv/fuzz_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
package tlv

import (
"bytes"
"testing"

"github.com/btcsuite/btcd/btcec/v2"
"github.com/stretchr/testify/require"
)

// harness decodes the passed data, re-encodes it, and verifies that the
// re-encoded data matches the original data.
func harness(t *testing.T, data []byte, encode Encoder, decode Decoder,
val interface{}, decodeLen uint64) {

if uint64(len(data)) > decodeLen {
return
}

r := bytes.NewReader(data)

var buf [8]byte
if err := decode(r, val, &buf, decodeLen); err != nil {
return
}

var b bytes.Buffer
require.NoError(t, encode(&b, val, &buf))

// Use bytes.Equal instead of require.Equal so that nil and empty slices
// are considered equal.
require.True(
t, bytes.Equal(data, b.Bytes()), "%v != %v", data, b.Bytes(),
)
}

func FuzzUint8(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint8
harness(t, data, EUint8, DUint8, &val, 1)
})
}

func FuzzUint16(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint16
harness(t, data, EUint16, DUint16, &val, 2)
})
}

func FuzzUint32(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint32
harness(t, data, EUint32, DUint32, &val, 4)
})
}

func FuzzUint64(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint64
harness(t, data, EUint64, DUint64, &val, 8)
})
}

func FuzzBytes32(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val [32]byte
harness(t, data, EBytes32, DBytes32, &val, 32)
})
}

func FuzzBytes33(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val [33]byte
harness(t, data, EBytes33, DBytes33, &val, 33)
})
}

func FuzzBytes64(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val [64]byte
harness(t, data, EBytes64, DBytes64, &val, 64)
})
}

func FuzzPubKey(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val *btcec.PublicKey
harness(t, data, EPubKey, DPubKey, &val, 33)
})
}

func FuzzVarBytes(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val []byte
harness(t, data, EVarBytes, DVarBytes, &val, uint64(len(data)))
})
}

// bigSizeHarness works the same as harness, except that it compares decoded
// values instead of encoded values. We do this because DBigSize may leave some
// bytes unparsed from data, causing the encoded data to be shorter than the
// original.
func bigSizeHarness(t *testing.T, data []byte, val1, val2 interface{}) {
if len(data) > 9 {
return
}

r := bytes.NewReader(data)

var buf [8]byte
if err := DBigSize(r, val1, &buf, 9); err != nil {
return
}

var b bytes.Buffer
require.NoError(t, EBigSize(&b, val1, &buf))

r2 := bytes.NewReader(b.Bytes())
require.NoError(t, DBigSize(r2, val2, &buf, 9))

require.Equal(t, val1, val2)
}

func FuzzBigSize32(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val1, val2 uint32
bigSizeHarness(t, data, &val1, &val2)
})
}

func FuzzBigSize64(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val1, val2 uint64
bigSizeHarness(t, data, &val1, &val2)
})
}

func FuzzTUint16(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint16
for decodeLen := 0; decodeLen <= 2; decodeLen++ {
harness(
t, data, ETUint16, DTUint16, &val,
uint64(decodeLen),
)
}
})
}

func FuzzTUint32(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint32
for decodeLen := 0; decodeLen <= 4; decodeLen++ {
harness(
t, data, ETUint32, DTUint32, &val,
uint64(decodeLen),
)
}
})
}

func FuzzTUint64(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var val uint64
for decodeLen := 0; decodeLen <= 8; decodeLen++ {
harness(
t, data, ETUint64, DTUint64, &val,
uint64(decodeLen),
)
}
})
}

// encodeParsedTypes re-encodes TLVs decoded from a stream, using the
// parsedTypes and decodedRecords produced during decoding. This function
// requires that each record in decodedRecords has a type number equivalent to
// its index in the slice.
func encodeParsedTypes(t *testing.T, parsedTypes TypeMap,
decodedRecords []Record) []byte {

var encodeRecords []Record
for typ, val := range parsedTypes {
// If typ is present in decodedRecords, use the decoded value.
if typ < Type(len(decodedRecords)) {
encodeRecords = append(
encodeRecords, decodedRecords[typ],
)
continue
}

// Otherwise, typ is not present in decodedRecords, and we must
// create a new one.
val := val
encodeRecords = append(
encodeRecords, MakePrimitiveRecord(typ, &val),
)
}
SortRecords(encodeRecords)
encodeStream := MustNewStream(encodeRecords...)

var b bytes.Buffer
require.NoError(t, encodeStream.Encode(&b))

return b.Bytes()
}

// FuzzStream does two stream decode-encode cycles on the fuzzer data and checks
// that the encoded values match.
func FuzzStream(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
var (
u8 uint8
u16 uint16
u32 uint32
u64 uint64
b32 [32]byte
b33 [33]byte
b64 [64]byte
pk *btcec.PublicKey
b []byte
bs32 uint32
bs64 uint64
tu16 uint16
tu32 uint32
tu64 uint64
)

sizeTU16 := func() uint64 {
return SizeTUint16(tu16)
}
sizeTU32 := func() uint64 {
return SizeTUint32(tu32)
}
sizeTU64 := func() uint64 {
return SizeTUint64(tu64)
}

// We deliberately set each record's type number to its index in
// the slice, as this simplifies the re-encoding logic in
// encodeParsedTypes().
decodeRecords := []Record{
MakePrimitiveRecord(0, &u8),
MakePrimitiveRecord(1, &u16),
MakePrimitiveRecord(2, &u32),
MakePrimitiveRecord(3, &u64),
MakePrimitiveRecord(4, &b32),
MakePrimitiveRecord(5, &b33),
MakePrimitiveRecord(6, &b64),
MakePrimitiveRecord(7, &pk),
MakePrimitiveRecord(8, &b),
MakeBigSizeRecord(9, &bs32),
MakeBigSizeRecord(10, &bs64),
MakeDynamicRecord(
11, &tu16, sizeTU16, ETUint16, DTUint16,
),
MakeDynamicRecord(
12, &tu32, sizeTU32, ETUint32, DTUint32,
),
MakeDynamicRecord(
13, &tu64, sizeTU64, ETUint64, DTUint64,
),
}
decodeStream := MustNewStream(decodeRecords...)

r := bytes.NewReader(data)

// Use the P2P decoding method to avoid OOMs from large lengths
// in the fuzzer TLV data.
parsedTypes, err := decodeStream.DecodeWithParsedTypesP2P(r)
if err != nil {
return
}

encoded := encodeParsedTypes(t, parsedTypes, decodeRecords)

r2 := bytes.NewReader(encoded)
decodeStream2 := MustNewStream(decodeRecords...)

// The P2P decoding method is not required here since we're now
// decoding TLV data that we created (not the fuzzer).
parsedTypes2, err := decodeStream2.DecodeWithParsedTypes(r2)
require.NoError(t, err)

encoded2 := encodeParsedTypes(t, parsedTypes2, decodeRecords)

require.Equal(t, encoded, encoded2)
})
}

0 comments on commit 2d98dcf

Please sign in to comment.