Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVX2 bignum_copy_row_from_table specialisations #3

Merged
merged 2 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graviola/src/low/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ cfg_if::cfg_if! {
pub(crate) use x86_64::bignum_add_p384::bignum_add_p384;
pub(crate) use x86_64::bignum_bitsize::bignum_bitsize;
pub(crate) use x86_64::bignum_cmp_lt::bignum_cmp_lt;
pub(crate) use x86_64::bignum_copy_row_from_table::bignum_copy_row_from_table;
pub(crate) use x86_64::bignum_copy_row_from_table_mux::bignum_copy_row_from_table;
pub(crate) use x86_64::bignum_demont::bignum_demont;
pub(crate) use x86_64::bignum_point_select_p256::{bignum_aff_point_select_p256, bignum_jac_point_select_p256};
pub(crate) use x86_64::bignum_point_select_p384::bignum_jac_point_select_p384;
Expand Down
57 changes: 57 additions & 0 deletions graviola/src/low/x86_64/bignum_copy_row_from_table_16_avx2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Written for Graviola by Joe Birr-Pixton, 2024.
// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0

use core::arch::x86_64::*;

pub fn bignum_copy_row_from_table_16_avx2(z: &mut [u64], table: &[u64], _height: u64, index: u64) {
debug_assert!(z.len() == 16);
debug_assert!(index < _height);
debug_assert!(table.len() == (_height as usize) * z.len());

unsafe { _bignum_copy_row_from_table_16_avx2(z, table, index) }
}

#[target_feature(enable = "avx,avx2")]
unsafe fn _bignum_copy_row_from_table_16_avx2(z: &mut [u64], table: &[u64], index: u64) {
_mm_prefetch(table.as_ptr().cast(), _MM_HINT_T0);
_mm_prefetch(table.as_ptr().add(16).cast(), _MM_HINT_T0);

let mut acc0 = _mm256_setzero_si256();
let mut acc1 = _mm256_setzero_si256();
let mut acc2 = _mm256_setzero_si256();
let mut acc3 = _mm256_setzero_si256();

let desired_index = _mm_set1_epi64x(index as i64);
let desired_index = _mm256_setr_m128i(desired_index, desired_index);

let index = _mm_set1_epi64x(0);
let mut index = _mm256_setr_m128i(index, index);

let ones = _mm_set1_epi64x(1);
let ones = _mm256_setr_m128i(ones, ones);

for row in table.chunks_exact(16) {
let mask = _mm256_cmpeq_epi64(index, desired_index);
index = _mm256_add_epi64(index, ones);

let row0 = _mm256_loadu_si256(row.as_ptr().add(0).cast());
let row1 = _mm256_loadu_si256(row.as_ptr().add(4).cast());
let row2 = _mm256_loadu_si256(row.as_ptr().add(8).cast());
let row3 = _mm256_loadu_si256(row.as_ptr().add(12).cast());

let row0 = _mm256_and_si256(row0, mask);
let row1 = _mm256_and_si256(row1, mask);
let row2 = _mm256_and_si256(row2, mask);
let row3 = _mm256_and_si256(row3, mask);

acc0 = _mm256_xor_si256(row0, acc0);
acc1 = _mm256_xor_si256(row1, acc1);
acc2 = _mm256_xor_si256(row2, acc2);
acc3 = _mm256_xor_si256(row3, acc3);
}

_mm256_storeu_si256(z.as_mut_ptr().add(0).cast(), acc0);
_mm256_storeu_si256(z.as_mut_ptr().add(4).cast(), acc1);
_mm256_storeu_si256(z.as_mut_ptr().add(8).cast(), acc2);
_mm256_storeu_si256(z.as_mut_ptr().add(12).cast(), acc3);
}
61 changes: 61 additions & 0 deletions graviola/src/low/x86_64/bignum_copy_row_from_table_8n_avx2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Written for Graviola by Joe Birr-Pixton, 2024.
// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0

use core::arch::x86_64::*;

pub fn bignum_copy_row_from_table_8n_avx2(
z: &mut [u64],
table: &[u64],
_height: u64,
width: u64,
index: u64,
) {
debug_assert!(z.len() as u64 == width);
debug_assert!(width % 8 == 0);
debug_assert!(index < _height);
debug_assert!(table.len() as u64 == _height * width);

unsafe { _bignum_copy_row_from_table_8n_avx2(z, table, width, index) }
}

#[target_feature(enable = "avx,avx2")]
unsafe fn _bignum_copy_row_from_table_8n_avx2(
z: &mut [u64],
table: &[u64],
width: u64,
index: u64,
) {
_mm_prefetch(table.as_ptr().cast(), _MM_HINT_T0);
_mm_prefetch(table.as_ptr().add(16).cast(), _MM_HINT_T0);

z.fill(0);

let desired_index = _mm_set1_epi64x(index as i64);
let desired_index = _mm256_setr_m128i(desired_index, desired_index);

let index = _mm_set1_epi64x(0);
let mut index = _mm256_setr_m128i(index, index);

let ones = _mm_set1_epi64x(1);
let ones = _mm256_setr_m128i(ones, ones);

for row in table.chunks_exact(width as usize) {
let mask = _mm256_cmpeq_epi64(index, desired_index);
index = _mm256_add_epi64(index, ones);

for (i, zz) in z.chunks_exact_mut(8).enumerate() {
let row0 = _mm256_loadu_si256(row.as_ptr().add(i * 8).cast());
let row1 = _mm256_loadu_si256(row.as_ptr().add(i * 8 + 4).cast());

let row0 = _mm256_and_si256(row0, mask);
let row1 = _mm256_and_si256(row1, mask);

let store0 = _mm256_loadu_si256(zz.as_ptr().add(0).cast());
let store1 = _mm256_loadu_si256(zz.as_ptr().add(4).cast());
let store0 = _mm256_xor_si256(store0, row0);
let store1 = _mm256_xor_si256(store1, row1);
_mm256_storeu_si256(zz.as_mut_ptr().add(0).cast(), store0);
_mm256_storeu_si256(zz.as_mut_ptr().add(4).cast(), store1);
}
}
}
26 changes: 26 additions & 0 deletions graviola/src/low/x86_64/bignum_copy_row_from_table_mux.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Written for Graviola by Joe Birr-Pixton, 2024.
// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0

/// Multiplex between specialisations of `bignum_copy_row_from_table`
#[inline]
pub fn bignum_copy_row_from_table(
z: &mut [u64],
table: &[u64],
height: u64,
width: u64,
index: u64,
) {
match width {
16 => super::bignum_copy_row_from_table_16_avx2::bignum_copy_row_from_table_16_avx2(
z, table, height, index,
),
width if width % 8 == 0 => {
super::bignum_copy_row_from_table_8n_avx2::bignum_copy_row_from_table_8n_avx2(
z, table, height, width, index,
)
}
width => super::bignum_copy_row_from_table::bignum_copy_row_from_table(
z, table, height, width, index,
),
}
}
3 changes: 3 additions & 0 deletions graviola/src/low/x86_64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pub(crate) mod bignum_add_p384;
pub(crate) mod bignum_bitsize;
pub(crate) mod bignum_cmp_lt;
pub(crate) mod bignum_copy_row_from_table;
pub(crate) mod bignum_copy_row_from_table_16_avx2;
pub(crate) mod bignum_copy_row_from_table_8n_avx2;
pub(crate) mod bignum_copy_row_from_table_mux;
pub(crate) mod bignum_demont;
pub(crate) mod bignum_demont_p256;
pub(crate) mod bignum_demont_p384;
Expand Down