From a69a248d63e0a651056ae0a8fafe04f442876aeb Mon Sep 17 00:00:00 2001 From: "daniel.vladco" Date: Mon, 19 Aug 2024 18:48:09 +0300 Subject: [PATCH] feat: parse jumptable and code in separate sections --- internal/polkavm/program.go | 238 +++++++++++++++++++------------ internal/polkavm/program_test.go | 39 ++++- 2 files changed, 177 insertions(+), 100 deletions(-) diff --git a/internal/polkavm/program.go b/internal/polkavm/program.go index 563564b..9e0418b 100644 --- a/internal/polkavm/program.go +++ b/internal/polkavm/program.go @@ -27,7 +27,9 @@ const ( BlobVersionV1 byte = 1 VersionDebugLineProgramV1 byte = 1 - VmMaximumImportCount uint32 = 1024 // The maximum number of functions the program can import. + VmMaximumJumpTableEntries uint32 = 16 * 1024 * 1024 + VmMaximumImportCount uint32 = 1024 // The maximum number of functions the program can import. + VmMaximumCodeSize uint32 = 32 * 1024 * 1024 ) type ProgramParts struct { @@ -36,7 +38,10 @@ type ProgramParts struct { StackSize uint32 ROData []byte RWData []byte - CodeAndJumpTable []byte + JumpTableEntrySize byte + JumpTable []byte + Code []byte + Bitmask []byte ImportOffsets []byte ImportSymbols []byte Exports []byte @@ -45,85 +50,69 @@ type ProgramParts struct { DebugLinePrograms []byte } -type Reader interface { - io.Reader - io.Seeker -} - -func ParseBlob(reader Reader) (pp *ProgramParts, err error) { +func ParseBlob(r *Reader) (pp *ProgramParts, err error) { magic := make([]byte, len(BlobMagic)) - _, err = reader.Read(magic) + _, err = r.Read(magic) if err != nil { return nil, err } if [len(BlobMagic)]byte(magic) != BlobMagic { return pp, fmt.Errorf("blob doesn't start with the expected magic bytes") } - var blobVersion = new(byte) - err = readByte(reader, blobVersion) + blobVersion, err := r.ReadByte() if err != nil { return nil, err } - if *blobVersion != BlobVersionV1 { + if blobVersion != BlobVersionV1 { return pp, fmt.Errorf("unsupported version: %d", blobVersion) } pp = &ProgramParts{} - section := new(byte) - err = readByte(reader, section) + section, err := r.ReadByte() if err != nil { return nil, err } - if *section == SectionMemoryConfig { - secLen, err := readVariant(reader) - if err != nil { - return nil, err - } - pos, err := reader.Seek(0, io.SeekCurrent) + if section == SectionMemoryConfig { + secLen, err := r.ReadVarint() if err != nil { return nil, err } + pos := r.Position() - pp.RODataSize, err = readVariant(reader) - if err != nil { + if pp.RODataSize, err = r.ReadVarint(); err != nil { return nil, err } - pp.RWDataSize, err = readVariant(reader) - if err != nil { - return nil, err - } - pp.StackSize, err = readVariant(reader) - if err != nil { + if pp.RWDataSize, err = r.ReadVarint(); err != nil { return nil, err } - pos2, err := reader.Seek(0, io.SeekCurrent) - if err != nil { + if pp.StackSize, err = r.ReadVarint(); err != nil { return nil, err } - if pos+int64(secLen) != pos2 { - return pp, fmt.Errorf("the memory config section contains more data than expected %v %v", pos+int64(secLen), pos2) + if pos+int64(secLen) != r.Position() { + return pp, fmt.Errorf("the memory config section contains more data than expected %v %v", pos+int64(secLen), r.Position()) } - err = readByte(reader, section) + section, err = r.ReadByte() if err != nil { return nil, err } } - if pp.ROData, err = readSectionAsBytes(reader, section, SectionROData); err != nil { - return nil, err - } - if pp.RWData, err = readSectionAsBytes(reader, section, SectionRWData); err != nil { - return nil, err + if section == SectionROData { + if section, pp.ROData, err = r.ReadSection(); err != nil { + return nil, err + } } - if *section == SectionImports { - secLen, err := readVariant(reader) - if err != nil { + if section == SectionRWData { + if section, pp.RWData, err = r.ReadSection(); err != nil { return nil, err } - posStart, err := reader.Seek(0, io.SeekCurrent) + } + if section == SectionImports { + secLen, err := r.ReadVarint() if err != nil { return nil, err } - importCount, err := readVariant(reader) + posStart := r.Position() + importCount, err := r.ReadVarint() if err != nil { return nil, err } @@ -133,112 +122,166 @@ func ParseBlob(reader Reader) (pp *ProgramParts, err error) { //TODO check for underflow and overflow? importOffsetsSize := importCount * 4 pp.ImportOffsets = make([]byte, importOffsetsSize) - _, err = reader.Read(pp.ImportOffsets) + _, err = r.Read(pp.ImportOffsets) if err != nil { return nil, err } - pos, err := reader.Seek(0, io.SeekCurrent) - if err != nil { - return nil, err - } //TODO check for underflow? - importSymbolsSize := secLen - uint32(pos-posStart) + importSymbolsSize := secLen - uint32(r.Position()-posStart) pp.ImportSymbols = make([]byte, importSymbolsSize) - _, err = reader.Read(pp.ImportSymbols) + _, err = r.Read(pp.ImportSymbols) if err != nil { return nil, err } - err = readByte(reader, section) + section, err = r.ReadByte() if err != nil { return nil, err } } - if pp.Exports, err = readSectionAsBytes(reader, section, SectionExports); err != nil { - return nil, err + if section == SectionExports { + if section, pp.Exports, err = r.ReadSection(); err != nil { + return nil, err + } } - if pp.CodeAndJumpTable, err = readSectionAsBytes(reader, section, SectionCodeAndJumpTable); err != nil { - return nil, err + if section == SectionCodeAndJumpTable { + secLen, err := r.ReadVarint() + if err != nil { + return nil, err + } + initialPosition := r.Position() + jumpTableEntryCount, err := r.ReadVarint() + if err != nil { + return nil, err + } + if jumpTableEntryCount > VmMaximumJumpTableEntries { + return nil, fmt.Errorf("the jump table section is too long") + } + jumpTableEntrySize, err := r.ReadByte() + if err != nil { + return nil, err + } + codeLength, err := r.ReadVarint() + if err != nil { + return nil, err + } + if codeLength > VmMaximumCodeSize { + return nil, fmt.Errorf("the code section is too long") + } + if jumpTableEntrySize > 4 { + return nil, fmt.Errorf("invalid jump table entry size") + } + + //TODO check for underflow and overflow? + jumpTableLength := jumpTableEntryCount * uint32(jumpTableEntrySize) + pp.JumpTableEntrySize = jumpTableEntrySize + + pp.JumpTable = make([]byte, jumpTableLength) + if _, err = r.Read(pp.JumpTable); err != nil { + return nil, err + } + pp.Code = make([]byte, codeLength) + if _, err = r.Read(pp.Code); err != nil { + return nil, err + } + + bitmaskLength := secLen - uint32(r.Position()-initialPosition) + pp.Bitmask = make([]byte, bitmaskLength) + if _, err = r.Read(pp.Bitmask); err != nil { + return nil, err + } + expectedBitmaskLength := codeLength / 8 + if codeLength%8 != 0 { + expectedBitmaskLength += 1 + } + + if bitmaskLength != expectedBitmaskLength { + return nil, fmt.Errorf("the bitmask length doesn't match the code length") + } + if section, err = r.ReadByte(); err != nil { + return nil, err + } } - if pp.DebugStrings, err = readSectionAsBytes(reader, section, SectionOptDebugStrings); err != nil { - return nil, err + if section == SectionOptDebugStrings { + if section, pp.DebugStrings, err = r.ReadSection(); err != nil { + return nil, err + } } - if pp.DebugLinePrograms, err = readSectionAsBytes(reader, section, SectionOptDebugLinePrograms); err != nil { - return nil, err + if section == SectionOptDebugLinePrograms { + if section, pp.DebugLinePrograms, err = r.ReadSection(); err != nil { + return nil, err + } } - if pp.DebugLineProgramRanges, err = readSectionAsBytes(reader, section, SectionOptDebugLineProgramRanges); err != nil { - return nil, err + if section == SectionOptDebugLineProgramRanges { + if section, pp.DebugLineProgramRanges, err = r.ReadSection(); err != nil { + return nil, err + } } - for (*section & 0b10000000) != 0 { + for (section & 0b10000000) != 0 { // We don't know this section, but it's optional, so just skip it. log.Printf("Skipping unsupported optional section: %v", section) - sectionLength, err := readVariant(reader) + sectionLength, err := r.ReadVarint() if err != nil { return nil, err } discardBytes := make([]byte, sectionLength) - _, err = reader.Read(discardBytes) + _, err = r.Read(discardBytes) if err != nil { return nil, err } - err = readByte(reader, section) + section, err = r.ReadByte() if err != nil { return nil, err } } - if *section != SectionEndOfFile { - return nil, fmt.Errorf("unexpected section: %v", *section) + if section != SectionEndOfFile { + return nil, fmt.Errorf("unexpected section: %v", section) } return pp, nil } -func readSectionAsBytes(reader Reader, outSection *byte, expected byte) ([]byte, error) { - if *outSection != expected { - return nil, nil - } +func NewReader(r io.ReadSeeker) *Reader { return &Reader{r} } + +type Reader struct{ io.ReadSeeker } - secLen, err := readVariant(reader) +func (r *Reader) ReadSection() (section byte, bytes []byte, err error) { + var secLen uint32 + secLen, err = r.ReadVarint() if err != nil { - return nil, err + return } - bb := make([]byte, secLen) - _, err = reader.Read(bb) - if err != nil { - return nil, err + bytes = make([]byte, secLen) + if _, err = r.Read(bytes); err != nil { + return } - err = readByte(reader, outSection) - if err != nil { - return nil, err - } - return bb, nil + section, err = r.ReadByte() + return } -func readByte(reader Reader, section *byte) error { +func (r *Reader) ReadByte() (byte, error) { b := make([]byte, 1) - _, err := reader.Read(b) + _, err := r.Read(b) if err != nil { - return err + return 0, err } - *section = b[0] - return nil + return b[0], nil } -func readVariant(reader Reader) (uint32, error) { - firstByte := new(byte) - err := readByte(reader, firstByte) +func (r *Reader) ReadVarint() (uint32, error) { + firstByte, err := r.ReadByte() if err != nil { return 0, err } - length := bits.LeadingZeros8(^*firstByte) + length := bits.LeadingZeros8(^firstByte) var upperMask uint32 = 0b11111111 >> length - var upperBits = upperMask & uint32(*firstByte) << (length * 8) + var upperBits = upperMask & uint32(firstByte) << (length * 8) if length == 0 { return upperBits, nil } value := make([]byte, length) - n, err := reader.Read(value) + n, err := r.Read(value) if err != nil { return 0, err } @@ -253,3 +296,12 @@ func readVariant(reader Reader) (uint32, error) { return 0, fmt.Errorf("invalid variant length: %d", n) } } + +func (r *Reader) Position() int64 { + pos, err := r.Seek(0, io.SeekCurrent) + if err != nil { + panic(fmt.Sprintf("the current position should always be seekable: %v", err)) + } + + return pos +} diff --git a/internal/polkavm/program_test.go b/internal/polkavm/program_test.go index 0ef0661..0d1ef51 100644 --- a/internal/polkavm/program_test.go +++ b/internal/polkavm/program_test.go @@ -2,6 +2,7 @@ package polkavm import ( "embed" + "io" "testing" "github.com/stretchr/testify/assert" @@ -15,15 +16,39 @@ func Test_ParseBlob(t *testing.T) { if err != nil { t.Fatal(err) } - defer f.Close() - pp, err := ParseBlob(f.(Reader)) + pp, err := ParseBlob(NewReader(f.(io.ReadSeeker))) if err != nil { t.Fatal(err) } - assert.Equal(t, pp.StackSize, uint32(4096)) - assert.Equal(t, pp.CodeAndJumpTable, []byte{0, 0, 25, 2, 17, 248, 3, 16, 4, 3, 21, 8, 120, 5, 78, 8, 87, 7, 1, 16, 4, 1, 21, 2, 17, 8, 19, 0, 73, 153, 148, 254}) - assert.Equal(t, pp.ImportOffsets, []byte{0, 0, 0, 0}) - assert.Equal(t, pp.ImportSymbols, []byte{103, 101, 116, 95, 116, 104, 105, 114, 100, 95, 110, 117, 109, 98, 101, 114}) - assert.Equal(t, pp.Exports, []byte{1, 0, 11, 97, 100, 100, 95, 110, 117, 109, 98, 101, 114, 115}) + assert.Equal(t, uint32(0), pp.RODataSize) + assert.Equal(t, uint32(0), pp.RWDataSize) + assert.Equal(t, uint32(4096), pp.StackSize) + assert.Equal(t, pp.JumpTableEntrySize, byte(0)) + assert.Equal(t, []byte{2, 17, 248, 3, 16, 4, 3, 21, 8, 120, 5, 78, 8, 87, 7, 1, 16, 4, 1, 21, 2, 17, 8, 19, 0}, pp.Code) + assert.Equal(t, []byte{73, 153, 148, 254}, pp.Bitmask) + assert.Equal(t, []byte{0, 0, 0, 0}, pp.ImportOffsets) + assert.Equal(t, []byte{103, 101, 116, 95, 116, 104, 105, 114, 100, 95, 110, 117, 109, 98, 101, 114}, pp.ImportSymbols) + assert.Equal(t, []byte{1, 0, 11, 97, 100, 100, 95, 110, 117, 109, 98, 101, 114, 115}, pp.Exports) +} + +// 55..80 +var a = []byte{ + 80, 86, 77, 0, 1, //5 + 1, 4, 0, 0, 144, //10 + 0, 4, 21, 1, 0, //15 + 0, 0, 0, 103, 101, //20 + 116, 95, 116, 104, 105, //25 + 114, 100, 95, 110, 117, //30 + 109, 98, 101, 114, 5, //35 + 14, 1, 0, 11, 97, //40 + 100, 100, 95, 110, 117, //45 + 109, 98, 101, 114, 115, //50 + 6, 32, 0, 0, 25, //55 + 2, 17, 248, 3, 16, //60 + 4, 3, 21, 8, 120, //65 + 5, 78, 8, 87, 7, //70 + 1, 16, 4, 1, 21, //75 + 2, 17, 8, 19, 0, //80 + 73, 153, 148, 254, 0, //85 }