Skip to content

Commit

Permalink
Merge pull request #62 from robertknight/wasm-simd-fixes
Browse files Browse the repository at this point in the history
Support using WASM impls of `SimdFloat`/`SimdInt` outside of the rten-vecmath crate
  • Loading branch information
robertknight authored Mar 25, 2024
2 parents 8e81747 + 7be2109 commit b19e7bf
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
6 changes: 5 additions & 1 deletion rten-vecmath/src/simd_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
#[cfg(target_arch = "aarch64")]
pub(crate) mod aarch64;

// The wasm module is exposed because it contains wrapper types which are needed
// to use the functionality outside of this crate.
#[cfg(target_arch = "wasm32")]
pub(crate) mod wasm;
pub mod wasm;

#[cfg(target_arch = "x86_64")]
pub(crate) mod x86_64;

Expand Down
24 changes: 24 additions & 0 deletions rten-vecmath/src/simd_vec/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,47 @@ impl SimdInt for v128i {

const LEN: usize = 4;

#[inline]
unsafe fn splat(val: i32) -> Self {
Self(i32x4_splat(val))
}

#[inline]
unsafe fn gt(self, other: Self) -> Self::Mask {
Self(i32x4_gt(self.0, other.0))
}

#[inline]
unsafe fn blend(self, other: Self, mask: Self::Mask) -> Self {
Self(v128_bitselect(other.0, self.0, mask.0))
}

#[inline]
unsafe fn add(self, rhs: Self) -> Self {
Self(i32x4_add(self.0, rhs.0))
}

#[inline]
unsafe fn sub(self, rhs: Self) -> Self {
Self(i32x4_sub(self.0, rhs.0))
}

#[inline]
unsafe fn shl<const COUNT: i32>(self) -> Self {
Self(i32x4_shl(self.0, COUNT as u32))
}

#[inline]
unsafe fn reinterpret_as_float(self) -> Self::Float {
v128f(self.0)
}

#[inline]
unsafe fn load(ptr: *const i32) -> Self {
Self(v128_load(ptr as *const v128))
}

#[inline]
unsafe fn store(self, ptr: *mut i32) {
v128_store(ptr as *mut v128, self.0)
}
Expand All @@ -65,62 +74,77 @@ impl SimdFloat for v128f {

const LEN: usize = 4;

#[inline]
unsafe fn splat(val: f32) -> Self {
Self(f32x4_splat(val))
}

#[inline]
unsafe fn abs(self) -> Self {
Self(f32x4_abs(self.0))
}

#[inline]
unsafe fn mul_add(self, a: Self, b: Self) -> Self {
Self(f32x4_add(f32x4_mul(self.0, a.0), b.0))
}

#[inline]
unsafe fn sub(self, rhs: Self) -> Self {
Self(f32x4_sub(self.0, rhs.0))
}

#[inline]
unsafe fn add(self, rhs: Self) -> Self {
Self(f32x4_add(self.0, rhs.0))
}

#[inline]
unsafe fn to_int_trunc(self) -> Self::Int {
v128i(i32x4_trunc_sat_f32x4(self.0))
}

#[inline]
unsafe fn mul(self, rhs: Self) -> Self {
Self(f32x4_mul(self.0, rhs.0))
}

#[inline]
unsafe fn div(self, rhs: Self) -> Self {
Self(f32x4_div(self.0, rhs.0))
}

#[inline]
unsafe fn ge(self, rhs: Self) -> Self::Mask {
v128i(f32x4_ge(self.0, rhs.0))
}

#[inline]
unsafe fn le(self, rhs: Self) -> Self::Mask {
v128i(f32x4_le(self.0, rhs.0))
}

#[inline]
unsafe fn lt(self, rhs: Self) -> Self::Mask {
v128i(f32x4_lt(self.0, rhs.0))
}

#[inline]
unsafe fn max(self, rhs: Self) -> Self {
Self(f32x4_max(self.0, rhs.0))
}

#[inline]
unsafe fn blend(self, rhs: Self, mask: Self::Mask) -> Self {
Self(v128_bitselect(rhs.0, self.0, mask.0))
}

#[inline]
unsafe fn load(ptr: *const f32) -> Self {
Self(v128_load(ptr as *const v128))
}

#[inline]
unsafe fn store(self, ptr: *mut f32) {
v128_store(ptr as *mut v128, self.0)
}
Expand Down
1 change: 1 addition & 0 deletions tools/benchmarks/wasm-gemm.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const cases = [
{ m: 512, n: 512, k: 512 }, // Square
{ m: 128, n: 2048, k: 512 }, // Wide
{ m: 2048, n: 128, k: 512 }, // Tall
{ m: 1, n: 4096, k: 512 }, // Vector
];

function logResult(engine, elapsedMs, m, n, k, iters) {
Expand Down

0 comments on commit b19e7bf

Please sign in to comment.