From 8c5f2d3cec7160fb0daf228192997685f11a81cb Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Fri, 1 Nov 2024 17:07:57 -0600 Subject: [PATCH] small optimization --- circuits/aes-gcm/nivc/aes-gctr-nivc.circom | 129 +++++++++++++++++---- 1 file changed, 105 insertions(+), 24 deletions(-) diff --git a/circuits/aes-gcm/nivc/aes-gctr-nivc.circom b/circuits/aes-gcm/nivc/aes-gctr-nivc.circom index 6f7852b..6d17213 100644 --- a/circuits/aes-gcm/nivc/aes-gctr-nivc.circom +++ b/circuits/aes-gcm/nivc/aes-gctr-nivc.circom @@ -38,20 +38,6 @@ template AESGCTRFOLD(DATA_BYTES, MAX_STACK_HEIGHT) { last_counter_num.in[i] <== last_counter_bits.out[31 - i]; } signal index <== last_counter_num.out - 1; - - // TODO (Colin): We can probably make a template that writes to two multiple arrays at once that saves us even more constraints here instead of just using the `WriteToIndex` twice - - // write new plain text block. - signal prevAccumulatedPlaintext[DATA_BYTES]; - for(var i = 0 ; i < DATA_BYTES ; i++) { - prevAccumulatedPlaintext[i] <== step_in[i]; - } - signal nextAccumulatedPlaintext[DATA_BYTES]; - component writeToIndex = WriteToIndex(DATA_BYTES, 16); - writeToIndex.array_to_write_to <== prevAccumulatedPlaintext; - writeToIndex.array_to_write_at_index <== plainText; - writeToIndex.index <== index * 16; - nextAccumulatedPlaintext <== writeToIndex.out; // folds one block component aes = AESGCTRFOLDABLE(); @@ -64,23 +50,29 @@ template AESGCTRFOLD(DATA_BYTES, MAX_STACK_HEIGHT) { aes.lastCounter[i] <== step_in[DATA_BYTES * 2 + i]; } - // accumulate cipher text + + // Write out the plaintext and ciphertext to our accumulation arrays, both at once. + signal prevAccumulatedPlaintext[DATA_BYTES]; + for(var i = 0 ; i < DATA_BYTES ; i++) { + prevAccumulatedPlaintext[i] <== step_in[i]; + } signal prevAccumulatedCiphertext[DATA_BYTES]; for(var i = 0 ; i < DATA_BYTES ; i++) { prevAccumulatedCiphertext[i] <== step_in[DATA_BYTES + i]; - } - signal nextAccumulatedCiphertext[DATA_BYTES]; - component writeCipherText = WriteToIndex(DATA_BYTES, 16); - writeCipherText.array_to_write_to <== prevAccumulatedCiphertext; - writeCipherText.array_to_write_at_index <== aes.cipherText; - writeCipherText.index <== index * 16; - nextAccumulatedCiphertext <== writeCipherText.out; + } + component nextTexts = WriteToIndexForTwoArrays(DATA_BYTES, 16); + nextTexts.first_array_to_write_to <== prevAccumulatedPlaintext; + nextTexts.second_array_to_write_to <== prevAccumulatedCiphertext; + nextTexts.first_array_to_write_at_index <== plainText; + nextTexts.second_array_to_write_at_index <== aes.cipherText; + nextTexts.index <== index * 16; + for(var i = 0 ; i < TOTAL_BYTES_ACROSS_NIVC ; i++) { if(i < DATA_BYTES) { - step_out[i] <== nextAccumulatedPlaintext[i]; + step_out[i] <== nextTexts.outFirst[i]; } else if(i < 2 * DATA_BYTES) { - step_out[i] <== nextAccumulatedCiphertext[i - DATA_BYTES]; + step_out[i] <== nextTexts.outSecond[i - DATA_BYTES]; } else if(i < 2 * DATA_BYTES + 4) { step_out[i] <== aes.counter[i - (2 * DATA_BYTES)]; } else { @@ -88,3 +80,92 @@ template AESGCTRFOLD(DATA_BYTES, MAX_STACK_HEIGHT) { } } } + + + +template WriteToIndexForTwoArrays(m, n) { + signal input first_array_to_write_to[m]; + signal input second_array_to_write_to[m]; + signal input first_array_to_write_at_index[n]; + signal input second_array_to_write_at_index[n]; + signal input index; + + signal output outFirst[m]; + signal output outSecond[m]; + + assert(m >= n); + + // Note: this is underconstrained, we need to constrain that index + n <= m + // Need to constrain that index + n <= m -- can't be an assertion, because uses a signal + // ------------------------- // + + // Here, we get an array of ALL zeros, except at the `index` AND `index + n` + // beginning-------^^^^^ end---^^^^^^^^^ + signal indexMatched[m]; + component indexBegining[m]; + component indexEnding[m]; + for(var i = 0 ; i < m ; i++) { + indexBegining[i] = IsZero(); + indexBegining[i].in <== i - index; + indexEnding[i] = IsZero(); + indexEnding[i].in <== i - (index + n); + indexMatched[i] <== indexBegining[i].out + indexEnding[i].out; + } + + // E.g., index == 31, m == 160, n == 16 + // => indexMatch[31] == 1; + // => indexMatch[47] == 1; + // => otherwise, all 0. + + signal accum[m]; + accum[0] <== indexMatched[0]; + + component writeAt = IsZero(); + writeAt.in <== accum[0] - 1; + + component orFirst = OR(); + orFirst.a <== (writeAt.out * first_array_to_write_at_index[0]); + orFirst.b <== (1 - writeAt.out) * first_array_to_write_to[0]; + outFirst[0] <== orFirst.out; + + component orSecond = OR(); + orSecond.a <== (writeAt.out * second_array_to_write_at_index[0]); + orSecond.b <== (1 - writeAt.out) * second_array_to_write_to[0]; + outSecond[0] <== orSecond.out; + // IF accum == 1 then { array_to_write_at } ELSE IF accum != 1 then { array to write_to } + var accum_index = accum[0]; + + component writeSelector[m - 1]; + component indexSelectorFirst[m - 1]; + component indexSelectorSecond[m - 1]; + component orsFirst[m-1]; + component orsSecond[m-1]; + for(var i = 1 ; i < m ; i++) { + // accum will be 1 at all indices where we want to write the new array + accum[i] <== accum[i-1] + indexMatched[i]; + writeSelector[i-1] = IsZero(); + writeSelector[i-1].in <== accum[i] - 1; + // IsZero(accum[i] - 1); --> tells us we are in the range where we want to write the new array + + indexSelectorFirst[i-1] = IndexSelector(n); + indexSelectorFirst[i-1].index <== accum_index; + indexSelectorFirst[i-1].in <== first_array_to_write_at_index; + + indexSelectorSecond[i-1] = IndexSelector(n); + indexSelectorSecond[i-1].index <== accum_index; + indexSelectorSecond[i-1].in <== second_array_to_write_at_index; + // When accum is not zero, out is array_to_write_at_index, otherwise it is array_to_write_to + + orsFirst[i-1] = OR(); + orsFirst[i-1].a <== (writeSelector[i-1].out * indexSelectorFirst[i-1].out); + orsFirst[i-1].b <== (1 - writeSelector[i-1].out) * first_array_to_write_to[i]; + outFirst[i] <== orsFirst[i-1].out; + + orsSecond[i-1] = OR(); + orsSecond[i-1].a <== (writeSelector[i-1].out * indexSelectorSecond[i-1].out); + orsSecond[i-1].b <== (1 - writeSelector[i-1].out) * second_array_to_write_to[i]; + outSecond[i] <== orsSecond[i-1].out; + + accum_index += writeSelector[i-1].out; + } +}