Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: reduce number of recursion shapes #1721

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
args: --release -p sp1-perf -- --program workdir/program.bin --stdin workdir/stdin.bin --mode cpu
env:
RUST_LOG: info
VERIFY_VK: false
RUSTFLAGS: -Copt-level=3 -Ctarget-cpu=native
RUST_BACKTRACE: 1

Expand Down Expand Up @@ -120,6 +121,7 @@ jobs:
args: --release -p sp1-perf -- --program workdir/program.bin --stdin workdir/stdin.bin --mode cuda
env:
RUST_LOG: debug
VERIFY_VK: false
RUSTFLAGS: -Copt-level=3 -Ctarget-cpu=native
RUST_BACKTRACE: 1
SP1_PROVER: cuda
Expand Down Expand Up @@ -172,6 +174,7 @@ jobs:
args: --release -p sp1-perf --features "native-gnark,network-v2" -- --program workdir/program.bin --stdin workdir/stdin.bin --mode network
env:
RUST_LOG: info
VERIFY_VK: false
RUSTFLAGS: -Copt-level=3 -Ctarget-cpu=native
RUST_BACKTRACE: 1
SP1_PROVER: network
Expand Down
2 changes: 1 addition & 1 deletion crates/cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl SP1CudaProver {
/// [SP1ProverClient] that can be used to communicate with the container.
pub fn new() -> Result<Self, Box<dyn StdError>> {
let container_name = "sp1-gpu";
let image_name = "public.ecr.aws/succinct-labs/sp1-gpu:7e66232";
let image_name = "public.ecr.aws/succinct-labs/sp1-gpu:3c231e2";

let cleaned_up = Arc::new(AtomicBool::new(false));
let cleanup_name = container_name;
Expand Down
116 changes: 70 additions & 46 deletions crates/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use std::{

use lru::LruCache;

use shapes::SP1ProofShape;
use tracing::instrument;

use p3_baby_bear::BabyBear;
Expand Down Expand Up @@ -102,7 +103,6 @@ const SHRINK_DEGREE: usize = 3;
const WRAP_DEGREE: usize = 9;

const CORE_CACHE_SIZE: usize = 5;
const COMPRESS_CACHE_SIZE: usize = 3;
pub const REDUCE_BATCH_SIZE: usize = 2;

// TODO: FIX
Expand Down Expand Up @@ -135,8 +135,7 @@ pub struct SP1Prover<C: SP1ProverComponents = DefaultProverComponents> {

pub recursion_cache_misses: AtomicUsize,

pub compress_programs:
Mutex<LruCache<SP1CompressWithVkeyShape, Arc<RecursionProgram<BabyBear>>>>,
pub compress_programs: BTreeMap<SP1CompressWithVkeyShape, Arc<RecursionProgram<BabyBear>>>,

pub compress_cache_misses: AtomicUsize,

Expand Down Expand Up @@ -188,14 +187,6 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
)
.expect("PROVER_CORE_CACHE_SIZE must be a non-zero usize");

let compress_cache_size = NonZeroUsize::new(
env::var("PROVER_COMPRESS_CACHE_SIZE")
.unwrap_or_else(|_| CORE_CACHE_SIZE.to_string())
.parse()
.unwrap_or(COMPRESS_CACHE_SIZE),
)
.expect("PROVER_COMPRESS_CACHE_SIZE must be a non-zero usize");

let core_shape_config = env::var("FIX_CORE_SHAPES")
.map(|v| v.eq_ignore_ascii_case("true"))
.unwrap_or(true)
Expand All @@ -220,14 +211,36 @@ impl<C: SP1ProverComponents> SP1Prover<C> {

let (root, merkle_tree) = MerkleTree::commit(allowed_vk_map.keys().copied().collect());

let mut compress_programs = BTreeMap::new();
if let Some(config) = &recursion_shape_config {
SP1ProofShape::generate_compress_shapes(config, 2).for_each(|shape| {
let compress_shape = SP1CompressWithVkeyShape {
compress_shape: SP1CompressShape { proof_shapes: shape },
merkle_tree_height: merkle_tree.height,
};
let input = SP1CompressWithVKeyWitnessValues::dummy(
compress_prover.machine(),
&compress_shape,
);
let program = compress_program_from_input::<C>(
recursion_shape_config.as_ref(),
&compress_prover,
vk_verification,
&input,
);
let program = Arc::new(program);
compress_programs.insert(compress_shape, program);
});
}

Self {
core_prover,
compress_prover,
shrink_prover,
wrap_prover,
recursion_programs: Mutex::new(LruCache::new(core_cache_size)),
recursion_cache_misses: AtomicUsize::new(0),
compress_programs: Mutex::new(LruCache::new(compress_cache_size)),
compress_programs,
compress_cache_misses: AtomicUsize::new(0),
vk_root: root,
vk_merkle_tree: merkle_tree,
Expand Down Expand Up @@ -355,40 +368,19 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
&self,
input: &SP1CompressWithVKeyWitnessValues<InnerSC>,
) -> Arc<RecursionProgram<BabyBear>> {
let mut cache = self.compress_programs.lock().unwrap_or_else(|e| e.into_inner());
cache
.get_or_insert(input.shape(), || {
let misses = self.compress_cache_misses.fetch_add(1, Ordering::Relaxed);
tracing::debug!("compress cache miss, misses: {}", misses);
// Get the operations.
let builder_span = tracing::debug_span!("build compress program").entered();
let mut builder = Builder::<InnerConfig>::default();

// read the input.
let input = input.read(&mut builder);
// Verify the proof.
SP1CompressWithVKeyVerifier::verify(
&mut builder,
self.compress_prover.machine(),
input,
self.vk_verification,
PublicValuesOutputDigest::Reduce,
);
let operations = builder.into_operations();
builder_span.exit();

// Compile the program.
let compiler_span = tracing::debug_span!("compile compress program").entered();
let mut compiler = AsmCompiler::<InnerConfig>::default();
let mut program = compiler.compile(operations);
if let Some(recursion_shape_config) = &self.recursion_shape_config {
recursion_shape_config.fix_shape(&mut program);
}
let program = Arc::new(program);
compiler_span.exit();
program
})
.clone()
if self.recursion_shape_config.is_some() {
self.compress_programs.get(&input.shape()).map(Clone::clone).unwrap()
} else {
let misses = self.compress_cache_misses.fetch_add(1, Ordering::Relaxed);
tracing::debug!("compress cache miss, misses: {}", misses);
// Compile the program if the recursion shape config is None.
Arc::new(compress_program_from_input::<C>(
self.recursion_shape_config.as_ref(),
&self.compress_prover,
self.vk_verification,
input,
))
}
}

pub fn shrink_program(
Expand Down Expand Up @@ -1217,6 +1209,38 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
}
}

pub fn compress_program_from_input<C: SP1ProverComponents>(
config: Option<&RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>>,
compress_prover: &C::CompressProver,
vk_verification: bool,
input: &SP1CompressWithVKeyWitnessValues<BabyBearPoseidon2>,
) -> RecursionProgram<BabyBear> {
let builder_span = tracing::debug_span!("build compress program").entered();
let mut builder = Builder::<InnerConfig>::default();
// read the input.
let input = input.read(&mut builder);
// Verify the proof.
SP1CompressWithVKeyVerifier::verify(
&mut builder,
compress_prover.machine(),
input,
vk_verification,
PublicValuesOutputDigest::Reduce,
);
let operations = builder.into_operations();
builder_span.exit();

// Compile the program.
let compiler_span = tracing::debug_span!("compile compress program").entered();
let mut compiler = AsmCompiler::<InnerConfig>::default();
let mut program = compiler.compile(operations);
if let Some(config) = config {
config.fix_shape(&mut program);
}
compiler_span.exit();

program
}
#[cfg(any(test, feature = "export-tests"))]
pub mod tests {

Expand Down
8 changes: 8 additions & 0 deletions crates/prover/src/shapes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ impl SP1ProofShape {
)
}

pub fn generate_compress_shapes(
recursion_shape_config: &RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
reduce_batch_size: usize,
) -> impl Iterator<Item = Vec<ProofShape>> + '_ {
(1..=reduce_batch_size)
.flat_map(|batch_size| recursion_shape_config.get_all_shape_combinations(batch_size))
}

pub fn dummy_vk_map<'a>(
core_shape_config: &'a CoreShapeConfig<BabyBear>,
recursion_shape_config: &'a RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
Expand Down
2 changes: 1 addition & 1 deletion crates/recursion/circuit/src/machine/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub struct SP1CompressWitnessValues<SC: StarkGenericConfig> {

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SP1CompressShape {
proof_shapes: Vec<ProofShape>,
pub proof_shapes: Vec<ProofShape>,
}

impl<C, SC, A> SP1CompressVerifier<C, SC, A>
Expand Down
114 changes: 7 additions & 107 deletions crates/recursion/core/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct RecursionShape {
}

pub struct RecursionShapeConfig<F, A> {
allowed_shapes: Vec<HashMap<String, usize>>,
pub allowed_shapes: Vec<HashMap<String, usize>>,
_marker: PhantomData<(F, A)>,
}

Expand Down Expand Up @@ -101,122 +101,22 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> Default
// Specify allowed shapes.
let allowed_shapes = [
[
(base_alu.clone(), 20),
(mem_var.clone(), 18),
(ext_alu.clone(), 18),
(exp_reverse_bits_len.clone(), 17),
(mem_const.clone(), 17),
(poseidon2_wide.clone(), 16),
(select.clone(), 18),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 20),
(mem_var.clone(), 18),
(ext_alu.clone(), 18),
(exp_reverse_bits_len.clone(), 17),
(mem_const.clone(), 16),
(poseidon2_wide.clone(), 16),
(select.clone(), 18),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(ext_alu.clone(), 20),
(base_alu.clone(), 19),
(mem_var.clone(), 19),
(poseidon2_wide.clone(), 17),
(mem_const.clone(), 16),
(exp_reverse_bits_len.clone(), 16),
(select.clone(), 18),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 19),
(mem_var.clone(), 18),
(ext_alu.clone(), 18),
(exp_reverse_bits_len.clone(), 17),
(mem_const.clone(), 16),
(poseidon2_wide.clone(), 16),
(select.clone(), 18),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 19),
(mem_var.clone(), 18),
(ext_alu.clone(), 18),
(exp_reverse_bits_len.clone(), 16),
(mem_const.clone(), 16),
(poseidon2_wide.clone(), 16),
(select.clone(), 18),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 20),
(ext_alu.clone(), 21),
(base_alu.clone(), 16),
(mem_var.clone(), 19),
(ext_alu.clone(), 19),
(exp_reverse_bits_len.clone(), 17),
(mem_const.clone(), 17),
(poseidon2_wide.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 21),
(mem_var.clone(), 19),
(ext_alu.clone(), 19),
(exp_reverse_bits_len.clone(), 18),
(mem_const.clone(), 18),
(poseidon2_wide.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 21),
(mem_var.clone(), 19),
(ext_alu.clone(), 19),
(exp_reverse_bits_len.clone(), 18),
(mem_const.clone(), 17),
(poseidon2_wide.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(ext_alu.clone(), 21),
(base_alu.clone(), 20),
(mem_var.clone(), 20),
(poseidon2_wide.clone(), 18),
(mem_const.clone(), 17),
(exp_reverse_bits_len.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 20),
(mem_var.clone(), 19),
(ext_alu.clone(), 19),
(exp_reverse_bits_len.clone(), 18),
(mem_const.clone(), 17),
(poseidon2_wide.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 20),
(mem_var.clone(), 19),
(ext_alu.clone(), 19),
(exp_reverse_bits_len.clone(), 17),
(mem_const.clone(), 17),
(poseidon2_wide.clone(), 17),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
[
(base_alu.clone(), 21),
(mem_var.clone(), 20),
(ext_alu.clone(), 20),
(exp_reverse_bits_len.clone(), 18),
(base_alu.clone(), 16),
(mem_var.clone(), 19),
(poseidon2_wide.clone(), 16),
(mem_const.clone(), 18),
(poseidon2_wide.clone(), 18),
(exp_reverse_bits_len.clone(), 18),
(select.clone(), 19),
(public_values.clone(), PUB_VALUES_LOG_HEIGHT),
],
Expand Down
Loading