diff --git a/internal/polkavm/instructions.go b/internal/polkavm/instructions.go new file mode 100644 index 0000000..1dd26b6 --- /dev/null +++ b/internal/polkavm/instructions.go @@ -0,0 +1,391 @@ +package polkavm + +import ( + "encoding/binary" +) + +type InstructionCode byte + +const ( + Trap InstructionCode = 0 + Fallthrough InstructionCode = 17 + + JumpIndirect InstructionCode = 19 + LoadImm InstructionCode = 4 + LoadU8 InstructionCode = 60 + LoadI8 InstructionCode = 74 + LoadU16 InstructionCode = 76 + LoadI16 InstructionCode = 66 + LoadU32 InstructionCode = 10 + StoreU8 InstructionCode = 71 + StoreU16 InstructionCode = 69 + StoreU32 InstructionCode = 22 + + LoadImmAndJump InstructionCode = 6 + BranchEqImm InstructionCode = 7 + BranchNotEqImm InstructionCode = 15 + BranchLessUnsignedImm InstructionCode = 44 + BranchLessSignedImm InstructionCode = 32 + BranchGreaterOrEqualUnsignedImm InstructionCode = 52 + BranchGreaterOrEqualSignedImm InstructionCode = 45 + BranchLessOrEqualSignedImm InstructionCode = 46 + BranchLessOrEqualUnsignedImm InstructionCode = 59 + BranchGreaterSignedImm InstructionCode = 53 + BranchGreaterUnsignedImm InstructionCode = 50 + + StoreImmIndirectU8 InstructionCode = 26 + StoreImmIndirectU16 InstructionCode = 54 + StoreImmIndirectU32 InstructionCode = 13 + + StoreIndirectU8 InstructionCode = 16 + StoreIndirectU16 InstructionCode = 29 + StoreIndirectU32 InstructionCode = 3 + LoadIndirectU8 InstructionCode = 11 + LoadIndirectI8 InstructionCode = 21 + LoadIndirectU16 InstructionCode = 37 + LoadIndirectI16 InstructionCode = 33 + LoadIndirectU32 InstructionCode = 1 + AddImm InstructionCode = 2 + AndImm InstructionCode = 18 + XorImm InstructionCode = 31 + OrImm InstructionCode = 49 + MulImm InstructionCode = 35 + MulUpperSignedSignedImm InstructionCode = 65 + MulUpperUnsignedUnsignedImm InstructionCode = 63 + SetLessThanUnsignedImm InstructionCode = 27 + SetLessThanSignedImm InstructionCode = 56 + ShiftLogicalLeftImm InstructionCode = 9 + ShiftLogicalRightImm InstructionCode = 14 + ShiftArithmeticRightImm InstructionCode = 25 + NegateAndAddImm InstructionCode = 40 + SetGreaterThanUnsignedImm InstructionCode = 39 + SetGreaterThanSignedImm InstructionCode = 61 + ShiftLogicalRightImmAlt InstructionCode = 72 + ShiftArithmeticRightImmAlt InstructionCode = 80 + ShiftLogicalLeftImmAlt InstructionCode = 75 + + CmovIfZeroImm InstructionCode = 85 + CmovIfNotZeroImm InstructionCode = 86 + + BranchEq InstructionCode = 24 + BranchNotEq InstructionCode = 30 + BranchLessUnsigned InstructionCode = 47 + BranchLessSigned InstructionCode = 48 + BranchGreaterOrEqualUnsigned InstructionCode = 41 + BranchGreaterOrEqualSigned InstructionCode = 43 + + Add InstructionCode = 8 + Sub InstructionCode = 20 + And InstructionCode = 23 + Xor InstructionCode = 28 + Or InstructionCode = 12 + Mul InstructionCode = 34 + MulUpperSignedSigned InstructionCode = 67 + MulUpperUnsignedUnsigned InstructionCode = 57 + MulUpperSignedUnsigned InstructionCode = 81 + SetLessThanUnsigned InstructionCode = 36 + SetLessThanSigned InstructionCode = 58 + ShiftLogicalLeft InstructionCode = 55 + ShiftLogicalRight InstructionCode = 51 + ShiftArithmeticRight InstructionCode = 77 + DivUnsigned InstructionCode = 68 + DivSigned InstructionCode = 64 + RemUnsigned InstructionCode = 73 + RemSigned InstructionCode = 70 + + CmovIfZero InstructionCode = 83 + CmovIfNotZero InstructionCode = 84 + + Jump InstructionCode = 5 + + Ecalli InstructionCode = 78 + + StoreImmU8 InstructionCode = 62 + StoreImmU16 InstructionCode = 79 + StoreImmU32 InstructionCode = 38 + + MoveReg InstructionCode = 82 + Sbrk InstructionCode = 87 + + LoadImmAndJumpIndirect InstructionCode = 42 +) + +type Reg byte + +func (r Reg) String() string { + switch r { + case RA: + return "ra" + case SP: + return "sp" + case T0: + return "t0" + case T1: + return "t1" + case T2: + return "t2" + case S0: + return "s0" + case S1: + return "s1" + case A0: + return "a0" + case A1: + return "a1" + case A2: + return "a2" + case A3: + return "a3" + case A4: + return "a4" + case A5: + return "a5" + default: + panic("unreachable") + } +} + +const ( + RA Reg = 0 + SP Reg = 1 + T0 Reg = 2 + T1 Reg = 3 + T2 Reg = 4 + S0 Reg = 5 + S1 Reg = 6 + A0 Reg = 7 + A1 Reg = 8 + A2 Reg = 9 + A3 Reg = 10 + A4 Reg = 11 + A5 Reg = 12 +) + +func parseReg(v byte) Reg { + value := v & 0b1111 + if value > 12 { + value = 12 + } + switch value { + case 0: + return RA + case 1: + return SP + case 2: + return T0 + case 3: + return T1 + case 4: + return T2 + case 5: + return S0 + case 6: + return S1 + case 7: + return A0 + case 8: + return A1 + case 9: + return A2 + case 10: + return A3 + case 11: + return A4 + case 12: + return A5 + default: + panic("unreachable") + } +} + +var ( + // Instructions with args: none + instrNone = []InstructionCode{Trap, Fallthrough} + // Instructions with args: reg, imm + instrRegImm = []InstructionCode{JumpIndirect, LoadImm, LoadU8, LoadI8, LoadU16, LoadI16, LoadU32, StoreU8, StoreU16, StoreU32} + // Instructions with args: reg, imm, offset + instrRegImmOffset = []InstructionCode{LoadImmAndJump, BranchEqImm, BranchNotEqImm, BranchLessUnsignedImm, BranchLessSignedImm, BranchGreaterOrEqualUnsignedImm, BranchGreaterOrEqualSignedImm, BranchLessOrEqualSignedImm, BranchLessOrEqualUnsignedImm, BranchGreaterSignedImm, BranchGreaterUnsignedImm} + // Instructions with args: reg, imm, imm + instrRegImm2 = []InstructionCode{StoreImmIndirectU8, StoreImmIndirectU16, StoreImmIndirectU32} + // Instructions with args: reg, reg, imm + instrReg2Imm = []InstructionCode{StoreIndirectU8, StoreIndirectU16, StoreIndirectU32, LoadIndirectU8, LoadIndirectI8, LoadIndirectU16, LoadIndirectI16, LoadIndirectU32, AddImm, AndImm, XorImm, OrImm, MulImm, MulUpperSignedSignedImm, MulUpperUnsignedUnsignedImm, SetLessThanUnsignedImm, SetLessThanSignedImm, ShiftLogicalLeftImm, ShiftLogicalRightImm, ShiftArithmeticRightImm, NegateAndAddImm, SetGreaterThanUnsignedImm, SetGreaterThanSignedImm, ShiftLogicalRightImmAlt, ShiftArithmeticRightImmAlt, ShiftLogicalLeftImmAlt, CmovIfZeroImm, CmovIfNotZeroImm} + // Instructions with args: reg, reg, offset + instrReg2Offset = []InstructionCode{BranchEq, BranchNotEq, BranchLessUnsigned, BranchLessSigned, BranchGreaterOrEqualUnsigned, BranchGreaterOrEqualSigned} + // Instructions with args: reg, reg, reg + instrReg3 = []InstructionCode{Add, Sub, And, Xor, Or, Mul, MulUpperSignedSigned, MulUpperUnsignedUnsigned, MulUpperSignedUnsigned, SetLessThanUnsigned, SetLessThanSigned, ShiftLogicalLeft, ShiftLogicalRight, ShiftArithmeticRight, DivUnsigned, DivSigned, RemUnsigned, RemSigned, CmovIfZero, CmovIfNotZero} + // Instructions with args: offset + instrOffset = []InstructionCode{Jump} + // Instructions with args: imm + instrImm = []InstructionCode{Ecalli} + // Instructions with args: imm, imm + instrImm2 = []InstructionCode{StoreImmU8, StoreImmU16, StoreImmU32} + // Instructions with args: reg, reg + instrRegReg = []InstructionCode{MoveReg, Sbrk} + // Instructions with args: reg, reg, imm, imm + instrReg2Imm2 = []InstructionCode{LoadImmAndJumpIndirect} +) + +type InstrParseArgFunc func(chunk []byte, instructionOffset, argsLength uint32) ([]Reg, []uint32) + +var parseArgsTable = map[InstructionCode]InstrParseArgFunc{} + +func init() { + for _, code := range instrNone { + parseArgsTable[code] = parseArgsNone + } + for _, code := range instrRegImm { + parseArgsTable[code] = parseArgsRegImm + } + for _, code := range instrRegImmOffset { + parseArgsTable[code] = parseArgsRegImmOffset + } + for _, code := range instrRegImm2 { + parseArgsTable[code] = parseArgsRegImm2 + } + for _, code := range instrReg2Imm { + parseArgsTable[code] = parseArgsRegs2Imm + } + for _, code := range instrReg2Offset { + parseArgsTable[code] = parseArgsRegs2Offset + } + for _, code := range instrReg3 { + parseArgsTable[code] = parseArgsRegs3 + } + for _, code := range instrOffset { + parseArgsTable[code] = parseArgsImmOffset + } + for _, code := range instrImm { + parseArgsTable[code] = parseArgsImm + } + for _, code := range instrImm2 { + parseArgsTable[code] = parseArgsImm2 + } + for _, code := range instrRegReg { + parseArgsTable[code] = parseArgsRegs2 + } + for _, code := range instrReg2Imm2 { + parseArgsTable[code] = parseArgsRegs2Imm2 + } +} + +func clamp(start, end, value uint32) uint32 { + if value < start { + return start + } else if value > end { + return end + } else { + return value + } +} + +func sext(value uint32, length uint32) uint32 { + switch length { + case 0: + return 0 + case 1: + return uint32(int32(int8(uint8(value)))) + case 2: + return uint32(int32(int16(uint16(value)))) + case 3: + return uint32((int32(value << 8)) >> 8) + case 4: + return value + default: + panic("unreachable") + } +} +func read(slice []byte, offset, length uint32) uint32 { + slice = slice[offset : offset+length] + switch length { + case 0: + return 0 + case 1: + return uint32(slice[0]) + case 2: + return uint32(binary.LittleEndian.Uint16(slice[:1])) // u16::from_le_bytes([slice[0], slice[1]]) as u32 + case 3: + return binary.LittleEndian.Uint32([]byte{slice[0], slice[1], slice[2], 0}) + case 4: + return binary.LittleEndian.Uint32([]byte{slice[0], slice[1], slice[2], slice[3]}) + default: + panic("unreachable") + } +} + +func parseArgsImm(code []byte, _, skip uint32) ([]Reg, []uint32) { + immLength := min(4, skip) + return nil, []uint32{sext(read(code, 0, immLength), immLength)} +} + +func parseArgsImmOffset(code []byte, instructionOffset, skip uint32) ([]Reg, []uint32) { + _, imm := parseArgsImm(code, instructionOffset, skip) + return nil, []uint32{instructionOffset + imm[0]} +} + +func parseArgsImm2(code []byte, _, skip uint32) ([]Reg, []uint32) { + imm1Length := min(4, uint32(code[0])&0b111) + imm2Length := clamp(0, 4, skip-imm1Length-1) + imm1 := sext(read(code, 1, imm1Length), imm1Length) + imm2 := sext(read(code, 1+imm1Length, imm2Length), imm2Length) + return nil, []uint32{imm1, imm2} +} + +func parseArgsNone(_ []byte, _, _ uint32) ([]Reg, []uint32) { + return nil, nil +} + +func parseArgsRegImm(code []byte, _, skip uint32) ([]Reg, []uint32) { + reg := min(12, code[0]&0b1111) + immLength := clamp(0, 4, skip-1) + imm := sext(read(code, 1, immLength), immLength) + return []Reg{parseReg(reg)}, []uint32{imm} +} + +func parseArgsRegImmOffset(code []byte, instructionOffset, skip uint32) ([]Reg, []uint32) { + regs, imm := parseArgsRegImm2(code, instructionOffset, skip) + return regs, []uint32{imm[0], instructionOffset + imm[1]} +} + +func parseArgsRegImm2(code []byte, _, skip uint32) ([]Reg, []uint32) { + reg := min(12, code[0]&0b1111) + imm1Length := min(4, uint32(code[0]>>4)&0b111) + imm2Length := clamp(0, 4, skip-imm1Length-1) + imm1 := sext(read(code, 1, imm1Length), imm1Length) + imm2 := sext(read(code, 1+imm1Length, imm2Length), imm2Length) + return []Reg{parseReg(reg)}, []uint32{imm1, imm2} +} + +func parseArgsRegs2Imm2(code []byte, _, skip uint32) ([]Reg, []uint32) { + reg1 := min(12, code[0]&0b1111) + reg2 := min(12, code[0]>>4) + imm1Length := min(4, uint32(code[1])&0b111) + imm2Length := clamp(0, 4, skip-imm1Length-2) + imm1 := sext(read(code, 2, imm1Length), imm1Length) + imm2 := sext(read(code, 2+imm1Length, imm2Length), imm2Length) + return []Reg{parseReg(reg1), parseReg(reg2)}, []uint32{imm1, imm2} +} +func parseArgsRegs2Imm(code []byte, _, skip uint32) ([]Reg, []uint32) { + immLength := clamp(0, 4, uint32(skip)-1) + imm := sext(read(code, 1, immLength), immLength) + return []Reg{ + parseReg(min(12, code[0]&0b1111)), + parseReg(min(12, code[0]>>4)), + }, []uint32{imm} +} + +func parseArgsRegs3(code []byte, _, _ uint32) ([]Reg, []uint32) { + return []Reg{ + parseReg(min(12, code[1]&0b1111)), + parseReg(min(12, code[0]&0b1111)), + parseReg(min(12, code[0]>>4)), + }, nil +} + +func parseArgsRegs2(code []byte, _, _ uint32) ([]Reg, []uint32) { + return []Reg{parseReg(min(12, code[0]&0b1111)), parseReg(min(12, code[0]>>4))}, nil +} + +func parseArgsRegs2Offset(code []byte, instructionOffset, skip uint32) ([]Reg, []uint32) { + regs, imm := parseArgsRegs2Imm(code, instructionOffset, skip) + return regs, []uint32{instructionOffset + imm[0]} +} diff --git a/internal/polkavm/program.go b/internal/polkavm/program.go new file mode 100644 index 0000000..aab8645 --- /dev/null +++ b/internal/polkavm/program.go @@ -0,0 +1,482 @@ +package polkavm + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "log" + "math/bits" +) + +// BlobMagic The magic bytes with which every program blob must start with. +var BlobMagic = [4]byte{'P', 'V', 'M', 0} + +// program blob sections +const ( + SectionMemoryConfig byte = 1 + SectionROData byte = 2 + SectionRWData byte = 3 + SectionImports byte = 4 + SectionExports byte = 5 + SectionCodeAndJumpTable byte = 6 + SectionOptDebugStrings byte = 128 + SectionOptDebugLinePrograms byte = 129 + SectionOptDebugLineProgramRanges byte = 130 + SectionEndOfFile byte = 0 + + BlobVersionV1 byte = 1 + VersionDebugLineProgramV1 byte = 1 + + VmMaximumJumpTableEntries uint32 = 16 * 1024 * 1024 + VmMaximumImportCount uint32 = 1024 // The maximum number of functions the program can import. + VmMaximumCodeSize uint32 = 32 * 1024 * 1024 + VmCodeAddressAlignment uint32 = 2 + BitmaskMax = 24 +) + +type Program struct { + RODataSize uint32 + RWDataSize uint32 + StackSize uint32 + ROData []byte + RWData []byte + JumpTable []uint32 + Instructions []Instruction + Imports []string + Exports []ProgramExport + DebugStrings []byte + DebugLineProgramRanges []byte + DebugLinePrograms []byte +} + +type ProgramExport struct { + TargetCodeOffset uint32 + Symbol string +} + +type Instruction struct { + Code InstructionCode + Imm []uint32 + Reg []Reg + Offset uint32 + Length uint32 +} + +func ParseBlob(r *Reader) (pp *Program, err error) { + magic := make([]byte, len(BlobMagic)) + if _, err = r.Read(magic); err != nil { + return nil, err + } + if !bytes.Equal(magic, BlobMagic[:]) { + return pp, fmt.Errorf("blob doesn't start with the expected magic bytes") + } + blobVersion, err := r.ReadByte() + if err != nil { + return nil, err + } + if blobVersion != BlobVersionV1 { + return pp, fmt.Errorf("unsupported version: %d", blobVersion) + } + + pp = &Program{} + section, err := r.ReadByte() + if err != nil { + return nil, err + } + if section == SectionMemoryConfig { + if section, err = parseMemoryConfig(r, pp); err != nil { + return nil, err + } + } + if section == SectionROData { + if pp.ROData, err = r.ReadWithLength(); err != nil { + return nil, err + } + if section, err = r.ReadByte(); err != nil { + return nil, err + } + } + if section == SectionRWData { + if pp.RWData, err = r.ReadWithLength(); err != nil { + return nil, err + } + if section, err = r.ReadByte(); err != nil { + return nil, err + } + } + if section == SectionImports { + if section, err = parseImports(r, pp); err != nil { + return nil, err + } + } + + if section == SectionExports { + if section, err = parseExports(r, pp); err != nil { + return nil, err + } + } + if section == SectionCodeAndJumpTable { + if section, err = parseCodeAndJumpTable(r, pp); err != nil { + return nil, err + } + } + if section == SectionOptDebugStrings { + if pp.DebugStrings, err = r.ReadWithLength(); err != nil { + return nil, err + } + if section, err = r.ReadByte(); err != nil { + return nil, err + } + } + if section == SectionOptDebugLinePrograms { + if pp.DebugLinePrograms, err = r.ReadWithLength(); err != nil { + return nil, err + } + if section, err = r.ReadByte(); err != nil { + return nil, err + } + } + if section == SectionOptDebugLineProgramRanges { + if pp.DebugLineProgramRanges, err = r.ReadWithLength(); err != nil { + return nil, err + } + if section, err = r.ReadByte(); err != nil { + return nil, err + } + } + + for (section & 0b10000000) != 0 { + // We don't know this section, but it's optional, so just skip it. + log.Printf("Skipping unsupported optional section: %v", section) + sectionLength, err := r.ReadVarint() + if err != nil { + return nil, err + } + discardBytes := make([]byte, sectionLength) + _, err = r.Read(discardBytes) + if err != nil { + return nil, err + } + section, err = r.ReadByte() + if err != nil { + return nil, err + } + } + if section != SectionEndOfFile { + return nil, fmt.Errorf("unexpected section: %v", section) + } + return pp, nil +} + +func parseMemoryConfig(r *Reader, p *Program) (byte, error) { + secLen, err := r.ReadVarint() + if err != nil { + return 0, err + } + pos := r.Position() + + if p.RODataSize, err = r.ReadVarint(); err != nil { + return 0, err + } + if p.RWDataSize, err = r.ReadVarint(); err != nil { + return 0, err + } + if p.StackSize, err = r.ReadVarint(); err != nil { + return 0, err + } + if pos+int64(secLen) != r.Position() { + return 0, fmt.Errorf("the memory config section contains more data than expected %v %v", pos+int64(secLen), r.Position()) + } + + return r.ReadByte() +} + +func parseImports(r *Reader, p *Program) (byte, error) { + secLen, err := r.ReadVarint() + if err != nil { + return 0, err + } + posStart := r.Position() + importCount, err := r.ReadVarint() + if err != nil { + return 0, err + } + if importCount > VmMaximumImportCount { + return 0, fmt.Errorf("too many imports") + } + //TODO check for underflow and overflow? + importOffsetsSize := importCount * 4 + importOffsets := make([]byte, importOffsetsSize) + _, err = r.Read(importOffsets) + if err != nil { + return 0, err + } + + //TODO check for underflow? + importSymbolsSize := secLen - uint32(r.Position()-posStart) + importSymbols := make([]byte, importSymbolsSize) + _, err = r.Read(importSymbols) + if err != nil { + return 0, err + } + + if len(importOffsets)%4 != 0 { + return 0, fmt.Errorf("invalid import offsets data: %d", len(importOffsets)) + } + var offsets []uint32 + for i := 0; i < len(importOffsets); i += 4 { + offsets = append(offsets, binary.BigEndian.Uint32(importOffsets[i:i+4])) + } + for i := 0; i < len(offsets); i += 2 { + if i+1 == len(offsets) { + p.Imports = append(p.Imports, string(importSymbols[offsets[i]:])) + continue + } + p.Imports = append(p.Imports, string(importSymbols[offsets[i]:offsets[i+1]])) + } + + return r.ReadByte() +} + +func parseCodeAndJumpTable(r *Reader, p *Program) (byte, error) { + secLen, err := r.ReadVarint() + if err != nil { + return 0, err + } + initialPosition := r.Position() + jumpTableEntryCount, err := r.ReadVarint() + if err != nil { + return 0, err + } + if jumpTableEntryCount > VmMaximumJumpTableEntries { + return 0, fmt.Errorf("the jump table section is too long") + } + jumpTableEntrySize, err := r.ReadByte() + if err != nil { + return 0, err + } + codeLength, err := r.ReadVarint() + if err != nil { + return 0, err + } + if codeLength > VmMaximumCodeSize { + return 0, fmt.Errorf("the code section is too long") + } + if jumpTableEntrySize > 4 { + return 0, fmt.Errorf("invalid jump table entry size") + } + + //TODO check for underflow and overflow? + jumpTableLength := jumpTableEntryCount * uint32(jumpTableEntrySize) + + jumpTable := make([]byte, jumpTableLength) + if _, err = r.Read(jumpTable); err != nil { + return 0, err + } + for i := 0; i < len(jumpTable); i += int(jumpTableEntrySize) { + switch jumpTableEntrySize { + case 1: + p.JumpTable = append(p.JumpTable, uint32(jumpTable[i])) + case 2: + p.JumpTable = append(p.JumpTable, uint32(binary.BigEndian.Uint16(jumpTable[i:i+2]))) + case 3: + p.JumpTable = append(p.JumpTable, binary.BigEndian.Uint32( + []byte{jumpTable[i], jumpTable[i+1], jumpTable[i+2], 0}, + )) + case 4: + p.JumpTable = append(p.JumpTable, binary.BigEndian.Uint32(jumpTable[i:i+int(jumpTableEntrySize)])) + default: + panic("unreachable") + } + } + + code := make([]byte, codeLength) + if _, err = r.Read(code); err != nil { + return 0, err + } + + bitmaskLength := secLen - uint32(r.Position()-initialPosition) + bitmask := make([]byte, bitmaskLength) + if _, err = r.Read(bitmask); err != nil { + return 0, err + } + expectedBitmaskLength := codeLength / 8 + if codeLength%8 != 0 { + expectedBitmaskLength += 1 + } + + if bitmaskLength != expectedBitmaskLength { + return 0, fmt.Errorf("the bitmask Length doesn't match the code Length") + } + + offset := 0 + for offset < len(code) { + nextOffset, instr, err := parseInstruction(code, bitmask, offset) + if err != nil { + return 0, err + } + p.Instructions = append(p.Instructions, instr) + offset = nextOffset + } + return r.ReadByte() +} + +func parseInstruction(code, bitmask []byte, instructionOffset int) (int, Instruction, error) { + if len(bitmask) == 0 { + return 0, Instruction{}, io.EOF + } + + nextOffset, argsLength := parseBitmaskSlow(bitmask, instructionOffset) + chunkLength := min(16, argsLength+1) + chunk := code[instructionOffset : instructionOffset+chunkLength] + opcode := InstructionCode(chunk[0]) + regs, imm := parseArgsTable[opcode](chunk[1:], uint32(instructionOffset), uint32(argsLength)) + return nextOffset, Instruction{ + Code: opcode, + Reg: regs, + Imm: imm, + Offset: uint32(instructionOffset), + Length: uint32(argsLength + 1), + }, nil +} + +//lint:ignore U1000 +func parseBitmaskFast(bitmask []byte, offset int) (int, int) { + offset += 1 + shift := offset & 7 + mask := (binary.LittleEndian.Uint32(bitmask[offset>>3:(offset>>3)+4]) >> shift) | (1 << BitmaskMax) + argsLength := bits.TrailingZeros32(mask) + if argsLength > BitmaskMax { + panic("args Length too big") + } + offset += argsLength + + return offset, argsLength +} + +func parseBitmaskSlow(bitmask []byte, offset int) (int, int) { + offset += 1 + argsLength := 0 + for offset>>3 < len(bitmask) { + b := bitmask[offset>>3] + shift := offset & 7 + mask := b >> shift + length := 0 + if mask == 0 { + length = 8 - shift + } else { + length = bits.TrailingZeros(uint(mask)) + if length == 0 { + break + } + } + + newArgsLength := argsLength + length + if newArgsLength >= BitmaskMax { + offset += BitmaskMax - argsLength + argsLength = BitmaskMax + break + } + + argsLength = newArgsLength + offset += length + } + + return offset, argsLength +} + +func parseExports(r *Reader, p *Program) (byte, error) { + var secLen uint32 + secLen, err := r.ReadVarint() + if err != nil { + return 0, err + } + initialPosition := r.Position() + nr, err := r.ReadVarint() + if err != nil { + return 0, err + } + for i := 0; i < int(nr); i++ { + targetCodeOffset, err := r.ReadVarint() + if err != nil { + return 0, err + } + symbol, err := r.ReadWithLength() + if err != nil { + return 0, err + } + + p.Exports = append(p.Exports, ProgramExport{ + TargetCodeOffset: targetCodeOffset, + Symbol: string(symbol), + }) + } + + if initialPosition+int64(secLen) != r.Position() { + return 0, fmt.Errorf("invalid exports section Length: %v", secLen) + } + + return r.ReadByte() +} + +func NewReader(r io.ReadSeeker) *Reader { return &Reader{r} } + +type Reader struct{ io.ReadSeeker } + +func (r *Reader) ReadWithLength() ([]byte, error) { + length, err := r.ReadVarint() + if err != nil { + return nil, err + } + bytes := make([]byte, length) + if _, err = r.Read(bytes); err != nil { + return nil, err + } + return bytes, nil +} + +func (r *Reader) ReadByte() (byte, error) { + b := make([]byte, 1) + _, err := r.Read(b) + return b[0], err +} + +func (r *Reader) ReadVarint() (uint32, error) { + firstByte, err := r.ReadByte() + if err != nil { + return 0, err + } + length := bits.LeadingZeros8(^firstByte) + var upperMask uint32 = 0b11111111 >> length + var upperBits = upperMask & uint32(firstByte) << (length * 8) + if length == 0 { + return upperBits, nil + } + value := make([]byte, length) + n, err := r.Read(value) + if err != nil { + return 0, err + } + switch n { + case 1: + return upperBits | uint32(value[0]), nil + case 2: + return upperBits | uint32(binary.BigEndian.Uint16(value)), nil + case 3: + return upperBits | binary.BigEndian.Uint32([]byte{value[0], value[1], value[2], 0}), nil + case 4: + return upperBits | binary.BigEndian.Uint32(value), nil + default: + return 0, fmt.Errorf("invalid varint Length: %d", n) + } +} + +func (r *Reader) Position() int64 { + pos, err := r.Seek(0, io.SeekCurrent) + if err != nil { + panic(fmt.Sprintf("the current position should always be seekable: %v", err)) + } + + return pos +} diff --git a/internal/polkavm/program_test.go b/internal/polkavm/program_test.go new file mode 100644 index 0000000..56443bc --- /dev/null +++ b/internal/polkavm/program_test.go @@ -0,0 +1,137 @@ +package polkavm + +import ( + "embed" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +//go:embed testdata +var fs embed.FS + +func Test_ParseBlob(t *testing.T) { + f, err := fs.Open("testdata/example-hello-world.polkavm") + if err != nil { + t.Fatal(err) + } + defer f.Close() + pp, err := ParseBlob(NewReader(f.(io.ReadSeeker))) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, uint32(0), pp.RODataSize) + assert.Equal(t, uint32(0), pp.RWDataSize) + assert.Equal(t, uint32(4096), pp.StackSize) + if assert.Equal(t, 10, len(pp.Instructions)) { + assert.Equal(t, Instruction{ + Code: AddImm, + Imm: []uint32{4294967288}, + Reg: []Reg{SP, SP}, + Offset: 0, Length: 3, + }, pp.Instructions[0]) + + assert.Equal(t, Instruction{ + Code: StoreIndirectU32, + Imm: []uint32{4}, + Reg: []Reg{RA, SP}, + Offset: 3, + Length: 3, + }, pp.Instructions[1]) + assert.Equal(t, Instruction{ + Code: StoreIndirectU32, + Imm: []uint32{0}, + Reg: []Reg{S0, SP}, + Offset: 6, + Length: 2, + }, pp.Instructions[2]) + assert.Equal(t, Instruction{ + Code: Add, + Reg: []Reg{S0, A1, A0}, + Offset: 8, + Length: 3, + }, pp.Instructions[3]) + assert.Equal(t, Instruction{ + Code: Ecalli, + Imm: []uint32{0}, + Offset: 11, + Length: 1, + }, pp.Instructions[4]) + assert.Equal(t, Instruction{ + Code: Add, + Imm: nil, + Reg: []Reg{A0, A0, S0}, + Offset: 12, + Length: 3, + }, pp.Instructions[5]) + assert.Equal(t, Instruction{ + Code: LoadIndirectU32, + Imm: []uint32{4}, + Reg: []Reg{RA, SP}, + Offset: 15, + Length: 3, + }, pp.Instructions[6]) + assert.Equal(t, Instruction{ + Code: LoadIndirectU32, + Imm: []uint32{0}, + Reg: []Reg{S0, SP}, + Offset: 18, + Length: 2, + }, pp.Instructions[7]) + assert.Equal(t, Instruction{ + Code: AddImm, + Imm: []uint32{8}, + Reg: []Reg{SP, SP}, + Offset: 20, + Length: 3, + }, pp.Instructions[8]) + assert.Equal(t, Instruction{ + Code: JumpIndirect, + Imm: []uint32{0}, + Reg: []Reg{RA}, + Offset: 23, + Length: 2, + }, pp.Instructions[9]) + } + assert.Equal(t, []string{"get_third_number"}, pp.Imports) + assert.Equal(t, []ProgramExport{{0, "add_numbers"}}, pp.Exports) +} + +func Test_parseBitmaskFast(t *testing.T) { + table := []struct { + bitmask []byte + offset, nextOffset, argsLength int + }{ + {[]byte{0b00000011, 0, 0, 0}, 0, 1, 0}, + {[]byte{0b00000101, 0, 0, 0}, 0, 2, 1}, + {[]byte{0b10000001, 0, 0, 0}, 0, 7, 6}, + {[]byte{0b00000001, 1, 0, 0}, 0, 8, 7}, + {[]byte{0b00000001, 1 << 7, 0, 0}, 0, 15, 14}, + {[]byte{0b00000001, 0, 1, 0}, 0, 16, 15}, + {[]byte{0b00000001, 0, 1 << 7, 0}, 0, 23, 22}, + {[]byte{0b00000001, 0, 0, 1}, 0, 24, 23}, + + {[]byte{0b11000000, 0, 0, 0, 0}, 6, 7, 0}, + {[]byte{0b01000000, 1, 0, 0, 0}, 6, 8, 1}, + + {[]byte{0b10000000, 1, 0, 0, 0}, 7, 8, 0}, + {[]byte{0b10000000, 1 << 1, 0, 0, 0}, 7, 9, 1}, + + {[]byte{0, 0, 0, 0, 0b00000001}, 0, 25, 24}, + {[]byte{0, 0, 0, 0, 0b00000001}, 6, 31, 24}, + {[]byte{0, 0, 0, 0, 0b00000001}, 7, 32, 24}, + } + for i, tc := range table { + nextOffset, argsLength := parseBitmaskFast(tc.bitmask, tc.offset) + nextOffsetSlow, argsLengthSlow := parseBitmaskSlow(tc.bitmask, tc.offset) + + assert.Equal(t, nextOffset, nextOffsetSlow, "index: %d", i) + assert.Equal(t, argsLength, argsLengthSlow, "index: %d", i) + + assert.Equal(t, tc.nextOffset, nextOffset, "index: %d", i) + assert.Equal(t, tc.argsLength, argsLength, "index: %d", i) + } +} + +//JumpIndirect, LoadImm, LoadU8, LoadI8, LoadU16, LoadI16, LoadU32, StoreU8, StoreU16, StoreU32 diff --git a/internal/polkavm/testdata/example-hello-world.polkavm b/internal/polkavm/testdata/example-hello-world.polkavm new file mode 100644 index 0000000..f6d11ab Binary files /dev/null and b/internal/polkavm/testdata/example-hello-world.polkavm differ