-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbackrun.mpc
430 lines (363 loc) · 20.4 KB
/
backrun.mpc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
### [PoC] Private Backrunning of Private Transactions Using MPC ###
from Compiler.mpc_math import sqrt
from Compiler.oram import OptimalORAM
# RLP coding rules:
# [0x00 - 0x7f]: as is
# [0x80 - 0xb7]: short string. length of string in bytes is encoded as VALUE-0x80
# [0xb8 - 0xbf]: long string. VALUE-0xb7 encodes the number of subsequent bytes in which the length of the long string is encoded.
# [0xc0 - 0xf7]: short list. length of list in bytes is encoded as VALUE-0xbf
# [0xf8 - 0xff]: long list. VALUE-0xf7 encodes the number of subsequent bytes in which the length of the long list is encoded.
# returns the number of bytes in a (short or long) list (+the number of bytes used to encode the list)
# modifies the 'success' variable in case value at the current index does not refer to a list
# increments 'user_tx_index' by the number of bytes read from the user tx
def decode_list():
bytes = sint(0)
tmp = user_tx[user_tx_index]
user_tx_index.update(user_tx_index + 1)
success.update(success & tmp.greater_equal(0xc0, bit_length=8))
len = tmp - 0xf7
# long list. use fixed-length loop to not leak data. maximum length is 8 bytes (0xff - 0xf7)
@for_range(0xff - 0xf7)
def _(i):
value = user_tx[user_tx_index + i].left_shift(8 * (len - i - 1), bit_length=64)
update = ((len - i).greater_than(0, bit_length=4)).if_else(value, 0)
bytes.update(bytes + update)
# short list
bytes = (len.greater_than(0, bit_length=4)).if_else(bytes, tmp - 0xbf)
# increment index
user_tx_index.update((len.greater_than(0, bit_length=4)).if_else(user_tx_index + len, user_tx_index + 1))
return bytes + len + 1
# TODO: generalize decode_as_is(), decode_short_string(), and decode_long_string() into a single, unified decoding function
# returns rlp-decoded byte from user tx
# modifies the 'success' variable in case value at the current index does not refer to a list
# increments 'user_tx_index' by one
def decode_as_is():
tmp = user_tx[user_tx_index]
user_tx_index.update(user_tx_index + 1)
success.update(success & tmp.greater_equal(0, bit_length=8))
success.update(success & tmp.less_equal(0x7f, bit_length=8))
return tmp
# returns rlp-decoded short string from user tx
# modifies the 'success' variable in case value at the current index does not refer to a short list
# increments 'user_tx_index' by the number of bytes read from the user tx
def decode_short_string():
string = sint(0)
tmp = user_tx[user_tx_index]
user_tx_index.update(user_tx_index + 1)
success.update(success & tmp.greater_equal(0x80, bit_length=8))
# TODO: short string may not fit into the 256 bit that we use as integer size. short string could be up to 440 bit (i.e., 0xb7 not only 0xa0).
success.update(success & tmp.less_equal(0xa0, bit_length=8))
len = ((tmp - 0x80).greater_than(32, bit_length=8)).if_else(32, tmp - 0x80)
# use fixed-length loop to not leak data.
@for_range(0xa0 - 0x80)
def _(i):
value = user_tx[user_tx_index + i].left_shift(8 * (len - i - 1), bit_length=64)
update = ((len - i).greater_than(0, bit_length=6)).if_else(value, 0)
string.update(string + update)
user_tx_index.update((len.greater_than(0, bit_length=4)).if_else(user_tx_index + len, user_tx_index))
return string
# returns the number of bytes in a long string from user tx
# modifies the 'success' variable in case value at the current index does not refer to a long string
# increments 'user_tx_index' by the number of bytes read from the user tx
def decode_long_string():
tmp = user_tx[user_tx_index]
user_tx_index.update(user_tx_index + 1)
success.update(success & tmp.greater_equal(0xb8, bit_length=8))
success.update(success & tmp.less_equal(0xbf, bit_length=8))
len = tmp - 0xb7
strlen = sint(0)
@for_range(0xbf - 0xb8)
def _(i):
value = user_tx[user_tx_index + i].left_shift(8 * (len - i - 1), bit_length=64)
update = ((len - i).greater_than(0, bit_length=6)).if_else(value, 0)
strlen.update(strlen + update)
user_tx_index.update((len.greater_than(0, bit_length=4)).if_else(user_tx_index + len, user_tx_index))
return strlen
# populate data items into storage
# modifies the 'success' variable in case the length of data is greater than the maximum length we support
# increments 'user_tx_index' by the number of bytes in data
def populate_data(len, maxlen):
success.update(success & len.less_equal(maxlen, bit_length=10))
len = (len.less_equal(maxlen, bit_length=10)).if_else(len, maxlen)
# function selector: first 4 bytes
function_selector = sint(0)
@for_range(4)
def _(i):
function_selector.update(function_selector + (user_tx[user_tx_index + i].left_shift(8 * (4 - i - 1), bit_length=32)))
storage[5] = function_selector
# populate storage with remaining data items
@while_do(lambda x: x < maxlen, regint(4))
def _(i):
# merge 32 single bytes into one 256 bit data item
data_item = sint(0)
@for_range(32)
def _(j):
update = ((len - i - j).greater_than(0, bit_length=10)).if_else(user_tx[user_tx_index + i + j].left_shift(8 * (32 - j - 1), bit_length=256), 0)
data_item.update(data_item + update)
storage[6 + (i - 4) / 32] = data_item
return i + 32
user_tx_index.update(user_tx_index + len)
# returns rlp-decoded v, r, and s parts of the ECDSA signature from user tx
# modifies the 'success' variable in case any of the three indices is incorrect
# increments 'user_tx_index' by 67
def decode_signature():
# decode r or s part of the signature
def decode_rs():
success.update(success & user_tx[user_tx_index].equal(0xa0, bit_length=8))
rs = sint(0)
@for_range(32)
def _(i):
rs.update(rs + (user_tx[user_tx_index + i + 1].left_shift(8 * (32 - i - 1), bit_length=256)))
user_tx_index.update(user_tx_index + 33)
return rs
v = user_tx[user_tx_index]
user_tx_index.update(user_tx_index + 1)
success.update(success & v.greater_equal(0, bit_length=8))
success.update(success & v.less_equal(0x7f, bit_length=8))
r = decode_rs()
s = decode_rs()
return (v, r, s)
# returns rlp-decoded address from user tx
# modifies the 'success' variable in case value at the current index does not refer to an address
# increments 'user_tx_index' by 21
def decode_address():
success.update(success & user_tx[user_tx_index].equal(0x94, bit_length=8))
to = sint(0)
@for_range(20)
def _(i):
to.update(to + (user_tx[user_tx_index + i + 1].left_shift(8 * (20 - i - 1), bit_length=160)))
user_tx_index.update(user_tx_index + 21)
return to
# TODO: generalize encode_as_is() and encode_short_string() into a single, unified encoding function
# updates backrunning transaction with rlp-encoded value as-is
# modifies the 'success' variable if the string is too large to refer to an as-is encoded value
# increments 'tx_index' by 1
def encode_as_is(str):
success.update((str.greater_equal(0, bit_length=256)).if_else(success, sintbit(0)))
success.update((str.less_equal(0x7f, bit_length=256)).if_else(success, sintbit(0)))
backrun_tx_nolen[backrun_tx_index] = str
backrun_tx_index.update(backrun_tx_index + 1)
# updates backrunning transaction with rlp-encoded short string
# modifies the 'success' variable if the string is too small to refer to a short string
# increments 'tx_index' by length of the string + 1
def encode_short_string(str):
success.update((str.greater_equal(0x80, bit_length=256)).if_else(success, sintbit(0)))
# TODO: should ensure here that 'str' is smaller than 440 bits, but we support only 256 bit for now anyways
len = strlen(str)
backrun_tx_nolen[backrun_tx_index] = 0x80 + len
backrun_tx_index.update(backrun_tx_index + 1)
updateTx(str, len)
# retrieve next value from protocol-internal storage, either a searcher-provided constant, from the user's tx, or from a previous computation's result
# modifies the 'success' variable if an out-of-bounds access in the protocol-internal storage would happen. TODO: is this check correctly positioned here or should it go earlier in the code, when reading the input from the searcher?
def getValue():
operand = searcher_program[searcher_program_index]
searcher_program_index.update(searcher_program_index + 1)
# check for invalid access
success.update(success & sintbit((operand.greater_equal(STORAGE_MAX_ENTRIES, bit_length=10)).if_else(0, 1)))
operand = (operand.greater_equal(STORAGE_MAX_ENTRIES, bit_length=10)).if_else(sint(STORAGE_MAX_ENTRIES - 1), operand)
return storage[operand]
# get number of bytes that a short string needs to be encoded in. supports strings represented by an integer up to 256 bit. TODO: short string can be up to 440 bit.
def strlen(str):
x = sint(str)
len = sint(0)
@for_range(33)
def _(i):
len.update((x.equal(0, bit_length=256)).bit_and(len.equal(0, bit_length=4)).if_else(i, len))
x.update(x.right_shift(8, bit_length=256))
success.update((x.equal(0, bit_length=256)).if_else(success, sintbit(0)))
return len
# update backrunning transaction with data of specified length
# increments 'tx_index' by the specified length
def updateTx(str, len):
k = sint(str)
@for_range(32)
def _(i):
shift = 8 * (len - i - 1)
result = k.right_shift(shift, bit_length=256)
k.update(k - result.left_shift(shift, bit_length=256))
backrun_tx_nolen[backrun_tx_index + i] = (sint(i).less_than(len, bit_length=5)).if_else(result, -1)
backrun_tx_index.update(backrun_tx_index + len)
### end of functions ###
### here starts the main() spaghetti code.. ###
# track whether any error occured
success = sintbit(1)
sfix.set_precision(258, 515) # for sqrt()
# read user tx as fixed-size user input
USER_TX_MAX_LEN = 500
user_tx = OptimalORAM(USER_TX_MAX_LEN, sint)
@for_range(USER_TX_MAX_LEN)
def _(i):
user_tx[i] = sint.get_input_from(0)
success.update(success & user_tx[i].greater_equal(0, bit_length=8))
success.update(success & user_tx[i].less_equal(0xff, bit_length=8))
user_tx_index = sint(0)
CONSTANTS_MAX = 30 # maximum number of constants in a searcher program
COMPUTATIONS_MAX = 50 # maximum number of computations in a searcher program
COMPARISONS_MAX = 10 # maximum number of comparisons in a searcher program
BACKRUNNING_MAX = 15 # maximum number of backrunning tx' items in a searcher program
# read searcher program as fixed-size searcher input
SEARCHER_PROGRAM_MAX_LEN = CONSTANTS_MAX * 4 + COMPUTATIONS_MAX * 3 + COMPARISONS_MAX * 3 + BACKRUNNING_MAX
searcher_program = OptimalORAM(SEARCHER_PROGRAM_MAX_LEN, sint)
@for_range(SEARCHER_PROGRAM_MAX_LEN)
def _(i):
searcher_program[i] = sint.get_input_from(1)
searcher_program_index = sint(0)
USER_TX_ENTRIES_BEFORE_DATA = 5 # nonce, gas price, gas limit, to, value
USER_TX_ENTRIES_AFTER_DATA = 3 # v, r, s
# maximum entries we support in a user's transaction data section.
DATA_MAX_ENTRIES = 15
COMPUTATIONS_START = USER_TX_ENTRIES_BEFORE_DATA + DATA_MAX_ENTRIES + USER_TX_ENTRIES_AFTER_DATA
# storage contains the user's tx, searcher-provided constants, and the results of the searcher's computation.
STORAGE_MAX_ENTRIES = COMPUTATIONS_START + CONSTANTS_MAX + COMPUTATIONS_MAX
# TODO: should we use fixed precision floating point numbers (sfix) here?
storage = OptimalORAM(STORAGE_MAX_ENTRIES, sint)
# an Ethereum tx is an RLP-encoded list
total_bytes = decode_list()
storage[0] = decode_as_is() # nonce
storage[1] = decode_short_string() # gas price
storage[2] = decode_short_string() # gas limit
storage[3] = decode_address() # to address
storage[4] = decode_short_string() # value
# data
populate_data(decode_long_string(), DATA_MAX_ENTRIES * 32 + 4)
# signature
SIGNATURE_INDEX = USER_TX_ENTRIES_BEFORE_DATA + DATA_MAX_ENTRIES
(storage[SIGNATURE_INDEX], storage[SIGNATURE_INDEX + 1], storage[SIGNATURE_INDEX + 2]) = decode_signature()
# ensure that the we have indeed read exactly the number of bytes that we were supposed to read according to the RLP-encoded list
success.update(success & user_tx_index.equal(total_bytes, bit_length=16))
print_ln('DEBUG success: %s', success.reveal())
print_ln('DEBUG storage: %s', storage)
# populate storage with constants from searcher, which are used later on for computations, comparisons, and the backrunning tx.
@for_range(CONSTANTS_MAX)
def _(i):
constant = sint(0)
@for_range(4)
def _(j):
tmp = searcher_program[searcher_program_index + i * 4 + j].left_shift(64 * (4 - j - 1), bit_length=256)
constant.update(constant + tmp)
print_ln('DEBUG constant #%s: %s', i, constant.reveal())
storage[COMPUTATIONS_START + i] = constant
searcher_program_index += CONSTANTS_MAX * 4
print_ln('DEBUG storage: %s', storage)
# simple interpreted language that allows the searcher to conduct computations on protocol-internal storage (which consists of the user's tx, the constants provided earlier, and results from previous computations).
@for_range(COMPUTATIONS_MAX)
def _(i):
print_ln('DEBUG computation #%s', i)
operand1 = getValue()
print_ln('DEBUG operand1: %s', operand1.reveal())
operator = searcher_program[searcher_program_index]
searcher_program_index.update(searcher_program_index + 1)
success.update(success & operator.greater_equal(0, bit_length=3))
success.update(success & operator.less_equal(4, bit_length=3))
print_ln('DEBUG operator: %s', operator.reveal())
operand2 = getValue()
print_ln('DEBUG operand2: %s', operand2.reveal())
target = COMPUTATIONS_START + CONSTANTS_MAX + i
storage[target] = (operator.equal(0, bit_length=3)).if_else(operand1 + operand2, storage[target])
storage[target] = (operator.equal(1, bit_length=3)).if_else(operand1 - operand2, storage[target])
storage[target] = (operator.equal(2, bit_length=3)).if_else(operand1 * operand2, storage[target])
storage[target] = (operator.equal(3, bit_length=3)).if_else(operand1.int_div(operand2, bit_length=256), storage[target])
storage[target] = (operator.equal(4, bit_length=3)).if_else(sint(sqrt(sfix(operand1))), storage[target])
print_ln('DEBUG result: %s', storage[target].reveal())
print_ln('DEBUG storage: %s', storage)
print_ln('DEBUG success: %s', success.reveal())
# simple interpreted language that allows the searcher to conduct comparisons/matching operations on protocol-internal storage (which consists of the user's tx, the constants provided earlier, and results from previous computations).
@for_range(COMPARISONS_MAX)
def _(i):
print_ln('DEBUG comparison #%s', i)
comparand1 = getValue()
print_ln('DEBUG comparand1: %s', comparand1.reveal())
comparator = searcher_program[searcher_program_index]
searcher_program_index.update(searcher_program_index + 1)
success.update(success & comparator.greater_equal(5, bit_length=3))
success.update(success & comparator.less_equal(10, bit_length=3))
print_ln('DEBUG comparator: %s', comparator.reveal())
comparand2 = getValue()
print_ln('DEBUG comparand2: %s', comparand2.reveal())
result = sint(0)
# bit length is needed as our default bit length is only 64 bit and comparions will provide wrong results if comparands are larger
result = ((result.equal(0, bit_length=256)).bit_and(comparator.equal(5, bit_length=4))).if_else(comparand1.less_than(comparand2, bit_length=256), result)
result = ((result.equal(0, bit_length=256)).bit_and(comparator.equal(6, bit_length=4))).if_else(comparand1.less_equal(comparand2, bit_length=256), result)
result = ((result.equal(0, bit_length=256)).bit_and(comparator.equal(7, bit_length=4))).if_else(comparand1.greater_than(comparand2, bit_length=256), result)
result = ((result.equal(0, bit_length=256)).bit_and(comparator.equal(8, bit_length=4))).if_else(comparand1.greater_equal(comparand2, bit_length=256), result)
result = ((result.equal(0, bit_length=256)).bit_and(comparator.equal(9, bit_length=4))).if_else(comparand1.equal(comparand2, bit_length=256), result)
result = ((result.equal(0, bit_length=256)).bit_and(comparator.equal(10, bit_length=4))).if_else(comparand1.not_equal(comparand2, bit_length=256), result)
# sintbit does not support if_else, so the result is always a sint, which we have to convert back to sintbit
success.update(success & sintbit(result))
print_ln('DEBUG success: %s', success.reveal())
# minimum number of entries in the backrunning transaction (i.e., nonce, gas price, gas limit, to, value, and data)
BACKRUN_TX_ENTRIES_MIN = 6
# maximum number of entries in the backrunning transaction.
BACKRUN_TX_ENTRIES_MAX = 15
backrun_tx_nolen = OptimalORAM(BACKRUN_TX_ENTRIES_MAX * 32, sint)
backrun_tx_index = sint(0)
# populate nonce, gas limit, gas price, to, and value in the backrunning tx
encode_as_is(getValue()) # nonce
encode_short_string(getValue()) # gas limit
encode_short_string(getValue()) # gas price
encode_short_string(getValue()) # to
encode_as_is(getValue()) # value
# TODO: data is always rlp-encoded as long string (but may actually be a short string or even an as-is-encoded 0)
data_size = 4 + 32 * (BACKRUN_TX_ENTRIES_MAX - BACKRUN_TX_ENTRIES_MIN)
data_len = strlen(data_size)
backrun_tx_nolen[backrun_tx_index] = 0xb7 + data_len
backrun_tx_index += 1
updateTx(data_size, data_len)
# set function selector in backrunning tx
data_function = getValue()
len = strlen(data_function)
updateTx(data_function, len)
# set data entries in backrunning tx
@for_range(BACKRUN_TX_ENTRIES_MAX - BACKRUN_TX_ENTRIES_MIN)
def _(_):
data_entry = OptimalORAM(32, sint)
val = getValue()
# convert 256 bit sint to list of 32 single bytes
@for_range(32)
def _(i):
shift = 8 * (31 - i)
result = val.right_shift(shift, bit_length=256)
val.update(val - result.left_shift(shift, bit_length=256))
data_entry[i] = result
print_ln('DEBUG tx data_entry: %s', data_entry)
@for_range(32)
def _(i):
backrun_tx_nolen[backrun_tx_index] = data_entry[i]
backrun_tx_index.update(backrun_tx_index + 1)
# TODO: signature should be generated here. We just add the signature from the user's tx as a a dummy signature for now, as generating a signature is a pretty sophisticated task and but no showstopper for the backrunning in MPC approach.
# Notes for later: We first have to RLP-encode the transaction as a list (i.e., [nonce, gas price, gas limit, to address, value, data]), hash it with Keccak-256, hash that with prefix “\x19Ethereum signed message:\n32”, and finally sign the resulting hash with ECDSA (using curve secp256k1).
# for Keccak-256, there exists an arithmetic circuit that supports small inputs here: https://homes.esat.kuleuven.be/~nsmart/MPC/, but we have larger inputs. Do we need another Keccak-256 implementation? For ECDSA signing in general MPC, there exists at least one paper https://eprint.iacr.org/2019/889.pdf and an implementation of it in C++ https://github.com/data61/MP-SPDZ/tree/master/ECDSA. Eventually, v, r, and s need to be replaced in the backrunning transaction.
# TODO: v looks weird in the original tx. set to 0x27?
encode_as_is(storage[SIGNATURE_INDEX]) # v
encode_short_string(storage[SIGNATURE_INDEX + 1]) # r
encode_short_string(storage[SIGNATURE_INDEX + 2]) # s
print_ln('DEBUG tmp backrunning tx: %s', backrun_tx_nolen)
print_ln('DEBUG success: %s', success.reveal())
BACKRUN_LEN = 9 + BACKRUN_TX_ENTRIES_MAX * 32 # rlp-encoding of long list takes at most 9 bytes
backrun_tx = OptimalORAM(BACKRUN_LEN, sint)
# rlp encode backrunning tx as long list
len = strlen(backrun_tx_index)
backrun_tx[0] = 0xf7 + len
# copy content of backrunning tx after the rlp-encoded length of long list
@for_range(BACKRUN_TX_ENTRIES_MAX * 32)
def _(i):
backrun_tx[len + 1 + i] = backrun_tx_nolen[i]
# update length of backrunning tx. max bytes to encode length is 8 bytes for rlp-encoded long list.
index = sint(1)
@for_range(8)
def _(i):
cond = i + len - 8
shift = (cond.greater_equal(0)).if_else(56 - 8 * i, 64)
result = backrun_tx_index.right_shift(shift, bit_length=256)
backrun_tx_index.update(backrun_tx_index - (cond.greater_equal(0)).if_else(result.left_shift(shift, bit_length=256), 0))
backrun_tx[index] = result
index.update(index + (cond.greater_equal(0)).if_else(1, 0))
# copy backrunning tx to array to facilitate printing
output = Array(BACKRUN_LEN, sint)
@for_range(BACKRUN_LEN)
def _(i):
output[i] = backrun_tx[i]
# output backrunning tx iff all conditions were met
# TODO: reveal only to builder
output = success.if_else(output, 0)
print_ln("DEBUG final backrunning tx: %s", output.reveal())