From 470b0758038835b18be7b18564fb0ed021b112be Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Thu, 21 Nov 2019 17:23:51 +1100 Subject: [PATCH] Python 3, semi-honest computation using semi-homomorphic encryption. --- BMR/RealGarbleWire.hpp | 4 +- BMR/RealProgramParty.hpp | 2 +- CHANGELOG.md | 6 + Compiler/GC/program.py | 2 +- Compiler/GC/types.py | 22 +- Compiler/__init__.py | 10 +- Compiler/allocator.py | 105 ++++----- Compiler/circuit_oram.py | 18 +- Compiler/comparison.py | 38 ++-- Compiler/compilerLib.py | 28 +-- Compiler/dijkstra.py | 18 +- Compiler/floatingpoint.py | 64 +++--- Compiler/graph.py | 18 +- Compiler/gs.py | 10 +- Compiler/instructions.py | 50 +++-- Compiler/instructions_base.py | 15 +- Compiler/library.py | 107 +++++---- Compiler/ml.py | 25 ++- Compiler/mpc_math.py | 4 +- Compiler/oram.py | 139 ++++++------ Compiler/path_oram.py | 19 +- Compiler/permutation.py | 60 ++--- Compiler/program.py | 141 ++++++------ Compiler/types.py | 206 +++++++++++------- Compiler/util.py | 52 ++++- ECDSA/EcdsaOptions.h | 66 ++++++ ECDSA/Fake-ECDSA.cpp | 2 +- ECDSA/P256Element.cpp | 8 + ECDSA/P256Element.h | 7 +- ECDSA/README.md | 3 +- ECDSA/fake-spdz-ecdsa-party.cpp | 7 +- ECDSA/hm-ecdsa-party.hpp | 18 +- ECDSA/ot-ecdsa-party.hpp | 77 ++++++- ECDSA/preprocessing.hpp | 54 +++-- ECDSA/rep-ecdsa-party.cpp | 6 +- ECDSA/sign.hpp | 14 +- ExternalIO/bankers-bonus-commsec-client.cpp | 1 + FHE/Ciphertext.h | 2 + FHEOffline/Multiplier.cpp | 17 +- FHEOffline/Multiplier.h | 4 +- GC/Processor.hpp | 2 + GC/SemiPrep.cpp | 4 +- GC/ShareParty.hpp | 2 +- GC/ShareSecret.hpp | 3 +- GC/ThreadMaster.hpp | 2 +- GC/TinyPrep.hpp | 9 +- Machines/OTMachine.cpp | 5 +- Machines/Player-Online.hpp | 2 +- Machines/TripleMachine.cpp | 30 +-- Machines/hemi-party.cpp | 29 +++ Makefile | 3 +- Math/Setup.cpp | 10 +- Math/Setup.h | 1 + Math/Square.cpp | 22 +- Math/Square.h | 2 + Math/Z2k.h | 3 +- Math/Z2k.hpp | 13 ++ Math/gfp.h | 2 + Math/mpn_fixed.h | 2 +- Networking/Player.cpp | 36 ++- Networking/Player.h | 5 +- Networking/ServerSocket.cpp | 96 +++----- OT/BaseOT.cpp | 4 +- OT/BaseOT.h | 9 +- OT/BitMatrix.cpp | 1 + OT/BitMatrix.h | 2 +- OT/MascotParams.cpp | 76 +------ OT/NPartyTripleGenerator.h | 30 ++- OT/NPartyTripleGenerator.hpp | 82 +++---- OT/OTExtension.cpp | 2 +- OT/OTExtension.h | 2 +- OT/OTExtensionWithMatrix.cpp | 48 +++- OT/OTExtensionWithMatrix.h | 7 +- OT/OTMultiplier.hpp | 85 +++++--- OT/OTTripleSetup.cpp | 19 ++ OT/OTTripleSetup.h | 12 +- OT/Triple.hpp | 17 +- OT/TripleMachine.h | 22 +- Processor/BaseMachine.h | 2 +- Processor/Data_Files.h | 2 + Processor/Input.hpp | 6 +- Processor/Instruction.h | 3 + Processor/Instruction.hpp | 19 +- Processor/IntInput.cpp | 14 -- Processor/IntInput.h | 3 +- Processor/IntInput.hpp | 15 ++ Processor/Machine.hpp | 4 +- Processor/Processor.h | 2 - Processor/Processor.hpp | 53 ++--- .../{ProcessorBase.cpp => ProcessorBase.hpp} | 9 +- Programs/Source/aes.mpc | 4 +- Programs/Source/bankers_bonus_commsec.mpc | 4 +- Programs/Source/gc_and.mpc | 2 +- Programs/Source/htmac.mpc | 8 +- Programs/Source/regression.mpc | 12 +- Programs/Source/test_sbitfix.mpc | 4 +- Programs/Source/vickrey.mpc | 4 +- Protocols/Beaver.h | 5 + Protocols/Beaver.hpp | 14 ++ Protocols/HemiPrep.h | 39 ++++ Protocols/HemiPrep.hpp | 87 ++++++++ Protocols/HemiShare.h | 40 ++++ Protocols/MAC_Check.h | 6 +- Protocols/MAC_Check.hpp | 15 +- Protocols/MAC_Check_Base.h | 4 +- Protocols/MaliciousRepMC.h | 7 + Protocols/MaliciousRepMC.hpp | 29 +++ Protocols/MaliciousShamirMC.h | 5 + Protocols/MaliciousShamirMC.hpp | 23 +- Protocols/MascotPrep.h | 4 +- Protocols/MascotPrep.hpp | 20 +- Protocols/Rep3Share.h | 5 + Protocols/Replicated.h | 15 +- Protocols/Replicated.hpp | 30 +++ Protocols/ReplicatedMC.h | 4 + Protocols/ReplicatedMC.hpp | 29 ++- Protocols/ReplicatedPrep.hpp | 19 +- Protocols/SemiMC.h | 6 +- Protocols/SemiMC.hpp | 18 +- Protocols/SemiShare.h | 5 + Protocols/Shamir.h | 8 + Protocols/Shamir.hpp | 25 +++ Protocols/ShamirMC.h | 8 + Protocols/ShamirMC.hpp | 32 ++- Protocols/ShamirShare.h | 5 + Protocols/Share.h | 3 + README.md | 15 +- Scripts/hemi.sh | 8 + Scripts/run-common.sh | 15 ++ Scripts/test_ecdsa.sh | 3 - Scripts/test_tutorial.sh | 8 +- Tools/BitVector.cpp | 10 + Tools/BitVector.h | 2 + Tools/MMO.cpp | 28 ++- Tools/MMO.h | 2 + Tools/time-func.cpp | 18 +- Tools/time-func.h | 11 +- compile.py | 2 +- 138 files changed, 1931 insertions(+), 1021 deletions(-) create mode 100644 ECDSA/EcdsaOptions.h create mode 100644 Machines/hemi-party.cpp delete mode 100644 Processor/IntInput.cpp create mode 100644 Processor/IntInput.hpp rename Processor/{ProcessorBase.cpp => ProcessorBase.hpp} (87%) create mode 100644 Protocols/HemiPrep.h create mode 100644 Protocols/HemiPrep.hpp create mode 100644 Protocols/HemiShare.h create mode 100755 Scripts/hemi.sh diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index 625a5a300..a62b1e705 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -133,7 +133,7 @@ void RealGarbleWire::input(party_id_t from, char input) protocol.init_mul(party.shared_proc); protocol.prepare_mul(mask, T(1, party.P->my_num(), party.mac_key) - mask); protocol.exchange(); - if (party.MC->POpen(protocol.finalize_mul(), *party.P) != 0) + if (party.MC->open(protocol.finalize_mul(), *party.P) != 0) throw runtime_error("input mask not a bit"); } #ifdef DEBUG_MASK @@ -168,7 +168,7 @@ void RealGarbleWire::output() auto& party = RealProgramParty::s(); assert(party.MC != 0); assert(party.P != 0); - auto m = party.MC->POpen(mask, *party.P); + auto m = party.MC->open(mask, *party.P); party.output_masks.push_back(m.get_bit(0)); party.taint(); #ifdef DEBUG_MASK diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 011552d6d..ff9aef2cf 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -90,7 +90,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : { mac_key.randomize(prng); if (T::needs_ot) - BaseMachine::s().ot_setups.push_back({{{*P, true}}}); + BaseMachine::s().ot_setups.push_back({*P, true}); prep = Preprocessing::get_live_prep(0, usage); } else diff --git a/CHANGELOG.md b/CHANGELOG.md index 17a3f3d08..30f217019 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.1.3 (Nov 21, 2019) + +- Python 3 +- Semi-honest computation based on semi-homomorphic encryption +- Access to player information in high-level language + ## 0.1.2 (Oct 11, 2019) - Machine learning capabilities used for [MobileNets inference](https://eprint.iacr.org/2019/131) and the iDASH submission diff --git a/Compiler/GC/program.py b/Compiler/GC/program.py index d0953f4da..c77d4281c 100644 --- a/Compiler/GC/program.py +++ b/Compiler/GC/program.py @@ -5,6 +5,6 @@ def __init__(self, progname): types.program = self instructions.program = self self.curr_tape = None - execfile(progname) + exec(compile(open(progname).read(), progname, 'exec')) def malloc(self, *args): pass diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index a015aa2b5..988f0a300 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -5,6 +5,7 @@ from Compiler import util, oram, floatingpoint, library import Compiler.GC.instructions as inst import operator +from functools import reduce class bits(Tape.Register, _structure): n = 40 @@ -82,7 +83,7 @@ def load_mem(cls, address, mem_type=None, size=None): cls.load_inst[util.is_constant(address)](res, address) return res def store_in_mem(self, address): - self.store_inst[isinstance(address, (int, long))](self, address) + self.store_inst[isinstance(address, int)](self, address) def __init__(self, value=None, n=None, size=None): if size != 1 and size is not None: raise Exception('invalid size for bit type: %s' % size) @@ -92,11 +93,11 @@ def __init__(self, value=None, n=None, size=None): self.load_other(value) def set_length(self, n): if n > self.max_length: - print self.max_length + print(self.max_length) raise Exception('too long: %d' % n) self.n = n def load_other(self, other): - if isinstance(other, (int, long)): + if isinstance(other, int): self.set_length(self.n or util.int_len(other)) self.load_int(other) elif isinstance(other, regint): @@ -115,6 +116,7 @@ def long_one(self): def __repr__(self): return '%s(%d/%d)' % \ (super(bits, self).__repr__(), self.n, type(self).n) + __str__ = __repr__ class cbits(bits): max_length = 64 @@ -219,13 +221,13 @@ def get_input_from(cls, player, n_bits=None): @classmethod def load_dynamic_mem(cls, address): res = cls() - if isinstance(address, (long, int)): + if isinstance(address, int): inst.ldmsd(res, address, cls.n) else: inst.ldmsdi(res, address, cls.n) return res def store_in_dynamic_mem(self, address): - if isinstance(address, (long, int)): + if isinstance(address, int): inst.stmsd(self, address) else: inst.stmsdi(self, cbits.conv(address)) @@ -322,7 +324,7 @@ def mul_int(self, other): mul_bits = [self if b else zero for b in bits] return self.bit_compose(mul_bits) else: - print self.n, other + print(self.n, other) return NotImplemented def __lshift__(self, i): return self.bit_compose([sbit(0)] * i + self.bit_decompose()[:self.max_length-i]) @@ -478,7 +480,7 @@ def __init__(self, value, start, lengths, entries_per_block): self.start_demux = oram.demux_list(self.start_bits) self.entries = [sbits.bit_compose(self.value_bits[i*length:][:length]) \ for i in range(entries_per_block)] - self.mul_entries = map(operator.mul, self.start_demux, self.entries) + self.mul_entries = list(map(operator.mul, self.start_demux, self.entries)) self.bits = sum(self.mul_entries).bit_decompose() self.mul_value = sbits.compose(self.mul_entries, sum(self.lengths)) self.anti_value = self.mul_value + self.value @@ -662,6 +664,12 @@ def __mul__(self, other): return super(sbitfix, self).__mul__(other) __rxor__ = __xor__ __rmul__ = __mul__ + @staticmethod + def multipliable(other, k, f): + class cls(_fix): + int_type = sbitint.get_type(k) + cls.set_precision(f, k) + return cls._new(cls.int_type(other), k, f) sbitfix.set_precision(20, 41) diff --git a/Compiler/__init__.py b/Compiler/__init__.py index 1f52e36fd..9a22da461 100644 --- a/Compiler/__init__.py +++ b/Compiler/__init__.py @@ -1,15 +1,15 @@ -import compilerLib, program, instructions, types, library, floatingpoint -import GC.types +from . import compilerLib, program, instructions, types, library, floatingpoint +from .GC import types as GC_types import inspect -from config import * -from compilerLib import run +from .config import * +from .compilerLib import run # add all instructions to the program VARS dictionary compilerLib.VARS = {} instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)] -for mod in (types, GC.types): +for mod in (types, GC_types): instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\ if t[1].__module__ == mod.__name__] diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 2ed6016b7..f6d1f0e0f 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -10,16 +10,17 @@ import heapq, itertools import operator import sys +from functools import reduce class StraightlineAllocator: """Allocate variables in a straightline program using n registers. It is based on the precondition that every register is only defined once.""" def __init__(self, n): - self.alloc = {} + self.alloc = dict_by_id() self.usage = Compiler.program.RegType.create_dict(lambda: 0) - self.defined = {} - self.dealloc = set() + self.defined = dict_by_id() + self.dealloc = set_by_id() self.n = n def alloc_reg(self, reg, free): @@ -77,8 +78,8 @@ def process(self, program, alloc_pool): unused_regs.append(j) if unused_regs and len(unused_regs) == len(list(i.get_def())): # only report if all assigned registers are unused - print "Register(s) %s never used, assigned by '%s' in %s" % \ - (unused_regs,i,format_trace(i.caller)) + print("Register(s) %s never used, assigned by '%s' in %s" % \ + (unused_regs,i,format_trace(i.caller))) for j in i.get_used(): self.alloc_reg(j, alloc_pool) @@ -86,7 +87,7 @@ def process(self, program, alloc_pool): self.dealloc_reg(j, i, alloc_pool) if k % 1000000 == 0 and k > 0: - print "Allocated registers for %d instructions at" % k, time.asctime() + print("Allocated registers for %d instructions at" % k, time.asctime()) # print "Successfully allocated registers" # print "modp usage: %d clear, %d secret" % \ @@ -97,8 +98,8 @@ def process(self, program, alloc_pool): def determine_scope(block, options): - last_def = defaultdict(lambda: -1) - used_from_scope = set() + last_def = defaultdict_by_id(lambda: -1) + used_from_scope = set_by_id() def find_in_scope(reg, scope): while True: @@ -114,18 +115,18 @@ def read(reg, n): used_from_scope.add(reg) reg.can_eliminate = False else: - print 'Warning: read before write at register', reg - print '\tline %d: %s' % (n, instr) - print '\tinstruction trace: %s' % format_trace(instr.caller, '\t\t') - print '\tregister trace: %s' % format_trace(reg.caller, '\t\t') + print('Warning: read before write at register', reg) + print('\tline %d: %s' % (n, instr)) + print('\tinstruction trace: %s' % format_trace(instr.caller, '\t\t')) + print('\tregister trace: %s' % format_trace(reg.caller, '\t\t')) if options.stop: sys.exit(1) def write(reg, n): if last_def[reg] != -1: - print 'Warning: double write at register', reg - print '\tline %d: %s' % (n, instr) - print '\ttrace: %s' % format_trace(instr.caller, '\t\t') + print('Warning: double write at register', reg) + print('\tline %d: %s' % (n, instr)) + print('\ttrace: %s' % format_trace(instr.caller, '\t\t')) if options.stop: sys.exit(1) last_def[reg] = n @@ -146,7 +147,7 @@ def write(reg, n): write(reg, n) block.used_from_scope = used_from_scope - block.defined_registers = set(last_def.iterkeys()) + block.defined_registers = set_by_id(last_def.keys()) class Merger: def __init__(self, block, options, merge_classes): @@ -178,7 +179,7 @@ def expand_vector_args(inst): if inst.is_vec(): for arg in inst.args: arg.create_vector_elements() - res = sum(zip(*inst.args), ()) + res = sum(list(zip(*inst.args)), ()) return list(res) else: return inst.args @@ -241,7 +242,7 @@ def merge_inputs(self): remaining_input_nodes = [] def do_merge(nodes): if len(nodes) > 1000: - print 'Merging %d inputs...' % len(nodes) + print('Merging %d inputs...' % len(nodes)) self.do_merge(iter(nodes)) for n in self.input_nodes: inst = self.instructions[n] @@ -252,7 +253,7 @@ def do_merge(nodes): if len(merge) >= self.max_parallel_open: do_merge(merge) merge[:] = [] - for merge in reversed(sorted(merges.itervalues())): + for merge in reversed(sorted(merges.values())): if merge: do_merge(merge) self.input_nodes = remaining_input_nodes @@ -266,7 +267,7 @@ def compute_preorder(self, merges, rev_depth_of): instructions = self.instructions flex_nodes = defaultdict(dict) starters = [] - for n in xrange(len(G)): + for n in range(len(G)): if n not in merge_nodes_set and \ depth_of[n] != rev_depth_of[n] and G[n] and G.get_attr(n,'start') == -1 and not isinstance(instructions[n], AsymmetricCommunicationInstruction): #print n, depth_of[n], rev_depth_of[n] @@ -275,19 +276,19 @@ def compute_preorder(self, merges, rev_depth_of): not isinstance(self.instructions[n], RawInputInstruction): starters.append(n) if n % 10000000 == 0 and n > 0: - print "Processed %d nodes at" % n, time.asctime() + print("Processed %d nodes at" % n, time.asctime()) inputs = defaultdict(list) for node in self.input_nodes: player = self.instructions[node].args[0] inputs[player].append(node) - first_inputs = [l[0] for l in inputs.itervalues()] + first_inputs = [l[0] for l in inputs.values()] other_inputs = [] i = 0 while True: i += 1 found = False - for l in inputs.itervalues(): + for l in inputs.values(): if i < len(l): other_inputs.append(l[i]) found = True @@ -299,20 +300,20 @@ def compute_preorder(self, merges, rev_depth_of): # magical preorder for topological search max_depth = max(merges) if max_depth > 10000: - print "Computing pre-ordering ..." - for i in xrange(max_depth, 0, -1): + print("Computing pre-ordering ...") + for i in range(max_depth, 0, -1): preorder.append(G.get_attr(merges[i], 'stop')) - for j in flex_nodes[i-1].itervalues(): + for j in flex_nodes[i-1].values(): preorder.extend(j) preorder.extend(flex_nodes[0].get(i, [])) preorder.append(merges[i]) if i % 100000 == 0 and i > 0: - print "Done level %d at" % i, time.asctime() + print("Done level %d at" % i, time.asctime()) preorder.extend(other_inputs) preorder.extend(starters) preorder.extend(first_inputs) if max_depth > 10000: - print "Done at", time.asctime() + print("Done at", time.asctime()) return preorder def longest_paths_merge(self): @@ -343,8 +344,8 @@ def longest_paths_merge(self): t = type(self.instructions[merge[0]]) self.counter[t] += len(merge) if len(merge) > 1000: - print 'Merging %d %s in round %d/%d' % \ - (len(merge), t.__name__, i, len(merges)) + print('Merging %d %s in round %d/%d' % \ + (len(merge), t.__name__, i, len(merges))) self.do_merge(merge) self.merge_inputs() @@ -352,11 +353,11 @@ def longest_paths_merge(self): preorder = None if len(instructions) > 100000: - print "Topological sort ..." + print("Topological sort ...") order = Compiler.graph.topological_sort(G, preorder) instructions[:] = [instructions[i] for i in order if instructions[i] is not None] if len(instructions) > 100000: - print "Done at", time.asctime() + print("Done at", time.asctime()) return len(merges) @@ -377,7 +378,7 @@ def dependency_graph(self, merge_classes): self.G = G reg_nodes = {} - last_def = defaultdict(lambda: -1) + last_def = defaultdict_by_id(lambda: -1) last_mem_write = [] last_mem_read = [] warned_about_mem = [] @@ -411,8 +412,8 @@ def write(reg, n): def handle_mem_access(addr, reg_type, last_access_this_kind, last_access_other_kind): - this = last_access_this_kind[addr,reg_type] - other = last_access_other_kind[addr,reg_type] + this = last_access_this_kind[str(addr),reg_type] + other = last_access_other_kind[str(addr),reg_type] if this and other: if this[-1] < other[0]: del this[:] @@ -429,15 +430,15 @@ def mem_access(n, instr, last_access_this_kind, last_access_other_kind): handle_mem_access(addr_i, reg_type, last_access_this_kind, last_access_other_kind) if not warned_about_mem and (instr.get_size() > 100): - print 'WARNING: Order of memory instructions ' \ - 'not preserved due to long vector, errors possible' + print('WARNING: Order of memory instructions ' \ + 'not preserved due to long vector, errors possible') warned_about_mem.append(True) else: handle_mem_access(addr, reg_type, last_access_this_kind, last_access_other_kind) if not warned_about_mem and not isinstance(instr, DirectMemoryInstruction): - print 'WARNING: Order of memory instructions ' \ - 'not preserved, errors possible' + print('WARNING: Order of memory instructions ' \ + 'not preserved, errors possible') # hack warned_about_mem.append(True) @@ -553,11 +554,11 @@ def keep_order(instr, n, t, arg_index=None): self.sources.append(n) if n % 100000 == 0 and n > 0: - print "Processed dependency of %d/%d instructions at" % \ - (n, len(block.instructions)), time.asctime() + print("Processed dependency of %d/%d instructions at" % \ + (n, len(block.instructions)), time.asctime()) if len(open_nodes) > 1000: - print "Program has %d %s instructions" % (len(open_nodes), merge_classes) + print("Program has %d %s instructions" % (len(open_nodes), merge_classes)) def merge_nodes(self, i, j): """ Merge node j into i, removing node j """ @@ -566,8 +567,8 @@ def merge_nodes(self, i, j): G.remove_edge(i, j) if i in G[j]: G.remove_edge(j, i) - G.add_edges_from(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]])) - G.add_edges_from(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]])) + G.add_edges_from(list(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]]))) + G.add_edges_from(list(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]]))) G.get_attr(i, 'merges').append(j) G.remove_node(j) @@ -578,7 +579,7 @@ def eliminate_dead_code(self): count = 0 open_count = 0 stats = defaultdict(lambda: 0) - for i,inst in zip(xrange(len(instructions) - 1, -1, -1), reversed(instructions)): + for i,inst in zip(range(len(instructions) - 1, -1, -1), reversed(instructions)): # remove if instruction has result that isn't used unused_result = not G.degree(i) and len(list(inst.get_def())) \ and reduce(operator.and_, (reg.can_eliminate for reg in inst.get_def())) \ @@ -608,21 +609,21 @@ def eliminate(i): eliminate(i) count += 2 if count > 0: - print 'Eliminated %d dead instructions, among which %d opens: %s' \ - % (count, open_count, dict(stats)) + print('Eliminated %d dead instructions, among which %d opens: %s' \ + % (count, open_count, dict(stats))) def print_graph(self, filename): f = open(filename, 'w') - print >>f, 'digraph G {' + print('digraph G {', file=f) for i in range(self.G.n): for j in self.G[i]: - print >>f, '"%d: %s" -> "%d: %s";' % \ - (i, self.instructions[i], j, self.instructions[j]) - print >>f, '}' + print('"%d: %s" -> "%d: %s";' % \ + (i, self.instructions[i], j, self.instructions[j]), file=f) + print('}', file=f) f.close() def print_depth(self, filename): f = open(filename, 'w') for i in range(self.G.n): - print >>f, '%d: %s' % (self.depths[i], self.instructions[i]) + print('%d: %s' % (self.depths[i], self.instructions[i]), file=f) f.close() diff --git a/Compiler/circuit_oram.py b/Compiler/circuit_oram.py index 3f2539c99..30c0d7b85 100644 --- a/Compiler/circuit_oram.py +++ b/Compiler/circuit_oram.py @@ -26,7 +26,7 @@ def find_deeper(a, b, path, start, length, compute_level=True): any_empty = OR(a.empty, b.empty) a_diff = [XOR(a_bits[i], path_bits[i]) for i in range(start, length)] b_diff = [XOR(b_bits[i], path_bits[i]) for i in range(start, length)] - diff = [XOR(ab, bb) for ab,bb in zip(a_bits, b_bits)[start:length]] + diff = [XOR(ab, bb) for ab,bb in list(zip(a_bits, b_bits))[start:length]] diff_preor = type(a.value).PreOR([any_empty] + diff) diff_first = [x - y for x,y in zip(diff_preor, diff_preor[1:])] winner = sum((ad * df for ad,df in zip(a_diff, diff_first)), a.empty) @@ -38,7 +38,7 @@ def find_deeper(a, b, path, start, length, compute_level=True): def find_deepest(paths, search_path, start, length, compute_level=True): if len(paths) == 1: return None, paths[0], 1 - l = len(paths) / 2 + l = len(paths) // 2 _, a, a_index = find_deepest(paths[:l], search_path, start, length, False) _, b, b_index = find_deepest(paths[l:], search_path, start, length, False) level, winner = find_deeper(a, b, search_path, start, length, compute_level) @@ -57,7 +57,7 @@ def greater_unary(a, b): if len(a) == 1: return a[0], b[0] else: - l = len(a) / 2 + l = len(a) // 2 return gu_step(greater_unary(a[l:], b[l:]), greater_unary(a[:l], b[:l])) def comp_step(high, low): @@ -75,7 +75,7 @@ def comp_binary(a, b): if len(a) == 1: return a[0], b[0] else: - l = len(a) / 2 + l = len(a) // 2 return comp_step(comp_binary(a[l:], b[l:]), comp_binary(a[:l], b[:l])) def unary_to_binary(l): @@ -89,8 +89,8 @@ def __init__(self, size, value_type=sgf2n, value_length=1, entry_size=None, \ self.D = log2(size) self.logD = log2(self.D) self.L = self.D + 1 - print 'create oram of size %d with depth %d and %d buckets' \ - % (size, self.D, self.n_buckets()) + print('create oram of size %d with depth %d and %d buckets' \ + % (size, self.D, self.n_buckets())) self.value_type = value_type self.index_type = value_type.get_type(self.D) if entry_size is not None: @@ -245,7 +245,7 @@ def recursive_evict(self): for i,_ in enumerate(self.recursive_evict_rounds()): Program.prog.curr_tape.start_new_basicblock(name='circuit-recursive-evict-round-%d-%d' % (i, self.size)) def recursive_evict_rounds(self): - for _ in itertools.izip(self.evict_rounds(), self.index.l.recursive_evict_rounds()): + for _ in zip(self.evict_rounds(), self.index.l.recursive_evict_rounds()): yield def bucket_indices_on_path_to(self, leaf): # root is at 1, different to PathORAM @@ -272,10 +272,10 @@ class DebugCircuitORAM(CircuitORAM): def OptimalCircuitORAM(size, value_type, *args, **kwargs): if size <= threshold: - print size, 'below threshold', threshold + print(size, 'below threshold', threshold) return LinearORAM(size, value_type, *args, **kwargs) else: - print size, 'above threshold', threshold + print(size, 'above threshold', threshold) return RecursiveCircuitORAM(size, value_type, *args, **kwargs) class RecursiveCircuitIndexStructure(PackedIndexStructure): diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 80d54c6ae..e9cb21dac 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -28,8 +28,8 @@ # (r[i], r[i]^-1, r[i] * r[i-1]^-1) do_precomp = True -import instructions_base -import util +from . import instructions_base +from . import util def set_variant(options): """ Set flags based on the command-line option provided """ @@ -55,7 +55,7 @@ def ld2i(c, n): """ Load immediate 2^n into clear GF(p) register c """ t1 = program.curr_block.new_reg('c') ldi(t1, 2 ** (n % 30)) - for i in range(n / 30): + for i in range(n // 30): t2 = program.curr_block.new_reg('c') mulci(t2, t1, 2 ** 30) t1 = t2 @@ -75,13 +75,13 @@ def LTZ(s, a, k, kappa): k: bit length of a """ - from types import sint + from .types import sint t = sint() Trunc(t, a, k, k - 1, kappa, True) subsfi(s, t, 0) def LessThanZero(a, k, kappa): - import types + from . import types res = types.sint() LTZ(res, a, k, kappa) return res @@ -124,7 +124,7 @@ def TruncZeroes(a, k, m, signed): if program.options.ring: return TruncLeakyInRing(a, k, m, signed) else: - import types + from . import types tmp = types.cint() inv2m(tmp, m) return a * tmp @@ -136,7 +136,7 @@ def TruncLeakyInRing(a, k, m, signed): """ assert k > m assert int(program.options.ring) >= k - from types import sint, intbitint, cint, cgf2n + from .types import sint, intbitint, cint, cgf2n n_bits = k - m n_shift = int(program.options.ring) - n_bits r_bits = [sint.get_random_bit() for i in range(n_bits)] @@ -165,7 +165,7 @@ def TruncRoundNearest(a, k, m, kappa, signed=False): # cannot work with bit length k+1 tmp = TruncRing(None, a, k, m - 1, signed) return TruncRing(None, tmp + 1, k - m + 1, 1, signed) - from types import sint + from .types import sint res = sint() Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed) return res @@ -277,7 +277,7 @@ def BitLTC1(u, a, b, kappa): """ k = len(b) p = [program.curr_block.new_reg('s') for i in range(k)] - import floatingpoint + from . import floatingpoint a_bits = floatingpoint.bits(a, k) if instructions_base.get_global_vector_size() == 1: a_ = a_bits @@ -357,12 +357,12 @@ def CarryOutAux(d, a, kappa): if k > 1 and k % 2 == 1: a.append(None) k += 1 - u = [None]*(k/2) + u = [None]*(k//2) a = a[::-1] if k > 1: - for i in range(k/2): - u[i] = carry(a[2*i+1], a[2*i], i != k/2-1) - CarryOutAux(d, u[:k/2][::-1], kappa) + for i in range(k//2): + u[i] = carry(a[2*i+1], a[2*i], i != k//2-1) + CarryOutAux(d, u[:k//2][::-1], kappa) else: movs(d, a[0][1]) @@ -376,7 +376,7 @@ def CarryOut(res, a, b, c=0, kappa=None): c: initial carry-in bit """ k = len(a) - import types + from . import types d = [program.curr_block.new_reg('s') for i in range(k)] t = [[types.sint() for i in range(k)] for i in range(4)] s = [program.curr_block.new_reg('s') for i in range(3)] @@ -394,7 +394,7 @@ def CarryOut(res, a, b, c=0, kappa=None): def CarryOutLE(a, b, c=0): """ Little-endian version """ - import types + from . import types res = types.sint() CarryOut(res, a[::-1], b[::-1], c) return res @@ -407,7 +407,7 @@ def BitLTL(res, a, b, kappa): b: array of secret bits (same length as a) """ k = len(b) - import floatingpoint + from . import floatingpoint a_bits = floatingpoint.bits(a, k) s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)] t = [program.curr_block.new_reg('s') for i in range(1)] @@ -547,7 +547,7 @@ def KMulC(a): """ Return just the product of all items in a """ - from types import sint, cint + from .types import sint, cint p = sint() if use_inv: PreMulC_with_inverses(p, a) @@ -582,7 +582,7 @@ def Mod2(a_0, a, k, kappa, signed): adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) asm_open(c, t[3]) - import floatingpoint + from . import floatingpoint c_0 = floatingpoint.bits(c, 1)[0] mulci(tc, c_0, 2) mulm(t[4], r_0, tc) @@ -591,4 +591,4 @@ def Mod2(a_0, a, k, kappa, signed): # hack for circular dependency -from instructions import * +from .instructions import * diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index a2773a64a..7a9fcbe9a 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -1,8 +1,8 @@ from Compiler.program import Program from Compiler.config import * from Compiler.exceptions import * -import instructions, instructions_base, types, comparison, library -import GC.types +from . import instructions, instructions_base, types, comparison, library +from .GC import types as GC_types import random import time @@ -25,36 +25,36 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \ prog.DEBUG = debug VARS['program'] = prog if options.binary: - VARS['sint'] = GC.types.sbitint.get_type(int(options.binary)) - VARS['sfix'] = GC.types.sbitfix + VARS['sint'] = GC_types.sbitint.get_type(int(options.binary)) + VARS['sfix'] = GC_types.sbitfix comparison.set_variant(options) - print 'Compiling file', prog.infile + print('Compiling file', prog.infile) # no longer needed, but may want to support assembly in future (?) if assemblymode: prog.restart_main_thread() - for i in xrange(INIT_REG_MAX): + for i in range(INIT_REG_MAX): VARS['c%d'%i] = prog.curr_block.new_reg('c') VARS['s%d'%i] = prog.curr_block.new_reg('s') VARS['cg%d'%i] = prog.curr_block.new_reg('cg') VARS['sg%d'%i] = prog.curr_block.new_reg('sg') if i % 10000000 == 0 and i > 0: - print "Initialized %d register variables at" % i, time.asctime() + print("Initialized %d register variables at" % i, time.asctime()) # first pass determines how many assembler registers are used prog.FIRST_PASS = True - execfile(prog.infile, VARS) + exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS) if instructions_base.Instruction.count != 0: - print 'instructions count', instructions_base.Instruction.count + print('instructions count', instructions_base.Instruction.count) instructions_base.Instruction.count = 0 prog.FIRST_PASS = False prog.reset_values() # make compiler modules directly accessible sys.path.insert(0, 'Compiler') # create the tapes - execfile(prog.infile, VARS) + exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS) # optimize the tapes for tape in prog.tapes: @@ -66,14 +66,14 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \ sharedmem = list(prog.mem_s) prog.emulate() if prog.mem_c != clearmem or prog.mem_s != sharedmem: - print 'Warning: emulated memory values changed after compiler optimization' + print('Warning: emulated memory values changed after compiler optimization') # raise CompilerError('Compiler optimization caused incorrect memory write.') if prog.main_thread_running: prog.update_req(prog.curr_tape) - print 'Program requires:', repr(prog.req_num) - print 'Cost:', 0 if prog.req_num is None else prog.req_num.cost() - print 'Memory size:', dict(prog.allocated_mem) + print('Program requires:', repr(prog.req_num)) + print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) + print('Memory size:', dict(prog.allocated_mem)) # finalize the memory prog.finalize_memory() diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py index 1011cd589..61a652477 100644 --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -86,8 +86,8 @@ def __init__(self, max_size, oram_type=ORAM, init_rounds=-1, int_type=sint): self.size = MemValue(int_type(0)) self.int_type = int_type self.basic_type = basic_type - print 'heap: %d levels, depth %d, size %d, index size %d' % \ - (self.levels, self.depth, self.heap.oram.size, self.value_index.size) + print('heap: %d levels, depth %d, size %d, index size %d' % \ + (self.levels, self.depth, self.heap.oram.size, self.value_index.size)) def update(self, value, prio, for_real=True): self._update(self.basic_type.hard_conv(value), \ self.basic_type.hard_conv(prio), \ @@ -217,7 +217,7 @@ def dump(self, msg=''): def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint): basic_type = int_type.basic_type - vert_loops = n_loops * e_index.size / edges.size \ + vert_loops = n_loops * e_index.size // edges.size \ if n_loops else -1 dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \ init_rounds=vert_loops, value_type=basic_type) @@ -287,7 +287,7 @@ def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint): cint(i).print_reg('edge') time() edges[i] = edges_list[i] - vert_loops = n_loops * e_index.size / edges.size \ + vert_loops = n_loops * e_index.size // edges.size \ if n_loops else e_index.size for i in range(vert_loops): cint(i).print_reg('vert') @@ -307,7 +307,7 @@ def f(i): time() neighbour = ((i >> 1) + 2 * (i % 2) - 1 + n) % n edges[i] = (neighbour, 1, i % 2) - vert_loops = n_loops * e_index.size / edges.size \ + vert_loops = n_loops * e_index.size // edges.size \ if n_loops else e_index.size @for_range(vert_loops) def f(i): @@ -390,14 +390,14 @@ def __repr__(self): class Vector(object): """ Works like a vector. """ def __add__(self, other): - print 'add', type(self) + print('add', type(self)) res = type(self)(len(self)) @for_range(len(self)) def f(i): res[i] = self[i] + other[i] return res def __sub__(self, other): - print 'sub', type(other) + print('sub', type(other)) res = type(other)(len(self)) @for_range(len(self)) def f(i): @@ -412,7 +412,7 @@ def f(i): res[0] += self[i] * other[i] return res[0] else: - print 'mul', type(self) + print('mul', type(self)) res = type(self)(len(self)) @for_range_parallel(1024, len(self)) def f(i): @@ -477,7 +477,7 @@ def binarymin(A): if len(A) == 1: return [1], A[0] else: - half = len(A) / 2 + half = len(A) // 2 A_prime = VectorArray(half) B = IntVectorArray(half) i = IntVectorArray(len(A)) diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index e673fe616..2a51ebd46 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -1,9 +1,9 @@ from math import log, floor, ceil from Compiler.instructions import * -import types -import comparison -import program -import util +from . import types +from . import comparison +from . import program +from . import util ## ## Helper functions for floating point arithmetic @@ -16,7 +16,7 @@ def two_power(n): else: max = types.cint(1) << 31 res = 2**(n%31) - for i in range(n / 31): + for i in range(n // 31): res *= max return res @@ -25,7 +25,7 @@ def shift_two(n, pos): return n >> pos else: res = (n >> (pos%63)) - for i in range(pos / 63): + for i in range(pos // 63): res >>= 63 return res @@ -139,7 +139,7 @@ def PreOpL(op, items): kmax = 2**logk output = list(items) for i in range(logk): - for j in range(kmax/(2**(i+1))): + for j in range(kmax//(2**(i+1))): y = two_power(i) + j*two_power(i+1) - 1 for z in range(1, 2**i+1): if y+z < k: @@ -153,7 +153,7 @@ def PreOpL2(op, items): op must be a binary function that outputs a new register """ k = len(items) - half = k / 2 + half = k // 2 output = list(items) if k == 0: return [] @@ -161,7 +161,7 @@ def PreOpL2(op, items): v = PreOpL2(op, u) for i in range(half): output[2 * i + 1] = v[i] - for i in range(1, (k + 1) / 2): + for i in range(1, (k + 1) // 2): output[2 * i] = op(v[i - 1], items[2 * i]) return output @@ -185,8 +185,8 @@ def KOpL(op, a): if k == 1: return a[0] else: - t1 = KOpL(op, a[:k/2]) - t2 = KOpL(op, a[k/2:]) + t1 = KOpL(op, a[:k//2]) + t2 = KOpL(op, a[k//2:]) return op(t1, t2) def KORL(a, kappa): @@ -195,8 +195,8 @@ def KORL(a, kappa): if k == 1: return a[0] else: - t1 = KORL(a[:k/2], kappa) - t2 = KORL(a[k/2:], kappa) + t1 = KORL(a[:k//2], kappa) + t2 = KORL(a[k//2:], kappa) return t1 + t2 - t1*t2 def KORC(a, kappa): @@ -234,7 +234,7 @@ def BitAdd(a, b, bits_to_compute=None): bits s[0], ... , s[k] """ k = len(a) if not bits_to_compute: - bits_to_compute = range(k) + bits_to_compute = list(range(k)) d = [None] * k for i in range(1,k): #assert(a[i].value == 0 or a[i].value == 1) @@ -248,25 +248,25 @@ def BitAdd(a, b, bits_to_compute=None): # (for testing) def print_state(): - print 'a: ', + print('a: ', end=' ') for i in range(k): - print '%d ' % a[i].value, - print '\nb: ', + print('%d ' % a[i].value, end=' ') + print('\nb: ', end=' ') for i in range(k): - print '%d ' % b[i].value, - print '\nd: ', + print('%d ' % b[i].value, end=' ') + print('\nd: ', end=' ') for i in range(k): - print '%d ' % d[i][0].value, - print '\n ', + print('%d ' % d[i][0].value, end=' ') + print('\n ', end=' ') for i in range(k): - print '%d ' % d[i][1].value, - print '\n\npg:', + print('%d ' % d[i][1].value, end=' ') + print('\n\npg:', end=' ') for i in range(k): - print '%d ' % pg[i][0].value, - print '\n ', + print('%d ' % pg[i][0].value, end=' ') + print('\n ', end=' ') for i in range(k): - print '%d ' % pg[i][1].value, - print '' + print('%d ' % pg[i][1].value, end=' ') + print('') for bit in c: pass#assert(bit.value == 0 or bit.value == 1) @@ -281,7 +281,7 @@ def print_state(): try: pass#assert(s[i].value == 0 or s[i].value == 1) except AssertionError: - print '#assertion failed in BitAdd for s[%d]' % i + print('#assertion failed in BitAdd for s[%d]' % i) print_state() s[k] = c[k-1] #print_state() @@ -316,9 +316,9 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None): try: pass#assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P) except AssertionError: - print 'BitDec assertion failed' - print 'a =', a.value - print 'a mod 2^%d =' % k, (a.value % 2**k) + print('BitDec assertion failed') + print('a =', a.value) + print('a mod 2^%d =' % k, (a.value % 2**k)) return types.intbitint.bit_adder(list(bits(c,m)), r) @@ -503,7 +503,7 @@ def TruncPrRing(a, k, m, signed=True): a += types.sint.get_random_bit() << i return comparison.TruncLeakyInRing(a, k, m, signed=signed) else: - from types import sint + from .types import sint if signed: a += (1 << (k - 1)) if program.Program.prog.use_trunc_pr: diff --git a/Compiler/graph.py b/Compiler/graph.py index 9ce87be72..43bea6ba5 100644 --- a/Compiler/graph.py +++ b/Compiler/graph.py @@ -19,10 +19,10 @@ def __init__(self, max_nodes, default_attributes=None): if default_attributes is None: default_attributes = { 'merges': None, 'stop': -1, 'start': -1 } self.default_attributes = default_attributes - self.attribute_pos = dict(zip(default_attributes.keys(), range(len(default_attributes)))) + self.attribute_pos = dict(list(zip(list(default_attributes.keys()), list(range(len(default_attributes)))))) self.n = max_nodes # each node contains list of default attributes, followed by outoing edges - self.nodes = [self.default_attributes.values() for i in range(self.n)] + self.nodes = [list(self.default_attributes.values()) for i in range(self.n)] self.succ = [set() for i in range(self.n)] self.pred = [set() for i in range(self.n)] self.weights = {} @@ -45,7 +45,7 @@ def add_node(self, i, **attr): raise CompilerError('Cannot add node %d to graph of size %d' % (i, self.n)) node = self.nodes[i] - for a,value in attr.items(): + for a,value in list(attr.items()): if a in self.default_attributes: node[self.attribute_pos[a]] = value else: @@ -72,7 +72,7 @@ def remove_node(self, i): #del self.weights[(v,i)] #self.nodes[v].remove(i) self.pred[i] = [] - self.nodes[i] = self.default_attributes.values() + self.nodes[i] = list(self.default_attributes.values()) def add_edge(self, i, j, weight=1): if j not in self[i]: @@ -111,7 +111,7 @@ def get_children(node): return G[node] else: def get_children(node): - if pref.has_key(node): + if node in pref: pref_set = set(pref[node]) for i in G[node]: if i not in pref_set: @@ -123,7 +123,7 @@ def get_children(node): yield i if nbunch is None: - nbunch = reversed(range(len(G))) + nbunch = reversed(list(range(len(G)))) for v in nbunch: # process all vertices in G if v in explored: continue @@ -170,8 +170,8 @@ def reverse_dag_shortest_paths(G, source): dist[source] = 0 for u in top_order: if u ==68273: - print 'dist[68273]', dist[u] - print 'pred[u]', G.pred[u] + print('dist[68273]', dist[u]) + print('pred[u]', G.pred[u]) if dist[u] is None: continue for v in G.pred[u]: @@ -207,7 +207,7 @@ def longest_paths(G, sources=None): G.weights[edge] = -G.weights[edge] dist = {} for source in sources: - print ('%s, ' % source), + print(('%s, ' % source), end=' ') dist[source] = dag_shortest_paths(G, source) #dist = johnson(G, sources) # reset weights diff --git a/Compiler/gs.py b/Compiler/gs.py index b7ad8ee6c..ab3b958f2 100644 --- a/Compiler/gs.py +++ b/Compiler/gs.py @@ -5,8 +5,8 @@ from Compiler.util import * -from oram import OptimalORAM,LinearORAM,RecursiveORAM,TrivialORAM,Entry -from library import for_range,do_while,time,start_timer,stop_timer,if_,print_ln,crash,print_str +from .oram import OptimalORAM,LinearORAM,RecursiveORAM,TrivialORAM,Entry +from .library import for_range,do_while,time,start_timer,stop_timer,if_,print_ln,crash,print_str class OMatrixRow(object): def __init__(self, oram, base, add_type): @@ -27,7 +27,7 @@ def read(self, offset): class OMatrix: def __init__(self, N, M=None, oram_type=OptimalORAM, int_type=types.sint): - print 'matrix', oram_type + print('matrix', oram_type) self.N = N self.M = M or N self.oram = oram_type(N * self.M, entry_size=log2(N), init_rounds=0, \ @@ -73,7 +73,7 @@ def __setitem__(self, index, value): class OStack: def __init__(self, N, oram_type=OptimalORAM, int_type=types.sint): - print 'stack', oram_type + print('stack', oram_type) self.oram = oram_type(N, entry_size=log2(N), init_rounds=0, \ value_type=int_type.basic_type) self.size = types.MemValue(int_type(0)) @@ -247,4 +247,4 @@ def __init__(self, N, M=None, reverse=False, oram_type=OptimalORAM, \ self.reverse = reverse self.int_type = int_type self.basic_type = int_type.basic_type - print 'match', self.oram_type + print('match', self.oram_type) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 12f4f33e8..2209cfac7 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -11,7 +11,7 @@ """ import itertools -import tools +from . import tools from random import randint from Compiler.config import * from Compiler.exceptions import * @@ -51,7 +51,7 @@ def execute(self): @base.vectorize class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction): r""" Assigns register $c_i$ the value in memory \verb+C[n]+. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['LDMC'] arg_format = ['cw','int'] @@ -62,7 +62,7 @@ def execute(self): @base.vectorize class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction): r""" Assigns register $s_i$ the value in memory \verb+S[n]+. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['LDMS'] arg_format = ['sw','int'] @@ -73,7 +73,7 @@ def execute(self): @base.vectorize class stmc(base.DirectMemoryWriteInstruction): r""" Sets \verb+C[n]+ to be the value $c_i$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['STMC'] arg_format = ['c','int'] @@ -84,7 +84,7 @@ def execute(self): @base.vectorize class stms(base.DirectMemoryWriteInstruction): r""" Sets \verb+S[n]+ to be the value $s_i$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['STMS'] arg_format = ['s','int'] @@ -94,7 +94,7 @@ def execute(self): @base.vectorize class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction): r""" Assigns register $ci_i$ the value in memory \verb+Ci[n]+. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['LDMINT'] arg_format = ['ciw','int'] @@ -104,7 +104,7 @@ def execute(self): @base.vectorize class stmint(base.DirectMemoryWriteInstruction): r""" Sets \verb+Ci[n]+ to be the value $ci_i$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['STMINT'] arg_format = ['ci','int'] @@ -227,7 +227,7 @@ class protectmemint(base.Instruction): @base.vectorize class movc(base.Instruction): r""" Assigns register $c_i$ the value in the register $c_j$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['MOVC'] arg_format = ['cw','c'] @@ -238,7 +238,7 @@ def execute(self): @base.vectorize class movs(base.Instruction): r""" Assigns register $s_i$ the value in the register $s_j$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['MOVS'] arg_format = ['sw','s'] @@ -248,7 +248,7 @@ def execute(self): @base.vectorize class movint(base.Instruction): r""" Assigns register $ci_i$ the value in the register $ci_j$. """ - __slots__ = ["code"] + __slots__ = [] code = base.opcodes['MOVINT'] arg_format = ['ciw','ci'] @@ -347,6 +347,21 @@ class use_prep(base.Instruction): code = base.opcodes['USE_PREP'] arg_format = ['str','int'] +class nplayers(base.Instruction): + r""" Number of players """ + code = base.opcodes['NPLAYERS'] + arg_format = ['ciw'] + +class threshold(base.Instruction): + r""" Maximal number of corrupt players """ + code = base.opcodes['THRESHOLD'] + arg_format = ['ciw'] + +class playerid(base.Instruction): + r""" My player number """ + code = base.opcodes['PLAYERID'] + arg_format = ['ciw'] + ### ### Basic arithmetic ### @@ -738,7 +753,7 @@ class shrci(base.ClearShiftInstruction): class triple(base.DataInstruction): r""" Load secret variables $s_i$, $s_j$ and $s_k$ with the next multiplication triple. """ - __slots__ = ['data_type'] + __slots__ = [] code = base.opcodes['TRIPLE'] arg_format = ['sw','sw','sw'] data_type = 'triple' @@ -752,7 +767,7 @@ def execute(self): class gbittriple(base.DataInstruction): r""" Load secret variables $s_i$, $s_j$ and $s_k$ with the next GF(2) multiplication triple. """ - __slots__ = ['data_type'] + __slots__ = [] code = base.opcodes['GBITTRIPLE'] arg_format = ['sgw','sgw','sgw'] data_type = 'bittriple' @@ -1400,7 +1415,8 @@ def __init__(self, *args, **kwargs): bitlength = program.bit_length if bitlength is None else bitlength if bitlength > 64: raise CompilerError('%d-bit conversion requested ' \ - 'but integer registers only have 64 bits') + 'but integer registers only have 64 bits' % \ + bitlength) super(convmodp_class, self).__init__(*(args + (bitlength,))) @base.vectorize @@ -1433,7 +1449,7 @@ class muls(base.VarArgsInstruction, base.DataInstruction): data_type = 'triple' def get_repeat(self): - return len(self.args) / 3 + return len(self.args) // 3 def merge_id(self): # can merge different sizes @@ -1508,6 +1524,8 @@ def arg_format(self): for j in range(self.args[i] - 2): yield 's' + field + gf2n_arg_format = arg_format + def bases(self): i = 0 while i < len(self.args): @@ -1515,7 +1533,7 @@ def bases(self): i += self.args[i] def get_repeat(self): - return sum(self.args[i] / 2 for i in self.bases()) * self.get_size() + return sum(self.args[i] // 2 for i in self.bases()) * self.get_size() def get_def(self): return [self.args[i + 1] for i in self.bases()] @@ -1567,7 +1585,7 @@ class lts(base.CISC): arg_format = ['sw', 's', 's', 'int', 'int'] def expand(self): - from types import sint + from .types import sint a = sint() subs(a, self.args[1], self.args[2]) comparison.LTZ(self.args[0], a, self.args[3], self.args[4]) diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index d4baa25d5..8683c305b 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -56,6 +56,9 @@ USE_PREP = 0x1C, STARTGRIND = 0x1D, STOPGRIND = 0x1E, + NPLAYERS = 0xE2, + THRESHOLD = 0xE3, + PLAYERID = 0xE4, # Addition ADDC = 0x20, ADDS = 0x21, @@ -277,7 +280,7 @@ def gf2n(instruction): vectorized GF_2^n instruction if a modp version exists. """ global_dict = inspect.getmodule(instruction).__dict__ - if global_dict.has_key('v' + instruction.__name__): + if 'v' + instruction.__name__ in global_dict: vectorized = True else: vectorized = False @@ -316,7 +319,7 @@ class GF2N_Instruction(instruction_cls): if 'gf2n_arg_format' in instruction_cls.__dict__: arg_format = instruction_cls.gf2n_arg_format elif isinstance(instruction_cls.arg_format, itertools.repeat): - __f = instruction_cls.arg_format.next() + __f = next(instruction_cls.arg_format) if __f != 'int' and __f != 'p': arg_format = itertools.repeat(__f[0] + 'g' + __f[1:]) else: @@ -420,7 +423,7 @@ class ClearIntAF(RegisterArgFormat): class IntArgFormat(ArgFormat): @classmethod def check(cls, arg): - if not isinstance(arg, (int, long)): + if not isinstance(arg, int): raise ArgumentError(arg, 'Expected an integer-valued argument') @classmethod @@ -512,7 +515,7 @@ def __init__(self, *args, **kwargs): Instruction.count += 1 if Instruction.count % 100000 == 0: - print "Compiled %d lines at" % self.__class__.count, time.asctime() + print("Compiled %d lines at" % self.__class__.count, time.asctime()) def get_code(self): return self.code @@ -535,7 +538,7 @@ def execute(self): def check_args(self): """ Check the args match up with that specified in arg_format """ - for n,(arg,f) in enumerate(itertools.izip_longest(self.args, self.arg_format)): + for n,(arg,f) in enumerate(itertools.zip_longest(self.args, self.arg_format)): if arg is None: if not isinstance(self.arg_format, (list, tuple)): break # end of optional arguments @@ -758,8 +761,6 @@ def check_args(self): else: # assume 64-bit machine bits = 63 - elif program.options.ring: - bits = int(program.options.ring) - 1 if self.args[2] > bits: raise CompilerError('Shifting by more than %d bits ' 'not implemented' % bits) diff --git a/Compiler/library.py b/Compiler/library.py index 25cf861ef..5314ee94c 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1,4 +1,4 @@ -from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single +from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint from Compiler.instructions import * from Compiler.util import tuplify,untuplify from Compiler import instructions,instructions_base,comparison,program,util @@ -6,6 +6,7 @@ import random import collections import operator +from functools import reduce def get_program(): return instructions.program @@ -101,6 +102,8 @@ def print_ln_if(cond, ss, *args): else: subs = ss.split('%s') assert len(subs) == len(args) + 1 + if isinstance(cond, localint): + cond = cond._v cond = cint.conv(cond) for i, s in enumerate(subs): if i != 0: @@ -274,7 +277,7 @@ def start(self): def join(self): self.thread.join() instructions.program.free(self.base, 'ci') - for reg_type,addr in self.bases.iteritems(): + for reg_type,addr in self.bases.items(): get_program().free(addr, reg_type.reg_type) class Function: @@ -287,7 +290,7 @@ def __init__(self, function, name=None, compile_args=[]): self.compile_args = compile_args def __call__(self, *args): args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args) - get_reg_type = lambda x: regint if isinstance(x, (int, long)) else type(x) + get_reg_type = lambda x: regint if isinstance(x, int) else type(x) if len(args) not in self.type_args: # first call type_args = collections.defaultdict(list) @@ -296,9 +299,11 @@ def __call__(self, *args): def wrapped_function(*compile_args): base = get_arg() bases = dict((t, regint.load_mem(base + i)) \ - for i,t in enumerate(sorted(type_args))) + for i,t in enumerate(sorted(type_args, + key=lambda x: + x.reg_type))) runtime_args = [None] * len(args) - for t in sorted(type_args): + for t in sorted(type_args, key=lambda x: x.reg_type): for i,i_arg in enumerate(type_args[t]): runtime_args[i_arg] = t.load_mem(bases[t] + i) return self.function(*(list(compile_args) + runtime_args)) @@ -308,7 +313,8 @@ def wrapped_function(*compile_args): base = instructions.program.malloc(len(type_args), 'ci') bases = dict((t, get_program().malloc(len(type_args[t]), t)) \ for t in type_args) - for i,reg_type in enumerate(sorted(type_args)): + for i,reg_type in enumerate(sorted(type_args, + key=lambda x: x.reg_type)): store_in_mem(bases[reg_type], base + i) for j,i_arg in enumerate(type_args[reg_type]): if get_reg_type(args[i_arg]) != reg_type: @@ -353,13 +359,13 @@ def on_first_call(self, wrapped_function): block.alloc_pool = defaultdict(set) del parent_node.children[-1] self.node = get_tape().req_node - print 'Compiling function', self.name + print('Compiling function', self.name) result = wrapped_function(*self.compile_args) if result is not None: self.result = memorize(result) else: self.result = None - print 'Done compiling function', self.name + print('Done compiling function', self.name) p_return_address = get_tape().program.malloc(1, 'ci') get_tape().function_basicblocks[block] = p_return_address return_address = regint.load_mem(p_return_address) @@ -429,7 +435,7 @@ def sort(a): res = a for i in range(len(a)): - for j in reversed(range(i)): + for j in reversed(list(range(i))): res[j], res[j+1] = cond_swap(res[j], res[j+1]) return res @@ -443,7 +449,7 @@ def odd_even_merge(a): odd_even_merge(even) odd_even_merge(odd) a[0] = even[0] - for i in range(1, len(a) / 2): + for i in range(1, len(a) // 2): a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i]) a[-1] = odd[-1] @@ -451,8 +457,8 @@ def odd_even_merge_sort(a): if len(a) == 1: return elif len(a) % 2 == 0: - lower = a[:len(a)/2] - upper = a[len(a)/2:] + lower = a[:len(a)//2] + upper = a[len(a)//2:] odd_even_merge_sort(lower) odd_even_merge_sort(upper) a[:] = lower + upper @@ -472,10 +478,10 @@ def chunky_odd_even_merge_sort(a): def round(): for i in range(len(a)): a[i] = type(a[i]).load_mem(i * a[i].sizeof()) - for i in range(len(a) / l): - for j in range(l / k): + for i in range(len(a) // l): + for j in range(l // k): base = i * l + j - step = l / k + step = l // k if k == 2: a[base], a[base+step] = cond_swap(a[base], a[base+step]) else: @@ -514,7 +520,7 @@ def run_threads(): def run_chunk(size, base): if size not in chunks: def swap_list(list_base): - for i in range(size / 2): + for i in range(size // 2): base = list_base + 2 * i x, y = cond_swap(load_secret_mem(base), load_secret_mem(base + 1)) @@ -526,8 +532,8 @@ def swap_list(list_base): def run_round(size): # minimize number of chunk sizes n_chunks = int(math.ceil(1.0 * size / max_chunk_size)) - lower_size = size / n_chunks / 2 * 2 - n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2 + lower_size = size // n_chunks // 2 * 2 + n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2 # print len(to_swap) == lower_size * n_lower_size + \ # (lower_size + 2) * (n_chunks - n_lower_size), \ # len(to_swap), n_chunks, lower_size, n_lower_size @@ -603,10 +609,10 @@ def run_postproc(): k *= 2 size = 0 instructions.program.curr_tape.merge_opens = False - for i in range(n / l): - for j in range(l / k): + for i in range(n // l): + for j in range(l // k): base = i * l + j - step = l / k + step = l // k size += run_setup(k, a_base + base, step, tmp_base + size) run_threads_in_rounds(pre_threads) run_round(size) @@ -651,7 +657,7 @@ def run_threads_in_rounds(all_threads): def run_chunk(size, base): if size not in chunks: def swap_list(list_base): - for i in range(size / 2): + for i in range(size // 2): base = list_base + 2 * i x, y = cond_swap(load_secret_mem(base), load_secret_mem(base + 1)) @@ -663,8 +669,8 @@ def swap_list(list_base): def run_round(size): # minimize number of chunk sizes n_chunks = int(math.ceil(1.0 * size / max_chunk_size)) - lower_size = size / n_chunks / 2 * 2 - n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2 + lower_size = size // n_chunks // 2 * 2 + n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2 # print len(to_swap) == lower_size * n_lower_size + \ # (lower_size + 2) * (n_chunks - n_lower_size), \ # len(to_swap), n_chunks, lower_size, n_lower_size @@ -692,7 +698,7 @@ def load_and_store(x, y): def outer(i): def inner(j): base = j - step = l / k + step = l // k if k == 2: tmp_addr = regint.load_mem(tmp_i) load_and_store(base, tmp_addr) @@ -704,19 +710,19 @@ def inner2(m): load_and_store(m, tmp_addr) store_in_mem(tmp_addr + 1, tmp_i) range_loop(inner2, base + step, base + (k - 1) * step, step) - range_loop(inner, a_base + i * l, a_base + i * l + l / k) + range_loop(inner, a_base + i * l, a_base + i * l + l // k) instructions.program.curr_tape.merge_opens = False to_tmp = True store_in_mem(tmp_base, tmp_i) - range_loop(outer, n / l) + range_loop(outer, n // l) if k == 2: run_round(n) else: - run_round(n / k * (k - 2)) + run_round(n // k * (k - 2)) instructions.program.curr_tape.merge_opens = False to_tmp = False store_in_mem(tmp_base, tmp_i) - range_loop(outer, n / l) + range_loop(outer, n // l) if isinstance(a, list): instructions.program.restart_main_thread() @@ -734,15 +740,15 @@ def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32): k = 1 while k < l: k *= 2 - n_outer = len(a) / l - n_inner = l / k - n_innermost = 1 if k == 2 else k / 2 - 1 - @for_range_parallel(n_parallel / n_innermost / n_inner, n_outer) + n_outer = len(a) // l + n_inner = l // k + n_innermost = 1 if k == 2 else k // 2 - 1 + @for_range_parallel(n_parallel // n_innermost // n_inner, n_outer) def loop(i): - @for_range_parallel(n_parallel / n_innermost, n_inner) + @for_range_parallel(n_parallel // n_innermost, n_inner) def inner(j): base = i*l + j - step = l/k + step = l//k if k == 2: a[base], a[base+step] = cond_swap(a[base], a[base+step]) else: @@ -805,7 +811,7 @@ def loop_fn(i): # known loop count if condition(start): get_tape().req_node.children[-1].aggregator = \ - lambda x: ((stop - start) / step) * x[0] + lambda x: ((stop - start) // step) * x[0] def for_range(start, stop=None, step=None): """ Execute loop bodies consecutively """ @@ -840,7 +846,7 @@ def decorator(loop_body): my_n_parallel = n_parallel if isinstance(n_parallel, int): if isinstance(n_loops, int): - loop_rounds = n_loops / n_parallel \ + loop_rounds = n_loops // n_parallel \ if n_parallel < n_loops else 0 else: loop_rounds = n_loops / n_parallel @@ -884,7 +890,7 @@ def _(i): regint.push(k) return i + k my_n_parallel = n_opt_loops - loop_rounds = n_loops / my_n_parallel + loop_rounds = n_loops // my_n_parallel blocks = get_tape().basicblocks n_to_merge = 5 if loop_rounds == 1 and parent_block is blocks[-n_to_merge]: @@ -966,7 +972,7 @@ def new_body(i): indices = [] for n in reversed(split): indices.insert(0, i % n) - i /= n + i //= n return loop_body(*indices) return new_body new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req) @@ -979,7 +985,7 @@ def new_body(i): else: return dec def decorator(loop_body): - thread_rounds = n_loops / n_threads + thread_rounds = n_loops // n_threads remainder = n_loops % n_threads for t in thread_mem_req: if t != regint: @@ -1233,10 +1239,25 @@ def stop_timer(timer_id=0): stop(timer_id) get_tape().start_new_basicblock(name='post-stop-timer') +def get_number_of_players(): + res = regint() + nplayers(res) + return res + +def get_threshold(): + res = regint() + threshold(res) + return res + +def get_player_id(): + res = localint() + playerid(res._v) + return res + # Fixed point ops from math import ceil, log -from floatingpoint import PreOR, TruncPr, two_power, shift_two +from .floatingpoint import PreOR, TruncPr, two_power, shift_two def approximate_reciprocal(divisor, k, f, theta): """ @@ -1369,7 +1390,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): # no probabilistic truncation in binary circuits nearest = True res_f = f - f = max((k - nearest) / 2 + 1, f) + f = max((k - nearest) // 2 + 1, f) assert 2 * f > k - nearest theta = int(ceil(log(k/3.5) / log(2))) alpha = b.get_type(2 * k).two_power(2*f) @@ -1387,7 +1408,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): x = x.round(2*k, 2*f, kappa, nearest, signed=True) y = y.extend(2 * k) * (alpha + x).extend(2 * k) - y = y.round(k + 2 * f, 3 * f - res_f, kappa, nearest, signed=True) + y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True) return y def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False): """ diff --git a/Compiler/ml.py b/Compiler/ml.py index 4c8e59b4c..9d5790558 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -3,6 +3,7 @@ from Compiler.types import * from Compiler.types import _unreduced_squant from Compiler.library import * +from functools import reduce def log_e(x): return mpc_math.log_fx(x, math.e) @@ -129,7 +130,7 @@ def _(j, k): tmp[j][k] = sfix.unreduced_dot_product(a, b) if self.d_in * self.d_out < 100000: - print 'reduce at once' + print('reduce at once') @multithread(self.n_threads, self.d_in * self.d_out) def _(base, size): self.nabla_W.assign_vector( @@ -386,7 +387,7 @@ def input_from(self, player): s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) self.weights.input_from(player, budget=100000) self.bias.input_from(player) - print 'WARNING: assuming that bias quantization parameters are correct' + print('WARNING: assuming that bias quantization parameters are correct') self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params) @@ -404,7 +405,7 @@ def reduction(self): start_timer(2) n_outputs = reduce(operator.mul, self.output_shape) if n_outputs % self.n_threads == 0: - n_per_thread = n_outputs / self.n_threads + n_per_thread = n_outputs // self.n_threads @for_range_opt_multithread(self.n_threads, self.n_threads) def _(i): res = _unreduced_squant( @@ -556,7 +557,7 @@ def __init__(self, input_shape, output_shape, filter_size): self.filter_size = filter_size def input_from(self, player): - print 'WARNING: assuming that input and output quantization parameters are the same' + print('WARNING: assuming that input and output quantization parameters are the same') for s in self.input_squant, self.output_squant: s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) @@ -567,7 +568,7 @@ def forward(self, N=1): _, output_h, output_w, n_channels_out = self.output_shape n = input_h * input_w - print 'divisor: ', n + print('divisor: ', n) assert output_h == output_w == 1 assert n_channels_in == n_channels_out @@ -599,7 +600,7 @@ def _(c): acc += self.X[0][in_y][in_x][c].v #fc += 1 logn = int(math.log(n, 2)) - acc = (acc + n / 2) + acc = (acc + n // 2) if 2 ** logn == n: acc = acc.round(self.output_squant.params.k + logn, logn, nearest=True) @@ -614,7 +615,7 @@ def __init__(self, input_shape, _, output_shape): super(QuantReshape, self).__init__(input_shape, output_shape) def input_from(self, player): - print 'WARNING: assuming that input and output quantization parameters are the same' + print('WARNING: assuming that input and output quantization parameters are the same') _ = self.new_squant() for s in self.input_squant, _, self.output_squant: s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) @@ -628,7 +629,7 @@ def forward(self, N=1): class QuantSoftmax(QuantBase): def input_from(self, player): - print 'WARNING: assuming that input and output quantization parameters are the same' + print('WARNING: assuming that input and output quantization parameters are the same') for s in self.input_squant, self.output_squant: s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) @@ -666,14 +667,14 @@ def _(): N = self.layers[0].N assert self.layers[-1].N == N assert N % 2 == 0 - n = N / 2 + n = N // 2 @for_range(n) def _(i): self.layers[-1].Y[i] = 0 self.layers[-1].Y[i + n] = 1 n_per_epoch = int(math.ceil(1. * max(len(X) for X in self.X_by_label) / n)) - print '%d runs per epoch' % n_per_epoch + print('%d runs per epoch' % n_per_epoch) indices_by_label = [] for label, X in enumerate(self.X_by_label): indices = regint.Array(n * n_per_epoch) @@ -794,8 +795,8 @@ def _(k): x = x.reveal() print_ln_if((x > 1000) + (x < -1000), name + ': %s %s %s %s', - *[y.v.reveal() for y in old, red_old, \ - new, diff]) + *[y.v.reveal() for y in (old, red_old, \ + new, diff)]) if self.debug: d = delta_theta.get_vector().reveal() a = cfix.Array(len(d.v)) diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 23926f24e..b2cc4803e 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -474,7 +474,7 @@ def norm_simplified_SQ(b, k): m_odd = m_odd + z[i] # construct w, - k_over_2 = k / 2 + 1 + k_over_2 = k // 2 + 1 w_array = [0] * (k_over_2) w_array[0] = z[0] for i in range(1, k_over_2): @@ -510,7 +510,7 @@ def sqrt_simplified_fx(x): m_odd = (1 - 2 * m_odd) + m_odd w = (w * 2 - w) * (1-m_odd) + w # map number to use sfix format and instantiate the number - w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) / 2)) + w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2)) # obtains correct 2 ** (m/2) w = (w * (types.cfix(2 ** (1/2.0))) - w) * m_odd + w # produce x/ 2^(m/2) diff --git a/Compiler/oram.py b/Compiler/oram.py index 12786e4f2..93e34c844 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -4,6 +4,7 @@ import itertools import operator import sys +from functools import reduce from Compiler.types import * from Compiler.types import _secret @@ -95,7 +96,7 @@ def __init__(self, value, start, lengths, entries_per_block): prod_bits = [start * bit for bit in value_bits] anti_bits = [v - p for v,p in zip(value_bits,prod_bits)] self.lower = sum(bit << i for i,bit in enumerate(prod_bits[:length])) - self.bits = map(operator.add, anti_bits[:length], prod_bits[length:]) + \ + self.bits = list(map(operator.add, anti_bits[:length], prod_bits[length:])) + \ anti_bits[length:] self.adjust = if_else(start, 1 << length, cgf2n(1)) elif entries_per_block < 4: @@ -105,7 +106,7 @@ def __init__(self, value, start, lengths, entries_per_block): choice_bits = demux(start_bits) inv_bits = [1 - bit for bit in floatingpoint.PreOR(choice_bits, None)] mask_bits = sum(([x] * length for x in inv_bits), []) - lower_bits = map(operator.mul, value_bits, mask_bits) + lower_bits = list(map(operator.mul, value_bits, mask_bits)) self.lower = sum(bit << i for i,bit in enumerate(lower_bits)) self.bits = [sum(map(operator.mul, choice_bits, value_bits[i::length])) \ for i in range(length)] @@ -124,7 +125,7 @@ def __init__(self, value, start, lengths, entries_per_block): pre_bits = floatingpoint.PreOpL(lambda x,y,z=None: x + y, bits) inv_bits = [1 - bit for bit in pre_bits] mask_bits = sum(([x] * length for x in inv_bits), []) - lower_bits = map(operator.mul, value_bits, mask_bits) + lower_bits = list(map(operator.mul, value_bits, mask_bits)) masked = self.value - sum(bit << i for i,bit in enumerate(lower_bits)) self.lower = sum(bit << i for i,bit in enumerate(lower_bits)) self.bits = (masked / adjust).bit_decompose(used_bits) @@ -177,12 +178,12 @@ def demux_list(x): return [1] elif n == 1: return [1 - x[0], x[0]] - a = demux_list(x[:n/2]) - b = demux_list(x[n/2:]) + a = demux_list(x[:n//2]) + b = demux_list(x[n//2:]) n_a = len(a) a *= len(b) b = reduce(operator.add, ([i] * n_a for i in b)) - res = map(operator.mul, a, b) + res = list(map(operator.mul, a, b)) return res def demux_array(x, res=None): @@ -193,12 +194,12 @@ def demux_array(x, res=None): res[0] = 1 - x[0] res[1] = x[0] else: - a = Array(2**(n/2), type(x[0])) - a.assign(demux(x[:n/2])) - b = Array(2**(n-n/2), type(x[0])) - b.assign(demux(x[n/2:])) + a = Array(2**(n//2), type(x[0])) + a.assign(demux(x[:n//2])) + b = Array(2**(n-n//2), type(x[0])) + b.assign(demux(x[n//2:])) @for_range_multithread(get_n_threads(len(res)), \ - max(1, n_parallel / len(b)), len(a)) + max(1, n_parallel // len(b)), len(a)) def f(i): @for_range_parallel(n_parallel, len(b)) def f(j): @@ -234,7 +235,7 @@ def __mul__(self, other): return Value(other * self.value, other * self.empty) __rmul__ = __mul__ def equal(self, other, length=None): - if isinstance(other, (int, long)) and isinstance(self.value, (int, long)): + if isinstance(other, int) and isinstance(self.value, int): return (1 - self.empty) * (other == self.value) return (1 - self.empty) * self.value.equal(other, length) def reveal(self): @@ -252,9 +253,9 @@ def __repr__(self): try: value = self.empty while True: - if value in (1, 1L): + if value == 1: return '<>' - if value in (0, 0L): + if value == 0: return '<%s>' % str(self.value) value = value.value except: @@ -297,8 +298,8 @@ def __init__(self, v, x=None, empty=None, value_type=None): self.created_non_empty = False if x is None: v = iter(v) - self.is_empty = v.next() - self.v = v.next() + self.is_empty = next(v) + self.v = next(v) self.x = ValueTuple(v) else: if empty is None: @@ -332,7 +333,7 @@ def __add__(self, other): try: return Entry(i + j for i,j in zip(self, other)) except: - print self, other + print(self, other) raise def __sub__(self, other): return Entry(i - j for i,j in zip(self, other)) @@ -342,7 +343,7 @@ def __mul__(self, other): try: return Entry(other * i for i in self) except: - print self, other + print(self, other) raise __rmul__ = __mul__ def reveal(self): @@ -372,8 +373,8 @@ def f(): for t,array in zip(self.entry_type,oram.ram.l)] self.index = index def init_mem(self, empty_entry): - print 'init ram' - for a,value in zip(self.l, empty_entry.defaults.values()): + print('init ram') + for a,value in zip(self.l, list(empty_entry.defaults.values())): # don't use threads if n_threads explicitly set to 1 a.assign_all(value, n_threads != 1, conv=False) def get_empty_bits(self): @@ -392,14 +393,14 @@ def get_value_array(self, index): return [Value(self.l[2+index][i], self.l[0][i]) for i in range(self.size)] def __getitem__(self, index): if print_access: - print 'get', id(self), index + print('get', id(self), index) return Entry(a[index] for a in self.l) def __setitem__(self, index, value): if print_access: - print 'set', id(self), index + print('set', id(self), index) if not isinstance(value, Entry): raise Exception('entries only please: %s' % str(value)) - for i,(a,v) in enumerate(zip(self.l, value.values())): + for i,(a,v) in enumerate(zip(self.l, list(value.values()))): a[index] = v def __len__(self): return self.size @@ -524,7 +525,7 @@ def __init__(self, index, oram): self.value_type, self.entry_size = oram.internal_entry_size() self.size = oram.bucket_size def init_mem(self): - print 'init trivial oram' + print('init trivial oram') self.ram.init_mem(self.empty_entry(apply_type=False)) def search(self, read_index): if use_binary_search and self.value_type == sgf2n: @@ -554,7 +555,7 @@ def read_and_remove(self, read_index, skip=0): self.last_index = read_index found, empty = self.search(read_index) entries = [entry for entry in self.ram] - prod_entries = map(operator.mul, found, entries) + prod_entries = list(map(operator.mul, found, entries)) read_value = sum((entry.x.skip(skip) for entry in prod_entries), \ empty * empty_entry.x.skip(skip)) for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)): @@ -566,7 +567,7 @@ def read_and_maybe_remove(self, index): def read_and_remove_by_public(self, index): empty_entry = self.empty_entry(False) entries = [entry for entry in self.ram] - prod_entries = map(operator.mul, index, entries) + prod_entries = list(map(operator.mul, index, entries)) read_entry = reduce(operator.add, prod_entries) for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)): self.ram[i] = entry - prod_entry + index[i] * empty_entry @@ -574,7 +575,7 @@ def read_and_remove_by_public(self, index): @method_block def _read(self, index): found, empty = self.search(index) - read_value = sum(map(operator.mul, found, self.ram.get_values()), \ + read_value = sum(list(map(operator.mul, found, self.ram.get_values())), \ empty * self.empty_entry(False).x) return read_value, empty @method_block @@ -583,8 +584,8 @@ def _access(self, index, write, new_empty, *new_value): found, not_found = self.search(index) add_here = self.find_first_empty() entries = [entry for entry in self.ram] - prod_values = map(operator.mul, found, \ - (entry.x for entry in entries)) + prod_values = list(map(operator.mul, found, \ + (entry.x for entry in entries))) read_value = sum(prod_values, not_found * empty_entry.x) new_value = ValueTuple(new_value) \ if isinstance(new_value, (tuple, list)) \ @@ -699,15 +700,15 @@ def binary_search(self, index): for k in range(2**(j)): t = k + 2**(j) - 1 if k % 2 == 0: - M += bit_prods[(t-1)/2] * mult_tree[t] + M += bit_prods[(t-1)//2] * mult_tree[t] b = 1 - M.equal(0, 40, expand) for k in range(2**j): t = k + 2**j - 1 if k % 2 == 0: - v = bit_prods[(t-1)/2] * b - bit_prods[t] = bit_prods[(t-1)/2] - v + v = bit_prods[(t-1)//2] * b + bit_prods[t] = bit_prods[(t-1)//2] - v else: bit_prods[t] = v return bit_prods[n-1:n-1+self.size], 1 - bit_prods[0] @@ -734,7 +735,7 @@ def f(): print_ln('Bucket overflow') crash() if debug and not sum(add_here) and not new_entry.empty(): - print self.empty_entry() + print(self.empty_entry()) raise Exception('no space for %s in %s' % (str(new_entry), str(self))) self.check(new_entry=new_entry, op='add') def pop(self): @@ -746,7 +747,7 @@ def pop(self): pop_here = [prefix_empty[i+1] - prefix_empty[i] \ for i in range(len(self.ram))] entries = [entry for entry in self.ram] - prod_entries = map(operator.mul, pop_here, self.ram) + prod_entries = list(map(operator.mul, pop_here, self.ram)) result = (1 - sum(pop_here)) * empty_entry result = sum(prod_entries, result) for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)): @@ -980,7 +981,7 @@ def __init__(self, size, entry_size, value_type=sint, init_rounds=-1, \ @for_range(init_rounds if init_rounds > 0 else size) def f(i): self.l[0][i] = random_block(entry_size, value_type) - print 'index size:', size + print('index size:', size) def update(self, index, value, evict=None): read_value = self[index] #print 'read', index, read_value @@ -1005,7 +1006,7 @@ class TreeORAM(AbstractORAM): """ Tree ORAM. """ def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \ bucket_oram=TrivialORAM, init_rounds=-1): - print 'create oram of size', size + print('create oram of size', size) self.bucket_oram = bucket_oram # heuristic bucket size delta = 3 @@ -1013,9 +1014,9 @@ def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \ # size + 1 for bucket overflow check self.bucket_size = min(int(math.ceil((1 + delta) * k)), size + 1) self.D = log2(max(size / k, 2)) - print 'bucket size:', self.bucket_size - print 'depth:', self.D - print 'complexity:', self.bucket_size * (self.D + 1) + print('bucket size:', self.bucket_size) + print('depth:', self.D) + print('complexity:', self.bucket_size * (self.D + 1)) self.value_type = value_type if entry_size is not None: self.value_length = len(tuplify(entry_size)) @@ -1279,8 +1280,8 @@ def batch_init(self, values): # split into 2 if bucket size can't fit into one field elem if self.bucket_size + Program.prog.security > 128: parity = (empty_positions[i]+1) % 2 - half = (empty_positions[i]+1 - parity) / 2 - half_max = self.bucket_size / 2 + half = (empty_positions[i]+1 - parity) // 2 + half_max = self.bucket_size // 2 bits = floatingpoint.B2U(half, half_max, Program.prog.security)[0] bits2 = floatingpoint.B2U(half+parity, half_max, Program.prog.security)[0] @@ -1384,11 +1385,11 @@ def get_parallel(index_size, value_type, value_length): value_size = get_value_size(value_type) if value_type == sint: value_size *= 2 - res = max(1, min(50 * 32 / (value_length * value_size), \ - 800 * 32 / (value_length * index_size))) + res = max(1, min(50 * 32 // (value_length * value_size), \ + 800 * 32 // (value_length * index_size))) if comparison.const_rounds: - res = max(1, res / 2) - print 'Reading %d buckets in parallel' % res + res = max(1, res // 2) + print('Reading %d buckets in parallel' % res) return res class PackedIndexStructure(object): @@ -1403,7 +1404,7 @@ def __init__(self, size, entry_size=None, value_type=sint, init_rounds=-1, \ self.value_type = value_type for demux_bits in range(max_demux_bits + 1): self.log_entries_per_element = min(log2(size), \ - int(math.floor(math.log(float(get_value_size(value_type)) / \ + int(math.floor(math.log(float(get_value_size(value_type)) // \ sum(self.entry_size), 2)))) self.log_elements_per_block = \ max(0, min(demux_bits, log2(size) - \ @@ -1423,24 +1424,24 @@ def __init__(self, size, entry_size=None, value_type=sint, init_rounds=-1, \ self.elements_per_entry = len(self.split_sizes) self.log_elements_per_block = log2(self.elements_per_entry) self.log_entries_per_element = -self.log_elements_per_block - print 'split sizes:', self.split_sizes + print('split sizes:', self.split_sizes) self.log_entries_per_block = \ self.log_elements_per_block + self.log_entries_per_element self.elements_per_block = 2**self.log_elements_per_block self.entries_per_element = 2**self.log_entries_per_element self.entries_per_block = 2**self.log_entries_per_block self.used_bits = self.entries_per_element * sum(self.entry_size) - real_size = -(-size / self.entries_per_block) - print 'packed size:', real_size - print 'index size:', size - print 'entry size:', self.entry_size - print 'log(entries per element):', self.log_entries_per_element - print 'entries per element:', self.entries_per_element - print 'log(entries per block):', self.log_entries_per_block - print 'entries per block:', self.entries_per_block - print 'log(elements per block):', self.log_elements_per_block - print 'elements per block:', self.elements_per_block - print 'used bits:', self.used_bits + real_size = -(-size // self.entries_per_block) + print('packed size:', real_size) + print('index size:', size) + print('entry size:', self.entry_size) + print('log(entries per element):', self.log_entries_per_element) + print('entries per element:', self.entries_per_element) + print('log(entries per block):', self.log_entries_per_block) + print('entries per block:', self.entries_per_block) + print('log(elements per block):', self.log_elements_per_block) + print('elements per block:', self.elements_per_block) + print('used bits:', self.used_bits) entry_size = [self.used_bits] * self.elements_per_block if real_size > 1: # no need to init underlying ORAM, will be initialized implicitely @@ -1454,10 +1455,10 @@ def __init__(self, size, entry_size=None, value_type=sint, init_rounds=-1, \ self.index_type = self.l.index_type if init_rounds: if init_rounds > 0: - real_init_rounds = init_rounds * real_size / size + real_init_rounds = init_rounds * real_size // size else: real_init_rounds = real_size - print 'packed init rounds:', real_init_rounds + print('packed init rounds:', real_init_rounds) @for_range(real_init_rounds) def f(i): if random_init: @@ -1467,7 +1468,7 @@ def f(i): self.l[i] = [0] * self.elements_per_block time() print_ln('packed ORAM init %s/%s', i, real_init_rounds) - print 'index initialized, size', size + print('index initialized, size', size) def translate_index(self, index): """ Bit slicing *index* according parameters. Output is tuple (storage address, index with storage cell, index within @@ -1501,16 +1502,16 @@ def read(self, block): self.block = block self.index_vector = \ demux(bit_decompose(self.b, self.pack.log_elements_per_block)) - self.vector = map(operator.mul, self.index_vector, block) + self.vector = list(map(operator.mul, self.index_vector, block)) self.element = get_block(sum(self.vector), self.c, \ self.pack.entry_size, \ self.pack.entries_per_element) return tuple(self.element.get_slice()) def write(self, value): self.element.set_slice(value) - anti_vector = map(operator.sub, self.block, self.vector) + anti_vector = list(map(operator.sub, self.block, self.vector)) updated_vector = [self.element.value * i for i in self.index_vector] - updated_block = map(operator.add, anti_vector, updated_vector) + updated_block = list(map(operator.add, anti_vector, updated_vector)) return updated_block class MultiSlicer(object): def __init__(self, pack, index): @@ -1685,7 +1686,7 @@ def test_oram(oram_type, N, value_type=sint, iterations=100): value_type = value_type.get_type(32) index_type = value_type.get_type(log2(N)) start_grind() - print 'initialized' + print('initialized') print_ln('initialized') stop_timer() # synchronize @@ -1718,7 +1719,7 @@ def f(i): def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=100): oram = oram_type(N, value_type=value_type, entry_size=32, \ init_rounds=0) - print 'initialized' + print('initialized') print_reg(cint(0), 'init') stop_timer() # synchronize @@ -1731,11 +1732,11 @@ def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations= def f(i): oram.access(value_type(i % N), value_type(0), value_type(True)) oram.access(value_type(i % N), value_type(i % N), value_type(True)) - print 'first write' + print('first write') time() x = oram.access(value_type(i % N), value_type(0), value_type(False)) x[0][0].reveal().print_reg('writ') - print 'first read' + print('first read') # @for_range(iterations) # def f(i): # x = oram.access(value_type(i % N), value_type(0), value_type(False), \ @@ -1747,7 +1748,7 @@ def f(i): def test_batch_init(oram_type, N): value_type = sint oram = oram_type(N, value_type) - print 'initialized' + print('initialized') print_reg(cint(0), 'init') oram.batch_init([value_type(i) for i in range(N)]) print_reg(cint(0), 'done') diff --git a/Compiler/path_oram.py b/Compiler/path_oram.py index e1265a0ce..fb1601c3d 100644 --- a/Compiler/path_oram.py +++ b/Compiler/path_oram.py @@ -1,9 +1,10 @@ if '_Array' not in dir(): - from oram import * - import permutation + from Compiler.oram import * + from Compiler import permutation _Array = Array -import oram +from Compiler import oram +from functools import reduce #import pdb @@ -140,7 +141,7 @@ def __init__(self, size, value_type=sgf2n, value_length=1, entry_size=None, \ bucket_size=2, init_rounds=-1): #if size <= k: # raise CompilerError('ORAM size too small') - print 'create oram of size', size + print('create oram of size', size) self.bucket_oram = bucket_oram self.bucket_size = bucket_size self.D = log2(size) @@ -240,7 +241,7 @@ def evict(): self.state.write(self.value_type(leaf)) - print 'eviction leaf =', leaf + print('eviction leaf =', leaf) # load the path for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)): @@ -325,7 +326,7 @@ def read_and_remove_levels(self, u): # at most one 1 in found empty = 1 - sum(found) - prod_entries = map(operator.mul, found, entries) + prod_entries = list(map(operator.mul, found, entries)) read_value = sum((entry.x.skip(skip) for entry in prod_entries), \ empty * empty_entry.x.skip(skip)) for i,(j, entry, prod_entry) in enumerate(zip(ram_indices, entries, prod_entries)): @@ -528,7 +529,7 @@ def read_and_remove(self, u): values = (ValueTuple(x) for x in zip(*self.read_value)) not_empty = [1 - x for x in self.read_empty] read_empty = 1 - sum(not_empty) - read_value = sum(map(operator.mul, not_empty, values), \ + read_value = sum(list(map(operator.mul, not_empty, values)), \ ValueTuple(0 for i in range(self.value_length))) self.check(u) Program.prog.curr_tape.\ @@ -545,7 +546,7 @@ def buckets_on_path_to(self, leaf): yield bucket def bucket_indices_on_path_to(self, leaf): leaf = regint(leaf) - yield range(self.bucket_size) + yield list(range(self.bucket_size)) index = 0 for i in range(self.D): index = 2*index + 1 + regint(cint(leaf) & 1) @@ -742,7 +743,7 @@ def add(self, entry, state=None, evict=True): try: self.stash.add(e) except Exception: - print self + print(self) raise if evict: self.evict() diff --git a/Compiler/permutation.py b/Compiler/permutation.py index 288965333..79c32e275 100644 --- a/Compiler/permutation.py +++ b/Compiler/permutation.py @@ -69,7 +69,7 @@ def odd_even_merge(a, comp): odd_even_merge(even, comp) odd_even_merge(odd, comp) a[0] = even[0] - for i in range(1, len(a) / 2): + for i in range(1, len(a) // 2): a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i], comp) a[-1] = odd[-1] @@ -77,8 +77,8 @@ def odd_even_merge_sort(a, comp=bitwise_comparator): if len(a) == 1: return elif len(a) % 2 == 0: - lower = a[:len(a)/2] - upper = a[len(a)/2:] + lower = a[:len(a)//2] + upper = a[len(a)//2:] odd_even_merge_sort(lower, comp) odd_even_merge_sort(upper, comp) a[:] = lower + upper @@ -137,7 +137,7 @@ def random_perm(n): if not Program.prog.options.insecure: raise CompilerError('no secure implementation of Waksman permution, ' 'use --insecure to activate') - a = range(n) + a = list(range(n)) for i in range(n-1, 0, -1): j = randint(0, i) t = a[i] @@ -155,10 +155,10 @@ def configure_waksman(perm): n = len(perm) if n == 2: return [(perm[0], perm[0])] - I = [None] * (n/2) - O = [None] * (n/2) - p0 = [None] * (n/2) - p1 = [None] * (n/2) + I = [None] * (n//2) + O = [None] * (n//2) + p0 = [None] * (n//2) + p1 = [None] * (n//2) inv_perm = [0] * n for i, p in enumerate(perm): @@ -170,7 +170,7 @@ def configure_waksman(perm): except ValueError: break #print 'j =', j - O[j/2] = 0 + O[j//2] = 0 via = 0 j0 = j while True: @@ -178,10 +178,10 @@ def configure_waksman(perm): i = inv_perm[j] #print ' p0[%d] = %d' % (inv_perm[j]/2, j/2) - p0[i/2] = j/2 + p0[i//2] = j//2 - I[i/2] = i % 2 - O[j/2] = j % 2 + I[i//2] = i % 2 + O[j//2] = j % 2 #print ' O[%d] = %d' % (j/2, j % 2) if i % 2 == 1: i -= 1 @@ -198,7 +198,7 @@ def configure_waksman(perm): j += 1 #j, via = set_swapper(O, i, via, perm) #print ' p1[%d] = %d' % (i/2, perm[i]/2) - p1[i/2] = perm[i]/2 + p1[i//2] = perm[i]//2 #print ' i = %d, j = %d' %(i,j) if j == j0: @@ -206,8 +206,8 @@ def configure_waksman(perm): if None not in p0 and None not in p1: break - assert sorted(p0) == range(n/2) - assert sorted(p1) == range(n/2) + assert sorted(p0) == list(range(n//2)) + assert sorted(p1) == list(range(n//2)) p0_config = configure_waksman(p0) p1_config = configure_waksman(p1) return [I + O] + [a+b for a,b in zip(p0_config, p1_config)] @@ -219,22 +219,22 @@ def waksman(a, config, depth=0, start=0, reverse=False): a[0], a[1] = cond_swap_bit(a[0], a[1], config[depth][start]) return - a0 = [0] * (n/2) - a1 = [0] * (n/2) - for i in range(n/2): + a0 = [0] * (n//2) + a1 = [0] * (n//2) + for i in range(n//2): if reverse: - a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + n/2 + start]) + a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + n//2 + start]) else: a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + start]) waksman(a0, config, depth+1, start, reverse) - waksman(a1, config, depth+1, start + n/2, reverse) + waksman(a1, config, depth+1, start + n//2, reverse) - for i in range(n/2): + for i in range(n//2): if reverse: a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + start]) else: - a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + n/2 + start]) + a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + n//2 + start]) WAKSMAN_FUNCTIONS = {} @@ -263,11 +263,11 @@ def do_round(size, config_address, a_address, a2_address): outwards = 1 - inwards sizeval = size - #for k in range(n/2): - @for_range_parallel(200, n/2) + #for k in range(n//2): + @for_range_parallel(200, n//2) def f(k): j = cint(k) % sizeval - i = (cint(k) - j)/sizeval + i = (cint(k) - j)//sizeval base = 2*i*sizeval in1, in2 = (base+j+j*inwards), (base+j+j*inwards+1*inwards+sizeval*outwards) @@ -297,7 +297,7 @@ def f(k): # going into middle of network @for_range(logn) def f(i): - size.write(n/(2*nblocks)) + size.write(n//(2*nblocks)) conf_address = MemValue(config.address + depth.read()*n) do_round(size, conf_address, a.address, a2.address, 1) @@ -307,20 +307,20 @@ def f(i): nblocks.write(nblocks*2) depth.write(depth+1) - nblocks.write(nblocks/4) + nblocks.write(nblocks//4) depth.write(depth-2) # and back out @for_range(logn-1) def f(i): - size.write(n/(2*nblocks)) + size.write(n//(2*nblocks)) conf_address = MemValue(config.address + depth.read()*n) do_round(size, conf_address, a.address, a2.address, 0) for i in range(n): a[i] = a2[i] - nblocks.write(nblocks/2) + nblocks.write(nblocks//2) depth.write(depth-1) ## going into middle of network @@ -375,7 +375,7 @@ def config_shuffle(n, value_type): if n & (n-1) != 0: # pad permutation to power of 2 m = 2**int(math.ceil(math.log(n, 2))) - perm += range(n, m) + perm += list(range(n, m)) config_bits = configure_waksman(perm) # 2-D array config = Array(len(config_bits) * len(perm), value_type.reg_type) diff --git a/Compiler/program.py b/Compiler/program.py index 76c3b00bc..eebc411a3 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -3,15 +3,17 @@ from Compiler.instructions_base import RegType import Compiler.instructions import Compiler.instructions_base -import compilerLib -import allocator as al +from . import compilerLib +from . import allocator as al +from . import util import random import time import sys, os, errno import inspect -from collections import defaultdict +from collections import defaultdict, deque import itertools import math +from functools import reduce data_types = dict( @@ -50,11 +52,11 @@ def __init__(self, args, options, param=-1, assemblymode=False): self.bit_length = int(options.binary) or int(options.field) if not self.bit_length: self.bit_length = BIT_LENGTHS[param] - print 'Default bit length:', self.bit_length + print('Default bit length:', self.bit_length) self.security = 40 - print 'Default security parameter:', self.security + print('Default security parameter:', self.security) self.galois_length = int(options.galois) - print 'Galois length:', self.galois_length + print('Galois length:', self.galois_length) self.schedule = [('start', [])] self.tape_counter = 0 self.tapes = [] @@ -118,7 +120,7 @@ def max_par_tapes(self): running[tape] -= 1 else: raise CompilerError('Invalid schedule action') - res = max(res, sum(running.itervalues())) + res = max(res, sum(running.values())) return res def init_names(self, args, assemblymode): @@ -129,7 +131,7 @@ def init_names(self, args, assemblymode): else: # assume source is in main SPDZ directory self.programs_dir = sys.path[0] + '/Programs' - print 'Compiling program in', self.programs_dir + print('Compiling program in', self.programs_dir) # create extra directories if needed for dirname in ['Public-Input', 'Bytecode', 'Schedules']: @@ -225,7 +227,7 @@ def update_req(self, tape): def read_memory(self, filename): """ Read the clear and shared memory from a file """ f = open(filename) - n = int(f.next()) + n = int(next(f)) self.mem_c = [0]*n self.mem_s = [0]*n mem = self.mem_c @@ -253,8 +255,8 @@ def reset_values(self): """ Reset register and memory values. """ for tape in self.tapes: tape.reset_registers() - self.mem_c = range(USER_MEM + TMP_MEM) - self.mem_s = range(USER_MEM + TMP_MEM) + self.mem_c = list(range(USER_MEM + TMP_MEM)) + self.mem_s = list(range(USER_MEM + TMP_MEM)) def write_bytes(self, outfile=None): """ Write all non-empty threads and schedule to files. """ @@ -265,7 +267,7 @@ def write_bytes(self, outfile=None): sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name sch_file = open(sch_filename, 'w') - print 'Writing to', sch_filename + print('Writing to', sch_filename) sch_file.write(str(self.max_par_tapes()) + '\n') sch_file.write(str(len(nonempty_tapes)) + '\n') sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n') @@ -276,7 +278,7 @@ def write_bytes(self, outfile=None): for sch in self.schedule: # schedule may still contain empty tapes: ignore these - tapes = filter(lambda x: not x[0].is_empty(), sch[1]) + tapes = [x for x in sch[1] if not x[0].is_empty()] # no empty line if not tapes: continue @@ -358,7 +360,7 @@ def curr_block(self): def malloc(self, size, mem_type, reg_type=None): """ Allocate memory from the top """ - if not isinstance(size, (int, long)): + if not isinstance(size, int): raise CompilerError('size must be known at compile time') if size == 0: return @@ -374,7 +376,7 @@ def malloc(self, size, mem_type, reg_type=None): addr = self.allocated_mem[mem_type] self.allocated_mem[mem_type] += size if len(str(addr)) != len(str(addr + size)): - print "Memory of type '%s' now of size %d" % (mem_type, addr + size) + print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) self.allocated_mem_blocks[addr,mem_type] = size return addr @@ -387,11 +389,11 @@ def free(self, addr, mem_type): self.free_mem_blocks[size,mem_type].add(addr) def finalize_memory(self): - import library + from . import library self.curr_tape.start_new_basicblock(None, 'memory-usage') # reset register counter to 0 self.curr_tape.init_registers() - for mem_type,size in self.allocated_mem.items(): + for mem_type,size in list(self.allocated_mem.items()): if size: #print "Memory of type '%s' of size %d" % (mem_type, size) if mem_type in self.types: @@ -404,11 +406,11 @@ def public_input(self, x): def set_bit_length(self, bit_length): self.bit_length = bit_length - print 'Changed bit length for comparisons etc. to', bit_length + print('Changed bit length for comparisons etc. to', bit_length) def set_security(self, security): self.security = security - print 'Changed statistical security for comparison etc. to', security + print('Changed statistical security for comparison etc. to', security) def optimize_for_gc(self): pass @@ -500,9 +502,9 @@ def adjust_jump(self): #print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset) def purge(self): - relevant = lambda inst: inst.add_usage.__func__ is not \ - Compiler.instructions_base.Instruction.add_usage.__func__ - self.usage_instructions = filter(relevant, self.instructions) + relevant = lambda inst: inst.add_usage is not \ + Compiler.instructions_base.Instruction.add_usage + self.usage_instructions = list(filter(relevant, self.instructions)) del self.instructions del self.defined_registers self.purged = True @@ -568,8 +570,8 @@ def purge(self): def unpurged(function): def wrapper(self, *args, **kwargs): if self.purged: - print '%s called on purged block %s, ignoring' % \ - (function.__name__, self.name) + print('%s called on purged block %s, ignoring' % \ + (function.__name__, self.name)) return return function(self, *args, **kwargs) return wrapper @@ -577,13 +579,13 @@ def wrapper(self, *args, **kwargs): @unpurged def optimize(self, options): if len(self.basicblocks) == 0: - print 'Tape %s is empty' % self.name + print('Tape %s is empty' % self.name) return if self.if_states: raise CompilerError('Unclosed if/else blocks') - print 'Processing tape', self.name, 'with %d blocks' % len(self.basicblocks) + print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)) for block in self.basicblocks: al.determine_scope(block, options) @@ -593,38 +595,38 @@ def optimize(self, options): if (options.merge_opens and self.merge_opens) or options.dead_code_elimination: for i,block in enumerate(self.basicblocks): if len(block.instructions) > 0: - print 'Processing basic block %s, %d/%d, %d instructions' % \ + print('Processing basic block %s, %d/%d, %d instructions' % \ (block.name, i, len(self.basicblocks), \ - len(block.instructions)) + len(block.instructions))) # the next call is necessary for allocation later even without merging merger = al.Merger(block, options, \ tuple(self.program.to_merge)) if options.dead_code_elimination: if len(block.instructions) > 10000: - print 'Eliminate dead code...' + print('Eliminate dead code...') merger.eliminate_dead_code() if options.merge_opens and self.merge_opens: if len(block.instructions) == 0: - block.used_from_scope = set() - block.defined_registers = set() + block.used_from_scope = util.set_by_id() + block.defined_registers = util.set_by_id() continue if len(block.instructions) > 10000: - print 'Merging instructions...' + print('Merging instructions...') numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) if numrounds > 0: - print 'Program requires %d rounds of communication' % numrounds + print('Program requires %d rounds of communication' % numrounds) if merger.counter: - print 'Block requires', \ + print('Block requires', \ ', '.join('%d %s' % (y, x.__name__) \ - for x, y in merger.counter.items()) + for x, y in list(merger.counter.items()))) # free memory merger = None if options.dead_code_elimination: - block.instructions = filter(lambda x: x is not None, block.instructions) + block.instructions = [x for x in block.instructions if x is not None] if not (options.merge_opens and self.merge_opens): - print 'Not merging instructions in tape %s' % self.name + print('Not merging instructions in tape %s' % self.name) # add jumps offset = 0 @@ -640,39 +642,44 @@ def optimize(self, options): block.adjust_return() # now remove any empty blocks (must be done after setting jumps) - self.basicblocks = filter(lambda x: len(x.instructions) != 0, self.basicblocks) + self.basicblocks = [x for x in self.basicblocks if len(x.instructions) != 0] # allocate registers reg_counts = self.count_regs() if not options.noreallocate: if self.program.verbose: - print 'Tape register usage:', dict(reg_counts) - print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) - print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) - print 'Re-allocating...' + print('Tape register usage:', dict(reg_counts)) + print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])) + print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])) + print('Re-allocating...') allocator = al.StraightlineAllocator(REG_MAX) - def alloc_loop(block): + def alloc(block): for reg in sorted(block.used_from_scope, key=lambda x: (x.reg_type, x.i)): allocator.alloc_reg(reg, block.alloc_pool) - for child in block.children: - if child.instructions: - alloc_loop(child) + def alloc_loop(block): + left = deque([block]) + while left: + block = left.popleft() + alloc(block) + for child in block.children: + if child.instructions: + left.append(child) for i,block in enumerate(reversed(self.basicblocks)): if len(block.instructions) > 10000: - print 'Allocating %s, %d/%d' % \ - (block.name, i, len(self.basicblocks)) + print('Allocating %s, %d/%d' % \ + (block.name, i, len(self.basicblocks))) if block.exit_condition is not None: jump = block.exit_condition.get_relative_jump() - if isinstance(jump, (int,long)) and jump < 0 and \ + if isinstance(jump, int) and jump < 0 and \ block.exit_block.scope is not None: alloc_loop(block.exit_block.scope) allocator.process(block.instructions, block.alloc_pool) # offline data requirements - print 'Compile offline data requirements...' + print('Compile offline data requirements...') self.req_num = self.req_tree.aggregate() - print 'Tape requires', self.req_num + print('Tape requires', self.req_num) for req,num in sorted(self.req_num.items()): if num == float('inf') or num >= 2 ** 32: num = -1 @@ -706,8 +713,8 @@ def alloc_loop(block): Compiler.instructions.reqbl(bl, add_to_prog=False)) if self.program.verbose: - print 'Tape requires prime bit length', self.req_bit_length['p'] - print 'Tape requires galois bit length', self.req_bit_length['2'] + print('Tape requires prime bit length', self.req_bit_length['p']) + print('Tape requires galois bit length', self.req_bit_length['2']) @unpurged def _get_instructions(self): @@ -722,12 +729,12 @@ def get_encoding(self): @unpurged def get_bytes(self): """ Get the byte encoding of the program as an actual string of bytes. """ - return "".join(str(i.get_bytes()) for i in self._get_instructions() if i is not None) + return b"".join(i.get_bytes() for i in self._get_instructions() if i is not None) @unpurged def write_encoding(self, filename): """ Write the readable encoding to a file. """ - print 'Writing to', filename + print('Writing to', filename) f = open(filename, 'w') for line in self.get_encoding(): f.write(str(line) + '\n') @@ -736,7 +743,7 @@ def write_encoding(self, filename): @unpurged def write_str(self, filename): """ Write the sequence of instructions to a file. """ - print 'Writing to', filename + print('Writing to', filename) f = open(filename, 'w') n = 0 for block in self.basicblocks: @@ -756,8 +763,8 @@ def write_bytes(self, filename=None): filename += '.bc' if not 'Bytecode' in filename: filename = self.program.programs_dir + '/Bytecode/' + filename - print 'Writing to', filename - f = open(filename, 'w') + print('Writing to', filename) + f = open(filename, 'wb') f.write(self.get_bytes()) f.close() @@ -785,9 +792,9 @@ def __init__(self, init={}): super(Tape.ReqNum, self).__init__(lambda: 0, init) def __add__(self, other): res = Tape.ReqNum() - for i,count in self.items(): + for i,count in list(self.items()): res[i] += count - for i,count in other.items(): + for i,count in list(other.items()): res[i] += count return res def __mul__(self, other): @@ -798,7 +805,7 @@ def __mul__(self, other): __rmul__ = __mul__ def set_all(self, value): if value == float('inf') and self['all', 'inv'] > 0: - print 'Going to unknown from %s' % self + print('Going to unknown from %s' % self) res = Tape.ReqNum() for i in self: res[i] = value @@ -811,14 +818,14 @@ def max(self, other): res[i] = max(self[i], other[i]) return res def cost(self): - return sum(num * COST[req[0]][req[1]] for req,num in self.items() \ + return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \ if req[1] != 'input') def __str__(self): return ", ".join('%s inputs in %s from player %d' \ % (num, req[0], req[2]) \ if req[1] == 'input' \ else '%s %ss in %s' % (num, req[1], req[0]) \ - for req,num in self.items()) + for req,num in list(self.items())) def __repr__(self): return repr(dict(self)) @@ -853,8 +860,8 @@ def aggregate(self, name): n_rounds = res['all', 'round'] n_invs = res['all', 'inv'] if (n_invs / n_rounds) * 1000 < n_reps: - print self.nodes[0].blocks[0].name, 'blowing up rounds: ', \ - '(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs) + print(self.nodes[0].blocks[0].name, 'blowing up rounds: ', \ + '(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs)) except: pass return res @@ -892,7 +899,7 @@ class Register(object): The 'value' property is for emulation. """ - __slots__ = ["reg_type", "program", "i", "value", "_is_active", \ + __slots__ = ["reg_type", "program", "i", "_is_active", \ "size", "vector", "vectorbase", "caller", \ "can_eliminate"] @@ -925,7 +932,7 @@ def __init__(self, reg_type, program, value=None, size=None, i=None): else: self.caller = None if self.i % 1000000 == 0 and self.i > 0: - print "Initialized %d registers at" % self.i, time.asctime() + print("Initialized %d registers at" % self.i, time.asctime()) def set_size(self, size): if self.size == size: diff --git a/Compiler/types.py b/Compiler/types.py index 3fa5992a2..f72690a7b 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2,11 +2,12 @@ from Compiler.exceptions import * from Compiler.instructions import * from Compiler.instructions_base import * -from floatingpoint import two_power -import comparison, floatingpoint +from .floatingpoint import two_power +from . import comparison, floatingpoint import math -import util +from . import util import operator +from functools import reduce class ClientMessageType: @@ -127,15 +128,15 @@ def square(self): return self * self def __add__(self, other): - if other is 0 or other is 0L: + if other is 0: return self else: return self.add(other) def __mul__(self, other): - if other is 0 or other is 0L: + if other is 0: return 0 - elif other is 1 or other is 1L: + elif other is 1: return self else: return self.mul(other) @@ -301,7 +302,7 @@ def __init__(self, reg_type, val, size): if isinstance(val, (tuple, list)): size = len(val) super(_register, self).__init__(reg_type, program.curr_tape, size=size) - if isinstance(val, (int, long)): + if isinstance(val, int): self.load_int(val) elif isinstance(val, (tuple, list)): for i, x in enumerate(val): @@ -374,7 +375,7 @@ def clear_op(self, other, c_inst, ci_inst, reverse=False): res = self.prep_res(other) if isinstance(other, cls): c_inst(res, self, other) - elif isinstance(other, (int, long)): + elif isinstance(other, int): if self.in_immediate_range(other): ci_inst(res, self, other) else: @@ -392,7 +393,7 @@ def clear_op(self, other, c_inst, ci_inst, reverse=False): def coerce_op(self, other, inst, reverse=False): cls = self.__class__ res = cls() - if isinstance(other, (int, long)): + if isinstance(other, int): other = cls(other) elif not isinstance(other, cls): return NotImplemented @@ -414,14 +415,14 @@ def __sub__(self, other): def __rsub__(self, other): return self.clear_op(other, subc, subcfi, True) - def __div__(self, other): + def __truediv__(self, other): return self.clear_op(other, divc, divci) - def __rdiv__(self, other): + def __rtruediv__(self, other): return self.coerce_op(other, divc, True) def __eq__(self, other): - if isinstance(other, (_clear,int,long)): + if isinstance(other, (_clear,int)): return regint(self) == other else: return NotImplemented @@ -493,12 +494,12 @@ def load_int(self, val): ldi(self, val) else: max = 2**31 - 1 - sign = abs(val) / val + sign = abs(val) // val val = abs(val) chunks = [] while val: mod = val % max - val = (val - mod) / max + val = (val - mod) // max chunks.append(mod) sum = cint(sign * chunks.pop()) for i,chunk in enumerate(reversed(chunks)): @@ -520,13 +521,13 @@ def __rmod__(self, other): return self.coerce_op(other, modc, True) def __lt__(self, other): - if isinstance(other, (type(self),int,long)): + if isinstance(other, (type(self),int)): return regint(self) < other else: return NotImplemented def __gt__(self, other): - if isinstance(other, (type(self),int,long)): + if isinstance(other, (type(self),int)): return regint(self) > other else: return NotImplemented @@ -537,6 +538,23 @@ def __le__(self, other): def __ge__(self, other): return 1 - (self < other) + @vectorize + def __eq__(self, other): + if not isinstance(other, (_clear, int)): + return NotImplemented + res = 1 + remaining = program.bit_length + while remaining > 0: + if isinstance(other, cint): + o = other.to_regint(min(remaining, 64)) + else: + o = other % 2 ** 64 + res *= (self.to_regint(min(remaining, 64)) == o) + self >>= 64 + other >>= 64 + remaining -= 64 + return res + def __lshift__(self, other): return self.clear_op(other, shlc, shlci) @@ -683,7 +701,7 @@ def __rshift__(self, other): def bit_decompose(self, bit_length=None, step=None): bit_length = bit_length or program.galois_length step = step or 1 - res = [type(self)() for _ in range(bit_length / step)] + res = [type(self)() for _ in range(bit_length // step)] gbitdec(self, step, *res) return res @@ -817,12 +835,15 @@ def mul(self, other): def __neg__(self): return 0 - self - def __div__(self, other): + def __floordiv__(self, other): return self.int_op(other, divint) - def __rdiv__(self, other): + def __rfloordiv__(self, other): return self.int_op(other, divint, True) + __truediv__ = __floordiv__ + __rtruediv__ = __rfloordiv__ + def __mod__(self, other): return self - (self / other) * other @@ -851,13 +872,13 @@ def __ge__(self, other): return 1 - (self < other) def __lshift__(self, other): - if isinstance(other, (int, long)): + if isinstance(other, int): return self * 2**other else: return regint(cint(self) << other) def __rshift__(self, other): - if isinstance(other, (int, long)): + if isinstance(other, int): return self / 2**other else: return regint(cint(self) >> other) @@ -911,6 +932,24 @@ def print_reg_plain(self): def print_if(self, string): cint(self).print_if(string) +class localint(object): + """ Local integer that must prevented from leaking into the secure + computation. Uses regint internally. """ + + def __init__(self, value=None): + self._v = regint(value) + self.size = 1 + + def output(self): + self._v.print_reg_plain() + + __lt__ = lambda self, other: localint(self._v < other) + __le__ = lambda self, other: localint(self._v <= other) + __gt__ = lambda self, other: localint(self._v > other) + __ge__ = lambda self, other: localint(self._v >= other) + __eq__ = lambda self, other: localint(self._v == other) + __ne__ = lambda self, other: localint(self._v != other) + class _secret(_register): __slots__ = [] @@ -996,11 +1035,11 @@ def row_matrix_mul(cls, row, matrix, res_params=None): def matrix_mul(cls, A, B, n, res_params=None): assert len(A) % n == 0 assert len(B) % n == 0 - size = len(A) * len(B) / n**2 + size = len(A) * len(B) // n**2 res = cls(size=size) - n_rows = len(A) / n - n_cols = len(B) / n - dotprods(*sum(([res[j], [A[j / n_cols * n + k] for k in range(n)], + n_rows = len(A) // n + n_cols = len(B) // n + dotprods(*sum(([res[j], [A[j // n_cols * n + k] for k in range(n)], [B[k * n_cols + j % n_cols] for k in range(n)]] for j in range(size)), [])) return res @@ -1054,7 +1093,7 @@ def secret_op(self, other, s_inst, m_inst, si_inst, reverse=False): m_inst(res, other, self) else: m_inst(res, self, other) - elif isinstance(other, (int, long)): + elif isinstance(other, int): if self.clear_type.in_immediate_range(other): si_inst(res, self, other) else: @@ -1086,11 +1125,11 @@ def __rsub__(self, other): return self.secret_op(other, subs, submr, subsfi, True) @vectorize - def __div__(self, other): + def __truediv__(self, other): return self * (self.clear_type(1) / other) @vectorize - def __rdiv__(self, other): + def __rtruediv__(self, other): a,b = self.get_random_inverse() return other * a / (a * self).reveal() @@ -1253,7 +1292,7 @@ def __ne__(self, other, bit_length=None, security=None): @vectorize def __mod__(self, modulus): - if isinstance(modulus, (int, long)): + if isinstance(modulus, int): l = math.log(modulus, 2) if 2**int(round(l)) == modulus: return self.mod2m(int(l)) @@ -1405,7 +1444,7 @@ def __invert__(self): return self ^ cgf2n(2**program.galois_length - 1) def __xor__(self, other): - if other is 0 or other is 0L: + if other is 0: return self else: return super(sgf2n, self).add(other) @@ -1414,7 +1453,7 @@ def __xor__(self, other): @vectorize def __and__(self, other): - if isinstance(other, (int, long)): + if isinstance(other, int): other_bits = [(other >> i) & 1 \ for i in range(program.galois_length)] else: @@ -1515,7 +1554,7 @@ def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0, else: pre_op = floatingpoint.PreOpL if d: - carries = zip(*pre_op(carry, [(0, carry_in)] + d))[1] + carries = list(zip(*pre_op(carry, [(0, carry_in)] + d)))[1] else: carries = [] res = lower + cls.sum_from_carries(a, b, carries) @@ -1539,7 +1578,7 @@ def carry_select_adder(cls, a, b, get_carry=False, carry_in=0): for k in range(m, -1, -1): if sum(range(m, k - 1, -1)) + 1 >= n: break - blocks = range(m, k, -1) + blocks = list(range(m, k, -1)) blocks.append(n - sum(blocks)) blocks.reverse() #print 'blocks:', blocks @@ -1597,9 +1636,9 @@ def bit_less_than(cls, a, b): @staticmethod def get_highest_different_bits(a, b, index): - diff = [ai + bi for (ai,bi) in reversed(zip(a,b))] + diff = [ai + bi for (ai,bi) in reversed(list(zip(a,b)))] preor = floatingpoint.PreOR(diff, raw=True) - highest_diff = [x - y for (x,y) in reversed(zip(preor, [0] + preor))] + highest_diff = [x - y for (x,y) in reversed(list(zip(preor, [0] + preor)))] raw = sum(map(operator.mul, highest_diff, (a,b)[index])) return raw.bit_decompose()[0] @@ -1622,7 +1661,7 @@ def mul(self, other): if type(other) == self.bin_type: raise CompilerError('Unclear multiplication') self_bits = self.bit_decompose() - if isinstance(other, (int, long)): + if isinstance(other, int): other_bits = util.bit_decompose(other, self.n_bits) bit_matrix = [[x * y for y in self_bits] for x in other_bits] else: @@ -1644,8 +1683,8 @@ def mul(self, other): @classmethod def wallace_tree_from_matrix(cls, bit_matrix, get_carry=True): - columns = [filter(None, (bit_matrix[j][i-j] \ - for j in range(min(len(bit_matrix), i + 1)))) \ + columns = [[_f for _f in (bit_matrix[j][i-j] \ + for j in range(min(len(bit_matrix), i + 1))) if _f] \ for i in range(len(bit_matrix[0]))] return cls.wallace_tree_from_columns(columns, get_carry) @@ -1671,7 +1710,7 @@ def wallace_tree_from_columns(cls, columns, get_carry=True): columns = new_columns[:-1] for col in columns: col.extend([0] * (2 - len(col))) - return self.bit_adder(*zip(*columns)) + return self.bit_adder(*list(zip(*columns))) @classmethod def wallace_tree(cls, rows): @@ -1685,17 +1724,17 @@ def __sub__(self, other): d = [(1 + ai + bi, (1 - ai) * bi) for (ai,bi) in zip(a,b)] borrow = lambda y,x,*args: \ (x[0] * y[0], 1 - (1 - x[1]) * (1 - x[0] * y[1])) - borrows = (0,) + zip(*floatingpoint.PreOpL(borrow, d))[1] + borrows = (0,) + list(zip(*floatingpoint.PreOpL(borrow, d)))[1] return self.compose(ai + bi + borrow \ for (ai,bi,borrow) in zip(a,b,borrows)) def __rsub__(self, other): raise NotImplementedError() - def __div__(self, other): + def __truediv__(self, other): raise NotImplementedError() - def __rdiv__(self, other): + def __truerdiv__(self, other): raise NotImplementedError() def __lshift__(self, other): @@ -1953,7 +1992,7 @@ class cfix(_number, _structure): """ Clear fixed point type. """ __slots__ = ['value', 'f', 'k', 'size'] reg_type = 'c' - scalars = (int, long, float, regint) + scalars = (int, float, regint) @classmethod def set_precision(cls, f, k = None): # k is the whole bitlength of fixed point @@ -1978,7 +2017,7 @@ def read_from_socket(cls, client_id, n=1): if n == 1: return cfix(cint_inputs) else: - return map(cfix, cint_inputs) + return list(map(cfix, cint_inputs)) @vectorize def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): @@ -1990,7 +2029,7 @@ def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoTy """ Send a list of cfix values to socket. Values are sent as bit shifted cints. """ def cfix_to_cint(fix_val): return cint(fix_val.v) - cint_values = map(cfix_to_cint, values) + cint_values = list(map(cfix_to_cint, values)) writesocketc(client_id, message_type, *cint_values) @staticmethod @@ -2152,7 +2191,7 @@ def __ne__(self, other): raise NotImplementedError @vectorize - def __div__(self, other): + def __truediv__(self, other): other = parse_type(other) if isinstance(other, cfix): return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f)) @@ -2190,7 +2229,7 @@ def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType """ Securely obtain shares of n values input by a client. Assumes client has already run bit shift to convert fixed point to integer.""" sint_inputs = cls.int_type.receive_from_client(n, client_id, ClientMessageType.TripleShares) - return map(cls, sint_inputs) + return list(map(cls, sint_inputs)) @vectorized_classmethod def load_mem(cls, address, mem_type=None): @@ -2333,7 +2372,7 @@ def set_precision(cls, f, k = None): @classmethod def coerce(cls, other): - if isinstance(other, (_fix, cfix)): + if isinstance(other, (_fix, cls.clear_type)): return other else: return cls.conv(other) @@ -2402,7 +2441,7 @@ def add(self, other): @vectorize def mul(self, other): - if isinstance(other, (sint, cint, regint, int, long)): + if isinstance(other, (sint, cint, regint, int)): return self._new(self.v * other, k=self.k, f=self.f) elif isinstance(other, float): if int(other) == other: @@ -2413,13 +2452,11 @@ def mul(self, other): f = self.f while v % 2 == 0: f -= 1 - v /= 2 + v //= 2 k = len(bin(abs(v))) - 1 - other = cfix(cint(v)) - other.f = f - other.k = k + other = self.multipliable(v, k, f) other = self.coerce(other) - if isinstance(other, (_fix, cfix)): + if isinstance(other, (_fix, self.clear_type)): val = self.v.TruncMul(other.v, self.k + other.k, other.f, self.kappa, self.round_nearest) @@ -2438,7 +2475,7 @@ def __neg__(self): return type(self)(-self.v) @vectorize - def __div__(self, other): + def __truediv__(self, other): other = self.coerce(other) if isinstance(other, _fix): return type(self)(library.FPDiv(self.v, other.v, self.k, self.f, @@ -2450,7 +2487,7 @@ def __div__(self, other): raise TypeError('Incompatible fixed point types in division') @vectorize - def __rdiv__(self, other): + def __rtruediv__(self, other): return self.coerce(other) / self @vectorize @@ -2497,6 +2534,10 @@ def pre_mul(self): def unreduced(self, v, other=None, res_params=None, n_summands=1): return unreduced_sfix(v, self.k * 2, self.f, self.kappa) + @staticmethod + def multipliable(v, k, f): + return cfix(cint.conv(v), k, f) + class unreduced_sfix(_single): int_type = sint @@ -2511,7 +2552,7 @@ def __init__(self, v, k, m, kappa): self.kappa = kappa def __add__(self, other): - if other is 0 or other is 0L: + if other is 0: return self assert self.k == other.k assert self.m == other.m @@ -2643,7 +2684,7 @@ def __init__(self, v, params, res_params=None, n_summands=1): self.res_params = res_params or params[0] def __add__(self, other): - if other is 0 or other is 0L: + if other is 0: return self assert self.params == other.params assert self.res_params == other.res_params @@ -2807,10 +2848,10 @@ def convert_float(v, vlen, plen): v = int(round(abs(v) * 2 ** (-p))) if v == 2 ** vlen: p += 1 - v /= 2 + v //= 2 z = 0 if p < -2 ** (plen - 1): - print 'Warning: %e truncated to zero' % vv + print('Warning: %e truncated to zero' % vv) v, p, z = 0, 0, 1 if p >= 2 ** (plen - 1): raise CompilerError('Cannot convert %s to float ' \ @@ -2950,8 +2991,8 @@ def add(self, other): v = t u = floatingpoint.BitDec(v, self.vlen + 2 + sfloat.round_nearest, self.vlen + 2 + sfloat.round_nearest, self.kappa, - range(1 + sfloat.round_nearest, - self.vlen + 2 + sfloat.round_nearest)) + list(range(1 + sfloat.round_nearest, + self.vlen + 2 + sfloat.round_nearest))) # using u[0] doesn't seem necessary h = floatingpoint.PreOR(u[:sfloat.round_nearest:-1], self.kappa) p0 = self.vlen + 1 - sum(h) @@ -3013,7 +3054,7 @@ def __sub__(self, other): def __rsub__(self, other): return -self + other - def __div__(self, other): + def __truediv__(self, other): other = self.conv(other) v = floatingpoint.SDiv(self.v, other.v + other.z * (2**self.vlen - 1), self.vlen, self.kappa, self.round_nearest) @@ -3029,21 +3070,16 @@ def __div__(self, other): sfloat.set_error(other.z) return sfloat(v, p, z, s) - def __rdiv__(self, other): + def __rtruediv__(self, other): return self.conv(other) / self @vectorize def __neg__(self): return sfloat(self.v, self.p, self.z, (1 - self.s) * (1 - self.z)) - def __abs__(self): - if self.s: - return -self - else: - return self - @vectorize def __lt__(self, other): + other = self.conv(other) if isinstance(other, sfloat): z1 = self.z z2 = other.z @@ -3066,8 +3102,15 @@ def __lt__(self, other): def __ge__(self, other): return 1 - (self < other) + def __gt__(self, other): + return self.conv(other) < self + + def __le__(self, other): + return self.conv(other) >= self + @vectorize def __eq__(self, other): + other = self.conv(other) # the sign can be both ways for zeroes both_zero = self.z * other.z return floatingpoint.EQZ(self.v - other.v, self.vlen, self.kappa) * \ @@ -3151,24 +3194,25 @@ def delete(self): program.free(self.address, self.value_type.reg_type) def get_address(self, index): + key = str(index) if isinstance(index, int) and self.length is not None: index += self.length * (index < 0) if index >= self.length or index < 0: raise IndexError('index %s, length %s' % \ (str(index), str(self.length))) - if (program.curr_block, index) not in self.address_cache: + if (program.curr_block, key) not in self.address_cache: n = self.value_type.n_elements() length = self.length if n == 1: # length can be None for single-element arrays length = 0 - self.address_cache[program.curr_block, index] = \ + self.address_cache[program.curr_block, key] = \ util.untuplify([self.address + index + i * length \ for i in range(n)]) if self.debug: library.print_ln_if(index >= self.length, 'OF:' + self.debug) - library.print_ln_if(self.address_cache[program.curr_block, index] >= program.allocated_mem[self.value_type.reg_type], 'AOF:' + self.debug) - return self.address_cache[program.curr_block, index] + library.print_ln_if(self.address_cache[program.curr_block, key] >= program.allocated_mem[self.value_type.reg_type], 'AOF:' + self.debug) + return self.address_cache[program.curr_block, key] def get_slice(self, index): if index.stop is None and self.length is None: @@ -3178,7 +3222,7 @@ def get_slice(self, index): def __getitem__(self, index): if isinstance(index, slice): start, stop, step = self.get_slice(index) - res_length = (stop - start - 1) / step + 1 + res_length = (stop - start - 1) // step + 1 res = Array(res_length, self.value_type) @library.for_range(res_length) def f(i): @@ -3303,7 +3347,7 @@ def __init__(self, sizes, value_type, address, index, debug=None): def __getitem__(self, index): if util.is_constant(index) and index >= self.sizes[0]: raise StopIteration - key = program.curr_block, index + key = program.curr_block, str(index) if key not in self.sub_cache: if self.debug: library.print_ln_if(index >= self.sizes[0], \ @@ -3531,7 +3575,7 @@ class _mem(_number): __add__ = lambda self,other: self.read() + other __sub__ = lambda self,other: self.read() - other __mul__ = lambda self,other: self.read() * other - __div__ = lambda self,other: self.read() / other + __truediv__ = lambda self,other: self.read() / other __mod__ = lambda self,other: self.read() % other __pow__ = lambda self,other: self.read() ** other __neg__ = lambda self,other: -self.read() @@ -3550,7 +3594,7 @@ class _mem(_number): __radd__ = lambda self,other: other + self.read() __rsub__ = lambda self,other: other - self.read() __rmul__ = lambda self,other: other * self.read() - __rdiv__ = lambda self,other: other / self.read() + __rtruediv__ = lambda self,other: other / self.read() __rmod__ = lambda self,other: other % self.read() __rand__ = lambda self,other: other & self.read() __rxor__ = lambda self,other: other ^ self.read() @@ -3627,7 +3671,7 @@ def write(self, value): self.check() if isinstance(value, MemValue): self.register = value.read() - elif isinstance(value, (int,long)): + elif isinstance(value, int): self.register = self.value_type(value) else: self.register = value @@ -3717,7 +3761,7 @@ def getNamedTupleType(*names): class NamedTuple(object): class NamedTupleArray(object): def __init__(self, size, t): - import types + from . import types self.arrays = [types.Array(size, t) for i in range(len(names))] def __getitem__(self, index): return NamedTuple(array[index] for array in self.arrays) @@ -3749,4 +3793,4 @@ def reveal(self): return self.__type__(x.reveal() for x in self) return NamedTuple -import library +from . import library diff --git a/Compiler/util.py b/Compiler/util.py index 403a81fa1..8b7ea214a 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -1,5 +1,6 @@ import math import operator +from functools import reduce def format_trace(trace, prefix=' '): if trace is None: @@ -46,7 +47,7 @@ def right_shift(a, b, bits): return a.right_shift(b, bits) def bit_decompose(a, bits): - if isinstance(a, (int,long)): + if isinstance(a, int): return [int((a >> i) & 1) for i in range(bits)] else: return a.bit_decompose(bits) @@ -82,7 +83,7 @@ def if_else(cond, a, b): else: return cond.if_else(a, b) except: - print cond, a, b + print(cond, a, b) raise def cond_swap(cond, a, b): @@ -112,8 +113,8 @@ def tree_reduce(function, sequence): if n == 1: return sequence[0] else: - reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n/2)] - return tree_reduce(function, reduced + sequence[n/2*2:]) + reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n//2)] + return tree_reduce(function, reduced + sequence[n//2*2:]) def or_op(a, b): return a + b - a * b @@ -144,7 +145,7 @@ def reveal(x): return x def is_constant(x): - return isinstance(x, (int, long, bool)) + return isinstance(x, (int, bool)) def is_constant_float(x): return isinstance(x, float) or is_constant(x) @@ -180,3 +181,44 @@ def expand(x, size): return x.expand_to_vector(size) except AttributeError: return x + +class set_by_id(object): + def __init__(self, init=[]): + self.content = {} + for x in init: + self.add(x) + + def __contains__(self, value): + return id(value) in self.content + + def __iter__(self): + return iter(self.content.values()) + + def add(self, value): + self.content[id(value)] = value + +class dict_by_id(object): + def __init__(self): + self.content = {} + + def __contains__(self, key): + return id(key) in self.content + + def __getitem__(self, key): + return self.content[id(key)][1] + + def __setitem__(self, key, value): + self.content[id(key)] = (key, value) + + def keys(self): + return (x[0] for x in self.content.values()) + +class defaultdict_by_id(dict_by_id): + def __init__(self, default): + dict_by_id.__init__(self) + self.default = default + + def __getitem__(self, key): + if key not in self: + self[key] = self.default() + return dict_by_id.__getitem__(self, key) diff --git a/ECDSA/EcdsaOptions.h b/ECDSA/EcdsaOptions.h new file mode 100644 index 000000000..619aaa032 --- /dev/null +++ b/ECDSA/EcdsaOptions.h @@ -0,0 +1,66 @@ +/* + * EcdsaOptions.h + * + */ + +#ifndef ECDSA_ECDSAOPTIONS_H_ +#define ECDSA_ECDSAOPTIONS_H_ + +#include "Tools/ezOptionParser.h" + +class EcdsaOptions +{ +public: + bool prep_mul; + bool fewer_rounds; + bool check_open; + bool check_beaver_open; + + EcdsaOptions(ez::ezOptionParser& opt, int argc, const char** argv) + { + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Delay multiplication until signing", // Help description. + "-D", // Flag token. + "--delay-multiplication" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Fewer rounds, more EC", // Help description. + "-P", // Flag token. + "--parallel-open" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Skip checking final openings (but not necessarily openings for Beaver; only relevant with active protocols)", // Help description. + "-C", // Flag token. + "--no-open-check" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Skip checking Beaver openings (only relevant with active protocols)", // Help description. + "-B", // Flag token. + "--no-beaver-open-check" // Flag token. + ); + opt.parse(argc, argv); + prep_mul = not opt.isSet("-D"); + fewer_rounds = opt.isSet("-P"); + check_open = not opt.isSet("-C"); + check_beaver_open = not opt.isSet("-B"); + opt.resetArgs(); + } +}; + +#endif /* ECDSA_ECDSAOPTIONS_H_ */ diff --git a/ECDSA/Fake-ECDSA.cpp b/ECDSA/Fake-ECDSA.cpp index 6ddc519c6..bf2c544ff 100644 --- a/ECDSA/Fake-ECDSA.cpp +++ b/ECDSA/Fake-ECDSA.cpp @@ -18,7 +18,7 @@ int main() string prefix = PREP_DIR "ECDSA/"; mkdir_p(prefix.c_str()); ofstream outf; - write_online_setup(outf, prefix, P256Element::Scalar::pr(), 0, false); + write_online_setup_without_init(outf, prefix, P256Element::Scalar::pr(), 0); generate_mac_keys>(key, key2, 2, prefix); make_mult_triples>(key, 2, 1000, false, prefix); make_inverse>(key, 2, 1000, false, prefix); diff --git a/ECDSA/P256Element.cpp b/ECDSA/P256Element.cpp index 0d58db435..d0828ec0c 100644 --- a/ECDSA/P256Element.cpp +++ b/ECDSA/P256Element.cpp @@ -77,6 +77,14 @@ P256Element& P256Element::operator +=(const P256Element& other) return *this; } +P256Element& P256Element::operator /=(const Scalar& other) +{ + auto tmp = other; + tmp.invert(); + *this = *this * tmp; + return *this; +} + bool P256Element::operator ==(const P256Element& other) const { return point == other.point; diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index 603a7e0d0..adb38a9d1 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -10,14 +10,10 @@ #include "Math/gfp.h" -#if GFP_MOD_SZ != 4 -#error GFP_MOD_SZ must be 4 -#endif - class P256Element : public ValueInterface { public: - typedef gfp Scalar; + typedef gfp_<2, 4> Scalar; private: static CryptoPP::DL_GroupParameters_EC params; @@ -52,6 +48,7 @@ class P256Element : public ValueInterface P256Element operator*(const Scalar& other) const; P256Element& operator+=(const P256Element& other); + P256Element& operator/=(const Scalar& other); bool operator==(const P256Element& other) const; bool operator!=(const P256Element& other) const; diff --git a/ECDSA/README.md b/ECDSA/README.md index d7db61910..5d1db349f 100644 --- a/ECDSA/README.md +++ b/ECDSA/README.md @@ -5,8 +5,7 @@ in `preprocessing.hpp` and `sign.hpp`, respectively. #### Compilation -- Add `MOD = -DGFP_MOD_SZ=4` to `CONFIG.mine` -- Also consider adding either `CXX = clang++` or `OPTIM = -O2` because GCC 8 or later with `-O3` will produce a segfault when using `mascot-ecdsa-party.x` +- Add either `CXX = clang++` or `OPTIM = -O2` because GCC 8 or later with `-O3` will produce a segfault when using `mascot-ecdsa-party.x` - For older hardware, also add `ARCH = -march=native` - Install [Crypto++](https://www.cryptopp.com) (`libcrypto++-dev` on Ubuntu). We used version 5.6.4, which is the default on Ubuntu 18.04. - Compile the binaries: `make -j8 ecdsa` diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index 9a703f376..bea4db5bb 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -23,6 +23,7 @@ int main(int argc, const char** argv) { ez::ezOptionParser opt; + EcdsaOptions opts(opt, argc, argv); Names N(opt, argc, argv, 2); int n_tuples = 1000; if (not opt.lastArgs.empty()) @@ -38,7 +39,7 @@ int main(int argc, const char** argv) typedef Share pShare; DataPositions usage; Sub_Data_Files prep(N, prefix, usage); - typename pShare::MAC_Check MCp(keyp); + typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); SubProcessor proc(_, MCp, prep, P); @@ -46,7 +47,7 @@ int main(int argc, const char** argv) proc.DataF.get_two(DATA_INVERSE, sk, __); vector> tuples; - preprocessing(tuples, n_tuples, sk, proc); + preprocessing(tuples, n_tuples, sk, proc, opts); check(tuples, sk, keyp, P); - sign_benchmark(tuples, sk, MCp, P); + sign_benchmark(tuples, sk, MCp, P, opts); } diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index 174ec672d..3ab383a80 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -29,15 +29,7 @@ void run(int argc, const char** argv) { bigint::init_thread(); ez::ezOptionParser opt; - opt.add( - "", // Default. - 0, // Required? - 0, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Delay multiplication until signing", // Help description. - "-D", // Flag token. - "--delay-multiplication" // Flag token. - ); + EcdsaOptions opts(opt, argc, argv); Names N(opt, argc, argv, 3); int n_tuples = 1000; if (not opt.lastArgs.empty()) @@ -51,10 +43,12 @@ void run(int argc, const char** argv) P.Broadcast_Receive(bundle, false); Timer timer; timer.start(); + auto stats = P.comm_stats; pShare sk = typename T::Honest::Protocol(P).get_random(); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; + (P.comm_stats - stats).print(true); - OnlineOptions::singleton.batch_size = n_tuples; + OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; DataPositions usage; auto& prep = *Preprocessing::get_live_prep(0, usage); typename pShare::MAC_Check MCp; @@ -63,9 +57,9 @@ void run(int argc, const char** argv) bool prep_mul = not opt.isSet("-D"); vector> tuples; - preprocessing(tuples, n_tuples, sk, proc, prep_mul); + preprocessing(tuples, n_tuples, sk, proc, opts); // check(tuples, sk, {}, P); - sign_benchmark(tuples, sk, MCp, P, prep_mul ? 0 : &proc); + sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); delete &prep; } diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 998657dbc..f9cc9c9fc 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -25,27 +25,72 @@ template class T> void run(int argc, const char** argv) { ez::ezOptionParser opt; + EcdsaOptions opts(opt, argc, argv); opt.add( "", // Default. 0, // Required? 0, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Delay multiplication until signing", // Help description. - "-D", // Flag token. - "--delay-multiplication" // Flag token. + "Use SimpleOT instead of OT extension", // Help description. + "-S", // Flag token. + "--simple-ot" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Don't check correlation in OT extension (only relevant with MASCOT)", // Help description. + "-U", // Flag token. + "--unchecked-correlation" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Fewer rounds for authentication (only relevant with MASCOT)", // Help description. + "-A", // Flag token. + "--auth-fewer-rounds" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Use Fiat-Shamir for amplification (only relevant with MASCOT)", // Help description. + "-H", // Flag token. + "--fiat-shamir" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Skip sacrifice (only relevant with MASCOT)", // Help description. + "-E", // Flag token. + "--embrace-life" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "No MACs (only relevant with MASCOT; implies skipping MAC checks)", // Help description. + "-M", // Flag token. + "--no-macs" // Flag token. + ); + Names N(opt, argc, argv, 2); int n_tuples = 1000; if (not opt.lastArgs.empty()) n_tuples = atoi(opt.lastArgs[0]->c_str()); PlainPlayer P(N); P256Element::init(); - gfp1::init_field(P256Element::Scalar::pr(), false); + P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false); BaseMachine machine; - machine.ot_setups.resize(1); - for (int i = 0; i < 2; i++) - machine.ot_setups[0].push_back({P, true}); + machine.ot_setups.push_back({P, true}); P256Element::Scalar keyp; SeededPRNG G; @@ -65,16 +110,28 @@ void run(int argc, const char** argv) P.Broadcast_Receive(bundle, false); Timer timer; timer.start(); + auto stats = P.comm_stats; sk_prep.get_two(DATA_INVERSE, sk, __); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; + (P.comm_stats - stats).print(true); - OnlineOptions::singleton.batch_size = n_tuples; + OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; typename pShare::LivePrep prep(0, usage); + prep.params.correlation_check &= not opt.isSet("-U"); + prep.params.fewer_rounds = opt.isSet("-A"); + prep.params.fiat_shamir = opt.isSet("-H"); + prep.params.check = not opt.isSet("-E"); + prep.params.generateMACs = not opt.isSet("-M"); + opts.check_beaver_open &= prep.params.generateMACs; + opts.check_open &= prep.params.generateMACs; SubProcessor proc(_, MCp, prep, P); + typename pShare::prep_type::Direct_MC MCpp(keyp); + prep.triple_generator->MC = &MCpp; bool prep_mul = not opt.isSet("-D"); + prep.params.use_extension = not opt.isSet("-S"); vector> tuples; - preprocessing(tuples, n_tuples, sk, proc, prep_mul); + preprocessing(tuples, n_tuples, sk, proc, opts); //check(tuples, sk, keyp, P); - sign_benchmark(tuples, sk, MCp, P, prep_mul ? 0 : &proc); + sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); } diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp index 10306dafb..449f9331f 100644 --- a/ECDSA/preprocessing.hpp +++ b/ECDSA/preprocessing.hpp @@ -7,6 +7,7 @@ #define ECDSA_PREPROCESSING_HPP_ #include "P256Element.h" +#include "EcdsaOptions.h" #include "Processor/Data_Files.h" #include "Protocols/ReplicatedPrep.h" #include "Protocols/MaliciousShamirShare.h" @@ -23,40 +24,66 @@ class EcTuple template class T> void preprocessing(vector>& tuples, int buffer_size, T& sk, - SubProcessor>& proc, bool prep_mul = true) + SubProcessor>& proc, + EcdsaOptions opts) { + bool prep_mul = opts.prep_mul; Timer timer; timer.start(); Player& P = proc.P; auto& prep = proc.DataF; size_t start = P.sent + prep.data_sent(); + auto stats = P.comm_stats + prep.comm_stats(); + auto& extra_player = P; + auto& protocol = proc.protocol; auto& MCp = proc.MC; typedef T pShare; typedef T cShare; vector inv_ks; vector secret_Rs; + prep.buffer_triples(); + vector bs, cs; for (int i = 0; i < buffer_size; i++) { - pShare a, a_inv; - prep.get_two(DATA_INVERSE, a, a_inv); - inv_ks.push_back(a_inv); - secret_Rs.push_back(a); + pShare a, b, c; + prep.get_three(DATA_TRIPLE, a, b, c); + inv_ks.push_back(a); + bs.push_back(b); + cs.push_back(c); + } + vector cs_opened; + MCp.POpen_Begin(cs_opened, cs, extra_player); + if (opts.fewer_rounds) + secret_Rs.insert(secret_Rs.begin(), bs.begin(), bs.end()); + else + { + MCp.POpen_End(cs_opened, cs, extra_player); + for (int i = 0; i < buffer_size; i++) + secret_Rs.push_back(bs[i] / cs_opened[i]); } + vector opened_Rs; + typename cShare::Direct_MC MCc(MCp.get_alphai()); + MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player); if (prep_mul) { protocol.init_mul(&proc); for (int i = 0; i < buffer_size; i++) protocol.prepare_mul(inv_ks[i], sk); - protocol.exchange(); + protocol.start_exchange(); } - else - prep.buffer_triples(); - vector opened_Rs; - typename cShare::MAC_Check MCc(MCp.get_alphai()); - MCc.POpen(opened_Rs, secret_Rs, P); - MCc.Check(P); - MCp.Check(P); + if (opts.fewer_rounds) + MCp.POpen_End(cs_opened, cs, extra_player); + MCc.POpen_End(opened_Rs, secret_Rs, extra_player); + if (opts.fewer_rounds) + for (int i = 0; i < buffer_size; i++) + opened_Rs[i] /= cs_opened[i]; + if (prep_mul) + protocol.stop_exchange(); + if (opts.check_open) + MCc.Check(extra_player); + if (opts.check_open or opts.check_beaver_open) + MCp.Check(extra_player); for (int i = 0; i < buffer_size; i++) { tuples.push_back( @@ -68,6 +95,7 @@ void preprocessing(vector>& tuples, int buffer_size, << " seconds, throughput " << buffer_size / timer.elapsed() << ", " << 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size << " kbytes per tuple" << endl; + (P.comm_stats + prep.comm_stats() - stats).print(true); } template class T> diff --git a/ECDSA/rep-ecdsa-party.cpp b/ECDSA/rep-ecdsa-party.cpp index 1678f1de1..f1ccc5729 100644 --- a/ECDSA/rep-ecdsa-party.cpp +++ b/ECDSA/rep-ecdsa-party.cpp @@ -8,10 +8,10 @@ #include "hm-ecdsa-party.hpp" template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) +Preprocessing>* Preprocessing>::get_live_prep( + SubProcessor>* proc, DataPositions& usage) { - return new ReplicatedPrep>(proc, usage); + return new ReplicatedPrep>(proc, usage); } int main(int argc, const char** argv) diff --git a/ECDSA/sign.hpp b/ECDSA/sign.hpp index c0e8c338c..f6f4d6631 100644 --- a/ECDSA/sign.hpp +++ b/ECDSA/sign.hpp @@ -71,12 +71,12 @@ EcSignature sign(const unsigned char* message, size_t length, prod = protocol.finalize_mul(); } auto rx = tuple.R.x(); - signature.s = MC.POpen( + signature.s = MC.open( tuple.a * hash_to_scalar(message, length) + prod * rx, P); cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending " << (P.sent - start) << " bytes" << endl; auto diff = (P.comm_stats - stats); - diff.print(); + diff.print(true); return signature; } @@ -112,26 +112,30 @@ void check(EcSignature signature, const unsigned char* message, size_t length, template class T> void sign_benchmark(vector>& tuples, T sk, typename T::MAC_Check& MCp, Player& P, + EcdsaOptions& opts, SubProcessor>* proc = 0) { unsigned char message[1024]; GlobalPRNG(P).get_octets(message, 1024); - typename T::MAC_Check MCc(MCp.get_alphai()); + typename T::Direct_MC MCc(MCp.get_alphai()); // synchronize Bundle bundle(P); P.Broadcast_Receive(bundle, true); Timer timer; timer.start(); - P256Element pk = MCc.POpen(sk, P); + auto stats = P.comm_stats; + P256Element pk = MCc.open(sk, P); MCc.Check(P); cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - P.comm_stats.print(); + (P.comm_stats - stats).print(true); for (size_t i = 0; i < min(10lu, tuples.size()); i++) { check(sign(message, 1 << i, tuples[i], MCp, P, pk, sk, proc), message, 1 << i, pk); + if (not opts.check_open) + continue; Timer timer; timer.start(); auto& check_player = MCp.get_check_player(P); diff --git a/ExternalIO/bankers-bonus-commsec-client.cpp b/ExternalIO/bankers-bonus-commsec-client.cpp index 33ab007cd..365b9b634 100644 --- a/ExternalIO/bankers-bonus-commsec-client.cpp +++ b/ExternalIO/bankers-bonus-commsec-client.cpp @@ -365,6 +365,7 @@ int main(int argc, char** argv) // init static gfp string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree()); initialise_fields(prep_data_prefix); + bigint::init_thread(); // Generate session keys to decrypt data sent from each spdz engine (party) vector session_keys(nparties); diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h index ee1870c7c..84fe88cd9 100644 --- a/FHE/Ciphertext.h +++ b/FHE/Ciphertext.h @@ -39,6 +39,8 @@ class Ciphertext // Rely on default copy assignment/constructor + word get_pk_id() { return pk_id; } + void set(const Rq_Element& a0, const Rq_Element& a1, word pk_id) { cc0=a0; cc1=a1; this->pk_id = pk_id; } void set(const Rq_Element& a0, const Rq_Element& a1, const FHE_PK& pk); diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 04dd9105b..460fee18e 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -9,13 +9,20 @@ template Multiplier::Multiplier(int offset, PairwiseGenerator& generator) : - generator(generator), machine(generator.machine), - P(generator.P, offset), - num_players(generator.P.num_players()), - my_num(generator.P.my_num()), + Multiplier(offset, generator.machine, generator.P, generator.timers) +{ +} + +template +Multiplier::Multiplier(int offset, PairwiseMachine& machine, Player& P, + map& timers) : + machine(machine), + P(P, offset), + num_players(P.num_players()), + my_num(P.my_num()), other_pk(machine.other_pks[(my_num + num_players - offset) % num_players]), other_enc_alpha(machine.enc_alphas[(my_num + num_players - offset) % num_players]), - timers(generator.timers), + timers(timers), C(machine.pk), mask(machine.pk), product_share(machine.setup().FieldD), rc(machine.pk), volatile_capacity(0) diff --git a/FHEOffline/Multiplier.h b/FHEOffline/Multiplier.h index 4a9ba4a5f..17159a7f7 100644 --- a/FHEOffline/Multiplier.h +++ b/FHEOffline/Multiplier.h @@ -21,7 +21,6 @@ class PairwiseMachine; template class Multiplier { - PairwiseGenerator& generator; PairwiseMachine& machine; OffsetPlayer P; int num_players, my_num; @@ -39,6 +38,9 @@ class Multiplier public: Multiplier(int offset, PairwiseGenerator& generator); + Multiplier(int offset, PairwiseMachine& machine, Player& P, + map& timers); + void multiply_and_add(Plaintext_& res, const Ciphertext& C, const Plaintext_& b); void multiply_and_add(Plaintext_& res, const Ciphertext& C, diff --git a/GC/Processor.hpp b/GC/Processor.hpp index ae23693ac..9ee5c2b92 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -13,6 +13,8 @@ using namespace std; #include "Access.h" #include "Processor/FixInput.h" +#include "Processor/ProcessorBase.hpp" + namespace GC { diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 8b4de98ba..37c1e4115 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -25,9 +25,9 @@ void SemiPrep::set_protocol(Beaver& protocol) (void) protocol; params.set_passive(); triple_generator = new SemiSecret::TripleGenerator( - thread.processor.machine.ot_setups.at(thread.thread_num).at(0), + thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(), thread.master.N, thread.thread_num, thread.master.opts.batch_size, - 1, params, thread.P); + 1, params, {}, thread.P); triple_generator->multi_threaded = false; } diff --git a/GC/ShareParty.hpp b/GC/ShareParty.hpp index 2f96a2e77..d8a04964b 100644 --- a/GC/ShareParty.hpp +++ b/GC/ShareParty.hpp @@ -95,7 +95,7 @@ ShareParty::ShareParty(int argc, const char** argv, int default_batch_size) : else P = new PlainPlayer(this->N, 0xFFFF); for (int i = 0; i < this->machine.nthreads; i++) - this->machine.ot_setups.push_back({{{*P, true}}}); + this->machine.ot_setups.push_back({*P, true}); delete P; } diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 7a9c0b38d..1e18d7e1e 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -152,8 +152,7 @@ void ReplicatedSecret::reveal(size_t n_bits, Clear& x) auto& share = *this; vector opened; auto& party = ShareThread::s(); - party.MC->POpen_Begin(opened, {share}, *party.P); - party.MC->POpen_End(opened, {share}, *party.P); + party.MC->POpen(opened, {share}, *party.P); x = IntBase(opened[0]); } diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index 043d93557..0a533b2ea 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -59,7 +59,7 @@ void ThreadMaster::run() if (T::needs_ot) for (int i = 0; i < machine.nthreads; i++) - machine.ot_setups.push_back({{*P, true}, {*P, true}}); + machine.ot_setups.push_back({*P, true}); for (int i = 0; i < machine.nthreads; i++) threads.push_back(new_thread(i)); diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index 0ff4c2a0a..e332be638 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -31,18 +31,17 @@ void TinyPrep::set_protocol(Beaver& protocol) params.generateMACs = true; params.amplify = false; params.check = false; - params.set_mac_key(thread.MC->get_alphai()); triple_generator = new typename T::TripleGenerator( - thread.processor.machine.ot_setups.at(thread.thread_num).at(0), + thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(), thread.master.N, thread.thread_num, thread.master.opts.batch_size, - 1, params, thread.P); + 1, params, thread.MC->get_alphai(), thread.P); triple_generator->multi_threaded = false; input_generator = new typename T::part_type::TripleGenerator( - thread.processor.machine.ot_setups.at(thread.thread_num).at(1), + thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(), thread.master.N, thread.thread_num, thread.master.opts.batch_size, - 1, params, thread.P); + 1, params, thread.MC->get_alphai(), thread.P); input_generator->multi_threaded = false; thread.MC->get_part_MC().set_prep(*this); } diff --git a/Machines/OTMachine.cpp b/Machines/OTMachine.cpp index 7d16cd0fc..e589cfc8c 100644 --- a/Machines/OTMachine.cpp +++ b/Machines/OTMachine.cpp @@ -264,11 +264,8 @@ OTMachine::OTMachine(int argc, const char** argv) // convert baseOT selection bits to BitVector // (not already BitVector due to legacy PVW code) + baseReceiverInput = bot.receiver_inputs; baseReceiverInput.resize(nbase); - for (int i = 0; i < nbase; i++) - { - baseReceiverInput.set_bit(i, bot.receiver_inputs[i]); - } } OTMachine::~OTMachine() diff --git a/Machines/Player-Online.hpp b/Machines/Player-Online.hpp index 69a6d0b6c..08fc64d0c 100644 --- a/Machines/Player-Online.hpp +++ b/Machines/Player-Online.hpp @@ -107,7 +107,7 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr 1, // Number of args expected. 0, // Delimiter if expecting multiple args. "Maximum number of parties to send to at once", // Help description. - "-b", // Flag token. + "-B", // Flag token. "--max-broadcast" // Flag token. ); opt.add( diff --git a/Machines/TripleMachine.cpp b/Machines/TripleMachine.cpp index 79a23c14b..811cb482e 100644 --- a/Machines/TripleMachine.cpp +++ b/Machines/TripleMachine.cpp @@ -118,6 +118,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) : opt.get("-l")->getInt(nloops); generateBits = opt.get("-B")->isSet; check = opt.get("-c")->isSet || generateBits; + correlation_check = opt.get("-c")->isSet; generateMACs = opt.get("-m")->isSet || check; amplify = opt.get("-a")->isSet || generateMACs; primeField = opt.get("-P")->isSet; @@ -143,21 +144,22 @@ TripleMachine::TripleMachine(int argc, const char** argv) : // doesn't work with Montgomery multiplication gfp1::init_field(p, false); + gfp::init_field(p, true); gf2n_long::init_field(128); PRNG G; G.ReSeed(); - mac_key2l.randomize(G); - mac_key2s.randomize(G); + mac_key2.randomize(G); mac_keyp.randomize(G); mac_keyz.randomize(G); } template -GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i) +GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i, + typename T::mac_key_type mac_key) { return new typename T::TripleGenerator(setup, N[i % nConnections], i, - nTriplesPerThread, nloops, *this); + nTriplesPerThread, nloops, *this, mac_key); } void TripleMachine::run() @@ -180,24 +182,24 @@ void TripleMachine::run() for (int i = 0; i < nthreads; i++) { if (primeField) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyp); else if (z2k) { if (z2k == 32 and z2s == 32) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyz); else if (z2k == 64 and z2s == 64) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyz); else if (z2k == 64 and z2s == 48) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyz); else if (z2k == 66 and z2s == 64) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyz); else if (z2k == 66 and z2s == 48) - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_keyz); else throw runtime_error("not compiled for k=" + to_string(z2k) + " and s=" + to_string(z2s)); } else - generators[i] = new_generator>(setup, i); + generators[i] = new_generator>(setup, i, mac_key2); } ntriples = generators[0]->nTriples * nthreads; cout <<"Setup generators\n"; @@ -251,10 +253,8 @@ void TripleMachine::run() void TripleMachine::output_mac_keys() { if (z2k) { - write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyz, mac_key2l); + write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyz, mac_key2); } - else if (gf2n::degree() > 64) - write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2l); else - write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2s); + write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2); } diff --git a/Machines/hemi-party.cpp b/Machines/hemi-party.cpp new file mode 100644 index 000000000..2ec4618bc --- /dev/null +++ b/Machines/hemi-party.cpp @@ -0,0 +1,29 @@ +/* + * hemi-party.cpp + * + */ + +#include "Protocols/HemiShare.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "FHE/P2Data.h" +#include "Tools/ezOptionParser.h" + +#include "Player-Online.hpp" +#include "Protocols/HemiPrep.hpp" +#include "Processor/Data_Files.hpp" +#include "Processor/Instruction.hpp" +#include "Processor/Machine.hpp" +#include "Protocols/SemiPrep.hpp" +#include "Protocols/SemiInput.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/MAC_Check.hpp" +#include "Protocols/fake-stuff.hpp" +#include "Protocols/SemiMC.hpp" +#include "Protocols/Beaver.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + spdz_main, HemiShare>(argc, argv, opt); +} diff --git a/Makefile b/Makefile index ed1bfe48c..443280f16 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,7 @@ DEPS := $(wildcard */*.d) all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x semi-bin-party.x mascot-party.x tiny-party.x ifeq ($(USE_NTL),1) -all: overdrive she-offline cowgear-party.x +all: overdrive she-offline cowgear-party.x hemi-party.x endif -include $(DEPS) @@ -165,6 +165,7 @@ malicious-shamir-party.x: Machines/ShamirMachine.o spdz2k-party.x: $(OT) semi-party.x: $(OT) semi2k-party.x: $(OT) +hemi-party.x: $(FHEOFFLINE) cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o mascot-party.x: Machines/SPDZ.o $(OT) Player-Online.x: Machines/SPDZ.o $(OT) diff --git a/Math/Setup.cpp b/Math/Setup.cpp index e103353da..ebe18dd0f 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -110,6 +110,13 @@ void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, i } void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2, bool mont) +{ + write_online_setup_without_init(outf, dirname, p, lg2); + gfp::init_field(p, mont); + init_gf2n(lg2); +} + +void write_online_setup_without_init(ofstream& outf, string dirname, const bigint& p, int lg2) { if (p == 0) throw runtime_error("prime cannot be 0"); @@ -132,9 +139,6 @@ void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2 // Fix as a negative lg2 is a ``signal'' to choose slightly weaker // LWE parameters outf << abs(lg2) << endl; - - gfp::init_field(p, mont); - init_gf2n(lg2); } void init_gf2n(int lg2) diff --git a/Math/Setup.h b/Math/Setup.h index d9a196df6..b95747643 100644 --- a/Math/Setup.h +++ b/Math/Setup.h @@ -22,6 +22,7 @@ using namespace std; // Create setup file for gfp and gf2n void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, int lg2); void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2, bool mont = true); +void write_online_setup_without_init(ofstream& outf, string dirname, const bigint& p, int lg2); // Setup primes only // Chooses a p of at least lgp bits diff --git a/Math/Square.cpp b/Math/Square.cpp index b7f3ece8a..83e1d97f2 100644 --- a/Math/Square.cpp +++ b/Math/Square.cpp @@ -15,14 +15,14 @@ void Square::to(gf2n_short& result) result = sum; } -template <> -void Square::to(gfp1& result) +template +template +void Square::to(gfp_& result) { - const int L = gfp1::N_LIMBS; mp_limb_t product[2 * L], sum[2 * L], tmp[L][2 * L]; memset(tmp, 0, sizeof(tmp)); memset(sum, 0, sizeof(sum)); - for (int i = 0; i < gfp1::length(); i++) + for (int i = 0; i < gfp_::length(); i++) { memcpy(&(tmp[i/64][i/64]), &(rows[i]), sizeof(rows[i])); if (i % 64 == 0) @@ -32,10 +32,22 @@ void Square::to(gfp1& result) mpn_add_fixed_n<2 * L>(sum, product, sum); } mp_limb_t q[2 * L], ans[2 * L]; - mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp1::get_ZpD().get_prA(), L); + mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp_::get_ZpD().get_prA(), L); result.assign((void*) ans); } +template<> +void Square::to(gfp1& result) +{ + to<1, GFP_MOD_SZ>(result); +} + +template<> +void Square::to(gfp3& result) +{ + to<3, 4>(result); +} + template<> void Square::to(BitVec& result) { diff --git a/Math/Square.h b/Math/Square.h index b33d81348..28484dbd4 100644 --- a/Math/Square.h +++ b/Math/Square.h @@ -31,6 +31,8 @@ class Square void conditional_add(BitVector& conditions, Square& other, int offset); void to(U& result); + template + void to(gfp_& result); void pack(octetStream& os) const; void unpack(octetStream& os); diff --git a/Math/Z2k.h b/Math/Z2k.h index 05530d8df..2bd9fbed7 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -283,7 +283,8 @@ Z2 Z2::operator>>(int i) const { Z2 res; int n_byte_shift = i / 8; - memcpy(res.a, (char*)a + n_byte_shift, N_BYTES - n_byte_shift); + if (N_BYTES - n_byte_shift > 0) + memcpy(res.a, (char*)a + n_byte_shift, N_BYTES - n_byte_shift); int n_inside_shift = i % 8; if (n_inside_shift > 0) mpn_rshift(res.a, res.a, N_WORDS, n_inside_shift); diff --git a/Math/Z2k.hpp b/Math/Z2k.hpp index 41e18f0be..2bc5f652a 100644 --- a/Math/Z2k.hpp +++ b/Math/Z2k.hpp @@ -147,4 +147,17 @@ ostream& operator<<(ostream& o, const Z2& x) return o; } +template +istream& operator>>(istream& i, SignedZ2& x) +{ + auto& tmp = bigint::tmp; + i >> tmp; + if (tmp.numBits() > K + 1) + throw runtime_error( + tmp.get_str() + " out of range for signed " + to_string(K) + + "-bit numbers"); + x = tmp; + return i; +} + #endif diff --git a/Math/gfp.h b/Math/gfp.h index de79d04fe..8360bee8c 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -262,6 +262,8 @@ typedef gfp_<0, GFP_MOD_SZ> gfp; typedef gfp_<1, GFP_MOD_SZ> gfp1; // enough for Brain protocol with 64-bit computation and 40-bit security typedef gfp_<2, 4> gfp2; +// for OT-based ECDSA +typedef gfp_<3, 4> gfp3; void to_signed_bigint(bigint& ans,const gfp& x); diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h index 0dbb9449f..83aa47454 100644 --- a/Math/mpn_fixed.h +++ b/Math/mpn_fixed.h @@ -110,7 +110,7 @@ inline mp_limb_t mpn_add_n_with_carry(mp_limb_t* res, const mp_limb_t* x, const char carry = 0; for (int i = 0; i < n; i++) #if defined(__clang__) -#if __clang_major__ < 8 || defined(__APPLE__) +#if __clang_major__ < 8 || (defined(__APPLE__) && __clang_major__ < 11) carry = __builtin_ia32_addcarry_u64(carry, x[i], y[i], (unsigned long long*)&res[i]); #else carry = __builtin_ia32_addcarryx_u64(carry, x[i], y[i], (unsigned long long*)&res[i]); diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 9e853ee8a..5edd49509 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -97,12 +97,20 @@ Names::Names(ez::ezOptionParser& opt, int argc, const char** argv, 1, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "This player's number", // Help description. + "This player's number (required)", // Help description. "-p", // Flag token. "--player" // Flag token. ); opt.parse(argc, argv); opt.get("-p")->getInt(player_no); + vector missing; + if (not opt.gotRequired(missing)) + { + string usage; + opt.getUsage(usage); + cerr << usage; + exit(1); + } global_server = network_opts.start_networking(*this, player_no); } @@ -123,7 +131,7 @@ void Names::setup_names(const char *servername, int my_port) set_up_client_socket(socket_num, servername, pn); send(socket_num, (octet*)&player_no, sizeof(player_no)); #ifdef DEBUG_NETWORKING - fprintf(stderr, "Sent %d to %s:%d\n", player_no, servername, pn); + cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl; #endif int inst=-1; // wait until instruction to start. @@ -338,6 +346,14 @@ void MultiPlayer::send_all(const octetStream& o,bool donthash) const } +void Player::receive_all(vector& os) const +{ + for (int j = 0; j < num_players(); j++) + if (j != my_num()) + receive_player(j, os[j], true); +} + + void Player::receive_player(int i,octetStream& o,bool donthash) const { #ifdef VERBOSE_COMM @@ -345,6 +361,7 @@ void Player::receive_player(int i,octetStream& o,bool donthash) const #endif TimeScope ts(timer); receive_player_no_stats(i, o); + comm_stats["Receiving directly"].add(o, ts); if (!donthash) { blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); } } @@ -627,12 +644,14 @@ void RealTwoPartyPlayer::receive(octetStream& o) const TimeScope ts(timer); o.reset_write_head(); o.Receive(socket); + comm_stats["Receiving one-to-one"].add(o, ts); } void VirtualTwoPartyPlayer::receive(octetStream& o) const { TimeScope ts(timer); P.receive_player_no_stats(other_player, o); + comm_stats["Receiving one-to-one"].add(o, ts); } void RealTwoPartyPlayer::send_receive_player(vector& o) const @@ -688,6 +707,8 @@ void TwoPartyPlayer::Broadcast_Receive(vector& o, CommStats& CommStats::operator +=(const CommStats& other) { data += other.data; + rounds += other.rounds; + timer += other.timer; return *this; } @@ -698,6 +719,13 @@ NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other) return *this; } +NamedCommStats NamedCommStats::operator +(const NamedCommStats& other) const +{ + auto res = *this; + res += other; + return res; +} + CommStats& CommStats::operator -=(const CommStats& other) { data -= other.data; @@ -722,13 +750,15 @@ size_t NamedCommStats::total_data() return res; } -void NamedCommStats::print() +void NamedCommStats::print(bool newline) { for (auto it = begin(); it != end(); it++) if (it->second.data) cerr << it->first << " " << 1e-6 * it->second.data << " MB in " << it->second.rounds << " rounds, taking " << it->second.timer.elapsed() << " seconds" << endl; + if (size() and newline) + cerr << endl; } template class MultiPlayer; diff --git a/Networking/Player.h b/Networking/Player.h index d07fd92cb..88e3293f9 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -91,6 +91,7 @@ struct CommStats Timer timer; CommStats() : data(0), rounds(0) {} Timer& add(const octetStream& os) { data += os.get_length(); rounds++; return timer; } + void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; } CommStats& operator+=(const CommStats& other); CommStats& operator-=(const CommStats& other); }; @@ -99,9 +100,10 @@ class NamedCommStats : public map { public: NamedCommStats& operator+=(const NamedCommStats& other); + NamedCommStats operator+(const NamedCommStats& other) const; NamedCommStats operator-(const NamedCommStats& other) const; size_t total_data(); - void print(); + void print(bool newline = false); #ifdef VERBOSE_COMM CommStats& operator[](const string& name) { @@ -160,6 +162,7 @@ class Player : public PlayerBase virtual void send_all(const octetStream& o,bool donthash=false) const = 0; void send_to(int player,const octetStream& o,bool donthash=false) const; virtual void send_to_no_stats(int player,const octetStream& o) const = 0; + void receive_all(vector& os) const; void receive_player(int i,octetStream& o,bool donthash=false) const; virtual void receive_player_no_stats(int i,octetStream& o) const = 0; virtual void receive_player(int i,FlexBuffer& buffer) const; diff --git a/Networking/ServerSocket.cpp b/Networking/ServerSocket.cpp index 87cc845f6..8adbb0989 100644 --- a/Networking/ServerSocket.cpp +++ b/Networking/ServerSocket.cpp @@ -83,84 +83,40 @@ ServerSocket::~ServerSocket() void ServerSocket::accept_clients() { - map unassigned_sockets; - while (true) { - fd_set readfds; - FD_ZERO(&readfds); - int nfds = main_socket; - FD_SET(main_socket, &readfds); - for (auto &socket : unassigned_sockets) - { - FD_SET(socket.first, &readfds); - nfds = max(socket.first, nfds); - } - - select(nfds + 1, &readfds, 0, 0, 0); - - if (FD_ISSET(main_socket, &readfds)) - { - struct sockaddr dest; - memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */ - int socksize = sizeof(dest); - int consocket = accept(main_socket, (struct sockaddr*) &dest, - (socklen_t*) &socksize); - if (consocket < 0) - error("set_up_socket:accept"); - unassigned_sockets[consocket] = dest; + struct sockaddr dest; + memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */ + int socksize = sizeof(dest); + int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize); + if (consocket<0) { error("set_up_socket:accept"); } + + int client_id; + try + { + receive(consocket, (unsigned char*)&client_id, sizeof(client_id)); + } + catch (closed_connection&) + { #ifdef DEBUG_NETWORKING - auto &conn = *(sockaddr_in*) &dest; - fprintf(stderr, "new client on %s:%d\n", inet_ntoa(conn.sin_addr), - ntohs(conn.sin_port)); + auto& conn = *(sockaddr_in*)&dest; + cerr << "client on " << inet_ntoa(conn.sin_addr) << ":" + << ntohs(conn.sin_port) << " left without identification" + << endl; #endif - } - - vector processed_sockets; - for (auto &socket : unassigned_sockets) - { - int consocket = socket.first; - if (FD_ISSET(consocket, &readfds)) - { - try - { - int client_id; - receive(consocket, (unsigned char*) &client_id, - sizeof(client_id)); - - data_signal.lock(); - clients[client_id] = consocket; - data_signal.broadcast(); - data_signal.unlock(); + } -#ifdef DEBUG_NETWORKING - auto &conn = *(sockaddr_in*) &socket.second; - fprintf(stderr, "client id %d on %s:%d\n", client_id, - inet_ntoa(conn.sin_addr), ntohs(conn.sin_port)); -#endif + data_signal.lock(); + clients[client_id] = consocket; + data_signal.broadcast(); + data_signal.unlock(); #ifdef __APPLE__ - int flags = fcntl(consocket, F_GETFL, 0); - int fl = fcntl(consocket, F_SETFL, O_NONBLOCK | flags); - if (fl < 0) - error("set non-blocking"); + int flags = fcntl(consocket, F_GETFL, 0); + int fl = fcntl(consocket, F_SETFL, O_NONBLOCK | flags); + if (fl < 0) + error("set non-blocking"); #endif - } - catch (closed_connection&) - { -#ifdef DEBUG_NETWORKING - auto &conn = *(sockaddr_in*) &socket.second; - cerr << "client on " << inet_ntoa(conn.sin_addr) << ":" - << ntohs(conn.sin_port) << " left without identification" - << endl; -#endif - close_client_socket(consocket); - } - processed_sockets.push_back(consocket); - } - } - for (int socket : processed_sockets) - unassigned_sockets.erase(socket); } } diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index e93f35cbe..99c5a3a6a 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -116,7 +116,7 @@ void BaseOT::exec_base(bool new_receiver_inputs) { if (new_receiver_inputs) receiver_inputs[i + j] = G.get_uchar()&1; - cs[j] = receiver_inputs[i + j]; + cs[j] = receiver_inputs[i + j].get(); } receiver_rsgen(&receiver, Rs_pack[0], cs); os[0].store_bytes(Rs_pack[0], sizeof(Rs_pack[0])); @@ -293,7 +293,7 @@ void FakeOT::exec_base(bool new_receiver_inputs) { for (int j = 0; j < 2; j++) bv[j].unpack(os[1]); - receiver_outputs[i] = bv[receiver_inputs[i]]; + receiver_outputs[i] = bv[receiver_inputs[i].get()]; } set_seeds(); diff --git a/OT/BaseOT.h b/OT/BaseOT.h index 54be8a1da..1b314d388 100644 --- a/OT/BaseOT.h +++ b/OT/BaseOT.h @@ -30,7 +30,7 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector& os, OT_ROLE rol class BaseOT { public: - vector receiver_inputs; + BitVector receiver_inputs; vector< vector > sender_inputs; vector receiver_outputs; TwoPartyPlayer* P; @@ -63,7 +63,7 @@ class BaseOT int length() { return ot_length; } - void set_receiver_inputs(const vector& new_inputs) + void set_receiver_inputs(const BitVector& new_inputs) { if ((int)new_inputs.size() != nOT) throw invalid_length(); @@ -72,7 +72,7 @@ class BaseOT void set_receiver_inputs(int128 inputs) { - vector new_inputs(128); + BitVector new_inputs(128); for (int i = 0; i < 128; i++) new_inputs[i] = (inputs >> i).get_lower() & 1; set_receiver_inputs(new_inputs); @@ -81,6 +81,7 @@ class BaseOT // do the OTs -- generate fresh random choice bits by default virtual void exec_base(bool new_receiver_inputs=true); // use PRG to get the next ot_length bits + void set_seeds(); void extend_length(); void check(); @@ -90,8 +91,6 @@ class BaseOT bool is_sender() { return (bool) (ot_role & SENDER); } bool is_receiver() { return (bool) (ot_role & RECEIVER); } - - void set_seeds(); }; class FakeOT : public BaseOT diff --git a/OT/BitMatrix.cpp b/OT/BitMatrix.cpp index bf3c32fa0..95687bb00 100644 --- a/OT/BitMatrix.cpp +++ b/OT/BitMatrix.cpp @@ -702,4 +702,5 @@ BMS XXXX(Matrix, gf2n_short) XXXX(Matrix>, gf2n_long) XXXX(Matrix>, gfp1) +XXXX(Matrix>, gfp3) XXXX(Matrix, BitVec) diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index 3b189b2e7..9ddfde0c0 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -142,7 +142,7 @@ class BitMatrix : public Matrix BitMatrix() {} BitMatrix(int length); - __m128i operator[](int i) { return squares[i / 128].rows[i % 128]; } + __m128i& operator[](int i) { return squares[i / 128].rows[i % 128]; } void resize(int length); int size(); diff --git a/OT/MascotParams.cpp b/OT/MascotParams.cpp index c6375fbea..868ea8747 100644 --- a/OT/MascotParams.cpp +++ b/OT/MascotParams.cpp @@ -26,81 +26,15 @@ MascotParams::MascotParams() generateMACs = true; amplify = true; check = true; + correlation_check = true; generateBits = false; + use_extension = true; + fewer_rounds = false; + fiat_shamir = false; timerclear(&start); } void MascotParams::set_passive() { - generateMACs = amplify = check = false; -} - -template<> gf2n_long MascotParams::get_mac_key() -{ - return mac_key2l; -} - -template<> gf2n_short MascotParams::get_mac_key() -{ - return mac_key2s; -} - -template<> gfp1 MascotParams::get_mac_key() -{ - return mac_keyp; -} - -template<> Z2<48> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> Z2<64> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> Z2<40> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> Z2<32> MascotParams::get_mac_key() -{ - return mac_keyz; -} - -template<> BitVec MascotParams::get_mac_key() -{ - return 0; -} - -template<> void MascotParams::set_mac_key(gf2n_long key) -{ - mac_key2l = key; -} - -template<> void MascotParams::set_mac_key(gf2n_short key) -{ - mac_key2s = key; -} - -template<> void MascotParams::set_mac_key(gfp1 key) -{ - mac_keyp = key; -} - -template<> void MascotParams::set_mac_key(Z2<64> key) -{ - mac_keyz = key; -} - -template<> void MascotParams::set_mac_key(Z2<48> key) -{ - mac_keyz = key; -} - -template<> void MascotParams::set_mac_key(Z2<40> key) -{ - mac_keyz = key; + generateMACs = amplify = check = correlation_check = false; } diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index b0e7f6c2b..484112647 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -68,6 +68,8 @@ class OTTripleGenerator : public GeneratorThread SeededPRNG share_prg; + mac_key_type mac_key; + void start_progress(); void print_progress(int k); @@ -101,8 +103,11 @@ class OTTripleGenerator : public GeneratorThread vector> preampTriples; vector> plainTriples; - OTTripleGenerator(OTTripleSetup& setup, const Names& names, + typename T::MAC_Check* MC; + + OTTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, + mac_key_type mac_key, Player* parentPlayer = 0); ~OTTripleGenerator(); @@ -113,7 +118,10 @@ class OTTripleGenerator : public GeneratorThread void run_multipliers(MultJob job); + mac_key_type get_mac_key() const { return mac_key; } + size_t data_sent(); + NamedCommStats comm_stats(); }; template @@ -130,8 +138,9 @@ class NPartyTripleGenerator : public OTTripleGenerator vector< ShareTriple_ > uncheckedTriples; vector>> inputs; - NPartyTripleGenerator(OTTripleSetup& setup, const Names& names, + NPartyTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, + mac_key_type mac_key, Player* parentPlayer = 0); virtual ~NPartyTripleGenerator() {} @@ -159,8 +168,9 @@ class MascotTripleGenerator : public NPartyTripleGenerator public: vector bits; - MascotTripleGenerator(OTTripleSetup& setup, const Names& names, + MascotTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, + mac_key_type mac_key, Player* parentPlayer = 0); }; @@ -181,8 +191,9 @@ class Spdz2kTripleGenerator : public NPartyTripleGenerator U& MC, PRNG& G); public: - Spdz2kTripleGenerator(OTTripleSetup& setup, const Names& names, + Spdz2kTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, + mac_key_type mac_key, Player* parentPlayer = 0); void generateTriples(); @@ -199,4 +210,15 @@ size_t OTTripleGenerator::data_sent() return res; } +template +NamedCommStats OTTripleGenerator::comm_stats() +{ + NamedCommStats res; + if (parentPlayer != &globalPlayer) + res = globalPlayer.comm_stats; + for (auto& player : players) + res += player->comm_stats; + return res; +} + #endif diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index e92ccc948..9460106df 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -5,27 +5,14 @@ #include "OT/OTExtensionWithMatrix.h" #include "OT/OTMultiplier.h" -#include "Math/gfp.h" -#include "Protocols/Share.h" -#include "Protocols/SemiShare.h" -#include "Protocols/Semi2kShare.h" -#include "Protocols/Spdz2kShare.h" #include "Math/operators.h" #include "Tools/Subroutines.h" #include "Protocols/MAC_Check.h" -#include "Protocols/Spdz2kPrep.h" -#include "GC/SemiSecret.h" #include "OT/Triple.hpp" -#include "OT/Rectangle.hpp" #include "OT/OTMultiplier.hpp" #include "Protocols/MAC_Check.hpp" -#include "Protocols/SemiMC.h" -#include "Protocols/MascotPrep.hpp" -#include "Protocols/ReplicatedInput.hpp" #include "Protocols/SemiInput.hpp" -#include "Processor/Input.hpp" -#include "Math/Z2k.hpp" #include #include @@ -43,44 +30,46 @@ void* run_ot_thread(void* ptr) * N.B. setup must not be stored as it will be used by other threads */ template -NPartyTripleGenerator::NPartyTripleGenerator(OTTripleSetup& setup, +NPartyTripleGenerator::NPartyTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, - MascotParams& machine, Player* parentPlayer) : + MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) : OTTripleGenerator(setup, names, thread_num, _nTriples, nloops, - machine, parentPlayer) + machine, mac_key, parentPlayer) { } template -MascotTripleGenerator::MascotTripleGenerator(OTTripleSetup& setup, +MascotTripleGenerator::MascotTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, - MascotParams& machine, Player* parentPlayer) : + MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) : NPartyTripleGenerator(setup, names, thread_num, _nTriples, nloops, - machine, parentPlayer) + machine, mac_key, parentPlayer) { } template -Spdz2kTripleGenerator::Spdz2kTripleGenerator(OTTripleSetup& setup, +Spdz2kTripleGenerator::Spdz2kTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, - MascotParams& machine, Player* parentPlayer) : + MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) : NPartyTripleGenerator(setup, names, thread_num, _nTriples, nloops, - machine, parentPlayer) + machine, mac_key, parentPlayer) { } template -OTTripleGenerator::OTTripleGenerator(OTTripleSetup& setup, +OTTripleGenerator::OTTripleGenerator(const OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, - MascotParams& machine, Player* parentPlayer) : + MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) : globalPlayer(parentPlayer ? *parentPlayer : *new PlainPlayer(names, - thread_num * names.num_players() * names.num_players())), parentPlayer(parentPlayer), thread_num(thread_num), + mac_key(mac_key), my_num(setup.get_my_num()), nloops(nloops), nparties(setup.get_nparties()), - machine(machine) + machine(machine), + MC(0) { nTriplesPerLoop = DIV_CEIL(_nTriples, nloops); nTriples = nTriplesPerLoop * nloops; @@ -208,7 +197,6 @@ void NPartyTripleGenerator::generateInputs(int player) { typedef open_type T; - auto& machine = this->machine; auto& nTriplesPerLoop = this->nTriplesPerLoop; auto& valueBits = this->valueBits; auto& share_prg = this->share_prg; @@ -235,7 +223,7 @@ void NPartyTripleGenerator::generateInputs(int player) GlobalPRNG G(globalPlayer); Share check_sum; inputs.resize(toCheck); - auto mac_key = machine.template get_mac_key(); + auto mac_key = this->get_mac_key(); SemiInput> input(0, globalPlayer); input.reset_all(globalPlayer); vector secrets(toCheck); @@ -289,7 +277,7 @@ void MascotTripleGenerator::generateBitsGf2n() bits.resize(nBitsToCheck); vector to_open(1); vector opened(1); - MAC_Check_ MC(this->machine.template get_mac_key()); + MAC_Check_ MC(this->get_mac_key()); this->start_progress(); @@ -313,7 +301,7 @@ void MascotTripleGenerator::generateBitsGf2n() typename T::clear r; for (int j = 0; j < nBitsToCheck; j++) { - auto mac_sum = valueBits[0].get_bit(j) ? MC.get_alphai() : 0; + auto mac_sum = valueBits[0].get_bit(j) ? this->get_mac_key() : 0; for (int i = 0; i < this->nparties-1; i++) mac_sum += this->ot_multipliers[i]->macs[0][j]; bits[j].set_share(valueBits[0].get_bit(j)); @@ -352,6 +340,13 @@ void MascotTripleGenerator>::generateBits() generateTriples(); } +template<> +inline +void MascotTripleGenerator>::generateBits() +{ + generateTriples(); +} + template void Spdz2kTripleGenerator::generateTriples() { @@ -360,7 +355,6 @@ void Spdz2kTripleGenerator::generateTriples() auto& uncheckedTriples = this->uncheckedTriples; auto& timers = this->timers; - auto& machine = this->machine; auto& nTriplesPerLoop = this->nTriplesPerLoop; auto& valueBits = this->valueBits; auto& share_prg = this->share_prg; @@ -382,7 +376,7 @@ void Spdz2kTripleGenerator::generateTriples() vector< PlainTriple_, Z2, 2> > amplifiedTriples(nTriplesPerLoop); uncheckedTriples.resize(nTriplesPerLoop); MAC_Check_Z2k, Z2, Z2, Share> > MC( - machine.template get_mac_key >()); + this->get_mac_key()); this->start_progress(); @@ -455,7 +449,7 @@ void Spdz2kTripleGenerator::generateTriples() // get piggy-backed random value Z2 r_share = b_padded_bits.get_ptr_to_byte(nTriplesPerLoop, Z2::N_BYTES); Z2 r_mac; - r_mac.mul(r_share, this->machine.template get_mac_key>()); + r_mac.mul(r_share, this->get_mac_key()); for (int i = 0; i < this->nparties-1; i++) r_mac += (ot_multipliers[i])->macs.at(1).at(nTriplesPerLoop); Share> r; @@ -563,16 +557,17 @@ void MascotTripleGenerator::generateTriples() valueBits[2*i].resize(field_size * nPreampTriplesPerLoop); valueBits[1].resize(field_size * nTriplesPerLoop); vector< PlainTriple > amplifiedTriples; - MAC_Check MC(machine.template get_mac_key()); + MAC_Check MC(this->get_mac_key()); if (machine.amplify) preampTriples.resize(nTriplesPerLoop); if (machine.generateMACs) { amplifiedTriples.resize(nTriplesPerLoop); - uncheckedTriples.resize(nTriplesPerLoop); } + uncheckedTriples.resize(nTriplesPerLoop); + this->start_progress(); for (int k = 0; k < nloops; k++) @@ -581,10 +576,15 @@ void MascotTripleGenerator::generateTriples() if (machine.amplify) { - octet seed[SEED_SIZE]; - Create_Random_Seed(seed, globalPlayer, SEED_SIZE); PRNG G; - G.SetSeed(seed); + if (machine.fiat_shamir and nparties == 2) + ot_multipliers[0]->otCorrelator.common_seed(G); + else + { + octet seed[SEED_SIZE]; + Create_Random_Seed(seed, globalPlayer, SEED_SIZE); + G.SetSeed(seed); + } for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) { PlainTriple triple; @@ -598,12 +598,16 @@ void MascotTripleGenerator::generateTriples() triple.output(outputFile); timers["Writing"].stop(); } + else + for (int i = 0; i < 3; i++) + uncheckedTriples[iTriple].byIndex(i, 0).set_share(triple.byIndex(i, 0)); } if (machine.generateMACs) { for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) - amplifiedTriples[iTriple].to(valueBits, iTriple); + amplifiedTriples[iTriple].to(valueBits, iTriple, + machine.check ? 2 : 1); for (int i = 0; i < nparties-1; i++) ot_multipliers[i]->inbox.push({}); @@ -625,7 +629,7 @@ void MascotTripleGenerator::generateTriples() if (machine.check) { - sacrifice(uncheckedTriples, MC, G); + sacrifice(uncheckedTriples, this->MC ? *this->MC : MC, G); } } } diff --git a/OT/OTExtension.cpp b/OT/OTExtension.cpp index e87d919c3..f3634adfc 100644 --- a/OT/OTExtension.cpp +++ b/OT/OTExtension.cpp @@ -259,7 +259,7 @@ void naive_transpose64(vector& output, const vector& input } -OTExtension::OTExtension(BaseOT& baseOT, TwoPartyPlayer* player, +OTExtension::OTExtension(const BaseOT& baseOT, TwoPartyPlayer* player, bool passive) : player(player) { nbaseOTs = baseOT.nOT; diff --git a/OT/OTExtension.h b/OT/OTExtension.h index 05067e0f5..df53eca41 100644 --- a/OT/OTExtension.h +++ b/OT/OTExtension.h @@ -30,7 +30,7 @@ class OTExtension vector receiverOutput; map times; - OTExtension(BaseOT& baseOT, TwoPartyPlayer* player, bool passive); + OTExtension(const BaseOT& baseOT, TwoPartyPlayer* player, bool passive); OTExtension(int nbaseOTs, int baseLength, int nloops, int nsubloops, diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index 1d6b5bcd2..39be81ce9 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -308,7 +308,8 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs) } template -void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput) +void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, + V& receiverOutput, bool correlated) { //cout << "Hashing... " << flush; octetStream os, h_os(HASH_SIZE); @@ -341,7 +342,11 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, V& r for (int j = 0; j < 8; j++) { tmp[0][j] = senderOutputMatrices[0].squares[i_outer_input].rows[i_inner_input + j]; - tmp[1][j] = tmp[0][j] ^ baseReceiverInput.get_int128(0); + if (correlated) + tmp[1][j] = tmp[0][j] ^ baseReceiverInput.get_int128(0); + else + tmp[1][j] = + senderOutputMatrices[1].squares[i_outer_input].rows[i_inner_input + j]; } for (int j = 0; j < 2; j++) mmo.hashBlocks( @@ -366,17 +371,39 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, V& r template template -void OTCorrelator::reduce_squares(unsigned int nTriples, vector& output) +void OTCorrelator::reduce_squares(unsigned int nTriples, vector& output, int start) { - if (receiverOutputMatrix.squares.size() < nTriples) + if (receiverOutputMatrix.squares.size() < nTriples + start) throw invalid_length(); output.resize(nTriples); for (unsigned int j = 0; j < nTriples; j++) { - receiverOutputMatrix.squares[j].sub(senderOutputMatrices[0].squares[j]).to(output[j]); + receiverOutputMatrix.squares[j + start].sub( + senderOutputMatrices[0].squares[j + start]).to(output[j]); } } +template +void OTCorrelator::common_seed(PRNG& G) +{ + Slice t1Slice(t1, 0, t1.squares.size()); + Slice uSlice(u, 0, u.squares.size()); + + octetStream os; + if (player->my_num()) + { + t1Slice.pack(os); + uSlice.pack(os); + } + else + { + uSlice.pack(os); + t1Slice.pack(os); + } + auto hash = os.hash(); + G = PRNG(hash); +} + octet* OTExtensionWithMatrix::get_receiver_output(int i) { return (octet*)&receiverOutputMatrix.squares[i/128].rows[i%128]; @@ -515,34 +542,35 @@ template class OTCorrelator; #define Z(BM,GF) \ template class OTCorrelator; \ template void OTCorrelator::reduce_squares(unsigned int nTriples, \ - vector& output); + vector& output, int); #define ZZZZ(GF) \ template void OTExtensionWithMatrix::print_post_correlate( \ BitVector& newReceiverInput, int j, int offset, int sender); \ #define ZZZ(GF, M) Z(M, GF) \ -template void OTExtensionWithMatrix::hash_outputs(int, vector&, M&); +template void OTExtensionWithMatrix::hash_outputs(int, vector&, M&, bool); ZZZZ(gf2n_long) ZZZ(gf2n_short, Matrix) ZZZ(gf2n_long, Matrix>) ZZZ(gfp1, Matrix>) +ZZZ(gfp3, Matrix>) ZZZ(BitVec, Matrix) #undef XX #define XX(T,U,N,L) \ template class OTCorrelator, Z2 > > >; \ template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ - vector& output); \ + vector& output, int); \ template void OTExtensionWithMatrix::hash_outputs(int, \ std::vector, Z2 > >, std::allocator, Z2 > > > >&, \ - Matrix, Z2 > >&); + Matrix, Z2 > >&, bool); #undef X #define X(N,L) \ template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ - vector >& output); \ + vector >& output, int); \ XX(Z2,Z2,N,L) //X(96, 160) diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h index 29af01f3d..72007fa17 100644 --- a/OT/OTExtensionWithMatrix.h +++ b/OT/OTExtensionWithMatrix.h @@ -45,7 +45,9 @@ class OTCorrelator : public OTExtension U& baseReceiverOutput); void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1); template - void reduce_squares(unsigned int nTriples, vector& output); + void reduce_squares(unsigned int nTriples, vector& output, + int start = 0); + void common_seed(PRNG& G); }; class OTExtensionWithMatrix : public OTCorrelator @@ -80,7 +82,8 @@ class OTExtensionWithMatrix : public OTCorrelator void transpose(int start, int slice); void expand_transposed(); template - void hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput); + void hash_outputs(int nOTs, vector& senderOutput, V& receiverOutput, + bool correlated = true); void print(BitVector& newReceiverInput, int i = 0); template diff --git a/OT/OTMultiplier.hpp b/OT/OTMultiplier.hpp index f40da615e..f5ba18d93 100644 --- a/OT/OTMultiplier.hpp +++ b/OT/OTMultiplier.hpp @@ -7,18 +7,11 @@ #include "OT/OTMultiplier.h" #include "OT/NPartyTripleGenerator.h" -#include "OT/Rectangle.h" -#include "Math/Z2k.h" -#include "Math/BitVec.h" -#include "Protocols/SemiShare.h" -#include "Protocols/Semi2kShare.h" -#include "Protocols/Spdz2kShare.h" +#include "OT/BaseOT.h" #include "OT/OTVole.hpp" #include "OT/Row.hpp" #include "OT/Rectangle.hpp" -#include "Math/Z2k.hpp" -#include "Math/Square.hpp" #include @@ -31,7 +24,8 @@ OTMultiplier::OTMultiplier(OTTripleGenerator& generator, rot_ext(128, 128, 0, 1, generator.players[thread_num], generator.baseReceiverInput, generator.baseSenderInputs[thread_num], - generator.baseReceiverOutputs[thread_num], BOTH, !generator.machine.check), + generator.baseReceiverOutputs[thread_num], BOTH, + !generator.machine.correlation_check), otCorrelator(0, 0, 0, 0, generator.players[thread_num], {}, {}, {}, BOTH, true) { this->thread = 0; @@ -89,7 +83,7 @@ OTMultiplier::~OTMultiplier() template void OTMultiplier::multiply() { - keyBits.set(generator.machine.template get_mac_key()); + keyBits.set(generator.get_mac_key()); rot_ext.extend(keyBits.size(), keyBits); this->outbox.push({}); senderOutput.resize(keyBits.size()); @@ -140,11 +134,6 @@ void OTMultiplier::multiplyForTriples() { typedef typename W::Rectangle X; - // dummy input for OT correlator - vector _; - vector< vector > __; - BitVector ___; - otCorrelator.resize(X::N_COLUMNS * generator.nPreampTriplesPerLoop); rot_ext.resize(X::N_ROWS * generator.nPreampTriplesPerLoop + 2 * 128); @@ -161,8 +150,26 @@ void OTMultiplier::multiplyForTriples() this->inbox.pop(job); BitVector aBits = generator.valueBits[0]; //timers["Extension"].start(); - rot_ext.extend_correlated(aBits); - rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput); + if (generator.machine.use_extension) + { + rot_ext.extend_correlated(aBits); + } + else + { + BaseOT bot(aBits.size(), -1, generator.players[thread_num]); + bot.set_receiver_inputs(aBits); + bot.exec_base(false); + for (size_t i = 0; i < aBits.size(); i++) + { + rot_ext.receiverOutputMatrix[i] = + bot.receiver_outputs[i].get_int128(0).a; + for (int j = 0; j < 2; j++) + rot_ext.senderOutputMatrices[j][i] = + bot.sender_inputs[i][j].get_int128(0).a; + } + } + rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, + baseReceiverOutput, generator.machine.use_extension); //timers["Extension"].stop(); //timers["Correlation"].start(); @@ -215,8 +222,6 @@ void MascotMultiplier::after_correlation() { typedef typename U::open_type T; - this->auth_ot_ext.resize( - this->generator.nPreampTriplesPerLoop * T::Square::N_COLUMNS); this->auth_ot_ext.set_role(BOTH); this->otCorrelator.reduce_squares(this->generator.nPreampTriplesPerLoop, @@ -229,15 +234,45 @@ void MascotMultiplier::after_correlation() this->macs.resize(3); MultJob job; this->inbox.pop(job); + auto& generator = this->generator; + array n_vals; for (int j = 0; j < 3; j++) { - int nValues = this->generator.nTriplesPerLoop; + n_vals[j] = generator.nTriplesPerLoop; if (this->generator.machine.check && (j % 2 == 0)) - nValues *= 2; - this->auth_ot_ext.expand(0, nValues); - this->auth_ot_ext.correlate(0, nValues, - this->generator.valueBits[j], true); - this->auth_ot_ext.reduce_squares(nValues, this->macs[j]); + n_vals[j] *= 2; + } + if (generator.machine.fewer_rounds) + { + BitVector bits; + int total = 0; + for (int j = 0; j < 3; j++) + { + bits.append(generator.valueBits[j], + n_vals[j] * T::Square::N_COLUMNS); + total += n_vals[j]; + } + this->auth_ot_ext.resize(bits.size()); + this->auth_ot_ext.expand(0, total); + this->auth_ot_ext.correlate(0, total, bits, true); + total = 0; + for (int j = 0; j < 3; j++) + { + this->auth_ot_ext.reduce_squares(n_vals[j], this->macs[j], total); + total += n_vals[j]; + } + } + else + { + this->auth_ot_ext.resize(n_vals[0] * T::Square::N_COLUMNS); + for (int j = 0; j < 3; j++) + { + int nValues = n_vals[j]; + this->auth_ot_ext.expand(0, nValues); + this->auth_ot_ext.correlate(0, nValues, + this->generator.valueBits[j], true); + this->auth_ot_ext.reduce_squares(nValues, this->macs[j]); + } } this->outbox.push(job); } diff --git a/OT/OTTripleSetup.cpp b/OT/OTTripleSetup.cpp index bc0a01aba..7cabdf5af 100644 --- a/OT/OTTripleSetup.cpp +++ b/OT/OTTripleSetup.cpp @@ -41,3 +41,22 @@ void OTTripleSetup::close_connections() delete players[i]; } } + +OTTripleSetup OTTripleSetup::get_fresh() +{ + OTTripleSetup res = *this; + for (int i = 0; i < nparties - 1; i++) + { + BaseOT bot(nbase, 128, 0); + bot.sender_inputs = baseSenderInputs[i]; + bot.receiver_outputs = baseReceiverOutputs[i]; + bot.set_seeds(); + bot.extend_length(); + baseSenderInputs[i] = bot.sender_inputs; + baseReceiverOutputs[i] = bot.receiver_outputs; + bot.extend_length(); + res.baseSenderInputs[i] = bot.sender_inputs; + res.baseReceiverOutputs[i] = bot.receiver_outputs; + } + return res; +} diff --git a/OT/OTTripleSetup.h b/OT/OTTripleSetup.h index 52ae7e077..a30b72bdc 100644 --- a/OT/OTTripleSetup.h +++ b/OT/OTTripleSetup.h @@ -11,7 +11,7 @@ */ class OTTripleSetup { - vector base_receiver_inputs; + BitVector base_receiver_inputs; vector baseOTs; PRNG G; @@ -25,10 +25,10 @@ class OTTripleSetup vector< vector< vector > > baseSenderInputs; vector< vector > baseReceiverOutputs; - int get_nparties() { return nparties; } - int get_nbase() { return nbase; } - int get_my_num() { return my_num; } - int get_base_receiver_input(int i) { return base_receiver_inputs[i]; } + int get_nparties() const { return nparties; } + int get_nbase() const { return nbase; } + int get_my_num() const { return my_num; } + int get_base_receiver_input(int i) const { return base_receiver_inputs[i]; } OTTripleSetup(Player& N, bool real_OTs) : nparties(N.num_players()), my_num(N.my_num()), nbase(128) @@ -78,6 +78,8 @@ class OTTripleSetup //template //T get_mac_key(); + + OTTripleSetup get_fresh(); }; diff --git a/OT/Triple.hpp b/OT/Triple.hpp index 842f13b17..a7fd99c76 100644 --- a/OT/Triple.hpp +++ b/OT/Triple.hpp @@ -16,13 +16,16 @@ class Triple T b; T c[N]; - int repeat(int l) + int repeat(int l, bool check) { switch (l) { case 0: case 2: - return N; + if (check) + return N; + else + return 1; case 1: return 1; default: @@ -75,12 +78,12 @@ class PlainTriple : public Triple { public: // this assumes that valueBits[1] is still set to the bits of b - void to(vector& valueBits, int i) + void to(vector& valueBits, int i, int repeat = N) { for (int j = 0; j < N; j++) { - valueBits[0].set_portion(i * N + j, this->a[j]); - valueBits[2].set_portion(i * N + j, this->c[j]); + valueBits[0].set_portion(i * repeat + j, this->a[j]); + valueBits[2].set_portion(i * repeat + j, this->c[j]); } } }; @@ -123,12 +126,12 @@ class ShareTriple_ : public Triple, N> { for (int l = 0; l < 3; l++) { - int repeat = this->repeat(l); + int repeat = this->repeat(l, generator.machine.check); for (int j = 0; j < repeat; j++) { T value = triple.byIndex(l,j); T mac; - mac.mul(value, generator.machine.template get_mac_key()); + mac.mul(value, generator.get_mac_key()); for (int i = 0; i < generator.nparties-1; i++) mac += generator.ot_multipliers[i]->macs.at(l).at(iTriple * repeat + j); Share& share = this->byIndex(l,j); diff --git a/OT/TripleMachine.h b/OT/TripleMachine.h index e49906d90..0dd9dde57 100644 --- a/OT/TripleMachine.h +++ b/OT/TripleMachine.h @@ -16,28 +16,21 @@ class GeneratorThread; class MascotParams : virtual public OfflineParams { -protected: - gf2n_short mac_key2s; - gf2n_long mac_key2l; - gfp1 mac_keyp; - Z2<128> mac_keyz; - public: string prep_data_dir; bool generateMACs; bool amplify; bool check; + bool correlation_check; bool generateBits; + bool use_extension; + bool fewer_rounds; + bool fiat_shamir; struct timeval start, stop; MascotParams(); void set_passive(); - - template - T get_mac_key(); - template - void set_mac_key(T key); }; class TripleMachine : public OfflineMachineBase, public MascotParams @@ -45,6 +38,10 @@ class TripleMachine : public OfflineMachineBase, public MascotParams Names N[2]; int nConnections; + gf2n mac_key2; + gfp1 mac_keyp; + Z2<128> mac_keyz; + public: int nloops; bool primeField; @@ -54,7 +51,8 @@ class TripleMachine : public OfflineMachineBase, public MascotParams TripleMachine(int argc, const char** argv); template - GeneratorThread* new_generator(OTTripleSetup& setup, int i); + GeneratorThread* new_generator(OTTripleSetup& setup, int i, + typename T::mac_key_type mac_key); void run(); diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index a5e227d6f..5adebec57 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -30,7 +30,7 @@ class BaseMachine string progname; int nthreads; - vector> ot_setups; + vector ot_setups; static BaseMachine& s(); diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 663dfecfa..e5671d26d 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -94,6 +94,7 @@ class Preprocessing virtual void purge() {} virtual size_t data_sent() { return 0; } + virtual NamedCommStats comm_stats() { return {}; } virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0; virtual void get_two_no_count(Dtype dtype, T& a, T& b) = 0; @@ -112,6 +113,7 @@ class Preprocessing virtual array get_triple(int n_bits); virtual void buffer_triples() {} + virtual void buffer_inverses() {} }; template diff --git a/Processor/Input.hpp b/Processor/Input.hpp index f2a22350f..8bb376d39 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -13,6 +13,8 @@ #include "FixInput.h" #include "FloatInput.h" +#include "IntInput.hpp" + template InputBase::InputBase(ArithmeticProcessor* proc) : P(0), values_input(0) @@ -295,7 +297,7 @@ void InputBase::input_mixed(SubProcessor& Proc, const vector& args, cout << "Please input " << U::NAME << "s:" << endl; \ prepare(Proc, player, &args[i + U::N_DEST + 1], size); \ break; - X(IntInput) X(FixInput) X(FloatInput) + X(IntInput) X(FixInput) X(FloatInput) #undef X default: throw runtime_error("unknown input type: " + to_string(type)); @@ -317,7 +319,7 @@ void InputBase::input_mixed(SubProcessor& Proc, const vector& args, n_arg_tuple = U::N_DEST + U::N_PARAM + 2; \ finalize(Proc, args[i + n_arg_tuple - 1], &args[i + 1], size); \ break; - X(IntInput) X(FixInput) X(FloatInput) + X(IntInput) X(FixInput) X(FloatInput) #undef X default: throw runtime_error("unknown input type: " + to_string(type)); diff --git a/Processor/Instruction.h b/Processor/Instruction.h index f4772da26..6d22f2c7c 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -61,6 +61,9 @@ enum USE_PREP = 0x1C, STARTGRIND = 0x1D, STOPGRIND = 0x1E, + NPLAYERS = 0xE2, + THRESHOLD = 0xE3, + PLAYERID = 0xE4, // Addition ADDC = 0x20, ADDS = 0x21, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 8fd8fdbb9..a7d9a36f7 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -177,6 +177,9 @@ void BaseInstruction::parse_operands(istream& s, int pos) case PRINTCHRINT: case PRINTSTRINT: case PRINTINT: + case NPLAYERS: + case THRESHOLD: + case PLAYERID: r[0]=get_int(s); break; // instructions with 3 registers + 1 integer operand @@ -442,6 +445,9 @@ int BaseInstruction::get_reg_type() const case CONVMODP: case GCONVGF2N: case RAND: + case NPLAYERS: + case THRESHOLD: + case PLAYERID: return INT; case PREP: case USE_PREP: @@ -1046,10 +1052,10 @@ inline void Instruction::execute(Processor& Proc) const Proc.temp.ans2.output(Proc.private_output, false); break; case INPUT: - sint::Input::template input(Proc.Procp, start, size); + sint::Input::template input>(Proc.Procp, start, size); return; case GINPUT: - sgf2n::Input::template input(Proc.Proc2, start, size); + sgf2n::Input::template input>(Proc.Proc2, start, size); return; case INPUTFIX: sint::Input::template input(Proc.Procp, start, size); @@ -1404,6 +1410,15 @@ inline void Instruction::execute(Processor& Proc) const case STOPGRIND: CALLGRIND_STOP_INSTRUMENTATION; break; + case NPLAYERS: + Proc.write_Ci(r[0], Proc.P.num_players()); + break; + case THRESHOLD: + Proc.write_Ci(r[0], sint::threshold(Proc.P.num_players())); + break; + case PLAYERID: + Proc.write_Ci(r[0], Proc.P.my_num()); + break; // *** // TODO: read/write shared GF(2^n) data instructions // *** diff --git a/Processor/IntInput.cpp b/Processor/IntInput.cpp deleted file mode 100644 index 959745da6..000000000 --- a/Processor/IntInput.cpp +++ /dev/null @@ -1,14 +0,0 @@ -/* - * IntInput.cpp - * - */ - -#include "IntInput.h" - -const char* IntInput::NAME = "integer"; - -void IntInput::read(std::istream& in, const int* params) -{ - (void) params; - in >> items[0]; -} diff --git a/Processor/IntInput.h b/Processor/IntInput.h index 2881e0c34..c550cf450 100644 --- a/Processor/IntInput.h +++ b/Processor/IntInput.h @@ -8,6 +8,7 @@ #include +template class IntInput { public: @@ -17,7 +18,7 @@ class IntInput const static int TYPE = 0; - long items[N_DEST]; + T items[N_DEST]; void read(std::istream& in, const int* params); }; diff --git a/Processor/IntInput.hpp b/Processor/IntInput.hpp new file mode 100644 index 000000000..97bc7c0b5 --- /dev/null +++ b/Processor/IntInput.hpp @@ -0,0 +1,15 @@ +/* + * IntInput.cpp + * + */ + +#include "IntInput.h" + +template +const char* IntInput::NAME = "integer"; + +template +void IntInput::read(std::istream& in, const int*) +{ + in >> items[0]; +} diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index cbcfa9412..bc4eef067 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -127,10 +127,8 @@ Machine::Machine(int my_number, Names& playerNames, P = new CryptoPlayer(playerNames, 0xF000); else P = new PlainPlayer(playerNames, 0xF000); - ot_setups.resize(nthreads); for (int i = 0; i < nthreads; i++) - for (int j = 0; j < 3; j++) - ot_setups.at(i).push_back({ *P, true }); + ot_setups.push_back({ *P, true }); delete P; } diff --git a/Processor/Processor.h b/Processor/Processor.h index 4baa92818..c3217c095 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -52,8 +52,6 @@ class SubProcessor Preprocessing& DataF, Player& P); // Access to PO (via calls to POpen start/stop) - void POpen_Start(const vector& reg,const Player& P,int size); - void POpen_Stop(const vector& reg,const Player& P,int size); void POpen(const vector& reg,const Player& P,int size); void muls(const vector& reg, int size); diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 0a96bd74f..ba504a579 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -6,6 +6,7 @@ #include "Protocols/ReplicatedInput.hpp" #include "Protocols/ReplicatedPrivateOutput.hpp" +#include "Processor/ProcessorBase.hpp" #include #include @@ -406,15 +407,16 @@ void Processor::write_shares_to_file(const vector& data_regist } template -void SubProcessor::POpen_Start(const vector& reg,const Player& P,int size) +void SubProcessor::POpen(const vector& reg,const Player& P,int size) { - int sz=reg.size(); + assert(reg.size() % 2 == 0); + int sz=reg.size() / 2; Sh_PO.clear(); Sh_PO.reserve(sz*size); if (size>1) { - for (typename vector::const_iterator reg_it=reg.begin(); - reg_it!=reg.end(); reg_it++) + for (typename vector::const_iterator reg_it=reg.begin() + 1; + reg_it < reg.end(); reg_it += 2) { auto begin=S.begin()+*reg_it; Sh_PO.insert(Sh_PO.end(),begin,begin+size); @@ -423,24 +425,15 @@ void SubProcessor::POpen_Start(const vector& reg,const Player& P,int siz else { for (int i=0; i -void SubProcessor::POpen_Stop(const vector& reg,const Player& P,int size) -{ - int sz=reg.size(); - PO.resize(sz*size); - MC.POpen_End(PO,Sh_PO,P); + MC.POpen(PO,Sh_PO,P); if (size>1) { auto PO_it=PO.begin(); for (typename vector::const_iterator reg_it=reg.begin(); - reg_it!=reg.end(); reg_it++) + reg_it!=reg.end(); reg_it += 2) { for (auto C_it=C.begin()+*reg_it; C_it!=C.begin()+*reg_it+size; C_it++) @@ -452,36 +445,16 @@ void SubProcessor::POpen_Stop(const vector& reg,const Player& P,int size } else { - for (unsigned int i=0; i& dest, vector& source, const vector& reg) -{ - int n = reg.size() / 2; - source.resize(n); - dest.resize(n); - for (int i = 0; i < n; i++) - { - source[i] = reg[2 * i + 1]; - dest[i] = reg[2 * i]; - } -} - -template -void SubProcessor::POpen(const vector& reg, const Player& P, - int size) -{ - vector source, dest; - unzip_open(dest, source, reg); - POpen_Start(source, P, size); - POpen_Stop(dest, P, size); -} - template void SubProcessor::muls(const vector& reg, int size) { diff --git a/Processor/ProcessorBase.cpp b/Processor/ProcessorBase.hpp similarity index 87% rename from Processor/ProcessorBase.cpp rename to Processor/ProcessorBase.hpp index 03f408137..7d17b4285 100644 --- a/Processor/ProcessorBase.cpp +++ b/Processor/ProcessorBase.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROCESSOR_PROCESSORBASE_HPP_ +#define PROCESSOR_PROCESSORBASE_HPP_ + #include "ProcessorBase.h" #include "IntInput.h" #include "FixInput.h" @@ -11,6 +14,7 @@ #include +inline void ProcessorBase::open_input_file(const string& name) { #ifdef DEBUG_FILES @@ -20,6 +24,7 @@ void ProcessorBase::open_input_file(const string& name) input_filename = name; } +inline void ProcessorBase::open_input_file(int my_num, int thread_num) { string input_file = "Player-Data/Input-P" + to_string(my_num) + "-" + to_string(thread_num); @@ -54,6 +59,4 @@ T ProcessorBase::get_input(istream& input_file, const string& input_filename, co return res; } -template IntInput ProcessorBase::get_input(bool, const int*); -template FixInput ProcessorBase::get_input(bool, const int*); -template FloatInput ProcessorBase::get_input(bool, const int*); +#endif diff --git a/Programs/Source/aes.mpc b/Programs/Source/aes.mpc index 191a15ee0..aaa6f1d2f 100644 --- a/Programs/Source/aes.mpc +++ b/Programs/Source/aes.mpc @@ -129,7 +129,7 @@ def expandAESKey(cipherKey, Nr = 10, Nb = 4, Nk = 4): temp[2] = box.apply_sbox(temp[2]) temp[3] = box.apply_sbox(temp[3]) - temp[0] = temp[0] + ApplyEmbedding(rcon[int(i/Nk)]) + temp[0] = temp[0] + ApplyEmbedding(rcon[int(i//Nk)]) for j in range(4): round_key[4 * i + j] = round_key[4 * (i - Nk) + j] + temp[j] @@ -233,7 +233,7 @@ def inverseMod(val): for idx in range(40): if idx % 5 == 0: - bd_val[idx] = raw_bit_dec[idx / 5] + bd_val[idx] = raw_bit_dec[idx // 5] bd_squared = bd_val squared_index = 2 diff --git a/Programs/Source/bankers_bonus_commsec.mpc b/Programs/Source/bankers_bonus_commsec.mpc index 99df66b0e..45a47ef95 100644 --- a/Programs/Source/bankers_bonus_commsec.mpc +++ b/Programs/Source/bankers_bonus_commsec.mpc @@ -118,10 +118,10 @@ def main(): return True if n_rounds > 0: - print 'run %d rounds' % n_rounds + print('run %d rounds' % n_rounds) for_range(n_rounds)(game_loop) else: - print 'run forever' + print('run forever') do_while(game_loop) main() diff --git a/Programs/Source/gc_and.mpc b/Programs/Source/gc_and.mpc index be9fd8f0a..3da453d63 100644 --- a/Programs/Source/gc_and.mpc +++ b/Programs/Source/gc_and.mpc @@ -13,7 +13,7 @@ if len(program.args) > 2: m = int(program.args[2]) pack = min(n, 50) -n = (n + pack - 1) / pack +n = (n + pack - 1) // pack a = sbit(1) b = sbit(1, n=pack) diff --git a/Programs/Source/htmac.mpc b/Programs/Source/htmac.mpc index c796469f2..310e20764 100644 --- a/Programs/Source/htmac.mpc +++ b/Programs/Source/htmac.mpc @@ -50,9 +50,9 @@ test_decryption = True instructions_base.set_global_vector_size(n_parallel) if use_mimc_prf: - execfile('./Programs/Source/prf_mimc.mpc') + exec(compile(__builtins__['open']('./Programs/Source/prf_mimc.mpc').read(), './Programs/Source/prf_mimc.mpc', 'exec')) elif use_leg_prf: - execfile('./Programs/Source/prf_leg.mpc') + exec(compile(__builtins__['open']('./Programs/Source/prf_leg.mpc').read(), './Programs/Source/prf_leg.mpc', 'exec')) class HMAC(object): def __init__(self, _enc): @@ -97,7 +97,7 @@ class NonceEncryptMAC(object): def get_long_random(self, nbits): """ Returns random cint() % 2^{nbits} """ result = cint(0) - for i in range(nbits / 30): + for i in range(nbits // 30): result += cint(regint.get_random(30)) result <<= 30 @@ -178,7 +178,7 @@ def time_private_mac(n_total, n_parallel, nmessages): # Benchmark n_total HtMAC's while executing in parallel n_parallel start_timer(1) - @for_range(n_total / n_parallel) + @for_range(n_total // n_parallel) def block(index): # Re-use off-line data after n_parallel runs for benchmarking purposes. # If real system-use need to initialize num_calls with a larger constant. diff --git a/Programs/Source/regression.mpc b/Programs/Source/regression.mpc index 992a5db0e..a599782cd 100644 --- a/Programs/Source/regression.mpc +++ b/Programs/Source/regression.mpc @@ -8,7 +8,7 @@ ml.set_n_threads(8) debug = False if 'halfprec' in program.args: - print '8-bit precision after point' + print('8-bit precision after point') sfix.set_precision(8, 31) cfix.set_precision(8, 31) else: @@ -26,7 +26,7 @@ n_features = 12634 if len(program.args) > 2: if 'bc' in program.args: - print 'Compiling for BC-TCGA' + print('Compiling for BC-TCGA') n_examples = 472 n_normal = 49 n_features = 17814 @@ -41,7 +41,7 @@ try: except: pass -print 'Using %d threads' % ml.Layer.n_threads +print('Using %d threads' % ml.Layer.n_threads) n_fold = 5 test_share = 1. / n_fold @@ -63,8 +63,8 @@ else: n_test = sum(n_tests) indices = [regint.Array(n) for n in n_ex] -indices[0].assign(range(n_pos, n_pos + n_normal)) -indices[1].assign(range(n_pos)) +indices[0].assign(list(range(n_pos, n_pos + n_normal))) +indices[1].assign(list(range(n_pos))) test = regint.Array(n_test) @@ -97,7 +97,7 @@ for arg in program.args: m = re.match('tol=(.*)', arg) if m: sgd.tol = float(m.group(1)) - print 'Stop with tolerance', sgd.tol + print('Stop with tolerance', sgd.tol) sum_acc = cfix.MemValue(0) diff --git a/Programs/Source/test_sbitfix.mpc b/Programs/Source/test_sbitfix.mpc index b8dd6f4fe..22be93c7a 100644 --- a/Programs/Source/test_sbitfix.mpc +++ b/Programs/Source/test_sbitfix.mpc @@ -6,8 +6,8 @@ sbitfix.set_precision(16, 32) def test(a, b, value_type=None): try: b = int(round((b * (1 << a.f)))) - if b < 0: - b += 2 ** sbitfix.k + if b < 0: + b += 2 ** sbitfix.k a = a.v.reveal() except AttributeError: pass diff --git a/Programs/Source/vickrey.mpc b/Programs/Source/vickrey.mpc index ba02c80ab..2e7e0c7ea 100644 --- a/Programs/Source/vickrey.mpc +++ b/Programs/Source/vickrey.mpc @@ -55,12 +55,12 @@ def f(_): def thread(): i = get_arg() - n_per_thread = n_inputs / n_threads + n_per_thread = n_inputs // n_threads if n_per_thread % 2 != 0: raise Exception('Number of inputs must be divisible by 2') start = i * n_per_thread tuples = [bid_sort(bids[start+2*j], bids[start+2*j+1]) \ - for j in range(n_per_thread / 2)] + for j in range(n_per_thread // 2)] first, second = util.tree_reduce(first_and_second, tuples) results[2*i], results[2*i+1] = first, second diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index 9d2346283..9e7c3f81d 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -29,6 +29,8 @@ class Beaver : public ProtocolBase typename T::MAC_Check* MC; public: + static const bool uses_triples = true; + Player& P; Beaver(Player& P) : prep(0), MC(0), P(P) {} @@ -39,6 +41,9 @@ class Beaver : public ProtocolBase void exchange(); T finalize_mul(int n = -1); + void start_exchange(); + void stop_exchange(); + int get_n_relevant_players() { return P.num_players(); } }; diff --git a/Protocols/Beaver.hpp b/Protocols/Beaver.hpp index 66d53f387..0ed322f2d 100644 --- a/Protocols/Beaver.hpp +++ b/Protocols/Beaver.hpp @@ -50,6 +50,20 @@ void Beaver::exchange() triple = triples.begin(); } +template +void Beaver::start_exchange() +{ + MC->POpen_Begin(opened, shares, P); +} + +template +void Beaver::stop_exchange() +{ + MC->POpen_End(opened, shares, P); + it = opened.begin(); + triple = triples.begin(); +} + template T Beaver::finalize_mul(int n) { diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h new file mode 100644 index 000000000..fda19a0db --- /dev/null +++ b/Protocols/HemiPrep.h @@ -0,0 +1,39 @@ +/* + * HemiPrep.h + * + */ + +#ifndef PROTOCOLS_HEMIPREP_H_ +#define PROTOCOLS_HEMIPREP_H_ + +#include "ReplicatedPrep.h" +#include "FHEOffline/Multiplier.h" + +template +class HemiPrep : public SemiHonestRingPrep +{ + typedef typename T::clear::FD FD; + + static PairwiseMachine* pairwise_machine; + static Lock lock; + + vector*> multipliers; + + SeededPRNG G; + + map timers; + +public: + static void basic_setup(Player& P); + static void teardown(); + + HemiPrep(SubProcessor* proc, DataPositions& usage) : + RingPrep(proc, usage), SemiHonestRingPrep(proc, usage) + { + } + + void buffer_triples(); + void buffer_inverses(); +}; + +#endif /* PROTOCOLS_HEMIPREP_H_ */ diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp new file mode 100644 index 000000000..6621f744b --- /dev/null +++ b/Protocols/HemiPrep.hpp @@ -0,0 +1,87 @@ +/* + * HemiPrep.hpp + * + */ + +#ifndef PROTOCOLS_HEMIPREP_HPP_ +#define PROTOCOLS_HEMIPREP_HPP_ + +#include "HemiPrep.h" +#include "FHEOffline/PairwiseMachine.h" +#include "Tools/Bundle.h" + +template +PairwiseMachine* HemiPrep::pairwise_machine = 0; + +template +Lock HemiPrep::lock; + +template +void HemiPrep::teardown() +{ + if (pairwise_machine) + delete pairwise_machine; +} + +template +void HemiPrep::basic_setup(Player& P) +{ + assert(pairwise_machine == 0); + pairwise_machine = new PairwiseMachine(P); + auto& machine = *pairwise_machine; + auto& setup = machine.setup(); + setup.secure_init(P, machine, T::clear::length(), 40); +} + +template +void HemiPrep::buffer_triples() +{ + assert(this->proc != 0); + auto& P = this->proc->P; + + lock.lock(); + if (pairwise_machine == 0 or pairwise_machine->enc_alphas.empty()) + { + PlainPlayer P(this->proc->P.N, T::clear::type_char()); + if (pairwise_machine == 0) + basic_setup(P); + pairwise_machine->setup().covert_key_generation(P, + *pairwise_machine, 1); + pairwise_machine->enc_alphas.resize(1, pairwise_machine->pk); + } + lock.unlock(); + + if (multipliers.empty()) + for (int i = 1; i < P.num_players(); i++) + multipliers.push_back( + new Multiplier(i, *pairwise_machine, P, timers)); + + auto& FieldD = pairwise_machine->setup().FieldD; + Plaintext_ a(FieldD), b(FieldD), c(FieldD); + a.randomize(G); + b.randomize(G); + c.mul(a, b); + Bundle bundle(P); + pairwise_machine->pk.encrypt(a).pack(bundle.mine); + P.Broadcast_Receive(bundle, true); + Ciphertext C(pairwise_machine->pk); + for (auto m : multipliers) + { + C.unpack(bundle[P.get_player(-m->get_offset())]); + m->multiply_and_add(c, C, b); + } + assert(b.num_slots() == a.num_slots()); + assert(c.num_slots() == a.num_slots()); + for (unsigned i = 0; i < a.num_slots(); i++) + this->triples.push_back( + {{ a.element(i), b.element(i), c.element(i) }}); +} + +template +void HemiPrep::buffer_inverses() +{ + assert(this->proc != 0); + ::buffer_inverses(this->inverses, *this, this->proc->MC, this->proc->P); +} + +#endif diff --git a/Protocols/HemiShare.h b/Protocols/HemiShare.h new file mode 100644 index 000000000..e51c0f61e --- /dev/null +++ b/Protocols/HemiShare.h @@ -0,0 +1,40 @@ +/* + * HemiShare.h + * + */ + +#ifndef PROTOCOLS_HEMISHARE_H_ +#define PROTOCOLS_HEMISHARE_H_ + +#include "SemiShare.h" + +template class HemiPrep; + +template +class HemiShare : public SemiShare +{ + typedef HemiShare This; + typedef SemiShare super; + +public: + typedef SemiMC MAC_Check; + typedef DirectSemiMC Direct_MC; + typedef SemiInput Input; + typedef ::PrivateOutput PrivateOutput; + typedef SPDZ Protocol; + typedef HemiPrep LivePrep; + + static const bool needs_ot = false; + + HemiShare() + { + } + template + HemiShare(const U& other) : + super(other) + { + } + +}; + +#endif /* PROTOCOLS_HEMISHARE_H_ */ diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index d520504fd..29f509f64 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -188,7 +188,7 @@ class Parallel_MAC_Check: public Separate_MAC_Check> template -class Direct_MAC_Check: public Separate_MAC_Check +class Direct_MAC_Check: public MAC_Check_ { typedef typename T::open_type open_type; @@ -196,7 +196,9 @@ class Direct_MAC_Check: public Separate_MAC_Check vector oss; public: - Direct_MAC_Check(const typename T::mac_key_type& ai, Names& Nms, int thread_num); + // legacy interface + Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai, Names& Nms, int thread_num); + Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai); ~Direct_MAC_Check(); void POpen_Begin(vector& values,const vector& S,const Player& P); diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index e0f930488..06b360d05 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -487,7 +487,16 @@ void Parallel_MAC_Check::POpen_End(vector& values, template -Direct_MAC_Check::Direct_MAC_Check(const typename T::mac_key_type& ai, Names& Nms, int num) : Separate_MAC_Check(ai, Nms, num) { +Direct_MAC_Check::Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai, + Names&, int) : + Direct_MAC_Check(ai) +{ +} + +template +Direct_MAC_Check::Direct_MAC_Check(const typename T::mac_key_type::Scalar& ai) : + MAC_Check_(ai) +{ open_counter = 0; } @@ -532,9 +541,7 @@ void Direct_MAC_Check::POpen_End(vector& values,const vector& S this->timers[RECV].start(); - for (int j=0; jtimers[RECV].stop(); open_counter++; diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index d231cfac7..1b2393f4b 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -32,8 +32,10 @@ class MAC_Check_Base virtual void POpen_Begin(vector& values,const vector& S,const Player& P) = 0; virtual void POpen_End(vector& values,const vector& S,const Player& P) = 0; - void POpen(vector& values,const vector& S,const Player& P); + virtual void POpen(vector& values,const vector& S,const Player& P); typename T::open_type POpen(const T& secret, const Player& P); + // alternative name to avoid conflict + typename T::open_type open(const T& secret, const Player& P) { return POpen(secret, P); } virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); diff --git a/Protocols/MaliciousRepMC.h b/Protocols/MaliciousRepMC.h index 03c4ee1d0..3457426d0 100644 --- a/Protocols/MaliciousRepMC.h +++ b/Protocols/MaliciousRepMC.h @@ -15,6 +15,8 @@ class MaliciousRepMC : public ReplicatedMC typedef ReplicatedMC super; public: + virtual void POpen(vector& values, + const vector& S, const Player& P); virtual void POpen_Begin(vector& values, const vector& S, const Player& P); virtual void POpen_End(vector& values, @@ -35,6 +37,8 @@ class HashMaliciousRepMC : public MaliciousRepMC void reset(); void update(); + void finalize(const vector& values); + public: // emulate MAC_Check HashMaliciousRepMC(const typename T::value_type& _, int __ = 0, int ___ = 0) : HashMaliciousRepMC() @@ -47,6 +51,7 @@ class HashMaliciousRepMC : public MaliciousRepMC HashMaliciousRepMC(); ~HashMaliciousRepMC(); + void POpen(vector& values,const vector& S,const Player& P); void POpen_End(vector& values,const vector& S,const Player& P); void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); @@ -60,6 +65,8 @@ class CommMaliciousRepMC : public MaliciousRepMC vector os; public: + void POpen(vector& values, const vector& S, + const Player& P); void POpen_Begin(vector& values, const vector& S, const Player& P); void POpen_End(vector& values, const vector& S, diff --git a/Protocols/MaliciousRepMC.hpp b/Protocols/MaliciousRepMC.hpp index 32d8af011..c9cc1850d 100644 --- a/Protocols/MaliciousRepMC.hpp +++ b/Protocols/MaliciousRepMC.hpp @@ -30,6 +30,13 @@ void MaliciousRepMC::POpen_End(vector& values, throw runtime_error("use subclass"); } +template +void MaliciousRepMC::POpen(vector&, + const vector&, const Player&) +{ + throw runtime_error("use subclass"); +} + template void MaliciousRepMC::Check(const Player& P) { @@ -60,11 +67,25 @@ HashMaliciousRepMC::~HashMaliciousRepMC() free(hash_state); } +template +void HashMaliciousRepMC::POpen(vector& values, + const vector& S, const Player& P) +{ + ReplicatedMC::POpen(values, S, P); + finalize(values); +} + template void HashMaliciousRepMC::POpen_End(vector& values, const vector& S, const Player& P) { ReplicatedMC::POpen_End(values, S, P); + finalize(values); +} + +template +void HashMaliciousRepMC::finalize(const vector& values) +{ os.reset_write_head(); for (auto& value : values) value.pack(os); @@ -118,6 +139,14 @@ void HashMaliciousRepMC::Check(const Player& P) } } +template +void CommMaliciousRepMC::POpen(vector& values, + const vector& S, const Player& P) +{ + POpen_Begin(values, S, P); + POpen_End(values, S, P); +} + template void CommMaliciousRepMC::POpen_Begin(vector& values, const vector& S, const Player& P) diff --git a/Protocols/MaliciousShamirMC.h b/Protocols/MaliciousShamirMC.h index 999e77b3f..b9ca6b157 100644 --- a/Protocols/MaliciousShamirMC.h +++ b/Protocols/MaliciousShamirMC.h @@ -13,6 +13,9 @@ class MaliciousShamirMC : public ShamirMC { vector> reconstructions; + void finalize(vector& values, const vector& S, + const Player& P); + public: MaliciousShamirMC(); @@ -28,6 +31,8 @@ class MaliciousShamirMC : public ShamirMC { (void)_; (void)__; (void)___; (void)____; } + void POpen(vector& values, const vector& S, + const Player& P); void POpen_End(vector& values, const vector& S, const Player& P); }; diff --git a/Protocols/MaliciousShamirMC.hpp b/Protocols/MaliciousShamirMC.hpp index c2bc0bf79..0236f4ead 100644 --- a/Protocols/MaliciousShamirMC.hpp +++ b/Protocols/MaliciousShamirMC.hpp @@ -12,11 +12,27 @@ MaliciousShamirMC::MaliciousShamirMC() this->threshold = 2 * ShamirMachine::s().threshold; } +template +void MaliciousShamirMC::POpen(vector& values, + const vector& S, const Player& P) +{ + this->prepare(S, P); + this->exchange(P); + finalize(values, S, P); +} + template void MaliciousShamirMC::POpen_End(vector& values, const vector& S, const Player& P) { - (void) P; + P.receive_all(this->os); + finalize(values, S, P); +} + +template +void MaliciousShamirMC::finalize(vector& values, + const vector& S, const Player& P) +{ int threshold = ShamirMachine::s().threshold; if (reconstructions.empty()) { @@ -36,7 +52,10 @@ void MaliciousShamirMC::POpen_End(vector& values, for (size_t i = 0; i < values.size(); i++) { for (size_t j = 0; j < shares.size(); j++) - shares[j].unpack(this->os[j]); + if (int(j) == P.my_num()) + shares[j] = S[i]; + else + shares[j].unpack(this->os[j]); T value = 0; for (int j = 0; j < threshold + 1; j++) value += shares[j] * reconstructions[threshold + 1][j]; diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 721b1f611..19a4f7cbf 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -13,10 +13,9 @@ template class OTPrep : public virtual RingPrep { -protected: +public: typename T::TripleGenerator* triple_generator; -public: MascotParams params; OTPrep(SubProcessor* proc, DataPositions& usage); @@ -25,6 +24,7 @@ class OTPrep : public virtual RingPrep void set_protocol(typename T::Protocol& protocol); size_t data_sent(); + NamedCommStats comm_stats(); }; template diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index 3ddd06edd..9a94f484f 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -17,7 +17,6 @@ template OTPrep::OTPrep(SubProcessor* proc, DataPositions& usage) : RingPrep(proc, usage), triple_generator(0) { - this->buffer_size = OnlineOptions::singleton.batch_size; } template @@ -34,13 +33,11 @@ void OTPrep::set_protocol(typename T::Protocol& protocol) SubProcessor* proc = this->proc; assert(proc != 0); auto& ot_setups = BaseMachine::s().ot_setups.at(proc->Proc.thread_num); - assert(not ot_setups.empty()); - OTTripleSetup setup = ot_setups.back(); - ot_setups.pop_back(); - params.set_mac_key(typename T::mac_key_type::next(proc->MC.get_alphai())); + OTTripleSetup setup = ot_setups.get_fresh(); triple_generator = new typename T::TripleGenerator(setup, - proc->P.N, proc->Proc.thread_num, this->buffer_size, 1, - params, &proc->P); + proc->P.N, proc->Proc.thread_num, + OnlineOptions::singleton.batch_size, 1, + params, proc->MC.get_alphai(), &proc->P); triple_generator->multi_threaded = false; } @@ -119,4 +116,13 @@ size_t OTPrep::data_sent() return 0; } +template +NamedCommStats OTPrep::comm_stats() +{ + if (triple_generator) + return triple_generator->comm_stats(); + else + return {}; +} + #endif diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 05b071474..ba0a3a730 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -42,6 +42,11 @@ class Rep3Share : public FixedVec return "replicated " + T::type_string(); } + static int threshold(int) + { + return 1; + } + static Rep3Share constant(T value, int my_num, const T& alphai = {}) { return Rep3Share(value, my_num, alphai); diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 8612aa9d5..f41bf98b2 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -8,6 +8,7 @@ #include #include +#include using namespace std; #include "Tools/octetStream.h" @@ -26,11 +27,14 @@ template class Preprocessing; class ReplicatedBase { public: - PRNG shared_prngs[2]; + array shared_prngs; Player& P; ReplicatedBase(Player& P); + ReplicatedBase(Player& P, array& prngs); + + ReplicatedBase branch(); int get_n_relevant_players() { return P.num_players() - 1; } }; @@ -62,6 +66,9 @@ class ProtocolBase virtual void trunc_pr(const vector& regs, int size, SubProcessor& proc) { (void) regs, (void) size; (void) proc; throw not_implemented(); } + + virtual void start_exchange() { exchange(); } + virtual void stop_exchange() {} }; template @@ -75,7 +82,10 @@ class Replicated : public ReplicatedBase, public ProtocolBase typedef ReplicatedMC MAC_Check; typedef ReplicatedInput Input; + static const bool uses_triples = false; + Replicated(Player& P); + Replicated(const ReplicatedBase& other); static void assign(T& share, const typename T::clear& value, int my_num) { @@ -103,6 +113,9 @@ class Replicated : public ReplicatedBase, public ProtocolBase void trunc_pr(const vector& regs, int size, SubProcessor& proc); T get_random(); + + void start_exchange(); + void stop_exchange(); }; #endif /* PROTOCOLS_REPLICATED_H_ */ diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index aeafc6706..c9e8271c9 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -29,6 +29,12 @@ Replicated::Replicated(Player& P) : ReplicatedBase(P) assert(T::length == 2); } +template +Replicated::Replicated(const ReplicatedBase& other) : + ReplicatedBase(other) +{ +} + inline ReplicatedBase::ReplicatedBase(Player& P) : P(P) { assert(P.num_players() == 3); @@ -43,6 +49,18 @@ inline ReplicatedBase::ReplicatedBase(Player& P) : P(P) shared_prngs[1].SetSeed(os.get_data()); } +inline ReplicatedBase::ReplicatedBase(Player& P, array& prngs) : + P(P) +{ + for (int i = 0; i < 2; i++) + shared_prngs[i].SetSeed(prngs[i]); +} + +inline ReplicatedBase ReplicatedBase::branch() +{ + return {P, shared_prngs}; +} + template ProtocolBase::~ProtocolBase() { @@ -128,6 +146,18 @@ void Replicated::exchange() P.pass_around(os[0], os[1], 1); } +template +void Replicated::start_exchange() +{ + P.send_relative(1, os[0]); +} + +template +void Replicated::stop_exchange() +{ + P.receive_relative(-1, os[1]); +} + template inline T Replicated::finalize_mul(int n) { diff --git a/Protocols/ReplicatedMC.h b/Protocols/ReplicatedMC.h index 496e4ff01..6bc657bc5 100644 --- a/Protocols/ReplicatedMC.h +++ b/Protocols/ReplicatedMC.h @@ -14,6 +14,9 @@ class ReplicatedMC : public MAC_Check_Base octetStream o; octetStream to_send; + void prepare(const vector& S); + void finalize(vector& values, const vector& S); + public: // emulate MAC_Check ReplicatedMC(const typename T::value_type& _ = {}, int __ = 0, int ___ = 0) @@ -23,6 +26,7 @@ class ReplicatedMC : public MAC_Check_Base ReplicatedMC(const typename T::value_type& _, Names& ____, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; (void)____; } + void POpen(vector& values,const vector& S,const Player& P); void POpen_Begin(vector& values,const vector& S,const Player& P); void POpen_End(vector& values,const vector& S,const Player& P); diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index 1f3c781d1..f3924a644 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -9,23 +9,44 @@ #include "ReplicatedMC.h" template -void ReplicatedMC::POpen_Begin(vector& values, +void ReplicatedMC::POpen(vector& values, const vector& S, const Player& P) +{ + prepare(S); + P.pass_around(to_send, o, -1); + finalize(values, S); +} + +template +void ReplicatedMC::POpen_Begin(vector&, + const vector& S, const Player& P) +{ + prepare(S); + P.send_relative(-1, to_send); +} + +template +void ReplicatedMC::prepare(const vector& S) { assert(T::length == 2); - (void)values; o.reset_write_head(); to_send.reset_write_head(); for (auto& x : S) x[0].pack(to_send); - P.pass_around(to_send, o, -1); } template void ReplicatedMC::POpen_End(vector& values, const vector& S, const Player& P) { - (void)P; + P.receive_relative(1, o); + finalize(values, S); +} + +template +void ReplicatedMC::finalize(vector& values, + const vector& S) +{ values.resize(S.size()); for (size_t i = 0; i < S.size(); i++) { diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 5203690cb..f7b601a68 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -54,7 +54,8 @@ template void ReplicatedRingPrep::buffer_triples() { assert(this->protocol != 0); - typename T::Protocol protocol(this->protocol->P); + // independent instance to avoid conflicts + typename T::Protocol protocol(this->protocol->branch()); generate_triples(this->triples, OnlineOptions::singleton.batch_size, &protocol); } @@ -264,8 +265,7 @@ void RingPrep::buffer_bits_without_check() int n_relevant_players = protocol->get_n_relevant_players(); vector> player_bits(n_relevant_players, vector(buffer_size)); typename T::Input input(proc, P); - for (int i = 0; i < P.num_players(); i++) - input.reset(i); + input.reset_all(P); for (int i = 0; i < n_relevant_players; i++) { int input_player = (base_player + i) % P.num_players(); @@ -274,20 +274,15 @@ void RingPrep::buffer_bits_without_check() SeededPRNG G; for (int i = 0; i < buffer_size; i++) input.add_mine(G.get_bit()); - input.send_mine(); - for (auto& x : player_bits[i]) - x = input.finalize_mine(); } else - { for (int i = 0; i < buffer_size; i++) input.add_other(input_player); - octetStream os; - P.receive_player(input_player, os, true); - for (auto& x : player_bits[i]) - input.finalize_other(input_player, x, os); - } } + input.exchange(); + for (int i = 0; i < n_relevant_players; i++) + for (auto& x : player_bits[i]) + x = input.finalize((base_player + i) % P.num_players()); auto& prot = *protocol; XOR(bits, player_bits[0], player_bits[1], buffer_size, prot, proc); for (int i = 2; i < n_relevant_players; i++) diff --git a/Protocols/SemiMC.h b/Protocols/SemiMC.h index 97fcc9c64..8eefc5a26 100644 --- a/Protocols/SemiMC.h +++ b/Protocols/SemiMC.h @@ -29,12 +29,12 @@ class DirectSemiMC : public SemiMC public: DirectSemiMC() {} // emulate Direct_MAC_Check - DirectSemiMC(const typename T::mac_key_type& _, Names& ____, int __ = 0, int ___ = 0) - { (void)_; (void)__; (void)___; (void)____; } + DirectSemiMC(const typename T::mac_key_type&, const Names& = {}, int = 0, int = 0) {} void POpen_(vector& values,const vector& S,const PlayerBase& P); - void POpen_Begin(vector& values,const vector& S,const Player& P) + void POpen(vector& values,const vector& S,const Player& P) { POpen_(values, S, P); } + void POpen_Begin(vector& values,const vector& S,const Player& P); void POpen_End(vector& values,const vector& S,const Player& P); void Check(const Player& P) { (void)P; } diff --git a/Protocols/SemiMC.hpp b/Protocols/SemiMC.hpp index 386b68da4..486561339 100644 --- a/Protocols/SemiMC.hpp +++ b/Protocols/SemiMC.hpp @@ -42,10 +42,24 @@ void DirectSemiMC::POpen_(vector& values, } template -void DirectSemiMC::POpen_End(vector& values, +void DirectSemiMC::POpen_Begin(vector& values, const vector& S, const Player& P) { - (void) values, (void) S, (void) P; + values.clear(); + values.insert(values.begin(), S.begin(), S.end()); + octetStream os; + for (auto& x : values) + x.pack(os); + P.send_all(os, true); +} + +template +void DirectSemiMC::POpen_End(vector& values, + const vector&, const Player& P) +{ + Bundle oss(P); + P.receive_all(oss); + direct_add_openings(values, P, oss); } #endif diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index 2a42f636e..20d5e4ad4 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -52,6 +52,11 @@ class SemiShare : public T static string type_short() { return "D" + string(1, T::type_char()); } + static int threshold(int nplayers) + { + return nplayers - 1; + } + static SemiShare constant(const clear& other, int my_num, const T& alphai = {}) { return SemiShare(other, my_num, alphai); diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index 83341f641..558509c93 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -38,6 +38,8 @@ class Shamir : public ProtocolBase> int n_mul_players; public: + static const bool uses_triples = false; + Player& P; static U get_rec_factor(int i, int n); @@ -45,6 +47,8 @@ class Shamir : public ProtocolBase> Shamir(Player& P); ~Shamir(); + Shamir branch(); + int get_n_relevant_players(); void reset(); @@ -52,7 +56,11 @@ class Shamir : public ProtocolBase> void init_mul(); void init_mul(SubProcessor* proc); U prepare_mul(const T& x, const T& y, int n = -1); + void exchange(); + void start_exchange(); + void stop_exchange(); + T finalize_mul(int n = -1); T finalize(int n_input_players); diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index d361db1ae..6ecf8663b 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -34,6 +34,12 @@ Shamir::~Shamir() delete resharing; } +template +Shamir Shamir::branch() +{ + return P; +} + template int Shamir::get_n_relevant_players() { @@ -100,6 +106,25 @@ void Shamir::exchange() } } +template +void Shamir::start_exchange() +{ + if (P.my_num() < n_mul_players) + for (int offset = 1; offset < P.num_players(); offset++) + P.send_relative(offset, resharing->os[P.get_player(offset)]); +} + +template +void Shamir::stop_exchange() +{ + for (int offset = 1; offset < P.num_players(); offset++) + { + int receive_from = P.get_player(-offset); + if (receive_from < n_mul_players) + P.receive_player(receive_from, os[receive_from], true); + } +} + template ShamirShare Shamir::finalize_mul(int n) { diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index fd98fe001..4a4228ce1 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -15,10 +15,17 @@ class ShamirMC : public MAC_Check_Base { vector reconstruction; + bool send; + + void finalize(vector& values, const vector& S); + protected: vector os; int threshold; + void prepare(const vector& S, const Player& P); + void exchange(const Player& P); + public: ShamirMC() : threshold(ShamirMachine::s().threshold) {} @@ -31,6 +38,7 @@ class ShamirMC : public MAC_Check_Base ShamirMC() { (void)_; (void)__; (void)___; (void)____; } + void POpen(vector& values,const vector& S,const Player& P); void POpen_Begin(vector& values,const vector& S,const Player& P); void POpen_End(vector& values,const vector& S,const Player& P); diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 6d3627570..63b5ba391 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -10,14 +10,35 @@ void ShamirMC::POpen_Begin(vector& values, const vector& S, const Player& P) { (void) values; + prepare(S, P); + P.send_all(os[P.my_num()], true); +} + +template +void ShamirMC::prepare(const vector& S, const Player& P) +{ os.clear(); os.resize(P.num_players()); - bool send = P.my_num() <= threshold; + send = P.my_num() <= threshold; if (send) { for (auto& share : S) share.pack(os[P.my_num()]); } +} + +template +void ShamirMC::POpen(vector& values, const vector& S, + const Player& P) +{ + prepare(S, P); + exchange(P); + finalize(values, S); +} + +template +void ShamirMC::exchange(const Player& P) +{ for (int offset = 1; offset < P.num_players(); offset++) { int send_to = P.get_player(offset); @@ -37,7 +58,14 @@ template void ShamirMC::POpen_End(vector& values, const vector& S, const Player& P) { - (void) P; + P.receive_all(os); + finalize(values, S); +} + +template +void ShamirMC::finalize(vector& values, + const vector& S) +{ int n_relevant_players = ShamirMachine::s().threshold + 1; if (reconstruction.empty()) { diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index b312db5ab..eef06c155 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -40,6 +40,11 @@ class ShamirShare : public T return "Shamir " + T::type_string(); } + static int threshold(int) + { + return ShamirMachine::s().threshold; + } + static ShamirShare constant(T value, int my_num, const T& alphai = {}) { return ShamirShare(value, my_num, alphai); diff --git a/Protocols/Share.h b/Protocols/Share.h index 80394fcee..b011ecf43 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -71,6 +71,9 @@ class Share static DataFieldType field_type() { return T::field_type(); } + static int threshold(int nplayers) + { return nplayers - 1; } + static Share constant(const clear& aa, int my_num, const typename T::Scalar& alphai) { return Share(aa, my_num, alphai); } diff --git a/README.md b/README.md index 3ab0d11b1..865321df3 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ us, but you can also write an email to mp-spdz@googlegroups.com #### TL;DR (Binary Distribution on Linux or Source Distribution on macOS) This requires either a Linux distribution originally released 2011 or -later (glibc 2.12) or macOS High Sierra or later as well as Python 2 +later (glibc 2.12) or macOS High Sierra or later as well as Python 3 and basic command-line utilities. Download and unpack the [distribution](https://github.com/n1analytics/MP-SPDZ/releases), @@ -72,7 +72,7 @@ The following table lists all protocols that are fully supported. | --- | --- | --- | --- | --- | | Malicious, dishonest majority | [MASCOT](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear](#secret-sharing) | N/A | N/A | N/A | -| Semi-honest, dishonest majority | [Semi](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | +| Semi-honest, dishonest majority | [Semi / Hemi](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | | Malicious, honest majority | [Shamir / Rep3 / PS](#honest-majority) | [Brain / Rep3 / PS](#honest-majority) | [Rep3](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3](#honest-majority) | [BMR](#bmr) | @@ -128,14 +128,14 @@ phase outputs the amount of offline material required, which allows to compute the preprocessing time for a particular computation. #### Requirements - - GCC 5 or later (tested with 8.2) or LLVM/clang 5 or later (tested with 7) + - GCC 5 or later (tested with 8.2) or LLVM/clang 5 or later (tested with 7). We recommend clang because it performs better. - MPIR library, compiled with C++ support (use flag --enable-cxx when running configure) - libsodium library, tested against 1.0.16 - OpenSSL, tested against and 1.0.2 and 1.1.0 - Boost.Asio with SSL support (`libboost-dev` on Ubuntu), tested against 1.65 - Boost.Thread for BMR (`libboost-thread-dev` on Ubuntu), tested against 1.65 - 64-bit CPU - - Python 2.x + - Python 3.5 or later - NTL library for CowGear and the SPDZ-2 and Overdrive offline phases (optional; tested with NTL 10.5) - If using macOS, Sierra or later @@ -239,6 +239,7 @@ The following table shows all programs for dishonest-majority computation using | `semi-party.x` | OT-based | Mod prime | Semi-honest | `semi.sh` | | `semi2k-party.x` | OT-based | Mod 2^k | Semi-honest | `semi2k.sh` | | `cowgear-party.x` | Adapted [LowGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `cowgear.sh` | +| `hemi-party.x` | Semi-homomorphic encryption | Mod prime | Semi-honest | `hemi.sh` | | `semi-bin-party.x` | OT-based | Binary | Semi-honest | `semi-bin.sh` | | `tiny-party.x` | Adapted SPDZ2k | Binary | Malicious | `tiny.sh` | @@ -254,13 +255,17 @@ security. Tiny denotes the adaption of SPDZ2k to the binary setting. In particular, the SPDZ2k sacrifice does not work for bits, so we replace it by cut-and-choose according to [Furukawa et -al.](https://eprint.iacr.org/2016/944.pdf). +al.](https://eprint.iacr.org/2016/944) CowGear denotes a covertly secure version of LowGear. The reason for this is the key generation that only achieves covert security. It is possible however to run full LowGear for triple generation by using `-s` with the desired security parameter. +Hemi denotes the stripped version version of LowGear for semi-honest +security similar to Semi, that is, generating additively shared Beaver +triples using semi-homomorphic encryption. + We will use MASCOT to demonstrate the use, but the other protocols work similarly. diff --git a/Scripts/hemi.sh b/Scripts/hemi.sh new file mode 100755 index 000000000..f0be05bf2 --- /dev/null +++ b/Scripts/hemi.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player hemi-party.x $* || exit 1 diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index b1b5e2ebf..431186fd5 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -9,6 +9,20 @@ gdb_screen() screen -S :$name -d -m bash -l -c "echo $*; echo $LIBRARY_PATH; gdb $prog -ex \"run $*\"" } +lldb_screen() +{ + prog=$1 + shift + IFS= + name=${*/-/} + IFS=' ' + echo debug $prog with arguments $* + echo name: $name + tmp=/tmp/$RANDOM + echo run > $tmp + screen -S :$i -d -m bash -l -c "lldb -s $tmp $prog -- $*" +} + run_player() { port=$((RANDOM%10000+10000)) bin=$1 @@ -42,6 +56,7 @@ run_player() { { if test $i = 0; then tee $log; else cat > $log; fi; } & done last_player=$(($players - 1)) + i=$last_player >&2 echo Running $prefix $SPDZROOT/$bin $last_player $params $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$log_prefix$last_player 2>&1 || return 1 } diff --git a/Scripts/test_ecdsa.sh b/Scripts/test_ecdsa.sh index 78194c6bf..57ba11557 100755 --- a/Scripts/test_ecdsa.sh +++ b/Scripts/test_ecdsa.sh @@ -1,8 +1,5 @@ #!/bin/bash -echo 'MOD = -DGFP_MOD_SZ=4' >> CONFIG.mine - -make clean make -j4 ecdsa Fake-ECDSA.x run() diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 4853d7ff3..0a519ea27 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -12,15 +12,15 @@ function test fi } -./compile.py tutorial +./compile.py -R 64 tutorial -for i in rep-field mal-rep-field ps-rep-field shamir mal-shamir cowgear semi mascot; do +for i in ring brain mal-rep-ring ps-rep-ring semi2k spdz2k; do test $i done -./compile.py -R 64 tutorial +./compile.py tutorial -for i in ring brain mal-rep-ring ps-rep-ring semi2k spdz2k; do +for i in rep-field mal-rep-field ps-rep-field shamir mal-shamir hemi cowgear semi mascot; do test $i done diff --git a/Tools/BitVector.cpp b/Tools/BitVector.cpp index 31b4e6864..e4668ce14 100644 --- a/Tools/BitVector.cpp +++ b/Tools/BitVector.cpp @@ -52,6 +52,16 @@ bool BitVector::parity() const #endif } +void BitVector::append(const BitVector& other, size_t length) +{ + assert(nbits % 8 == 0); + assert(length % 8 == 0); + assert(length <= other.nbits); + auto old_nbytes = nbytes; + resize(nbits + length); + memcpy(bytes + old_nbytes, other.bytes, length / 8); +} + void BitVector::randomize(PRNG& G) { G.get_octets(bytes, nbytes); diff --git a/Tools/BitVector.h b/Tools/BitVector.h index 54e319b31..3c174a188 100644 --- a/Tools/BitVector.h +++ b/Tools/BitVector.h @@ -239,6 +239,8 @@ class BitVector return true; } + void append(const BitVector& other, size_t length); + void randomize(PRNG& G); template void randomize_blocks(PRNG& G); diff --git a/Tools/MMO.cpp b/Tools/MMO.cpp index 9c410d3bd..8c78ce3dd 100644 --- a/Tools/MMO.cpp +++ b/Tools/MMO.cpp @@ -81,13 +81,13 @@ void MMO::hashBlocks(void* output, const void* input) encrypt_and_xor<1>(output, output, IV[0]); } -template <> -void MMO::hashBlocks(void* output, const void* input) +template +void MMO::hashEightGfp(void* output, const void* input) { - if (gfp1::get_ZpD().get_t() < 2) + if (gfp_::get_ZpD().get_t() < 2) throw not_implemented(); - gfp1* out = (gfp1*)output; - hashBlocks<8, gfp1::N_BYTES>(output, input, sizeof(gfp1)); + gfp_* out = (gfp_*)output; + hashBlocks<8, gfp_::N_BYTES>(output, input, sizeof(gfp_)); for (int i = 0; i < 8; i++) out[i].zero_overhang(); int left = 8; @@ -97,7 +97,7 @@ void MMO::hashBlocks(void* output, const void* input) int now_left = 0; for (int j = 0; j < left; j++) if (mpn_cmp((mp_limb_t*) out[indices[j]].get_ptr(), - gfp1::get_ZpD().get_prA(), gfp1::t()) >= 0) + gfp_::get_ZpD().get_prA(), gfp_::t()) >= 0) { indices[now_left] = indices[j]; now_left++; @@ -105,19 +105,31 @@ void MMO::hashBlocks(void* output, const void* input) left = now_left; int block_size = sizeof(__m128i); - int n_blocks = DIV_CEIL(gfp1::size(), block_size); + int n_blocks = DIV_CEIL(gfp_::size(), block_size); for (int i = 0; i < n_blocks; i++) for (int j = 0; j < left; j++) { __m128i* addr = (__m128i*) out[indices[j]].get_ptr() + i; __m128i* in = (__m128i*) out[indices[j]].get_ptr(); auto tmp = aes_128_encrypt(_mm_loadu_si128(in), IV[i]); - memcpy(addr, &tmp, min(block_size, gfp1::size() - i * block_size)); + memcpy(addr, &tmp, min(block_size, gfp_::size() - i * block_size)); out[indices[j]].zero_overhang(); } } } +template <> +void MMO::hashBlocks(void* output, const void* input) +{ + hashEightGfp<1, GFP_MOD_SZ>(output, input); +} + +template <> +void MMO::hashBlocks(void* output, const void* input) +{ + hashEightGfp<3, 4>(output, input); +} + #define ZZ(F,N) \ template void MMO::hashBlocks(void*, const void*); #define Z(F) ZZ(F,1) ZZ(F,2) ZZ(F,8) diff --git a/Tools/MMO.h b/Tools/MMO.h index 2631640f0..f2fd2996e 100644 --- a/Tools/MMO.h +++ b/Tools/MMO.h @@ -32,6 +32,8 @@ class MMO void hashBlocks(void* output, const void* input, size_t alloc_size); template void hashBlocks(void* output, const void* input); + template + void hashEightGfp(void* output, const void* input); template void outputOneBlock(octet* output); Key hash(const Key& input); diff --git a/Tools/time-func.cpp b/Tools/time-func.cpp index 9dc0b8006..37eb28321 100644 --- a/Tools/time-func.cpp +++ b/Tools/time-func.cpp @@ -23,7 +23,7 @@ double timeval_diff_in_seconds(struct timeval *start_time, struct timeval *end_t } -long long timespec_diff(struct timespec *start_time, struct timespec *end_time) +long long timespec_diff(const struct timespec *start_time, const struct timespec *end_time) { long long sec =end_time->tv_sec -start_time->tv_sec ; long long nsec=end_time->tv_nsec-start_time->tv_nsec; @@ -72,3 +72,19 @@ Timer& Timer::operator -=(const Timer& other) elapsed_time -= other.elapsed_time; return *this; } + +Timer& Timer::operator +=(const Timer& other) +{ + assert(clock_id == other.clock_id); + assert(not running); + elapsed_time += other.elapsed_time + other.elapsed_since_last_start(); + return *this; +} + +Timer& Timer::operator +=(const TimeScope& other) +{ + assert(clock_id == other.timer.clock_id); + assert(not running); + elapsed_time += other.timer.elapsed_since_last_start(); + return *this; +} diff --git a/Tools/time-func.h b/Tools/time-func.h index 144f11ff6..8381d6ed0 100644 --- a/Tools/time-func.h +++ b/Tools/time-func.h @@ -10,7 +10,9 @@ long long timeval_diff(struct timeval *start_time, struct timeval *end_time); double timeval_diff_in_seconds(struct timeval *start_time, struct timeval *end_time); -long long timespec_diff(struct timespec *start_time, struct timespec *end_time); +long long timespec_diff(const struct timespec *start_time, const struct timespec *end_time); + +class TimeScope; class Timer { @@ -26,6 +28,8 @@ class Timer double idle(); Timer& operator-=(const Timer& other); + Timer& operator+=(const Timer& other); + Timer& operator+=(const TimeScope& other); private: timespec startv; @@ -33,11 +37,12 @@ class Timer long long elapsed_time; clockid_t clock_id; - long long elapsed_since_last_start(); + long long elapsed_since_last_start() const; }; class TimeScope { + friend class Timer; Timer& timer; public: @@ -83,7 +88,7 @@ inline void Timer::reset() clock_gettime(clock_id, &startv); } -inline long long Timer::elapsed_since_last_start() +inline long long Timer::elapsed_since_last_start() const { timespec endv; clock_gettime(clock_id, &endv); diff --git a/compile.py b/compile.py index 48586cf3f..77c2c5775 100755 --- a/compile.py +++ b/compile.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # ===== Compiler usage instructions =====