Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(meq): use simd to improve performance of meq #884

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fuel-vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ name = "execution"
harness = false
required-features = ["std"]

[[bench]]
name = "meq_performance"
harness = false
required-features = ["std"]

[dependencies]
anyhow = { version = "1.0", optional = true }
async-trait = "0.1"
Expand Down Expand Up @@ -110,6 +115,7 @@ test-helpers = [
"tai64",
"fuel-crypto/test-helpers",
]
experimental = []

[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] }
95 changes: 95 additions & 0 deletions fuel-vm/benches/meq_performance.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use criterion::{
criterion_group,
criterion_main,
Criterion,
};
use fuel_asm::{
op,
Instruction,
RegId,
};
use fuel_tx::{
Finalizable,
GasCosts,
Script,
TransactionBuilder,
};
use fuel_types::{
Immediate12,
Word,
};
use fuel_vm::{
interpreter::{
Interpreter,
InterpreterParams,
},
prelude::{
IntoChecked,
MemoryInstance,
MemoryStorage,
},
};

/// from; fuel-vm/src/tests/test_helpers.rs
/// Set a register `r` to a Word-sized number value using left-shifts
pub fn set_full_word(r: RegId, v: Word) -> Vec<Instruction> {
let r = r.to_u8();
let mut ops = vec![op::movi(r, 0)];
for byte in v.to_be_bytes() {
ops.push(op::ori(r, r, byte as Immediate12));
ops.push(op::slli(r, r, 8));
}
ops.pop().unwrap(); // Remove last shift
ops
}

fn meq_performance(c: &mut Criterion) {
let benchmark_matrix = [
1, 10, 100, 1000, 10_000, 50_000, 100_000, 500_000, 1_000_000, 2_000_000,
2_500_000, 5_000_000, 10_000_000, 15_000_000, 20_000_000,
// some exact multiples of 8 to verify alignment perf
8, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072,
262144, 524288, 1048576, 2097152, 4194304, 8388608,
];

for size in benchmark_matrix.iter() {
let mut interpreter = Interpreter::<_, _, Script>::with_storage(
MemoryInstance::new(),
MemoryStorage::default(),
InterpreterParams {
gas_costs: GasCosts::free(),
..Default::default()
},
);

let reg_len = RegId::new_checked(0x13).unwrap();

let mut script = set_full_word(reg_len, *size as Word);
script.extend(vec![
op::cfe(0x13),
op::meq(RegId::WRITABLE, RegId::ZERO, RegId::ZERO, reg_len),
op::jmpb(RegId::ZERO, 0),
]);

let tx_builder_script =
TransactionBuilder::script(script.into_iter().collect(), vec![])
.max_fee_limit(0)
.add_fee_input()
.finalize();
let script = tx_builder_script
.into_checked_basic(Default::default(), &Default::default())
.unwrap();
let script = script.test_into_ready();

interpreter.init_script(script).unwrap();

c.bench_function(&format!("meq_performance_{}", size), |b| {
b.iter(|| {
interpreter.execute().unwrap();
});
});
}
}

criterion_group!(benches, meq_performance);
criterion_main!(benches);
270 changes: 269 additions & 1 deletion fuel-vm/src/interpreter/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,272 @@ pub(crate) fn memcopy(
Ok(inc_pc(pc)?)
}

#[cfg(feature = "experimental")]
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
fn slices_equal_neon(a: &[u8], b: &[u8]) -> bool {
use std::arch::aarch64::*;

if a.len() != b.len() {
return false;
}

let len = a.len();
let mut i = 0;
const CHUNK: usize = 96;

// if the slices are small, we don't need to
// use SIMD instructions due to overhead
if a.len() < CHUNK {
return slices_equal_fallback(a, b);
}

unsafe {
while i + CHUNK <= len {
let mut cmp =
vceqq_u8(vld1q_u8(a.as_ptr().add(i)), vld1q_u8(b.as_ptr().add(i)));

cmp = vandq_u8(
cmp,
vceqq_u8(
vld1q_u8(a.as_ptr().add(i + 16)),
vld1q_u8(b.as_ptr().add(i + 16)),
),
);
cmp = vandq_u8(
cmp,
vceqq_u8(
vld1q_u8(a.as_ptr().add(i + 32)),
vld1q_u8(b.as_ptr().add(i + 32)),
),
);
cmp = vandq_u8(
cmp,
vceqq_u8(
vld1q_u8(a.as_ptr().add(i + 48)),
vld1q_u8(b.as_ptr().add(i + 48)),
),
);
cmp = vandq_u8(
cmp,
vceqq_u8(
vld1q_u8(a.as_ptr().add(i + 64)),
vld1q_u8(b.as_ptr().add(i + 64)),
),
);
cmp = vandq_u8(
cmp,
vceqq_u8(
vld1q_u8(a.as_ptr().add(i + 80)),
vld1q_u8(b.as_ptr().add(i + 80)),
),
);

if vmaxvq_u8(cmp) != 0xFF {
return false;
}

i += CHUNK;
}

// Scalar comparison for the remainder
a[i..] == b[i..]
}
}

#[cfg(feature = "experimental")]
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
fn slices_equal_avx2(a: &[u8], b: &[u8]) -> bool {
use std::arch::x86_64::*;

if a.len() != b.len() {
return false;
}

let len = a.len();
let mut i = 0;
const CHUNK: usize = 256;

// if the slices are small, we don't need to
// use SIMD instructions due to overhead
if a.len() < CHUNK {
return slices_equal_fallback(a, b);
}

unsafe {
let mut aggregate_mask_a = -1i32;
let mut aggregate_mask_b = -1i32;
let mut aggregate_mask_c = -1i32;
let mut aggregate_mask_d = -1i32;
let mut aggregate_mask_a_b = -1i32;
let mut aggregate_mask_c_d = -1i32;

while i + CHUNK <= len {
let simd_a1 = _mm256_loadu_si256(a.as_ptr().add(i) as *const _);
let simd_b1 = _mm256_loadu_si256(b.as_ptr().add(i) as *const _);

let simd_a2 = _mm256_loadu_si256(a.as_ptr().add(i + 32) as *const _);
let simd_b2 = _mm256_loadu_si256(b.as_ptr().add(i + 32) as *const _);

let simd_a3 = _mm256_loadu_si256(a.as_ptr().add(i + 64) as *const _);
let simd_b3 = _mm256_loadu_si256(b.as_ptr().add(i + 64) as *const _);

let simd_a4 = _mm256_loadu_si256(a.as_ptr().add(i + 96) as *const _);
let simd_b4 = _mm256_loadu_si256(b.as_ptr().add(i + 96) as *const _);

let simd_a5 = _mm256_loadu_si256(a.as_ptr().add(i + 128) as *const _);
let simd_b5 = _mm256_loadu_si256(b.as_ptr().add(i + 128) as *const _);

let simd_a6 = _mm256_loadu_si256(a.as_ptr().add(i + 160) as *const _);
let simd_b6 = _mm256_loadu_si256(b.as_ptr().add(i + 160) as *const _);

let simd_a7 = _mm256_loadu_si256(a.as_ptr().add(i + 192) as *const _);
let simd_b7 = _mm256_loadu_si256(b.as_ptr().add(i + 192) as *const _);

let simd_a8 = _mm256_loadu_si256(a.as_ptr().add(i + 224) as *const _);
let simd_b8 = _mm256_loadu_si256(b.as_ptr().add(i + 224) as *const _);

let cmp1 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a1, simd_b1));
let cmp2 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a2, simd_b2));
let cmp3 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a3, simd_b3));
let cmp4 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a4, simd_b4));
let cmp5 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a5, simd_b5));
let cmp6 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a6, simd_b6));
let cmp7 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a7, simd_b7));
let cmp8 = _mm256_movemask_epi8(_mm256_cmpeq_epi8(simd_a8, simd_b8));

aggregate_mask_a &= cmp1 & cmp2;
aggregate_mask_b &= cmp3 & cmp4;
aggregate_mask_c &= cmp5 & cmp6;
aggregate_mask_d &= cmp7 & cmp8;

aggregate_mask_a_b &= aggregate_mask_a & aggregate_mask_b;
aggregate_mask_c_d &= aggregate_mask_c & aggregate_mask_d;

if aggregate_mask_a_b & aggregate_mask_c_d != -1i32 {
return false;
}

i += CHUNK;
}

a[i..] == b[i..]
}
}

#[cfg(feature = "experimental")]
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
fn slices_equal_avx512(a: &[u8], b: &[u8]) -> bool {
use std::arch::x86_64::*;

if a.len() != b.len() {
return false;
}

let len = a.len();
let mut i = 0;
const CHUNK: usize = 512;

// if the slices are small, we don't need to
// use SIMD instructions due to overhead
if a.len() < CHUNK {
return slices_equal_fallback(a, b);
}

unsafe {
while i + CHUNK <= len {
let simd_a1 = _mm512_loadu_si512(a.as_ptr().add(i) as *const _);
let simd_b1 = _mm512_loadu_si512(b.as_ptr().add(i) as *const _);

let simd_a2 = _mm512_loadu_si512(a.as_ptr().add(i + 64) as *const _);
let simd_b2 = _mm512_loadu_si512(b.as_ptr().add(i + 64) as *const _);

let simd_a3 = _mm512_loadu_si512(a.as_ptr().add(i + 128) as *const _);
let simd_b3 = _mm512_loadu_si512(b.as_ptr().add(i + 128) as *const _);

let simd_a4 = _mm512_loadu_si512(a.as_ptr().add(i + 192) as *const _);
let simd_b4 = _mm512_loadu_si512(b.as_ptr().add(i + 192) as *const _);

let simd_a5 = _mm512_loadu_si512(a.as_ptr().add(i + 256) as *const _);
let simd_b5 = _mm512_loadu_si512(b.as_ptr().add(i + 256) as *const _);

let simd_a6 = _mm512_loadu_si512(a.as_ptr().add(i + 320) as *const _);
let simd_b6 = _mm512_loadu_si512(b.as_ptr().add(i + 320) as *const _);

let simd_a7 = _mm512_loadu_si512(a.as_ptr().add(i + 384) as *const _);
let simd_b7 = _mm512_loadu_si512(b.as_ptr().add(i + 384) as *const _);

let simd_a8 = _mm512_loadu_si512(a.as_ptr().add(i + 448) as *const _);
let simd_b8 = _mm512_loadu_si512(b.as_ptr().add(i + 448) as *const _);

let cmp1 = _mm512_cmpeq_epi8_mask(simd_a1, simd_b1);
let cmp2 = _mm512_cmpeq_epi8_mask(simd_a2, simd_b2);
let cmp3 = _mm512_cmpeq_epi8_mask(simd_a3, simd_b3);
let cmp4 = _mm512_cmpeq_epi8_mask(simd_a4, simd_b4);
let cmp5 = _mm512_cmpeq_epi8_mask(simd_a5, simd_b5);
let cmp6 = _mm512_cmpeq_epi8_mask(simd_a6, simd_b6);
let cmp7 = _mm512_cmpeq_epi8_mask(simd_a7, simd_b7);
let cmp8 = _mm512_cmpeq_epi8_mask(simd_a8, simd_b8);

let cmp1_2 = cmp1 & cmp2;
let cmp3_4 = cmp3 & cmp4;
let cmp5_6 = cmp5 & cmp6;
let cmp7_8 = cmp7 & cmp8;

let cmp1_4 = cmp1_2 & cmp3_4;
let cmp5_8 = cmp5_6 & cmp7_8;

let full_cmp = cmp1_4 & cmp5_8;

if full_cmp != u64::MAX {
return false;
}

i += CHUNK_SIZE;
}

a[i..] == b[i..]
}
}

#[inline]
fn slices_equal_fallback(a: &[u8], b: &[u8]) -> bool {
a == b
}

#[inline]
fn slice_eq(a: &[u8], b: &[u8]) -> bool {
#[cfg(feature = "experimental")]
{
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
{
return slices_equal_avx512(a, b);
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
return slices_equal_avx2(a, b);
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
return slices_equal_neon(a, b);
}

#[allow(unreachable_code)]
slices_equal_fallback(a, b)
}
#[cfg(not(feature = "experimental"))]
{
slices_equal_fallback(a, b)
}
}

#[test]
fn slice_eq_test() {
let a = [1u8; 20000];
let b = [1u8; 20000];

assert!(slice_eq(&a, &b));
}

pub(crate) fn memeq(
memory: &mut MemoryInstance,
result: &mut Word,
Expand All @@ -1031,7 +1297,9 @@ pub(crate) fn memeq(
c: Word,
d: Word,
) -> SimpleResult<()> {
*result = (memory.read(b, d)? == memory.read(c, d)?) as Word;
let range_a = memory.read(b, d)?;
let range_b = memory.read(c, d)?;
*result = slice_eq(range_a, range_b) as Word;
Ok(inc_pc(pc)?)
}

Expand Down
Loading
Loading