diff --git a/src/arm/64/cdef_dist.S b/src/arm/64/cdef_dist.S index 8a0f3404d2..afc66f9897 100644 --- a/src/arm/64/cdef_dist.S +++ b/src/arm/64/cdef_dist.S @@ -135,3 +135,116 @@ L(cdk_8x8): CDEF_DIST_REFINE ret endfunc + +// v0: tmp register +// v1: src input +// v2: dst input +// v3 = sum(src_{i,j}) +// v4 = sum(src_{i,j}^2) +// v5 = sum(dst_{i,j}) +// v6 = sum(dst_{i,j}^2) +// v7 = sum(src_{i,j} * dst_{i,j}) +// v16: zero register +.macro CDEF_DIST_HBD_W8 + uabal v3.4s, v1.4h, v16.4h // sum pixel values + uabal2 v3.4s, v1.8h, v16.8h + umlal v4.4s, v1.4h, v1.4h // square and accumulate + umlal2 v4.4s, v1.8h, v1.8h + uabal v5.4s, v2.4h, v16.4h // same as above, but for dst + uabal2 v5.4s, v2.8h, v16.8h + umlal v6.4s, v2.4h, v2.4h + umlal2 v6.4s, v2.8h, v2.8h + umlal v7.4s, v1.4h, v2.4h // src_{i,j} * dst_{i,j} + umlal2 v7.4s, v1.8h, v2.8h +.endm + +.macro CDEF_DIST_HBD_REFINE shift=0 + addv s3, v3.4s + umull v3.2d, v3.2s, v3.2s + urshr d3, d3, #(6-\shift) // d3: sum(src_{i,j})^2 / N + uaddlv d4, v4.4s // d4: sum(src_{i,j}^2) + addv s5, v5.4s + umull v5.2d, v5.2s, v5.2s + urshr d5, d5, #(6-\shift) // d5: sum(dst_{i,j})^2 / N + uaddlv d6, v6.4s // d6: sum(dst_{i,j}^2) + uaddlv d7, v7.4s + add d0, d4, d6 + sub d0, d0, d7 + sub d0, d0, d7 // d0: sse + uqsub d4, d4, d3 // d4: svar + uqsub d6, d6, d5 // d6: dvar +.if \shift != 0 + shl d4, d4, #\shift + shl d6, d6, #\shift +.endif + str s4, [x4] + str s6, [x4, #4] + str s0, [x4, #8] +.endm + +.macro LOAD_ROW_HBD + ldr q1, [x0] + ldr q2, [x2] + add x0, x0, x1 + add x2, x2, x3 +.endm + +.macro LOAD_ROWS_HBD + ldr d1, [x0] + ldr d2, [x2] + ldr d0, [x0, x1] + ldr d17, [x2, x3] + add x0, x0, x1, lsl 1 + add x2, x2, x3, lsl 1 + mov v1.d[1], v0.d[0] + mov v2.d[1], v17.d[0] +.endm + +// x0: src: *const u16, +// x1: src_stride: isize, +// x2: dst: *const u16, +// x3: dst_stride: isize, +// x4: ret_ptr: *mut u32, +function cdef_dist_kernel_4x4_hbd_neon, export=1 + CDEF_DIST_INIT 4, 4 +L(cdk_hbd_4x4): + LOAD_ROWS_HBD + CDEF_DIST_HBD_W8 + subs w5, w5, #1 + bne L(cdk_hbd_4x4) + CDEF_DIST_HBD_REFINE 2 + ret +endfunc + +function cdef_dist_kernel_4x8_hbd_neon, export=1 + CDEF_DIST_INIT 4, 8 +L(cdk_hbd_4x8): + LOAD_ROWS_HBD + CDEF_DIST_HBD_W8 + subs w5, w5, #1 + bne L(cdk_hbd_4x8) + CDEF_DIST_HBD_REFINE 1 + ret +endfunc + +function cdef_dist_kernel_8x4_hbd_neon, export=1 + CDEF_DIST_INIT 8, 4 +L(cdk_hbd_8x4): + LOAD_ROW_HBD + CDEF_DIST_HBD_W8 + subs w5, w5, #1 + bne L(cdk_hbd_8x4) + CDEF_DIST_HBD_REFINE 1 + ret +endfunc + +function cdef_dist_kernel_8x8_hbd_neon, export=1 + CDEF_DIST_INIT 8, 8 +L(cdk_hbd_8x8): + LOAD_ROW_HBD + CDEF_DIST_HBD_W8 + subs w5, w5, #1 + bne L(cdk_hbd_8x8) + CDEF_DIST_HBD_REFINE + ret +endfunc diff --git a/src/asm/aarch64/dist/cdef_dist.rs b/src/asm/aarch64/dist/cdef_dist.rs index 59a2232874..54e4afcda8 100644 --- a/src/asm/aarch64/dist/cdef_dist.rs +++ b/src/asm/aarch64/dist/cdef_dist.rs @@ -22,6 +22,14 @@ type CdefDistKernelFn = unsafe extern fn( ret_ptr: *mut u32, ); +type CdefDistKernelHBDFn = unsafe extern fn( + src: *const u16, + src_stride: isize, + dst: *const u16, + dst_stride: isize, + ret_ptr: *mut u32, +); + extern { fn rav1e_cdef_dist_kernel_4x4_neon( src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize, @@ -39,6 +47,22 @@ extern { src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize, ret_ptr: *mut u32, ); + fn rav1e_cdef_dist_kernel_4x4_hbd_neon( + src: *const u16, src_stride: isize, dst: *const u16, dst_stride: isize, + ret_ptr: *mut u32, + ); + fn rav1e_cdef_dist_kernel_4x8_hbd_neon( + src: *const u16, src_stride: isize, dst: *const u16, dst_stride: isize, + ret_ptr: *mut u32, + ); + fn rav1e_cdef_dist_kernel_8x4_hbd_neon( + src: *const u16, src_stride: isize, dst: *const u16, dst_stride: isize, + ret_ptr: *mut u32, + ); + fn rav1e_cdef_dist_kernel_8x8_hbd_neon( + src: *const u16, src_stride: isize, dst: *const u16, dst_stride: isize, + ret_ptr: *mut u32, + ); } /// # Panics @@ -86,7 +110,25 @@ pub fn cdef_dist_kernel( } } PixelType::U16 => { - return call_rust(); + if let Some(func) = + CDEF_DIST_KERNEL_HBD_FNS[cpu.as_index()][kernel_fn_index(w, h)] + { + let mut ret_buf = [0u32; 3]; + // SAFETY: Calls Assembly code. + unsafe { + func( + src.data_ptr() as *const _, + T::to_asm_stride(src.plane_cfg.stride), + dst.data_ptr() as *const _, + T::to_asm_stride(dst.plane_cfg.stride), + ret_buf.as_mut_ptr(), + ) + } + + (ret_buf[0], ret_buf[1], ret_buf[2]) + } else { + return call_rust(); + } } }; @@ -127,3 +169,23 @@ cpu_function_lookup_table!( default: [None; CDEF_DIST_KERNEL_FNS_LENGTH], [NEON] ); + +static CDEF_DIST_KERNEL_HBD_FNS_NEON: [Option; + CDEF_DIST_KERNEL_FNS_LENGTH] = { + let mut out: [Option; CDEF_DIST_KERNEL_FNS_LENGTH] = + [None; CDEF_DIST_KERNEL_FNS_LENGTH]; + + out[kernel_fn_index(4, 4)] = Some(rav1e_cdef_dist_kernel_4x4_hbd_neon); + out[kernel_fn_index(4, 8)] = Some(rav1e_cdef_dist_kernel_4x8_hbd_neon); + out[kernel_fn_index(8, 4)] = Some(rav1e_cdef_dist_kernel_8x4_hbd_neon); + out[kernel_fn_index(8, 8)] = Some(rav1e_cdef_dist_kernel_8x8_hbd_neon); + + out +}; + +cpu_function_lookup_table!( + CDEF_DIST_KERNEL_HBD_FNS: + [[Option; CDEF_DIST_KERNEL_FNS_LENGTH]], + default: [None; CDEF_DIST_KERNEL_FNS_LENGTH], + [NEON] +);