From b58ef82592b8968e2de3cb0fa1b2d3bd80944c2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20G=C3=B6rgens?= Date: Fri, 13 Dec 2024 12:29:11 +0800 Subject: [PATCH] Implement the TODOs in the Harvard Architecture PR (#711) In the [Harvard Architecture PR](https://github.com/scroll-tech/ceno/pull/688) we left a few TODOs. Here we make good on them. There's a few things happening in this PR: - We remove the fake instruction `EANY`. `EANY` was only ever an artifact of Risc0's incredible 'interesting' approach to decoding. We use the real instructions `ECALL` and `EBREAK` instead. - Remove `AUIPC` and `LUI`. Both of them are now implemented as pseudo-instructions, that translate to `ADDI` during decoding. - Use the same library as SP1 for RiscV decoding, instead of copy-and-pasting-and-editing Risc0's 'interesting' decoder. That simplifies our code, and comes with a lot more tests than we ever had. Both because of explicit tests in the library, and because of the usage in SP1 and other projects. This gets rid of much error prone bit manipulating code. - Use `struct Instruction` throughout the code when handling and testing instructions, instead of `u32`. That makes specifying tests a lot simpler and more readable. No more `0b_000000001010_00000_000_00001_0010011, // addi x1, x0, 10` in the code. - Remove the notion of executable vs non-executable ROM. This is only necessary for a von-Neumann architecture: everything that's in our instruction-cache is meant to be executable already. (We can re-implement this restriction later by controlling what is allowed to make it into the instruction cache when we eg decode the ELF. But it's unnecessary: we already honour the executable flag for memory sections in the ELF.) --- Cargo.lock | 75 ++ ceno_emul/Cargo.toml | 1 + ceno_emul/src/disassemble/mod.rs | 378 ++++++++ ceno_emul/src/elf.rs | 8 +- ceno_emul/src/lib.rs | 7 +- ceno_emul/src/platform.rs | 5 - ceno_emul/src/rv32im.rs | 895 ++++++------------ ceno_emul/src/rv32im_encode.rs | 116 --- ceno_emul/src/tracer.rs | 46 +- ceno_emul/src/vm_state.rs | 20 +- ceno_emul/tests/test_elf.rs | 2 +- ceno_emul/tests/test_vm_trace.rs | 61 +- ceno_zkvm/examples/riscv_opcodes.rs | 57 +- ceno_zkvm/src/e2e.rs | 9 +- ceno_zkvm/src/instructions/riscv.rs | 5 +- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 7 +- .../src/instructions/riscv/branch/test.rs | 16 +- .../instructions/riscv/dummy/dummy_circuit.rs | 18 +- .../src/instructions/riscv/dummy/test.rs | 2 +- ceno_zkvm/src/instructions/riscv/ecall.rs | 2 +- .../src/instructions/riscv/ecall_insn.rs | 4 +- ceno_zkvm/src/instructions/riscv/jump.rs | 4 - .../src/instructions/riscv/jump/auipc.rs | 94 -- ceno_zkvm/src/instructions/riscv/jump/lui.rs | 64 -- ceno_zkvm/src/instructions/riscv/jump/test.rs | 78 +- .../riscv/logic_imm/logic_imm_circuit.rs | 4 +- .../src/instructions/riscv/memory/test.rs | 53 +- ceno_zkvm/src/instructions/riscv/rv32im.rs | 40 +- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 8 +- ceno_zkvm/src/instructions/riscv/slti.rs | 2 +- .../src/instructions/riscv/test_utils.rs | 23 - ceno_zkvm/src/instructions/riscv/u_insn.rs | 65 -- ceno_zkvm/src/scheme/mock_prover.rs | 35 +- ceno_zkvm/src/scheme/tests.rs | 43 +- ceno_zkvm/src/tables/program.rs | 65 +- clippy.toml | 7 +- 36 files changed, 959 insertions(+), 1360 deletions(-) create mode 100644 ceno_emul/src/disassemble/mod.rs delete mode 100644 ceno_emul/src/rv32im_encode.rs delete mode 100644 ceno_zkvm/src/instructions/riscv/jump/auipc.rs delete mode 100644 ceno_zkvm/src/instructions/riscv/jump/lui.rs delete mode 100644 ceno_zkvm/src/instructions/riscv/test_utils.rs delete mode 100644 ceno_zkvm/src/instructions/riscv/u_insn.rs diff --git a/Cargo.lock b/Cargo.lock index d84685eb4..865b783f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -238,6 +238,7 @@ dependencies = [ "itertools 0.13.0", "num-derive", "num-traits", + "rrs-succinct", "strum", "strum_macros", "tracing", @@ -583,6 +584,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "downcast-rs" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" + [[package]] name = "either" version = "1.13.0" @@ -1194,6 +1201,27 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_enum" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f646caf906c20226733ed5b1374287eb97e3c2a5c227ce668c1f2ce20ae57c9" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcbff9bc912032c62bf65ef1d5aea88983b420f4f839db1e9b0c281a25c9c799" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "object" version = "0.36.5" @@ -1424,6 +1452,16 @@ dependencies = [ "uint", ] +[[package]] +name = "proc-macro-crate" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" +dependencies = [ + "once_cell", + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.92" @@ -1611,6 +1649,17 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8188909339ccc0c68cfb5a04648313f09621e8b87dc03095454f1a11f6c5d436" +[[package]] +name = "rrs-succinct" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3372685893a9f67d18e98e792d690017287fd17379a83d798d958e517d380fa9" +dependencies = [ + "downcast-rs", + "num_enum", + "paste", +] + [[package]] name = "rustc-demangle" version = "0.1.24" @@ -1927,6 +1976,23 @@ dependencies = [ "serde_json", ] +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" + +[[package]] +name = "toml_edit" +version = "0.19.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" +dependencies = [ + "indexmap", + "toml_datetime", + "winnow", +] + [[package]] name = "tracing" version = "0.1.41" @@ -2345,6 +2411,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" +dependencies = [ + "memchr", +] + [[package]] name = "wyz" version = "0.5.1" diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index 38f0a8bfd..74562142a 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -15,6 +15,7 @@ elf = "0.7" itertools.workspace = true num-derive.workspace = true num-traits.workspace = true +rrs_lib = { package = "rrs-succinct", version = "0.1.0" } strum.workspace = true strum_macros.workspace = true tracing.workspace = true diff --git a/ceno_emul/src/disassemble/mod.rs b/ceno_emul/src/disassemble/mod.rs new file mode 100644 index 000000000..ca6161000 --- /dev/null +++ b/ceno_emul/src/disassemble/mod.rs @@ -0,0 +1,378 @@ +use crate::rv32im::{InsnKind, Instruction}; +use itertools::izip; +use rrs_lib::{ + InstructionProcessor, + instruction_formats::{BType, IType, ITypeCSR, ITypeShamt, JType, RType, SType, UType}, + process_instruction, +}; + +/// A transpiler that converts the 32-bit encoded instructions into instructions. +pub(crate) struct InstructionTranspiler { + pc: u32, + word: u32, +} + +impl Instruction { + /// Create a new [`Instruction`] from an R-type instruction. + #[must_use] + pub const fn from_r_type(kind: InsnKind, dec_insn: &RType, raw: u32) -> Self { + Self { + kind, + rd: dec_insn.rd, + rs1: dec_insn.rs1, + rs2: dec_insn.rs2, + imm: 0, + raw, + } + } + + /// Create a new [`Instruction`] from an I-type instruction. + #[must_use] + pub const fn from_i_type(kind: InsnKind, dec_insn: &IType, raw: u32) -> Self { + Self { + kind, + rd: dec_insn.rd, + rs1: dec_insn.rs1, + imm: dec_insn.imm, + rs2: 0, + raw, + } + } + + /// Create a new [`Instruction`] from an I-type instruction with a shamt. + #[must_use] + pub const fn from_i_type_shamt(kind: InsnKind, dec_insn: &ITypeShamt, raw: u32) -> Self { + Self { + kind, + rd: dec_insn.rd, + rs1: dec_insn.rs1, + imm: dec_insn.shamt as i32, + rs2: 0, + raw, + } + } + + /// Create a new [`Instruction`] from an S-type instruction. + #[must_use] + pub const fn from_s_type(kind: InsnKind, dec_insn: &SType, raw: u32) -> Self { + Self { + kind, + rd: 0, + rs1: dec_insn.rs1, + rs2: dec_insn.rs2, + imm: dec_insn.imm, + raw, + } + } + + /// Create a new [`Instruction`] from a B-type instruction. + #[must_use] + pub const fn from_b_type(kind: InsnKind, dec_insn: &BType, raw: u32) -> Self { + Self { + kind, + rd: 0, + rs1: dec_insn.rs1, + rs2: dec_insn.rs2, + imm: dec_insn.imm, + raw, + } + } + + /// Create a new [`Instruction`] that is not implemented. + #[must_use] + pub const fn unimp(raw: u32) -> Self { + Self { + kind: InsnKind::INVALID, + rd: 0, + rs1: 0, + rs2: 0, + imm: 0, + raw, + } + } +} + +impl InstructionProcessor for InstructionTranspiler { + type InstructionResult = Instruction; + + fn process_add(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::ADD, &dec_insn, self.word) + } + + fn process_addi(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::ADDI, &dec_insn, self.word) + } + + fn process_sub(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::SUB, &dec_insn, self.word) + } + + fn process_xor(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::XOR, &dec_insn, self.word) + } + + fn process_xori(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::XORI, &dec_insn, self.word) + } + + fn process_or(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::OR, &dec_insn, self.word) + } + + fn process_ori(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::ORI, &dec_insn, self.word) + } + + fn process_and(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::AND, &dec_insn, self.word) + } + + fn process_andi(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::ANDI, &dec_insn, self.word) + } + + fn process_sll(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::SLL, &dec_insn, self.word) + } + + fn process_slli(&mut self, dec_insn: ITypeShamt) -> Self::InstructionResult { + Instruction::from_i_type_shamt(InsnKind::SLLI, &dec_insn, self.word) + } + + fn process_srl(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::SRL, &dec_insn, self.word) + } + + fn process_srli(&mut self, dec_insn: ITypeShamt) -> Self::InstructionResult { + Instruction::from_i_type_shamt(InsnKind::SRLI, &dec_insn, self.word) + } + + fn process_sra(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::SRA, &dec_insn, self.word) + } + + fn process_srai(&mut self, dec_insn: ITypeShamt) -> Self::InstructionResult { + Instruction::from_i_type_shamt(InsnKind::SRAI, &dec_insn, self.word) + } + + fn process_slt(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::SLT, &dec_insn, self.word) + } + + fn process_slti(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::SLTI, &dec_insn, self.word) + } + + fn process_sltu(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::SLTU, &dec_insn, self.word) + } + + fn process_sltui(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::SLTIU, &dec_insn, self.word) + } + + fn process_lb(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::LB, &dec_insn, self.word) + } + + fn process_lh(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::LH, &dec_insn, self.word) + } + + fn process_lw(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::LW, &dec_insn, self.word) + } + + fn process_lbu(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::LBU, &dec_insn, self.word) + } + + fn process_lhu(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction::from_i_type(InsnKind::LHU, &dec_insn, self.word) + } + + fn process_sb(&mut self, dec_insn: SType) -> Self::InstructionResult { + Instruction::from_s_type(InsnKind::SB, &dec_insn, self.word) + } + + fn process_sh(&mut self, dec_insn: SType) -> Self::InstructionResult { + Instruction::from_s_type(InsnKind::SH, &dec_insn, self.word) + } + + fn process_sw(&mut self, dec_insn: SType) -> Self::InstructionResult { + Instruction::from_s_type(InsnKind::SW, &dec_insn, self.word) + } + + fn process_beq(&mut self, dec_insn: BType) -> Self::InstructionResult { + Instruction::from_b_type(InsnKind::BEQ, &dec_insn, self.word) + } + + fn process_bne(&mut self, dec_insn: BType) -> Self::InstructionResult { + Instruction::from_b_type(InsnKind::BNE, &dec_insn, self.word) + } + + fn process_blt(&mut self, dec_insn: BType) -> Self::InstructionResult { + Instruction::from_b_type(InsnKind::BLT, &dec_insn, self.word) + } + + fn process_bge(&mut self, dec_insn: BType) -> Self::InstructionResult { + Instruction::from_b_type(InsnKind::BGE, &dec_insn, self.word) + } + + fn process_bltu(&mut self, dec_insn: BType) -> Self::InstructionResult { + Instruction::from_b_type(InsnKind::BLTU, &dec_insn, self.word) + } + + fn process_bgeu(&mut self, dec_insn: BType) -> Self::InstructionResult { + Instruction::from_b_type(InsnKind::BGEU, &dec_insn, self.word) + } + + fn process_jal(&mut self, dec_insn: JType) -> Self::InstructionResult { + Instruction { + kind: InsnKind::JAL, + rd: dec_insn.rd, + rs1: 0, + rs2: 0, + imm: dec_insn.imm, + raw: self.word, + } + } + + fn process_jalr(&mut self, dec_insn: IType) -> Self::InstructionResult { + Instruction { + kind: InsnKind::JALR, + rd: dec_insn.rd, + rs1: dec_insn.rs1, + rs2: 0, + imm: dec_insn.imm, + raw: self.word, + } + } + + /// Convert LUI to ADDI. + fn process_lui(&mut self, dec_insn: UType) -> Self::InstructionResult { + // Verify assumption that the immediate is already shifted left by 12 bits. + assert_eq!(dec_insn.imm & 0xfff, 0); + Instruction { + kind: InsnKind::ADDI, + rd: dec_insn.rd, + rs1: 0, + rs2: 0, + imm: dec_insn.imm, + raw: self.word, + } + } + + /// Convert AUIPC to ADDI. + fn process_auipc(&mut self, dec_insn: UType) -> Self::InstructionResult { + let pc = self.pc; + // Verify our assumption that the immediate is already shifted left by 12 bits. + assert_eq!(dec_insn.imm & 0xfff, 0); + Instruction { + kind: InsnKind::ADDI, + rd: dec_insn.rd, + rs1: 0, + rs2: 0, + imm: dec_insn.imm.wrapping_add(pc as i32), + raw: self.word, + } + } + + fn process_ecall(&mut self) -> Self::InstructionResult { + Instruction { + kind: InsnKind::ECALL, + rd: 0, + rs1: 0, + rs2: 0, + imm: 0, + raw: self.word, + } + } + + fn process_ebreak(&mut self) -> Self::InstructionResult { + Instruction::unimp(self.word) + } + + fn process_mul(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::MUL, &dec_insn, self.word) + } + + fn process_mulh(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::MULH, &dec_insn, self.word) + } + + fn process_mulhu(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::MULHU, &dec_insn, self.word) + } + + fn process_mulhsu(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::MULHSU, &dec_insn, self.word) + } + + fn process_div(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::DIV, &dec_insn, self.word) + } + + fn process_divu(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::DIVU, &dec_insn, self.word) + } + + fn process_rem(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::REM, &dec_insn, self.word) + } + + fn process_remu(&mut self, dec_insn: RType) -> Self::InstructionResult { + Instruction::from_r_type(InsnKind::REMU, &dec_insn, self.word) + } + + fn process_csrrc(&mut self, _: ITypeCSR) -> Self::InstructionResult { + Instruction::unimp(self.word) + } + + fn process_csrrci(&mut self, _: ITypeCSR) -> Self::InstructionResult { + Instruction::unimp(self.word) + } + + fn process_csrrs(&mut self, _: ITypeCSR) -> Self::InstructionResult { + Instruction::unimp(self.word) + } + + fn process_csrrsi(&mut self, _: ITypeCSR) -> Self::InstructionResult { + Instruction::unimp(self.word) + } + + fn process_csrrw(&mut self, _: ITypeCSR) -> Self::InstructionResult { + Instruction::unimp(self.word) + } + + fn process_csrrwi(&mut self, _: ITypeCSR) -> Self::InstructionResult { + Instruction::unimp(self.word) + } + + fn process_fence(&mut self, _: IType) -> Self::InstructionResult { + Instruction::unimp(self.word) + } + + fn process_mret(&mut self) -> Self::InstructionResult { + Instruction::unimp(self.word) + } + + fn process_wfi(&mut self) -> Self::InstructionResult { + Instruction::unimp(self.word) + } +} + +/// Transpile the [`Instruction`]s from the 32-bit encoded instructions. +#[must_use] +pub fn transpile(base: u32, instructions_u32: &[u32]) -> Vec { + izip!(enumerate(base, 4), instructions_u32) + .map(|(pc, &word)| { + process_instruction(&mut InstructionTranspiler { pc, word }, word) + .unwrap_or(Instruction::unimp(word)) + }) + .collect() +} + +fn enumerate(start: u32, step: u32) -> impl Iterator { + std::iter::successors(Some(start), move |&i| Some(i + step)) +} diff --git a/ceno_emul/src/elf.rs b/ceno_emul/src/elf.rs index 82697c619..ee59d3de3 100644 --- a/ceno_emul/src/elf.rs +++ b/ceno_emul/src/elf.rs @@ -18,7 +18,7 @@ extern crate alloc; use alloc::collections::BTreeMap; -use crate::addr::WORD_SIZE; +use crate::{addr::WORD_SIZE, disassemble::transpile, rv32im::Instruction}; use anyhow::{Context, Result, anyhow, bail}; use elf::{ ElfBytes, @@ -35,7 +35,7 @@ pub struct Program { /// This is the lowest address of the program's executable code pub base_address: u32, /// The instructions of the program - pub instructions: Vec, + pub instructions: Vec, /// The initial memory image pub image: BTreeMap, } @@ -45,7 +45,7 @@ impl Program { pub fn new( entry: u32, base_address: u32, - instructions: Vec, + instructions: Vec, image: BTreeMap, ) -> Program { Self { @@ -162,6 +162,8 @@ impl Program { assert!(entry >= base_address); assert!((entry - base_address) as usize <= instructions.len() * WORD_SIZE); + let instructions = transpile(base_address, &instructions); + Ok(Program { entry, base_address, diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index c734b1794..1a855006e 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -12,10 +12,11 @@ mod vm_state; pub use vm_state::VMState; mod rv32im; -pub use rv32im::{DecodedInstruction, EmuContext, InsnCategory, InsnCodes, InsnFormat, InsnKind}; +pub use rv32im::{ + EmuContext, InsnCategory, InsnFormat, InsnKind, Instruction, encode_rv32, encode_rv32u, +}; mod elf; pub use elf::Program; -mod rv32im_encode; -pub use rv32im_encode::encode_rv32; +pub mod disassemble; diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index c28bcda40..fec7e0181 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -73,10 +73,6 @@ impl Platform { self.is_ram(addr) } - pub fn can_execute(&self, addr: Addr) -> bool { - self.is_rom(addr) - } - // Environment calls. /// Register containing the ecall function code. (x5, t0) @@ -113,7 +109,6 @@ mod tests { #[test] fn test_no_overlap() { let p = CENO_PLATFORM; - assert!(p.can_execute(p.pc_base())); // ROM and RAM do not overlap. assert!(!p.is_rom(p.ram.start)); assert!(!p.is_rom(p.ram.end - WORD_SIZE as Addr)); diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index 52fac1508..e63fa761b 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -15,13 +15,39 @@ // limitations under the License. use anyhow::{Result, anyhow}; -use itertools::enumerate; use num_derive::ToPrimitive; -use std::sync::OnceLock; use strum_macros::{Display, EnumIter}; use super::addr::{ByteAddr, RegIdx, WORD_SIZE, Word, WordAddr}; +/// Convenience function to create an `Instruction` with the given fields. +/// +/// Pass 0 for unused fields. +pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) -> Instruction { + Instruction { + kind, + rs1: rs1 as usize, + rs2: rs2 as usize, + rd: rd as usize, + imm, + raw: 0, + } +} + +/// Convenience function to create an `Instruction` with the given fields. +/// +/// Pass 0 for unused fields. +pub const fn encode_rv32u(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: u32) -> Instruction { + Instruction { + kind, + rs1: rs1 as usize, + rs2: rs2 as usize, + rd: rd as usize, + imm: imm as i32, + raw: 0, + } +} + pub trait EmuContext { // Handle environment call fn ecall(&mut self) -> Result; @@ -29,11 +55,8 @@ pub trait EmuContext { // Handle a trap fn trap(&self, cause: TrapCause) -> Result; - // Callback when instructions are decoded - fn on_insn_decoded(&mut self, _decoded: &DecodedInstruction) {} - // Callback when instructions end normally - fn on_normal_end(&mut self, _decoded: &DecodedInstruction) {} + fn on_normal_end(&mut self, _decoded: &Instruction) {} // Get the program counter fn get_pc(&self) -> ByteAddr; @@ -60,14 +83,7 @@ pub trait EmuContext { fn peek_memory(&self, addr: WordAddr) -> Word; /// Load from instruction cache - // TODO(Matthias): this should really return `Result` - // because the instruction cache should contain instructions, not just words. - fn fetch(&mut self, pc: WordAddr) -> Option; - - // Check access for instruction load - fn check_insn_load(&self, _addr: ByteAddr) -> bool { - true - } + fn fetch(&mut self, pc: WordAddr) -> Option; // Check access for data load fn check_data_load(&self, _addr: ByteAddr) -> bool { @@ -80,11 +96,6 @@ pub trait EmuContext { } } -/// An implementation of the basic ISA (RV32IM), that is instruction decoding and functional units. -pub struct Emulator { - table: &'static FastDecodeTable, -} - #[derive(Debug)] pub enum TrapCause { InstructionAddressMisaligned, @@ -98,17 +109,18 @@ pub enum TrapCause { EcallError, } -#[derive(Clone, Debug, Default)] -pub struct DecodedInstruction { - insn: u32, - top_bit: u32, - // The bit fields of the instruction encoding, regardless of the instruction format. - func7: u32, - rs2: u32, - rs1: u32, - func3: u32, - rd: u32, - opcode: u32, +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct Instruction { + pub kind: InsnKind, + pub rs1: RegIdx, + pub rs2: RegIdx, + pub rd: RegIdx, + pub imm: i32, + /// `raw` is there only to produce better logging and error messages. + /// + /// Set to 0, if you are creating an instruction directly, + /// instead of decoding it from a raw 32-bit `Word`. + pub raw: Word, } #[derive(Clone, Copy, Debug)] @@ -133,9 +145,12 @@ pub enum InsnFormat { } use InsnFormat::*; -#[derive(Clone, Copy, Display, Debug, PartialEq, Eq, PartialOrd, Ord, EnumIter, ToPrimitive)] +#[derive( + Clone, Copy, Display, Debug, PartialEq, Eq, PartialOrd, Ord, EnumIter, ToPrimitive, Default, +)] #[allow(clippy::upper_case_acronyms)] pub enum InsnKind { + #[default] INVALID, ADD, SUB, @@ -164,8 +179,6 @@ pub enum InsnKind { BGEU, JAL, JALR, - LUI, - AUIPC, MUL, MULH, MULHSU, @@ -182,611 +195,331 @@ pub enum InsnKind { SB, SH, SW, - /// ECALL and EBREAK etc. - EANY, + ECALL, } use InsnKind::*; -impl InsnKind { - pub const fn codes(self) -> InsnCodes { - RV32IM_ISA[self as usize] +impl From for InsnCategory { + fn from(kind: InsnKind) -> Self { + match kind { + INVALID => Invalid, + ADD | SUB | XOR | OR | AND | SLL | SRL | SRA | SLT | SLTU | MUL | MULH | MULHSU + | MULHU | DIV | DIVU | REM | REMU => Compute, + ADDI | XORI | ORI | ANDI | SLLI | SRLI | SRAI | SLTI | SLTIU => Compute, + BEQ | BNE | BLT | BGE | BLTU | BGEU => Branch, + JAL | JALR => Compute, + LB | LH | LW | LBU | LHU => Load, + SB | SH | SW => Store, + ECALL => System, + } } } -#[derive(Clone, Copy, Debug)] -pub struct InsnCodes { - pub format: InsnFormat, - pub kind: InsnKind, - pub category: InsnCategory, - pub(crate) opcode: u32, - pub(crate) func3: u32, - pub(crate) func7: u32, -} - -impl DecodedInstruction { - /// A virtual register which absorbs the writes to x0. - pub const RD_NULL: u32 = 32; - - pub fn new(insn: u32) -> Self { - Self { - insn, - top_bit: (insn & 0x80000000) >> 31, - func7: (insn & 0xfe000000) >> 25, - rs2: (insn & 0x01f00000) >> 20, - rs1: (insn & 0x000f8000) >> 15, - func3: (insn & 0x00007000) >> 12, - rd: (insn & 0x00000f80) >> 7, - opcode: insn & 0x0000007f, +// For encoding, which is useful for tests. +impl From for InsnFormat { + fn from(kind: InsnKind) -> Self { + match kind { + ADD | SUB | XOR | OR | AND | SLL | SRL | SRA | SLT | SLTU | MUL | MULH | MULHSU + | MULHU | DIV | DIVU | REM | REMU => R, + ADDI | XORI | ORI | ANDI | SLLI | SRLI | SRAI | SLTI | SLTIU => I, + BEQ | BNE | BLT | BGE | BLTU | BGEU => B, + JAL => J, + JALR => I, + LB | LH | LW | LBU | LHU => I, + SB | SH | SW => S, + ECALL => I, + INVALID => I, } } +} - pub fn encoded(&self) -> u32 { - self.insn - } - - pub fn opcode(&self) -> u32 { - self.opcode - } - - /// The internal register destination. It is either the regular rd, or an internal RD_NULL if - /// the instruction does not write to a register or writes to x0. +impl Instruction { + pub const RD_NULL: u32 = 32; pub fn rd_internal(&self) -> u32 { - match self.codes().format { - R | I | U | J if self.rd != 0 => self.rd, + match InsnFormat::from(self.kind) { + R | I | U | J if self.rd != 0 => self.rd as u32, _ => Self::RD_NULL, } } - - /// Get the rs1 field, regardless of the instruction format. - pub fn rs1(&self) -> u32 { - self.rs1 - } - /// Get the register source 1, or zero if the instruction does not use rs1. pub fn rs1_or_zero(&self) -> u32 { - match self.codes().format { - R | I | S | B => self.rs1, + match InsnFormat::from(self.kind) { + R | I | S | B => self.rs1 as u32, _ => 0, } } - - /// Get the rs2 field, regardless of the instruction format. - pub fn rs2(&self) -> u32 { - self.rs2 - } - /// Get the register source 2, or zero if the instruction does not use rs2. pub fn rs2_or_zero(&self) -> u32 { - match self.codes().format { - R | S | B => self.rs2, + match InsnFormat::from(self.kind) { + R | S | B => self.rs2 as u32, _ => 0, } } - - pub fn immediate(&self) -> u32 { - match self.codes().format { - R => 0, - I => self.imm_i(), - S => self.imm_s(), - B => self.imm_b(), - U => self.imm_u(), - J => self.imm_j(), - } - } - - pub fn codes(&self) -> InsnCodes { - FastDecodeTable::get().lookup(self) - } - - fn imm_b(&self) -> u32 { - (self.top_bit * 0xfffff000) - | ((self.rd & 1) << 11) - | ((self.func7 & 0x3f) << 5) - | (self.rd & 0x1e) - } - - fn imm_i(&self) -> u32 { - (self.top_bit * 0xffff_f000) | (self.func7 << 5) | self.rs2 - } - - fn imm_s(&self) -> u32 { - (self.top_bit * 0xfffff000) | (self.func7 << 5) | self.rd - } - - fn imm_j(&self) -> u32 { - (self.top_bit * 0xfff00000) - | (self.rs1 << 15) - | (self.func3 << 12) - | ((self.rs2 & 1) << 11) - | ((self.func7 & 0x3f) << 5) - | (self.rs2 & 0x1e) - } - - fn imm_u(&self) -> u32 { - self.insn & 0xfffff000 - } } -const fn insn( - format: InsnFormat, - kind: InsnKind, - category: InsnCategory, - opcode: u32, - func3: i32, - func7: i32, -) -> InsnCodes { - InsnCodes { - format, - kind, - category, - opcode, - func3: func3 as u32, - func7: func7 as u32, - } -} - -type InstructionTable = [InsnCodes; 47]; -type FastInstructionTable = [u8; 1 << 10]; - -const RV32IM_ISA: InstructionTable = [ - insn(R, INVALID, Invalid, 0x00, 0x0, 0x00), - insn(R, ADD, Compute, 0x33, 0x0, 0x00), - insn(R, SUB, Compute, 0x33, 0x0, 0x20), - insn(R, XOR, Compute, 0x33, 0x4, 0x00), - insn(R, OR, Compute, 0x33, 0x6, 0x00), - insn(R, AND, Compute, 0x33, 0x7, 0x00), - insn(R, SLL, Compute, 0x33, 0x1, 0x00), - insn(R, SRL, Compute, 0x33, 0x5, 0x00), - insn(R, SRA, Compute, 0x33, 0x5, 0x20), - insn(R, SLT, Compute, 0x33, 0x2, 0x00), - insn(R, SLTU, Compute, 0x33, 0x3, 0x00), - insn(I, ADDI, Compute, 0x13, 0x0, -1), - insn(I, XORI, Compute, 0x13, 0x4, -1), - insn(I, ORI, Compute, 0x13, 0x6, -1), - insn(I, ANDI, Compute, 0x13, 0x7, -1), - insn(I, SLLI, Compute, 0x13, 0x1, 0x00), - insn(I, SRLI, Compute, 0x13, 0x5, 0x00), - insn(I, SRAI, Compute, 0x13, 0x5, 0x20), - insn(I, SLTI, Compute, 0x13, 0x2, -1), - insn(I, SLTIU, Compute, 0x13, 0x3, -1), - insn(B, BEQ, Branch, 0x63, 0x0, -1), - insn(B, BNE, Branch, 0x63, 0x1, -1), - insn(B, BLT, Branch, 0x63, 0x4, -1), - insn(B, BGE, Branch, 0x63, 0x5, -1), - insn(B, BLTU, Branch, 0x63, 0x6, -1), - insn(B, BGEU, Branch, 0x63, 0x7, -1), - insn(J, JAL, Compute, 0x6f, -1, -1), - insn(I, JALR, Compute, 0x67, 0x0, -1), - insn(U, LUI, Compute, 0x37, -1, -1), - insn(U, AUIPC, Compute, 0x17, -1, -1), - insn(R, MUL, Compute, 0x33, 0x0, 0x01), - insn(R, MULH, Compute, 0x33, 0x1, 0x01), - insn(R, MULHSU, Compute, 0x33, 0x2, 0x01), - insn(R, MULHU, Compute, 0x33, 0x3, 0x01), - insn(R, DIV, Compute, 0x33, 0x4, 0x01), - insn(R, DIVU, Compute, 0x33, 0x5, 0x01), - insn(R, REM, Compute, 0x33, 0x6, 0x01), - insn(R, REMU, Compute, 0x33, 0x7, 0x01), - insn(I, LB, Load, 0x03, 0x0, -1), - insn(I, LH, Load, 0x03, 0x1, -1), - insn(I, LW, Load, 0x03, 0x2, -1), - insn(I, LBU, Load, 0x03, 0x4, -1), - insn(I, LHU, Load, 0x03, 0x5, -1), - insn(S, SB, Store, 0x23, 0x0, -1), - insn(S, SH, Store, 0x23, 0x1, -1), - insn(S, SW, Store, 0x23, 0x2, -1), - insn(I, EANY, System, 0x73, 0x0, 0x00), -]; - -#[cfg(test)] -#[test] -fn test_isa_table() { - use strum::IntoEnumIterator; - for kind in InsnKind::iter() { - assert_eq!(kind.codes().kind, kind); - } -} - -// RISC-V instruction are determined by 3 parts: -// - Opcode: 7 bits -// - Func3: 3 bits -// - Func7: 7 bits -// In many cases, func7 and/or func3 is ignored. A standard trick is to decode -// via a table, but a 17 bit lookup table destroys L1 cache. Luckily for us, -// in practice the low 2 bits of opcode are always 11, so we can drop them, and -// also func7 is always either 0, 1, 0x20 or don't care, so we can reduce func7 -// to 2 bits, which gets us to 10 bits, which is only 1k. -struct FastDecodeTable { - table: FastInstructionTable, -} - -impl FastDecodeTable { - fn new() -> Self { - let mut table: FastInstructionTable = [0; 1 << 10]; - for (isa_idx, insn) in enumerate(&RV32IM_ISA) { - Self::add_insn(&mut table, insn, isa_idx); - } - Self { table } - } - - fn get() -> &'static Self { - FAST_DECODE_TABLE.get_or_init(Self::new) - } - - // Map to 10 bit format - fn map10(opcode: u32, func3: u32, func7: u32) -> usize { - let op_high = opcode >> 2; - // Map 0 -> 0, 1 -> 1, 0x20 -> 2, everything else to 3 - let func72bits = if func7 <= 1 { - func7 - } else if func7 == 0x20 { - 2 - } else { - 3 - }; - ((op_high << 5) | (func72bits << 3) | func3) as usize - } - - fn add_insn(table: &mut FastInstructionTable, insn: &InsnCodes, isa_idx: usize) { - let op_high = insn.opcode >> 2; - if (insn.func3 as i32) < 0 { - for f3 in 0..8 { - for f7b in 0..4 { - let idx = (op_high << 5) | (f7b << 3) | f3; - table[idx as usize] = isa_idx as u8; - } - } - } else if (insn.func7 as i32) < 0 { - for f7b in 0..4 { - let idx = (op_high << 5) | (f7b << 3) | insn.func3; - table[idx as usize] = isa_idx as u8; - } - } else { - table[Self::map10(insn.opcode, insn.func3, insn.func7)] = isa_idx as u8; - } - } - - fn lookup(&self, decoded: &DecodedInstruction) -> InsnCodes { - let isa_idx = self.table[Self::map10(decoded.opcode, decoded.func3, decoded.func7)]; - RV32IM_ISA[isa_idx as usize] - } +pub fn step(ctx: &mut C) -> Result<()> { + let pc = ctx.get_pc(); + + let Some(insn) = ctx.fetch(pc.waddr()) else { + ctx.trap(TrapCause::InstructionAccessFault)?; + return Err(anyhow!( + "Fatal: could not fetch instruction at pc={pc:?}, ELF does not have instructions there." + )); + }; + + tracing::trace!("pc: {:x}, kind: {:?}", pc.0, insn.kind); + + if match InsnCategory::from(insn.kind) { + InsnCategory::Compute => step_compute(ctx, insn.kind, &insn)?, + InsnCategory::Branch => step_branch(ctx, insn.kind, &insn)?, + InsnCategory::Load => step_load(ctx, insn.kind, &insn)?, + InsnCategory::Store => step_store(ctx, insn.kind, &insn)?, + InsnCategory::System => step_system(ctx, insn.kind, &insn)?, + InsnCategory::Invalid => ctx.trap(TrapCause::IllegalInstruction(insn.raw))?, + } { + ctx.on_normal_end(&insn); + }; + + Ok(()) } -static FAST_DECODE_TABLE: OnceLock = OnceLock::new(); - -impl Emulator { - pub fn new() -> Self { - Self { - table: FastDecodeTable::get(), +fn step_compute(ctx: &mut M, kind: InsnKind, insn: &Instruction) -> Result { + use super::InsnKind::*; + + let pc = ctx.get_pc(); + let mut new_pc = pc + WORD_SIZE; + let imm_i = insn.imm as u32; + let out = match kind { + // Instructions that do not read rs1 nor rs2. + JAL => { + new_pc = pc.wrapping_add(insn.imm as u32); + (pc + WORD_SIZE).0 } - } - - pub fn step(&self, ctx: &mut C) -> Result<()> { - let pc = ctx.get_pc(); - - // TODO(Matthias): `check_insn_load` should be unnecessary: we can statically - // check in `fn new_from_elf` that the program only has instructions where - // our platform accepts them. - if !ctx.check_insn_load(pc) { - ctx.trap(TrapCause::InstructionAccessFault)?; - return Err(anyhow!("Fatal: could not fetch instruction at pc={:?}", pc)); - } - let Some(word) = ctx.fetch(pc.waddr()) else { - ctx.trap(TrapCause::InstructionAccessFault)?; - return Err(anyhow!("Fatal: could not fetch instruction at pc={:?}", pc)); - }; - - // TODO(Matthias): our `Program` that we are fetching from should really store - // already decoded instructions, instead of doing this weird, partial checking - // for `0x03` here. - // - // Note how we can fail here with an IllegalInstruction, and again further down - // when we match against the decoded instruction. We should centralise that. And - // our `step` function here shouldn't need to know anything about how instruction - // are encoded as numbers. - // - // One way to centralise is to do the check once when loading the program from the - // ELF. - if word & 0x03 != 0x03 { - // Opcode must end in 0b11 in RV32IM. - ctx.trap(TrapCause::IllegalInstruction(word))?; - return Err(anyhow!( - "Fatal: illegal instruction at pc={:?}: 0x{:08x}", - pc, - word - )); - } - - let decoded = DecodedInstruction::new(word); - let insn = self.table.lookup(&decoded); - ctx.on_insn_decoded(&decoded); - tracing::trace!("pc: {:x}, kind: {:?}", pc.0, insn.kind); - - if match insn.category { - InsnCategory::Compute => self.step_compute(ctx, insn.kind, &decoded)?, - InsnCategory::Branch => self.step_branch(ctx, insn.kind, &decoded)?, - InsnCategory::Load => self.step_load(ctx, insn.kind, &decoded)?, - InsnCategory::Store => self.step_store(ctx, insn.kind, &decoded)?, - InsnCategory::System => self.step_system(ctx, insn.kind, &decoded)?, - InsnCategory::Invalid => ctx.trap(TrapCause::IllegalInstruction(word))?, - } { - ctx.on_normal_end(&decoded); - }; - - Ok(()) - } - - fn step_compute( - &self, - ctx: &mut M, - kind: InsnKind, - decoded: &DecodedInstruction, - ) -> Result { - use InsnKind::*; - - let pc = ctx.get_pc(); - let mut new_pc = pc + WORD_SIZE; - let imm_i = decoded.imm_i(); - let out = match kind { - // Instructions that do not read rs1 nor rs2. - JAL => { - new_pc = pc.wrapping_add(decoded.imm_j()); - (pc + WORD_SIZE).0 - } - LUI => decoded.imm_u(), - AUIPC => (pc.wrapping_add(decoded.imm_u())).0, - - _ => { - // Instructions that read rs1 but not rs2. - let rs1 = ctx.load_register(decoded.rs1 as usize)?; - - match kind { - ADDI => rs1.wrapping_add(imm_i), - XORI => rs1 ^ imm_i, - ORI => rs1 | imm_i, - ANDI => rs1 & imm_i, - SLLI => rs1 << (imm_i & 0x1f), - SRLI => rs1 >> (imm_i & 0x1f), - SRAI => ((rs1 as i32) >> (imm_i & 0x1f)) as u32, - SLTI => { - if (rs1 as i32) < (imm_i as i32) { - 1 - } else { - 0 - } + _ => { + // Instructions that read rs1 but not rs2. + let rs1 = ctx.load_register(insn.rs1)?; + + match kind { + ADDI => rs1.wrapping_add(imm_i), + XORI => rs1 ^ imm_i, + ORI => rs1 | imm_i, + ANDI => rs1 & imm_i, + SLLI => rs1 << (imm_i & 0x1f), + SRLI => rs1 >> (imm_i & 0x1f), + SRAI => ((rs1 as i32) >> (imm_i & 0x1f)) as u32, + SLTI => { + if (rs1 as i32) < (imm_i as i32) { + 1 + } else { + 0 } - SLTIU => { - if rs1 < imm_i { - 1 - } else { - 0 - } - } - JALR => { - new_pc = ByteAddr(rs1.wrapping_add(imm_i) & 0xfffffffe); - (pc + WORD_SIZE).0 + } + SLTIU => { + if rs1 < imm_i { + 1 + } else { + 0 } + } + JALR => { + new_pc = ByteAddr(rs1.wrapping_add(imm_i) & !1); + (pc + WORD_SIZE).0 + } - _ => { - // Instructions that use rs1 and rs2. - let rs2 = ctx.load_register(decoded.rs2 as usize)?; - - match kind { - ADD => rs1.wrapping_add(rs2), - SUB => rs1.wrapping_sub(rs2), - XOR => rs1 ^ rs2, - OR => rs1 | rs2, - AND => rs1 & rs2, - SLL => rs1 << (rs2 & 0x1f), - SRL => rs1 >> (rs2 & 0x1f), - SRA => ((rs1 as i32) >> (rs2 & 0x1f)) as u32, - SLT => { - if (rs1 as i32) < (rs2 as i32) { - 1 - } else { - 0 - } - } - SLTU => { - if rs1 < rs2 { - 1 - } else { - 0 - } + _ => { + // Instructions that use rs1 and rs2. + let rs2 = ctx.load_register(insn.rs2)?; + + match kind { + ADD => rs1.wrapping_add(rs2), + SUB => rs1.wrapping_sub(rs2), + XOR => rs1 ^ rs2, + OR => rs1 | rs2, + AND => rs1 & rs2, + SLL => rs1 << (rs2 & 0x1f), + SRL => rs1 >> (rs2 & 0x1f), + SRA => ((rs1 as i32) >> (rs2 & 0x1f)) as u32, + SLT => { + if (rs1 as i32) < (rs2 as i32) { + 1 + } else { + 0 } - MUL => rs1.wrapping_mul(rs2), - MULH => { - (sign_extend_u32(rs1).wrapping_mul(sign_extend_u32(rs2)) >> 32) - as u32 + } + SLTU => { + if rs1 < rs2 { + 1 + } else { + 0 } - MULHSU => (sign_extend_u32(rs1).wrapping_mul(rs2 as i64) >> 32) as u32, - MULHU => (((rs1 as u64).wrapping_mul(rs2 as u64)) >> 32) as u32, - DIV => { - if rs2 == 0 { - u32::MAX - } else { - ((rs1 as i32).wrapping_div(rs2 as i32)) as u32 - } + } + MUL => rs1.wrapping_mul(rs2), + MULH => { + (sign_extend_u32(rs1).wrapping_mul(sign_extend_u32(rs2)) >> 32) as u32 + } + MULHSU => (sign_extend_u32(rs1).wrapping_mul(rs2 as i64) >> 32) as u32, + MULHU => (((rs1 as u64).wrapping_mul(rs2 as u64)) >> 32) as u32, + DIV => { + if rs2 == 0 { + u32::MAX + } else { + ((rs1 as i32).wrapping_div(rs2 as i32)) as u32 } - DIVU => { - if rs2 == 0 { - u32::MAX - } else { - rs1 / rs2 - } + } + DIVU => { + if rs2 == 0 { + u32::MAX + } else { + rs1 / rs2 } - REM => { - if rs2 == 0 { - rs1 - } else { - ((rs1 as i32).wrapping_rem(rs2 as i32)) as u32 - } + } + REM => { + if rs2 == 0 { + rs1 + } else { + ((rs1 as i32).wrapping_rem(rs2 as i32)) as u32 } - REMU => { - if rs2 == 0 { - rs1 - } else { - rs1 % rs2 - } + } + REMU => { + if rs2 == 0 { + rs1 + } else { + rs1 % rs2 } - - _ => unreachable!("Illegal compute instruction: {:?}", kind), } + + _ => unreachable!("Illegal compute instruction: {:?}", kind), } } } - }; - if !new_pc.is_aligned() { - return ctx.trap(TrapCause::InstructionAddressMisaligned); } - ctx.store_register(decoded.rd_internal() as usize, out)?; - ctx.set_pc(new_pc); - Ok(true) + }; + if !new_pc.is_aligned() { + return ctx.trap(TrapCause::InstructionAddressMisaligned); } + ctx.store_register(insn.rd_internal() as usize, out)?; + ctx.set_pc(new_pc); + Ok(true) +} - fn step_branch( - &self, - ctx: &mut M, - kind: InsnKind, - decoded: &DecodedInstruction, - ) -> Result { - use InsnKind::*; - - let pc = ctx.get_pc(); - let rs1 = ctx.load_register(decoded.rs1 as RegIdx)?; - let rs2 = ctx.load_register(decoded.rs2 as RegIdx)?; - - let taken = match kind { - BEQ => rs1 == rs2, - BNE => rs1 != rs2, - BLT => (rs1 as i32) < (rs2 as i32), - BGE => (rs1 as i32) >= (rs2 as i32), - BLTU => rs1 < rs2, - BGEU => rs1 >= rs2, - _ => unreachable!("Illegal branch instruction: {:?}", kind), - }; - - let new_pc = if taken { - pc.wrapping_add(decoded.imm_b()) - } else { - pc + WORD_SIZE - }; - - if !new_pc.is_aligned() { - return ctx.trap(TrapCause::InstructionAddressMisaligned); - } - ctx.set_pc(new_pc); - Ok(true) - } +fn step_branch(ctx: &mut M, kind: InsnKind, decoded: &Instruction) -> Result { + use super::InsnKind::*; + + let pc = ctx.get_pc(); + let rs1 = ctx.load_register(decoded.rs1 as RegIdx)?; + let rs2 = ctx.load_register(decoded.rs2 as RegIdx)?; + + let taken = match kind { + BEQ => rs1 == rs2, + BNE => rs1 != rs2, + BLT => (rs1 as i32) < (rs2 as i32), + BGE => (rs1 as i32) >= (rs2 as i32), + BLTU => rs1 < rs2, + BGEU => rs1 >= rs2, + _ => unreachable!("Illegal branch instruction: {:?}", kind), + }; + + let new_pc = if taken { + pc.wrapping_add(decoded.imm as u32) + } else { + pc + WORD_SIZE + }; + + if !new_pc.is_aligned() { + return ctx.trap(TrapCause::InstructionAddressMisaligned); + } + ctx.set_pc(new_pc); + Ok(true) +} - fn step_load( - &self, - ctx: &mut M, - kind: InsnKind, - decoded: &DecodedInstruction, - ) -> Result { - let rs1 = ctx.load_register(decoded.rs1 as usize)?; - // LOAD instructions do not read rs2. - let addr = ByteAddr(rs1.wrapping_add(decoded.imm_i())); - if !ctx.check_data_load(addr) { - return ctx.trap(TrapCause::LoadAccessFault(addr)); +fn step_load(ctx: &mut M, kind: InsnKind, decoded: &Instruction) -> Result { + let rs1 = ctx.load_register(decoded.rs1)?; + // LOAD instructions do not read rs2. + let addr = ByteAddr(rs1.wrapping_add_signed(decoded.imm)); + if !ctx.check_data_load(addr) { + return ctx.trap(TrapCause::LoadAccessFault(addr)); + } + let data = ctx.load_memory(addr.waddr())?; + let shift = 8 * (addr.0 & 3); + let out = match kind { + InsnKind::LB => { + let mut out = (data >> shift) & 0xff; + if out & 0x80 != 0 { + out |= 0xffffff00; + } + out } - let data = ctx.load_memory(addr.waddr())?; - let shift = 8 * (addr.0 & 3); - let out = match kind { - InsnKind::LB => { - let mut out = (data >> shift) & 0xff; - if out & 0x80 != 0 { - out |= 0xffffff00; - } - out + InsnKind::LH => { + if addr.0 & 0x01 != 0 { + return ctx.trap(TrapCause::LoadAddressMisaligned); } - InsnKind::LH => { - if addr.0 & 0x01 != 0 { - return ctx.trap(TrapCause::LoadAddressMisaligned); - } - let mut out = (data >> shift) & 0xffff; - if out & 0x8000 != 0 { - out |= 0xffff0000; - } - out + let mut out = (data >> shift) & 0xffff; + if out & 0x8000 != 0 { + out |= 0xffff0000; } - InsnKind::LW => { - if addr.0 & 0x03 != 0 { - return ctx.trap(TrapCause::LoadAddressMisaligned); - } - data + out + } + InsnKind::LW => { + if addr.0 & 0x03 != 0 { + return ctx.trap(TrapCause::LoadAddressMisaligned); } - InsnKind::LBU => (data >> shift) & 0xff, - InsnKind::LHU => { - if addr.0 & 0x01 != 0 { - return ctx.trap(TrapCause::LoadAddressMisaligned); - } - (data >> shift) & 0xffff + data + } + InsnKind::LBU => (data >> shift) & 0xff, + InsnKind::LHU => { + if addr.0 & 0x01 != 0 { + return ctx.trap(TrapCause::LoadAddressMisaligned); } - _ => unreachable!(), - }; - ctx.store_register(decoded.rd_internal() as usize, out)?; - ctx.set_pc(ctx.get_pc() + WORD_SIZE); - Ok(true) - } + (data >> shift) & 0xffff + } + _ => unreachable!(), + }; + ctx.store_register(decoded.rd_internal() as usize, out)?; + ctx.set_pc(ctx.get_pc() + WORD_SIZE); + Ok(true) +} - fn step_store( - &self, - ctx: &mut M, - kind: InsnKind, - decoded: &DecodedInstruction, - ) -> Result { - let rs1 = ctx.load_register(decoded.rs1 as usize)?; - let rs2 = ctx.load_register(decoded.rs2 as usize)?; - let addr = ByteAddr(rs1.wrapping_add(decoded.imm_s())); - let shift = 8 * (addr.0 & 3); - if !ctx.check_data_store(addr) { - tracing::error!("mstore: addr={:x?},rs1={:x}", addr, rs1); - return ctx.trap(TrapCause::StoreAccessFault); +fn step_store(ctx: &mut M, kind: InsnKind, decoded: &Instruction) -> Result { + let rs1 = ctx.load_register(decoded.rs1)?; + let rs2 = ctx.load_register(decoded.rs2)?; + let addr = ByteAddr(rs1.wrapping_add(decoded.imm as u32)); + let shift = 8 * (addr.0 & 3); + if !ctx.check_data_store(addr) { + tracing::error!("mstore: addr={:x?},rs1={:x}", addr, rs1); + return ctx.trap(TrapCause::StoreAccessFault); + } + let mut data = ctx.peek_memory(addr.waddr()); + match kind { + InsnKind::SB => { + data ^= data & (0xff << shift); + data |= (rs2 & 0xff) << shift; } - let mut data = ctx.peek_memory(addr.waddr()); - match kind { - InsnKind::SB => { - data ^= data & (0xff << shift); - data |= (rs2 & 0xff) << shift; + InsnKind::SH => { + if addr.0 & 0x01 != 0 { + tracing::debug!("Misaligned SH"); + return ctx.trap(TrapCause::StoreAddressMisaligned(addr)); } - InsnKind::SH => { - if addr.0 & 0x01 != 0 { - tracing::debug!("Misaligned SH"); - return ctx.trap(TrapCause::StoreAddressMisaligned(addr)); - } - data ^= data & (0xffff << shift); - data |= (rs2 & 0xffff) << shift; - } - InsnKind::SW => { - if addr.0 & 0x03 != 0 { - tracing::debug!("Misaligned SW"); - return ctx.trap(TrapCause::StoreAddressMisaligned(addr)); - } - data = rs2; + data ^= data & (0xffff << shift); + data |= (rs2 & 0xffff) << shift; + } + InsnKind::SW => { + if addr.0 & 0x03 != 0 { + tracing::debug!("Misaligned SW"); + return ctx.trap(TrapCause::StoreAddressMisaligned(addr)); } - _ => unreachable!(), + data = rs2; } - ctx.store_memory(addr.waddr(), data)?; - ctx.set_pc(ctx.get_pc() + WORD_SIZE); - Ok(true) + _ => unreachable!(), } + ctx.store_memory(addr.waddr(), data)?; + ctx.set_pc(ctx.get_pc() + WORD_SIZE); + Ok(true) +} - fn step_system( - &self, - ctx: &mut M, - kind: InsnKind, - decoded: &DecodedInstruction, - ) -> Result { - match kind { - InsnKind::EANY => match decoded.rs2 { - 0 => ctx.ecall(), - 1 => ctx.trap(TrapCause::Breakpoint), - _ => ctx.trap(TrapCause::IllegalInstruction(decoded.insn)), - }, - _ => unreachable!(), - } +fn step_system(ctx: &mut M, kind: InsnKind, decoded: &Instruction) -> Result { + match kind { + InsnKind::ECALL => ctx.ecall(), + _ => ctx.trap(TrapCause::IllegalInstruction(decoded.raw)), } } diff --git a/ceno_emul/src/rv32im_encode.rs b/ceno_emul/src/rv32im_encode.rs deleted file mode 100644 index b6a80f32b..000000000 --- a/ceno_emul/src/rv32im_encode.rs +++ /dev/null @@ -1,116 +0,0 @@ -use crate::{InsnKind, rv32im::InsnFormat}; - -const MASK_4_BITS: u32 = 0xF; -const MASK_5_BITS: u32 = 0x1F; -const MASK_6_BITS: u32 = 0x3F; -const MASK_7_BITS: u32 = 0x7F; -const MASK_8_BITS: u32 = 0xFF; -const MASK_10_BITS: u32 = 0x3FF; -const MASK_12_BITS: u32 = 0xFFF; - -/// Generate bit encoding of a RISC-V instruction. -/// -/// Values `rs1`, `rs2` and `rd1` are 5-bit register indices, and `imm` is of -/// bit length depending on the requirements of the instruction format type. -/// -/// Fields not required by the instruction's format type are ignored, so one can -/// safely pass an arbitrary value for these, say 0. -pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: u32) -> u32 { - match kind.codes().format { - InsnFormat::R => encode_r(kind, rs1, rs2, rd), - InsnFormat::I => encode_i(kind, rs1, rd, imm), - InsnFormat::S => encode_s(kind, rs1, rs2, imm), - InsnFormat::B => encode_b(kind, rs1, rs2, imm), - InsnFormat::U => encode_u(kind, rd, imm), - InsnFormat::J => encode_j(kind, rd, imm), - } -} - -// R-Type -// 25 20 15 12 7 0 -// +------+-----+-----+--------+----+-------+ -// funct7 | rs2 | rs1 | funct3 | rd | opcode -const fn encode_r(kind: InsnKind, rs1: u32, rs2: u32, rd: u32) -> u32 { - let rs2 = rs2 & MASK_5_BITS; // 5-bits mask - let rs1 = rs1 & MASK_5_BITS; - let rd = rd & MASK_5_BITS; - let func7 = kind.codes().func7; - let func3 = kind.codes().func3; - let opcode = kind.codes().opcode; - func7 << 25 | rs2 << 20 | rs1 << 15 | func3 << 12 | rd << 7 | opcode -} - -// I-Type -// 20 15 12 7 0 -// +---------+-----+--------+----+-------+ -// imm[0:11] | rs1 | funct3 | rd | opcode -const fn encode_i(kind: InsnKind, rs1: u32, rd: u32, imm: u32) -> u32 { - let rs1 = rs1 & MASK_5_BITS; - let rd = rd & MASK_5_BITS; - let func3 = kind.codes().func3; - let opcode = kind.codes().opcode; - // SRLI/SRAI use a specialization of the I-type format with the shift type in imm[10]. - let is_arithmetic_right_shift = (matches!(kind, InsnKind::SRAI) as u32) << 10; - let imm = imm & MASK_12_BITS | is_arithmetic_right_shift; - imm << 20 | rs1 << 15 | func3 << 12 | rd << 7 | opcode -} - -// S-Type -// 25 20 15 12 7 0 -// +---------+-----+-----+--------+----------+-------+ -// imm[5:11] | rs2 | rs1 | funct3 | imm[0:4] | opcode -const fn encode_s(kind: InsnKind, rs1: u32, rs2: u32, imm: u32) -> u32 { - let rs2 = rs2 & MASK_5_BITS; - let rs1 = rs1 & MASK_5_BITS; - let func3 = kind.codes().func3; - let opcode = kind.codes().opcode; - let imm_lo = imm & MASK_5_BITS; - let imm_hi = (imm >> 5) & MASK_7_BITS; // 7-bits mask - imm_hi << 25 | rs2 << 20 | rs1 << 15 | func3 << 12 | imm_lo << 7 | opcode -} - -// B-Type -// 31 25 20 15 12 8 7 0 -// +-------+-----------+-----+-----+--------+----------+---------+-------+ -// imm[12] | imm[5:10] | rs2 | rs1 | funct3 | imm[1:4] | imm[11] | opcode -const fn encode_b(kind: InsnKind, rs1: u32, rs2: u32, imm: u32) -> u32 { - let rs2 = rs2 & MASK_5_BITS; - let rs1 = rs1 & MASK_5_BITS; - let func3 = kind.codes().func3; - let opcode = kind.codes().opcode; - let imm_1_4 = (imm >> 1) & MASK_4_BITS; // skip imm[0] - let imm_5_10 = (imm >> 5) & MASK_6_BITS; - ((imm >> 12) & 1) << 31 - | imm_5_10 << 25 - | rs2 << 20 - | rs1 << 15 - | func3 << 12 - | imm_1_4 << 8 - | ((imm >> 11) & 1) << 7 - | opcode -} - -// J-Type -// 31 21 20 12 7 0 -// +-------+-----------+---------+------------+----+-------+ -// imm[20] | imm[1:10] | imm[11] | imm[12:19] | rd | opcode -const fn encode_j(kind: InsnKind, rd: u32, imm: u32) -> u32 { - let rd = rd & MASK_5_BITS; - let opcode = kind.codes().opcode; - let imm_1_10 = (imm >> 1) & MASK_10_BITS; // skip imm[0] - let imm_12_19 = (imm >> 12) & MASK_8_BITS; - ((imm >> 20) & 1) << 31 - | imm_1_10 << 21 - | ((imm >> 11) & 1) << 20 - | imm_12_19 << 12 - | rd << 7 - | opcode -} - -// U-Type -// 12 7 0 -// +----------+----+--------+ -// imm[12:31] | rd | opcode -const fn encode_u(kind: InsnKind, rd: u32, imm: u32) -> u32 { - (imm >> 12) << 12 | (rd & MASK_5_BITS) << 7 | kind.codes().opcode -} diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index a1df363e2..01ad33e01 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -1,10 +1,9 @@ use std::{collections::HashMap, fmt, mem}; use crate::{ - CENO_PLATFORM, InsnKind, PC_STEP_SIZE, Platform, + CENO_PLATFORM, InsnKind, Instruction, PC_STEP_SIZE, Platform, addr::{ByteAddr, Cycle, RegIdx, Word, WordAddr}, encode_rv32, - rv32im::DecodedInstruction, }; /// An instruction and its context in an execution trace. That is concrete values of registers and memory. @@ -22,7 +21,7 @@ use crate::{ pub struct StepRecord { cycle: Cycle, pc: Change, - insn_code: Word, + pub insn: Instruction, rs1: Option, rs2: Option, @@ -57,7 +56,7 @@ impl StepRecord { pub fn new_r_instruction( cycle: Cycle, pc: ByteAddr, - insn_code: u32, + insn_code: Instruction, rs1_read: Word, rs2_read: Word, rd: Change, @@ -79,7 +78,7 @@ impl StepRecord { pub fn new_b_instruction( cycle: Cycle, pc: Change, - insn_code: u32, + insn_code: Instruction, rs1_read: Word, rs2_read: Word, prev_cycle: Cycle, @@ -99,7 +98,7 @@ impl StepRecord { pub fn new_i_instruction( cycle: Cycle, pc: Change, - insn_code: u32, + insn_code: Instruction, rs1_read: Word, rd: Change, prev_cycle: Cycle, @@ -119,7 +118,7 @@ impl StepRecord { pub fn new_im_instruction( cycle: Cycle, pc: ByteAddr, - insn_code: u32, + insn_code: Instruction, rs1_read: Word, rd: Change, mem_op: ReadOp, @@ -148,7 +147,7 @@ impl StepRecord { pub fn new_u_instruction( cycle: Cycle, pc: ByteAddr, - insn_code: u32, + insn_code: Instruction, rd: Change, prev_cycle: Cycle, ) -> StepRecord { @@ -159,7 +158,7 @@ impl StepRecord { pub fn new_j_instruction( cycle: Cycle, pc: Change, - insn_code: u32, + insn_code: Instruction, rd: Change, prev_cycle: Cycle, ) -> StepRecord { @@ -169,7 +168,7 @@ impl StepRecord { pub fn new_s_instruction( cycle: Cycle, pc: ByteAddr, - insn_code: u32, + insn_code: Instruction, rs1_read: Word, rs2_read: Word, memory_op: WriteOp, @@ -194,7 +193,7 @@ impl StepRecord { Self::new_insn( cycle, Change::new(pc, pc + PC_STEP_SIZE), - encode_rv32(InsnKind::EANY, 0, 0, 0, 0), + encode_rv32(InsnKind::ECALL, 0, 0, 0, 0), Some(value), Some(value), Some(Change::new(value, value)), @@ -214,25 +213,23 @@ impl StepRecord { fn new_insn( cycle: Cycle, pc: Change, - insn_code: u32, + insn: Instruction, rs1_read: Option, rs2_read: Option, rd: Option>, memory_op: Option, previous_cycle: Cycle, ) -> StepRecord { - let insn = DecodedInstruction::new(insn_code); StepRecord { cycle, pc, - insn_code, rs1: rs1_read.map(|rs1| ReadOp { - addr: Platform::register_vma(insn.rs1() as RegIdx).into(), + addr: Platform::register_vma(insn.rs1).into(), value: rs1, previous_cycle, }), rs2: rs2_read.map(|rs2| ReadOp { - addr: Platform::register_vma(insn.rs2() as RegIdx).into(), + addr: Platform::register_vma(insn.rs2).into(), value: rs2, previous_cycle, }), @@ -241,6 +238,7 @@ impl StepRecord { value: rd, previous_cycle, }), + insn, memory_op, } } @@ -253,14 +251,9 @@ impl StepRecord { self.pc } - /// The instruction as a raw code. - pub fn insn_code(&self) -> Word { - self.insn_code - } - /// The instruction as a decoded structure. - pub fn insn(&self) -> DecodedInstruction { - DecodedInstruction::new(self.insn_code) + pub fn insn(&self) -> Instruction { + self.insn } pub fn rs1(&self) -> Option { @@ -327,12 +320,9 @@ impl Tracer { self.record.pc.after = pc; } - // TODO(Matthias): this should perhaps record `DecodedInstruction`s instead - // of raw codes, or perhaps only the pc, because we can always look up the - // instruction in the program. - pub fn fetch(&mut self, pc: WordAddr, value: Word) { + pub fn fetch(&mut self, pc: WordAddr, value: Instruction) { self.record.pc.before = pc.baddr(); - self.record.insn_code = value; + self.record.insn = value; } pub fn load_register(&mut self, idx: RegIdx, value: Word) { diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index e071cb000..838779979 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -5,7 +5,7 @@ use crate::{ PC_STEP_SIZE, Program, WORD_SIZE, addr::{ByteAddr, RegIdx, Word, WordAddr}, platform::Platform, - rv32im::{DecodedInstruction, Emulator, TrapCause}, + rv32im::{Instruction, TrapCause}, tracer::{Change, StepRecord, Tracer}, }; use anyhow::{Result, anyhow}; @@ -77,18 +77,17 @@ impl VMState { } pub fn iter_until_halt(&mut self) -> impl Iterator> + '_ { - let emu = Emulator::new(); from_fn(move || { if self.halted() { None } else { - Some(self.step(&emu)) + Some(self.step()) } }) } - fn step(&mut self, emu: &Emulator) -> Result { - emu.step(self)?; + fn step(&mut self) -> Result { + crate::rv32im::step(self)?; let step = self.tracer.advance(); if step.is_busy_loop() && !self.halted() { Err(anyhow!("Stuck in loop {}", "{}")) @@ -121,7 +120,7 @@ impl EmuContext for VMState { // Treat unknown ecalls as all powerful instructions: // Read two registers, write one register, write one memory word, and branch. tracing::warn!("ecall ignored: syscall_id={}", function); - self.store_register(DecodedInstruction::RD_NULL as RegIdx, 0)?; + self.store_register(Instruction::RD_NULL as RegIdx, 0)?; // Example ecall effect - any writable address will do. let addr = (self.platform.stack_top - WORD_SIZE as u32).into(); self.store_memory(addr, self.peek_memory(addr))?; @@ -136,7 +135,7 @@ impl EmuContext for VMState { Err(anyhow!("Trap {:?}", cause)) // Crash. } - fn on_normal_end(&mut self, _decoded: &DecodedInstruction) { + fn on_normal_end(&mut self, _decoded: &Instruction) { self.tracer.store_pc(ByteAddr(self.pc)); } @@ -189,8 +188,7 @@ impl EmuContext for VMState { *self.memory.get(&addr).unwrap_or(&0) } - // TODO(Matthias): this should really return `Result` - fn fetch(&mut self, pc: WordAddr) -> Option { + fn fetch(&mut self, pc: WordAddr) -> Option { let byte_pc: ByteAddr = pc.into(); let relative_pc = byte_pc.0.wrapping_sub(self.program.base_address); let idx = (relative_pc / WORD_SIZE as u32) as usize; @@ -206,8 +204,4 @@ impl EmuContext for VMState { fn check_data_store(&self, addr: ByteAddr) -> bool { self.platform.can_write(addr.0) } - - fn check_insn_load(&self, addr: ByteAddr) -> bool { - self.platform.can_execute(addr.0) - } } diff --git a/ceno_emul/tests/test_elf.rs b/ceno_emul/tests/test_elf.rs index ca7a14c1c..7448d4508 100644 --- a/ceno_emul/tests/test_elf.rs +++ b/ceno_emul/tests/test_elf.rs @@ -15,7 +15,7 @@ fn test_ceno_rt_panic() -> Result<()> { let mut state = VMState::new_from_elf(CENO_PLATFORM, program_elf)?; let steps = run(&mut state)?; let last = steps.last().unwrap(); - assert_eq!(last.insn().codes().kind, InsnKind::EANY); + assert_eq!(last.insn().kind, InsnKind::ECALL); assert_eq!(last.rs1().unwrap().value, Platform::ecall_halt()); assert_eq!(last.rs2().unwrap().value, 1); // panic / halt(1) Ok(()) diff --git a/ceno_emul/tests/test_vm_trace.rs b/ceno_emul/tests/test_vm_trace.rs index 2aa5f0da2..ef3b8820e 100644 --- a/ceno_emul/tests/test_vm_trace.rs +++ b/ceno_emul/tests/test_vm_trace.rs @@ -6,8 +6,8 @@ use std::{ }; use ceno_emul::{ - CENO_PLATFORM, Cycle, EmuContext, InsnKind, Platform, Program, StepRecord, Tracer, VMState, - WORD_SIZE, WordAddr, + CENO_PLATFORM, Cycle, EmuContext, InsnKind, Instruction, Platform, Program, StepRecord, Tracer, + VMState, WordAddr, encode_rv32, }; #[test] @@ -15,17 +15,8 @@ fn test_vm_trace() -> Result<()> { let program = Program::new( CENO_PLATFORM.pc_base(), CENO_PLATFORM.pc_base(), - PROGRAM_FIBONACCI_20.to_vec(), - PROGRAM_FIBONACCI_20 - .iter() - .enumerate() - .map(|(insn_idx, &insn)| { - ( - CENO_PLATFORM.pc_base() + (WORD_SIZE * insn_idx) as u32, - insn, - ) - }) - .collect(), + program_fibonacci_20(), + Default::default(), ); let mut ctx = VMState::new(CENO_PLATFORM, Arc::new(program)); @@ -36,7 +27,7 @@ fn test_vm_trace() -> Result<()> { assert_eq!(ctx.peek_register(2), x2); assert_eq!(ctx.peek_register(3), x3); - let ops: Vec = steps.iter().map(|step| step.insn().codes().kind).collect(); + let ops: Vec = steps.iter().map(|step| step.insn().kind).collect(); assert_eq!(ops, expected_ops_fibonacci_20()); assert_eq!( @@ -66,27 +57,25 @@ fn run(state: &mut VMState) -> Result> { } /// Example in RISC-V bytecode and assembly. -const PROGRAM_FIBONACCI_20: [u32; 7] = [ - // x1 = 10; - // x3 = 1; - // immediate rs1 f3 rd opcode - 0b_000000001010_00000_000_00001_0010011, // addi x1, x0, 10 - 0b_000000000001_00000_000_00011_0010011, // addi x3, x0, 1 - // loop { - // x1 -= 1; - // immediate rs1 f3 rd opcode - 0b_111111111111_00001_000_00001_0010011, // addi x1, x1, -1 - // x2 += x3; - // x3 += x2; - // zeros rs2 rs1 f3 rd opcode - 0b_0000000_00011_00010_000_00010_0110011, // add x2, x2, x3 - 0b_0000000_00011_00010_000_00011_0110011, // add x3, x2, x3 - // if x1 == 0 { break } - // imm rs2 rs1 f3 imm opcode - 0b_1_111111_00000_00001_001_1010_1_1100011, // bne x1, x0, -12 - // ecall HALT, SUCCESS - 0b_000000000000_00000_000_00000_1110011, -]; +pub fn program_fibonacci_20() -> Vec { + vec![ + // x1 = 10; + // x3 = 1; + encode_rv32(InsnKind::ADDI, 0, 0, 1, 10), + encode_rv32(InsnKind::ADDI, 0, 0, 3, 1), + // loop { + // x1 -= 1; + encode_rv32(InsnKind::ADDI, 1, 0, 1, -1), + // x2 += x3; + // x3 += x2; + encode_rv32(InsnKind::ADD, 2, 3, 2, 0), + encode_rv32(InsnKind::ADD, 2, 3, 3, 0), + // if x1 == 0 { break } + encode_rv32(InsnKind::BNE, 1, 0, 0, -12), + // ecall HALT, SUCCESS + encode_rv32(InsnKind::ECALL, 0, 0, 0, 0), + ] +} /// Rust version of the example. Reconstruct the output. fn expected_fibonacci_20() -> (u32, u32, u32) { @@ -115,7 +104,7 @@ fn expected_ops_fibonacci_20() -> Vec { for _ in 0..10 { ops.extend(&[ADDI, ADD, ADD, BNE]); } - ops.push(EANY); + ops.push(ECALL); ops } diff --git a/ceno_zkvm/examples/riscv_opcodes.rs b/ceno_zkvm/examples/riscv_opcodes.rs index fe119bb25..81d1f6eb3 100644 --- a/ceno_zkvm/examples/riscv_opcodes.rs +++ b/ceno_zkvm/examples/riscv_opcodes.rs @@ -1,7 +1,6 @@ use std::{panic, sync::Arc, time::Instant}; use ceno_zkvm::{ - declare_program, instructions::riscv::{MemPadder, MmuConfig, Rv32imConfig, constants::EXIT_PC}, scheme::{mock_prover::MockProver, prover::ZKVMProver}, state::GlobalState, @@ -13,8 +12,9 @@ use clap::Parser; use ceno_emul::{ CENO_PLATFORM, EmuContext, - InsnKind::{ADD, BLTU, EANY, LUI, LW}, - PC_WORD_SIZE, Platform, Program, StepRecord, Tracer, VMState, Word, WordAddr, encode_rv32, + InsnKind::{ADD, ADDI, BLTU, ECALL, LW}, + Instruction, Platform, Program, StepRecord, Tracer, VMState, Word, WordAddr, encode_rv32, + encode_rv32u, }; use ceno_zkvm::{ scheme::{PublicValues, constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier}, @@ -28,33 +28,26 @@ use mpcs::{Basefold, BasefoldRSParams, PolynomialCommitmentScheme}; use sumcheck::macros::{entered_span, exit_span}; use tracing_subscriber::{EnvFilter, Registry, fmt, fmt::format::FmtSpan, layer::SubscriberExt}; use transcript::BasicTranscript as Transcript; -const PROGRAM_SIZE: usize = 16; // For now, we assume registers // - x0 is not touched, // - x1 is initialized to 1, // - x2 is initialized to -1, // - x3 is initialized to loop bound. // we use x4 to hold the acc_sum. -#[allow(clippy::unusual_byte_groupings)] -const ECALL_HALT: u32 = 0b_000000000000_00000_000_00000_1110011; -#[allow(clippy::unusual_byte_groupings)] -const PROGRAM_CODE: [u32; PROGRAM_SIZE] = { - let mut program: [u32; PROGRAM_SIZE] = [ECALL_HALT; PROGRAM_SIZE]; - declare_program!( - program, - encode_rv32(LUI, 0, 0, 10, CENO_PLATFORM.public_io.start), // lui x10, public_io - encode_rv32(LW, 10, 0, 1, 0), // lw x1, 0(x10) - encode_rv32(LW, 10, 0, 2, 4), // lw x2, 4(x10) - encode_rv32(LW, 10, 0, 3, 8), // lw x3, 8(x10) +fn program_code() -> Vec { + vec![ + encode_rv32u(ADDI, 0, 0, 10, CENO_PLATFORM.public_io.start), // lui x10, public_io + encode_rv32(LW, 10, 0, 1, 0), // lw x1, 0(x10) + encode_rv32(LW, 10, 0, 2, 4), // lw x2, 4(x10) + encode_rv32(LW, 10, 0, 3, 8), // lw x3, 8(x10) // Main loop. - encode_rv32(ADD, 1, 4, 4, 0), // add x4, x1, x4 - encode_rv32(ADD, 2, 3, 3, 0), // add x3, x2, x3 - encode_rv32(BLTU, 0, 3, 0, -8_i32 as u32), // bltu x0, x3, -8 + encode_rv32(ADD, 1, 4, 4, 0), // add x4, x1, x4 + encode_rv32(ADD, 2, 3, 3, 0), // add x3, x2, x3 + encode_rv32(BLTU, 0, 3, 0, -8_i32), // bltu x0, x3, -8 // End. - ECALL_HALT, // ecall halt - ); - program -}; + encode_rv32(ECALL, 0, 0, 0, 0), // ecall (halt) + ] +} type ExampleProgramTableCircuit = ProgramTableCircuit; /// Simple program to greet a person @@ -74,21 +67,14 @@ fn main() { let args = Args::parse(); type E = GoldilocksExt2; type Pcs = Basefold; + let program_code = program_code(); + let program_size = program_code.len(); let program = Program::new( CENO_PLATFORM.pc_base(), CENO_PLATFORM.pc_base(), - PROGRAM_CODE.to_vec(), - PROGRAM_CODE - .iter() - .enumerate() - .map(|(insn_idx, &insn)| { - ( - (insn_idx * PC_WORD_SIZE) as u32 + CENO_PLATFORM.pc_base(), - insn, - ) - }) - .collect(), + program_code, + Default::default(), ); let mem_addresses = CENO_PLATFORM.ram.clone(); let io_addresses = CENO_PLATFORM.public_io.clone(); @@ -117,7 +103,7 @@ fn main() { let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim(pcs_param, 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let program_params = ProgramParams { - program_size: PROGRAM_SIZE, + program_size, ..Default::default() }; let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); @@ -193,8 +179,7 @@ fn main() { .iter() .rev() .find(|record| { - record.insn().codes().kind == EANY - && record.rs1().unwrap().value == Platform::ecall_halt() + record.insn().kind == ECALL && record.rs1().unwrap().value == Platform::ecall_halt() }) .expect("halt record not found"); diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index b1d5b4d77..95bcb73be 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -11,8 +11,8 @@ use crate::{ tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, }; use ceno_emul::{ - ByteAddr, EmuContext, InsnKind::EANY, IterAddresses, Platform, Program, StepRecord, Tracer, - VMState, WORD_SIZE, WordAddr, + ByteAddr, EmuContext, InsnKind, IterAddresses, Platform, Program, StepRecord, Tracer, VMState, + WORD_SIZE, WordAddr, }; use ff_ext::ExtensionField; use itertools::{Itertools, MinMaxResult, chain}; @@ -73,7 +73,7 @@ fn emulate_program( .iter() .rev() .find(|record| { - record.insn().codes().kind == EANY + record.insn().kind == InsnKind::ECALL && record.rs1().unwrap().value == Platform::ecall_halt() }) .and_then(|halt_record| halt_record.rs2()) @@ -537,9 +537,8 @@ fn format_segments( fn format_segment(platform: &Platform, addr: u32) -> String { format!( - "{}{}{}", + "{}{}", if platform.can_read(addr) { "R" } else { "-" }, if platform.can_write(addr) { "W" } else { "-" }, - if platform.can_execute(addr) { "X" } else { "-" }, ) } diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index de0dfc771..cc5e12b12 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -28,7 +28,6 @@ mod i_insn; mod insn_base; mod j_insn; mod r_insn; -mod u_insn; mod ecall_insn; @@ -37,15 +36,13 @@ mod memory; mod s_insn; #[cfg(test)] mod test; -#[cfg(test)] -mod test_utils; pub trait RIVInstruction { const INST_KIND: InsnKind; } pub use arith::{AddInstruction, SubInstruction}; -pub use jump::{AuipcInstruction, JalInstruction, JalrInstruction, LuiInstruction}; +pub use jump::{JalInstruction, JalrInstruction}; pub use memory::{ LbInstruction, LbuInstruction, LhInstruction, LhuInstruction, LwInstruction, SbInstruction, ShInstruction, SwInstruction, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index c4c09629b..b7f50e898 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -89,7 +89,7 @@ mod test { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{Instruction, riscv::test_utils::imm_i}, + instructions::Instruction, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; @@ -110,7 +110,7 @@ mod test { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm_i(3)); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, 3); let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, @@ -143,7 +143,8 @@ mod test { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm_i(-3)); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, -3); + let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 9847c66d5..5b53ec904 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -5,7 +5,7 @@ use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, - instructions::{Instruction, riscv::test_utils::imm_b}, + instructions::Instruction, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; @@ -32,7 +32,7 @@ fn impl_opcode_beq(equal: bool) { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, imm_b(8)); + let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); let pc_offset = if equal { 8 } else { PC_STEP_SIZE }; let (raw_witin, lkm) = BeqInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ @@ -70,7 +70,7 @@ fn impl_opcode_bne(equal: bool) { .unwrap() .unwrap(); - let insn_code = encode_rv32(InsnKind::BNE, 2, 3, 0, imm_b(8)); + let insn_code = encode_rv32(InsnKind::BNE, 2, 3, 0, 8); let pc_offset = if equal { PC_STEP_SIZE } else { 8 }; let (raw_witin, lkm) = BneInstruction::assign_instances(&config, cb.cs.num_witin as usize, vec![ @@ -111,8 +111,8 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, imm_b(-8)); - println!("{:#b}", insn_code); + let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, -8); + println!("{:?}", insn_code); let (raw_witin, lkm) = BltuInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( @@ -153,7 +153,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, imm_b(-8)); + let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, -8); let (raw_witin, lkm) = BgeuInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( @@ -195,7 +195,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, imm_b(-8)); + let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, -8); let (raw_witin, lkm) = BltInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( @@ -237,7 +237,7 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { MOCK_PC_START + PC_STEP_SIZE }; - let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, imm_b(-8)); + let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, -8); let (raw_witin, lkm) = BgeInstruction::assign_instances(&config, circuit_builder.cs.num_witin as usize, vec![ StepRecord::new_b_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 1d98d1271..b73c932a1 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -33,13 +33,15 @@ impl Instruction for DummyInstruction, ) -> Result { - let codes = I::INST_KIND.codes(); + let kind = I::INST_KIND; + let format = InsnFormat::from(kind); + let category = InsnCategory::from(kind); // ECALL can do everything. - let is_ecall = matches!(codes.kind, InsnKind::EANY); + let is_ecall = matches!(kind, InsnKind::ECALL); // Regular instructions do what is implied by their format. - let (with_rs1, with_rs2, with_rd) = match codes.format { + let (with_rs1, with_rs2, with_rd) = match format { _ if is_ecall => (true, true, true), InsnFormat::R => (true, true, true), InsnFormat::I => (true, false, true), @@ -48,10 +50,10 @@ impl Instruction for DummyInstruction (false, false, true), InsnFormat::J => (false, false, true), }; - let with_mem_write = matches!(codes.category, InsnCategory::Store) || is_ecall; - let with_mem_read = matches!(codes.category, InsnCategory::Load); - let branching = matches!(codes.category, InsnCategory::Branch) - || matches!(codes.kind, InsnKind::JAL | InsnKind::JALR) + let with_mem_write = matches!(category, InsnCategory::Store) || is_ecall; + let with_mem_read = matches!(category, InsnCategory::Load); + let branching = matches!(category, InsnCategory::Branch) + || matches!(kind, InsnKind::JAL | InsnKind::JALR) || is_ecall; DummyConfig::construct_circuit( @@ -174,7 +176,7 @@ impl DummyConfig { // Fetch instruction // The register IDs of ECALL is fixed, not encoded. - let is_ecall = matches!(kind, InsnKind::EANY); + let is_ecall = matches!(kind, InsnKind::ECALL); let rs1_id = match &rs1 { Some((r, _)) if !is_ecall => r.id.expr(), _ => 0.into(), diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index df1eb0572..5f3a03bb9 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -30,7 +30,7 @@ fn test_dummy_ecall() { .unwrap(); let step = StepRecord::new_ecall_any(4, MOCK_PC_START); - let insn_code = step.insn_code(); + let insn_code = step.insn(); let (raw_witin, lkm) = EcallDummy::assign_instances(&config, cb.cs.num_witin as usize, vec![step]).unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index 0d7a3315a..261772638 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -7,7 +7,7 @@ use super::{RIVInstruction, dummy::DummyInstruction}; pub struct EcallOp; impl RIVInstruction for EcallOp { - const INST_KIND: InsnKind = InsnKind::EANY; + const INST_KIND: InsnKind = InsnKind::ECALL; } /// Unsafe. A dummy ecall circuit that ignores unimplemented functions. pub type EcallDummy = DummyInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs index 078223617..8de7e02c8 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -10,7 +10,7 @@ use crate::{ tables::InsnRecord, witness::LkMultiplicity, }; -use ceno_emul::{InsnKind::EANY, PC_STEP_SIZE, Platform, StepRecord, Tracer}; +use ceno_emul::{InsnKind::ECALL, PC_STEP_SIZE, Platform, StepRecord, Tracer}; use ff_ext::ExtensionField; pub struct EcallInstructionConfig { @@ -38,7 +38,7 @@ impl EcallInstructionConfig { cb.lk_fetch(&InsnRecord::new( pc.expr(), - EANY.into(), + ECALL.into(), None, 0.into(), 0.into(), diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index 50708d07a..b57aadbbb 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -1,12 +1,8 @@ -mod auipc; mod jal; mod jalr; -mod lui; -pub use auipc::AuipcInstruction; pub use jal::JalInstruction; pub use jalr::JalrInstruction; -pub use lui::LuiInstruction; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/jump/auipc.rs b/ceno_zkvm/src/instructions/riscv/jump/auipc.rs deleted file mode 100644 index ed1f07d1f..000000000 --- a/ceno_zkvm/src/instructions/riscv/jump/auipc.rs +++ /dev/null @@ -1,94 +0,0 @@ -use std::marker::PhantomData; - -use ceno_emul::InsnKind; -use ff_ext::ExtensionField; - -use crate::{ - Value, - circuit_builder::CircuitBuilder, - error::ZKVMError, - expression::{ToExpr, WitIn}, - instructions::{ - Instruction, - riscv::{constants::UInt, u_insn::UInstructionConfig}, - }, - set_val, - tables::InsnRecord, - utils::i64_to_base, - witness::LkMultiplicity, -}; - -pub struct AuipcConfig { - pub u_insn: UInstructionConfig, - pub imm: WitIn, - pub overflow_bit: WitIn, - pub rd_written: UInt, -} - -pub struct AuipcInstruction(PhantomData); - -/// AUIPC instruction circuit -impl Instruction for AuipcInstruction { - type InstructionConfig = AuipcConfig; - - fn name() -> String { - format!("{:?}", InsnKind::AUIPC) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - ) -> Result, ZKVMError> { - let imm = circuit_builder.create_witin(|| "imm"); - let rd_written = UInt::new(|| "rd_written", circuit_builder)?; - - let u_insn = UInstructionConfig::construct_circuit( - circuit_builder, - InsnKind::AUIPC, - &imm.expr(), - rd_written.register_expr(), - )?; - - let overflow_bit = circuit_builder.create_witin(|| "overflow_bit"); - circuit_builder.assert_bit(|| "is_bit", overflow_bit.expr())?; - - // assert: imm + pc = rd_written + overflow_bit * 2^32 - // valid formulation of mod 2^32 arithmetic because: - // - imm and pc are constrained to 4 bytes by instruction table lookup - // - rd_written is constrained to 4 bytes by UInt checked limbs - circuit_builder.require_equal( - || "imm+pc = rd_written+2^32*overflow", - imm.expr() + u_insn.vm_state.pc.expr(), - rd_written.value() + overflow_bit.expr() * (1u64 << 32), - )?; - - Ok(AuipcConfig { - u_insn, - imm, - overflow_bit, - rd_written, - }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [E::BaseField], - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - let pc: u32 = step.pc().before.0; - let imm = InsnRecord::imm_internal(&step.insn()); - let (sum, overflow) = pc.overflowing_add(imm as u32); - - set_val!(instance, config.imm, i64_to_base::(imm)); - set_val!(instance, config.overflow_bit, overflow as u64); - - let sum_limbs = Value::new(sum, lk_multiplicity); - config.rd_written.assign_value(instance, sum_limbs); - - config - .u_insn - .assign_instance(instance, lk_multiplicity, step)?; - - Ok(()) - } -} diff --git a/ceno_zkvm/src/instructions/riscv/jump/lui.rs b/ceno_zkvm/src/instructions/riscv/jump/lui.rs deleted file mode 100644 index 8ad9d497e..000000000 --- a/ceno_zkvm/src/instructions/riscv/jump/lui.rs +++ /dev/null @@ -1,64 +0,0 @@ -use std::marker::PhantomData; - -use ceno_emul::InsnKind; -use ff_ext::ExtensionField; - -use crate::{ - Value, - circuit_builder::CircuitBuilder, - error::ZKVMError, - instructions::{ - Instruction, - riscv::{constants::UInt, u_insn::UInstructionConfig}, - }, - witness::LkMultiplicity, -}; - -pub struct LuiConfig { - pub u_insn: UInstructionConfig, - pub rd_written: UInt, -} - -pub struct LuiInstruction(PhantomData); - -/// LUI instruction circuit -impl Instruction for LuiInstruction { - type InstructionConfig = LuiConfig; - - fn name() -> String { - format!("{:?}", InsnKind::LUI) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - ) -> Result, ZKVMError> { - let rd_written = UInt::new(|| "rd_limbs", circuit_builder)?; - - // rd_written = imm, so just enforce that U-type immediate from program - // table is equal to rd_written value - let u_insn = UInstructionConfig::construct_circuit( - circuit_builder, - InsnKind::LUI, - &rd_written.value(), // instruction immediate for program table lookup - rd_written.register_expr(), - )?; - - Ok(LuiConfig { u_insn, rd_written }) - } - - fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [E::BaseField], - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - config - .u_insn - .assign_instance(instance, lk_multiplicity, step)?; - - let rd = Value::new(step.rd().unwrap().value.after, lk_multiplicity); - config.rd_written.assign_limbs(instance, rd.as_u16_limbs()); - - Ok(()) - } -} diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 4453f4c3a..648b17f66 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -3,14 +3,11 @@ use goldilocks::GoldilocksExt2; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{ - Instruction, - riscv::test_utils::{imm_j, imm_u}, - }, + instructions::Instruction, scheme::mock_prover::{MOCK_PC_START, MockProver}, }; -use super::{AuipcInstruction, JalInstruction, JalrInstruction, LuiInstruction}; +use super::{JalInstruction, JalrInstruction}; #[test] fn test_opcode_jal() { @@ -29,7 +26,7 @@ fn test_opcode_jal() { let pc_offset: i32 = -8i32; let new_pc: ByteAddr = ByteAddr(MOCK_PC_START.0.wrapping_add_signed(pc_offset)); - let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, imm_j(pc_offset)); + let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, pc_offset); let (raw_witin, lkm) = JalInstruction::::assign_instances( &config, cb.cs.num_witin as usize, @@ -64,7 +61,7 @@ fn test_opcode_jalr() { let imm = -15i32; let rs1_read: Word = 100u32; let new_pc: ByteAddr = ByteAddr(rs1_read.wrapping_add_signed(imm) & (!1)); - let insn_code = encode_rv32(InsnKind::JALR, 2, 0, 4, imm as u32); + let insn_code = encode_rv32(InsnKind::JALR, 2, 0, 4, imm); let (raw_witin, lkm) = JalrInstruction::::assign_instances( &config, @@ -79,72 +76,5 @@ fn test_opcode_jalr() { )], ) .unwrap(); - - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); -} - -#[test] -fn test_opcode_lui() { - let mut cs = ConstraintSystem::::new(|| "riscv"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "lui", - |cb| { - let config = LuiInstruction::::construct_circuit(cb); - Ok(config) - }, - ) - .unwrap() - .unwrap(); - - let imm_value = imm_u(0x90005); - let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm_value); - let (raw_witin, lkm) = LuiInstruction::::assign_instances( - &config, - cb.cs.num_witin as usize, - vec![StepRecord::new_u_instruction( - 4, - MOCK_PC_START, - insn_code, - Change::new(0, imm_value), - 0, - )], - ) - .unwrap(); - - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); -} - -#[test] -fn test_opcode_auipc() { - let mut cs = ConstraintSystem::::new(|| "riscv"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "auipc", - |cb| { - let config = AuipcInstruction::::construct_circuit(cb); - Ok(config) - }, - ) - .unwrap() - .unwrap(); - - let imm_value = imm_u(0x90005); - let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm_value); - let (raw_witin, lkm) = AuipcInstruction::::assign_instances( - &config, - cb.cs.num_witin as usize, - vec![StepRecord::new_u_instruction( - 4, - MOCK_PC_START, - insn_code, - Change::new(0, MOCK_PC_START.0.wrapping_add(imm_value)), - 0, - )], - ) - .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index 9328e27a7..4e9bd5faa 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -123,7 +123,7 @@ impl LogicConfig { #[cfg(test)] mod test { - use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; use goldilocks::GoldilocksExt2; use crate::{ @@ -190,7 +190,7 @@ mod test { .unwrap() .unwrap(); - let insn_code = encode_rv32(I::INST_KIND, 2, 0, 4, imm); + let insn_code = encode_rv32u(I::INST_KIND, 2, 0, 4, imm); let (raw_witin, lkm) = LogicInstruction::::assign_instances( &config, cb.cs.num_witin as usize, diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 6243f197b..4785994d3 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -74,7 +74,7 @@ fn load(mem_value: Word, insn: InsnKind, shift: u32) -> Word { } } -fn impl_opcode_store>(imm: u32) { +fn impl_opcode_store>(imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb @@ -92,7 +92,7 @@ fn impl_opcode_store sb(prev_mem_value, rs2_word, unaligned_addr.shift()), InsnKind::SH => sh(prev_mem_value, rs2_word, unaligned_addr.shift()), @@ -122,7 +122,7 @@ fn impl_opcode_store>(imm: u32) { +fn impl_opcode_load>(imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); let config = cb @@ -140,7 +140,7 @@ fn impl_opcode_load>(imm) } -fn impl_opcode_sh(imm: u32) { +fn impl_opcode_sh(imm: i32) { assert_eq!(imm & 0x01, 0); impl_opcode_store::>(imm) } -fn impl_opcode_sw(imm: u32) { +fn impl_opcode_sw(imm: i32) { assert_eq!(imm & 0x03, 0); impl_opcode_store::>(imm) } @@ -187,9 +187,8 @@ fn test_sb() { impl_opcode_sb(10); impl_opcode_sb(15); - let neg_one = u32::MAX; - for i in 0..4 { - impl_opcode_sb(neg_one - i); + for i in -4..0 { + impl_opcode_sb(i); } } @@ -198,9 +197,8 @@ fn test_sh() { impl_opcode_sh(0); impl_opcode_sh(2); - let neg_two = u32::MAX - 1; - for i in [0, 2] { - impl_opcode_sh(neg_two - i) + for i in [-4, -2] { + impl_opcode_sh(i) } } @@ -209,8 +207,7 @@ fn test_sw() { impl_opcode_sw(0); impl_opcode_sw(4); - let neg_four = u32::MAX - 3; - impl_opcode_sw(neg_four); + impl_opcode_sw(-4); } #[test] @@ -220,10 +217,8 @@ fn test_lb() { impl_opcode_load::>(2); impl_opcode_load::>(3); - let neg_one = u32::MAX; - // imm = -1, -2, -3 - for i in 0..3 { - impl_opcode_load::>(neg_one - i); + for i in -3..0 { + impl_opcode_load::>(i); } } @@ -234,10 +229,8 @@ fn test_lbu() { impl_opcode_load::>(2); impl_opcode_load::>(3); - let neg_one = u32::MAX; - // imm = -1, -2, -3 - for i in 0..3 { - impl_opcode_load::>(neg_one - i); + for i in -3..0 { + impl_opcode_load::>(i); } } @@ -247,10 +240,8 @@ fn test_lh() { impl_opcode_load::>(2); impl_opcode_load::>(4); - let neg_two = u32::MAX - 1; - // imm = -2, -4 - for i in [0, 2] { - impl_opcode_load::>(neg_two - i); + for i in [-4, -2] { + impl_opcode_load::>(i); } } @@ -260,10 +251,8 @@ fn test_lhu() { impl_opcode_load::>(2); impl_opcode_load::>(4); - let neg_two = u32::MAX - 1; - // imm = -2, -4 - for i in [0, 2] { - impl_opcode_load::>(neg_two - i); + for i in [-4, -2] { + impl_opcode_load::>(i); } } @@ -271,5 +260,5 @@ fn test_lhu() { fn test_lw() { impl_opcode_load::>(0); impl_opcode_load::>(4); - impl_opcode_load::>(u32::MAX - 3); // imm = -4 + impl_opcode_load::>(-4); } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index d404cd073..a00d2c7e4 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -39,10 +39,7 @@ use std::collections::{BTreeMap, BTreeSet}; use strum::IntoEnumIterator; use super::{ - arith::AddInstruction, - branch::BltuInstruction, - ecall::HaltInstruction, - jump::{JalInstruction, LuiInstruction}, + arith::AddInstruction, branch::BltuInstruction, ecall::HaltInstruction, jump::JalInstruction, memory::LwInstruction, }; @@ -88,8 +85,6 @@ pub struct Rv32imConfig { // Jump Opcodes pub jal_config: as Instruction>::InstructionConfig, pub jalr_config: as Instruction>::InstructionConfig, - pub auipc_config: as Instruction>::InstructionConfig, - pub lui_config: as Instruction>::InstructionConfig, // Memory Opcodes pub lw_config: as Instruction>::InstructionConfig, @@ -155,10 +150,8 @@ impl Rv32imConfig { let bgeu_config = cs.register_opcode_circuit::>(); // jump opcodes - let lui_config = cs.register_opcode_circuit::>(); let jal_config = cs.register_opcode_circuit::>(); let jalr_config = cs.register_opcode_circuit::>(); - let auipc_config = cs.register_opcode_circuit::>(); // memory opcodes let lw_config = cs.register_opcode_circuit::>(); @@ -218,10 +211,8 @@ impl Rv32imConfig { bge_config, bgeu_config, // jump opcodes - lui_config, jal_config, jalr_config, - auipc_config, // memory opcodes sw_config, sh_config, @@ -287,8 +278,6 @@ impl Rv32imConfig { // jump fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); - fixed.register_opcode_circuit::>(cs); - fixed.register_opcode_circuit::>(cs); // memory fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); @@ -318,20 +307,19 @@ impl Rv32imConfig { witness: &mut ZKVMWitnesses, steps: Vec, ) -> Result { - let mut all_records: BTreeMap> = InsnKind::iter() - .map(|insn_kind| ((insn_kind as usize), Vec::new())) + let mut all_records: BTreeMap> = InsnKind::iter() + .map(|insn_kind| (insn_kind, Vec::new())) .collect(); let mut halt_records = Vec::new(); steps.into_iter().for_each(|record| { - let insn_kind = record.insn().codes().kind; + let insn_kind = record.insn.kind; match insn_kind { // ecall / halt - EANY if record.rs1().unwrap().value == Platform::ecall_halt() => { + InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { halt_records.push(record); } // other type of ecalls are handled by dummy ecall instruction _ => { - let insn_kind = insn_kind as usize; // it's safe to unwrap as all_records are initialized with Vec::new() all_records.get_mut(&insn_kind).unwrap().push(record); } @@ -353,7 +341,7 @@ impl Rv32imConfig { witness.assign_opcode_circuit::<$instruction>( cs, &self.$config, - all_records.remove(&($insn_kind as usize)).unwrap(), + all_records.remove(&($insn_kind)).unwrap(), )?; }; } @@ -393,8 +381,6 @@ impl Rv32imConfig { // jump assign_opcode!(JAL, JalInstruction, jal_config); assign_opcode!(JALR, JalrInstruction, jalr_config); - assign_opcode!(AUIPC, AuipcInstruction, auipc_config); - assign_opcode!(LUI, LuiInstruction, lui_config); // memory assign_opcode!(LW, LwInstruction, lw_config); assign_opcode!(LB, LbInstruction, lb_config); @@ -411,9 +397,8 @@ impl Rv32imConfig { assert_eq!( all_records.keys().cloned().collect::>(), // these are opcodes that haven't been implemented - [INVALID, DIV, REM, REMU, EANY] + [INVALID, DIV, REM, REMU, ECALL] .into_iter() - .map(|insn_kind| insn_kind as usize) .collect::>(), ); Ok(GroupedSteps(all_records)) @@ -439,7 +424,7 @@ impl Rv32imConfig { } /// Opaque type to pass unimplemented instructions from Rv32imConfig to DummyExtraConfig. -pub struct GroupedSteps(BTreeMap>); +pub struct GroupedSteps(BTreeMap>); /// Fake version of what is missing in Rv32imConfig, for some tests. pub struct DummyExtraConfig { @@ -487,7 +472,7 @@ impl DummyExtraConfig { witness.assign_opcode_circuit::<$instruction>( cs, &self.$config, - steps.remove(&($insn_kind as usize)).unwrap(), + steps.remove(&($insn_kind)).unwrap(), )?; }; } @@ -495,10 +480,11 @@ impl DummyExtraConfig { assign_opcode!(DIV, DivDummy, div_config); assign_opcode!(REM, RemDummy, rem_config); assign_opcode!(REMU, RemuDummy, remu_config); - assign_opcode!(EANY, EcallDummy, ecall_config); + assign_opcode!(ECALL, EcallDummy, ecall_config); - let _ = steps.remove(&(INVALID as usize)); - assert!(steps.is_empty()); + let _ = steps.remove(&INVALID); + let keys: Vec<&InsnKind> = steps.keys().collect::>(); + assert!(steps.is_empty(), "unimplemented opcodes: {:?}", keys); Ok(()) } } diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 50ebe4ff0..b16dd5bf2 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -188,7 +188,7 @@ impl Instruction for ShiftImmInstructio #[cfg(test)] mod test { - use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_emul::{Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; use goldilocks::GoldilocksExt2; use super::{ShiftImmInstruction, SlliOp, SraiOp, SrliOp}; @@ -255,17 +255,17 @@ mod test { let (prefix, insn_code, rd_written) = match I::INST_KIND { InsnKind::SLLI => ( "SLLI", - encode_rv32(InsnKind::SLLI, 2, 0, 4, imm), + encode_rv32u(InsnKind::SLLI, 2, 0, 4, imm), rs1_read << imm, ), InsnKind::SRAI => ( "SRAI", - encode_rv32(InsnKind::SRAI, 2, 0, 4, imm), + encode_rv32u(InsnKind::SRAI, 2, 0, 4, imm), (rs1_read as i32 >> imm as i32) as u32, ), InsnKind::SRLI => ( "SRLI", - encode_rv32(InsnKind::SRLI, 2, 0, 4, imm), + encode_rv32u(InsnKind::SRLI, 2, 0, 4, imm), rs1_read >> imm, ), _ => unreachable!(), diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index a0766b425..cbab2fe16 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -228,7 +228,7 @@ mod test { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let insn_code = encode_rv32(I::INST_KIND, 2, 0, 4, imm as u32); + let insn_code = encode_rv32(I::INST_KIND, 2, 0, 4, imm); let config = cb .namespace( diff --git a/ceno_zkvm/src/instructions/riscv/test_utils.rs b/ceno_zkvm/src/instructions/riscv/test_utils.rs deleted file mode 100644 index 416ce628c..000000000 --- a/ceno_zkvm/src/instructions/riscv/test_utils.rs +++ /dev/null @@ -1,23 +0,0 @@ -pub fn imm_b(imm: i32) -> u32 { - // imm is 13 bits in B-type - imm_with_max_valid_bits(imm, 13) -} - -pub fn imm_i(imm: i32) -> u32 { - // imm is 12 bits in I-type - imm_with_max_valid_bits(imm, 12) -} - -pub fn imm_j(imm: i32) -> u32 { - // imm is 21 bits in J-type - imm_with_max_valid_bits(imm, 21) -} - -fn imm_with_max_valid_bits(imm: i32, bits: u32) -> u32 { - imm as u32 & !(u32::MAX << bits) -} - -pub fn imm_u(imm: u32) -> u32 { - // valid imm is imm[12:31] in U-type - imm << 12 -} diff --git a/ceno_zkvm/src/instructions/riscv/u_insn.rs b/ceno_zkvm/src/instructions/riscv/u_insn.rs deleted file mode 100644 index 719ee2df1..000000000 --- a/ceno_zkvm/src/instructions/riscv/u_insn.rs +++ /dev/null @@ -1,65 +0,0 @@ -use ceno_emul::{InsnKind, StepRecord}; -use ff_ext::ExtensionField; - -use crate::{ - chip_handler::RegisterExpr, - circuit_builder::CircuitBuilder, - error::ZKVMError, - expression::{Expression, ToExpr}, - instructions::riscv::insn_base::{StateInOut, WriteRD}, - tables::InsnRecord, - witness::LkMultiplicity, -}; - -/// This config handles the common part of the U-type instruction: -/// - PC, cycle, fetch -/// - Register access -/// -/// It does not witness the output rd value or instruction immediate -#[derive(Debug)] -pub struct UInstructionConfig { - pub vm_state: StateInOut, - pub rd: WriteRD, -} - -impl UInstructionConfig { - pub fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - insn_kind: InsnKind, - imm: &Expression, - rd_written: RegisterExpr, - ) -> Result { - // State in and out - let vm_state = StateInOut::construct_circuit(circuit_builder, false)?; - - // Registers - let rd = WriteRD::construct_circuit(circuit_builder, rd_written, vm_state.ts)?; - - // Fetch instruction - circuit_builder.lk_fetch(&InsnRecord::new( - vm_state.pc.expr(), - insn_kind.into(), - Some(rd.id.expr()), - 0.into(), - 0.into(), - imm.clone(), - ))?; - - Ok(UInstructionConfig { vm_state, rd }) - } - - pub fn assign_instance( - &self, - instance: &mut [::BaseField], - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; - - // Fetch the instruction. - lk_multiplicity.fetch(step.pc().before.0); - - Ok(()) - } -} diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 21b6d4c05..d2ce320ee 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -17,7 +17,7 @@ use crate::{ }; use ark_std::test_rng; use base64::{Engine, engine::general_purpose::STANDARD_NO_PAD}; -use ceno_emul::{ByteAddr, CENO_PLATFORM, PC_WORD_SIZE, Program}; +use ceno_emul::{ByteAddr, CENO_PLATFORM, Program}; use ff::Field; use ff_ext::ExtensionField; use generic_static::StaticTypeMap; @@ -26,7 +26,7 @@ use itertools::{Itertools, enumerate, izip}; use multilinear_extensions::{mle::IntoMLEs, virtual_poly_v2::ArcMultilinearExtension}; use rand::thread_rng; use std::{ - collections::{BTreeMap, HashMap, HashSet}, + collections::{HashMap, HashSet}, fs::File, hash::Hash, io::{BufReader, ErrorKind}, @@ -400,7 +400,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { pub fn run( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], - programs: &[u32], + programs: &[ceno_emul::Instruction], lkm: Option, ) -> Result<(), Vec>> { Self::run_maybe_challenge(cb, wits_in, programs, &[], None, lkm) @@ -409,33 +409,16 @@ impl<'a, E: ExtensionField + Hash> MockProver { fn run_maybe_challenge( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], - input_programs: &[u32], + input_programs: &[ceno_emul::Instruction], pi: &[ArcMultilinearExtension<'a, E>], challenge: Option<[E; 2]>, lkm: Option, ) -> Result<(), Vec>> { - // fix the program table - let instructions = input_programs - .iter() - .cloned() - .chain(std::iter::repeat(0)) - .take(MOCK_PROGRAM_SIZE) - .collect_vec(); - let image = instructions - .iter() - .enumerate() - .map(|(insn_idx, &insn)| { - ( - CENO_PLATFORM.pc_base() + (insn_idx * PC_WORD_SIZE) as u32, - insn, - ) - }) - .collect::>(); let program = Program::new( CENO_PLATFORM.pc_base(), CENO_PLATFORM.pc_base(), - instructions, - image, + input_programs.to_vec(), + Default::default(), ); // load tables @@ -667,7 +650,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { pub fn assert_with_expected_errors( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], - programs: &[u32], + programs: &[ceno_emul::Instruction], constraint_names: &[&str], challenge: Option<[E; 2]>, lkm: Option, @@ -719,7 +702,7 @@ Hints: pub fn assert_satisfied_raw( cb: &CircuitBuilder, raw_witin: RowMajorMatrix, - programs: &[u32], + programs: &[ceno_emul::Instruction], challenge: Option<[E; 2]>, lkm: Option, ) { @@ -734,7 +717,7 @@ Hints: pub fn assert_satisfied( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], - programs: &[u32], + programs: &[ceno_emul::Instruction], challenge: Option<[E; 2]>, lkm: Option, ) { diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 945e49730..8a97ff437 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -3,8 +3,8 @@ use std::marker::PhantomData; use ark_std::test_rng; use ceno_emul::{ CENO_PLATFORM, - InsnKind::{ADD, EANY}, - PC_WORD_SIZE, Platform, Program, StepRecord, VMState, + InsnKind::{ADD, ECALL}, + Platform, Program, StepRecord, VMState, encode_rv32, }; use ff::Field; use ff_ext::ExtensionField; @@ -18,7 +18,6 @@ use transcript::{BasicTranscript, Transcript}; use crate::{ circuit_builder::CircuitBuilder, - declare_program, error::ZKVMError, expression::{ToExpr, WitIn}, instructions::{ @@ -187,23 +186,12 @@ fn test_rw_lk_expression_combination() { test_rw_lk_expression_combination_inner::<17, 61>(); } -const PROGRAM_SIZE: usize = 4; -#[allow(clippy::unusual_byte_groupings)] -const ECALL_HALT: u32 = 0b_000000000000_00000_000_00000_1110011; -#[allow(clippy::unusual_byte_groupings)] -const PROGRAM_CODE: [u32; PROGRAM_SIZE] = { - let mut program: [u32; PROGRAM_SIZE] = [ECALL_HALT; PROGRAM_SIZE]; - - declare_program!( - program, - // func7 rs2 rs1 f3 rd opcode - 0b_0000000_00100_00001_000_00100_0110011, // add x4, x4, x1 <=> addi x4, x4, 1 - ECALL_HALT, // ecall halt - ECALL_HALT, // ecall halt - ECALL_HALT, // ecall halt - ); - program -}; +const PROGRAM_CODE: [ceno_emul::Instruction; 4] = [ + encode_rv32(ADD, 4, 1, 4, 0), + encode_rv32(ECALL, 0, 0, 0, 0), + encode_rv32(ECALL, 0, 0, 0, 0), + encode_rv32(ECALL, 0, 0, 0, 0), +]; #[ignore = "this case is already tested in riscv_example as ecall_halt has only one instance"] #[test] @@ -216,16 +204,7 @@ fn test_single_add_instance_e2e() { CENO_PLATFORM.pc_base(), CENO_PLATFORM.pc_base(), PROGRAM_CODE.to_vec(), - PROGRAM_CODE - .iter() - .enumerate() - .map(|(insn_idx, &insn)| { - ( - (insn_idx * PC_WORD_SIZE) as u32 + CENO_PLATFORM.pc_base(), - insn, - ) - }) - .collect(), + Default::default(), ); let pcs_param = Pcs::setup(1 << MAX_NUM_VARIABLES).expect("Basefold PCS setup"); @@ -271,10 +250,10 @@ fn test_single_add_instance_e2e() { let mut add_records = vec![]; let mut halt_records = vec![]; all_records.into_iter().for_each(|record| { - let kind = record.insn().codes().kind; + let kind = record.insn().kind; match kind { ADD => add_records.push(record), - EANY => { + ECALL => { if record.rs1().unwrap().value == Platform::ecall_halt() { halt_records.push(record); } diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index a5ae2ad7f..5695ee247 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -13,27 +13,13 @@ use crate::{ witness::RowMajorMatrix, }; use ceno_emul::{ - DecodedInstruction, InsnCodes, InsnFormat::*, InsnKind::*, PC_STEP_SIZE, Program, WORD_SIZE, + InsnFormat, InsnFormat::*, InsnKind::*, Instruction, PC_STEP_SIZE, Program, WORD_SIZE, }; use ff_ext::ExtensionField; use goldilocks::SmallField; use itertools::Itertools; use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; -#[macro_export] -macro_rules! declare_program { - ($program:ident, $($instr:expr),* $(,)?) => { - - { - let mut _i = 0; - $( - $program[_i] = $instr; - _i += 1; - )* - } - }; -} - /// This structure establishes the order of the fields in instruction records, common to the program table and circuit fetches. #[derive(Clone, Debug)] pub struct InsnRecord([T; 6]); @@ -43,7 +29,7 @@ impl InsnRecord { where T: From, { - let rd = rd.unwrap_or_else(|| T::from(DecodedInstruction::RD_NULL)); + let rd = rd.unwrap_or_else(|| T::from(Instruction::RD_NULL)); InsnRecord([pc, kind, rd, rs1, rs2, imm_internal]) } @@ -53,10 +39,10 @@ impl InsnRecord { } impl InsnRecord { - fn from_decoded(pc: u32, insn: &DecodedInstruction) -> Self { + fn from_decoded(pc: u32, insn: &Instruction) -> Self { InsnRecord([ (pc as u64).into(), - (insn.codes().kind as u64).into(), + (insn.kind as u64).into(), (insn.rd_internal() as u64).into(), (insn.rs1_or_zero() as u64).into(), (insn.rs2_or_zero() as u64).into(), @@ -73,25 +59,17 @@ impl InsnRecord<()> { /// - `as u32` and `as i32` as usual. /// - `i64_to_base(imm)` gives the field element going into the program table. /// - `as u64` in unsigned cases. - pub fn imm_internal(insn: &DecodedInstruction) -> i64 { - let imm: u32 = insn.immediate(); - match insn.codes() { + pub fn imm_internal(insn: &Instruction) -> i64 { + match (insn.kind, InsnFormat::from(insn.kind)) { // Prepare the immediate for ShiftImmInstruction. // The shift is implemented as a multiplication/division by 1 << immediate. - InsnCodes { - kind: SLLI | SRLI | SRAI, - .. - } => 1 << (imm & 0x1F), + (SLLI | SRLI | SRAI, _) => 1 << insn.imm, // Unsigned view. // For example, u32::MAX is `u32::MAX mod p` in the finite field. - InsnCodes { format: R | U, .. } - | InsnCodes { - kind: ADDI | SLTIU | ANDI | XORI | ORI, - .. - } => imm as u64 as i64, + (_, R | U) | (ADDI | SLTIU | ANDI | XORI | ORI, _) => insn.imm as u32 as i64, // Signed view. // For example, u32::MAX is `-1 mod p` in the finite field. - _ => imm as i32 as i64, + _ => insn.imm as i64, } } } @@ -171,7 +149,7 @@ impl TableCircuit for ProgramTableCircuit { .zip((0..num_instructions).into_par_iter()) .for_each(|(row, i)| { let pc = pc_base + (i * PC_STEP_SIZE) as u32; - let insn = DecodedInstruction::new(program.instructions[i]); + let insn = program.instructions[i]; let values: InsnRecord<_> = InsnRecord::from_decoded(pc, &insn); // Copy all the fields. @@ -222,29 +200,6 @@ mod tests { use ff::Field; use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; - #[test] - #[allow(clippy::identity_op)] - fn test_decode_imm() { - for (i, expected) in [ - // Example of I-type: ADDI. - // imm | rs1 | funct3 | rd | opcode - (89 << 20 | 1 << 15 | 0b000 << 12 | 1 << 7 | 0x13, 89), - // Shifts get a precomputed power of 2: SLLI, SRLI, SRAI. - (31 << 20 | 1 << 15 | 0b001 << 12 | 1 << 7 | 0x13, 1 << 31), - (31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, 1 << 31), - ( - 1 << 30 | 31 << 20 | 1 << 15 | 0b101 << 12 | 1 << 7 | 0x13, - 1 << 31, - ), - // Example of R-type with funct7: SUB. - // funct7 | rs2 | rs1 | funct3 | rd | opcode - (0x20 << 25 | 1 << 20 | 1 << 15 | 0 << 12 | 1 << 7 | 0x33, 0), - ] { - let imm = InsnRecord::imm_internal(&DecodedInstruction::new(i)); - assert_eq!(imm, expected); - } - } - #[test] fn test_program_padding() { let mut cs = ConstraintSystem::::new(|| "riscv"); diff --git a/clippy.toml b/clippy.toml index 6e64e6f38..87fda2227 100644 --- a/clippy.toml +++ b/clippy.toml @@ -2,9 +2,10 @@ # Eg removing syn is blocked by ark-ff-asm cutting a new release # (https://github.com/arkworks-rs/algebra/issues/813) amongst other things. allowed-duplicate-crates = [ - "syn", - "windows-sys", + "hashbrown", + "itertools", "regex-automata", "regex-syntax", - "itertools", + "syn", + "windows-sys", ]