Skip to content

Commit

Permalink
feat: parse jumptable and code in separate sections
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvladco committed Aug 19, 2024
1 parent 24fc849 commit a69a248
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 100 deletions.
238 changes: 145 additions & 93 deletions internal/polkavm/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ const (
BlobVersionV1 byte = 1
VersionDebugLineProgramV1 byte = 1

VmMaximumImportCount uint32 = 1024 // The maximum number of functions the program can import.
VmMaximumJumpTableEntries uint32 = 16 * 1024 * 1024
VmMaximumImportCount uint32 = 1024 // The maximum number of functions the program can import.
VmMaximumCodeSize uint32 = 32 * 1024 * 1024
)

type ProgramParts struct {
Expand All @@ -36,7 +38,10 @@ type ProgramParts struct {
StackSize uint32
ROData []byte
RWData []byte
CodeAndJumpTable []byte
JumpTableEntrySize byte
JumpTable []byte
Code []byte
Bitmask []byte
ImportOffsets []byte
ImportSymbols []byte
Exports []byte
Expand All @@ -45,85 +50,69 @@ type ProgramParts struct {
DebugLinePrograms []byte
}

type Reader interface {
io.Reader
io.Seeker
}

func ParseBlob(reader Reader) (pp *ProgramParts, err error) {
func ParseBlob(r *Reader) (pp *ProgramParts, err error) {
magic := make([]byte, len(BlobMagic))
_, err = reader.Read(magic)
_, err = r.Read(magic)
if err != nil {
return nil, err
}
if [len(BlobMagic)]byte(magic) != BlobMagic {
return pp, fmt.Errorf("blob doesn't start with the expected magic bytes")
}
var blobVersion = new(byte)
err = readByte(reader, blobVersion)
blobVersion, err := r.ReadByte()
if err != nil {
return nil, err
}
if *blobVersion != BlobVersionV1 {
if blobVersion != BlobVersionV1 {
return pp, fmt.Errorf("unsupported version: %d", blobVersion)
}

pp = &ProgramParts{}
section := new(byte)
err = readByte(reader, section)
section, err := r.ReadByte()
if err != nil {
return nil, err
}
if *section == SectionMemoryConfig {
secLen, err := readVariant(reader)
if err != nil {
return nil, err
}
pos, err := reader.Seek(0, io.SeekCurrent)
if section == SectionMemoryConfig {
secLen, err := r.ReadVarint()
if err != nil {
return nil, err
}
pos := r.Position()

pp.RODataSize, err = readVariant(reader)
if err != nil {
if pp.RODataSize, err = r.ReadVarint(); err != nil {
return nil, err
}
pp.RWDataSize, err = readVariant(reader)
if err != nil {
return nil, err
}
pp.StackSize, err = readVariant(reader)
if err != nil {
if pp.RWDataSize, err = r.ReadVarint(); err != nil {
return nil, err
}
pos2, err := reader.Seek(0, io.SeekCurrent)
if err != nil {
if pp.StackSize, err = r.ReadVarint(); err != nil {
return nil, err
}
if pos+int64(secLen) != pos2 {
return pp, fmt.Errorf("the memory config section contains more data than expected %v %v", pos+int64(secLen), pos2)
if pos+int64(secLen) != r.Position() {
return pp, fmt.Errorf("the memory config section contains more data than expected %v %v", pos+int64(secLen), r.Position())
}
err = readByte(reader, section)
section, err = r.ReadByte()
if err != nil {
return nil, err
}
}
if pp.ROData, err = readSectionAsBytes(reader, section, SectionROData); err != nil {
return nil, err
}
if pp.RWData, err = readSectionAsBytes(reader, section, SectionRWData); err != nil {
return nil, err
if section == SectionROData {
if section, pp.ROData, err = r.ReadSection(); err != nil {
return nil, err
}
}
if *section == SectionImports {
secLen, err := readVariant(reader)
if err != nil {
if section == SectionRWData {
if section, pp.RWData, err = r.ReadSection(); err != nil {
return nil, err
}
posStart, err := reader.Seek(0, io.SeekCurrent)
}
if section == SectionImports {
secLen, err := r.ReadVarint()
if err != nil {
return nil, err
}
importCount, err := readVariant(reader)
posStart := r.Position()
importCount, err := r.ReadVarint()
if err != nil {
return nil, err
}
Expand All @@ -133,112 +122,166 @@ func ParseBlob(reader Reader) (pp *ProgramParts, err error) {
//TODO check for underflow and overflow?
importOffsetsSize := importCount * 4
pp.ImportOffsets = make([]byte, importOffsetsSize)
_, err = reader.Read(pp.ImportOffsets)
_, err = r.Read(pp.ImportOffsets)
if err != nil {
return nil, err
}

pos, err := reader.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
//TODO check for underflow?
importSymbolsSize := secLen - uint32(pos-posStart)
importSymbolsSize := secLen - uint32(r.Position()-posStart)
pp.ImportSymbols = make([]byte, importSymbolsSize)
_, err = reader.Read(pp.ImportSymbols)
_, err = r.Read(pp.ImportSymbols)
if err != nil {
return nil, err
}
err = readByte(reader, section)
section, err = r.ReadByte()
if err != nil {
return nil, err
}
}

if pp.Exports, err = readSectionAsBytes(reader, section, SectionExports); err != nil {
return nil, err
if section == SectionExports {
if section, pp.Exports, err = r.ReadSection(); err != nil {
return nil, err
}
}
if pp.CodeAndJumpTable, err = readSectionAsBytes(reader, section, SectionCodeAndJumpTable); err != nil {
return nil, err
if section == SectionCodeAndJumpTable {
secLen, err := r.ReadVarint()
if err != nil {
return nil, err
}
initialPosition := r.Position()
jumpTableEntryCount, err := r.ReadVarint()
if err != nil {
return nil, err
}
if jumpTableEntryCount > VmMaximumJumpTableEntries {
return nil, fmt.Errorf("the jump table section is too long")
}
jumpTableEntrySize, err := r.ReadByte()
if err != nil {
return nil, err
}
codeLength, err := r.ReadVarint()
if err != nil {
return nil, err
}
if codeLength > VmMaximumCodeSize {
return nil, fmt.Errorf("the code section is too long")
}
if jumpTableEntrySize > 4 {
return nil, fmt.Errorf("invalid jump table entry size")
}

//TODO check for underflow and overflow?
jumpTableLength := jumpTableEntryCount * uint32(jumpTableEntrySize)
pp.JumpTableEntrySize = jumpTableEntrySize

pp.JumpTable = make([]byte, jumpTableLength)
if _, err = r.Read(pp.JumpTable); err != nil {
return nil, err
}
pp.Code = make([]byte, codeLength)
if _, err = r.Read(pp.Code); err != nil {
return nil, err
}

bitmaskLength := secLen - uint32(r.Position()-initialPosition)
pp.Bitmask = make([]byte, bitmaskLength)
if _, err = r.Read(pp.Bitmask); err != nil {
return nil, err
}
expectedBitmaskLength := codeLength / 8
if codeLength%8 != 0 {
expectedBitmaskLength += 1
}

if bitmaskLength != expectedBitmaskLength {
return nil, fmt.Errorf("the bitmask length doesn't match the code length")
}
if section, err = r.ReadByte(); err != nil {
return nil, err
}
}
if pp.DebugStrings, err = readSectionAsBytes(reader, section, SectionOptDebugStrings); err != nil {
return nil, err
if section == SectionOptDebugStrings {
if section, pp.DebugStrings, err = r.ReadSection(); err != nil {
return nil, err
}
}
if pp.DebugLinePrograms, err = readSectionAsBytes(reader, section, SectionOptDebugLinePrograms); err != nil {
return nil, err
if section == SectionOptDebugLinePrograms {
if section, pp.DebugLinePrograms, err = r.ReadSection(); err != nil {
return nil, err
}
}
if pp.DebugLineProgramRanges, err = readSectionAsBytes(reader, section, SectionOptDebugLineProgramRanges); err != nil {
return nil, err
if section == SectionOptDebugLineProgramRanges {
if section, pp.DebugLineProgramRanges, err = r.ReadSection(); err != nil {
return nil, err
}
}

for (*section & 0b10000000) != 0 {
for (section & 0b10000000) != 0 {
// We don't know this section, but it's optional, so just skip it.
log.Printf("Skipping unsupported optional section: %v", section)
sectionLength, err := readVariant(reader)
sectionLength, err := r.ReadVarint()
if err != nil {
return nil, err
}
discardBytes := make([]byte, sectionLength)
_, err = reader.Read(discardBytes)
_, err = r.Read(discardBytes)
if err != nil {
return nil, err
}
err = readByte(reader, section)
section, err = r.ReadByte()
if err != nil {
return nil, err
}
}
if *section != SectionEndOfFile {
return nil, fmt.Errorf("unexpected section: %v", *section)
if section != SectionEndOfFile {
return nil, fmt.Errorf("unexpected section: %v", section)
}
return pp, nil
}

func readSectionAsBytes(reader Reader, outSection *byte, expected byte) ([]byte, error) {
if *outSection != expected {
return nil, nil
}
func NewReader(r io.ReadSeeker) *Reader { return &Reader{r} }

type Reader struct{ io.ReadSeeker }

secLen, err := readVariant(reader)
func (r *Reader) ReadSection() (section byte, bytes []byte, err error) {
var secLen uint32
secLen, err = r.ReadVarint()
if err != nil {
return nil, err
return
}
bb := make([]byte, secLen)
_, err = reader.Read(bb)
if err != nil {
return nil, err
bytes = make([]byte, secLen)
if _, err = r.Read(bytes); err != nil {
return
}
err = readByte(reader, outSection)
if err != nil {
return nil, err
}
return bb, nil
section, err = r.ReadByte()
return
}

func readByte(reader Reader, section *byte) error {
func (r *Reader) ReadByte() (byte, error) {
b := make([]byte, 1)
_, err := reader.Read(b)
_, err := r.Read(b)
if err != nil {
return err
return 0, err
}
*section = b[0]
return nil
return b[0], nil
}

func readVariant(reader Reader) (uint32, error) {
firstByte := new(byte)
err := readByte(reader, firstByte)
func (r *Reader) ReadVarint() (uint32, error) {
firstByte, err := r.ReadByte()
if err != nil {
return 0, err
}
length := bits.LeadingZeros8(^*firstByte)
length := bits.LeadingZeros8(^firstByte)
var upperMask uint32 = 0b11111111 >> length
var upperBits = upperMask & uint32(*firstByte) << (length * 8)
var upperBits = upperMask & uint32(firstByte) << (length * 8)
if length == 0 {
return upperBits, nil
}
value := make([]byte, length)
n, err := reader.Read(value)
n, err := r.Read(value)
if err != nil {
return 0, err
}
Expand All @@ -253,3 +296,12 @@ func readVariant(reader Reader) (uint32, error) {
return 0, fmt.Errorf("invalid variant length: %d", n)
}
}

func (r *Reader) Position() int64 {
pos, err := r.Seek(0, io.SeekCurrent)
if err != nil {
panic(fmt.Sprintf("the current position should always be seekable: %v", err))
}

return pos
}
Loading

0 comments on commit a69a248

Please sign in to comment.