Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neon impl of ChaCha20 (better size & perf) #159

Open
wants to merge 11 commits into
base: development
Choose a base branch
from
3 changes: 3 additions & 0 deletions ChangeLog.d/chacha20-neon.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
* ChaCha20 size and performance: add a Neon implementation of ChaCha20 for
Thumb2 and 32 and 64-bit Arm, for Armv7 onwards. At default settings,
this improves performance by around 2x to 2.7x on Aarch64.
318 changes: 297 additions & 21 deletions drivers/builtin/src/chacha20.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,250 @@

#include "mbedtls/platform.h"

#define ROTL32(value, amount) \
((uint32_t) ((value) << (amount)) | ((value) >> (32 - (amount))))

#define CHACHA20_CTR_INDEX (12U)

#define CHACHA20_BLOCK_SIZE_BYTES (4U * 16U)

/*
* The Neon implementation can be configured to process multiple blocks in parallel; increasing the
* number of blocks gains a lot of performance, but adds on average around 250 bytes of code size
* for each additional block.
*
* This is controlled by setting MBEDTLS_CHACHA20_NEON_MULTIBLOCK in the range [0..6] (0 selects
* the scalar implementation; 1 selects single-block Neon; 2..6 select multi-block Neon).
*
* The default (i.e., if MBEDTLS_CHACHA20_NEON_MULTIBLOCK is not set) selects the fastest variant
* which has better code size than the scalar implementation (based on testing for Aarch64 on clang
* and gcc).
*
* Size & performance notes for Neon implementation from informal tests on Aarch64
* (applies to both gcc and clang except as noted):
* - When single-block is selected, this saves around 400-550 bytes of code-size c.f. the scalar
* implementation
* - Multi-block Neon is smaller and faster than scalar (up to 2 blocks for gcc, 3 for clang)
* - Code size increases consistently with number of blocks
* - Performance increases with number of blocks (except at 5 which is slightly slower than 4)
* - Performance is within a few % for gcc vs clang at all settings
* - Performance at 4 blocks roughly matches our hardware accelerated AES-GCM impl with
* better code size
* - Performance is worse at 7 or more blocks, due to running out of Neon registers
*/

#if !defined(MBEDTLS_HAVE_NEON_INTRINSICS)
// Select scalar implementation if Neon not available
#define MBEDTLS_CHACHA20_NEON_MULTIBLOCK 0
#elif !defined(MBEDTLS_CHACHA20_NEON_MULTIBLOCK)
// By default, select the best performing option that is not a code-size regression (based on
// measurements from recent gcc and clang).
#if defined(MBEDTLS_ARCH_IS_THUMB)
#if defined(MBEDTLS_COMPILER_IS_GCC)
#define MBEDTLS_CHACHA20_NEON_MULTIBLOCK 1
#else
#define MBEDTLS_CHACHA20_NEON_MULTIBLOCK 2
#endif
#elif defined(MBEDTLS_ARCH_IS_ARM64)
#define MBEDTLS_CHACHA20_NEON_MULTIBLOCK 3
#else
#if defined(MBEDTLS_COMPILER_IS_GCC)
#define MBEDTLS_CHACHA20_NEON_MULTIBLOCK 2
#else
#define MBEDTLS_CHACHA20_NEON_MULTIBLOCK 3
#endif
#endif
#endif

#if MBEDTLS_CHACHA20_NEON_MULTIBLOCK != 0
// Tested on all combinations of Armv7 arm/thumb2; Armv8 arm/thumb2/aarch64; Armv8 aarch64_be on
// clang 14, gcc 11, and some more recent versions.

// Define rotate-left operations that rotate within each 32-bit element in a 128-bit vector.
static inline uint32x4_t chacha20_neon_vrotlq_16_u32(uint32x4_t v)
{
return vreinterpretq_u32_u16(vrev32q_u16(vreinterpretq_u16_u32(v)));
}

static inline uint32x4_t chacha20_neon_vrotlq_12_u32(uint32x4_t v)
{
uint32x4_t x = vshlq_n_u32(v, 12);
return vsriq_n_u32(x, v, 20);
}

static inline uint32x4_t chacha20_neon_vrotlq_8_u32(uint32x4_t v)
{
uint32x4_t result;
#if defined(MBEDTLS_ARCH_IS_ARM64)
// This implementation is slightly faster, but only supported on 64-bit Arm
// Table look-up which results in an 8-bit rotate-left within each 32-bit element
const uint8_t tbl_rotl8[16] = { 3, 0, 1, 2, 7, 4, 5, 6, 11, 8, 9, 10, 15, 12, 13, 14 };
const uint8x16_t vrotl8_tbl = vld1q_u8(tbl_rotl8);
result = vreinterpretq_u32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(v), vrotl8_tbl));
#else
uint32x4_t a = vshlq_n_u32(v, 8);
result = vsriq_n_u32(a, v, 24);
#endif
return result;
}

static inline uint32x4_t chacha20_neon_vrotlq_7_u32(uint32x4_t v)
{
uint32x4_t x = vshlq_n_u32(v, 7);
return vsriq_n_u32(x, v, 25);
}

// Increment the 32-bit element within v that corresponds to the ChaCha20 counter
static inline uint32x4_t chacha20_neon_inc_counter(uint32x4_t v)
{
if (MBEDTLS_IS_BIG_ENDIAN) {
v[3]++;
} else {
v[0]++;
}
return v;
}

typedef struct {
uint32x4_t a, b, c, d;
} chacha20_neon_regs_t;

static inline chacha20_neon_regs_t chacha20_neon_singlepass(chacha20_neon_regs_t r)
{
for (unsigned i = 0; i < 2; i++) {
r.a = vaddq_u32(r.a, r.b); // r.a += b
r.d = veorq_u32(r.d, r.a); // r.d ^= a
r.d = chacha20_neon_vrotlq_16_u32(r.d); // r.d <<<= 16

r.c = vaddq_u32(r.c, r.d); // r.c += d
r.b = veorq_u32(r.b, r.c); // r.b ^= c
r.b = chacha20_neon_vrotlq_12_u32(r.b); // r.b <<<= 12

r.a = vaddq_u32(r.a, r.b); // r.a += b
r.d = veorq_u32(r.d, r.a); // r.d ^= a
r.d = chacha20_neon_vrotlq_8_u32(r.d); // r.d <<<= 8

r.c = vaddq_u32(r.c, r.d); // r.c += d
r.b = veorq_u32(r.b, r.c); // r.b ^= c
r.b = chacha20_neon_vrotlq_7_u32(r.b); // r.b <<<= 7

if (i == 0) {
// re-order b, c and d for the diagonal rounds
r.b = vextq_u32(r.b, r.b, 1); // r.b now holds positions 5,6,7,4
r.c = vextq_u32(r.c, r.c, 2); // 10, 11, 8, 9
r.d = vextq_u32(r.d, r.d, 3); // 15, 12, 13, 14
} else {
// restore element order in b, c, d
r.b = vextq_u32(r.b, r.b, 3);
r.c = vextq_u32(r.c, r.c, 2);
r.d = vextq_u32(r.d, r.d, 1);
}
}

return r;
}

static inline void chacha20_neon_finish_block(chacha20_neon_regs_t r,
chacha20_neon_regs_t r_original,
uint8_t **output,
const uint8_t **input)
{
r.a = vaddq_u32(r.a, r_original.a);
r.b = vaddq_u32(r.b, r_original.b);
r.c = vaddq_u32(r.c, r_original.c);
r.d = vaddq_u32(r.d, r_original.d);

vst1q_u8(*output + 0, veorq_u8(vld1q_u8(*input + 0), vreinterpretq_u8_u32(r.a)));
vst1q_u8(*output + 16, veorq_u8(vld1q_u8(*input + 16), vreinterpretq_u8_u32(r.b)));
vst1q_u8(*output + 32, veorq_u8(vld1q_u8(*input + 32), vreinterpretq_u8_u32(r.c)));
vst1q_u8(*output + 48, veorq_u8(vld1q_u8(*input + 48), vreinterpretq_u8_u32(r.d)));

*input += CHACHA20_BLOCK_SIZE_BYTES;
*output += CHACHA20_BLOCK_SIZE_BYTES;
}

// Prevent gcc from rolling up the (manually unrolled) interleaved block loops
MBEDTLS_OPTIMIZE_FOR_PERFORMANCE
static inline uint32x4_t chacha20_neon_blocks(chacha20_neon_regs_t r_original,
uint8_t *output,
const uint8_t *input,
size_t blocks)
{
// Assuming 32 regs, with 4 for original values plus 4 for scratch, with 4 regs per block,
// we should be able to process up to 24/4 = 6 blocks simultaneously.
// Testing confirms that perf indeed increases with more blocks, and then falls off after 6.

for (;;) {
chacha20_neon_regs_t r[6];

// It's essential to unroll these loops to benefit from interleaving multiple blocks.
// If MBEDTLS_CHACHA20_NEON_MULTIBLOCK < 6, gcc and clang will optimise away the unused bits
r[0] = r_original;
r[1] = r_original;
r[2] = r_original;
r[3] = r_original;
r[4] = r_original;
r[5] = r_original;
r[1].d = chacha20_neon_inc_counter(r[0].d);
r[2].d = chacha20_neon_inc_counter(r[1].d);
r[3].d = chacha20_neon_inc_counter(r[2].d);
r[4].d = chacha20_neon_inc_counter(r[3].d);
r[5].d = chacha20_neon_inc_counter(r[4].d);

for (unsigned i = 0; i < 10; i++) {
r[0] = chacha20_neon_singlepass(r[0]);
r[1] = chacha20_neon_singlepass(r[1]);
r[2] = chacha20_neon_singlepass(r[2]);
r[3] = chacha20_neon_singlepass(r[3]);
r[4] = chacha20_neon_singlepass(r[4]);
r[5] = chacha20_neon_singlepass(r[5]);
}

chacha20_neon_finish_block(r[0], r_original, &output, &input);
r_original.d = chacha20_neon_inc_counter(r_original.d);
if (--blocks == 0) {
return r_original.d;
}
#if MBEDTLS_CHACHA20_NEON_MULTIBLOCK >= 2
chacha20_neon_finish_block(r[1], r_original, &output, &input);
r_original.d = chacha20_neon_inc_counter(r_original.d);
if (--blocks == 0) {
return r_original.d;
}
#endif
#if MBEDTLS_CHACHA20_NEON_MULTIBLOCK >= 3
chacha20_neon_finish_block(r[2], r_original, &output, &input);
r_original.d = chacha20_neon_inc_counter(r_original.d);
if (--blocks == 0) {
return r_original.d;
}
#endif
#if MBEDTLS_CHACHA20_NEON_MULTIBLOCK >= 4
chacha20_neon_finish_block(r[3], r_original, &output, &input);
r_original.d = chacha20_neon_inc_counter(r_original.d);
if (--blocks == 0) {
return r_original.d;
}
#endif
#if MBEDTLS_CHACHA20_NEON_MULTIBLOCK >= 5
chacha20_neon_finish_block(r[4], r_original, &output, &input);
r_original.d = chacha20_neon_inc_counter(r_original.d);
if (--blocks == 0) {
return r_original.d;
}
#endif
#if MBEDTLS_CHACHA20_NEON_MULTIBLOCK >= 6
chacha20_neon_finish_block(r[5], r_original, &output, &input);
r_original.d = chacha20_neon_inc_counter(r_original.d);
if (--blocks == 0) {
return r_original.d;
}
#endif
}
}

#else

#define ROTL32(value, amount) \
((uint32_t) ((value) << (amount)) | ((value) >> (32 - (amount))))

/**
* \brief ChaCha20 quarter round operation.
*
Expand Down Expand Up @@ -138,10 +375,11 @@ static void chacha20_block(const uint32_t initial_state[16],
mbedtls_platform_zeroize(working_state, sizeof(working_state));
}

#endif

void mbedtls_chacha20_init(mbedtls_chacha20_context *ctx)
{
mbedtls_platform_zeroize(ctx->state, sizeof(ctx->state));
mbedtls_platform_zeroize(ctx->keystream8, sizeof(ctx->keystream8));
mbedtls_platform_zeroize(ctx, sizeof(mbedtls_chacha20_context));

/* Initially, there's no keystream bytes available */
ctx->keystream_bytes_used = CHACHA20_BLOCK_SIZE_BYTES;
Expand All @@ -158,20 +396,22 @@ int mbedtls_chacha20_setkey(mbedtls_chacha20_context *ctx,
const unsigned char key[32])
{
/* ChaCha20 constants - the string "expand 32-byte k" */
ctx->state[0] = 0x61707865;
ctx->state[1] = 0x3320646e;
ctx->state[2] = 0x79622d32;
ctx->state[3] = 0x6b206574;
static const char EXPAND_32_BYTE_K[16] = "expand 32-byte k";
memcpy(ctx->state, EXPAND_32_BYTE_K, 16);

/* Set key */
ctx->state[4] = MBEDTLS_GET_UINT32_LE(key, 0);
ctx->state[5] = MBEDTLS_GET_UINT32_LE(key, 4);
ctx->state[6] = MBEDTLS_GET_UINT32_LE(key, 8);
ctx->state[7] = MBEDTLS_GET_UINT32_LE(key, 12);
ctx->state[8] = MBEDTLS_GET_UINT32_LE(key, 16);
ctx->state[9] = MBEDTLS_GET_UINT32_LE(key, 20);
ctx->state[10] = MBEDTLS_GET_UINT32_LE(key, 24);
ctx->state[11] = MBEDTLS_GET_UINT32_LE(key, 28);
if (MBEDTLS_IS_BIG_ENDIAN) {
ctx->state[4] = MBEDTLS_GET_UINT32_LE(key, 0);
ctx->state[5] = MBEDTLS_GET_UINT32_LE(key, 4);
ctx->state[6] = MBEDTLS_GET_UINT32_LE(key, 8);
ctx->state[7] = MBEDTLS_GET_UINT32_LE(key, 12);
ctx->state[8] = MBEDTLS_GET_UINT32_LE(key, 16);
ctx->state[9] = MBEDTLS_GET_UINT32_LE(key, 20);
ctx->state[10] = MBEDTLS_GET_UINT32_LE(key, 24);
ctx->state[11] = MBEDTLS_GET_UINT32_LE(key, 28);
} else {
memcpy(&ctx->state[4], key, 32);
}

return 0;
}
Expand All @@ -184,9 +424,13 @@ int mbedtls_chacha20_starts(mbedtls_chacha20_context *ctx,
ctx->state[12] = counter;

/* Nonce */
ctx->state[13] = MBEDTLS_GET_UINT32_LE(nonce, 0);
ctx->state[14] = MBEDTLS_GET_UINT32_LE(nonce, 4);
ctx->state[15] = MBEDTLS_GET_UINT32_LE(nonce, 8);
if (MBEDTLS_IS_BIG_ENDIAN) {
ctx->state[13] = MBEDTLS_GET_UINT32_LE(nonce, 0);
ctx->state[14] = MBEDTLS_GET_UINT32_LE(nonce, 4);
ctx->state[15] = MBEDTLS_GET_UINT32_LE(nonce, 8);
} else {
memcpy(&ctx->state[13], nonce, 12);
}

mbedtls_platform_zeroize(ctx->keystream8, sizeof(ctx->keystream8));

Expand All @@ -204,7 +448,7 @@ int mbedtls_chacha20_update(mbedtls_chacha20_context *ctx,
size_t offset = 0U;

/* Use leftover keystream bytes, if available */
while (size > 0U && ctx->keystream_bytes_used < CHACHA20_BLOCK_SIZE_BYTES) {
while (ctx->keystream_bytes_used < CHACHA20_BLOCK_SIZE_BYTES && size) {
output[offset] = input[offset]
^ ctx->keystream8[ctx->keystream_bytes_used];

Expand All @@ -213,6 +457,37 @@ int mbedtls_chacha20_update(mbedtls_chacha20_context *ctx,
size--;
}

#if MBEDTLS_CHACHA20_NEON_MULTIBLOCK != 0
/* Load state into NEON registers */
chacha20_neon_regs_t state;
state.a = vld1q_u32(&ctx->state[0]);
state.b = vld1q_u32(&ctx->state[4]);
state.c = vld1q_u32(&ctx->state[8]);
state.d = vld1q_u32(&ctx->state[12]);

/* Process full blocks */
if (size >= CHACHA20_BLOCK_SIZE_BYTES) {
size_t blocks = size / CHACHA20_BLOCK_SIZE_BYTES;
state.d = chacha20_neon_blocks(state, output + offset, input + offset, blocks);

offset += CHACHA20_BLOCK_SIZE_BYTES * blocks;
size -= CHACHA20_BLOCK_SIZE_BYTES * blocks;
}

/* Last (partial) block */
if (size > 0U) {
/* Generate new keystream block and increment counter */
memset(ctx->keystream8, 0, CHACHA20_BLOCK_SIZE_BYTES);
state.d = chacha20_neon_blocks(state, ctx->keystream8, ctx->keystream8, 1);

mbedtls_xor_no_simd(output + offset, input + offset, ctx->keystream8, size);

ctx->keystream_bytes_used = size;
}

/* Capture state */
vst1q_u32(&ctx->state[12], state.d);
#else
/* Process full blocks */
while (size >= CHACHA20_BLOCK_SIZE_BYTES) {
/* Generate new keystream block and increment counter */
Expand All @@ -236,6 +511,7 @@ int mbedtls_chacha20_update(mbedtls_chacha20_context *ctx,
ctx->keystream_bytes_used = size;

}
#endif

return 0;
}
Expand Down
5 changes: 5 additions & 0 deletions include/tf-psa-crypto/build_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
#define MBEDTLS_ARCH_IS_ARM32
#endif

#if !defined(MBEDTLS_ARCH_IS_THUMB) && \
defined(_M_ARMT) || defined(__thumb__) || defined(__thumb2__)
#define MBEDTLS_ARCH_IS_THUMB
#endif

#if !defined(MBEDTLS_ARCH_IS_X64) && \
(defined(__amd64__) || defined(__x86_64__) || \
((defined(_M_X64) || defined(_M_AMD64)) && !defined(_M_ARM64EC)))
Expand Down