From 2813c0ef0fa9ca48ced3f70b7081c071f0bd28fd Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Mon, 14 Aug 2023 18:29:08 +1000 Subject: [PATCH] Maintenance. --- CHANGELOG.md | 14 +++ Compiler/GC/types.py | 12 +-- Compiler/allocator.py | 125 +++++++++++++++++++--- Compiler/circuit.py | 4 +- Compiler/compilerLib.py | 5 +- Compiler/decision_tree.py | 3 +- Compiler/instructions.py | 38 +++++++ Compiler/instructions_base.py | 35 ++++--- Compiler/library.py | 11 +- Compiler/ml.py | 25 ++--- Compiler/program.py | 90 +++++++++++----- Compiler/types.py | 140 ++++++++++++++++++++----- ExternalIO/bankers-bonus-client.py | 2 + FHE/Ciphertext.cpp | 2 +- FHE/Ring.cpp | 7 ++ FHE/Ring.h | 3 +- FHE/Ring_Element.cpp | 33 +++++- FHE/Ring_Element.h | 12 ++- FHE/Rq_Element.cpp | 4 + FHE/Rq_Element.h | 7 +- FHE/Subroutines.cpp | 3 + FHEOffline/DistDecrypt.cpp | 22 +++- FHEOffline/PairwiseSetup.cpp | 7 +- GC/DealerPrep.h | 6 +- GC/MaliciousRepSecret.h | 1 + GC/Rep4Secret.h | 1 + GC/Secret.h | 3 + GC/SemiPrep.cpp | 3 + GC/ShareSecret.h | 7 ++ GC/ShareThread.hpp | 15 +-- GC/TinierShare.h | 1 + GC/TinyPrep.hpp | 4 +- GC/TinySecret.h | 3 +- Makefile | 2 +- Math/Setup.cpp | 4 +- Math/ValueInterface.cpp | 4 +- Math/Z2k.h | 5 + Math/Zp_Data.cpp | 6 +- Math/Zp_Data.h | 7 +- Math/bigint.h | 1 + Math/gfp.h | 4 + Math/gfp.hpp | 3 + Math/gfpvar.cpp | 6 ++ Math/gfpvar.h | 1 + Math/modp.hpp | 8 +- OT/NPartyTripleGenerator.h | 2 + OT/NPartyTripleGenerator.hpp | 12 ++- OT/OTMultiplier.h | 3 +- OT/OTMultiplier.hpp | 4 +- OT/OTVole.hpp | 4 +- Processor/BaseMachine.cpp | 24 +++++ Processor/BaseMachine.h | 114 ++++++++++++++++++++ Processor/DataPositions.cpp | 6 ++ Processor/Data_Files.h | 8 +- Processor/Data_Files.hpp | 51 +++++---- Processor/DummyProtocol.h | 7 ++ Processor/EdabitBuffer.h | 7 +- Processor/Instruction.cpp | 3 +- Processor/Instruction.h | 5 +- Processor/Instruction.hpp | 19 ++++ Processor/Machine.h | 6 -- Processor/Machine.hpp | 25 +++-- Processor/Memory.h | 36 ++++--- Processor/OfflineMachine.h | 6 +- Processor/OfflineMachine.hpp | 17 ++- Processor/Online-Thread.hpp | 21 ++-- Processor/PrepBase.cpp | 32 ++++-- Processor/PrepBase.h | 4 +- Processor/instructions.h | 2 + Programs/Source/mnist_full_B.mpc | 2 +- Programs/Source/mnist_full_C.mpc | 2 +- Programs/Source/mnist_full_D.mpc | 2 +- Programs/Source/tf.mpc | 6 +- Protocols/BrainPrep.hpp | 2 +- Protocols/BufferScope.h | 32 ++++++ Protocols/ChaiGearPrep.h | 2 + Protocols/CowGearPrep.h | 2 + Protocols/CowGearPrep.hpp | 3 - Protocols/DabitSacrifice.h | 13 ++- Protocols/DabitSacrifice.hpp | 25 ++++- Protocols/DealerPrep.hpp | 24 +++-- Protocols/HemiMatrixPrep.h | 2 + Protocols/HemiPrep.h | 2 + Protocols/HemiPrep.hpp | 4 + Protocols/LowGearKeyGen.hpp | 6 +- Protocols/MAC_Check.h | 50 +++++++-- Protocols/MAC_Check.hpp | 4 +- Protocols/MalRepRingPrep.h | 5 + Protocols/MalRepRingPrep.hpp | 3 +- Protocols/MalRepRingShare.h | 5 - Protocols/MaliciousRep3Share.h | 4 + Protocols/MaliciousRepPrep.hpp | 8 +- Protocols/MaliciousRingPrep.hpp | 3 +- Protocols/MascotPrep.h | 2 + Protocols/MascotPrep.hpp | 3 + Protocols/Rep3Shuffler.hpp | 3 +- Protocols/Rep4Prep.h | 5 + Protocols/Rep4Prep.hpp | 4 +- Protocols/RepRingOnlyEdabitPrep.hpp | 4 +- Protocols/ReplicatedPrep.h | 24 +++-- Protocols/ReplicatedPrep.hpp | 100 +++++++++++++----- Protocols/RingOnlyPrep.hpp | 1 + Protocols/SemiPrep.h | 2 +- Protocols/SemiPrep.hpp | 10 +- Protocols/SemiRep3Prep.h | 2 + Protocols/ShareInterface.h | 3 + Protocols/ShareMatrix.h | 1 + Protocols/ShuffleSacrifice.h | 16 +-- Protocols/ShuffleSacrifice.hpp | 21 ++-- Protocols/SohoPrep.h | 2 + Protocols/Spdz2kPrep.h | 1 + Protocols/Spdz2kPrep.hpp | 7 +- Protocols/Spdz2kShare.h | 5 +- Protocols/SpdzWisePrep.hpp | 10 +- Protocols/SpdzWiseRingPrep.h | 5 + Protocols/TemiPrep.h | 2 + Protocols/fake-stuff.h | 2 +- Protocols/fake-stuff.hpp | 6 ++ README.md | 54 +++++----- Scripts/list-field-protocols.sh | 2 +- Scripts/list-protocols.sh | 4 + Scripts/list-ring-protocols.sh | 2 +- Scripts/memory-usage.py | 6 +- Scripts/run-common.sh | 21 +++- Scripts/test_tutorial.sh | 3 - Tools/Buffer.cpp | 2 +- Tools/Exceptions.cpp | 16 +++ Tools/Exceptions.h | 12 +++ Tools/random.h | 2 +- Utils/Fake-Offline.cpp | 6 ++ Utils/protocol-tutorial.cpp | 74 +++++++++++++ Utils/stream-fake-mascot-triples.cpp | 5 +- deps/libOTe | 2 +- doc/add-protocol.rst | 3 +- doc/compilation.rst | 7 +- doc/index.rst | 1 + doc/low-level.rst | 4 +- doc/lowest-level.rst | 149 +++++++++++++++++++++++++++ doc/machine-learning.rst | 3 +- doc/troubleshooting.rst | 39 +++++++ 140 files changed, 1598 insertions(+), 388 deletions(-) create mode 100644 Protocols/BufferScope.h create mode 100755 Scripts/list-protocols.sh create mode 100644 Utils/protocol-tutorial.cpp create mode 100644 doc/lowest-level.rst diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b72829e3..313bd3de3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.7 (August 14, 2023) + +- Path Oblivious Heap (@tskovlund) +- Adjust batch and bucket size to program +- Direct communication available in more protocols +- Option for seed in fake preprocessing (@strieflin) +- Lower memory usage due to improved register allocation +- New instructions to speed up CISC compilation +- Protocol implementation example +- Fixed security bug: missing MAC checks in multi-threaded programs +- Fixed security bug: race condition in MAC check +- Fixed security bug: missing shuffling check in PS mod 2^k and Brain +- Fixed security bug: insufficient drowning in pairwise protocols + ## 0.3.6 (May 9, 2023) - More extensive benchmarking outputs diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 118cc76df..ec3644f87 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -781,8 +781,6 @@ def store_in_mem(self, address): for i in range(n): for j, x in enumerate(v[i].bit_decompose()): x.store_in_mem(address + i + j * n) - def reveal(self): - return util.untuplify([x.reveal() for x in self.elements()]) @classmethod def two_power(cls, nn, size=1): return cls.from_vec( @@ -919,8 +917,7 @@ def bit_decompose(self, n_bits=None, security=None, maybe_mixed=None): return self.v[:n_bits] bit_compose = from_vec def reveal(self): - assert len(self) == 1 - return self.v[0].reveal() + return util.untuplify([x.reveal() for x in self.elements()]) def long_one(self): return [x.long_one() for x in self.v] def __rsub__(self, other): @@ -1279,7 +1276,8 @@ def pow2(self, k): class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): """ - Vector of signed integers for parallel binary computation:: + Vector of signed integers for parallel binary computation. + The following example uses vectors of size two:: sb32 = sbits.get_type(32) siv32 = sbitintvec.get_type(32) @@ -1291,7 +1289,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): print_ln('mul: %s, %s', c[0].reveal(), c[1].reveal()) c = (a - b).elements() print_ln('sub: %s, %s', c[0].reveal(), c[1].reveal()) - c = (a < b).bit_decompose() + c = (a < b).elements() print_ln('lt: %s, %s', c[0].reveal(), c[1].reveal()) This should output:: @@ -1467,7 +1465,7 @@ class sbitfixvec(_fix, _vec): print_ln('mul: %s, %s', c[0].reveal(), c[1].reveal()) c = (a - b).elements() print_ln('sub: %s, %s', c[0].reveal(), c[1].reveal()) - c = (a < b).bit_decompose() + c = (a < b).elements() print_ln('lt: %s, %s', c[0].reveal(), c[1].reveal()) This should output roughly:: diff --git a/Compiler/allocator.py b/Compiler/allocator.py index afc2933e3..fe1848035 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -43,7 +43,7 @@ def pop(self, size): else: done = False for x in self.by_logsize[logsize + 1:]: - for block_size, addresses in x.items(): + for block_size, addresses in sorted(x.items()): if len(addresses) > 0: done = True break @@ -60,16 +60,92 @@ def pop(self, size): self.by_address[addr + size] = diff return addr +class AllocRange: + def __init__(self, base=0): + self.base = base + self.top = base + self.limit = base + self.grow = True + self.pool = defaultdict(set) + + def alloc(self, size): + if self.pool[size]: + return self.pool[size].pop() + elif self.grow or self.top + size <= self.limit: + res = self.top + self.top += size + self.limit = max(self.limit, self.top) + if res >= REG_MAX: + raise RegisterOverflowError() + return res + + def free(self, base, size): + assert self.base <= base < self.top + self.pool[size].add(base) + + def stop_growing(self): + self.grow = False + + def consolidate(self): + regs = [] + for size, pool in self.pool.items(): + for base in pool: + regs.append((base, size)) + for base, size in reversed(sorted(regs)): + if base + size == self.top: + self.top -= size + self.pool[size].remove(base) + regs.pop() + else: + if program.Program.prog.verbose: + print('cannot free %d register blocks ' + 'by a gap of %d at %d' % + (len(regs), self.top - size - base, base)) + break + +class AllocPool: + def __init__(self): + self.ranges = defaultdict(lambda: [AllocRange()]) + self.by_base = {} + + def alloc(self, reg_type, size): + for r in self.ranges[reg_type]: + res = r.alloc(size) + if res is not None: + self.by_base[reg_type, res] = r + return res + + def free(self, reg): + r = self.by_base.pop((reg.reg_type, reg.i)) + r.free(reg.i, reg.size) + + def new_ranges(self, min_usage): + for t, n in min_usage.items(): + r = self.ranges[t][-1] + assert (n >= r.limit) + if r.limit < n: + r.stop_growing() + self.ranges[t].append(AllocRange(n)) + + def consolidate(self): + for r in self.ranges.values(): + for rr in r: + rr.consolidate() + + def n_fragments(self): + return max(len(r) for r in self.ranges) + class StraightlineAllocator: """Allocate variables in a straightline program using n registers. It is based on the precondition that every register is only defined once.""" def __init__(self, n, program): self.alloc = dict_by_id() - self.usage = Compiler.program.RegType.create_dict(lambda: 0) + self.max_usage = defaultdict(lambda: 0) self.defined = dict_by_id() self.dealloc = set_by_id() - self.n = n + assert(n == REG_MAX) self.program = program + self.old_pool = None def alloc_reg(self, reg, free): base = reg.vectorbase @@ -79,14 +155,7 @@ def alloc_reg(self, reg, free): reg_type = reg.reg_type size = base.size - if free[reg_type, size]: - res = free[reg_type, size].pop() - else: - if self.usage[reg_type] < self.n: - res = self.usage[reg_type] - self.usage[reg_type] += size - else: - raise RegisterOverflowError() + res = free.alloc(reg_type, size) self.alloc[base] = res base.i = self.alloc[base] @@ -126,7 +195,7 @@ def dealloc_reg(self, reg, inst, free): for x in itertools.chain(dup.duplicates, base.duplicates): to_check.add(x) - free[reg.reg_type, base.size].append(self.alloc[base]) + free.free(base) if inst.is_vec() and base.vector: self.defined[base] = inst for i in base.vector: @@ -135,6 +204,7 @@ def dealloc_reg(self, reg, inst, free): self.defined[reg] = inst def process(self, program, alloc_pool): + self.update_usage(alloc_pool) for k,i in enumerate(reversed(program)): unused_regs = [] for j in i.get_def(): @@ -161,12 +231,26 @@ def process(self, program, alloc_pool): if k % 1000000 == 0 and k > 0: print("Allocated registers for %d instructions at" % k, time.asctime()) + self.update_max_usage(alloc_pool) + alloc_pool.consolidate() + # print "Successfully allocated registers" # print "modp usage: %d clear, %d secret" % \ # (self.usage[Compiler.program.RegType.ClearModp], self.usage[Compiler.program.RegType.SecretModp]) # print "GF2N usage: %d clear, %d secret" % \ # (self.usage[Compiler.program.RegType.ClearGF2N], self.usage[Compiler.program.RegType.SecretGF2N]) - return self.usage + return self.max_usage + + def update_max_usage(self, alloc_pool): + for t, r in alloc_pool.ranges.items(): + self.max_usage[t] = max(self.max_usage[t], r[-1].limit) + + def update_usage(self, alloc_pool): + if self.old_pool: + self.update_max_usage(self.old_pool) + if id(self.old_pool) != id(alloc_pool): + alloc_pool.new_ranges(self.max_usage) + self.old_pool = alloc_pool def finalize(self, options): for reg in self.alloc: @@ -178,6 +262,21 @@ def finalize(self, options): '\t\t')) if options.stop: sys.exit(1) + if self.program.verbose: + def p(sizes): + total = defaultdict(lambda: 0) + for (t, size) in sorted(sizes): + n = sizes[t, size] + total[t] += size * n + print('%s:%d*%d' % (t, size, n), end=' ') + print() + print('Total:', dict(total)) + + sizes = defaultdict(lambda: 0) + for reg in self.alloc: + x = reg.reg_type, reg.size + print('Used registers: ', end='') + p(sizes) def determine_scope(block, options): last_def = defaultdict_by_id(lambda: -1) diff --git a/Compiler/circuit.py b/Compiler/circuit.py index 395c66146..147c6c93e 100644 --- a/Compiler/circuit.py +++ b/Compiler/circuit.py @@ -10,7 +10,7 @@ """ from Compiler.GC.types import * -from Compiler.library import function_block +from Compiler.library import function_block, get_tape from Compiler import util import itertools import struct @@ -54,7 +54,7 @@ def __call__(self, *inputs): return self.run(*inputs) def run(self, *inputs): - n = inputs[0][0].n + n = inputs[0][0].n, get_tape() if n not in self.functions: self.functions[n] = function_block(lambda *args: self.compile(*args)) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index f83e7ca2a..f34674a7b 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -270,9 +270,9 @@ def build_program(self, name=None): self.prog = Program(self.args, self.options, name=name) if self.execute: if self.options.execute in \ - ("emulate", "ring", "rep-field"): + ("emulate", "ring", "rep-field", "rep4-ring"): self.prog.use_trunc_pr = True - if self.options.execute in ("ring",): + if self.options.execute in ("ring", "ps-rep-ring", "sy-rep-ring"): self.prog.use_split(3) if self.options.execute in ("semi2k",): self.prog.use_split(2) @@ -487,6 +487,7 @@ def local_execution(self, args=[]): "Cannot produce %s. " % executable + \ "Note that compilation requires a few GB of RAM.") vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute) + sys.stdout.flush() os.execl(vm, vm, self.prog.name, *args) def remote_execution(self, args=[]): diff --git a/Compiler/decision_tree.py b/Compiler/decision_tree.py index 92be5fc0d..5f7ac8716 100644 --- a/Compiler/decision_tree.py +++ b/Compiler/decision_tree.py @@ -633,7 +633,8 @@ def preprocess_pandas(data): res.append(data.iloc[:,i].to_numpy()) types.append('c') elif pandas.api.types.is_object_dtype(t): - values = data.iloc[:,i].unique() + values = list(filter(lambda x: isinstance(x, str), + list(data.iloc[:,i].unique()))) print('converting the following to unary:', values) if len(values) == 2: res.append(data.iloc[:,i].to_numpy() == values[1]) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 5642c59c3..f3cb6ea66 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -638,6 +638,44 @@ class prefixsums(base.Instruction): code = base.opcodes['PREFIXSUMS'] arg_format = ['sw','s'] +class picks(base.VectorInstruction): + """ Extract part of vector. + + :param: result (sint) + :param: input (sint) + :param: start offset (int) + :param: step + + """ + __slots__ = [] + code = base.opcodes['PICKS'] + arg_format = ['sw','s','int','int'] + + def __init__(self, *args): + super(picks, self).__init__(*args) + assert 0 <= args[2] < len(args[1]) + assert 0 <= args[2] + args[3] * len(args[0]) <= len(args[1]) + +class concats(base.VectorInstruction): + """ Concatenate vectors. + + :param: result (sint) + :param: start offset (int) + :param: input (sint) + :param: (repeat from offset)... + + """ + __slots__ = [] + code = base.opcodes['CONCATS'] + arg_format = tools.chain(['sw'], tools.cycle(['int','s'])) + + def __init__(self, *args): + super(concats, self).__init__(*args) + assert len(args) % 2 == 1 + assert len(args[0]) == sum(args[1::2]) + for i in range(1, len(args), 2): + assert args[i] == len(args[i + 1]) + @base.gf2n @base.vectorize class mulc(base.MulBase): diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 9e72eea7d..57ff46197 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -82,6 +82,8 @@ SUBCFI = 0x2B, SUBSFI = 0x2C, PREFIXSUMS = 0x2D, + PICKS = 0x2E, + CONCATS = 0x2F, # Multiplication/division MULC = 0x30, MULM = 0x31, @@ -523,29 +525,26 @@ def expand_merged(self, skip): tape.active_basicblock = block size = sum(call[0][0].size for call in self.calls) new_regs = [] - for arg in self.args: + for i, arg in enumerate(self.args): try: - new_regs.append(type(arg)(size=size)) - except TypeError: + if i == 0: + new_regs.append(type(arg)(size=size)) + else: + new_regs.append(type(arg).concat( + call[0][i] for call in self.calls)) + assert len(new_regs[-1]) == size + except (TypeError, AttributeError): + if not isinstance(arg, int): + raise break except: print([call[0][0].size for call in self.calls]) raise - assert len(new_regs) > 1 - base = 0 - for call in self.calls: - for new_reg, reg in zip(new_regs[1:], call[0][1:]): - set_global_vector_size(reg.size) - reg.mov(new_reg.get_vector(base, reg.size), reg) - reset_global_vector_size() - base += reg.size self.new_instructions(size, new_regs) base = 0 for call in self.calls: reg = call[0][0] - set_global_vector_size(reg.size) - reg.mov(reg, new_regs[0].get_vector(base, reg.size)) - reset_global_vector_size() + reg.copy_from_part(new_regs[0], base, reg.size) base += reg.size return block.instructions, self.n_rounds - 1 @@ -628,7 +627,7 @@ def instruction(res, arg, k, f, *args): instruction = cisc(instruction) def wrapper(*args, **kwargs): - if isinstance(args[0], sfix): + if isinstance(args[0], sfix) and program.options.cisc: for arg in args[1:]: assert util.is_constant(arg) assert not kwargs @@ -844,6 +843,12 @@ def __init__(self, *args, **kwargs): Instruction.count += 1 if Instruction.count % 100000 == 0: print("Compiled %d lines at" % self.__class__.count, time.asctime()) + if Instruction.count > 10 ** 7: + print("Compilation produced more that 10 million instructions. " + "Consider using './compile.py -l' or replacing for loops " + "with @for_range_opt: " + "https://mp-spdz.readthedocs.io/en/latest/Compiler.html#" + "Compiler.library.for_range_opt") def get_code(self, prefix=0): return (prefix << self.code_length) + self.code diff --git a/Compiler/library.py b/Compiler/library.py index 187450fe2..91f23b3eb 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -6,7 +6,7 @@ from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint, personal, copy_doc, _vec from Compiler.instructions import * from Compiler.util import tuplify,untuplify,is_zero -from Compiler.allocator import RegintOptimizer +from Compiler.allocator import RegintOptimizer, AllocPool from Compiler import instructions,instructions_base,comparison,program,util import inspect,math import random @@ -411,7 +411,7 @@ def on_first_call(self, wrapped_function): parent_node = get_tape().req_node get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name) block = get_tape().active_basicblock - block.alloc_pool = defaultdict(list) + block.alloc_pool = AllocPool() del parent_node.children[-1] self.node = get_tape().req_node if get_program().verbose: @@ -935,7 +935,7 @@ def f(i, j): ... Note that you cannot use registers across threads. Use - :py:class:`MemValue` instead:: + :py:class:`~Compiler.types.MemValue` instead:: a = MemValue(sint(0)) @for_range_opt_multithread(8, 80) @@ -1069,6 +1069,7 @@ def f(i): threads = prog.run_tapes(thread_args) for thread in threads: prog.join_tape(thread) + prog.free_later() if len(state): if thread_rounds: for i in range(n_threads - remainder): @@ -1320,6 +1321,7 @@ class State: pass state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \ name='if-block') state.has_else = False + state.closed_if = False state.caller = [frame[1:] for frame in inspect.stack()[1:]] instructions.program.curr_tape.if_states.append(state) @@ -1434,6 +1436,7 @@ def decorator(body): else: if_then(condition) _run_and_link(body) + get_tape().if_states[-1].closed_if = True return decorator def else_(body): @@ -1443,6 +1446,8 @@ def else_(body): _run_and_link(body) if_states.pop() else: + if not if_states[-1].closed_if: + raise CompilerError('@if_e not closed before else block') else_then() _run_and_link(body) end_if() diff --git a/Compiler/ml.py b/Compiler/ml.py index 07e30bec2..98677e1fc 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -822,10 +822,10 @@ def reset(self): self.W.randomize(-r, r, n_threads=self.n_threads) self.b.assign_all(0) - def input_from(self, player, raw=False): - self.W.input_from(player, raw=raw) + def input_from(self, player, **kwargs): + self.W.input_from(player, **kwargs) if self.input_bias: - self.b.input_from(player, raw=raw) + self.b.input_from(player, **kwargs) def compute_f_input(self, batch): N = len(batch) @@ -1088,10 +1088,7 @@ class Relu(ElementWiseLayer): :param shape: input/output shape (tuple/list of int) """ - f = staticmethod(relu) - f_prime = staticmethod(relu_prime) prime_type = sint - comparisons = None def __init__(self, shape, inputs=None): super(Relu, self).__init__(shape) @@ -1310,12 +1307,12 @@ def __init__(self, shape, inputs=None): self.bias = sfix.Array(shape[3]) self.inputs = inputs - def input_from(self, player, raw=False): - self.weights.input_from(player, raw=raw) - self.bias.input_from(player, raw=raw) + def input_from(self, player, **kwargs): + self.weights.input_from(player, **kwargs) + self.bias.input_from(player, **kwargs) tmp = sfix.Array(len(self.bias)) - tmp.input_from(player, raw=raw) - tmp.input_from(player, raw=raw) + tmp.input_from(player, **kwargs) + tmp.input_from(player, **kwargs) def _forward(self, batch=[0]): assert len(batch) == 1 @@ -1611,11 +1608,11 @@ def __repr__(self): self.bias_shape, self.Y.sizes, self.stride, repr(self.padding), self.tf_weight_format) - def input_from(self, player, raw=False): + def input_from(self, player, **kwargs): self.input_params_from(player) - self.weights.input_from(player, budget=100000, raw=raw) + self.weights.input_from(player, budget=100000, **kwargs) if self.input_bias: - self.bias.input_from(player, raw=raw) + self.bias.input_from(player, **kwargs) def output_weights(self): self.weights.print_reveal_nested() diff --git a/Compiler/program.py b/Compiler/program.py index c64370ca6..77755f7d6 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -112,8 +112,8 @@ def __init__(self, args, options=defaults, name=None): self.non_linear = Prime(self.security) if not self.bit_length: self.bit_length = 64 - print("Default bit length:", self.bit_length) - print("Default security parameter:", self.security) + print("Default bit length for compilation:", self.bit_length) + print("Default security parameter for compilation:", self.security) self.galois_length = int(options.galois) if self.verbose: print("Galois length:", self.galois_length) @@ -122,6 +122,7 @@ def __init__(self, args, options=defaults, name=None): self.DEBUG = options.debug self.allocated_mem = RegType.create_dict(lambda: USER_MEM) self.free_mem_blocks = defaultdict(al.BlockAllocator) + self.later_mem_blocks = defaultdict(list) self.allocated_mem_blocks = {} self.saved = 0 self.req_num = None @@ -229,24 +230,26 @@ def init_names(self, args): if self.name.endswith(ext): self.name = self.name[:-len(ext)] - if os.path.exists(args[0]): - self.infile = args[0] + infiles = [args[0]] + for x in (self.programs_dir, sys.path[0] + "/Programs"): + for ext in exts: + filename = args[0] + if not filename.endswith(ext): + filename += ext + filename = x + "/Source/" + filename + if os.path.abspath(filename) not in \ + [os.path.abspath(f) for f in infiles]: + infiles += [filename] + existing = [f for f in infiles if os.path.exists(f)] + if len(existing) == 1: + self.infile = existing[0] + elif len(existing) > 1: + raise CompilerError("ambiguous input files: " + + ", ".join(existing)) else: - infiles = [] - for x in (self.programs_dir, sys.path[0] + "/Programs"): - for ext in exts: - filename = args[0] - if not filename.endswith(ext): - filename += ext - infiles += [x + "/Source/" + filename] - for f in infiles: - if os.path.exists(f): - self.infile = f - break - else: - raise CompilerError( - "found none of the potential input files: " + - ", ".join("'%s'" % x for x in [args[0]] + infiles)) + raise CompilerError( + "found none of the potential input files: " + + ", ".join("'%s'" % x for x in [args[0]] + infiles)) """ self.name is input file name (minus extension) + any optional arguments. Used to generate output filenames @@ -463,7 +466,7 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) if addr + size >= 2**64: raise CompilerError("allocation exceeded for type '%s'" % mem_type) - self.allocated_mem_blocks[addr, mem_type] = size + self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool if single_size: from .library import get_thread_number, runtime_error_if @@ -477,12 +480,24 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): def free(self, addr, mem_type): """Free memory""" - if self.curr_block.alloc_pool is not self.curr_tape.basicblocks[0].alloc_pool: - raise CompilerError("Cannot free memory within function block") + now = True if not util.is_constant(addr): addr = self.base_addresses[str(addr)] - size = self.allocated_mem_blocks.pop((addr, mem_type)) - self.free_mem_blocks[mem_type].push(addr, size) + now = self.curr_tape == self.tapes[0] + size, pool = self.allocated_mem_blocks[addr, mem_type] + if self.curr_block.alloc_pool is not pool: + raise CompilerError("Cannot free memory across function blocks") + self.allocated_mem_blocks.pop((addr, mem_type)) + if now: + self.free_mem_blocks[mem_type].push(addr, size) + else: + self.later_mem_blocks[mem_type].append((addr, size)) + + def free_later(self): + for mem_type in self.later_mem_blocks: + for block in self.later_mem_blocks[mem_type]: + self.free_mem_blocks[mem_type].push(*block) + self.later_mem_blocks.clear() def finalize(self): # optimize the tapes @@ -744,6 +759,7 @@ def __init__(self, name, program): self.purged = False self.block_counter = 0 self.active_basicblock = None + self.old_allocated_mem = program.allocated_mem.copy() self.start_new_basicblock() self._is_empty = False self.merge_opens = True @@ -771,7 +787,7 @@ def __init__(self, parent, name, scope, exit_condition=None): scope.children.append(self) self.alloc_pool = scope.alloc_pool else: - self.alloc_pool = defaultdict(list) + self.alloc_pool = al.AllocPool() self.purged = False self.n_rounds = 0 self.n_to_merge = 0 @@ -869,6 +885,15 @@ def is_empty(self): return self._is_empty def start_new_basicblock(self, scope=False, name=""): + if self.program.verbose and self.active_basicblock and \ + self.program.allocated_mem != self.old_allocated_mem: + print("New allocated memory in %s " % self.active_basicblock.name, + end="") + for t, n in self.program.allocated_mem.items(): + if n != self.old_allocated_mem[t]: + print("%s:%d " % (t, n - self.old_allocated_mem[t]), end="") + print() + self.old_allocated_mem = self.program.allocated_mem.copy() # use False because None means no scope if scope is False: scope = self.active_basicblock @@ -1029,6 +1054,7 @@ def optimize(self, options): allocator = al.StraightlineAllocator(REG_MAX, self.program) def alloc(block): + allocator.update_usage(block.alloc_pool) for reg in sorted( block.used_from_scope, key=lambda x: (x.reg_type, x.i) ): @@ -1042,6 +1068,7 @@ def alloc_loop(block): for child in block.children: left.append(child) + allocator.old_pool = None for i, block in enumerate(reversed(self.basicblocks)): if len(block.instructions) > 1000000: print( @@ -1055,10 +1082,20 @@ def alloc_loop(block): and block.exit_block.scope is not None ): alloc_loop(block.exit_block.scope) + usage = allocator.max_usage.copy() allocator.process(block.instructions, block.alloc_pool) + if self.program.verbose and usage != allocator.max_usage: + print("Allocated registers in %s " % block.name, end="") + for t, n in allocator.max_usage.items(): + if n > usage[t]: + print("%s:%d " % (t, n - usage[t]), end="") + print() allocator.finalize(options) if self.program.verbose: - print("Tape register usage:", dict(allocator.usage)) + print("Tape register usage:", dict(allocator.max_usage)) + scopes = set(block.alloc_pool for block in self.basicblocks) + n_fragments = sum(scope.n_fragments() for scope in scopes) + print("%d register fragments in %d scopes" % (n_fragments, len(scopes))) # offline data requirements if self.program.verbose: @@ -1499,6 +1536,7 @@ def link(self, other): if Program.prog.options.noreallocate: raise CompilerError("reallocation necessary for linking, " "remove option -u") + assert self.reg_type == other.reg_type self.duplicates |= other.duplicates for dup in self.duplicates: dup.duplicates = self.duplicates diff --git a/Compiler/types.py b/Compiler/types.py index 881178cd4..1a82daafb 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -38,6 +38,7 @@ sfloat sgf2n cgf2n + personal Container types --------------- @@ -392,7 +393,7 @@ def cond_swap(self, a, b): return a - prod, b + prod def bit_xor(self, other): - """ XOR in arithmetic circuits. + """ Single-bit XOR in arithmetic circuits. :param self/other: 0 or 1 (any compatible type) :return: type depends on inputs (secret if any of them is) """ @@ -404,7 +405,7 @@ def bit_xor(self, other): return self + other - 2 * self * other def bit_or(self, other): - """ OR in arithmetic circuits. + """ Single-bit OR in arithmetic circuits. :param self/other: 0 or 1 (any compatible type) :return: type depends on inputs (secret if any of them is) """ @@ -416,14 +417,14 @@ def bit_or(self, other): return self + other - self * other def bit_and(self, other): - """ AND in arithmetic circuits. + """ Single-bit AND in arithmetic circuits. :param self/other: 0 or 1 (any compatible type) :rtype: depending on inputs (secret if any of them is) """ return self * other def bit_not(self): - """ NOT in arithmetic circuits. """ + """ Single-bit NOT in arithmetic circuits. """ return 1 - self def half_adder(self, other): @@ -611,6 +612,8 @@ def input_tensor_via(cls, player, content=None, shape=None, binary=True, if program.curr_tape != program.tapes[0]: raise CompilerError('only available in main thread') if content is not None: + if isinstance(content, (_vectorizable, Tape.Register)): + raise CompilerError('cannot input data already in the VM') requested_shape = shape if binary: import numpy @@ -800,11 +803,31 @@ def expand_to_vector(self, size=None): if self.size == size: return self assert self.size == 1 + return self._expand_to_vector(size) + + def _expand_to_vector(self, size): res = type(self)(size=size) for i in range(size): self.mov(res[i], self) return res + def copy_from_part(self, source, base, size): + set_global_vector_size(size) + self.mov(self, source.get_vector(base, size)) + reset_global_vector_size() + + @classmethod + def concat(cls, parts): + parts = list(parts) + res = cls(size=sum(len(part) for part in parts)) + base = 0 + for reg in parts: + set_global_vector_size(reg.size) + reg.mov(res.get_vector(base, reg.size), reg) + reset_global_vector_size() + base += reg.size + return res + class _arithmetic_register(_register): """ Arithmetic circuit type. """ def __init__(self, *args, **kwargs): @@ -1508,6 +1531,14 @@ def load_other(self, val): raise CompilerError("Cannot convert '%s' to integer" % \ type(val)) + def expand_to_vector(self, size=None): + if size is None: + size = get_global_vector_size() + if self.size == size: + return self + assert self.size == 1 + return self.inc(size, self, 0) + @vectorize @read_mem_value def int_op(self, other, inst, reverse=False): @@ -1762,7 +1793,8 @@ def output(self): class personal(Tape._no_truth): """ Value known to one player. Supports operations with public values and personal values known to the same player. Can be used - with :py:func:`~Compiler.library.print_ln_to`. + with :py:func:`~Compiler.library.print_ln_to`. It is possible to + convert to secret types like :py:class:`sint`. :param player: player (int) :param value: cleartext value (cint, cfix, cfloat) or array thereof @@ -2023,11 +2055,16 @@ def dot_product(cls, x, y): :rtype: same as inputs """ - x = list(x) - set_global_vector_size(x[0].size) - res = cls() - dotprods(res, x, y) - reset_global_vector_size() + if isinstance(x, cls) and isinstance(y, cls): + assert len(x) == len(y) + res = cls() + matmuls(res, x, y, 1, len(x), 1) + else: + x = list(x) + set_global_vector_size(x[0].size) + res = cls() + dotprods(res, x, y) + reset_global_vector_size() return res @classmethod @@ -2952,6 +2989,27 @@ def prefix_sum(self): prefixsums(res, self) return res + def sum(self): + res = type(self)(size=1) + picks(res, self.prefix_sum(), len(self) - 1, 0) + return res + + def _expand_to_vector(self, size): + res = type(self)(size=size) + picks(res, self, 0, 0) + return res + + def copy_from_part(self, source, base, size): + picks(self, source, base, 1) + + @classmethod + def concat(cls, parts): + parts = list(parts) + res = cls(size=sum(len(part) for part in parts)) + args = sum(([len(part), part] for part in parts), []) + concats(res, *args) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ @@ -4726,6 +4784,24 @@ def secure_permute(self, *args, **kwargs): def prefix_sum(self): return self._new(self.v.prefix_sum(), k=self.k, f=self.f) + def sum(self): + return self._new(self.v.sum()) + + @classmethod + def concat(cls, parts): + parts = list(parts) + int_parts = [] + f = parts[0].f + k = parts[0].k + for part in parts: + assert part.f == f + assert part.k == k + int_parts.append(part.v) + return cls._new(cls.int_type.concat(int_parts), k=k, f=f) + + def __repr__(self): + return '' % (self.f, self.k, self.v) + class unreduced_sfix(_single): int_type = sint @@ -5739,9 +5815,11 @@ def assign_all(self, value, use_threads=True, conv=True): if value.size != 1: raise CompilerError('cannot assign vector to all elements') mem_value = MemValue(value) - self.address = MemValue.if_necessary(self.address) n_threads = 8 if use_threads and util.is_constant(self.length) and \ - len(self) > 2**20 and not program.options.garbled else None + len(self) > 2**20 and not program.options.garbled and \ + program.curr_tape.singular else None + if n_threads is not None: + self.address = MemValue.if_necessary(self.address) @library.multithread(n_threads, self.length) def _(base, size): if use_vector: @@ -6444,12 +6522,20 @@ class t(self.value_type): try: try: self.value_type.direct_matrix_mul - assert self.value_type == other.value_type + skip_reduce = set((sint, sfix)) == \ + set((self.value_type, other.value_type)) + assert self.value_type == other.value_type or skip_reduce max_size = _register.maximum_size // res_matrix.sizes[1] @library.multithread(n_threads, self.sizes[0], max_size) def _(base, size): - res_matrix.assign_part_vector( - self.get_part(base, size).direct_mul(other), base) + tmp = self.get_part(base, size).direct_mul( + other, reduce=not skip_reduce, + res_type=sfix if skip_reduce else None) + if skip_reduce: + tmp = t._new(tmp.v) + else: + tmp = tmp.reduce_after_mul() + res_matrix.assign_part_vector(tmp, base) except AttributeError: assert n_threads is None if max(res_matrix.sizes) > 1000: @@ -6483,7 +6569,7 @@ def _(k): else: raise NotImplementedError - def direct_mul(self, other, reduce=True, indices=None): + def direct_mul(self, other, reduce=True, indices=None, res_type=None): """ Matrix multiplication in the virtual machine. Unlike :py:func:`dot`, this only works for sint and sfix, and it returns a vector instead of a data structure. @@ -6511,10 +6597,15 @@ def direct_mul(self, other, reduce=True, indices=None): other_sizes = other.sizes assert len(other.sizes) == 2 assert self.sizes[1] == other_sizes[0] - assert self.value_type == other.value_type - return self.value_type.direct_matrix_mul(self.address, other.address, - self.sizes[0], *other_sizes, - reduce=reduce, indices=indices) + if self.value_type == other.value_type: + assert res_type in (self.value_type, None) + res_type = self.value_type + else: + assert not reduce + assert res_type + return res_type.direct_matrix_mul(self.address, other.address, + self.sizes[0], *other_sizes, + reduce=reduce, indices=indices) def direct_mul_trans(self, other, reduce=True, indices=None): """ @@ -6988,7 +7079,10 @@ class _mem(_number): __ilshift__ = lambda self,other: self.write(self.read() << other) __irshift__ = lambda self,other: self.write(self.read() >> other) - iadd = __iadd__ + def iadd(self, other): + """ Addition assignment. """ + return self.__iadd__(other) + isub = __isub__ imul = __imul__ itruediv = __itruediv__ @@ -7016,7 +7110,7 @@ class MemValue(_mem): @classmethod def if_necessary(cls, value): - if util.is_constant_float(value): + if util.is_constant_float(value) or isinstance(value, MemValue): return value else: return cls(value) @@ -7121,7 +7215,7 @@ def expand_to_vector(self, size=None): return self.value_type.load_mem(addresses) def __repr__(self): - return 'MemValue(%s,%d)' % (self.value_type, self.address) + return 'MemValue(%s,%s)' % (self.value_type, self.address) class MemFloat(MemValue): diff --git a/ExternalIO/bankers-bonus-client.py b/ExternalIO/bankers-bonus-client.py index d0f8d285b..f849ac2cc 100755 --- a/ExternalIO/bankers-bonus-client.py +++ b/ExternalIO/bankers-bonus-client.py @@ -28,6 +28,8 @@ os.store(finish) os.Send(socket) +# running two rounds +# first for sint, then for sfix for x in bonus, bonus * 2 ** 16: client.send_private_inputs([domain(x)]) diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 27f85b619..bb22c11d2 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -124,7 +124,7 @@ void Ciphertext::rerandomize(const FHE_PK& pk) { Rq_Element tmp(*params); SeededPRNG G; - vector r(params->FFTD()[0].m()); + vector r(params->FFTD()[0].phi_m()); bigint p = pk.p(); assert(p != 0); for (auto& x : r) diff --git a/FHE/Ring.cpp b/FHE/Ring.cpp index 3b63f3069..9848dd1a2 100644 --- a/FHE/Ring.cpp +++ b/FHE/Ring.cpp @@ -2,6 +2,13 @@ #include "Ring.h" #include "Tools/Exceptions.h" +Ring::Ring(int m) : + mm(0), phim(0) +{ + if (m != 0) + init(*this, m); +} + void Ring::pack(octetStream& o) const { o.store(mm); diff --git a/FHE/Ring.h b/FHE/Ring.h index 8225ab88c..3f2686310 100644 --- a/FHE/Ring.h +++ b/FHE/Ring.h @@ -22,7 +22,7 @@ class Ring public: - Ring() : mm(0), phim(0) { ; } + Ring(int m = 0); ~Ring() { ; } // Rely on default copy assignment/constructor @@ -40,6 +40,7 @@ class Ring void unpack(octetStream& o); bool operator!=(const Ring& other) const; + bool operator==(const Ring& other) const { return not (*this != other); } }; void init(Ring& Rg, int m, bool generate_poly = false); diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 39690fa6a..4c53d28da 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -44,6 +44,7 @@ void Ring_Element::prepare(const Ring_Element& other) void Ring_Element::prepare_push() { element.clear(); + assert(FFTD); element.reserve(FFTD->phi_m()); } @@ -63,6 +64,7 @@ void Ring_Element::assign_zero() void Ring_Element::assign_one() { + assert(FFTD); allocate(); modp fill; if (rep==polynomial) { assignZero(fill,(*FFTD).get_prD()); } @@ -79,6 +81,7 @@ void Ring_Element::negate() if (element.empty()) return; + assert(FFTD); for (int i=0; i<(*FFTD).phi_m(); i++) { Negate(element[i],element[i],(*FFTD).get_prD()); } } @@ -87,6 +90,7 @@ void Ring_Element::negate() void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { + assert(a.FFTD); if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } if (a.element.empty()) { @@ -119,6 +123,7 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { + assert(a.FFTD); if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } if (a.element.empty()) @@ -148,6 +153,7 @@ void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { + assert(a.FFTD); if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } if (a.element.empty() or b.element.empty()) @@ -200,9 +206,11 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) } else if ((*a.FFTD).get_twop()==0) { // m a power of two case - ans.partial_assign(a); Ring_Element aa(*ans.FFTD,ans.rep); + aa.partial_assign(a); modp temp; + cerr << "slow polynomial multiplication " + "(change representation to change this)..." << endl; for (int i=0; i<(*ans.FFTD).phi_m(); i++) { for (int j=0; j<(*ans.FFTD).phi_m(); j++) { Mul(temp,a.element[i],b.element[j],(*a.FFTD).get_prD()); @@ -213,7 +221,9 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) } Add(aa.element[k],aa.element[k],temp,(*a.FFTD).get_prD()); } + cerr << "\r" << i << "/" << ans.FFTD->phi_m(); } + cerr << endl; ans=aa; } else @@ -241,6 +251,7 @@ void mul(Ring_Element& ans,const Ring_Element& a,const modp& b) Ring_Element& Ring_Element::operator +=(const Ring_Element& other) { assert(element.size() == other.element.size()); + assert(FFTD); assert(FFTD == other.FFTD); assert(rep == other.rep); for (size_t i = 0; i < element.size(); i++) @@ -252,6 +263,7 @@ Ring_Element& Ring_Element::operator +=(const Ring_Element& other) Ring_Element& Ring_Element::operator -=(const Ring_Element& other) { assert(element.size() == other.element.size()); + assert(FFTD); assert(FFTD == other.FFTD); assert(rep == other.rep); for (size_t i = 0; i < element.size(); i++) @@ -263,6 +275,7 @@ Ring_Element& Ring_Element::operator -=(const Ring_Element& other) Ring_Element& Ring_Element::operator *=(const Ring_Element& other) { assert(element.size() == other.element.size()); + assert(FFTD); assert(FFTD == other.FFTD); assert(rep == other.rep); assert(rep == evaluation); @@ -274,6 +287,7 @@ Ring_Element& Ring_Element::operator *=(const Ring_Element& other) Ring_Element& Ring_Element::operator *=(const modp& other) { + assert(FFTD); for (size_t i = 0; i < element.size(); i++) element[i] = element[i].mul(other, FFTD->get_prD()); return *this; @@ -282,6 +296,7 @@ Ring_Element& Ring_Element::operator *=(const modp& other) Ring_Element Ring_Element::mul_by_X_i(int j) const { + assert(FFTD); Ring_Element ans; ans.prepare(*this); if (element.empty()) @@ -331,6 +346,7 @@ Ring_Element Ring_Element::mul_by_X_i(int j) const void Ring_Element::randomize(PRNG& G,bool Diag) { + assert(FFTD); allocate(); if (Diag==false) { for (int i=0; i<(*FFTD).phi_m(); i++) @@ -352,6 +368,7 @@ void Ring_Element::randomize(PRNG& G,bool Diag) void Ring_Element::change_rep(RepType r) { + assert(FFTD); if (element.empty()) { rep = r; @@ -403,6 +420,7 @@ void Ring_Element::change_rep(RepType r) bool Ring_Element::equals(const Ring_Element& a) const { + assert(FFTD); if (rep!=a.rep) { throw rep_mismatch(); } if (*FFTD!=*a.FFTD) { throw pr_mismatch(); } @@ -417,6 +435,7 @@ bool Ring_Element::equals(const Ring_Element& a) const bool Ring_Element::is_zero() const { + assert(FFTD); if (element.empty()) return true; for (auto& x : element) @@ -428,6 +447,7 @@ bool Ring_Element::is_zero() const ConversionIterator Ring_Element::get_iterator() const { + assert(FFTD); if (rep != polynomial) throw runtime_error("simple iterator only available in polynomial represention"); assert(not element.empty()); @@ -436,16 +456,19 @@ ConversionIterator Ring_Element::get_iterator() const RingReadIterator Ring_Element::get_copy_iterator() const { + assert(FFTD); return *this; } RingWriteIterator Ring_Element::get_write_iterator() { + assert(FFTD); return *this; } vector Ring_Element::to_vec_bigint() const { + assert(FFTD); vector v; to_vec_bigint(v); return v; @@ -454,6 +477,7 @@ vector Ring_Element::to_vec_bigint() const void Ring_Element::to_vec_bigint(vector& v) const { + assert(FFTD); v.resize(FFTD->phi_m()); if (element.empty()) return; @@ -476,6 +500,7 @@ void Ring_Element::to_vec_bigint(vector& v) const modp Ring_Element::get_constant() const { + assert(FFTD); if (element.empty()) return {}; else @@ -516,6 +541,7 @@ void get(octetStream& o,vector& v,const Zp_Data& ZpD) void Ring_Element::pack(octetStream& o) const { + assert(FFTD); check_size(); o.store(unsigned(rep)); store(o,element,(*FFTD).get_prD()); @@ -524,6 +550,7 @@ void Ring_Element::pack(octetStream& o) const void Ring_Element::unpack(octetStream& o) { + assert(FFTD); unsigned int a; o.get(a); rep=(RepType) a; @@ -542,12 +569,14 @@ void Ring_Element::check_rep() void Ring_Element::check_size() const { + assert(FFTD); if (not element.empty() and (int)element.size() != FFTD->phi_m()) throw runtime_error("invalid element size"); } void Ring_Element::output(ostream& s) const { + assert(FFTD); s.write((char*)&rep, sizeof(rep)); auto size = element.size(); s.write((char*)&size, sizeof(size)); @@ -558,6 +587,7 @@ void Ring_Element::output(ostream& s) const void Ring_Element::input(istream& s) { + assert(FFTD); s.read((char*)&rep, sizeof(rep)); check_rep(); auto size = element.size(); @@ -579,6 +609,7 @@ void Ring_Element::check(const FFT_Data& FFTD) const size_t Ring_Element::report_size(ReportType type) const { + assert(FFTD); if (type == CAPACITY) return sizeof(modp) * element.capacity(); else diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h index 5982bbe32..04698ade8 100644 --- a/FHE/Ring_Element.h +++ b/FHE/Ring_Element.h @@ -56,9 +56,9 @@ class Ring_Element void allocate(); void set_data(const FFT_Data& prd) { FFTD=&prd; } - const FFT_Data& get_FFTD() const { return *FFTD; } - const Zp_Data& get_prD() const { return (*FFTD).get_prD(); } - const bigint& get_prime() const { return (*FFTD).get_prime(); } + const FFT_Data& get_FFTD() const { assert(FFTD); return *FFTD; } + const Zp_Data& get_prD() const { return get_FFTD().get_prD(); } + const bigint& get_prime() const { return get_FFTD().get_prime(); } void assign_zero(); void assign_one(); @@ -120,6 +120,7 @@ class Ring_Element template void from(const vector& source) { + assert(source.size() == (size_t) get_FFTD().phi_m()); from(Iterator(source)); } @@ -162,7 +163,7 @@ class RingWriteIterator : public WriteConversionIterator RepType rep; public: RingWriteIterator(Ring_Element& element) : - WriteConversionIterator(element.element, element.FFTD->get_prD()), + WriteConversionIterator(element.element, element.get_FFTD().get_prD()), element(element), rep(element.rep) { element.rep = polynomial; @@ -177,7 +178,7 @@ class RingReadIterator : public ConversionIterator Ring_Element element; public: RingReadIterator(const Ring_Element& element) : - ConversionIterator(this->element.element, element.FFTD->get_prD()), + ConversionIterator(this->element.element, element.get_FFTD().get_prD()), element(element) { this->element.change_rep(polynomial); @@ -198,6 +199,7 @@ void Ring_Element::from(const Generator& generator) T tmp; modp tmp2; prepare_push(); + assert(FFTD); for (int i=0; i<(*FFTD).phi_m(); i++) { generator.get(tmp); diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index f65eddbd5..97977fdba 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -14,7 +14,10 @@ Rq_Element::Rq_Element(const vector& prd, RepType r0, RepType r1) if (prd.size() > 0) a.push_back({prd[0], r0}); if (prd.size() > 1) + { + assert(prd[0].get_R() == prd[1].get_R()); a.push_back({prd[1], r1}); + } lev = n_mults(); } @@ -155,6 +158,7 @@ void Rq_Element::to_vec_bigint(vector& v) const if (lev==1) { vector v1; a[1].to_vec_bigint(v1); + assert(v.size() == v1.size()); bigint p0=a[0].get_prime(); bigint p1=a[1].get_prime(); bigint p0i,lambda,Q=p0*p1; diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index f315d22b2..cdf3626c7 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -63,7 +63,10 @@ class Rq_Element Rq_Element(const FHE_PK& pk); Rq_Element(const Ring_Element& b0,const Ring_Element& b1) : - a({b0, b1}), lev(n_mults()) {} + a({b0, b1}), lev(n_mults()) + { + assert(b0.get_FFTD().get_R() == b1.get_FFTD().get_R()); + } Rq_Element(const Ring_Element& b0) : a({b0}), lev(n_mults()) {} @@ -139,6 +142,8 @@ class Rq_Element template void from(const vector& source, int level=-1) { + for (auto& x : a) + assert(source.size() == (size_t) x.get_FFTD().phi_m()); from(Iterator(source), level); } diff --git a/FHE/Subroutines.cpp b/FHE/Subroutines.cpp index a688b1691..f9d8ed497 100644 --- a/FHE/Subroutines.cpp +++ b/FHE/Subroutines.cpp @@ -76,11 +76,14 @@ modp Find_Primitive_Root_2m(int m,const vector& poly,const Zp_Data& ZpD) */ modp Find_Primitive_Root_2power(int m,const Zp_Data& ZpD) { + assert((m & (m - 1)) == 0); + assert(m > 1); modp ans,e,one,base; assignOne(one,ZpD); assignOne(base,ZpD); bigint exp; exp=(ZpD.pr-1)/m; + assert(exp * m == ZpD.pr - 1); bool flag=true; while (flag) { Add(base,base,one,ZpD); // Keep incrementing base until we hit the answer diff --git a/FHEOffline/DistDecrypt.cpp b/FHEOffline/DistDecrypt.cpp index 1063717f1..a247f19d1 100644 --- a/FHEOffline/DistDecrypt.cpp +++ b/FHEOffline/DistDecrypt.cpp @@ -17,6 +17,23 @@ DistDecrypt::DistDecrypt(const Player& P, const FHE_SK& share, mf.allocate_slots(pk.p() << 64); } +class ModuloTreeSum : public TreeSum +{ + bigint modulo; + + void post_add_process(vector& values) + { + for (auto& v : values) + v %= modulo; + } + +public: + ModuloTreeSum(bigint modulo) : + modulo(modulo) + { + } +}; + template Plaintext_& DistDecrypt::run(const Ciphertext& ctx, bool NewCiphertext) { @@ -57,10 +74,7 @@ Plaintext_& DistDecrypt::run(const Ciphertext& ctx, bool NewCiphertext) } else { - TreeSum().run(vv, P); - bigint mod=params.p0(); - for (auto& v : vv) - v %= mod; + ModuloTreeSum(params.p0()).run(vv, P); } // Now get the final message diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index bc890ed21..25efc2778 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -76,6 +76,7 @@ void secure_init(T& setup, Player& P, U& machine, + to_string(CowGearOptions::singleton.top_gear()) + "-P" + to_string(P.my_num()) + "-" + to_string(P.num_players()); string reason; + auto base_setup = setup; try { @@ -107,7 +108,7 @@ void secure_init(T& setup, Player& P, U& machine, << " because no suitable material " "from a previous run was found (" << reason << ")" << endl; - setup = {}; + setup = base_setup; setup.generate(P, machine, plaintext_length, sec); setup.check(P, machine); octetStream os; @@ -122,6 +123,10 @@ void secure_init(T& setup, Player& P, U& machine, cerr << "Ciphertext length: " << params.p0().numBits(); for (size_t i = 1; i < params.FFTD().size(); i++) cerr << "+" << params.FFTD()[i].get_prime().numBits(); + cerr << " (" << DIV_CEIL(params.p0().numBits(), 64); + for (size_t i = 1; i < params.FFTD().size(); i++) + cerr << "+" << DIV_CEIL(params.FFTD()[i].get_prime().numBits(), 64); + cerr << " limbs)"; cerr << endl; } } diff --git a/GC/DealerPrep.h b/GC/DealerPrep.h index c9c8b21c1..ec2d36a71 100644 --- a/GC/DealerPrep.h +++ b/GC/DealerPrep.h @@ -33,7 +33,11 @@ class DealerPrep : public BufferPrep, ShiftableTripleBuffer> setup(*P); ProtocolSet> set(*P, setup); - for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + int buffer_size = DIV_CEIL( + BaseMachine::batch_size(DATA_TRIPLE), + DealerSecret::default_length); + set.preprocessing.buffer_extra(DATA_TRIPLE, buffer_size); + for (int i = 0; i < buffer_size; i++) { auto triple = set.preprocessing.get_triple( DealerSecret::default_length); diff --git a/GC/MaliciousRepSecret.h b/GC/MaliciousRepSecret.h index 9f941d51d..1f2a74ac3 100644 --- a/GC/MaliciousRepSecret.h +++ b/GC/MaliciousRepSecret.h @@ -70,6 +70,7 @@ class MalRepSecretBase : public ReplicatedSecret typedef U whole_type; static const bool expensive_triples = true; + static const bool malicious = true; static MC* new_mc(typename super::mac_key_type) { diff --git a/GC/Rep4Secret.h b/GC/Rep4Secret.h index 5b5582296..f464c82c7 100644 --- a/GC/Rep4Secret.h +++ b/GC/Rep4Secret.h @@ -26,6 +26,7 @@ class Rep4Secret : public RepSecretBase typedef Rep4Input Input; static const bool expensive_triples = false; + static const bool malicious = true; static MC* new_mc(typename super::mac_key_type) { return new MC; } diff --git a/GC/Secret.h b/GC/Secret.h index b4f9ac8e9..b6c3b3f65 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -135,6 +135,9 @@ class Secret static void run_tapes(const vector& args) { T::run_tapes(args); } + template + static string proto_fake_opts() { return U::fake_opts(); } + Secret(); Secret(const Integer& x) { *this = x; } diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 2cf710f3d..02cf31a5f 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -43,6 +43,9 @@ void SemiPrep::set_protocol(SemiSecret::Protocol& protocol) void SemiPrep::buffer_triples() { assert(this->triple_generator); + this->triple_generator->set_batch_size( + DIV_CEIL(BaseMachine::batch_size(DATA_TRIPLE, + this->buffer_size), 64)); this->triple_generator->generatePlainTriples(); for (auto& x : this->triple_generator->plainTriples) { diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 5114bdef6..64189822f 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -151,6 +151,12 @@ class RepSecretBase : public FixedVec, public ShareSecret { } + template + static string proto_fake_opts() + { + return T::fake_opts(); + } + RepSecretBase() { } @@ -258,6 +264,7 @@ class SemiHonestRepSecret : public ReplicatedSecret typedef SemiHonestRepSecret whole_type; static const bool expensive_triples = false; + static const bool malicious = false; static MC* new_mc(mac_key_type) { return new MC; } diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index b0eea1b0b..4056a5a9c 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -184,12 +184,15 @@ void ShareThread::xors(Processor& processor, const vector& args) int out = args[i + 1]; int left = args[i + 2]; int right = args[i + 3]; - for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++) - { - int n = min(T::default_length, n_bits - j * T::default_length); - processor.S[out + j].xor_(n, processor.S[left + j], - processor.S[right + j]); - } + if (n_bits == 1) + processor.S[out].xor_(1, processor.S[left], processor.S[right]); + else + for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++) + { + int n = min(T::default_length, n_bits - j * T::default_length); + processor.S[out + j].xor_(n, processor.S[left + j], + processor.S[right + j]); + } } } diff --git a/GC/TinierShare.h b/GC/TinierShare.h index c7ea00ef0..c7f672596 100644 --- a/GC/TinierShare.h +++ b/GC/TinierShare.h @@ -62,6 +62,7 @@ class TinierShare: public Share_, SemiShare>, typedef TinierSecret whole_type; static const int default_length = 1; + static const bool expensive_triples = true; static string name() { diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index d3efbb831..1fe06c486 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -34,7 +34,9 @@ void TinierSharePrep::buffer_secret_triples() vector> triples; TripleShuffleSacrifice sacrifice; size_t required; - required = sacrifice.minimum_n_inputs_with_combining(); + required = sacrifice.minimum_n_inputs_with_combining( + BaseMachine::batch_size(DATA_TRIPLE)); + triple_generator->set_batch_size(DIV_CEIL(required, 64)); while (triples.size() < required) { triple_generator->generatePlainTriples(); diff --git a/GC/TinySecret.h b/GC/TinySecret.h index b9c037762..ada5ca885 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -50,7 +50,8 @@ class VectorSecret : public Secret static const bool variable_players = T::variable_players; static const bool needs_ot = T::needs_ot; static const bool has_mac = T::has_mac; - static const bool expensive_triples = false; + static const bool malicious = T::malicious; + static const bool expensive_triples = T::expensive_triples; static const bool randoms_for_opens = false; static const int default_length = 64; diff --git a/Makefile b/Makefile index fc8b8fb97..282e839ce 100644 --- a/Makefile +++ b/Makefile @@ -82,7 +82,7 @@ CONFIG.mine: %.o: %.cpp $(CXX) -o $@ $< $(CFLAGS) -MMD -MP -c -online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x emulate.x +online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x emulate.x mascot-party.x offline: $(OT_EXE) Check-Offline.x mascot-offline.x cowgear-offline.x mal-shamir-offline.x diff --git a/Math/Setup.cpp b/Math/Setup.cpp index 38cc6a388..2a54575c2 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -154,9 +154,9 @@ void check_setup(string dir, bigint pr) string filename = dir + "Params-Data"; ifstream(filename) >> p; if (p == 0) - throw runtime_error("no modulus in " + filename); + throw setup_error("no modulus in " + filename); if (p != pr) - throw runtime_error("wrong modulus in " + filename); + throw setup_error("wrong modulus in " + filename); } string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, diff --git a/Math/ValueInterface.cpp b/Math/ValueInterface.cpp index db7904bba..68758fb08 100644 --- a/Math/ValueInterface.cpp +++ b/Math/ValueInterface.cpp @@ -12,7 +12,7 @@ void ValueInterface::check_setup(const string& directory) { struct stat sb; if (stat(directory.c_str(), &sb) != 0) - throw runtime_error(directory + " does not exist"); + throw setup_error(directory + " does not exist"); if (not (sb.st_mode & S_IFDIR)) - throw runtime_error(directory + " is not a directory"); + throw setup_error(directory + " is not a directory"); } diff --git a/Math/Z2k.h b/Math/Z2k.h index 1414b5d23..924aa9536 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -109,6 +109,11 @@ class Z2 : public ValueInterface void assign(const void* buffer) { avx_memcpy(a, buffer, N_BYTES); normalize(); } void assign(int x) { *this = x; } + /** + * Get 64-bit part. + * + * @param i return word containing 64*i- to 64*i+63-least significant bits + */ mp_limb_t get_limb(int i) const { return a[i]; } bool get_bit(int i) const; diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index 95ac1e8d4..e207d8250 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -8,15 +8,17 @@ void Zp_Data::init(const bigint& p,bool mont) { lock.lock(); -#ifdef VERBOSE if (pr != 0) { +#ifdef VERBOSE if (pr != p) cerr << "Changing prime from " << pr << " to " << p << endl; if (mont != montgomery) cerr << "Changing Montgomery" << endl; - } #endif + if (pr != p or mont != montgomery) + throw runtime_error("Zp_Data instance already initialized"); + } if (not probPrime(p)) throw runtime_error(p.get_str() + " is not a prime"); diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 5ff3f6351..8677df404 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -47,7 +47,7 @@ class Zp_Data void Mont_Mult_switch(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y, int t) const; void Mont_Mult_variable(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const - { Mont_Mult(z, x, y, t); } + { Mont_Mult(z, x, y, get_t()); } void Mont_Mult_max(mp_limb_t* z, const mp_limb_t* x, const mp_limb_t* y, int max_t) const; @@ -61,7 +61,7 @@ class Zp_Data void assign(const Zp_Data& Zp); void init(const bigint& p,bool mont=true); - int get_t() const { return t; } + int get_t() const { assert(t > 0); return t; } const mp_limb_t* get_prA() const { return prA; } bool get_mont() const { return montgomery; } mp_limb_t overhang_mask() const; @@ -73,8 +73,9 @@ class Zp_Data Zp_Data() : montgomery(0), pi(0), mask(0), pr_byte_length(0), pr_bit_length(0) { - t = MAX_MOD_SZ; + t = -1; overhang = 0; + shanks_r = 0; } // The main init funciton diff --git a/Math/bigint.h b/Math/bigint.h index bdf3666a4..41da70f1e 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -91,6 +91,7 @@ class bigint : public mpz_class template bigint& operator=(const SignedZ2& x); + /// Convert to signed representation in :math:`[-p/2,p/2]`. template bigint& from_signed(const gfp_& other); template diff --git a/Math/gfp.h b/Math/gfp.h index fe6f64c3c..31f3a571c 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -61,6 +61,8 @@ class gfp_ : public ValueInterface static thread_local vector powers; + static gfp_ two; + public: typedef gfp_ value_type; @@ -317,6 +319,8 @@ gfp_::gfp_(long x) assign_zero(); else if (x == 1) assign_one(); + else if (x == 2) + *this = two; else *this = bigint::tmp = x; } diff --git a/Math/gfp.hpp b/Math/gfp.hpp index 0e0f7b624..2a2e27785 100644 --- a/Math/gfp.hpp +++ b/Math/gfp.hpp @@ -16,6 +16,8 @@ template const true_type gfp_::prime_field; template const int gfp_::MAX_N_BITS; +template +gfp_ gfp_::two; template inline void gfp_::read_or_generate_setup(string dir, @@ -50,6 +52,7 @@ void gfp_::init_field(const bigint& p, bool mont) else cerr << name << " larger than necessary for modulus " << p << endl; } + two = bigint::tmp = 2; } template diff --git a/Math/gfpvar.cpp b/Math/gfpvar.cpp index 06b753854..383d45751 100644 --- a/Math/gfpvar.cpp +++ b/Math/gfpvar.cpp @@ -80,6 +80,12 @@ void gfpvar_::init_default(int lgp, bool montgomery) init_field(SPDZ_Data_Setup_Primes(lgp), montgomery); } +template +inline void gfpvar_::reset() +{ + ZpD = {}; +} + template const Zp_Data& gfpvar_::get_ZpD() { diff --git a/Math/gfpvar.h b/Math/gfpvar.h index b6ab2ae3e..ceb4e9ed3 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -68,6 +68,7 @@ class gfpvar_ { init_field(T::pr(), montgomery); } + static void reset(); static const Zp_Data& get_ZpD(); static const bigint& pr(); diff --git a/Math/modp.hpp b/Math/modp.hpp index 50f93cae7..32faf9766 100644 --- a/Math/modp.hpp +++ b/Math/modp.hpp @@ -102,7 +102,7 @@ bool isZero(const modp_& ans,const Zp_Data& ZpD) template void assignOne(modp_& x,const Zp_Data& ZpD) { if (ZpD.montgomery) - { mpn_copyi(x.x,ZpD.R,ZpD.t); } + { mpn_copyi(x.x,ZpD.R,ZpD.get_t()); } else { assignZero(x,ZpD); x.x[0]=1; @@ -177,7 +177,7 @@ void modp_::to_bigint(bigint& ans,const Zp_Data& ZpD,bool reduce) const template void to_modp(modp_& ans,int x,const Zp_Data& ZpD) { - inline_mpn_zero(ans.x,ZpD.t); + inline_mpn_zero(ans.x,ZpD.get_t()); if (x>=0) { ans.x[0]=x; if (ZpD.t==1) { ans.x[0]=ans.x[0]%ZpD.prA[0]; } @@ -232,13 +232,13 @@ void modp_::convert_destroy(const fixint& xx, SignedZ2<64 * L> tmp = xx; if (xx.negative()) tmp += ZpD.pr; - convert(tmp.get(), ZpD.t, ZpD, false); + convert(tmp.get(), ZpD.get_t(), ZpD, false); } template void modp_::convert(const mp_limb_t* source, mp_size_t size, const Zp_Data& ZpD, bool negative) { - assert(size <= ZpD.t); + assert(size <= ZpD.get_t()); if (negative) mpn_sub(x, ZpD.prA, ZpD.t, source, size); else diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index b212a4805..f2aab9db1 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -110,6 +110,8 @@ class OTTripleGenerator : public GeneratorThread Player* parentPlayer = 0); ~OTTripleGenerator(); + void set_batch_size(int nTriples); + void generate() { throw not_implemented(); } void generatePlainTriples(); diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index f52502540..6a73e9834 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -69,6 +69,14 @@ Spdz2kTripleGenerator::Spdz2kTripleGenerator(const OTTripleSetup& setup, { } +template +void OTTripleGenerator::set_batch_size(int batch_size) +{ + nTriplesPerLoop = DIV_CEIL(batch_size, nloops); + nTriples = nTriplesPerLoop * nloops; + nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify; +} + template OTTripleGenerator::OTTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, @@ -84,11 +92,9 @@ OTTripleGenerator::OTTripleGenerator(const OTTripleSetup& setup, machine(machine), MC(0) { - nTriplesPerLoop = DIV_CEIL(_nTriples, nloops); - nTriples = nTriplesPerLoop * nloops; field_size = T::open_type::size() * 8; nAmplify = machine.amplify ? N_AMPLIFY : 1; - nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify; + set_batch_size(_nTriples); int n = nparties; //baseReceiverInput = machines[0]->baseReceiverInput; diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index 0f86bc0ca..64b78412c 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -13,6 +13,7 @@ using namespace std; #include "OT/OTVole.h" #include "OT/Rectangle.h" #include "Tools/random.h" +#include "Tools/CheckVector.h" template class NPartyTripleGenerator; @@ -187,7 +188,7 @@ class SemiMultiplier : public OTMultiplier } public: - vector c_output; + CheckVector c_output; SemiMultiplier(OTTripleGenerator& generator, int i) : OTMultiplier(generator, i) diff --git a/OT/OTMultiplier.hpp b/OT/OTMultiplier.hpp index 1c05296b8..69636cfe7 100644 --- a/OT/OTMultiplier.hpp +++ b/OT/OTMultiplier.hpp @@ -163,10 +163,10 @@ void SemiMultiplier::multiplyForBits() rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput); + int n_squares = otCorrelator.receiverOutputMatrix.squares.size(); otCorrelator.setup_for_correlation(aBits, baseSenderOutputs, baseReceiverOutput); - otCorrelator.correlate(0, otCorrelator.receiverOutputMatrix.squares.size(), - this->generator.valueBits[0], false, -1); + otCorrelator.correlate(0, n_squares, aBits, false, -1); c_output.clear(); diff --git a/OT/OTVole.hpp b/OT/OTVole.hpp index 13f58f7fd..f206d613f 100644 --- a/OT/OTVole.hpp +++ b/OT/OTVole.hpp @@ -114,7 +114,7 @@ void OTVoleBase::hash_row(__m128i res[2], const U& row, int num_blocks = DIV_CEIL(row.size() * T::size(), 16); __m128i buffer[T::size()]; size_t next = 0; - while (next + 16 < row.size()) + while (next + 16 <= row.size()) { for (int j = 0; j < 16; j++) memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size()); @@ -124,6 +124,8 @@ void OTVoleBase::hash_row(__m128i res[2], const U& row, for (int j = 0; j < 16; j++) if (next < row.size()) memcpy((char*) buffer + j * T::size(), row[next++].get_ptr(), T::size()); + else + memset((char*) buffer + j * T::size(), 0, T::size()); for (int j = 0; j < num_blocks % T::size(); j++) add_mul(res, buffer[j], *coefficients++); assert(coefficients == coeff_base + num_blocks); diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index ee9e19bc3..105a755f2 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -8,6 +8,8 @@ #include "Math/Setup.h" #include "Tools/Bundle.h" +#include "Protocols/ShuffleSacrifice.hpp" + #include #include using namespace std; @@ -30,6 +32,28 @@ BaseMachine& BaseMachine::s() throw runtime_error("no singleton"); } +bool BaseMachine::has_program() +{ + return has_singleton() and not s().progs.empty(); +} + +int BaseMachine::edabit_bucket_size(int n_bits) +{ + int res = OnlineOptions::singleton.bucket_size; + + if (has_program()) + { + auto usage = s().progs[0].get_offline_data_used().total_edabits(n_bits); + for (int B = res; B <= 5; B++) + if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9) + break; + else + res = B; + } + + return res; +} + BaseMachine::BaseMachine() : nthreads(0) { if (sodium_init() == -1) diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 46b1a85e6..7d2da9be3 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -11,6 +11,8 @@ #include "OT/OTTripleSetup.h" #include "ThreadJob.h" #include "ThreadQueues.h" +#include "Program.h" +#include "OnlineOptions.h" #include #include @@ -44,8 +46,11 @@ class BaseMachine vector bc_filenames; + vector progs; + static BaseMachine& s(); static bool has_singleton() { return singleton != 0; } + static bool has_program(); static string memory_filename(const string& type_short, int my_number); @@ -54,6 +59,12 @@ class BaseMachine static int prime_length_from_schedule(string progname); static bigint prime_from_schedule(string progname); + template + static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0); + template + static int edabit_batch_size(int n_bits, int buffer_size = 0); + static int edabit_bucket_size(int n_bits); + BaseMachine(); virtual ~BaseMachine() {} @@ -76,6 +87,8 @@ class BaseMachine void print_global_comm(Player& P, const NamedCommStats& stats); void print_comm(Player& P, const NamedCommStats& stats); + + virtual const Names& get_N() { throw not_implemented(); } }; inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P) @@ -83,4 +96,105 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P) return ot_setup.get_fresh(P); } +template +int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback) +{ + int n_opts; + int n = 0; + int res = 0; + + if (buffer_size > 0) + n_opts = buffer_size; + else if (fallback > 0) + n_opts = fallback; + else + n_opts = OnlineOptions::singleton.batch_size; + + if (buffer_size <= 0 and has_program()) + { + auto files = s().progs[0].get_offline_data_used().files; + auto usage = files[T::clear::field_type()]; + + if (type == DATA_DABIT and T::LivePrep::bits_from_dabits()) + n = usage[DATA_BIT] + usage[DATA_DABIT]; + else if (type == DATA_BIT and T::LivePrep::dabits_from_bits()) + n = usage[DATA_BIT] + usage[DATA_DABIT]; + else + n = usage[type]; + } + else if (type != DATA_DABIT) + { + n = buffer_size; + buffer_size = 0; + n_opts = OnlineOptions::singleton.batch_size; + } + + if (n > 0 and not (buffer_size > 0)) + { + bool used_frac = false; + if (n > n_opts) + { + // finding the right fraction + for (int i = 1; i <= 10; i++) + { + int frac = DIV_CEIL(n, i); + if (frac <= n_opts) + { + res = frac; + used_frac = true; +#ifdef DEBUG_BATCH_SIZE + cerr << "found fraction " << frac << endl; +#endif + break; + } + } + } + if (not used_frac) + res = min(n, n_opts); + } + else + res = n_opts; + +#ifdef DEBUG_BATCH_SIZE + cerr << DataPositions::dtype_names[type] << " " << T::type_string() + << " res=" << res << " n=" + << n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl; +#endif + + assert(res > 0); + return res; +} + +template +int BaseMachine::edabit_batch_size(int n_bits, int buffer_size) +{ + int n_opts; + int n = 0; + int res; + + if (buffer_size > 0) + n_opts = buffer_size; + else + n_opts = OnlineOptions::singleton.batch_size; + + if (has_program()) + { + n = s().progs[0].get_offline_data_used().total_edabits(n_bits); + } + + if (n > 0 and not (buffer_size > 0)) + res = min(n, n_opts); + else + res = n_opts; + +#ifdef DEBUG_BATCH_SIZE + cerr << "edaBits " << T::type_string() << " (" << n_bits + << ") res=" << res << " n=" + << n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl; +#endif + + assert(res > 0); + return res; +} + #endif /* PROCESSOR_BASEMACHINE_H_ */ diff --git a/Processor/DataPositions.cpp b/Processor/DataPositions.cpp index 2294a5991..b93990619 100644 --- a/Processor/DataPositions.cpp +++ b/Processor/DataPositions.cpp @@ -229,3 +229,9 @@ bool DataPositions::any_more(const DataPositions& other) const return false; } + +long long DataPositions::total_edabits(int n_bits) const +{ + auto usage = edabits; + return usage[{false, n_bits}] + usage[{true, n_bits}]; +} diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index a4a3e515a..ae948294e 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -85,6 +85,8 @@ class DataPositions void print_cost() const; bool empty() const; bool any_more(const DataPositions& other) const; + + long long total_edabits(int n_bits) const; }; template class Processor; @@ -229,6 +231,10 @@ class Sub_Data_Files : public Preprocessing static long additional_inputs(const DataPositions& usage); + static string get_prep_dir(const Names& N); + static void check_setup(const Names& N); + static void check_setup(int num_players, const string& prep_dir); + Sub_Data_Files(int my_num, int num_players, const string& prep_data_dir, DataPositions& usage, int thread_num = -1); Sub_Data_Files(const Names& N, DataPositions& usage, int thread_num = -1); @@ -299,7 +305,7 @@ class Data_Files Data_Files(Machine& machine, SubProcessor* procp = 0, SubProcessor* proc2 = 0); - Data_Files(const Names& N); + Data_Files(const Names& N, int thread_num = -1); ~Data_Files(); DataPositions tellg() { return usage; } diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 2552dc113..df6dfba38 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -60,8 +60,7 @@ T Preprocessing::get_random_from_inputs(int nplayers) template Sub_Data_Files::Sub_Data_Files(const Names& N, DataPositions& usage, int thread_num) : - Sub_Data_Files(N, - OnlineOptions::singleton.prep_dir_prefix(N.num_players()), usage, + Sub_Data_Files(N, get_prep_dir(N), usage, thread_num) { } @@ -98,6 +97,32 @@ string Sub_Data_Files::get_edabit_filename(const Names& N, int n_bits, get_prep_sub_dir(N.num_players()), n_bits, N.my_num(), thread_num); } +template +string Sub_Data_Files::get_prep_dir(const Names& N) +{ + return OnlineOptions::singleton.prep_dir_prefix(N.num_players()); +} + +template +void Sub_Data_Files::check_setup(const Names& N) +{ + return check_setup(N.num_players(), get_prep_dir(N)); +} + +template +void Sub_Data_Files::check_setup(int num_players, const string& prep_dir) +{ + try + { + T::clear::check_setup(prep_dir); + } + catch (exception& e) + { + throw prep_setup_error(e.what(), num_players, + T::template proto_fake_opts()); + } +} + template Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, const string& prep_data_dir, DataPositions& usage, int thread_num) : @@ -109,19 +134,7 @@ Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, cerr << "Setting up Data_Files in: " << prep_data_dir << endl; #endif - try - { - T::clear::check_setup(prep_data_dir); - } - catch (...) - { - cerr << "Something is wrong with the preprocessing data on disk." << endl; - cerr - << "Have you run the right program for generating it, such as './Fake-Offline.x " - << num_players - << T::clear::fake_opts() << "'?" << endl; - throw; - } + check_setup(num_players, prep_data_dir); string type_short = T::type_short(); string type_string = T::type_string(); @@ -173,11 +186,11 @@ Data_Files::Data_Files(Machine& machine, SubProcessor< } template -Data_Files::Data_Files(const Names& N) : +Data_Files::Data_Files(const Names& N, int thread_num) : usage(N.num_players()), - DataFp(*new Sub_Data_Files(N, usage)), - DataF2(*new Sub_Data_Files(N, usage)), - DataFb(*new Sub_Data_Files(N, usage)) + DataFp(*new Sub_Data_Files(N, usage, thread_num)), + DataF2(*new Sub_Data_Files(N, usage, thread_num)), + DataFb(*new Sub_Data_Files(N, usage, thread_num)) { } diff --git a/Processor/DummyProtocol.h b/Processor/DummyProtocol.h index b3ed5bc54..028b65ac9 100644 --- a/Processor/DummyProtocol.h +++ b/Processor/DummyProtocol.h @@ -112,6 +112,8 @@ template class DummyLivePrep : public Preprocessing { public: + static const bool homomorphic = true; + static void basic_setup(Player&) { } @@ -125,6 +127,11 @@ class DummyLivePrep : public Preprocessing "live preprocessing not implemented for " + T::type_string()); } + static bool bits_from_dabits() + { + return false; + } + DummyLivePrep(DataPositions& usage, GC::ShareThread&) : Preprocessing(usage) { diff --git a/Processor/EdabitBuffer.h b/Processor/EdabitBuffer.h index af87a0559..d6506b75f 100644 --- a/Processor/EdabitBuffer.h +++ b/Processor/EdabitBuffer.h @@ -29,7 +29,12 @@ class EdabitBuffer : public BufferOwner if (not BufferBase::file) { if (this->open()->fail()) - throw runtime_error("error opening " + this->filename); + throw runtime_error( + "error opening " + this->filename + + ", have you generated edaBits, " + "for example by running " + "'./Fake-Offline.x -e " + + to_string(n_bits) + " ...'?"); } assert(BufferBase::file); diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index 68acda3b9..7cb309fc2 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -6,6 +6,7 @@ #include "Instruction.h" #include "instructions.h" #include "Processor.h" +#include "Memory.h" #include "Math/gf2n.h" #include "GC/instructions.h" @@ -54,7 +55,7 @@ void Instruction::gbitcom(vector& registers) const } } -void Instruction::execute_regint(ArithmeticProcessor& Proc, vector& Mi) const +void Instruction::execute_regint(ArithmeticProcessor& Proc, MemoryPart& Mi) const { (void) Mi; auto& Ci = Proc.get_Ci(); diff --git a/Processor/Instruction.h b/Processor/Instruction.h index ea900b3e3..36ffbed57 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -14,6 +14,7 @@ using namespace std; template class Machine; template class Processor; template class SubProcessor; +template class MemoryPart; class ArithmeticProcessor; class SwitchableOutput; @@ -86,6 +87,8 @@ enum SUBCFI = 0x2B, SUBSFI = 0x2C, PREFIXSUMS = 0x2D, + PICKS = 0x2E, + CONCATS = 0x2F, // Multiplication/division/other arithmetic MULC = 0x30, MULM = 0x31, @@ -392,7 +395,7 @@ class Instruction : public BaseInstruction template void gbitcom(vector& registers) const; - void execute_regint(ArithmeticProcessor& Proc, vector& Mi) const; + void execute_regint(ArithmeticProcessor& Proc, MemoryPart& Mi) const; void shuffle(ArithmeticProcessor& Proc) const; void bitdecint(ArithmeticProcessor& Proc) const; diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 276b2b3d4..bb555281b 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -208,6 +208,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) get_ints(r, s, 2); n = get_int(s); break; + case PICKS: + get_ints(r, s, 3); + n = get_int(s); + break; case USE: case USE_INP: case USE_EDABIT: @@ -392,6 +396,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case EDABIT: case SEDABIT: case WRITEFILESHARE: + case CONCATS: num_var_args = get_int(s) - 1; r[0] = get_int(s); get_vector(num_var_args, start, s); @@ -930,6 +935,20 @@ inline void Instruction::execute(Processor& Proc) const case MOVC: Proc.write_Cp(r[0],Proc.read_Cp(r[1])); break; + case CONCATS: + { + auto& S = Proc.Procp.get_S(); + auto dest = S.begin() + r[0]; + for (auto j = start.begin(); j < start.end(); j += 2) + { + auto source = S.begin() + *(j + 1); + assert(dest + *j <= S.end()); + assert(source + *j <= S.end()); + for (int k = 0; k < *j; k++) + *dest++ = *source++; + } + return; + } case DIVC: Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / sanitize(Proc.Procp, r[2])); break; diff --git a/Processor/Machine.h b/Processor/Machine.h index fb7f5d939..803e3d919 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -53,8 +53,6 @@ class Machine : public BaseMachine public: - vector progs; - Memory M2; Memory Mp; Memory Mi; @@ -63,10 +61,6 @@ class Machine : public BaseMachine vector join_timer; Timer finish_timer; - bool direct; - int opening_sum; - bool receive_threads; - int max_broadcast; bool use_encryption; bool live_prep; diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 965e0fdd3..d9c245819 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -55,8 +55,6 @@ template Machine::Machine(Names& playerNames, bool use_encryption, const OnlineOptions opts, int lg2) : my_number(playerNames.my_num()), N(playerNames), - direct(opts.direct), opening_sum(opts.opening_sum), - receive_threads(opts.receive_threads), max_broadcast(opts.max_broadcast), use_encryption(use_encryption), live_prep(opts.live_prep), opts(opts), external_clients(my_number) { @@ -69,11 +67,6 @@ Machine::Machine(Names& playerNames, bool use_encryption, exit(1); } - if (opening_sum < 2) - this->opening_sum = N.num_players(); - if (max_broadcast < 2) - this->max_broadcast = N.num_players(); - // Set the prime modulus from command line or program if applicable if (opts.prime) sint::clear::init_field(opts.prime); @@ -102,7 +95,17 @@ Machine::Machine(Names& playerNames, bool use_encryption, sint::bit_type::MAC_Check::setup(*P); sgf2n::MAC_Check::setup(*P); - alphapi = read_generate_write_mac_key(*P); + if (opts.live_prep) + alphapi = read_generate_write_mac_key(*P); + else + { + // check for directory + Sub_Data_Files::check_setup(N); + // require existing MAC key + if (sint::has_mac) + read_mac_key(N, alphapi); + } + alpha2i = read_generate_write_mac_key(*P); alphabi = read_generate_write_mac_key(*P); @@ -451,6 +454,7 @@ void Machine::run(const string& progname) finish_timer.start(); // actual usage + bool multithread = nthreads > 1; auto res = stop_threads(); DataPositions& pos = res.first; @@ -479,7 +483,10 @@ void Machine::run(const string& progname) cerr << "Communication details " "(rounds in parallel threads counted double):" << endl; comm_stats.print(); - cerr << "CPU time = " << proc_timer.elapsed() << endl; + cerr << "CPU time = " << proc_timer.elapsed(); + if (multithread) + cerr << " (overall core time)"; + cerr << endl; } print_timers(); diff --git a/Processor/Memory.h b/Processor/Memory.h index 1fbeda7ec..16e885485 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -19,6 +19,28 @@ template class MemoryPart : public CheckVector { public: + template + static void check_index(const vector& M, size_t i) + { + (void) M, (void) i; +#ifndef NO_CHECK_INDEX + if (i >= M.size()) + throw overflow(U::type_string() + " memory", i, M.size()); +#endif + } + + T& operator[](size_t i) + { + check_index(*this, i); + return CheckVector::operator[](i); + } + + const T& operator[](size_t i) const + { + check_index(*this, i); + return CheckVector::operator[](i); + } + void minimum_size(size_t size); }; @@ -40,35 +62,21 @@ class Memory size_t size_c() { return MC.size(); } - template - static void check_index(const vector& M, size_t i) - { - (void) M, (void) i; -#ifndef NO_CHECK_INDEX - if (i >= M.size()) - throw overflow(U::type_string() + " memory", i, M.size()); -#endif - } - const typename T::clear& read_C(size_t i) const { - check_index(MC, i); return MC[i]; } const T& read_S(size_t i) const { - check_index(MS, i); return MS[i]; } void write_C(size_t i,const typename T::clear& x) { - check_index(MC, i); MC[i]=x; } void write_S(size_t i,const T& x) { - check_index(MS, i); MS[i]=x; } diff --git a/Processor/OfflineMachine.h b/Processor/OfflineMachine.h index d1b142569..5f3f0ea6f 100644 --- a/Processor/OfflineMachine.h +++ b/Processor/OfflineMachine.h @@ -12,13 +12,11 @@ #include "Networking/CryptoPlayer.h" template -class OfflineMachine : public W +class OfflineMachine : public W, BaseMachine { DataPositions usage; - BaseMachine machine; Names& playerNames; Player& P; - int n_threads; template void generate(); @@ -34,6 +32,8 @@ class OfflineMachine : public W template int run(); + + const Names& get_N(); }; #endif /* PROCESSOR_OFFLINEMACHINE_H_ */ diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index e18c47b5f..f869d638f 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -18,18 +18,19 @@ OfflineMachine::OfflineMachine(int argc, const char** argv, W(argc, argv, opt, online_opts, V(), nplayers), playerNames( W::playerNames), P(*this->new_player("machine")) { - machine.load_schedule(online_opts.progname, false); + load_schedule(online_opts.progname, false); Program program(playerNames.num_players()); - program.parse(machine.bc_filenames[0]); + program.parse(bc_filenames[0]); + progs.push_back(program); if (program.usage_unknown()) { - cerr << "Preprocessing might be insufficient " + cerr << "Preprocessing will be insufficient " << "due to unknown requirements" << endl; + exit(1); } usage = program.get_offline_data_used(); - n_threads = machine.nthreads; } template @@ -73,7 +74,7 @@ int OfflineMachine::run() template int OfflineMachine::buffered_total(size_t required, size_t batch) { - return DIV_CEIL(required, batch) * batch + (n_threads - 1) * batch; + return DIV_CEIL(required, batch) * batch + (nthreads - 1) * batch; } template @@ -183,4 +184,10 @@ void OfflineMachine::generate() output.Check(P); } +template +const Names& OfflineMachine::get_N() +{ + return playerNames; +} + #endif /* PROCESSOR_OFFLINEMACHINE_HPP_ */ diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 3409c0470..5bce952fe 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -43,6 +43,7 @@ void thread_info::Sub_Main_Func() BaseMachine::s().thread_num = num; auto& queues = machine.queues[num]; + auto& opts = machine.opts; queues->next(); ThreadQueue::thread_queue = queues; @@ -58,7 +59,7 @@ void thread_info::Sub_Main_Func() #endif player = new CryptoPlayer(*(tinfo->Nms), id); } - else if (!machine.receive_threads or machine.direct) + else if (!opts.receive_threads or opts.direct) { #ifdef VERBOSE_OPTIONS cerr << "Using single-threaded receiving" << endl; @@ -80,7 +81,7 @@ void thread_info::Sub_Main_Func() typename sgf2n::MAC_Check* MC2; typename sint::MAC_Check* MCp; - if (machine.direct) + if (opts.direct) { #ifdef VERBOSE_OPTIONS cerr << "Using direct communication." << endl; @@ -93,8 +94,8 @@ void thread_info::Sub_Main_Func() #ifdef VERBOSE_OPTIONS cerr << "Using indirect communication." << endl; #endif - MC2 = new typename sgf2n::MAC_Check(*(tinfo->alpha2i), machine.opening_sum, machine.max_broadcast); - MCp = new typename sint::MAC_Check(*(tinfo->alphapi), machine.opening_sum, machine.max_broadcast); + MC2 = new typename sgf2n::MAC_Check(*(tinfo->alpha2i), opts.opening_sum, opts.max_broadcast); + MCp = new typename sint::MAC_Check(*(tinfo->alphapi), opts.opening_sum, opts.max_broadcast); } // Allocate memory for first program before starting the clock @@ -376,6 +377,10 @@ void* thread_info::Main_Func(void* ptr) { ti.Sub_Main_Func(); } + catch (setup_error&) + { + throw; + } catch (...) { thread_info* ti = (thread_info*)ptr; @@ -393,16 +398,20 @@ void thread_info::purge_preprocessing(const Names& N, int thread_nu cerr << "Purging preprocessed data because something is wrong" << endl; try { - Data_Files df(N); + Data_Files df(N, thread_num); df.purge(); DataPositions pos; Sub_Data_Files bit_df(N, pos, thread_num); bit_df.get_part(); bit_df.purge(); } - catch(...) + catch(setup_error&) + { + } + catch(exception& e) { cerr << "Purging failed. This might be because preprocessed data is incomplete." << endl << "SECURITY FAILURE; YOU ARE ON YOUR OWN NOW!" << endl; + cerr << "Reason: " << e.what() << endl; } } diff --git a/Processor/PrepBase.cpp b/Processor/PrepBase.cpp index 775caf01a..5403b6f7b 100644 --- a/Processor/PrepBase.cpp +++ b/Processor/PrepBase.cpp @@ -41,25 +41,28 @@ string PrepBase::get_edabit_filename(const string& prep_data_dir, } void PrepBase::print_left(const char* name, size_t n, const string& type_string, - size_t used) + size_t used, bool large) { if (n > 0 and OnlineOptions::singleton.verbose) cerr << "\t" << n << " " << name << " of " << type_string << " left" << endl; - if (n > used / 10) + if (n > used / 10 and n >= 64) { cerr << "Significant amount of unused " << name << " of " << type_string - << ". For more accurate benchmarks, " - << "consider reducing the batch size with --batch-size." << endl; - cerr - << "Note that some protocols have larger minimum batch sizes." - << endl; + << " distorting the benchmark. "; + if (large) + cerr << "This protocol has a large minimum batch size, " + << "which makes this unavoidable for small programs."; + else + cerr << "For more accurate benchmarks, " + << "consider reducing the batch size with --batch-size."; + cerr << endl; } } void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict, - int n_bits, size_t used) + int n_bits, size_t used, bool malicious) { if (n > 0 and OnlineOptions::singleton.verbose) { @@ -70,8 +73,15 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict, } if (n * n_batch > used / 10) + { cerr << "Significant amount of unused edaBits of size " << n_bits - << ". For more accurate benchmarks, " - << "consider reducing the batch size with --batch-size " - << "or increasing the bucket size with --bucket-size." << endl; + << ". "; + if (malicious) + cerr << "This protocol has a large minimum batch size, " + << "which makes this unavoidable for small programs."; + else + cerr << "For more accurate benchmarks, " + << "consider reducing the batch size with --batch-size."; + cerr << endl; + } } diff --git a/Processor/PrepBase.h b/Processor/PrepBase.h index e598d31ce..8e68b1ad7 100644 --- a/Processor/PrepBase.h +++ b/Processor/PrepBase.h @@ -26,9 +26,9 @@ class PrepBase int my_num, int thread_num = 0); static void print_left(const char* name, size_t n, - const string& type_string, size_t used); + const string& type_string, size_t used, bool large = false); static void print_left_edabits(size_t n, size_t n_batch, bool strict, - int n_bits, size_t used); + int n_bits, size_t used, bool malicious); TimerWithComm prep_timer; }; diff --git a/Processor/instructions.h b/Processor/instructions.h index cc833284a..756bbf7ca 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -65,6 +65,8 @@ X(PREFIXSUMS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \ sint s, \ s += *op1++; *dest++ = s) \ + X(PICKS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1] + r[2]], \ + *dest++ = *op1; op1 += int(n)) \ X(MULM, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \ auto op2 = &Procp.get_C()[r[2]], \ *dest++ = *op1++ * *op2++) \ diff --git a/Programs/Source/mnist_full_B.mpc b/Programs/Source/mnist_full_B.mpc index a7cdb3688..41d83f313 100644 --- a/Programs/Source/mnist_full_B.mpc +++ b/Programs/Source/mnist_full_B.mpc @@ -38,7 +38,7 @@ except: try: batch_size = int(program.args[2]) except: - batch_size = N + batch_size = min(N, 128) if 'savemem' in program.args: N = batch_size diff --git a/Programs/Source/mnist_full_C.mpc b/Programs/Source/mnist_full_C.mpc index 76b5e9acf..2933d7fe0 100644 --- a/Programs/Source/mnist_full_C.mpc +++ b/Programs/Source/mnist_full_C.mpc @@ -38,7 +38,7 @@ except: try: batch_size = int(program.args[2]) except: - batch_size = N + batch_size = min(N, 128) if 'savemem' in program.args: N = batch_size diff --git a/Programs/Source/mnist_full_D.mpc b/Programs/Source/mnist_full_D.mpc index 68d12f977..13ad13980 100644 --- a/Programs/Source/mnist_full_D.mpc +++ b/Programs/Source/mnist_full_D.mpc @@ -38,7 +38,7 @@ except: try: batch_size = int(program.args[2]) except: - batch_size = N + batch_size = min(N, 128) assert batch_size <= N ml.Layer.back_batch_size = batch_size diff --git a/Programs/Source/tf.mpc b/Programs/Source/tf.mpc index 0c51cf581..bd6d8c8ef 100644 --- a/Programs/Source/tf.mpc +++ b/Programs/Source/tf.mpc @@ -26,12 +26,14 @@ exec(subprocess.check_output(['Scripts/process-tf.py', program.args[1]])) opt = ml.Optimizer() opt.set_layers_with_inputs(layers) -layers[0].X.input_from(0) +layers[0].X.input_from(0, binary=True) for layer in layers: - layer.input_from(0, raw='raw' in program.args) + layer.input_from(0, binary=True) sint(0).reveal().store_in_mem(0) +opt.time_layers = 'time_layers' in program.args + start_timer(1) opt.forward(1, keep_intermediate=False) stop_timer(1) diff --git a/Protocols/BrainPrep.hpp b/Protocols/BrainPrep.hpp index d3333d3a7..65b9ad0be 100644 --- a/Protocols/BrainPrep.hpp +++ b/Protocols/BrainPrep.hpp @@ -115,7 +115,7 @@ void BrainPrep::buffer_triples() + to_string(ZProtocol::share_type::clear::N_BITS) + "-bit integer computation"); typedef Rep3Share pShare; - auto buffer_size = OnlineOptions::singleton.batch_size; + auto buffer_size = BaseMachine::batch_size(DATA_TRIPLE); Player& P = this->protocol->P; vector, 3>> triples; vector, 3>> check_triples; diff --git a/Protocols/BufferScope.h b/Protocols/BufferScope.h new file mode 100644 index 000000000..6cb2b82c9 --- /dev/null +++ b/Protocols/BufferScope.h @@ -0,0 +1,32 @@ +/* + * BufferScope.h + * + */ + +#ifndef PROTOCOLS_BUFFERSCOPE_H_ +#define PROTOCOLS_BUFFERSCOPE_H_ + +template class BufferPrep; +template class Preprocessing; + +template +class BufferScope +{ + BufferPrep& prep; + int bak; + +public: + BufferScope(Preprocessing & prep, int buffer_size) : + prep(dynamic_cast&>(prep)) + { + bak = this->prep.buffer_size; + this->prep.buffer_size = buffer_size; + } + + ~BufferScope() + { + prep.buffer_size = bak; + } +}; + +#endif /* PROTOCOLS_BUFFERSCOPE_H_ */ diff --git a/Protocols/ChaiGearPrep.h b/Protocols/ChaiGearPrep.h index b57e32ca4..8893a8a8a 100644 --- a/Protocols/ChaiGearPrep.h +++ b/Protocols/ChaiGearPrep.h @@ -33,6 +33,8 @@ class ChaiGearPrep : public MaliciousRingPrep void buffer_bits(false_type); public: + static const bool homomorphic = true; + static void basic_setup(Player& P); static void key_setup(Player& P, mac_key_type alphai); static void teardown(); diff --git a/Protocols/CowGearPrep.h b/Protocols/CowGearPrep.h index e15d3feab..8657e7654 100644 --- a/Protocols/CowGearPrep.h +++ b/Protocols/CowGearPrep.h @@ -33,6 +33,8 @@ class CowGearPrep : public MaliciousRingPrep void buffer_bits(false_type); public: + static const bool homomorphic = true; + static void basic_setup(Player& P); static void key_setup(Player& P, mac_key_type alphai); static void setup(Player& P, mac_key_type alphai); diff --git a/Protocols/CowGearPrep.hpp b/Protocols/CowGearPrep.hpp index 61c240aaa..5c54ae035 100644 --- a/Protocols/CowGearPrep.hpp +++ b/Protocols/CowGearPrep.hpp @@ -112,9 +112,6 @@ PairwiseGenerator& CowGearPrep::get_generator() { auto& machine = *pairwise_machine; typedef typename T::open_type::FD FD; - // generate minimal number of items - this->buffer_size = min(machine.setup().alpha.num_slots(), - (unsigned)OnlineOptions::singleton.batch_size); pairwise_generator = new PairwiseGenerator(0, machine, &proc->P); } return *pairwise_generator; diff --git a/Protocols/DabitSacrifice.h b/Protocols/DabitSacrifice.h index 6da8cc238..134e829c8 100644 --- a/Protocols/DabitSacrifice.h +++ b/Protocols/DabitSacrifice.h @@ -6,21 +6,28 @@ #ifndef PROTOCOLS_DABITSACRIFICE_H_ #define PROTOCOLS_DABITSACRIFICE_H_ +#include "Processor/BaseMachine.h" + template class DabitSacrifice { const int S; + size_t n_masks, n_produced; + public: DabitSacrifice(); + ~DabitSacrifice(); int minimum_n_inputs(int n_outputs = 0) { - if (n_outputs < 1) - n_outputs = OnlineOptions::singleton.batch_size; if (T::clear::N_BITS < 0) // sacrifice uses S^2 random bits - n_outputs = max(n_outputs, 10 * S * S); + n_outputs = BaseMachine::batch_size(DATA_DABIT, + n_outputs, max(n_outputs, 10 * S * S)); + else + n_outputs = BaseMachine::batch_size(DATA_DABIT, n_outputs); + assert(n_outputs > 0); return n_outputs + S; } diff --git a/Protocols/DabitSacrifice.hpp b/Protocols/DabitSacrifice.hpp index d6f485cc9..26d24def7 100644 --- a/Protocols/DabitSacrifice.hpp +++ b/Protocols/DabitSacrifice.hpp @@ -7,13 +7,15 @@ #define PROTOCOLS_DABITSACRIFICE_HPP_ #include "DabitSacrifice.h" +#include "BufferScope.h" #include "Tools/PointerVector.h" #include template DabitSacrifice::DabitSacrifice() : - S(OnlineOptions::singleton.security_parameter) + S(OnlineOptions::singleton.security_parameter), + n_masks(0), n_produced() { } @@ -36,10 +38,15 @@ void DabitSacrifice::sacrifice_without_bit_check(vector >& dabits, timer.start(); #endif int n = check_dabits.size() - S; + n_masks += S; + assert(n > 0); GlobalPRNG G(proc.P); typedef typename T::bit_type::part_type BT; vector shares; vector bit_shares; + if (T::clear::N_BITS <= 0) + dynamic_cast&>(proc.DataF).buffer_extra(DATA_BIT, + S * (ceil(log2(n)) + S)); for (int i = 0; i < S; i++) { dabit to_check; @@ -58,6 +65,7 @@ void DabitSacrifice::sacrifice_without_bit_check(vector >& dabits, T tmp; proc.DataF.get_one(DATA_BIT, tmp); masked += tmp << (1 + j); + n_masks++; } shares.push_back(masked); bit_shares.push_back(to_check.second); @@ -84,6 +92,7 @@ void DabitSacrifice::sacrifice_without_bit_check(vector >& dabits, } } dabits.insert(dabits.end(), check_dabits.begin(), check_dabits.begin() + n); + n_produced += n; MCBB.Check(proc.P); delete &MCBB; #ifdef VERBOSE_DABIT @@ -92,6 +101,17 @@ void DabitSacrifice::sacrifice_without_bit_check(vector >& dabits, #endif } +template +DabitSacrifice::~DabitSacrifice() +{ +#ifdef DABIT_WASTAGE + if (n_produced > 0) + { + cerr << "daBit wastage: " << float(n_masks) / n_produced << endl; + } +#endif +} + template void DabitSacrifice::sacrifice_and_check_bits(vector >& dabits, vector >& check_dabits, SubProcessor& proc, @@ -113,7 +133,10 @@ void DabitSacrifice::sacrifice_and_check_bits(vector >& dabits, queues->wrap_up(job); } else + { + BufferScope scope(proc.DataF, multiplicands.size()); protocol.multiply(products, multiplicands, 0, multiplicands.size(), proc); + } vector check_for_zero; for (auto& x : to_check) check_for_zero.push_back(x.first - products.next()); diff --git a/Protocols/DealerPrep.hpp b/Protocols/DealerPrep.hpp index ea334257c..e70edecb5 100644 --- a/Protocols/DealerPrep.hpp +++ b/Protocols/DealerPrep.hpp @@ -17,11 +17,13 @@ void DealerPrep::buffer_triples() vector senders(P.num_players()); senders.back() = true; octetStreams os(P), to_receive(P); + int buffer_size = BaseMachine::batch_size(DATA_TRIPLE, + this->buffer_size); if (this->proc->input.is_dealer()) { SeededPRNG G; vector> shares(P.num_players() - 1); - for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + for (int i = 0; i < buffer_size; i++) { T triples[3]; for (int i = 0; i < 2; i++) @@ -41,7 +43,7 @@ void DealerPrep::buffer_triples() else { P.send_receive_all(senders, os, to_receive); - for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + for (int i = 0; i < buffer_size; i++) this->triples.push_back(to_receive.back().get>().get()); } } @@ -68,11 +70,12 @@ void DealerPrep::buffer_inverses(true_type) vector senders(P.num_players()); senders.back() = true; octetStreams os(P), to_receive(P); + int buffer_size = BaseMachine::batch_size(DATA_INVERSE); if (this->proc->input.is_dealer()) { SeededPRNG G; vector> shares(P.num_players() - 1); - for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + for (int i = 0; i < buffer_size; i++) { T tuple[2]; while (tuple[0] == 0) @@ -92,7 +95,7 @@ void DealerPrep::buffer_inverses(true_type) else { P.send_receive_all(senders, os, to_receive); - for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + for (int i = 0; i < buffer_size; i++) this->inverses.push_back(to_receive.back().get>().get()); } } @@ -105,11 +108,12 @@ void DealerPrep::buffer_bits() vector senders(P.num_players()); senders.back() = true; octetStreams os(P), to_receive(P); + int buffer_size = BaseMachine::batch_size(DATA_BIT); if (this->proc->input.is_dealer()) { SeededPRNG G; vector> shares(P.num_players() - 1); - for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + for (int i = 0; i < buffer_size; i++) { T bit = G.get_bit(); make_share(shares.data(), typename T::clear(bit), @@ -123,7 +127,7 @@ void DealerPrep::buffer_bits() else { P.send_receive_all(senders, os, to_receive); - for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + for (int i = 0; i < buffer_size; i++) this->bits.push_back(to_receive.back().get()); } } @@ -136,12 +140,13 @@ void DealerPrep::buffer_dabits(ThreadQueues*) vector senders(P.num_players()); senders.back() = true; octetStreams os(P), to_receive(P); + int buffer_size = BaseMachine::batch_size(DATA_DABIT); if (this->proc->input.is_dealer()) { SeededPRNG G; vector> shares(P.num_players() - 1); vector bit_shares(P.num_players() - 1); - for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + for (int i = 0; i < buffer_size; i++) { auto bit = G.get_bit(); make_share(shares.data(), typename T::clear(bit), @@ -160,7 +165,7 @@ void DealerPrep::buffer_dabits(ThreadQueues*) else { P.send_receive_all(senders, os, to_receive); - for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + for (int i = 0; i < buffer_size; i++) { this->dabits.push_back({to_receive.back().get(), to_receive.back().get()}); @@ -200,7 +205,8 @@ void DealerPrep::buffer_edabits(int length, false_type) vector senders(P.num_players()); senders.back() = true; octetStreams os(P), to_receive(P); - int n_vecs = OnlineOptions::singleton.batch_size / edabitvec::MAX_SIZE; + int n_vecs = DIV_CEIL(BaseMachine::edabit_batch_size(length), + edabitvec::MAX_SIZE); auto& buffer = this->edabits[{false, length}]; if (this->proc->input.is_dealer()) { diff --git a/Protocols/HemiMatrixPrep.h b/Protocols/HemiMatrixPrep.h index db8682193..d6912bb83 100644 --- a/Protocols/HemiMatrixPrep.h +++ b/Protocols/HemiMatrixPrep.h @@ -28,6 +28,8 @@ class HemiMatrixPrep : public BufferPrep> HemiMatrixPrep(const HemiMatrixPrep&) = delete; public: + static const bool homomorphic = true; + HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep, DataPositions& usage) : super(usage), n_rows(n_rows), n_inner(n_inner), diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h index 6db5bf432..29599c5e0 100644 --- a/Protocols/HemiPrep.h +++ b/Protocols/HemiPrep.h @@ -35,6 +35,8 @@ class HemiPrep : public SemiHonestRingPrep SemiPrep& get_two_party_prep(); public: + static const bool homomorphic = true; + static void basic_setup(Player& P); static void teardown(); diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp index aa94506ee..0f212b438 100644 --- a/Protocols/HemiPrep.hpp +++ b/Protocols/HemiPrep.hpp @@ -142,6 +142,8 @@ void HemiPrep::buffer_bits() if (this->proc->P.num_players() == 2) { auto& prep = get_two_party_prep(); + prep.buffer_size = BaseMachine::batch_size(DATA_BIT, + this->buffer_size); prep.buffer_dabits(0); for (auto& x : prep.dabits) this->bits.push_back(x.first); @@ -158,6 +160,8 @@ void HemiPrep::buffer_dabits(ThreadQueues* queues) if (this->proc->P.num_players() == 2) { auto& prep = get_two_party_prep(); + prep.buffer_size = BaseMachine::batch_size(DATA_DABIT, + this->buffer_size); prep.buffer_dabits(queues); this->dabits = prep.dabits; prep.dabits.clear(); diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index 5056c3a84..f84c8ac32 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -24,7 +24,7 @@ KeyGenProtocol::KeyGenProtocol(Player& P, const FHE_Params& params, int level) : P(P), params(params), fftd(params.FFTD().at(level)), usage(P) { - open_type::init_field(params.FFTD().at(level).get_prD().pr); + open_type::init_field(params.FFTD().at(level).get_prD().pr, false); typename share_type::mac_key_type alphai; auto& batch_size = OnlineOptions::singleton.batch_size; @@ -54,7 +54,8 @@ KeyGenProtocol::~KeyGenProtocol() { MC->Check(P); - usage.print_cost(); + if (OnlineOptions::singleton.verbose) + usage.print_cost(); delete proc; delete prep; @@ -63,6 +64,7 @@ KeyGenProtocol::~KeyGenProtocol() MC->teardown(); OnlineOptions::singleton.batch_size = backup_batch_size; + open_type::reset(); } template diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 2cfe3d8ba..53dc7e557 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -12,6 +12,7 @@ using namespace std; #include "Protocols/MAC_Check_Base.h" #include "Tools/time-func.h" #include "Tools/Coordinator.h" +#include "Processor/OnlineOptions.h" /* The MAX number of things we will partially open before running @@ -38,6 +39,8 @@ class TreeSum void add_openings(vector& values, const Player& P, int sum_players, int last_sum_players, int send_player); + virtual void post_add_process(vector&) {} + protected: int base_player; int opening_sum; @@ -54,7 +57,9 @@ class TreeSum vector timers; vector player_timers; - TreeSum(int opening_sum = 10, int max_broadcast = 10, int base_player = 0); + TreeSum(int opening_sum = OnlineOptions::singleton.opening_sum, + int max_broadcast = OnlineOptions::singleton.max_broadcast, + int base_player = 0); virtual ~TreeSum(); void run(vector& values, const Player& P); @@ -114,7 +119,7 @@ Coordinator* Tree_MAC_Check::coordinator = 0; * SPDZ opening protocol with MAC check (indirect communication) */ template -class MAC_Check_ : public Tree_MAC_Check +class MAC_Check_ : public virtual Tree_MAC_Check { public: MAC_Check_(const typename U::mac_key_type::Scalar& ai, int opening_sum = 10, @@ -135,7 +140,7 @@ template class MascotPrep; * SPDZ2k opening protocol with MAC check */ template -class MAC_Check_Z2k : public Tree_MAC_Check +class MAC_Check_Z2k : public virtual Tree_MAC_Check { protected: Preprocessing* prep; @@ -161,12 +166,11 @@ template using MAC_Check_Z2k_ = MAC_Check_Z2k; - /** * SPDZ opening protocol with MAC check (pairwise communication) */ template -class Direct_MAC_Check: public MAC_Check_ +class Direct_MAC_Check: public virtual MAC_Check_ { typedef MAC_Check_ super; @@ -186,7 +190,35 @@ class Direct_MAC_Check: public MAC_Check_ void init_open(const Player& P, int n = 0); void prepare_open(const T& secret, int = -1); - void exchange(const Player& P); + virtual void exchange(const Player& P); +}; + +template +class Direct_MAC_Check_Z2k: virtual public MAC_Check_Z2k_, + virtual public Direct_MAC_Check +{ +public: + Direct_MAC_Check_Z2k(const typename T::mac_key_type& ai) : + Tree_MAC_Check(ai), MAC_Check_Z2k_(ai), MAC_Check_(ai), + Direct_MAC_Check(ai) + { + } + + void prepare_open(const T& secret, int = -1) + { + MAC_Check_Z2k_::prepare_open(secret); + } + + void exchange(const Player& P) + { + Direct_MAC_Check::exchange(P); + assert(this->WaitingForCheck() > 0); + } + + void Check(const Player& P) + { + MAC_Check_Z2k_::Check(P); + } }; @@ -272,6 +304,7 @@ void TreeSum::add_openings(vector& values, const Player& P, { values[i].add(oss[j], use_lengths ? lengths[i] : -1); } + post_add_process(values); MC.timers[SUM].stop(); } } @@ -279,6 +312,11 @@ void TreeSum::add_openings(vector& values, const Player& P, template void TreeSum::start(vector& values, const Player& P) { + if (opening_sum < 2) + opening_sum = P.num_players(); + if (max_broadcast < 2) + max_broadcast = P.num_players(); + os.reset_write_head(); int sum_players = P.num_players(); int my_relative_num = positive_modulo(P.my_num() - base_player, P.num_players()); diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index f986d9a6e..cd42f242a 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -98,6 +98,7 @@ void Tree_MAC_Check::init_open(const Player&, int n) template void Tree_MAC_Check::prepare_open(const U& secret, int) { + assert(U::mac_type::invertible); this->values.push_back(secret.get_share()); macs.push_back(secret.get_mac()); } @@ -344,7 +345,7 @@ Direct_MAC_Check::Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai template Direct_MAC_Check::Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai) : - MAC_Check_(ai) + Tree_MAC_Check(ai), MAC_Check_(ai) { open_counter = 0; } @@ -405,6 +406,7 @@ void Direct_MAC_Check::init_open(const Player& P, int n) template void Direct_MAC_Check::prepare_open(const T& secret, int) { + assert(T::mac_type::invertible); this->values.push_back(secret.get_share()); this->macs.push_back(secret.get_mac()); } diff --git a/Protocols/MalRepRingPrep.h b/Protocols/MalRepRingPrep.h index ea857a5a4..58e2313e1 100644 --- a/Protocols/MalRepRingPrep.h +++ b/Protocols/MalRepRingPrep.h @@ -77,6 +77,11 @@ class MalRepRingPrepWithBits: public virtual MaliciousRingPrep, public virtual SimplerMalRepRingPrep { public: + static bool dabits_from_bits() + { + return true; + } + MalRepRingPrepWithBits(SubProcessor* proc, DataPositions& usage); void set_protocol(typename T::Protocol& protocol) diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index 1be5bb9a5..fa66bcf64 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -64,7 +64,8 @@ void MalRepRingPrep::buffer_squares() MaliciousRepPrep prep(_); assert(this->proc != 0); prep.init_honest(this->proc->P); - prep.buffer_size = this->buffer_size; + prep.buffer_size = BaseMachine::batch_size(DATA_SQUARE, + this->buffer_size); prep.buffer_squares(); for (auto& x : prep.squares) this->squares.push_back({{x[0], x[1]}}); diff --git a/Protocols/MalRepRingShare.h b/Protocols/MalRepRingShare.h index ff33a6eea..4c263fd72 100644 --- a/Protocols/MalRepRingShare.h +++ b/Protocols/MalRepRingShare.h @@ -35,11 +35,6 @@ class MalRepRingShare : public MaliciousRep3Share> typedef MalRepRingShare SquareToBitShare; typedef MalRepRingPrep SquarePrep; - static string type_short() - { - return "RR"; - } - MalRepRingShare() { } diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index 7c94b5d81..2a153e5f2 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -53,6 +53,10 @@ class MaliciousRep3Share : public Rep3Share { return "M" + string(1, T::type_char()); } + static string type_string() + { + return "malicious " + super::type_string(); + } MaliciousRep3Share() { diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index 9203d849b..0a98e8e27 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -76,7 +76,8 @@ void MaliciousRepPrep::buffer_triples() { check_field_size(); auto& triples = this->triples; - auto buffer_size = this->buffer_size; + auto buffer_size = BaseMachine::batch_size(DATA_TRIPLE, + this->buffer_size); auto& honest_proc = this->honest_proc; assert(honest_proc != 0); Player& P = honest_proc->P; @@ -140,7 +141,7 @@ void MaliciousRepPrep::buffer_squares() vector opened; vector> check_squares; auto& squares = this->squares; - auto buffer_size = this->buffer_size; + auto buffer_size = BaseMachine::batch_size(DATA_SQUARE, this->buffer_size); auto& honest_prep = this->honest_prep; auto& honest_proc = this->honest_proc; auto& MC = this->MC; @@ -186,7 +187,8 @@ void MaliciousBitOnlyRepPrep::buffer_bits() vector opened; vector> check_squares; auto& bits = this->bits; - auto buffer_size = this->buffer_size; + auto buffer_size = BaseMachine::batch_size(DATA_BIT, + this->buffer_size); assert(honest_proc); Player& P = honest_proc->P; honest_prep.buffer_size = buffer_size; diff --git a/Protocols/MaliciousRingPrep.hpp b/Protocols/MaliciousRingPrep.hpp index 0594446bb..63a11f76d 100644 --- a/Protocols/MaliciousRingPrep.hpp +++ b/Protocols/MaliciousRingPrep.hpp @@ -32,9 +32,8 @@ void MaliciousDabitOnlyPrep::buffer_dabits(ThreadQueues* queues, false_type, { assert(this->proc != 0); vector> check_dabits; - DabitSacrifice dabit_sacrifice; this->buffer_dabits_without_check(check_dabits, - dabit_sacrifice.minimum_n_inputs(), queues); + dabit_sacrifice.minimum_n_inputs(this->buffer_size), queues); dabit_sacrifice.sacrifice_and_check_bits(this->dabits, check_dabits, *this->proc, queues); } diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 5cfa82b84..754efec0c 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -90,6 +90,8 @@ class MascotPrep : public virtual MaliciousRingPrep, public virtual MascotDabitOnlyPrep { public: + static bool bits_from_triples() { return true; } + MascotPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index f5c09941d..5d3464c5e 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -62,8 +62,11 @@ void MascotTriplePrep::buffer_triples() auto& params = this->params; auto& triple_generator = this->triple_generator; params.generateBits = false; + triple_generator->set_batch_size( + BaseMachine::batch_size(DATA_TRIPLE, this->buffer_size)); triple_generator->generate(); triple_generator->unlock(); + triple_generator->set_batch_size(OnlineOptions::singleton.batch_size); assert(triple_generator->uncheckedTriples.size() != 0); for (auto& triple : triple_generator->uncheckedTriples) this->triples.push_back( diff --git a/Protocols/Rep3Shuffler.hpp b/Protocols/Rep3Shuffler.hpp index a2edfb76f..19de8c4f9 100644 --- a/Protocols/Rep3Shuffler.hpp +++ b/Protocols/Rep3Shuffler.hpp @@ -118,8 +118,7 @@ void Rep3Shuffler::apply(vector& a, size_t n, int unit_size, template void Rep3Shuffler::del(int handle) { - for (int i = 0; i < 2; i++) - shuffles.at(handle)[i].clear(); + shuffles.at(handle) = {}; } template diff --git a/Protocols/Rep4Prep.h b/Protocols/Rep4Prep.h index 148e4285a..81d4f7281 100644 --- a/Protocols/Rep4Prep.h +++ b/Protocols/Rep4Prep.h @@ -42,6 +42,11 @@ class Rep4RingOnlyPrep : public virtual Rep4RingPrep, } public: + static bool dabits_from_bits() + { + return true; + } + static void edabit_sacrifice_buckets(vector>&, size_t, bool, int, SubProcessor&, int, int, const void* = 0) { diff --git a/Protocols/Rep4Prep.hpp b/Protocols/Rep4Prep.hpp index e871e82c9..92ccc7efa 100644 --- a/Protocols/Rep4Prep.hpp +++ b/Protocols/Rep4Prep.hpp @@ -46,7 +46,7 @@ void Rep4RingPrep::buffer_inputs(int player) template void Rep4RingPrep::buffer_triples() { - generate_triples(this->triples, OnlineOptions::singleton.batch_size, + generate_triples(this->triples, BaseMachine::batch_size(DATA_TRIPLE), this->protocol); } @@ -78,7 +78,7 @@ void Rep4RingPrep::buffer_bits() auto& protocol = this->proc->protocol; vector bits; - int batch_size = OnlineOptions::singleton.batch_size; + int batch_size = BaseMachine::batch_size(DATA_BIT); bits.reserve(batch_size); for (int i = 0; i < batch_size; i++) bits.push_back(G.get_bit()); diff --git a/Protocols/RepRingOnlyEdabitPrep.hpp b/Protocols/RepRingOnlyEdabitPrep.hpp index c2650c27b..5213cf491 100644 --- a/Protocols/RepRingOnlyEdabitPrep.hpp +++ b/Protocols/RepRingOnlyEdabitPrep.hpp @@ -12,7 +12,7 @@ void RepRingOnlyEdabitPrep::buffer_edabits(int n_bits, ThreadQueues*) { assert(this->proc); int dl = T::bit_type::default_length; - int buffer_size = DIV_CEIL(this->buffer_size, dl) * dl; + int buffer_size = DIV_CEIL(BaseMachine::edabit_batch_size(n_bits, this->buffer_size), dl) * dl; vector wholes; wholes.resize(buffer_size); Instruction inst; @@ -49,5 +49,5 @@ void RepRingOnlyEdabitPrep::buffer_edabits(int n_bits, ThreadQueues*) SubProcessor bit_proc(party.MC->get_part_MC(), this->proc->bit_prep, P); bit_adder.multi_add(sums, summands, 0, buffer_size / dl, bit_proc, dl, 0); - this->push_edabits(this->edabits[{false, n_bits}], wholes, sums, buffer_size); + this->push_edabits(this->edabits[{false, n_bits}], wholes, sums); } diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index c3899e745..0633cbdc4 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -15,6 +15,7 @@ #include "Protocols/ShuffleSacrifice.h" #include "Tools/TimerWithComm.h" #include "edabit.h" +#include "DabitSacrifice.h" #include @@ -36,13 +37,13 @@ class BufferPrep : public Preprocessing friend class InScope; + static const bool homomorphic = false; + template void buffer_inverses(true_type); template void buffer_inverses(false_type) { throw runtime_error("no inverses"); } - virtual bool bits_from_dabits() { return false; } - protected: vector> triples; vector> squares; @@ -83,8 +84,9 @@ class BufferPrep : public Preprocessing { throw runtime_error("no personal daBits"); } void push_edabits(vector>& edabits, - const vector& sums, const vector>& bits, - int buffer_size); + const vector& sums, + const vector>& bits); + public: typedef T share_type; @@ -103,6 +105,10 @@ class BufferPrep : public Preprocessing throw runtime_error("sacrifice not available"); } + static bool bits_from_dabits() { return false; } + static bool bits_from_triples() { return false; } + static bool dabits_from_bits() { return false; } + BufferPrep(DataPositions& usage); virtual ~BufferPrep(); @@ -135,6 +141,8 @@ class BufferPrep : public Preprocessing SubProcessor* get_proc() { return proc; } void set_proc(SubProcessor* proc) { this->proc = proc; } + + void buffer_extra(Dtype type, int n_items); }; /** @@ -272,7 +280,7 @@ class SemiHonestRingPrep : public virtual RingPrep void buffer_edabits(int n_bits, false_type) { this->template buffer_edabits_without_check<0>(n_bits, this->edabits[{false, n_bits}], - OnlineOptions::singleton.batch_size); } + BaseMachine::edabit_batch_size(n_bits, this->buffer_size)); } template void buffer_edabits(int, true_type) { throw not_implemented(); } @@ -286,6 +294,8 @@ class SemiHonestRingPrep : public virtual RingPrep template class MaliciousDabitOnlyPrep : public virtual RingPrep { + DabitSacrifice dabit_sacrifice; + template void buffer_dabits(ThreadQueues* queues, true_type, false_type); template @@ -312,6 +322,8 @@ class MaliciousRingPrep : public virtual MaliciousDabitOnlyPrep { typedef typename T::bit_type::part_type BT; + DabitSacrifice dabit_sacrifice; + protected: void buffer_personal_edabits(int n_bits, vector& sums, vector>& bits, SubProcessor& proc, int input_player, @@ -343,7 +355,7 @@ class MaliciousRingPrep : public virtual MaliciousDabitOnlyPrep bool strict, int player, SubProcessor& proc, int begin, int end, const void* supply = 0) { - EdabitShuffleSacrifice().edabit_sacrifice_buckets(to_check, n_bits, strict, + EdabitShuffleSacrifice(n_bits).edabit_sacrifice_buckets(to_check, strict, player, proc, begin, end, supply); } diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 00ee24b12..2d9a161b2 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -7,10 +7,11 @@ #define PROTOCOlS_REPLICATEDPREP_HPP_ #include "ReplicatedPrep.h" + +#include "BufferScope.h" #include "SemiRep3Prep.h" #include "DabitSacrifice.h" #include "Spdz2kPrep.h" - #include "GC/BitAdder.h" #include "Processor/OnlineOptions.h" #include "Protocols/Rep3Share.h" @@ -62,7 +63,7 @@ template BufferPrep::BufferPrep(DataPositions& usage) : Preprocessing(usage), n_bit_rounds(0), proc(0), P(0), - buffer_size(OnlineOptions::singleton.batch_size) + buffer_size(0) { } @@ -82,13 +83,14 @@ BufferPrep::~BufferPrep() this->print_left("triples", triples.size() * T::default_length, type_string, this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE) - * T::default_length); + * T::default_length, + T::LivePrep::homomorphic or T::expensive_triples); size_t used_bits = my_usage.at(DATA_BIT); size_t used_dabits = my_usage.at(DATA_DABIT); - if (bits_from_dabits()) + if (T::LivePrep::bits_from_dabits()) { - if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac) + if (field_type == DATA_INT and not T::has_mac) // add dabits with computation modulo power of two but without MAC used_dabits += my_usage.at(DATA_BIT); } @@ -108,7 +110,8 @@ BufferPrep::~BufferPrep() for (auto& x : this->edabits) { this->print_left_edabits(x.second.size(), x.second[0].size(), - x.first.first, x.first.second, this->usage.edabits[x.first]); + x.first.first, x.first.second, this->usage.edabits[x.first], + T::malicious); } #ifdef VERBOSE @@ -180,7 +183,7 @@ void ReplicatedRingPrep::buffer_triples() assert(this->protocol != 0); // independent instance to avoid conflicts typename T::Protocol protocol(this->protocol->branch()); - generate_triples(this->triples, OnlineOptions::singleton.batch_size, + generate_triples(this->triples, BaseMachine::batch_size(DATA_TRIPLE), &protocol); } @@ -232,7 +235,8 @@ template void BitPrep::buffer_squares() { auto proc = this->proc; - auto buffer_size = this->buffer_size; + auto buffer_size = BaseMachine::batch_size(DATA_SQUARE, + this->buffer_size); assert(proc != 0); vector a_plus_b(buffer_size), as(buffer_size), cs(buffer_size); T b; @@ -258,6 +262,7 @@ template void generate_squares(vector>& squares, int n_squares, U* protocol) { + n_squares = BaseMachine::batch_size(DATA_SQUARE, n_squares); assert(protocol != 0); squares.resize(n_squares); protocol->init_mul(); @@ -286,7 +291,7 @@ void BufferPrep::buffer_inverses(true_type) auto& P = proc->P; auto& MC = proc->MC; auto& prep = *this; - int buffer_size = OnlineOptions::singleton.batch_size; + int buffer_size = BaseMachine::batch_size(DATA_INVERSE); vector> triples(buffer_size); vector c; for (int i = 0; i < buffer_size; i++) @@ -372,13 +377,17 @@ void buffer_bits_from_squares(RingPrep& prep) auto proc = prep.get_proc(); assert(proc != 0); auto& bits = prep.get_bits(); - vector> squares(prep.buffer_size); + vector> squares( + BaseMachine::batch_size(DATA_BIT, prep.buffer_size)); + int bak = prep.buffer_size; + prep.buffer_size = squares.size(); vector s; - for (int i = 0; i < prep.buffer_size; i++) + for (size_t i = 0; i < squares.size(); i++) { prep.get_two(DATA_SQUARE, squares[i][0], squares[i][1]); s.push_back(squares[i][1]); } + prep.buffer_size = bak; vector open; proc->MC.POpen(open, s, proc->P); auto one = T::constant(1, proc->P.my_num(), proc->MC.get_alphai()); @@ -406,7 +415,8 @@ template void BitPrep::buffer_bits_without_check() { SeededPRNG G; - buffer_ring_bits_without_check(this->bits, G, this->buffer_size); + buffer_ring_bits_without_check(this->bits, G, + BaseMachine::batch_size(DATA_BIT, this->buffer_size)); } template @@ -430,9 +440,8 @@ void MaliciousRingPrep::buffer_personal_dabits(int input_player, false_type, { assert(this->proc != 0); vector> check_dabits; - DabitSacrifice dabit_sacrifice; this->buffer_personal_dabits_without_check<0>(input_player, check_dabits, - dabit_sacrifice.minimum_n_inputs()); + dabit_sacrifice.minimum_n_inputs(this->buffer_size)); dabit_sacrifice.sacrifice_and_check_bits( this->personal_dabits[input_player], check_dabits, *this->proc, 0); } @@ -585,7 +594,7 @@ void MaliciousRingPrep::buffer_personal_edabits(int n_bits, vector& wholes Timer timer; timer.start(); #endif - EdabitShuffleSacrifice shuffle_sacrifice; + EdabitShuffleSacrifice shuffle_sacrifice(n_bits); int buffer_size = shuffle_sacrifice.minimum_n_inputs(); vector sums(buffer_size); vector> bits(n_bits, vector(DIV_CEIL(buffer_size, BT::default_length))); @@ -606,8 +615,9 @@ void MaliciousRingPrep::buffer_personal_edabits(int n_bits, vector& wholes << " seconds" << endl; #endif vector> edabits; - shuffle_sacrifice.edabit_sacrifice(edabits, sums, bits, n_bits, *this->proc, + shuffle_sacrifice.edabit_sacrifice(edabits, sums, bits, *this->proc, strict, input_player, queues); + assert(not edabits.empty()); wholes.clear(); parts.clear(); parts.resize(n_bits); @@ -670,6 +680,7 @@ void BitPrep::buffer_ring_bits_without_check(vector& bits, PRNG& G, int n_relevant_players = protocol->get_n_relevant_players(); vector> player_bits; auto stat = proc->P.total_comm(); + BufferScope _(*this, buffer_size); buffer_bits_from_players(player_bits, G, *proc, this->base_player, buffer_size, 1); auto& prot = *protocol; @@ -688,8 +699,7 @@ template void RingPrep::buffer_dabits_without_check(vector>& dabits, int buffer_size, ThreadQueues* queues) { - if (buffer_size < 0) - buffer_size = OnlineOptions::singleton.batch_size; + buffer_size = BaseMachine::batch_size(DATA_DABIT, buffer_size); int old_size = dabits.size(); dabits.resize(dabits.size() + buffer_size); if (queues) @@ -712,7 +722,9 @@ void SemiRep3Prep::buffer_dabits(ThreadQueues*) assert(this->proc); typedef typename T::bit_type BT; - int n_blocks = DIV_CEIL(this->buffer_size, BT::default_length); + int n_blocks = DIV_CEIL( + BaseMachine::batch_size(DATA_DABIT, this->buffer_size), + BT::default_length); int n_bits = n_blocks * BT::default_length; vector b(n_blocks); @@ -838,7 +850,6 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector& sums, } else buffer_edabits_without_check<0>(n_bits, sums, bits, 0, rounded); - sums.resize(buffer_size); #ifdef VERBOSE_EDA cerr << "Done with unchecked edaBit generation after " << timer.elapsed() << " seconds" << endl; @@ -910,7 +921,7 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector>& vector> bits; vector sums; buffer_edabits_without_check<0>(n_bits, sums, bits, buffer_size); - this->push_edabits(edabits, sums, bits, buffer_size); + this->push_edabits(edabits, sums, bits); (void) stat; #ifdef VERBOSE_PREP cerr << "edaBit generation" << endl; @@ -920,12 +931,11 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector>& template void BufferPrep::push_edabits(vector>& edabits, - const vector& sums, const vector>& bits, - int buffer_size) + const vector& sums, const vector>& bits) { int unit = T::bit_type::part_type::default_length; - edabits.reserve(edabits.size() + DIV_CEIL(buffer_size, unit)); - for (int i = 0; i < buffer_size; i++) + edabits.reserve(edabits.size() + DIV_CEIL(sums.size(), unit)); + for (size_t i = 0; i < sums.size(); i++) { if (i % unit == 0) edabits.push_back(bits.at(i / unit)); @@ -938,9 +948,11 @@ template void RingPrep::buffer_sedabits_from_edabits(int n_bits, false_type) { assert(this->proc != 0); - size_t buffer_size = OnlineOptions::singleton.batch_size; + size_t buffer_size = DIV_CEIL(BaseMachine::edabit_batch_size(n_bits), + edabitvec::MAX_SIZE); auto& loose = this->edabits[{false, n_bits}]; - while (loose.size() < size_t(DIV_CEIL(buffer_size, edabitvec::MAX_SIZE))) + BufferScope scope(*this, buffer_size); + while (loose.size() < buffer_size) this->buffer_edabits(false, n_bits); sanitize<0>(loose, n_bits); for (auto& x : loose) @@ -980,6 +992,7 @@ void RingPrep::sanitize(vector>& edabits, int n_bits, int player, vector dabits; typedef typename T::bit_type::part_type::small_type BT; vector to_open; + BufferScope scope(*this, (end - begin)); for (int i = begin; i < end; i++) { auto& x = edabits[i]; @@ -1028,6 +1041,12 @@ void RingPrep::sanitize(vector>& edabits, int n_bits) vector dabits; typedef typename T::bit_type::part_type BT; vector to_open; + BufferScope scope(*this, edabits.size() * edabits[0].size()); + +#ifdef DEBUG_BATCH_SIZE + cerr << this->dabits.size() << " daBits left before" << endl; +#endif + for (auto& x : edabits) { for (size_t j = n_bits; j < x.b.size(); j++) @@ -1044,6 +1063,11 @@ void RingPrep::sanitize(vector>& edabits, int n_bits) to_open.push_back(x.b[j] + bits); } } + +#ifdef DEBUG_BATCH_SIZE + cerr << this->dabits.size() << " daBits left after" << endl; +#endif + vector opened; auto& MCB = *BT::new_mc( GC::ShareThread::s().MC->get_alphai()); @@ -1184,6 +1208,7 @@ edabitvec BufferPrep::get_edabitvec(bool strict, int n_bits) InScope in_scope(this->do_count, false, *this); buffer_edabits_with_queues(strict, n_bits); } + assert(not buffer.empty()); auto res = buffer.back(); buffer.pop_back(); this->fill(res, strict, n_bits); @@ -1343,4 +1368,25 @@ T BufferPrep::get_random() } } +template +void BufferPrep::buffer_extra(Dtype type, int n_items) +{ + BufferScope scope(*this, n_items); + + switch (type) + { + case DATA_TRIPLE: + buffer_triples(); + break; + case DATA_SQUARE: + buffer_squares(); + break; + case DATA_BIT: + buffer_bits(); + break; + default: + throw not_implemented(); + } +} + #endif diff --git a/Protocols/RingOnlyPrep.hpp b/Protocols/RingOnlyPrep.hpp index 727b4f184..cf1d0675d 100644 --- a/Protocols/RingOnlyPrep.hpp +++ b/Protocols/RingOnlyPrep.hpp @@ -18,6 +18,7 @@ void RingOnlyPrep::buffer_dabits_from_bits_without_check( this->proc->bit_prep, this->proc->P); typename T::bit_type::part_type::Input input(bit_proc); input.reset_all(this->proc->P); + BufferScope scope(*this, buffer_size); for (int i = 0; i < buffer_size; i++) { T bit; diff --git a/Protocols/SemiPrep.h b/Protocols/SemiPrep.h index 9646e9453..bfcae4245 100644 --- a/Protocols/SemiPrep.h +++ b/Protocols/SemiPrep.h @@ -27,7 +27,7 @@ class SemiPrep : public virtual OTPrep, public virtual SemiHonestRingPrep void get_one_no_count(Dtype dtype, T& a); - bool bits_from_dabits(); + static bool bits_from_dabits(); }; #endif /* PROTOCOLS_SEMIPREP_H_ */ diff --git a/Protocols/SemiPrep.hpp b/Protocols/SemiPrep.hpp index f1ec6efd9..9fc54b5d0 100644 --- a/Protocols/SemiPrep.hpp +++ b/Protocols/SemiPrep.hpp @@ -24,7 +24,10 @@ template void SemiPrep::buffer_triples() { assert(this->triple_generator); + this->triple_generator->set_batch_size( + BaseMachine::batch_size(DATA_TRIPLE)); this->triple_generator->generatePlainTriples(); + this->triple_generator->set_batch_size(OnlineOptions::singleton.batch_size); for (auto& x : this->triple_generator->plainTriples) { this->triples.push_back({{x[0], x[1], x[2]}}); @@ -35,8 +38,8 @@ void SemiPrep::buffer_triples() template bool SemiPrep::bits_from_dabits() { - assert(this->proc); - return this->proc->P.num_players() == 2 and not T::clear::characteristic_two; + return not T::clear::characteristic_two and BaseMachine::has_singleton() + and BaseMachine::s().get_N().num_players() == 2; } template @@ -45,9 +48,12 @@ void SemiPrep::buffer_dabits(ThreadQueues* queues) if (bits_from_dabits()) { assert(this->triple_generator); + this->triple_generator->set_batch_size( + BaseMachine::batch_size(DATA_DABIT, this->buffer_size)); this->triple_generator->generatePlainBits(); for (auto& x : this->triple_generator->plainBits) this->dabits.push_back({x.first, x.second}); + this->triple_generator->set_batch_size(OnlineOptions::singleton.batch_size); } else SemiHonestRingPrep::buffer_dabits(queues); diff --git a/Protocols/SemiRep3Prep.h b/Protocols/SemiRep3Prep.h index 5d68f03ea..2e2ba7344 100644 --- a/Protocols/SemiRep3Prep.h +++ b/Protocols/SemiRep3Prep.h @@ -18,6 +18,8 @@ class SemiRep3Prep : public virtual SemiHonestRingPrep, void buffer_dabits(ThreadQueues*); public: + static bool bits_from_dabits() { return true; } + SemiRep3Prep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index 3182b1c05..b45828e35 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -74,6 +74,9 @@ class ShareInterface static void generate_mac_key(T&, U&) {} static int threshold(int) { throw runtime_error("undefined threshold"); } + + template + static string proto_fake_opts() { return T::fake_opts(); } }; #endif /* PROTOCOLS_SHAREINTERFACE_H_ */ diff --git a/Protocols/ShareMatrix.h b/Protocols/ShareMatrix.h index b31aa7085..d67ce7123 100644 --- a/Protocols/ShareMatrix.h +++ b/Protocols/ShareMatrix.h @@ -250,6 +250,7 @@ class ShareMatrix : public ValueMatrix, public ShareInterface typedef MatrixMC MAC_Check; typedef Beaver Protocol; typedef ::Input Input; + typedef DummyLivePrep LivePrep; typedef ValueMatrix clear; typedef clear open_type; diff --git a/Protocols/ShuffleSacrifice.h b/Protocols/ShuffleSacrifice.h index b8ffd0aaf..56ae0f0a3 100644 --- a/Protocols/ShuffleSacrifice.h +++ b/Protocols/ShuffleSacrifice.h @@ -30,15 +30,15 @@ class ShuffleSacrifice const int C; ShuffleSacrifice(); - ShuffleSacrifice(int B, int C); + ShuffleSacrifice(int B, int C = 3); int minimum_n_inputs(int n_outputs = 1) { return max(n_outputs, minimum_n_outputs()) * B + C; } - int minimum_n_inputs_with_combining() + int minimum_n_inputs_with_combining(int n_outputs = 1) { - return minimum_n_inputs(B * minimum_n_outputs()); + return minimum_n_inputs(B * max(n_outputs, minimum_n_outputs())); } int minimum_n_outputs() { @@ -89,17 +89,21 @@ class EdabitShuffleSacrifice : public ShuffleSacrifice { typedef typename T::bit_type::part_type BT; + size_t n_bits; + public: + EdabitShuffleSacrifice(int n_bits); + void edabit_sacrifice(vector>& output, vector& sums, - vector>& bits, size_t n_bits, + vector>& bits, SubProcessor& proc, bool strict = false, int player = -1, ThreadQueues* = 0); - void edabit_sacrifice_buckets(vector>& to_check, size_t n_bits, + void edabit_sacrifice_buckets(vector>& to_check, bool strict, int player, SubProcessor& proc, int begin, int end, const void* supply = 0); - void edabit_sacrifice_buckets(vector>& to_check, size_t n_bits, + void edabit_sacrifice_buckets(vector>& to_check, bool strict, int player, SubProcessor& proc, int begin, int end, LimitedPrep& personal_prep, const void* supply = 0); }; diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index 150cdb610..7fadd8195 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -14,7 +14,7 @@ inline ShuffleSacrifice::ShuffleSacrifice() : - ShuffleSacrifice(OnlineOptions::singleton.bucket_size, 3) + ShuffleSacrifice(OnlineOptions::singleton.bucket_size) { } @@ -175,10 +175,16 @@ void DabitShuffleSacrifice::dabit_sacrifice(vector >& output, delete &MCB; } +template +EdabitShuffleSacrifice::EdabitShuffleSacrifice(int n_bits) : + ShuffleSacrifice(BaseMachine::edabit_bucket_size(n_bits)), n_bits(n_bits) +{ +} + template void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, vector& wholes, vector>& parts, - size_t n_bits, SubProcessor& proc, bool strict, int player, + SubProcessor& proc, bool strict, int player, ThreadQueues* queues) { #ifdef VERBOSE_EDA @@ -227,6 +233,7 @@ void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, int buffer_size = to_check.size(); int N = (buffer_size - C) / B; + assert(N > 0); // needs to happen before shuffling for security LimitedPrep personal_prep; @@ -310,13 +317,13 @@ void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, EdabitSacrificeJob job(&to_check, n_bits, strict, player); int start = queues->distribute_no_setup(job, N, 0, BT::default_length, &supplies); - edabit_sacrifice_buckets(to_check, n_bits, strict, player, proc, start, + edabit_sacrifice_buckets(to_check, strict, player, proc, start, N, personal_prep); if (start) queues->wrap_up(job); } else - edabit_sacrifice_buckets(to_check, n_bits, strict, player, proc, 0, N, + edabit_sacrifice_buckets(to_check, strict, player, proc, 0, N, personal_prep); #ifdef VERBOSE_EDA cerr << "Bucket sacrifice took " << bucket_timer.elapsed() << " seconds" @@ -348,17 +355,17 @@ void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, template void EdabitShuffleSacrifice::edabit_sacrifice_buckets(vector>& to_check, - size_t n_bits, bool strict, int player, SubProcessor& proc, int begin, + bool strict, int player, SubProcessor& proc, int begin, int end, const void* supply) { LimitedPrep personal_prep; - edabit_sacrifice_buckets(to_check, n_bits, strict, player, proc, begin, end, + edabit_sacrifice_buckets(to_check, strict, player, proc, begin, end, personal_prep, supply); } template void EdabitShuffleSacrifice::edabit_sacrifice_buckets(vector>& to_check, - size_t n_bits, bool strict, int player, SubProcessor& proc, int begin, + bool strict, int player, SubProcessor& proc, int begin, int end, LimitedPrep& personal_prep, const void* supply) { typedef typename T::bit_type::part_type BT; diff --git a/Protocols/SohoPrep.h b/Protocols/SohoPrep.h index 5e28381be..e6ba1495b 100644 --- a/Protocols/SohoPrep.h +++ b/Protocols/SohoPrep.h @@ -18,6 +18,8 @@ class SohoPrep : public SemiHonestRingPrep static Lock lock; public: + static const bool homomorphic = true; + static void basic_setup(Player& P); static void teardown(); diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index 8b447f311..5b6f75774 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -27,6 +27,7 @@ class Spdz2kPrep : public virtual MaliciousRingPrep, MascotTriplePrep* bit_prep; SubProcessor* bit_proc; typename BitShare::MAC_Check* bit_MC; + DabitSacrifice dabit_sacrifice; public: Spdz2kPrep(SubProcessor* proc, DataPositions& usage); diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index 4b76904b7..a3e3f2fc2 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -85,6 +85,7 @@ void Spdz2kPrep::buffer_bits() template void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep) { + buffer_size = BaseMachine::batch_size(DATA_BIT, buffer_size); typedef typename U::share_type BitShare; typedef typename BitShare::open_type open_type; assert(bit_prep != 0); @@ -129,10 +130,7 @@ void Spdz2kPrep::buffer_dabits(ThreadQueues* queues) { assert(this->proc != 0); vector> check_dabits; - DabitSacrifice dabit_sacrifice; - int buffer_size = OnlineOptions::singleton.batch_size; - if (queues) - buffer_size *= queues->size(); + int buffer_size = BaseMachine::batch_size(DATA_DABIT, this->buffer_size); this->buffer_dabits_from_bits_without_check(check_dabits, dabit_sacrifice.minimum_n_inputs(buffer_size), queues); dabit_sacrifice.sacrifice_without_bit_check(this->dabits, check_dabits, @@ -187,6 +185,7 @@ void MaliciousRingPrep::buffer_edabits_from_personal(bool strict, int n_bits, vector> tmp_bits; this->buffer_personal_edabits(n_bits, tmp, tmp_bits, bit_proc, i, strict, queues); + assert(not tmp.empty()); sums.resize(tmp.size()); for (size_t j = 0; j < tmp.size(); j++) sums[j] += tmp[j]; diff --git a/Protocols/Spdz2kShare.h b/Protocols/Spdz2kShare.h index 82be3c384..85f9ecb91 100644 --- a/Protocols/Spdz2kShare.h +++ b/Protocols/Spdz2kShare.h @@ -48,7 +48,7 @@ class Spdz2kShare : public Share> typedef Z2kRectangle Rectangle; typedef MAC_Check_Z2k, Z2, open_type, Spdz2kShare> MAC_Check; - typedef MAC_Check Direct_MC; + typedef Direct_MAC_Check_Z2k Direct_MC; typedef ::Input Input; typedef ::PrivateOutput PrivateOutput; typedef SPDZ2k Protocol; @@ -68,6 +68,9 @@ class Spdz2kShare : public Share> static string type_string() { return "SPDZ2^(" + to_string(K) + "+" + to_string(S) + ")"; } static string type_short() { return "Z" + to_string(K) + "," + to_string(S); } + template + static string proto_fake_opts() { return " -Z " + to_string(K) + " -S " + to_string(S); } + Spdz2kShare() {} template Spdz2kShare(const Share_& x) : super(x) {} diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp index 1090fc08e..e94a59ee3 100644 --- a/Protocols/SpdzWisePrep.hpp +++ b/Protocols/SpdzWisePrep.hpp @@ -24,7 +24,8 @@ void SpdzWisePrep::buffer_triples() assert(this->proc != 0); this->protocol->init_mul(); generate_triples_initialized(this->triples, - OnlineOptions::singleton.batch_size, this->protocol); + BaseMachine::batch_size(DATA_TRIPLE, this->buffer_size), + this->protocol); } template @@ -42,6 +43,9 @@ void SpdzWisePrep>>::buffer_bits() ProtocolSet set(this->proc->P, {}); auto& protocol = set.protocol; auto& prep = set.preprocessing; + int buffer_size = BaseMachine::batch_size< + SpdzWiseShare>>(DATA_BIT, + this->buffer_size); for (int i = 0; i < buffer_size; i++) bits.push_back(prep.get_bit()); protocol.init_mul(); @@ -63,7 +67,9 @@ void buffer_bits_from_squares_in_ring(vector>& bits, SquarePrep prep(usage); SubProcessor bit_proc(MC, prep, proc->P, proc->Proc); prep.set_proc(&bit_proc); - bits_from_square_in_ring(bits, OnlineOptions::singleton.batch_size, &prep); + bits_from_square_in_ring(bits, + BaseMachine::batch_size>(DATA_BIT), + &prep); } template diff --git a/Protocols/SpdzWiseRingPrep.h b/Protocols/SpdzWiseRingPrep.h index 4a16b92ee..c59ff4f9a 100644 --- a/Protocols/SpdzWiseRingPrep.h +++ b/Protocols/SpdzWiseRingPrep.h @@ -35,6 +35,11 @@ class SpdzWiseRingPrep : public virtual SpdzWisePrep, } public: + static bool dabits_from_bits() + { + return true; + } + static void edabit_sacrifice_buckets(vector>&, size_t, bool, int, SubProcessor&, int, int, const void* = 0) { diff --git a/Protocols/TemiPrep.h b/Protocols/TemiPrep.h index ad12837a8..2b6d8b87d 100644 --- a/Protocols/TemiPrep.h +++ b/Protocols/TemiPrep.h @@ -50,6 +50,8 @@ class TemiPrep : public SemiHonestRingPrep vector*> multipliers; public: + static const bool homomorphic = true; + static void basic_setup(Player& P); static void teardown(); diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index 7a63298f1..735a59531 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -45,7 +45,7 @@ class Files ofstream* outf; int N; typename T::mac_type key; - PRNG G; + PRNG& G; Files(int N, const typename T::mac_type& key, const string& prep_data_prefix, Dtype type, PRNG& G, int thread_num = -1) : Files(N, key, diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index fa4572f31..f46b3805f 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -312,6 +312,12 @@ void write_mac_key(const Names& N, typename T::mac_key_type key) N.num_players(), key); } +template +void read_mac_key(const Names& N, typename T::mac_key_type& key) +{ + read_mac_key(get_prep_sub_dir(N.num_players()), N, key); +} + template void read_mac_key(const string& directory, const Names& N, T& key) { diff --git a/README.md b/README.md index 5cf8b6d42..2e457d2f7 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ parties and malicious security. On Linux, this requires a working toolchain and [all requirements](#requirements). On Ubuntu, the following might suffice: ``` -sudo apt-get install automake build-essential clang cmake git libboost-dev libboost-thread-dev libgmp3-dev libntl-dev libsodium-dev libssl-dev libtool python3 +sudo apt-get install automake build-essential clang cmake git libboost-dev libboost-thread-dev libgmp-dev libntl-dev libsodium-dev libssl-dev libtool python3 ``` On MacOS, this requires [brew](https://brew.sh) to be installed, which will be used for all dependencies. @@ -230,28 +230,13 @@ following repositories: - https://github.com/mkskeller/SPDZ-BMR-ORAM - https://github.com/mkskeller/SPDZ-Yao -#### Alternatives - -There is another fork of SPDZ-2 called -[SCALE-MAMBA](https://github.com/KULeuven-COSIC/SCALE-MAMBA). -The main differences at the time of writing are as follows: -- It provides honest-majority computation for any Q2 structure. -- For dishonest majority computation, it provides integration of -SPDZ/Overdrive offline and online phases but without secure key -generation. -- It only provides computation modulo a prime. -- It only provides malicious security. - -More information can be found here: -https://homes.esat.kuleuven.be/~nsmart/SCALE - #### Overview For the actual computation, the software implements a virtual machine that executes programs in a specific bytecode. Such code can be generated from high-level Python code using a compiler that optimizes the computation with a particular focus on minimizing the number of -communication rounds (for protocol based on secret sharing) or on +communication rounds (for protocols based on secret sharing) or on AES-NI pipelining (for garbled circuits). The software uses two different bytecode sets, one for @@ -275,7 +260,7 @@ compute the preprocessing time for a particular computation. better. Note that GCC 5/6 and clang 9 don't support libOTe, so you need to deactivate its use for these compilers (see the next section). - - For protocol using oblivious transfer, libOTe with [the necessary + - For protocols using oblivious transfer, libOTe with [the necessary patches](https://github.com/mkskeller/softspoken-implementation) but without SimplestOT. The easiest way is to run `make libote`, which will install it as needed in a subdirectory. libOTe requires @@ -605,7 +590,7 @@ run SqueezeNet inference for ImageNet as follows: git clone https://github.com/mkskeller/EzPC cd EzPC/Athos/Networks/SqueezeNetImgNet axel -a -n 5 -c --output ./PreTrainedModel https://github.com/avoroshilov/tf-squeezenet/raw/master/sqz_full.mat -pip3 install scipy==1.1.0 +pip3 install numpy scipy pillow>=9.1 tensorflow python3 squeezenet_main.py --in ./SampleImages/n02109961_36.JPEG --saveTFMetadata True python3 squeezenet_main.py --in ./SampleImages/n02109961_36.JPEG --scalingFac 12 --saveImgAndWtData True cd ../../../.. @@ -620,8 +605,8 @@ three-party semi-honest computation, similar to CrypTFlow's Porthos. Replace 1 by the desired number of thread in the last two lines. If you run with some other protocols, you will need to remove `trunc_pr` and/or `split`. Also note that you will need to use a -CrypTFlow repository that includes the patch in -https://github.com/mkskeller/EzPC/commit/2021be90d21dc26894be98f33cd10dd26769f479. +CrypTFlow repository that includes the patches in +https://github.com/mkskeller/EzPC. [The reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.ml) contains further documentation on available layers. @@ -965,6 +950,9 @@ that: Make sure to run `make clean` before recompiling any binaries. Then, you need to run `make Fake-Offline.x -party.x`. +Note that you can as well run the full protocol with option `-v` to +see the cost split by preprocessing and online phase. + ### SPDZ The SPDZ protocol uses preprocessing, that is, in a first (sometimes @@ -1007,22 +995,26 @@ This creates the bytecode and schedule files in Programs/Bytecode/ and Programs/ To run the above program with two parties on one machine, run: -`./Player-Online.x -N 2 0 tutorial` +`./mascot-party.x -F -N 2 0 tutorial` -`./Player-Online.x -N 2 1 tutorial` (in a separate terminal) +`./mascot-party.x -F -N 2 1 tutorial` (in a separate terminal) Or, you can use a script to do the above automatically: -`Scripts/run-online.sh tutorial` +`Scripts/mascot.sh -F tutorial` -To run a program on two different machines, firstly the preprocessing data must be -copied across to the second machine (or shared using sshfs), and secondly, Player-Online.x -needs to be passed the machine where the first party is running. -e.g. if this machine is name `diffie` on the local network: +MASCOT is one of the protocols that use SPDZ for the online phase, and +`-F` causes the programs to read preprocessing material from files. + +To run a program on two different machines, firstly the preprocessing +data must be copied across to the second machine (or shared using +sshfs), and secondly, `mascot-party.x` needs to be passed the machine +where the first party is running. E.g., if this machine is named +`diffie` on the local network: -`./Player-Online.x -N 2 -h diffie 0 test_all` +`./mascot-party.x -F -N 2 -h diffie 0 test_all` -`./Player-Online.x -N 2 -h diffie 1 test_all` +`./mascot-party.x -F -N 2 -h diffie 1 test_all` The software uses TCP ports around 5000 by default, use the `-pn` argument to change that. @@ -1094,7 +1086,7 @@ the actual computation. First, compile the binary: `make -offline.x` At the time of writing the supported protocols are `mascot`, -`cowgear`, mal-shamir`, `semi`, `semi2k`, and `hemi`. +`cowgear`, `mal-shamir`, `semi`, `semi2k`, and `hemi`. If you have not done so already, then compile your high-level program: diff --git a/Scripts/list-field-protocols.sh b/Scripts/list-field-protocols.sh index 16052052c..c303c9775 100755 --- a/Scripts/list-field-protocols.sh +++ b/Scripts/list-field-protocols.sh @@ -1,4 +1,4 @@ #!/bin/bash echo rep-field shamir mal-rep-field ps-rep-field sy-rep-field \ - atlas mal-shamir sy-shamir semi hemi temi mascot cowgear chaigear + atlas mal-shamir sy-shamir semi hemi temi mascot soho cowgear chaigear diff --git a/Scripts/list-protocols.sh b/Scripts/list-protocols.sh new file mode 100755 index 000000000..6b68fb118 --- /dev/null +++ b/Scripts/list-protocols.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +dir="$(dirname $0)" +echo `$dir/list-ring-protocols.sh` `$dir/list-field-protocols.sh` diff --git a/Scripts/list-ring-protocols.sh b/Scripts/list-ring-protocols.sh index 9491d066d..b67c12a69 100755 --- a/Scripts/list-ring-protocols.sh +++ b/Scripts/list-ring-protocols.sh @@ -1,4 +1,4 @@ #!/bin/bash echo ring semi2k brain mal-rep-ring ps-rep-ring sy-rep-ring \ - spdz2k rep4-ring + spdz2k rep4-ring dealer-ring diff --git a/Scripts/memory-usage.py b/Scripts/memory-usage.py index d5026eaa3..098f90b77 100755 --- a/Scripts/memory-usage.py +++ b/Scripts/memory-usage.py @@ -19,8 +19,8 @@ def process(tapename, res, regs): for inst in Tape.read_instructions(tapename): t = inst.type if issubclass(t, DirectMemoryInstruction): - res[t.arg_format[0]] = max(inst.args[1].i + inst.size, - res[t.arg_format[0]]) + 1 + res[type(inst.args[0])] = max(inst.args[1].i + inst.size, + res[type(inst.args[0])]) + 1 for arg in inst.args: if isinstance(arg, RegisterArgFormat): regs[type(arg)] = max(regs[type(arg)], arg.i + inst.size) @@ -49,7 +49,7 @@ def output(data): total += sum(x.values()) print ('Memory:') -output(res) +output(regout(res)) print ('Registers in main thread:') output(regout(regs)) diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index 8ceef7385..64d24d300 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -1,4 +1,11 @@ +gdb_front() +{ + prog=$1 + shift + gdb $prog -ex "run $*" +} + gdb_screen() { prog=$1 @@ -46,14 +53,20 @@ run_player() { fi set -o pipefail for i in $(seq 0 $[players-1]); do - >&2 echo Running $prefix $SPDZROOT/$bin $i $params + if test "$GDB_PLAYER" -a $i = "$GDB_PLAYER"; then + my_prefix=gdb_front + else + my_prefix=$prefix + fi + front_player=${GDB_PLAYER:-0} + >&2 echo Running $my_prefix $SPDZROOT/$bin $i $params log=logs/$log_prefix$i - $prefix $SPDZROOT/$bin $i $params 2>&1 | + $my_prefix $SPDZROOT/$bin $i $params 2>&1 | { if test "$BENCH"; then - if test $i = 0; then tee -a $log; else cat >> $log; fi; + if test $i = $front_player; then tee -a $log; else cat >> $log; fi; else - if test $i = 0; then tee $log; else cat > $log; fi; + if test $i = $front_player; then tee $log; else cat > $log; fi; fi } & codes[$i]=$! diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 094a6393d..5d790dee7 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -37,9 +37,6 @@ function test_vm fi } -# big buckets for smallest batches -run_opts="$run_opts -B 5" - export PORT=$((RANDOM%10000+10000)) export BENCH= diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index e64c8e461..8e61aa135 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -78,7 +78,7 @@ void BufferBase::try_rewind() if (file->peek() == ifstream::traits_type::eof()) throw runtime_error("empty file: " + filename); if (!rewind) - cerr << "REWINDING - ONLY FOR BENCHMARKING" << endl; + cerr << "REUSING DATA - ONLY FOR BENCHMARKING" << endl; rewind = true; eof = true; } diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index ec39b7728..2d38ec90f 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -90,3 +90,19 @@ gf2n_not_supported::gf2n_not_supported(int n, string options) : + (options.empty() ? "" : ", options are " + options)) { } + +setup_error::setup_error(const string& error) : + runtime_error(error) +{ +} + +prep_setup_error::prep_setup_error(const string& error, int nplayers, + const string& fake_opts) : + setup_error( + "Something is wrong with the preprocessing data on disk: " + + error + + "\nHave you run the right program for generating it, " + "such as './Fake-Offline.x " + + to_string(nplayers) + fake_opts + "'?") +{ +} diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index a3ca3a5d0..b62f01c32 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -284,4 +284,16 @@ class gf2n_not_supported : public runtime_error gf2n_not_supported(int n, string options = ""); }; +class setup_error : public runtime_error +{ +public: + setup_error(const string& error); +}; + +class prep_setup_error : public setup_error +{ +public: + prep_setup_error(const string& error, int nplayers, const string& fake_opts); +}; + #endif diff --git a/Tools/random.h b/Tools/random.h index 3b16cdb2f..c80f60160 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -72,7 +72,7 @@ class PRNG public: - /// Construction without initialization. Usage without initilization will fail. + /// Construction without initialization. Usage without initialization will fail. PRNG(); /// Initialize with ``SEED_SIZE`` bytes from buffer. PRNG(octetStream& seed); diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index 6dc233383..68465bcc7 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -811,6 +811,10 @@ int FakeParams::generate() // default generate_ring<64>(G); +#if defined(RING_SIZE) and RING_SIZE != 64 + generate_ring(G); +#endif + // reuse lgp for simplified interface switch (lgp) { @@ -859,6 +863,8 @@ inline void FakeParams::generate_ring(PRNG& G) if (nplayers == 3) { make_bits>({}, nplayers, default_num, zero, G); + make_basic>({}, nplayers, + default_num, zero, G); make_basic>({}, nplayers, default_num, zero, G); make_basic>({}, nplayers, diff --git a/Utils/protocol-tutorial.cpp b/Utils/protocol-tutorial.cpp new file mode 100644 index 000000000..6b460e058 --- /dev/null +++ b/Utils/protocol-tutorial.cpp @@ -0,0 +1,74 @@ +/* + * protocol-tutorial.cpp + * + * This file demonstrates the use of the low-level capabilities + * to build a protocol, namely Rep3 multiplication and opening. + * + */ + +#include "Networking/CryptoPlayer.h" +#include "Math/Z2k.hpp" + +int main(int argc, char** argv) +{ + if (argc != 2) + { + cerr << "Usage: " << argv[0] << " " << endl; + exit(1); + } + + // set up networking on localhost + int my_number = atoi(argv[1]); + int port_base = 9999; + Names N(my_number, 3, "localhost", port_base); + CryptoPlayer P(N); + + // correlated randomness for resharing + SeededPRNG G[2]; + + // synchronize with other parties + octetStream os; + os.append(G[0].get_seed(), SEED_SIZE); + P.pass_around(os, os, 1); + G[1].SetSeed(os.consume(SEED_SIZE)); + + // simplify code + typedef Z2<64> Z; + + // start with same shares on all parties for simplicity + // replicated secret sharing of 3 + Z a[2] = {1, 1}; + // and 6 + Z b[2] = {2, 2}; + + // compute an additive sharing of the product + Z cc = a[0] * (b[0] + b[1]) + a[1] * b[0]; + + // result shares + Z c[2]; + + // re-randomize + c[0] = cc + G[0].get() - G[1].get(); + + // send and receive share + os.reset_write_head(); + c[0].pack(os); + P.pass_around(os, os, 1); + c[1].unpack(os); + + // open value to party 0 + if (P.my_num() == 1) + { + os.reset_write_head(); + c[0].pack(os); + P.send_to(0, os); + } + + // output result on party 0, which should be 18 + if (P.my_num() == 0) + { + P.receive_player(1, os); + cout << "My shares: " << c[0] << ", " << c[1] << endl; + cout << "Result: " << (os.get() + c[0] + c[1]) << endl; + } +} diff --git a/Utils/stream-fake-mascot-triples.cpp b/Utils/stream-fake-mascot-triples.cpp index 517056e72..16901471c 100644 --- a/Utils/stream-fake-mascot-triples.cpp +++ b/Utils/stream-fake-mascot-triples.cpp @@ -22,8 +22,8 @@ class Info void* run(void* arg) { auto& info = *(Info*) arg; - Files> files(info.nplayers, info.key, PREP_DIR, DATA_TRIPLE, info.thread_num); SeededPRNG G; + Files> files(info.nplayers, info.key, PREP_DIR, DATA_TRIPLE, G, info.thread_num); int count = 0; while (true) { @@ -53,7 +53,8 @@ int main() string prep_data_prefix = PREP_DIR; gfpvar::generate_setup(prep_data_prefix, nplayers, lgp); T::mac_key_type keyp; - generate_mac_keys(keyp, nplayers, prep_data_prefix); + SeededPRNG G; + generate_mac_keys(keyp, nplayers, prep_data_prefix, G); int nthreads = 3; OnlineOptions::singleton.file_prep_per_thread = true; diff --git a/deps/libOTe b/deps/libOTe index 5d9f9c400..e2d782519 160000 --- a/deps/libOTe +++ b/deps/libOTe @@ -1 +1 @@ -Subproject commit 5d9f9c400c6acda734cbd20b5b8ea02392c0f75e +Subproject commit e2d7825196986c106394ab4e2d05f179a239624a diff --git a/doc/add-protocol.rst b/doc/add-protocol.rst index 1f741f36a..28019f68a 100644 --- a/doc/add-protocol.rst +++ b/doc/add-protocol.rst @@ -36,7 +36,8 @@ of the share class. For example ``replicated-ring-party.x`` is implemented in ``Machines/replicated-ring-party.cpp``, which refers to :c:func:`Rep3Share2` in ``Protocols/Rep3Share2.h``. There you will find that it uses :c:func:`Replicated` for multiplication, which is -found in ``Protocols/Replicated.h``. +found in ``Protocols/Replicated.h``. You can also consult :ref:`the +tutorial for the lowest-level interface `. 1. Fill in the :c:func:`constant` static member function of :c:type:`NoShare` as well as the :c:func:`exchange` member function diff --git a/doc/compilation.rst b/doc/compilation.rst index 993f75da6..9d9cfe875 100644 --- a/doc/compilation.rst +++ b/doc/compilation.rst @@ -5,7 +5,7 @@ The easiest way of using MP-SPDZ is using ``compile.py`` as described below. If you would like to run compilation directly from Python, see :ref:`Direct Compilation in Python`. -After putting your code in ``Program/Source/.mpc``, run the +After putting your code in ``Program/Source/.[mpc|py]``, run the compiler from the root directory as follows .. code-block:: bash @@ -13,8 +13,9 @@ compiler from the root directory as follows ./compile.py [options] [args] The arguments `` [args]`` are accessible as list under -``program.args`` within ``progname.mpc``, with ```` as -``program.args[0]``. +``program.args`` within ``progname.[mpc|py]``, with ```` as +``program.args[0]``. The resulting program for the virtual machine +will be called ``[-[-...]``. The following options influence the computation domain: diff --git a/doc/index.rst b/doc/index.rst index f072135bc..5054c1d91 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -30,6 +30,7 @@ If you're new to MP-SPDZ, consider the following: client-interface non-linear preprocessing + lowest-level add-protocol homomorphic-encryption troubleshooting diff --git a/doc/low-level.rst b/doc/low-level.rst index 89302d97e..34ae0ce08 100644 --- a/doc/low-level.rst +++ b/doc/low-level.rst @@ -1,7 +1,9 @@ +.. _low-level: + Low-Level Interface =================== -In the following we will explain the basic of the C++ interface by +In the following we will explain the basics of the C++ interface by walking trough :file:`Utils/paper-example.cpp`. .. default-domain:: cpp diff --git a/doc/lowest-level.rst b/doc/lowest-level.rst new file mode 100644 index 000000000..e3b29c213 --- /dev/null +++ b/doc/lowest-level.rst @@ -0,0 +1,149 @@ +.. _lowest-level: + +Lowest-Level Interface +---------------------- + +In the following, we will introduce the most protocol-independent +interfaces by walking through `Utils/protocol-tutorial +<../Utils/protocol-tutorial.cpp>`_. It implements the Rep3 +multiplication protocol independently of the usual protocol interface +for illustration purposes. + +.. default-domain:: cpp + +.. code-block:: cpp + + // set up networking on localhost + int my_number = atoi(argv[1]); + int port_base = 9999; + Names N(my_number, 3, "localhost", port_base); + CryptoPlayer P(N); + +This sets up pairwise encrypted connections as in :ref:`the low-level +example `. + +.. code-block:: cpp + + // correlated randomness for resharing + SeededPRNG G[2]; + +The protocol requires every pair of parties to have a common PRNG, so +we need two instances. We use :class:`SeededPRNG` to make sure to +never use an uninitialized one. + +.. code-block:: cpp + + // synchronize with other parties + octetStream os; + os.append(G[0].get_seed(), SEED_SIZE); + +:class:`octetStream` is generally used to serialize and aggregate +network communication. In this case, we use it to store the seed of +one of the PRNGs. + +.. code-block:: cpp + + P.pass_around(os, os, 1); + +:func:`Player::pass_around` allows simultaneous sending to the "next" party +and receiving from the "previous" party. We use this with the buffer +holding the seed. As we don't need the send buffer afterwards, we can +use the same buffer for receiving. + +.. code-block:: cpp + + G[1].SetSeed(os.consume(SEED_SIZE)); + +We seed the second PRNG using the received data. :func:`PRNG::SetSeed` +implicitly uses the required number of bits. + +.. code-block:: cpp + + // simplify code + typedef Z2<64> Z; + +In this example, we use integers modulo :math:`2^{64}`, but the +protocol also works for any modulus, so we could also use +:class:`gfp_`. + +.. code-block:: cpp + + // start with same shares on all parties for simplicity + // replicated secret sharing of 3 + Z a[2] = {1, 1}; + // and 6 + Z b[2] = {2, 2}; + +For every secret number in Rep3, every party holds a pair of numbers +in the domain such that every pair of parties has the same number. The +sum of the unique numbers is the secret. + +.. code-block:: cpp + + // compute an additive sharing of the product + Z cc = a[0] * (b[0] + b[1]) + a[1] * b[0]; + +In a first step, every party computes an additive share of the +product. See `Araki et al. `_ for +details. All domain classes support the standard operators. + +.. code-block:: cpp + + // result shares + Z c[2]; + + // re-randomize + c[0] = cc + G[0].get() - G[1].get(); + +Sending the computed additive secret sharing directly to another party +to get back to a replicative secret sharing would be +insecure. Therefore, we randomize it using random numbers from the two +PRNGs. + +.. code-block:: cpp + + // send and receive share + os.reset_write_head(); + c[0].pack(os); + P.pass_around(os, os, 1); + c[1].unpack(os); + +We clear the buffer, serialize our share, send it to the "next" party, +and receive one from the "previous" party. This concludes the +multiplication protocol. :func:`Z2::pack` and :func:`Z2::unpack` are +main methods for (de-)serialization. All domain classes support +this. You can use :func:`octetStream::output` to write the buffer to a +C++ output stream. + +.. code-block:: cpp + + // open value to party 0 + if (P.my_num() == 1) + { + os.reset_write_head(); + c[0].pack(os); + P.send_to(0, os); + } + +To allow party 0 to output the result, party 1 serializes one of their +shares and sends it to party 0. + +.. code-block:: cpp + + // output result on party 0, which should be 18 + if (P.my_num() == 0) + { + P.receive_player(1, os); + cout << "My shares: " << c[0] << ", " << c[1] << endl; + cout << "Result: " << (os.get() + c[0] + c[1]) << endl; + } + +Party 0 receives the missing share from party 1 and reconstructs the +secret by summing up. + +You can run the example as follows in the main directory: + +.. code-block:: sh + + make protocol-tutorial.x + for i in 0 1 2; do ./protocol-tutorial.x $i & true; done diff --git a/doc/machine-learning.rst b/doc/machine-learning.rst index d59ed343e..d873256bd 100644 --- a/doc/machine-learning.rst +++ b/doc/machine-learning.rst @@ -237,7 +237,8 @@ of linear regression. PyTorch interface ================= -MP-SPDZ supports importing sequential models from PyTorch as shown in +MP-SPDZ supports importing sequential models from PyTorch using +:py:func:`~Compiler.ml.layers_from_torch` as shown in this code snippet in ``torch_mnist_dense.mpc``:: import torch.nn as nn diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 1c6ff2eec..508936a96 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -45,6 +45,45 @@ resulting in potentially too much virtual machine code. Consider using version. +Cannot derive truth value from register +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This message appears when you try to use branching on run-time data +types, for example:: + + x = cint(0) + y = 0 + if x == 0: + y = 1 + print_ln('x is zero') + +There a number of ways to solve this: + +1. Use the ``--flow-optimization`` argument during compilation. +2. Use run-time branching:: + + x = cint(0) + y = cint(0) + @if_(x == 0) + def _(): + y.update(1) + print_ln('x is zero') + + See :py:func:`~Compiler.library.if_e` for the equivalent to + if/else. +3. Use conditional statements:: + + check = x == 0 + y = check.if_else(1, y) + print_ln_if(check, 'x is zero') + +If the condition is secret, for example, :py:obj:`x` is an +:py:class:`~Compiler.types.sint` and thus ``x == 0`` is secret too, +:py:func:`~Compiler.types.sint.if_else` is the only option because +branching would reveal the secret. For the same reason, +:py:func:`~Compiler.library.print_ln_if` doesn't work on secret values. + + Incorrect results when using :py:class:`~Compiler.types.sfix` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~