From 6cc3fccef0904764364eb8e285519cd34028f325 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 9 May 2023 14:49:52 +1000 Subject: [PATCH] Maintenance. --- .gitmodules | 3 - BMR/RealProgramParty.h | 2 - BMR/RealProgramParty.hpp | 3 +- BMR/Register.h | 11 +- CHANGELOG.md | 15 ++ CONFIG | 21 +- Compiler/GC/instructions.py | 2 + Compiler/allocator.py | 33 ++- Compiler/comparison.py | 3 + Compiler/compilerLib.py | 15 +- Compiler/floatingpoint.py | 4 +- Compiler/instructions.py | 31 ++- Compiler/instructions_base.py | 6 + Compiler/library.py | 2 +- Compiler/ml.py | 203 +++++++++++++----- Compiler/mpc_math.py | 6 +- Compiler/program.py | 64 ++++-- Compiler/types.py | 82 +++++-- Dockerfile | 9 +- ExternalIO/Client.hpp | 73 +++++-- ExternalIO/client.py | 23 +- ExternalIO/domains.py | 4 + FHE/Ciphertext.cpp | 2 - FHE/FFT.cpp | 3 - FHE/FFT_Data.cpp | 2 + FHE/FHE_Keys.cpp | 41 ++-- FHE/FHE_Params.cpp | 6 - FHE/Matrix.cpp | 43 ---- FHE/Matrix.h | 8 - FHE/NTL-Subs.cpp | 132 ------------ FHE/NTL-Subs.h | 5 - FHE/PPData.cpp | 99 --------- FHE/PPData.h | 61 ------ FHE/Plaintext.cpp | 115 ---------- FHEOffline/EncCommit.cpp | 6 +- GC/DealerPrep.h | 1 + GC/NoShare.h | 2 + GC/Program.hpp | 4 +- GC/RepPrep.hpp | 1 + GC/Secret.hpp | 2 +- GC/SemiPrep.cpp | 1 + GC/ThreadMaster.hpp | 4 +- GC/TinierShare.h | 5 + GC/instructions.h | 1 + Machines/OTMachine.cpp | 6 +- Makefile | 37 +--- Math/Integer.h | 6 + Math/Setup.cpp | 4 +- Math/Setup.h | 16 +- Math/ValueInterface.cpp | 18 ++ Math/ValueInterface.h | 6 +- Math/Z2k.h | 4 +- Math/Zp_Data.cpp | 36 +++- Math/Zp_Data.h | 4 + Math/bigint.cpp | 66 ------ Math/bigint.h | 15 +- Math/bigint.hpp | 74 ++++++- Math/gf2n.h | 3 + Math/gf2nlong.h | 2 + Math/gfp.h | 3 + Math/gfp.hpp | 3 +- Math/gfpvar.cpp | 4 +- Math/gfpvar.h | 1 + Math/modp.h | 2 +- Math/mpn_fixed.h | 2 +- Math/square128.cpp | 2 +- Networking/Player.cpp | 19 +- Networking/Player.h | 8 +- Networking/sockets.cpp | 2 +- Processor/BaseMachine.cpp | 22 +- Processor/BaseMachine.h | 2 + Processor/Conv2dTuple.h | 41 ++++ Processor/DataPositions.cpp | 3 +- Processor/Data_Files.h | 33 +-- Processor/Data_Files.hpp | 101 ++++++--- Processor/EdabitBuffer.h | 50 +++++ Processor/Instruction.h | 1 + Processor/Instruction.hpp | 21 +- Processor/Machine.h | 1 + Processor/Machine.hpp | 50 ++--- Processor/OfflineMachine.hpp | 2 +- Processor/Online-Thread.hpp | 16 ++ Processor/PrepBase.cpp | 2 +- Processor/PrepBase.h | 3 + Processor/Processor.hpp | 59 +++-- Processor/ThreadQueue.cpp | 10 + Processor/ThreadQueue.h | 5 + Processor/ThreadQueues.cpp | 29 +++ Processor/ThreadQueues.h | 4 + Processor/instructions.h | 1 + Programs/Source/alex.mpc | 114 ++++++++++ Programs/Source/bankers_bonus.mpc | 3 + Programs/Source/falcon_alex.mpc | 2 +- Programs/Source/keras_mnist_lenet_avgpool.mpc | 72 +++++++ Programs/Source/torch_mnist_lenet_avgpool.mpc | 49 +++++ Protocols/DealerMatrixPrep.hpp | 2 + Protocols/Hemi.h | 6 +- Protocols/Hemi.hpp | 36 +--- Protocols/HemiMatrixPrep.h | 2 + Protocols/MaliciousShamirMC.hpp | 6 +- Protocols/ReplicatedPrep.h | 10 + Protocols/ReplicatedPrep.hpp | 92 ++++++-- Protocols/Shamir.h | 4 +- Protocols/Shamir.hpp | 27 ++- Protocols/ShamirInput.hpp | 25 ++- Protocols/ShamirMC.h | 2 +- Protocols/ShamirMC.hpp | 11 +- Protocols/ShamirShare.h | 2 +- Protocols/Spdz2kPrep.hpp | 3 +- Protocols/edabit.h | 3 +- Protocols/fake-stuff.h | 2 +- Protocols/fake-stuff.hpp | 24 ++- README.md | 10 +- Scripts/compile-emulate.py | 2 +- Scripts/compile-run.py | 2 +- Scripts/emulate.sh | 6 +- Scripts/list-field-protocols.sh | 4 + Scripts/list-ring-protocols.sh | 4 + Scripts/memory-usage.py | 29 ++- Scripts/run-common.sh | 24 ++- Tools/Buffer.cpp | 1 + Tools/Buffer.h | 15 +- Tools/FlexBuffer.cpp | 2 +- Tools/FlexBuffer.h | 2 +- Tools/TimerWithComm.cpp | 62 +++++- Tools/TimerWithComm.h | 15 +- Tools/random.h | 2 +- Tools/time-func.cpp | 12 ++ Tools/time-func.h | 4 + Utils/Fake-Offline.cpp | 27 ++- azure-pipelines.yml | 4 +- deps/mpir | 1 - doc/io.rst | 3 +- doc/machine-learning.rst | 9 +- doc/troubleshooting.rst | 25 ++- 135 files changed, 1658 insertions(+), 1062 deletions(-) delete mode 100644 FHE/PPData.cpp delete mode 100644 FHE/PPData.h create mode 100644 Math/ValueInterface.cpp create mode 100644 Processor/Conv2dTuple.h create mode 100644 Processor/EdabitBuffer.h create mode 100644 Programs/Source/alex.mpc create mode 100644 Programs/Source/keras_mnist_lenet_avgpool.mpc create mode 100644 Programs/Source/torch_mnist_lenet_avgpool.mpc create mode 100755 Scripts/list-field-protocols.sh create mode 100755 Scripts/list-ring-protocols.sh delete mode 160000 deps/mpir diff --git a/.gitmodules b/.gitmodules index 7dea81d36..9307e3292 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,6 @@ [submodule "SimpleOT"] path = deps/SimpleOT url = https://github.com/mkskeller/SimpleOT -[submodule "mpir"] - path = deps/mpir - url = https://github.com/wbhart/mpir [submodule "Programs/Circuits"] path = Programs/Circuits url = https://github.com/mkskeller/bristol-fashion diff --git a/BMR/RealProgramParty.h b/BMR/RealProgramParty.h index 9f274bd88..08500914a 100644 --- a/BMR/RealProgramParty.h +++ b/BMR/RealProgramParty.h @@ -43,8 +43,6 @@ class RealProgramParty : public ProgramPartySpec bool one_shot; - size_t data_sent; - public: static RealProgramParty& s(); diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 421394146..2ab9fa75d 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -154,7 +154,6 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : while (next != GC::DONE_BREAK); MC->Check(*P); - data_sent = P->total_comm().sent; if (online_opts.verbose) P->total_comm().print(); @@ -216,7 +215,7 @@ RealProgramParty::~RealProgramParty() delete prep; delete garble_inputter; delete garble_protocol; - cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl; + garble_machine.print_comm(*this->P, this->P->total_comm()); T::MAC_Check::teardown(); } diff --git a/BMR/Register.h b/BMR/Register.h index 2085eb25a..3baf69041 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -62,11 +62,13 @@ class BaseKeyVector #endif }; #else -class BaseKeyVector : public vector +class BaseKeyVector : public CheckVector { + typedef CheckVector super; + public: - BaseKeyVector(int size = 0) : vector(size, Key(0)) {} - void resize(int size) { vector::resize(size, Key(0)); } + BaseKeyVector(int size = 0) : super(size, Key(0)) {} + void resize(int size) { super::resize(size, Key(0)); } }; #endif @@ -296,7 +298,8 @@ class ProgramRegister : public Phase, public Register static void andm(GC::Processor&, const BaseInstruction&) { throw runtime_error("andm not implemented"); } - static void run_tapes(const vector&) { throw not_implemented(); } + static void run_tapes(const vector&) + { throw runtime_error("multi-threading not implemented"); } // most BMR phases don't need actual input template diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a3a276d7..4b72829e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.6 (May 9, 2023) + +- More extensive benchmarking outputs +- Replace MPIR by GMP +- Secure reading of edaBits from files +- Semi-honest client communication +- Back-propagation for average pooling +- Parallelized convolution +- Probabilistic truncation as in ABY3 +- More balanced communication in Shamir secret sharing +- Avoid unnecessary communication in Dealer protocol +- Linear solver using Cholesky decomposition +- Accept .py files for compilation +- Fixed security bug: proper accounting for random elements + ## 0.3.5 (Feb 16, 2023) - Easier-to-use machine learning interface diff --git a/CONFIG b/CONFIG index 6d5f0f170..f6f436294 100644 --- a/CONFIG +++ b/CONFIG @@ -35,15 +35,32 @@ ARM := $(shell uname -m | grep x86; echo $$?) OS := $(shell uname -s) ifeq ($(MACHINE), x86_64) ifeq ($(OS), Linux) +ifeq ($(shell cat /proc/cpuinfo | grep -q avx2; echo $$?), 0) AVX_OT = 1 else AVX_OT = 0 endif else +AVX_OT = 0 +endif +else ARCH = AVX_OT = 0 endif +ifeq ($(OS), Darwin) +BREW_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include +BREW_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib +endif + +ifeq ($(OS), Linux) +ifeq ($(ARM), 1) +ifeq ($(shell cat /proc/cpuinfo | grep -q aes; echo $$?), 0) +ARCH = -march=armv8.2-a+crypto +endif +endif +endif + USE_KOS = 0 # allow to set compiler in CONFIG.mine @@ -66,7 +83,8 @@ endif # Default for MAX_MOD_SZ is 10, which suffices for all Overdrive protocols # MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5 -LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) +LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS) +LDLIBS += $(BREW_LDLIBS) LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib LDLIBS += -lboost_system -lssl -lcrypto @@ -88,6 +106,7 @@ BOOST = -lboost_thread $(MY_BOOST) endif CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror +CFLAGS += $(BREW_CFLAGS) CPPFLAGS = $(CFLAGS) LD = $(CXX) diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index 73a8af216..4fc2fe7c4 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -17,8 +17,10 @@ class SecretBitsAF(base.RegisterArgFormat): reg_type = 'sb' + name = 'sbit' class ClearBitsAF(base.RegisterArgFormat): reg_type = 'cb' + name = 'cbit' base.ArgFormats['sb'] = SecretBitsAF base.ArgFormats['sbw'] = SecretBitsAF diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 530d21e4b..d1229e988 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -338,16 +338,19 @@ def add_edge(i, j): d[j] = d[i] def read(reg, n): - last_read[reg] = n for dup in reg.duplicates: - if last_def[dup] != -1: + if last_def[dup] not in (-1, n): add_edge(last_def[dup], n) + last_read[reg] = n def write(reg, n): - last_def[reg] = n for dup in reg.duplicates: if last_read[dup] not in (-1, n): add_edge(last_read[dup], n) + if id(dup) in [id(x) for x in block.instructions[n].get_used()] and \ + last_read[dup] not in (-1, n): + add_edge(last_read[dup], n) + last_def[reg] = n def handle_mem_access(addr, reg_type, last_access_this_kind, last_access_other_kind): @@ -434,19 +437,19 @@ def keep_text_order(inst, n): # if options.debug: # col = colordict[instr.__class__.__name__] # G.add_node(n, color=col, label=str(instr)) - for reg in inputs: + for reg in outputs: if reg.vector and instr.is_vec(): for i in reg.vector: - read(i, n) + write(i, n) else: - read(reg, n) + write(reg, n) - for reg in outputs: + for reg in inputs: if reg.vector and instr.is_vec(): for i in reg.vector: - write(i, n) + read(i, n) else: - write(reg, n) + read(reg, n) # will be merged if isinstance(instr, TextInputInstruction): @@ -556,18 +559,6 @@ def eliminate(i): if unused_result: eliminate(i) count += 1 - # remove unnecessary stack instructions - # left by optimization with budget - if isinstance(inst, popint_class) and \ - (not G.degree(i) or (G.degree(i) == 1 and - isinstance(instructions[list(G[i])[0]], StackInstruction))) \ - and \ - inst.args[0].can_eliminate and \ - len(G.pred[i]) == 1 and \ - isinstance(instructions[list(G.pred[i])[0]], pushint_class): - eliminate(list(G.pred[i])[0]) - eliminate(i) - count += 2 if count > 0 and self.block.parent.program.verbose: print('Eliminated %d dead instructions, among which %d opens: %s' \ % (count, open_count, dict(stats))) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 1a139ef6d..cf818570a 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -50,6 +50,9 @@ def set_variant(options): do_precomp = False elif variant is not None: raise CompilerError('Unknown comparison variant: %s' % variant) + if const_rounds and instructions_base.program.options.binary: + raise CompilerError( + 'Comparison variant choice incompatible with binary circuits') def ld2i(c, n): """ Load immediate 2^n into clear GF(p) register c """ diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index bb80dc344..f83e7ca2a 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -22,6 +22,7 @@ def __init__(self, custom_args=None, usage=None, execute=False): self.custom_args = custom_args self.build_option_parser() self.VARS = {} + self.root = os.path.dirname(__file__) + '/..' def build_option_parser(self): parser = OptionParser(usage=self.usage) @@ -269,7 +270,7 @@ def build_program(self, name=None): self.prog = Program(self.args, self.options, name=name) if self.execute: if self.options.execute in \ - ("emulate", "ring", "rep-field", "semi2k"): + ("emulate", "ring", "rep-field"): self.prog.use_trunc_pr = True if self.options.execute in ("ring",): self.prog.use_split(3) @@ -405,7 +406,7 @@ def compile_file(self): infile = open(self.prog.infile) # make compiler modules directly accessible - sys.path.insert(0, "Compiler") + sys.path.insert(0, "%s/Compiler" % self.root) # create the tapes exec(compile(infile.read(), infile.name, "exec"), self.VARS) @@ -477,15 +478,15 @@ def executable_from_protocol(protocol): def local_execution(self, args=[]): executable = self.executable_from_protocol(self.options.execute) - if not os.path.exists(executable): + if not os.path.exists("%s/%s" % (self.root, executable)): print("Creating binary for virtual machine...") try: - subprocess.run(["make", executable], check=True) + subprocess.run(["make", executable], check=True, cwd=self.root) except: raise CompilerError( "Cannot produce %s. " % executable + \ "Note that compilation requires a few GB of RAM.") - vm = 'Scripts/%s.sh' % self.options.execute + vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute) os.execl(vm, vm, self.prog.name, *args) def remote_execution(self, args=[]): @@ -496,7 +497,7 @@ def remote_execution(self, args=[]): from fabric import Connection import subprocess print("Creating static binary for virtual machine...") - subprocess.run(["make", "static/%s" % vm], check=True) + subprocess.run(["make", "static/%s" % vm], check=True, cwd=self.root) # transfer files import glob @@ -519,7 +520,7 @@ def run(i): "mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \ dest) # executable - connection.put("static/%s" % vm, dest) + connection.put("%s/static/%s" % (self.root, vm), dest) # program dest += "/" connection.put("Programs/Schedules/%s.sch" % self.prog.name, diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index f44d95cbe..bd5c13844 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -289,7 +289,7 @@ def BitDecRingRaw(a, k, m): def BitDecRing(a, k, m): bits = BitDecRingRaw(a, k, m) # reversing to reduce number of rounds - return [types.sint.conv(bit) for bit in reversed(bits)][::-1] + return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1] def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): instructions_base.set_global_vector_size(a.size) @@ -306,7 +306,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): def BitDecField(a, k, m, kappa, bits_to_compute=None): res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute) - return [types.sint.conv(bit) for bit in res] + return [types.sintbit.conv(bit) for bit in res] @instructions_base.ret_cisc diff --git a/Compiler/instructions.py b/Compiler/instructions.py index c2aa76684..5642c59c3 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -356,7 +356,17 @@ class reqbl(base.Instruction): code = base.opcodes['REQBL'] arg_format = ['int'] +class active(base.Instruction): + """ Indicate whether program is compatible with malicious-security + protocols. + + :param: 0 for no, 1 for yes + """ + code = base.opcodes['ACTIVE'] + arg_format = ['int'] + class time(base.IOInstruction): + """ Output time since start of computation. """ code = base.opcodes['TIME'] arg_format = [] @@ -2418,9 +2428,10 @@ def add_usage(self, req_node): super(matmulsm, self).add_usage(req_node) req_node.increment(('matmul', tuple(self.args[3:6])), 1) -class conv2ds(base.DataInstruction): +class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable): """ Secret 2D convolution. + :param: number of arguments to follow (int) :param: result (sint vector in row-first order) :param: inputs (sint vector in row-first order) :param: weights (sint vector in row-first order) @@ -2436,10 +2447,12 @@ class conv2ds(base.DataInstruction): :param: padding height (int) :param: padding width (int) :param: batch size (int) + :param: repeat from result... + """ code = base.opcodes['CONV2DS'] - arg_format = ['sw','s','s','int','int','int','int','int','int','int','int', - 'int','int','int','int'] + arg_format = itertools.cycle(['sw','s','s','int','int','int','int','int', + 'int','int','int','int','int','int','int']) data_type = 'triple' is_vec = lambda self: True @@ -2450,14 +2463,16 @@ def __init__(self, *args, **kwargs): assert args[2].size == args[7] * args[8] * args[11] def get_repeat(self): - return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \ - self.args[11] * self.args[14] + args = self.args + return sum(args[i+3] * args[i+4] * args[i+7] * args[i+8] * \ + args[i+11] * args[i+14] for i in range(0, len(args), 15)) def add_usage(self, req_node): super(conv2ds, self).add_usage(req_node) - args = self.args - req_node.increment(('matmul', (1, args[7] * args[8] * args[11], - args[14] * args[3] * args[4])), 1) + for i in range(0, len(self.args), 15): + args = self.args[i:i + 15] + req_node.increment(('matmul', (1, args[7] * args[8] * args[11], + args[14] * args[3] * args[4])), 1) @base.vectorize class trunc_pr(base.VarArgsInstruction): diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index b72079c76..9e72eea7d 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -66,6 +66,7 @@ PLAYERID = 0xE4, USE_EDABIT = 0xE5, USE_MATMUL = 0x1F, + ACTIVE = 0xE9, # Addition ADDC = 0x20, ADDS = 0x21, @@ -700,18 +701,23 @@ def __str__(self): class ClearModpAF(RegisterArgFormat): reg_type = RegType.ClearModp + name = 'cint' class SecretModpAF(RegisterArgFormat): reg_type = RegType.SecretModp + name = 'sint' class ClearGF2NAF(RegisterArgFormat): reg_type = RegType.ClearGF2N + name = 'cgf2n' class SecretGF2NAF(RegisterArgFormat): reg_type = RegType.SecretGF2N + name = 'sgf2n' class ClearIntAF(RegisterArgFormat): reg_type = RegType.ClearInt + name = 'regint' class IntArgFormat(ArgFormat): n_bits = 32 diff --git a/Compiler/library.py b/Compiler/library.py index 1ad3dae68..c857be27e 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1226,7 +1226,7 @@ def loop_fn(): result = loop_body(arg) if isinstance(result, MemValue): result = result.read() - result.link(arg) + arg.update(result) return condition(result) if not isinstance(pre_condition, (bool,int)) or pre_condition: if_statement(pre_condition, lambda: do_while(loop_fn, g=g)) diff --git a/Compiler/ml.py b/Compiler/ml.py index c7f09d09f..20e2941ed 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -372,6 +372,7 @@ def reveal_correctness(self, n=None, Y=None, debug=False): n = self.X.sizes[0] if Y is None: Y = self.Y + assert isinstance(Y, Array) n_correct = MemValue(0) n_printed = MemValue(0) @for_range_opt(n) @@ -1109,14 +1110,7 @@ class Square(ElementWiseLayer): f_prime = staticmethod(lambda x: cfix(2, size=x.size) * x) prime_type = sfix -class MaxPool(NoVariableLayer): - """ Fixed-point MaxPool layer. - - :param shape: input shape (tuple/list of four int) - :param strides: strides (tuple/list of four int, first and last must be 1) - :param ksize: kernel size (tuple/list of four int, first and last must be 1) - :param padding: :py:obj:`'VALID'` (default) or :py:obj:`'SAME'` - """ +class PoolBase(NoVariableLayer): def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), padding='VALID'): assert len(shape) == 4 @@ -1152,38 +1146,6 @@ def __repr__(self): (type(self).__name__, self.X.sizes, self.strides, self.ksize, self.padding) - def forward(self, batch=None, training=False): - if batch is None: - batch = Array.create_from(regint(0)) - def process(pool, bi, k, i, j): - def m(a, b): - c = a[0] > b[0] - l = [c * x for x in a[1]] - l += [(1 - c) * x for x in b[1]] - return c.if_else(a[0], b[0]), l - red = util.tree_reduce(m, [(x[0], [1] if training else []) - for x in pool]) - self.Y[bi][i][j][k] = red[0] - for ii, x in enumerate(red[1]): - self.comparisons[bi][k][i][j][ii] = x - self.traverse(batch, process) - - def backward(self, compute_nabla_X=True, batch=None): - if compute_nabla_X: - self.nabla_X.alloc() - self.nabla_X.assign_all(0) - break_point() - def process(pool, bi, k, i, j): - for (x, h_in, w_in, h, w), c \ - in zip(pool, self.comparisons[bi][k][i][j]): - hh = h * h_in - ww = w * w_in - res = h_in * w_in * c * self.nabla_Y[bi][i][j][k] - get_program().protect_memory(True) - self.nabla_X[bi][hh][ww][k] += res - get_program().protect_memory(False) - self.traverse(batch, process) - def traverse(self, batch, process): need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] > self.X.sizes[i] for i in range(4)] @@ -1221,6 +1183,47 @@ def _(j): h_in, w_in, h, w]) process(pool, bi, k, i, j) +class MaxPool(PoolBase): + """ Fixed-point MaxPool layer. + + :param shape: input shape (tuple/list of four int) + :param strides: strides (tuple/list of four int, first and last must be 1) + :param ksize: kernel size (tuple/list of four int, first and last must be 1) + :param padding: :py:obj:`'VALID'` (default), :py:obj:`'SAME'`, integer, or + list/tuple of integers + + """ + def forward(self, batch=None, training=False): + if batch is None: + batch = Array.create_from(regint(0)) + def process(pool, bi, k, i, j): + def m(a, b): + c = a[0] > b[0] + l = [c * x for x in a[1]] + l += [(1 - c) * x for x in b[1]] + return c.if_else(a[0], b[0]), l + red = util.tree_reduce(m, [(x[0], [1] if training else []) + for x in pool]) + self.Y[bi][i][j][k] = red[0] + for ii, x in enumerate(red[1]): + self.comparisons[bi][k][i][j][ii] = x + self.traverse(batch, process) + + def backward(self, compute_nabla_X=True, batch=None): + if compute_nabla_X: + self.nabla_X.alloc() + self.nabla_X.assign_all(0) + break_point() + def process(pool, bi, k, i, j): + for (x, h_in, w_in, h, w), c \ + in zip(pool, self.comparisons[bi][k][i][j]): + hh = h * h_in + ww = w * w_in + res = h_in * w_in * c * self.nabla_Y[bi][i][j][k] + get_program().protect_memory(True) + self.nabla_X[bi][hh][ww][k] += res + get_program().protect_memory(False) + self.traverse(batch, process) class Argmax(NoVariableLayer): """ Fixed-point Argmax layer. @@ -2058,6 +2061,12 @@ def easyMaxPool(input_shape, kernel_size, stride=None, padding=0): or tuple/list of two int """ + kernel_size, stride, padding = \ + _standardize_pool_options(kernel_size, stride, padding) + return MaxPool(input_shape, [1] + list(stride) + [1], + [1] + list(kernel_size) + [1], padding) + +def _standardize_pool_options(kernel_size, stride, padding): if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): @@ -2066,8 +2075,7 @@ def easyMaxPool(input_shape, kernel_size, stride=None, padding=0): stride = kernel_size padding = padding.upper() if isinstance(padding, str) \ else padding - return MaxPool(input_shape, [1] + list(stride) + [1], - [1] + list(kernel_size) + [1], padding) + return kernel_size, stride, padding class QuantAveragePool2d(QuantBase, AveragePool2d): def input_params_from(self, player): @@ -2075,14 +2083,47 @@ def input_params_from(self, player): for s in self.input_squant, self.output_squant: s.get_params_from(player) -class FixAveragePool2d(FixBase, AveragePool2d): +class FixAveragePool2d(PoolBase, FixBase): """ Fixed-point 2D AvgPool layer. :param input_shape: input shape (tuple/list of four int) :param output_shape: output shape (tuple/list of four int) - :param filter_size: filter size (tuple/list of two int) - :param strides: strides (tuple/list of two int) - """ + :param filter_size: filter size (int or tuple/list of two int) + :param strides: strides (int or tuple/list of two int) + :param padding: :py:obj:`'SAME'`, :py:obj:`'VALID'`, int, + or tuple/list of two int + + """ + def __init__(self, input_shape, output_shape, filter_size, strides=(1, 1), + padding=0): + filter_size, strides, padding = \ + _standardize_pool_options(filter_size, strides, padding) + PoolBase.__init__(self, input_shape, [1] + list(strides) + [1], + [1] + list(filter_size) + [1], padding) + self.pool_size = reduce(operator.mul, filter_size) + if output_shape: + assert self.Y.shape == list(output_shape) + + def _forward(self, batch): + def process(pool, bi, k, i, j): + self.Y[bi][i][j][k] = sum(x[0] for x in pool) * (1 / self.pool_size) + self.traverse(batch, process) + + def backward(self, compute_nabla_X=True, batch=None): + if compute_nabla_X: + self.nabla_X.alloc() + self.nabla_X.assign_all(0) + break_point() + def process(pool, bi, k, i, j): + part = self.nabla_Y[bi][i][j][k] * (1 / self.pool_size) + for x, h_in, w_in, h, w in pool: + hh = h * h_in + ww = w * w_in + res = h_in * w_in * part + get_program().protect_memory(True) + self.nabla_X[bi][hh][ww][k] += res + get_program().protect_memory(False) + self.traverse(batch, process) class QuantReshape(QuantBase, BaseLayer): def __init__(self, input_shape, _, output_shape): @@ -2265,6 +2306,8 @@ def eval(self, data, batch_size=None, top=False): :param data: sample data (:py:class:`Compiler.types.Matrix` with one row per sample) :param top: return top prediction instead of probability distribution + :returns: sfix/sint Array (depening on :py:obj:`top`) + """ if isinstance(self.layers[-1].Y, Array) or top: if top: @@ -2540,6 +2583,8 @@ def _(i): @_no_mem_warnings def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, acc_batch_size=None, reset=True): + MultiArray.disable_index_checks() + Array.check_indices = False if acc_batch_size is None: acc_batch_size = batch_size depreciation = None @@ -2943,6 +2988,10 @@ def MaxPooling2D(pool_size=2, strides=None, padding='valid'): return 'maxpool', {'pool_size': pool_size, 'strides': strides, 'padding': padding} + def AveragePooling2D(pool_size=2, strides=None, padding='valid'): + return 'avgpool', {'filter_size': pool_size, 'strides': strides, + 'padding': padding} + def Dropout(rate): l = math.log(rate, 2) if int(l) != l: @@ -3014,9 +3063,12 @@ def build(self, input_shape, batch_size=128): n_units = reduce(operator.mul, layers[-1].Y.sizes[1:]) if i == len(self.layers) - 1: - if layer[2].get('activation', 'softmax') in \ - ('softmax', 'sigmoid'): + activation = layer[2].get('activation', None) + if activation in ('softmax', 'sigmoid'): layer[2].pop('activation', None) + if activation == 'softmax' and layer[1][0] == 1: + raise CompilerError( + 'softmax requires more than one output neuron') layers.append(Dense(N, n_units, layer[1][0], **layer[2])) input_shape = layers[-1].Y.sizes @@ -3041,6 +3093,9 @@ def build(self, input_shape, batch_size=128): layers.append(easyMaxPool(input_shape, pool_size, strides, padding)) input_shape = layers[-1].Y.sizes + elif name == 'avgpool': + layers.append(FixAveragePool2d(input_shape, None, **layer[1])) + input_shape = layers[-1].Y.sizes elif name == 'dropout': layers.append(Dropout(batch_size, reduce( operator.mul, layers[-1].Y.sizes[1:]), @@ -3192,6 +3247,10 @@ def process(item): layers.append(easyMaxPool(input_shape, item.kernel_size, item.stride, item.padding)) input_shape = layers[-1].Y.shape + elif name == 'AvgPool2d': + layers.append(FixAveragePool2d(input_shape, None, item.kernel_size, + item.stride, item.padding)) + input_shape = layers[-1].Y.shape elif name == 'ReLU': layers.append(Relu(input_shape)) elif name == 'Flatten': @@ -3295,7 +3354,7 @@ def predict_proba(self, X): return super(SGDLogistic, self).predict(X) class SGDLinear(OneLayerSGD): - """ Logistic regression using SGD. + """ Linear regression using SGD. :param n_epochs: number of epochs :param batch_size: batch size @@ -3415,11 +3474,16 @@ def _(i): return res.read() def cholesky(A, reveal_diagonal=False): - """ Cholesky decomposition. """ + """ Cholesky decomposition. + + :returns: lower triangular matrix + + """ assert len(A.shape) == 2 assert A.shape[0] == A.shape[1] L = A.same_shape() L.assign_all(0) + diag_inv = A.value_type.Array(A.shape[0]) @for_range(A.shape[0]) def _(i): @for_range(i + 1) @@ -3429,10 +3493,47 @@ def _(j): @if_e(i == j) def _(): L[i][j] = mpc_math.sqrt(A[i][i] - sum) + diag_inv[i] = 1 / L[i][j] if reveal_diagonal: print_ln('L[%s][%s] = %s = sqrt(%s - %s)', i, j, L[i][j].reveal(), A[i][j].reveal(), sum.reveal()) @else_ def _(): - L[i][j] = (1.0 / L[j][j] * (A[i][j] - sum)) + L[i][j] = (diag_inv[j] * (A[i][j] - sum)) return L + +def solve_lower(A, b): + """ Linear solver where :py:obj:`A` is lower triangular quadratic. """ + assert len(A.shape) == 2 + assert A.shape[0] == A.shape[1] + assert len(b) == A.shape[0] + b = Array.create_from(b) + res = sfix.Array(len(b)) + @for_range(len(b)) + def _(i): + res[i] = b[i] / A[i][i] + b[:] -= res[i] * A.get_column(i) + return res + +def solve_upper(A, b): + """ Linear solver where :py:obj:`A` is upper triangular quadratic. """ + assert len(A.shape) == 2 + assert A.shape[0] == A.shape[1] + assert len(b) == A.shape[0] + b = Array.create_from(b) + res = sfix.Array(len(b)) + @for_range(len(b) - 1, -1, -1) + def _(i): + res[i] = b[i] / A[i][i] + b[:] -= res[i] * A.get_column(i) + return res + +def solve_cholesky(A, b, debug=False): + """ Linear solver using Cholesky decomposition. """ + L = cholesky(A, reveal_diagonal=debug) + if debug: + Optimizer.stat('L', L) + x = solve_lower(L, b) + if debug: + Optimizer.stat('intermediate', x) + return solve_upper(L.transpose(), x) diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 56d09e1ee..d8bfbf370 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -661,7 +661,7 @@ def sqrt_simplified_fx(x): h = h * r H = 4 * (h * h) - if not x.round_nearest or (2 * f < k - 1): + if not x.round_nearest or (2 * x.f < x.k - 1): H = (h < 2 ** (-x.f / 2) / 2).if_else(0, H) H = H * x @@ -806,9 +806,7 @@ def sqrt_fx(x_l, k, f): @instructions_base.sfix_cisc def sqrt(x, k=None, f=None): """ - Returns the square root (sfix) of any given fractional - value as long as it can be rounded to a integral value - with :py:obj:`f` bits of decimal precision. + Square root. :param x: fractional input (sfix). diff --git a/Compiler/program.py b/Compiler/program.py index 9860053fe..9987ad4d7 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -186,6 +186,8 @@ def __init__(self, args, options=defaults, name=None): self.input_files = {} self.base_addresses = {} self._protect_memory = False + self._always_active = True + self.active = True if not self.options.cisc: self.options.cisc = not self.options.optimize_hard @@ -207,16 +209,14 @@ def max_par_tapes(self): return self.n_threads def init_names(self, args): - # ignore path to file - source must be in Programs/Source - if "Programs" in os.listdir(os.getcwd()): - # compile prog in ./Programs/Source directory - self.programs_dir = "Programs" - else: - # assume source is in main SPDZ directory - self.programs_dir = sys.path[0] + "/Programs" + self.programs_dir = "Programs" if self.verbose: print("Compiling program in", self.programs_dir) + for dirname in (self.programs_dir, "Player-Data"): + if not os.path.exists(dirname): + os.mkdir(dirname) + # create extra directories if needed for dirname in ["Public-Input", "Bytecode", "Schedules"]: if not os.path.exists(self.programs_dir + "/" + dirname): @@ -224,13 +224,29 @@ def init_names(self, args): if self.name is None: self.name = args[0].split("/")[-1] - if self.name.endswith(".mpc"): - self.name = self.name[:-4] + exts = ".mpc", ".py" + for ext in exts: + if self.name.endswith(ext): + self.name = self.name[:-len(ext)] if os.path.exists(args[0]): self.infile = args[0] else: - self.infile = self.programs_dir + "/Source/" + self.name + ".mpc" + infiles = [] + for x in (self.programs_dir, sys.path[0] + "/Programs"): + for ext in exts: + filename = args[0] + if not filename.endswith(ext): + filename += ext + infiles += [x + "/Source/" + filename] + for f in infiles: + if os.path.exists(f): + self.infile = f + break + else: + raise CompilerError( + "found none of the potential input files: " + + ", ".join("'%s'" % x for x in [args[0]] + infiles)) """ self.name is input file name (minus extension) + any optional arguments. Used to generate output filenames @@ -479,6 +495,9 @@ def finalize(self): # finalize the memory self.finalize_memory() + # communicate protocol compability + Compiler.instructions.active(self._always_active) + self.write_bytes() if self.options.asmoutfile: @@ -672,6 +691,19 @@ def rabbit_gap(self): logp = int(round(math.log(p, 2))) return abs(p - 2 ** logp) / p < 2 ** -self.security + @property + def active(self): + """ Whether to use actively secure protocols. """ + return self._active + + @active.setter + def active(self, change): + self._always_active &= change + self._active = change + + def semi_honest(self): + self._always_active = False + @staticmethod def read_tapes(schedule): m = re.search(r"([^/]*)\.mpc", schedule) @@ -1454,6 +1486,9 @@ def copy(self): return Tape.Register(self.reg_type, Program.prog.curr_tape) def link(self, other): + if Program.prog.options.noreallocate: + raise CompilerError("reallocation necessary for linking, " + "remove option -u") self.duplicates |= other.duplicates for dup in self.duplicates: dup.duplicates = self.duplicates @@ -1466,12 +1501,15 @@ def update(self, other): :param other: any convertible type """ - other = type(self)(other) + if isinstance(other, Tape.Register) and other.block != Program.prog.curr_block: + other = type(self)(other) + else: + other = self.conv(other) + if Program.prog.curr_block in [x.block for x in self.duplicates]: + self.program.start_new_basicblock() if self.program != other.program: raise CompilerError( 'cannot update register with one from another thread') - if other.block in [x.block for x in self.duplicates]: - self.program.start_new_basicblock() self.link(other) @property diff --git a/Compiler/types.py b/Compiler/types.py index a87371736..881178cd4 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -659,6 +659,7 @@ def traverse(content, level): traverse(x, level + 1) traverse(content, 0) f.write('\n') + f.flush() if requested_shape is not None and \ list(shape) != list(requested_shape): raise CompilerError('content contradicts shape') @@ -2415,26 +2416,34 @@ def get_raw_input_from(cls, player): def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): """ Securely obtain shares of values input by a client. This uses the triple-based input protocol introduced by - `Damgård et al. `_ + `Damgård et al. `_ unless + :py:obj:`program.active` is set to false, in which case + it uses random values to mask the clients' input. :param n: number of inputs (int) :param client_id: regint :param size: vector size (default 1) :returns: list of sint """ - # send shares of a triple to client - triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n)))) + if program.active: + # send shares of a triple to client + triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n)))) + else: + triples = [sint.get_random() for i in range(n)] + sint.write_shares_to_socket(client_id, triples, message_type) received = util.tuplify(cint.read_from_socket(client_id, n)) y = [0] * n for i in range(n): - y[i] = received[i] - triples[i * 3] + y[i] = received[i] - triples[i * 3 if program.active else i] return y @classmethod def reveal_to_clients(cls, clients, values): """ Reveal securely to clients. + Uses :py:obj:`program.active` to determine whether to use + triples for active security. :param clients: client ids (list or array) :param values: list of sint to reveal @@ -2445,8 +2454,11 @@ def reveal_to_clients(cls, clients, values): for value in values: assert(value.size == values[0].size) - r = sint.get_random() - to_send += [value, r, value * r] + if program.active: + r = sint.get_random() + to_send += [value, r, value * r] + else: + to_send += [value] if isinstance(clients, Array): n_clients = clients.length @@ -2844,7 +2856,7 @@ def reveal_to(self, player): privateoutput(self.size, player, res._v, self) return res - def private_division(self, divisor, active=True, dividend_length=None, + def private_division(self, divisor, active=None, dividend_length=None, divisor_length=None): """ Private integer division as per `Veugen and Abspoel `_ @@ -2878,6 +2890,9 @@ def private_division(self, divisor, active=True, dividend_length=None, z_shared = ((self << (l + sigma)) + h + r_pprime) z = z_shared.reveal_to(0) + if active is None: + active = program.active + if active: z_prime = [sint(x) for x in (z // d).bit_decompose(min_length)] check = [(x * (1 - x)).reveal() == 0 for x in z_prime] @@ -2893,6 +2908,7 @@ def private_division(self, divisor, active=True, dividend_length=None, y_prime = sint.bit_compose(z_prime[:l + sigma]) y = sint.bit_compose(z_prime[l + sigma:]) else: + program.semi_honest() y = sint(z // (d << (l + sigma))) y_prime = sint((z // d) % (2 ** (l + sigma))) @@ -3147,7 +3163,9 @@ def bit_decompose(self, bit_length=None, step=1): for i in range(0, bit_length, step)] one = cgf2n(1) - masked = sum([b * (one << (i * step)) for i,b in enumerate(random_bits)], self).reveal() + masked = sum([b * (one << (i * step)) + for i,b in enumerate(random_bits)], self).reveal( + check=False) masked_bits = masked.bit_decompose(bit_length,step=step) return [m + r for m,r in zip(masked_bits, random_bits)] @@ -3157,7 +3175,9 @@ def bit_decompose_embedding(self): for i in range(8)] one = cgf2n(1) wanted_positions = [0, 5, 10, 15, 20, 25, 30, 35] - masked = sum([b * (one << wanted_positions[i]) for i,b in enumerate(random_bits)], self).reveal() + masked = sum([b * (one << wanted_positions[i]) + for i,b in enumerate(random_bits)], self).reveal( + check=False) return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)] for t in (sint, sgf2n): @@ -4080,7 +4100,8 @@ class _single(_number, _secret_structure): @vectorized_classmethod def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): """ - Securely obtain shares of values input by a client. Assumes client + Securely obtain shares of values input by a client via + :py:func:`sint.receive_from_client`. Assumes client has already converted values to integer representation. :param n: number of inputs (int) @@ -4095,7 +4116,7 @@ def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType @classmethod def reveal_to_clients(cls, clients, values): - """ Reveal securely to clients. + """ Reveal securely to clients via :py:func:`sint.reveal_to_clients`. :param clients: client ids (list or array) :param values: list of values of this class @@ -4556,7 +4577,7 @@ class sfix(_fix): :py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``), returning :py:class:`sbitint`. The other operand can be any of sfix/sint/cfix/regint/cint/int/float. It also supports ``abs()`` - and ``**``, the latter for integer exponents. + and ``**``. Note that the default precision (16 bits after the dot, 31 bits in total) only allows numbers up to :math:`2^{31-16-1} \\approx @@ -4669,6 +4690,8 @@ def hard_conv_me(self, cls): return self.v def mul_no_reduce(self, other, res_params=None): + if not isinstance(other, type(self)): + return self * other assert self.f == other.f assert self.k == other.k return self.unreduced(self.v * other.v) @@ -4734,6 +4757,11 @@ def reduce_after_mul(self): nearest=sfix.round_nearest, signed=True) return sfix._new(v, k=self.k - self.m, f=self.m) + def update(self, other): + assert self.k == other.k + assert self.m == other.m + self.v.update(other.v) + sfix.unreduced_type = unreduced_sfix sfix.set_precision(16, 31) @@ -4953,6 +4981,8 @@ class sfloat(_number, _secret_structure): This uses integer operations internally, see :py:class:`sint` for security considerations. + See `Aliasgari et al. `_ for + details. The type supports basic arithmetic (``+, -, *, /``), returning :py:class:`sfloat`, and comparisons (``==, !=, <, <=, >, >=``), @@ -5459,6 +5489,9 @@ class Array(_vectorizable): b.input_from(1) a[:] += b[:] + Arrays aren't initialized on creation, you need to call + :py:func:`assign_all` to initialize them to a constant value. + """ check_indices = True @@ -5708,7 +5741,7 @@ def assign_all(self, value, use_threads=True, conv=True): mem_value = MemValue(value) self.address = MemValue.if_necessary(self.address) n_threads = 8 if use_threads and util.is_constant(self.length) and \ - len(self) > 2**20 else None + len(self) > 2**20 and not program.options.garbled else None @library.multithread(n_threads, self.length) def _(base, size): if use_vector: @@ -5896,7 +5929,7 @@ def secure_shuffle(self): self.assign_vector(self.get_vector().secure_shuffle()) def secure_permute(self, *args, **kwargs): - """ Secure permutate in place according to the security model. + """ Secure permute in place according to the security model. See :py:func:`MultiArray.secure_shuffle` for references. :param permutation: output of :py:func:`sint.get_secure_shuffle()` @@ -6227,7 +6260,10 @@ def assign_vector_by_indices(self, vector, *indices): def same_shape(self): """ :return: new multidimensional array with same shape and basic type """ - return MultiArray(self.sizes, self.value_type) + if len(self.sizes) == 2: + return Matrix(*self.sizes, self.value_type) + else: + return MultiArray(self.sizes, self.value_type) def get_part(self, start, size): """ Part multi-array. @@ -6400,7 +6436,7 @@ class t(self.value_type): pass t.params = res_params else: - if issubclass(self.value_type, _secret_structure): + if self.value_type == other.value_type: t = self.value_type else: t = type(self.value_type(0) * other.value_type(0)) @@ -6435,10 +6471,12 @@ def _(i): # fallback for binary circuits @library.for_range_opt(other.sizes[1]) def _(j): - res_matrix[i][j] = 0 - @library.for_range_opt(self.sizes[1]) + tmp = self[i][0].mul_no_reduce(other[0][j]) + @library.for_range_opt(1, self.sizes[1]) def _(k): - res_matrix[i][j] += self[i][k] * other[k][j] + prod = self[i][k].mul_no_reduce(other[k][j]) + tmp.iadd(prod) + res_matrix[i][j] = tmp.reduce_after_mul() return res_matrix elif isinstance(other, self.value_type): return self * Array.create_from(other) @@ -6780,6 +6818,9 @@ class MultiArray(SubMultiArray): a[1].input_from(1) a[2][:] = a[0][:] * a[1][:] + Arrays aren't initialized on creation, you need to call + :py:func:`assign_all` to initialize them to a constant value. + """ @staticmethod def disable_index_checks(): @@ -6817,6 +6858,9 @@ class Matrix(MultiArray): :param columns: compile-time (int) :param value_type: basic type of entries + Matrices aren't initialized on creation, you need to call + :py:func:`assign_all` to initialize them to a constant value. + """ def __init__(self, rows, columns, value_type, debug=None, address=None): MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ diff --git a/Dockerfile b/Dockerfile index 79dc702bb..3dfd92023 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,23 +47,16 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ libboost-dev \ libboost-thread-dev \ libclang-dev \ + libgmp-dev \ libntl-dev \ libsodium-dev \ libssl-dev \ libtool \ - m4 \ - texinfo \ - yasm \ vim \ gdb \ valgrind \ && rm -rf /var/lib/apt/lists/* -# mpir -COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/include/* /usr/local/include/ -COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/lib/* /usr/local/lib/ -COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/share/info/* /usr/local/share/info/ - ENV MP_SPDZ_HOME /usr/src/MP-SPDZ WORKDIR $MP_SPDZ_HOME diff --git a/ExternalIO/Client.hpp b/ExternalIO/Client.hpp index ffc9705cc..c401d86b1 100644 --- a/ExternalIO/Client.hpp +++ b/ExternalIO/Client.hpp @@ -46,6 +46,7 @@ void Client::send_private_inputs(const vector& values) octetStream os; vector< vector > triples(num_inputs, vector(3)); vector triple_shares(3); + bool active = true; // Receive num_inputs triples from SPDZ for (size_t j = 0; j < sockets.size(); j++) @@ -61,9 +62,21 @@ void Client::send_private_inputs(const vector& values) cerr << "received " << os.get_length() << " from " << j << endl << flush; #endif + if (j == 0) + { + if (os.get_length() == 3 * values.size() * T::size()) + active = true; + else + active = false; + } + + int n_expected = active ? 3 : 1; + if (os.get_length() != n_expected * T::size() * values.size()) + throw runtime_error("unexpected data length in sending"); + for (int j = 0; j < num_inputs; j++) { - for (int k = 0; k < 3; k++) + for (int k = 0; k < n_expected; k++) { triple_shares[k].unpack(os); triples[j][k] += triple_shares[k]; @@ -71,16 +84,18 @@ void Client::send_private_inputs(const vector& values) } } - // Check triple relations (is a party cheating?) - for (int i = 0; i < num_inputs; i++) - { - if (T(triples[i][0] * triples[i][1]) != triples[i][2]) + if (active) + // Check triple relations (is a party cheating?) + for (int i = 0; i < num_inputs; i++) { - cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl; - cerr << "Incorrect triple at " << i << ", aborting\n"; - throw mac_fail(); + if (T(triples[i][0] * triples[i][1]) != triples[i][2]) + { + cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl; + cerr << "Incorrect triple at " << i << ", aborting\n"; + throw mac_fail(); + } } - } + // Send inputs + triple[0], so SPDZ can compute shares of each value os.reset_write_head(); for (int i = 0; i < num_inputs; i++) @@ -100,6 +115,7 @@ vector Client::receive_outputs(int n) { vector triples(3 * n); octetStream os; + bool active = true; for (auto& socket : sockets) { os.reset_write_head(); @@ -107,7 +123,20 @@ vector Client::receive_outputs(int n) #ifdef VERBOSE_COMM cout << "received " << os.get_length() << endl << flush; #endif - for (int j = 0; j < 3 * n; j++) + + if (socket == sockets[0]) + { + if (os.get_length() == (size_t) 3 * n * T::size()) + active = true; + else + active = false; + } + + int n_expected = n * (active ? 3 : 1); + if (os.get_length() != (size_t) n_expected * T::size()) + throw runtime_error("unexpected data length in receiving"); + + for (int j = 0; j < n_expected; j++) { T value; value.unpack(os); @@ -115,16 +144,24 @@ vector Client::receive_outputs(int n) } } - vector output_values; - for (int i = 0; i < 3 * n; i += 3) + if (active) { - if (T(triples[i] * triples[i + 1]) != triples[i + 2]) + vector output_values; + for (int i = 0; i < 3 * n; i += 3) { - cerr << "Unable to authenticate output value as correct, aborting." << endl; - throw mac_fail(); + if (T(triples[i] * triples[i + 1]) != triples[i + 2]) + { + cerr << "Unable to authenticate output value as correct, aborting." << endl; + throw mac_fail(); + } + output_values.push_back(triples[i]); } - output_values.push_back(triples[i]); - } - return output_values; + return output_values; + } + else + { + triples.resize(n); + return triples; + } } diff --git a/ExternalIO/client.py b/ExternalIO/client.py index a6fd0b035..17f2e3eaf 100644 --- a/ExternalIO/client.py +++ b/ExternalIO/client.py @@ -34,17 +34,25 @@ def receive_triples(self, T, n): os = octetStream() for socket in self.sockets: os.Receive(socket) + if socket == self.sockets[0]: + active = os.get_length() == 3 * n * T.size() + n_expected = 3 if active else 1 + if os.get_length() != n_expected * T.size() * n: + import sys + print (os.get_length(), n_expected, T.size(), n, active, file=sys.stderr) + raise Exception('unexpected data length') for triple in triples: - for i in range(3): + for i in range(n_expected): t = T() t.unpack(os) triple[i] += t res = [] - for triple in triples: - prod = triple[0] * triple[1] - if prod != triple[2]: - raise Exception( - 'invalid triple, diff %s' % hex(prod.v - triple[2].v)) + if active: + for triple in triples: + prod = triple[0] * triple[1] + if prod != triple[2]: + raise Exception( + 'invalid triple, diff %s' % hex(prod.v - triple[2].v)) return triples def send_private_inputs(self, values): @@ -68,6 +76,9 @@ def __init__(self, value=None): if value is not None: self.buf += value + def get_length(self): + return len(self.buf) + def reset_write_head(self): self.buf = b'' self.ptr = 0 diff --git a/ExternalIO/domains.py b/ExternalIO/domains.py index bd98da326..2362623e8 100644 --- a/ExternalIO/domains.py +++ b/ExternalIO/domains.py @@ -27,6 +27,10 @@ def __eq__(self, other): def __neq__(self, other): return self.v != other.v + @classmethod + def size(cls): + return cls.n_bytes + def unpack(self, os): self.v = 0 buf = os.consume(self.n_bytes) diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 62cbd5281..0a638f852 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -1,5 +1,4 @@ #include "Ciphertext.h" -#include "PPData.h" #include "P2Data.h" #include "Tools/Exceptions.h" @@ -143,6 +142,5 @@ void Ciphertext::rerandomize(const FHE_PK& pk) template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); -template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); template void mul(Ciphertext& ans, const Plaintext& a, const Ciphertext& c); diff --git a/FHE/FFT.cpp b/FHE/FFT.cpp index e8dcc228a..8daf641f2 100644 --- a/FHE/FFT.cpp +++ b/FHE/FFT.cpp @@ -259,6 +259,3 @@ void BFFT(vector& ans,const vector& a,const FFT_Data& FFTD,bool forw else { throw crash_requested(); } } - - - diff --git a/FHE/FFT_Data.cpp b/FHE/FFT_Data.cpp index d3b67b506..da7c5d7c9 100644 --- a/FHE/FFT_Data.cpp +++ b/FHE/FFT_Data.cpp @@ -83,6 +83,8 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD) for (int r=0; r<2; r++) { FFT_Iter(b[r],twop,two_root[0],PrD); } } + else + throw bad_value(); } } diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 742c85452..08870dc75 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -2,7 +2,6 @@ #include "FHE_Keys.h" #include "Ciphertext.h" #include "P2Data.h" -#include "PPData.h" #include "FFT_Data.h" #include "Math/modp.hpp" @@ -406,29 +405,17 @@ bigint FHE_SK::get_noise(const Ciphertext& c) } - -template void FHE_PK::encrypt(Ciphertext&, const Plaintext_& mess, - const Random_Coins& rc) const; -template void FHE_PK::encrypt(Ciphertext&, const Plaintext_& mess, - const Random_Coins& rc) const; - -template Ciphertext FHE_PK::encrypt(const Plaintext_& mess, - const Random_Coins& rc) const; -template Ciphertext FHE_PK::encrypt(const Plaintext_& mess) const; -template Ciphertext FHE_PK::encrypt(const Plaintext_& mess) const; - -template void FHE_SK::decrypt(Plaintext_&, const Ciphertext& c) const; -template void FHE_SK::decrypt(Plaintext_&, const Ciphertext& c) const; - -template Plaintext_ FHE_SK::decrypt(const Ciphertext& c, - const FFT_Data& FieldD); -template Plaintext_ FHE_SK::decrypt(const Ciphertext& c, - const P2Data& FieldD); - -template void FHE_SK::decrypt_any(Plaintext_& res, - const Ciphertext& c); -template void FHE_SK::decrypt_any(Plaintext_& res, - const Ciphertext& c); - -template void FHE_SK::check(const FHE_PK& pk, const FFT_Data&); -template void FHE_SK::check(const FHE_PK& pk, const P2Data&); +#define X(FD) \ + template void FHE_PK::encrypt(Ciphertext&, const Plaintext_& mess, \ + const Random_Coins& rc) const; \ + template Ciphertext FHE_PK::encrypt(const Plaintext_& mess) const; \ + template Plaintext_ FHE_SK::decrypt(const Ciphertext& c, \ + const FD& FieldD); \ + template void FHE_SK::decrypt(Plaintext_& res, \ + const Ciphertext& c) const; \ + template void FHE_SK::decrypt_any(Plaintext_& res, \ + const Ciphertext& c); \ + template void FHE_SK::check(const FHE_PK& pk, const FD&); + +X(FFT_Data) +X(P2Data) diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 5a0f3991c..4fe98e58b 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -119,12 +119,6 @@ const P2Data& FHE_Params::get_plaintext_field_data() const throw not_implemented(); } -template<> -const PPData& FHE_Params::get_plaintext_field_data() const -{ - throw not_implemented(); -} - bigint FHE_Params::get_plaintext_modulus() const { return fd.get_prime(); diff --git a/FHE/Matrix.cpp b/FHE/Matrix.cpp index dcec137e4..8815fb356 100644 --- a/FHE/Matrix.cpp +++ b/FHE/Matrix.cpp @@ -248,49 +248,6 @@ matrix inv(const matrix& A) } -vector solve(modp_matrix& A,const Zp_Data& PrD) -{ - unsigned int n=A.size(); - if ((n+1)!=A[0].size()) { throw invalid_params(); } - - modp t,ti; - for (unsigned int r=0; r ans; - ans.resize(n); - for (unsigned int i=0; i > matrix; -typedef vector< vector > modp_matrix; class imatrix : public vector< BitVector > { @@ -39,13 +38,6 @@ void print(const imatrix& S); // requires column operations to create the inverse matrix inv(const matrix& A); -// Another special routine for modp matrices. -// Solves -// Ax=b -// Assumes A is unimodular, square and only requires row operations to -// create the inverse. In put is C=(A||b) and the routines alters A -vector solve(modp_matrix& C,const Zp_Data& PrD); - // Finds a pseudo-inverse of a matrix A modulo 2 // - Input matrix is assumed to have more rows than columns void pinv(imatrix& Ai,const imatrix& A); diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index f3973026e..85b630ee6 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -742,135 +742,3 @@ void load_or_generate(P2Data& P2D, const Ring& R) P2D.store(R); } } - - -#ifdef USE_NTL -/* - * Create FHE parameters for a general plaintext modulus p - * Basically this is for general large primes only - */ -void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0, - bigint& pr1, int n, int sec, bigint& p, FHE_Params& params) -{ - cout << "Setting up parameters" << endl; - - int lgp=numBits(p); - int mm,idx; - - // mm is the minimum value of m we will accept - if (lgp<48) - { mm=100; // Test case - idx=0; - } - else if (lgp <96) - { mm=8192; - idx=1; - } - else if (lgp<192) - { mm=16384; - idx=2; - } - else if (lgp<384) - { mm=16384; - idx=3; - } - else if (lgp<768) - { mm=32768; - idx=4; - } - else - { throw invalid_params(); } - - // Now find the small factors of p-1 and their exponents - bigint t=p-1; - vector primes(100),exp(100); - - PrimeSeq s; - long pr; - pr=s.next(); - int len=0; - while (pr<2*mm) - { int e=0; - while ((t%pr)==0) - { e++; - t=t/pr; - } - if (e!=0) - { primes[len]=pr; - exp[len]=e; - if (len!=0) { cout << " * "; } - cout << pr << "^" << e << flush; - len++; - } - pr=s.next(); - } - cout << endl; - - // We want to find the best m which divides pr-1, such that - // - 2*m > phi(m) > mm - // - m has the smallest number of factors - vector ee; - ee.resize(len); - for (int i=0; imx) { mx=ee[i]; } - } - } - // Put "if" here to stop searching for things which will never work - if (cand_m>1 && cand_m<4*mm) - { //cout << " : " << cand_m << " : " << hwt << flush; - int phim=phi_N(cand_m); - //cout << " : " << phim << " : " << mm << endl; - if (phim>mm && phim<3*mm) - { if (m==-1 || hwtexp[i] && flag) - { ee[i]=0; - i++; - if (i==len) { flag=false; i=0; } - else { ee[i]=ee[i]+1; } - } - } - if (m==-1) - { throw bad_value(); } - cout << "Chosen value of m=" << m << "\t\t phi(m)=" << bphi_m << " : " << min_hwt << " : " << bmx << endl; - - Parameters parameters(n, lgp, sec); - parameters.SPDZ_Data_Setup_Char_p_Sub(idx,m,p,params); - int mx=0; - for (int i=0; i& elem) const -{ - if (elem.size()!= (unsigned) R.phi_m()) - { throw params_mismatch(); } - - throw not_implemented(); - -/* - vector ans; - ans.resize(R.phi_m()); - modp x=root; - for (int i=0; i&) const -{ - // avoid warning - throw not_implemented(); - - /* - modp_matrix A; - int n=phi_m(); - A.resize(n, vector(n+1) ); - modp x=root; - for (int i=0; i& mess) const -{ - // Uses Horner's rule - gfp ans; - ans = mess[mess.size()-1]; - gfp coeff; - for (int j=mess.size()-2; j>=0; j--) - { ans *= (thetaPow); - coeff = mess[j]; - ans += (coeff); - } - return ans; -} - - diff --git a/FHE/PPData.h b/FHE/PPData.h deleted file mode 100644 index 46c8c8e6e..000000000 --- a/FHE/PPData.h +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef _PPData -#define _PPData - -#include "Math/modp.h" -#include "Math/Zp_Data.h" -#include "Math/gfpvar.h" -#include "Math/fixint.h" -#include "FHE/Ring.h" -#include "FHE/FFT_Data.h" - -/* Class for holding modular arithmetic data wrt the ring - * - * It also holds the ring - */ - -class PPData -{ - public: - typedef gfp T; - typedef bigint S; - typedef typename FFT_Data::poly_type poly_type; - - Ring R; - Zp_Data prData; - - modp root; // m'th Root of Unity mod pr - - void init(const Ring& Rg,const Zp_Data& PrD); - - PPData() { ; } - PPData(const Ring& Rg,const Zp_Data& PrD) - { init(Rg,PrD); } - - const Zp_Data& get_prD() const { return prData; } - const bigint& get_prime() const { return prData.pr; } - int phi_m() const { return R.phi_m(); } - int m() const { return R.m(); } - int num_slots() const { return R.phi_m(); } - - - int p(int i) const { return R.p(i); } - int p_inv(int i) const { return R.p_inv(i); } - const vector& Phi() const { return R.Phi(); } - - // Convert input vector from poly to evaluation representation - // - Uses naive method and not FFT, we only use this rarely in any case - void to_eval(vector& elem) const; - void from_eval(vector& elem) const; - - // Following are used to iteratively get slots, as we use PPData when - // we do not have an efficient FFT algorithm - gfp thetaPow,theta; - int pow; - void reset_iteration(); - void next_iteration(); - gfp get_evaluation(const vector& mess) const; - -}; - -#endif - diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index 4eba6e8f0..8adb3d34a 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -1,7 +1,6 @@ #include "FHE/Plaintext.h" #include "FHE/Ring_Element.h" -#include "FHE/PPData.h" #include "FHE/P2Data.h" #include "FHE/Rq_Element.h" #include "FHE_Keys.h" @@ -85,39 +84,6 @@ void Plaintext::to_poly() const } -template<> -void Plaintext::from_poly() const -{ - if (type!=Polynomial) { return; } - vector aa((*Field_Data).phi_m()); - for (unsigned int i=0; iget_prD()}; - type=Both; -} - - -template<> -void Plaintext::to_poly() const -{ - if (type!=Evaluation) { return; } - cout << "This is VERY inefficient to convert a plaintext to poly representation" << endl; - vector bb((*Field_Data).phi_m()); - for (unsigned int i=0; i void Plaintext::from_poly() const @@ -385,34 +351,6 @@ void add(Plaintext& z,const Plaintext& } -template<> -void add(Plaintext& z,const Plaintext& x, - const Plaintext& y) -{ - if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); } - if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); } - - if (x.type==Both && y.type!=Both) { z.type=y.type; } - else if (y.type==Both && x.type!=Both) { z.type=x.type; } - else if (x.type!=y.type) { throw rep_mismatch(); } - else { z.type=x.type; } - - if (z.type!=Polynomial) - { - z.a.resize(z.num_slots()); - for (unsigned int i=0; i(*z.Field_Data).get_prime()) - { z.b[i]-=(*z.Field_Data).get_prime(); } - } - } -} - - template<> @@ -475,36 +413,6 @@ void sub(Plaintext& z,const Plaintext& -template<> -void sub(Plaintext& z,const Plaintext& x, - const Plaintext& y) -{ - if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); } - if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); } - - if (x.type==Both && y.type!=Both) { z.type=y.type; } - else if (y.type==Both && x.type!=Both) { z.type=x.type; } - else if (x.type!=y.type) { throw rep_mismatch(); } - else { z.type=x.type; } - - z.allocate(); - if (z.type!=Polynomial) - { - z.a.resize(z.num_slots()); - for (unsigned int i=0; i @@ -572,23 +480,6 @@ void Plaintext::negate() } } -template<> -void Plaintext::negate() -{ - if (type!=Polynomial) - { - a.resize(num_slots()); - for (unsigned int i=0; i @@ -731,12 +622,6 @@ template void mul(Plaintext& z,const Plaintext; - -template void mul(Plaintext& z,const Plaintext& x,const Plaintext& y); - - - template class Plaintext; template void mul(Plaintext& z,const Plaintext& x,const Plaintext& y); diff --git a/FHEOffline/EncCommit.cpp b/FHEOffline/EncCommit.cpp index ddd77d800..0a0e100ca 100644 --- a/FHEOffline/EncCommit.cpp +++ b/FHEOffline/EncCommit.cpp @@ -274,7 +274,7 @@ void EncCommit::Create_More() const (*P).Broadcast_Receive(ctx_Delta); // Output the ctx_Delta to a file - sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread); + snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread); ofstream outf(filename); for (int j=0; j<(*P).num_players(); j++) { @@ -308,7 +308,7 @@ void EncCommit::Create_More() const octetStream occ,ctx_D; for (int i=0; i<2*TT; i++) { if (open[i]==1) - { sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread); + { snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread); ifstream inpf(filename); for (int j=0; j<(*P).num_players(); j++) { @@ -386,7 +386,7 @@ void EncCommit::Create_More() const Ciphertext enc1(params),enc2(params),eDelta(params); octetStream oe1,oe2; - sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread); + snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread); ifstream inpf(filename); for (int k=0; k<(*P).num_players(); k++) { diff --git a/GC/DealerPrep.h b/GC/DealerPrep.h index a3bd4bcc8..c9c8b21c1 100644 --- a/GC/DealerPrep.h +++ b/GC/DealerPrep.h @@ -26,6 +26,7 @@ class DealerPrep : public BufferPrep, ShiftableTripleBuffer::P = P; } void buffer_triples() diff --git a/GC/NoShare.h b/GC/NoShare.h index ec2c85ac0..c3d795e7b 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -183,6 +183,8 @@ class NoShare : public ShareInterface NoShare operator-(const NoShare&) const { fail(); return {}; } NoShare operator*(const NoValue&) const { fail(); return {}; } + NoShare operator^(const NoShare&) const { fail(); return {}; } + NoShare operator&(int) const { fail(); return {}; } NoShare operator>>(int) const { fail(); return {}; } diff --git a/GC/Program.hpp b/GC/Program.hpp index 768a09c58..d493e4e1e 100644 --- a/GC/Program.hpp +++ b/GC/Program.hpp @@ -123,7 +123,9 @@ BreakType Program::execute(Processor& Proc, U& dynamic_memory, } time++; #ifdef DEBUG_COMPLEXITY - cout << "complexity at " << time << ": " << Proc.complexity << endl; + cout << T::part_type::name() << " complexity at " << time << ": " << + Proc.complexity << " after " << hex << + instruction.get_opcode() << dec << endl; #endif } while (Proc.complexity < (size_t) OnlineOptions::singleton.batch_size); diff --git a/GC/RepPrep.hpp b/GC/RepPrep.hpp index f83fbdaf4..8cba2d75c 100644 --- a/GC/RepPrep.hpp +++ b/GC/RepPrep.hpp @@ -39,6 +39,7 @@ void RepPrep::set_protocol(typename T::Protocol& protocol) return; this->protocol = new ReplicatedBase(protocol.P); + this->P = &protocol.P; } template diff --git a/GC/Secret.hpp b/GC/Secret.hpp index 15fbd5977..68d794cf5 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -89,7 +89,7 @@ void Secret::random(int n_bits, int128 share) { (void)share; if (n_bits > 128) - throw not_implemented(); + throw runtime_error("too many bits"); resize_regs(n_bits); for (int i = 0; i < n_bits; i++) get_reg(i).random(); diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 3adc385d5..2cf710f3d 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -37,6 +37,7 @@ void SemiPrep::set_protocol(SemiSecret::Protocol& protocol) protocol.P.N, -1, OnlineOptions::singleton.batch_size, 1, params, {}, &protocol.P); triple_generator->multi_threaded = false; + this->P = &protocol.P; } void SemiPrep::buffer_triples() diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index ff0763833..abcec91ec 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -103,9 +103,7 @@ void ThreadMaster::run() machine.print_timers(); - cerr << "Data sent = " << stats.sent * 1e-6 << " MB" << endl; - - machine.print_global_comm(*P, stats); + machine.print_comm(*P, stats); delete P; } diff --git a/GC/TinierShare.h b/GC/TinierShare.h index 98d114989..17aa32d38 100644 --- a/GC/TinierShare.h +++ b/GC/TinierShare.h @@ -105,6 +105,11 @@ class TinierShare: public Share_, SemiShare>, *this = a + b; } + This operator^(const This& other) const + { + return *this + other; + } + This& operator^=(const This& other) { *this += other; diff --git a/GC/instructions.h b/GC/instructions.h index 272011947..67ea461a1 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -146,6 +146,7 @@ X(THRESHOLD, I0 = T::threshold(Thread::s().P->num_players())) \ X(PLAYERID, I0 = Thread::s().P->my_num()) \ X(CRASH, if (I0.get()) throw crash_requested()) \ + X(ACTIVE, ) \ #define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS diff --git a/Machines/OTMachine.cpp b/Machines/OTMachine.cpp index 7a54167dc..961dfbc5f 100644 --- a/Machines/OTMachine.cpp +++ b/Machines/OTMachine.cpp @@ -365,11 +365,11 @@ void OTMachine::run() { BitVector receiver_output, sender_output; char filename[1024]; - sprintf(filename, RECEIVER_INPUT, my_num); + snprintf(filename, 1024, RECEIVER_INPUT, my_num); ofstream outf(filename); receiverInput.output(outf, false); outf.close(); - sprintf(filename, RECEIVER_OUTPUT, my_num); + snprintf(filename, 1024, RECEIVER_OUTPUT, my_num); outf.open(filename); for (unsigned int i = 0; i < nOTs; i++) { @@ -380,7 +380,7 @@ void OTMachine::run() for (int i = 0; i < 2; i++) { - sprintf(filename, SENDER_OUTPUT, my_num, i); + snprintf(filename,1024, SENDER_OUTPUT, my_num, i); outf.open(filename); for (int j = 0; j < nOTs; j++) { diff --git a/Makefile b/Makefile index 7a03b864e..fc8b8fb97 100644 --- a/Makefile +++ b/Makefile @@ -116,7 +116,7 @@ mascot: mascot-party.x spdz2k mama-party.x ifeq ($(OS), Darwin) setup: mac-setup else -setup: boost mpir linux-machine-setup +setup: boost linux-machine-setup endif tldr: setup @@ -297,27 +297,6 @@ deps/SimplestOT_C/ref10/Makefile: Programs/Circuits: git submodule update --init Programs/Circuits -.PHONY: mpir-setup mpir-global -mpir-setup: deps/mpir/Makefile -deps/mpir/Makefile: - git submodule update --init deps/mpir || git clone https://github.com/wbhart/mpir deps/mpir - cd deps/mpir; \ - autoreconf -i; \ - autoreconf -i - - $(MAKE) -C deps/mpir clean - -mpir-global: mpir-setup - cd deps/mpir; \ - ./configure --enable-cxx; - $(MAKE) -C deps/mpir - sudo $(MAKE) -C deps/mpir install - -mpir: local/lib/libmpirxx.so -local/lib/libmpirxx.so: deps/mpir/Makefile - cd deps/mpir; \ - ./configure --enable-cxx --prefix=$(CURDIR)/local - $(MAKE) -C deps/mpir install - deps/libOTe/libOTe: git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe boost: deps/libOTe/libOTe @@ -369,26 +348,16 @@ cmake: ./bootstrap --parallel=8 --prefix=../local && make && make install mac-setup: mac-machine-setup - brew install openssl boost libsodium mpir yasm ntl cmake - -echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include >> CONFIG.mine - -echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib >> CONFIG.mine -# -echo USE_NTL = 1 >> CONFIG.mine + brew install openssl boost libsodium gmp yasm ntl cmake -ifeq ($(ARM), 1) -mac-machine-setup: - -echo ARCH = >> CONFIG.mine linux-machine-setup: - -echo ARCH = -march=armv8.2-a+crypto >> CONFIG.mine -else mac-machine-setup: -linux-machine-setup: -endif deps/simde/simde: git submodule update --init deps/simde || git clone https://github.com/simd-everywhere/simde deps/simde clean-deps: - -rm -rf local deps/libOTe/out + -rm -rf local/lib/liblibOTe.* deps/libOTe/out clean: clean-deps -rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so diff --git a/Math/Integer.h b/Math/Integer.h index 1fbb257fc..de594e637 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -17,6 +17,10 @@ using namespace std; #include "ValueInterface.h" #include "gf2nlong.h" +// Fix false warning +#if __GNUC__ == 10 +#pragma GCC diagnostic ignored "-Wstringop-overflow" +#endif // Functionality shared between integers and bit vectors template @@ -39,6 +43,8 @@ class IntBase : public ValueInterface static bool allows(Dtype type) { return type <= DATA_BIT; } + static void check_setup(const string&) {} + IntBase() { a = 0; } IntBase(T a) : a(a) {} diff --git a/Math/Setup.cpp b/Math/Setup.cpp index 715d480d6..38cc6a388 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -160,13 +160,13 @@ void check_setup(string dir, bigint pr) } string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, - const string& type_short) + const string& type_short, bool create) { string res = prep_dir + "/" + to_string(nparties) + "-" + type_short; if (log2mod > 1) res += "-" + to_string(log2mod); res += "/"; - if (mkdir_p(res.c_str()) < 0) + if (create and mkdir_p(res.c_str()) < 0) throw file_error("cannot create " + res); return res; } diff --git a/Math/Setup.h b/Math/Setup.h index 8c599198e..27724b58f 100644 --- a/Math/Setup.h +++ b/Math/Setup.h @@ -38,26 +38,28 @@ bigint generate_prime(int lgp, int m); int default_m(int& lgp, int& idx); string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, - const string& type_short); + const string& type_short, bool create = false); template -string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod) +string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, + bool create = false) { if (T::clear::length() > 1) log2mod = T::clear::length(); - return get_prep_sub_dir(prep_dir, nparties, log2mod, T::type_short()); + return get_prep_sub_dir(prep_dir, nparties, log2mod, T::type_short(), create); } template -string get_prep_sub_dir(const string& prep_dir, int nparties) +string get_prep_sub_dir(const string& prep_dir, int nparties, bool create = + false) { - return get_prep_sub_dir(prep_dir, nparties, T::clear::length()); + return get_prep_sub_dir(prep_dir, nparties, T::clear::length(), create); } template -string get_prep_sub_dir(int nparties) +string get_prep_sub_dir(int nparties, bool create = false) { - return get_prep_sub_dir(PREP_DIR, nparties); + return get_prep_sub_dir(PREP_DIR, nparties, create); } template diff --git a/Math/ValueInterface.cpp b/Math/ValueInterface.cpp new file mode 100644 index 000000000..db7904bba --- /dev/null +++ b/Math/ValueInterface.cpp @@ -0,0 +1,18 @@ +/* + * ValueInterface.cpp + * + */ + +#include "bigint.h" +#include "ValueInterface.h" + +#include + +void ValueInterface::check_setup(const string& directory) +{ + struct stat sb; + if (stat(directory.c_str(), &sb) != 0) + throw runtime_error(directory + " does not exist"); + if (not (sb.st_mode & S_IFDIR)) + throw runtime_error(directory + " is not a directory"); +} diff --git a/Math/ValueInterface.h b/Math/ValueInterface.h index 07807cb23..d10820201 100644 --- a/Math/ValueInterface.h +++ b/Math/ValueInterface.h @@ -7,6 +7,7 @@ #define MATH_VALUEINTERFACE_H_ #include "Tools/Exceptions.h" +#include "Math/Setup.h" class OnlineOptions; class bigint; @@ -31,9 +32,10 @@ class ValueInterface template static void generate_setup(string, int, int) {} template - static void write_setup(int) {} + static void write_setup(int nplayers) { get_prep_sub_dir(nplayers, true); } static void write_setup(string) {} - static void check_setup(string) {} + static void check_setup(const string& directory); + static const char* fake_opts() { return ""; } static bigint pr() { throw runtime_error("no prime modulus"); } diff --git a/Math/Z2k.h b/Math/Z2k.h index e8d2ba532..b5ffb196e 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -6,7 +6,7 @@ #ifndef MATH_Z2K_H_ #define MATH_Z2K_H_ -#include +#include #include using namespace std; @@ -74,6 +74,8 @@ class Z2 : public ValueInterface static Z2 power_of_two(bool bit, int exp) { return Z2(bit) << exp; } + static string fake_opts() { return " -lgp " + to_string(K); } + typedef Z2 next; typedef Z2 Scalar; diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index f816b4f04..95ac1e8d4 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -53,7 +53,7 @@ void Zp_Data::init(const bigint& p,bool mont) mpn_copyi(R3,r3.get_mpz_t()->_mp_d,mpz_size(r3.get_mpz_t())); if (sizeof(unsigned long)!=sizeof(mp_limb_t)) - { cout << "The underlying types of MPIR mean we cannot use our Montgomery code" << endl; + { cout << "The underlying types of GMP mean we cannot use our Montgomery code" << endl; throw not_implemented(); } } @@ -194,3 +194,37 @@ bool Zp_Data::operator==(const Zp_Data& other) const { return not (*this != other); } + +void Zp_Data::get_shanks_parameters(bigint& y, bigint& q_half, int& r) const +{ + if (shanks_y == 0) + { + auto& p = pr; + bigint n, q, yy, xx, temp; + // Find n such that (n/p)=-1 + int leg = 1; + gmp_randclass Gen(gmp_randinit_default); + Gen.seed(0); + while (leg != -1) + { + n = Gen.get_z_range(p); + leg = mpz_legendre(n.get_mpz_t(), p.get_mpz_t()); + } + // Split p-1 = 2^e q + q = p - 1; + int e = 0; + while (mpz_even_p(q.get_mpz_t())) + { + e++; + q = q / 2; + } + // y=n^q mod p, x=a^((q-1)/2) mod p, r=e + shanks_r = e; + mpz_powm(shanks_y.get_mpz_t(), n.get_mpz_t(), q.get_mpz_t(), p.get_mpz_t()); + shanks_q_half = (q - 1) / 2; + } + + y = shanks_y; + q_half = shanks_q_half; + r = shanks_r; +} diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 3d3ecc20d..5ff3f6351 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -38,6 +38,8 @@ class Zp_Data int t; // More Montgomery data mp_limb_t overhang; Lock lock; + mutable bigint shanks_y, shanks_q_half; + mutable int shanks_r; template void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; @@ -89,6 +91,8 @@ class Zp_Data bool operator!=(const Zp_Data& other) const; bool operator==(const Zp_Data& other) const; + void get_shanks_parameters(bigint& y, bigint& q_half, int& r) const; + template friend void to_modp(modp_& ans,int x,const Zp_Data& ZpD); template friend void to_modp(modp_& ans,const mpz_class& x,const Zp_Data& ZpD); diff --git a/Math/bigint.cpp b/Math/bigint.cpp index 1952859d9..aef081a3a 100644 --- a/Math/bigint.cpp +++ b/Math/bigint.cpp @@ -10,76 +10,10 @@ #include "bigint.hpp" -class gmp_random -{ -public: - gmp_randclass Gen; - gmp_random() : Gen(gmp_randinit_default) - { - Gen.seed(0); - } -}; - thread_local bigint bigint::tmp = 0; thread_local bigint bigint::tmp2 = 0; thread_local gmp_random bigint::random; -bigint sqrRootMod(const bigint& a,const bigint& p) -{ - bigint ans; - if (a==0) { ans=0; return ans; } - if (mpz_legendre(a.get_mpz_t(), p.get_mpz_t()) != 1) - throw runtime_error("cannot compute square root of non-square"); - if (mpz_tstbit(p.get_mpz_t(),1)==1) - { // First do case with p=3 mod 4 - bigint exp=(p+1)/4; - mpz_powm(ans.get_mpz_t(),a.get_mpz_t(),exp.get_mpz_t(),p.get_mpz_t()); - } - else - { // Shanks algorithm - bigint x,y,n,q,t,b,temp; - // Find n such that (n/p)=-1 - int leg=1; - while (leg!=-1) - { n=bigint::random.Gen.get_z_range(p); - leg=mpz_legendre(n.get_mpz_t(),p.get_mpz_t()); - } - // Split p-1 = 2^e q - q=p-1; - int e=0; - while (mpz_even_p(q.get_mpz_t())) - { e++; q=q/2; } - // y=n^q mod p, x=a^((q-1)/2) mod p, r=e - int r=e; - mpz_powm(y.get_mpz_t(),n.get_mpz_t(),q.get_mpz_t(),p.get_mpz_t()); - temp=(q-1)/2; - mpz_powm(x.get_mpz_t(),a.get_mpz_t(),temp.get_mpz_t(),p.get_mpz_t()); - // b=a*x^2 mod p, x=a*x mod p - b=(a*x*x)%p; - x=(a*x)%p; - // While b!=1 do - while (b!=1) - { // Find smallest m such that b^(2^m)=1 mod p - int m=1; - temp=(b*b)%p; - while (temp!=1) - { temp=(temp*temp)%p; m++; } - // t=y^(2^(r-m-1)) mod p, y=t^2, r=m - t=y; - for (int i=0; i -#include +#include #include "Tools/Exceptions.h" #include "Tools/int.h" @@ -39,7 +39,7 @@ namespace GC /** * Type for arbitrarily large integers. - * This is a sub-class of ``mpz_class`` from MPIR. As such, it implements + * This is a sub-class of ``mpz_class`` from GMP. As such, it implements * all integers operations and input/output via C++ streams. In addition, * the ``get_ui()`` member function allows retrieving the least significant * 64 bits. @@ -139,8 +139,6 @@ class bigint : public mpz_class void inline_mpn_zero(mp_limb_t* x, mp_size_t size); void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t size); -#include "Z2k.h" - inline bigint& bigint::operator=(int n) { @@ -281,11 +279,7 @@ inline int numBytes(const bigint& m) inline int probPrime(const bigint& x) { - gmp_randstate_t rand_state; - gmp_randinit_default(rand_state); - int ans = mpz_probable_prime_p(x.get_mpz_t(), rand_state, - max(40, DEFAULT_SECURITY), 0); - gmp_randclear(rand_state); + int ans = mpz_probab_prime_p(x.get_mpz_t(), max(40, DEFAULT_SECURITY) / 2); return ans; } @@ -318,7 +312,8 @@ inline int isOdd(const bigint& x) } -bigint sqrRootMod(const bigint& x,const bigint& p); +template +bigint sqrRootMod(const T& x); bigint powerMod(const bigint& x,const bigint& e,const bigint& p); diff --git a/Math/bigint.hpp b/Math/bigint.hpp index 9662d0bf8..afd995a6b 100644 --- a/Math/bigint.hpp +++ b/Math/bigint.hpp @@ -26,7 +26,7 @@ bigint& bigint::from_signed(const T& other) template mpf_class bigint::get_float(T v, T p, T z, T s) { - // MPIR can't handle more precision in exponent + // GMP can't handle more precision in exponent Integer exp = Integer(p, 31).get(); bigint tmp; tmp.from_signed(v); @@ -59,4 +59,76 @@ void bigint::output_float(U& o, const mpf_class& x, T nan) o << "NaN"; } + +class gmp_random +{ +public: + gmp_randclass Gen; + gmp_random() : Gen(gmp_randinit_default) + { + Gen.seed(0); + } +}; + +template +bigint sqrRootMod(const T& aa) +{ + bigint a = aa; + bigint p = T::pr(); + + bigint ans; + if (a == 0) + { + ans = 0; + return ans; + } + if (mpz_legendre(a.get_mpz_t(), p.get_mpz_t()) != 1) + throw runtime_error("cannot compute square root of non-square"); + if (mpz_tstbit(p.get_mpz_t(), 1) == 1) + { + // First do case with p=3 mod 4 + bigint exp = (p + 1) / 4; + mpz_powm(ans.get_mpz_t(), a.get_mpz_t(), exp.get_mpz_t(), + p.get_mpz_t()); + } + else + { + // Shanks algorithm + bigint n, q, yy, xx, temp; + int r; + T::get_ZpD().get_shanks_parameters(yy, temp, r); + mpz_powm(xx.get_mpz_t(), a.get_mpz_t(), temp.get_mpz_t(), p.get_mpz_t()); + // b=a*x^2 mod p, x=a*x mod p + T x = xx; + T b = (aa * x * x); + x = (aa * x); + T y = yy; + // While b!=1 do + while (b != 1) + { + // Find smallest m such that b^(2^m)=1 mod p + int m = 1; + T temp = (b * b); + while (temp != 1) + { + temp = (temp * temp); + m++; + } + // t=y^(2^(r-m-1)) mod p, y=t^2, r=m + T t = y; + for (int i = 0; i < r - m - 1; i++) + { + t = (t * t); + } + y = (t * t); + r = m; + // x=x*t mod p, b=b*y mod p + x = (x * t); + b = (b * y); + } + ans = x; + } + return ans; +} + #endif /* MATH_BIGINT_HPP_ */ diff --git a/Math/gf2n.h b/Math/gf2n.h index 5b7b06a41..a09d9aed4 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -17,6 +17,7 @@ class gf2n_short; class P2Data; class Bit; class int128; +template class IntBase; template class Square; typedef Square gf2n_short_square; @@ -88,6 +89,8 @@ class gf2n_ : public ValueInterface static string options(); + static string fake_opts() { return " -lg2 " + to_string(length()); } + static const true_type invertible; static const true_type characteristic_two; diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index a15dbfc62..64035090f 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -154,6 +154,8 @@ class gf2n_long : public gf2n_ gf2n_long(int g) : gf2n_long(int128(unsigned(g))) {} template gf2n_long(IntBase g) : super(g.get()) {} + template + gf2n_long(const gf2n_& a) : super(int128(a.get())) {} }; #if defined(__aarch64__) && defined(__clang__) diff --git a/Math/gfp.h b/Math/gfp.h index de00934a0..9f5475e24 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -105,6 +105,7 @@ class gfp_ : public ValueInterface static void write_setup(string dir) { write_online_setup(dir, pr()); } static void check_setup(string dir); + static string fake_opts() { return " -lgp " + to_string(length()); } /** * Get the prime modulus @@ -314,6 +315,8 @@ gfp_::gfp_(long x) { if (x == 0) assign_zero(); + else if (x == 1) + assign_one(); else *this = bigint::tmp = x; } diff --git a/Math/gfp.hpp b/Math/gfp.hpp index 3387f7ae4..0e0f7b624 100644 --- a/Math/gfp.hpp +++ b/Math/gfp.hpp @@ -146,8 +146,7 @@ gfp_ gfp_::sqrRoot() { // Temp move to bigint so as to call sqrRootMod bigint ti; - to_bigint(ti, *this); - ti = sqrRootMod(ti, ZpD.pr); + ti = sqrRootMod(*this); if (!isOdd(ti)) ti = ZpD.pr - ti; gfp_ temp; diff --git a/Math/gfpvar.cpp b/Math/gfpvar.cpp index 368bca4b4..eb065cf4f 100644 --- a/Math/gfpvar.cpp +++ b/Math/gfpvar.cpp @@ -312,8 +312,8 @@ gfpvar_ gfpvar_::invert() const template gfpvar_ gfpvar_::sqrRoot() const { - bigint ti = *this; - ti = sqrRootMod(ti, ZpD.pr); + bigint ti; + ti = sqrRootMod(*this); if (!isOdd(ti)) ti = ZpD.pr - ti; return ti; diff --git a/Math/gfpvar.h b/Math/gfpvar.h index 7d332fdd8..55c08d4b4 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -81,6 +81,7 @@ class gfpvar_ { write_setup(get_prep_sub_dir(nplayers)); } + static string fake_opts() { return " -lgp " + to_string(length()); } gfpvar_(); gfpvar_(int other); diff --git a/Math/modp.h b/Math/modp.h index f84da1d68..3f19de003 100644 --- a/Math/modp.h +++ b/Math/modp.h @@ -2,7 +2,7 @@ #define _Modp /* - * Currently we only support an MPIR based implementation. + * Currently we only support an GMP based implementation. * * What ever is type-def'd to bigint is assumed to have * operator overloading for all standard operators, has diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h index 87e94078f..b8f73ef4a 100644 --- a/Math/mpn_fixed.h +++ b/Math/mpn_fixed.h @@ -6,7 +6,7 @@ #ifndef MATH_MPN_FIXED_H_ #define MATH_MPN_FIXED_H_ -#include +#include #include #include diff --git a/Math/square128.cpp b/Math/square128.cpp index fadbe21f6..7aa2c75ca 100644 --- a/Math/square128.cpp +++ b/Math/square128.cpp @@ -3,7 +3,7 @@ * */ -#include +#include #include "OT/BitMatrix.h" #include "Tools/random.h" diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 048843bb1..fc9350e68 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -78,7 +78,7 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante } } if (nplayers_wanted > 0 and nplayers_wanted != nplayers) - throw runtime_error("not enought hosts in HOSTS"); + throw runtime_error("not enough hosts in " + filename); #ifdef DEBUG_NETWORKING cerr << "Got list of " << nplayers << " players from file: " << endl; for (unsigned int i = 0; i < names.size(); i++) @@ -324,7 +324,9 @@ void PlainPlayer::setup_sockets(const vector& names, template void MultiPlayer::send_long(int i, long a) const { + TimeScope ts(comm_stats["Sending by number"].add(sizeof(long))); send(sockets[i], (octet*)&a, sizeof(long)); + sent += sizeof(long); } template @@ -716,7 +718,7 @@ size_t VirtualTwoPartyPlayer::send(const PlayerBuffer& buffer, bool block) const { auto sent = P.send_no_stats(other_player, buffer, block); lock.lock(); - comm_stats["Sending one-to-one"].add(sent); + comm_stats.add_to_last_round("Sending one-to-one", sent); comm_stats.sent += sent; lock.unlock(); return sent; @@ -726,7 +728,7 @@ size_t VirtualTwoPartyPlayer::recv(const PlayerBuffer& buffer, bool block) const { auto received = P.recv_no_stats(other_player, buffer, block); lock.lock(); - comm_stats["Receiving one-to-one"].add(received); + comm_stats.add_to_last_round("Receiving one-to-one", received); lock.unlock(); return received; } @@ -805,6 +807,17 @@ void NamedCommStats::reset() sent = 0; } +Timer& NamedCommStats::add_to_last_round(const string& name, size_t length) +{ + if (name == last) + return (*this)[name].add_length_only(length); + else + { + last = name; + return (*this)[name].add(length); + } +} + void PlayerBase::reset_stats() { comm_stats.reset(); diff --git a/Networking/Player.h b/Networking/Player.h index a02f288c9..d31e1ebac 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -136,11 +136,15 @@ struct CommStats CommStats() : data(0), rounds(0) {} Timer& add(size_t length) { + rounds++; + return add_length_only(length); + } + Timer& add_length_only(size_t length) + { #ifdef VERBOSE_COMM cout << "add " << length << endl; #endif data += length; - rounds++; return timer; } Timer& add(const octetStream& os) { return add(os.get_length()); } @@ -153,6 +157,7 @@ class NamedCommStats : public map { public: size_t sent; + string last; NamedCommStats(); @@ -161,6 +166,7 @@ class NamedCommStats : public map NamedCommStats operator-(const NamedCommStats& other) const; void print(bool newline = false); void reset(); + Timer& add_to_last_round(const string& name, size_t length); #ifdef VERBOSE_COMM CommStats& operator[](const string& name) { diff --git a/Networking/sockets.cpp b/Networking/sockets.cpp index 8034809e6..b03613d4c 100644 --- a/Networking/sockets.cpp +++ b/Networking/sockets.cpp @@ -134,7 +134,7 @@ void close_client_socket(int socket) if (close(socket)) { char tmp[1000]; - sprintf(tmp, "close(%d)", socket); + snprintf(tmp, 1000, "close(%d)", socket); error(tmp); } } diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 2929c11b2..ee9e19bc3 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -126,7 +126,7 @@ void BaseMachine::time() void BaseMachine::start(int n) { cout << "Starting timer " << n << " at " << timer[n].elapsed() - << " (" << timer[n].mb_sent() << " MB)" + << " (" << timer[n] << ")" << " after " << timer[n].idle() << endl; timer[n].start(total_comm()); } @@ -135,7 +135,7 @@ void BaseMachine::stop(int n) { timer[n].stop(total_comm()); cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " (" - << timer[n].mb_sent() << " MB)" << endl; + << timer[n] << ")" << endl; } void BaseMachine::print_timers() @@ -150,7 +150,7 @@ void BaseMachine::print_timers() timer.erase(0); for (auto it = timer.begin(); it != timer.end(); it++) cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds (" - << it->second.mb_sent() << " MB)" << endl; + << it->second << ")" << endl; } string BaseMachine::memory_filename(const string& type_short, int my_number) @@ -227,3 +227,19 @@ void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats) global += os.get_int(8); cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl; } + +void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats) +{ + size_t rounds = 0; + for (auto& x : comm_stats) + rounds += x.second.rounds; + cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds + << " rounds (party " << P.my_num() << " only"; + if (nthreads > 1) + cerr << "; rounds counted double due to multi-threading"; + if (not OnlineOptions::singleton.verbose) + cerr << "; use '-v' for more details"; + cerr << ")" << endl; + + print_global_comm(P, comm_stats); +} diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 6b5a029f1..46b1a85e6 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -67,6 +67,7 @@ class BaseMachine void print_timers(); virtual void reqbl(int) {} + virtual void active(int) {} static OTTripleSetup fresh_ot_setup(Player& P); @@ -74,6 +75,7 @@ class BaseMachine void set_thread_comm(const NamedCommStats& stats); void print_global_comm(Player& P, const NamedCommStats& stats); + void print_comm(Player& P, const NamedCommStats& stats); }; inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P) diff --git a/Processor/Conv2dTuple.h b/Processor/Conv2dTuple.h new file mode 100644 index 000000000..8e265ab36 --- /dev/null +++ b/Processor/Conv2dTuple.h @@ -0,0 +1,41 @@ +/* + * Conv2dTuple.h + * + */ + +#ifndef PROCESSOR_CONV2DTUPLE_H_ +#define PROCESSOR_CONV2DTUPLE_H_ + +#include +using namespace std; + +class Conv2dTuple +{ +public: + int output_h, output_w; + int inputs_h, inputs_w; + int weights_h, weights_w; + int stride_h, stride_w; + int n_channels_in; + int padding_h; + int padding_w; + int batch_size; + size_t r0; + size_t r1; + int r2; + vector>> lengths; + int filter_stride_h = 1; + int filter_stride_w = 1; + + Conv2dTuple(const vector& args, int start); + + template + void pre(vector& S, typename T::Protocol& protocol); + template + void post(vector& S, typename T::Protocol& protocol); + + template + void run_matrix(SubProcessor& processor); +}; + +#endif /* PROCESSOR_CONV2DTUPLE_H_ */ diff --git a/Processor/DataPositions.cpp b/Processor/DataPositions.cpp index c32eb019f..2294a5991 100644 --- a/Processor/DataPositions.cpp +++ b/Processor/DataPositions.cpp @@ -222,7 +222,8 @@ bool DataPositions::any_more(const DataPositions& other) const for (auto it = edabits.begin(); it != edabits.end(); it++) { auto x = other.edabits.find(it->first); - if (x == other.edabits.end() or it->second > x->second) + if ((x == other.edabits.end() or it->second > x->second) + and it->second > 0) return true; } diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 7f8a49bf6..a4a3e515a 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -12,6 +12,8 @@ #include "Networking/Player.h" #include "Protocols/edabit.h" #include "PrepBase.h" +#include "EdabitBuffer.h" +#include "Tools/TimerWithComm.h" #include #include @@ -102,9 +104,6 @@ class Preprocessing : public PrepBase DataPositions& usage; - map, vector>> edabits; - map, edabitvec> my_edabits; - bool do_count; void count(Dtype dtype, int n = 1) @@ -120,6 +119,8 @@ class Preprocessing : public PrepBase const vector&, true_type) { throw not_implemented(); } + void fill(edabitvec& res, bool strict, int n_bits); + T get_random_from_inputs(int nplayers); public: @@ -173,12 +174,11 @@ class Preprocessing : public PrepBase virtual void get_edabits(bool strict, size_t size, T* a, vector& Sb, const vector& regs) { get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); } - template - void get_edabit_no_count(bool, int n_bits, edabit& eb); - template + virtual void get_edabit_no_count(bool, int, edabit&) + { throw runtime_error("no edaBits"); } /// Get fresh edaBit chunk - edabitvec get_edabitvec(bool strict, int n_bits); - virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); } + virtual edabitvec get_edabitvec(bool, int) + { throw runtime_error("no edabitvec"); } virtual void push_triples(const vector>&) { throw runtime_error("no pushing"); } @@ -204,7 +204,8 @@ class Sub_Data_Files : public Preprocessing BufferOwner, RefInputTuple> my_input_buffers; map > extended; BufferOwner, dabit> dabit_buffer; - map edabit_buffers; + map> edabit_buffers; + map> my_edabits; int my_num,num_players; @@ -213,13 +214,11 @@ class Sub_Data_Files : public Preprocessing part_type* part; - void buffer_edabits_with_queues(bool strict, int n_bits) - { buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); } - template - void buffer_edabits_with_queues(bool strict, int n_bits, false_type); - template - void buffer_edabits_with_queues(bool, int, true_type) - { throw not_implemented(); } + EdabitBuffer& get_edabit_buffer(int n_bits); + + /// Get fresh edaBit chunk + edabitvec get_edabitvec(bool strict, int n_bits); + void get_edabit_no_count(bool strict, int n_bits, edabit& eb); public: static string get_filename(const Names& N, Dtype type, int thread_num = -1); @@ -317,6 +316,8 @@ class Data_Files void reset_usage() { usage.reset(); skipped.reset(); } void set_usage(const DataPositions& pos) { usage = pos; } + + TimerWithComm total_time(); }; template inline diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index b42d2f76d..2552dc113 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -108,7 +108,21 @@ Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, #ifdef DEBUG_FILES cerr << "Setting up Data_Files in: " << prep_data_dir << endl; #endif - T::clear::check_setup(prep_data_dir); + + try + { + T::clear::check_setup(prep_data_dir); + } + catch (...) + { + cerr << "Something is wrong with the preprocessing data on disk." << endl; + cerr + << "Have you run the right program for generating it, such as './Fake-Offline.x " + << num_players + << T::clear::fake_opts() << "'?" << endl; + throw; + } + string type_short = T::type_short(); string type_string = T::type_string(); @@ -135,7 +149,7 @@ Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, type_short, i, my_num, thread_num); if (i == my_num) my_input_buffers.setup(filename, - T::size() + T::clear::size(), type_string); + InputTuple::size(), type_string); else input_buffers[i].setup(filename, T::size(), type_string); @@ -179,10 +193,6 @@ Data_Files::~Data_Files() template Sub_Data_Files::~Sub_Data_Files() { - for (auto& x: edabit_buffers) - { - delete x.second; - } if (part != 0) delete part; } @@ -229,6 +239,26 @@ void Sub_Data_Files::seekg(DataPositions& pos) extended[it->first].seekg(it->second); } dabit_buffer.seekg(pos.files[field_type][DATA_DABIT]); + + if (field_type == DATA_INT) + { + for (auto& x : pos.edabits) + { + // open files + get_edabit_buffer(x.first.second); + } + + + int block_size = edabitvec::MAX_SIZE; + for (auto& x : edabit_buffers) + { + int n = pos.edabits[{true, x.first}] + pos.edabits[{false, x.first}]; + x.second.seekg(n / block_size); + edabit eb; + for (int i = 0; i < n % block_size; i++) + get_edabit_no_count(false, x.first, eb); + } + } } template @@ -262,6 +292,8 @@ void Sub_Data_Files::prune() dabit_buffer.prune(); if (part != 0) part->prune(); + for (auto& x : edabit_buffers) + x.second.prune(); } template @@ -285,6 +317,8 @@ void Sub_Data_Files::purge() dabit_buffer.purge(); if (part != 0) part->purge(); + for (auto& x : edabit_buffers) + x.second.prune(); } template @@ -322,34 +356,43 @@ void Sub_Data_Files::get_dabit_no_count(T& a, typename T::bit_type& b) } template -template -void Sub_Data_Files::buffer_edabits_with_queues(bool strict, int n_bits, - false_type) +EdabitBuffer& Sub_Data_Files::get_edabit_buffer(int n_bits) { - if (edabit_buffers.empty()) - insecure("reading edaBits from files"); - if (edabit_buffers.find(n_bits) == edabit_buffers.end()) { string filename = PrepBase::get_edabit_filename(prep_data_dir, n_bits, my_num, thread_num); - ifstream* f = new ifstream(filename); - if (f->fail()) - throw runtime_error("cannot open " + filename); - check_file_signature(*f, filename); - edabit_buffers[n_bits] = f; + edabit_buffers[n_bits] = n_bits; + edabit_buffers[n_bits].setup(filename, + T::size() * edabitvec::MAX_SIZE + + n_bits * T::bit_type::part_type::size()); } - auto& buffer = *edabit_buffers[n_bits]; - if (buffer.peek() == EOF) + return edabit_buffers[n_bits]; +} + +template +edabitvec Sub_Data_Files::get_edabitvec(bool strict, int n_bits) +{ + if (my_edabits[n_bits].empty()) + return get_edabit_buffer(n_bits).read(); + else { - buffer.seekg(0); - check_file_signature(buffer, ""); + auto res = my_edabits[n_bits]; + my_edabits[n_bits] = {}; + this->fill(res, strict, n_bits); + return res; + } +} + +template +void Preprocessing::fill(edabitvec& res, bool strict, int n_bits) +{ + edabit eb; + while (res.size() < res.MAX_SIZE) + { + get_edabit_no_count(strict, n_bits, eb); + res.push_back(eb); } - edabitvec eb; - eb.input(n_bits, buffer); - this->edabits[{strict, n_bits}].push_back(eb); - if (buffer.fail()) - throw runtime_error("error reading edaBits"); } template @@ -362,4 +405,10 @@ typename Sub_Data_Files::part_type& Sub_Data_Files::get_part() return *part; } +template +TimerWithComm Data_Files::total_time() +{ + return DataFp.prep_timer + DataF2.prep_timer + DataFb.prep_timer; +} + #endif diff --git a/Processor/EdabitBuffer.h b/Processor/EdabitBuffer.h new file mode 100644 index 000000000..af87a0559 --- /dev/null +++ b/Processor/EdabitBuffer.h @@ -0,0 +1,50 @@ +/* + * EdabitBuffer.h + * + */ + +#ifndef PROCESSOR_EDABITBUFFER_H_ +#define PROCESSOR_EDABITBUFFER_H_ + +#include "Tools/Buffer.h" + +template +class EdabitBuffer : public BufferOwner +{ + int n_bits; + + int element_length() + { + return -1; + } + +public: + EdabitBuffer(int n_bits = 0) : + n_bits(n_bits) + { + } + + edabitvec read() + { + if (not BufferBase::file) + { + if (this->open()->fail()) + throw runtime_error("error opening " + this->filename); + } + + assert(BufferBase::file); + auto& buffer = *BufferBase::file; + if (buffer.peek() == EOF) + { + this->try_rewind(); + } + + edabitvec eb; + eb.input(n_bits, buffer); + if (buffer.fail()) + throw runtime_error("error reading edaBits"); + return eb; + } +}; + +#endif /* PROCESSOR_EDABITBUFFER_H_ */ diff --git a/Processor/Instruction.h b/Processor/Instruction.h index a70e095cb..ea900b3e3 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -70,6 +70,7 @@ enum PLAYERID = 0xE4, USE_EDABIT = 0xE5, USE_MATMUL = 0x1F, + ACTIVE = 0xE9, // Addition ADDC = 0x20, ADDS = 0x21, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index dd62afc40..276b2b3d4 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -311,6 +311,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case PRIVATEOUTPUT: case TRUNC_PR: case RUN_TAPE: + case CONV2DS: num_var_args = get_int(s); get_vector(num_var_args, start, s); break; @@ -322,10 +323,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) get_ints(r, s, 3); get_vector(9, start, s); break; - case CONV2DS: - get_ints(r, s, 3); - get_vector(12, start, s); - break; // read from file, input is opcode num_args, // start_file_posn (read), end_file_posn(write) var1, var2, ... @@ -425,6 +422,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) throw Processor_Error(ss.str()); } break; + case ACTIVE: + n = get_int(s); + BaseMachine::s().active(n); + break; case XORM: case ANDM: case XORCB: @@ -720,7 +721,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const case MATMULSM: return r[0] + start[0] * start[2]; case CONV2DS: - return r[0] + start[0] * start[1] * start[11]; + { + unsigned res = 0; + for (size_t i = 0; i < start.size(); i += 15) + { + unsigned tmp = start[i] + + start[i + 3] * start[i + 4] * start.at(i + 14); + res = max(res, tmp); + } + return res; + } case OPEN: skip = 2; break; @@ -1164,6 +1174,7 @@ inline void Instruction::execute(Processor& Proc) const break; case REQBL: case GREQBL: + case ACTIVE: case USE: case USE_INP: case USE_EDABIT: diff --git a/Processor/Machine.h b/Processor/Machine.h index 7317d3199..fb7f5d939 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -109,6 +109,7 @@ class Machine : public BaseMachine string prep_dir_prefix(); void reqbl(int n); + void active(int n); typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; } typename sint::mac_key_type get_sint_mac_key() { return alphapi; } diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index e9e3eb209..e2022dc4f 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -415,6 +415,9 @@ pair Machine::stop_threads() auto comm_stats = total_comm(); + if (OnlineOptions::singleton.verbose) + queues.print_breakdown(); + for (auto& queue : queues) delete queue; @@ -477,20 +480,7 @@ void Machine::run(const string& progname) print_timers(); if (sint::is_real) - { - size_t rounds = 0; - for (auto& x : comm_stats) - rounds += x.second.rounds; - cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds - << " rounds (party " << my_number; - if (threads.size() > 1) - cerr << "; rounds counted double due to multi-threading"; - cerr << "; use '-v' for more details"; - cerr << ")" << endl; - - auto& P = *this->P; - this->print_global_comm(P, comm_stats); - } + this->print_comm(*this->P, comm_stats); #ifdef VERBOSE_OPTIONS if (opening_sum < N.num_players() && !direct) @@ -521,23 +511,6 @@ void Machine::run(const string& progname) bit_memories.write_memory(N.my_num()); -#ifdef OLD_USAGE - for (int dtype = 0; dtype < N_DTYPE; dtype++) - { - cerr << "Num " << DataPositions::dtype_names[dtype] << "\t="; - for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++) - cerr << " " << pos.files[field_type][dtype]; - cerr << endl; - } - for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++) - { - cerr << "Num " << DataPositions::field_names[field_type] << " Inputs\t="; - for (int i = 0; i < N.num_players(); i++) - cerr << " " << pos.inputs[i][field_type]; - cerr << endl; - } -#endif - if (opts.verbose) { cerr << "Actual cost of program:" << endl; @@ -586,6 +559,17 @@ void Machine::reqbl(int n) sint::clear::reqbl(n); } +template +void Machine::active(int n) +{ + + if (sint::malicious and n == 0) + { + cerr << "Program requires a semi-honest protocol" << endl; + exit(1); + } +} + template void Machine::suggest_optimizations() { @@ -599,8 +583,8 @@ void Machine::suggest_optimizations() optimizations.append("\tprogram.use_edabit(True)\n"); if (not optimizations.empty()) cerr << "This program might benefit from some protocol options." << endl - << "Consider adding the following at the beginning of '" << progname - << ".mpc':" << endl << optimizations; + << "Consider adding the following at the beginning of your code:" + << endl << optimizations; #ifndef __clang__ cerr << "This virtual machine was compiled with GCC. Recompile with " "'CXX = clang++' in 'CONFIG.mine' for optimal performance." << endl; diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index 98359a499..e18c47b5f 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -172,7 +172,7 @@ void OfflineMachine::generate() auto& opts = OnlineOptions::singleton; opts.batch_size = DIV_CEIL(opts.batch_size, batch) * batch; for (int i = 0; i < buffered_total(total, batch) / batch; i++) - preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits, + preprocessing.get_edabitvec(true, n_bits).output(n_bits, out); } else diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 1182d61f5..3409c0470 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -44,6 +44,7 @@ void thread_info::Sub_Main_Func() auto& queues = machine.queues[num]; queues->next(); + ThreadQueue::thread_queue = queues; #ifdef DEBUG_THREADS fprintf(stderr, "\tI am in thread %d\n",num); @@ -118,6 +119,8 @@ void thread_info::Sub_Main_Func() DataPositions actual_usage(P.num_players()); Timer thread_timer(CLOCK_THREAD_CPUTIME_ID), wait_timer; thread_timer.start(); + TimerWithComm timer, online_timer, online_prep_timer; + timer.start(); while (flag) { // Wait until I have a program to run @@ -262,6 +265,8 @@ void thread_info::Sub_Main_Func() #ifdef DEBUG_THREADS printf("\tClient %d about to run %d\n",num,program); #endif + online_timer.start(P.total_comm()); + online_prep_timer -= Proc.DataF.total_time(); Proc.reset(progs[program], job.arg); // Bits, Triples, Squares, and Inverses skipping @@ -290,6 +295,8 @@ void thread_info::Sub_Main_Func() printf("\tSignalling I have finished with program %d" "in thread %d\n", program, num); #endif + online_timer.stop(P.total_comm()); + online_prep_timer += Proc.DataF.total_time(); wait_timer.start(); queues->finished(job, P.total_comm()); wait_timer.stop(); @@ -297,7 +304,11 @@ void thread_info::Sub_Main_Func() } // final check + online_timer.start(P.total_comm()); + online_prep_timer -= Proc.DataF.total_time(); Proc.check(); + online_timer.stop(P.total_comm()); + online_prep_timer += Proc.DataF.total_time(); if (machine.opts.file_prep_per_thread) Proc.DataF.prune(); @@ -330,6 +341,11 @@ void thread_info::Sub_Main_Func() // wind down thread by thread machine.stats += Proc.stats; + queues->timers["wait"] = wait_timer + queues->wait_timer; + timer.stop(P.total_comm()); + queues->timers["online"] = online_timer - online_prep_timer - queues->wait_timer; + queues->timers["prep"] = timer - queues->timers["wait"] - queues->timers["online"]; + // prevent faulty usage message Proc.DataF.set_usage(actual_usage); delete processor; diff --git a/Processor/PrepBase.cpp b/Processor/PrepBase.cpp index a2f79027e..775caf01a 100644 --- a/Processor/PrepBase.cpp +++ b/Processor/PrepBase.cpp @@ -69,7 +69,7 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict, cerr << " edaBits of size " << n_bits << " left" << endl; } - if (n > used / 10) + if (n * n_batch > used / 10) cerr << "Significant amount of unused edaBits of size " << n_bits << ". For more accurate benchmarks, " << "consider reducing the batch size with --batch-size " diff --git a/Processor/PrepBase.h b/Processor/PrepBase.h index ccc2f4b40..e598d31ce 100644 --- a/Processor/PrepBase.h +++ b/Processor/PrepBase.h @@ -10,6 +10,7 @@ using namespace std; #include "Math/field_types.h" +#include "Tools/TimerWithComm.h" class PrepBase { @@ -28,6 +29,8 @@ class PrepBase const string& type_string, size_t used); static void print_left_edabits(size_t n, size_t n_batch, bool strict, int n_bits, size_t used); + + TimerWithComm prep_timer; }; #endif /* PROCESSOR_PREPBASE_H_ */ diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 34cb5b369..747f51188 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -5,6 +5,7 @@ #include "Processor/Program.h" #include "GC/square64.h" #include "SpecificPrivateOutput.h" +#include "Conv2dTuple.h" #include "Processor/ProcessorBase.hpp" #include "GC/Processor.hpp" @@ -31,6 +32,7 @@ SubProcessor::SubProcessor(typename T::MAC_Check& MC, DataF.set_proc(this); protocol.init(DataF, MC); DataF.set_protocol(protocol); + MC.set_prep(DataF); bit_usage.set_num_players(P.num_players()); personal_bit_preps.resize(P.num_players()); for (int i = 0; i < P.num_players(); i++) @@ -40,6 +42,7 @@ SubProcessor::SubProcessor(typename T::MAC_Check& MC, template SubProcessor::~SubProcessor() { + DataF.set_proc(0); for (size_t i = 0; i < personal_bit_preps.size(); i++) { auto& x = personal_bit_preps[i]; @@ -391,7 +394,7 @@ void Processor::read_shares_from_file(int start_file_posn, int end_ return; string filename; - filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; + filename = binary_file_io.filename(P.my_num()); unsigned int size = data_registers.size(); @@ -652,21 +655,35 @@ void SubProcessor::conv2ds(const Instruction& instruction) { protocol.init_dotprod(); auto& args = instruction.get_start(); - int output_h = args[0], output_w = args[1]; - int inputs_h = args[2], inputs_w = args[3]; - int weights_h = args[4], weights_w = args[5]; - int stride_h = args[6], stride_w = args[7]; - int n_channels_in = args[8]; - int padding_h = args[9]; - int padding_w = args[10]; - int batch_size = args[11]; - size_t r0 = instruction.get_r(0); - size_t r1 = instruction.get_r(1); - int r2 = instruction.get_r(2); - int lengths[batch_size][output_h][output_w]; - memset(lengths, 0, sizeof(lengths)); - int filter_stride_h = 1; - int filter_stride_w = 1; + vector tuples; + for (size_t i = 0; i < args.size(); i += 15) + tuples.push_back(Conv2dTuple(args, i)); + for (auto& tuple : tuples) + tuple.pre(S, protocol); + protocol.exchange(); + for (auto& tuple : tuples) + tuple.post(S, protocol); +} + +inline +Conv2dTuple::Conv2dTuple(const vector& arguments, int start) +{ + assert(arguments.size() >= start + 15ul); + auto args = arguments.data() + start + 3; + output_h = args[0], output_w = args[1]; + inputs_h = args[2], inputs_w = args[3]; + weights_h = args[4], weights_w = args[5]; + stride_h = args[6], stride_w = args[7]; + n_channels_in = args[8]; + padding_h = args[9]; + padding_w = args[10]; + batch_size = args[11]; + r0 = arguments[start]; + r1 = arguments[start + 1]; + r2 = arguments[start + 2]; + lengths.resize(batch_size, vector>(output_h, vector(output_w))); + filter_stride_h = 1; + filter_stride_w = 1; if (stride_h < 0) { filter_stride_h = -stride_h; @@ -677,7 +694,11 @@ void SubProcessor::conv2ds(const Instruction& instruction) filter_stride_w = -stride_w; stride_w = 1; } +} +template +void Conv2dTuple::pre(vector& S, typename T::Protocol& protocol) +{ for (int i_batch = 0; i_batch < batch_size; i_batch ++) { size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in; @@ -714,9 +735,11 @@ void SubProcessor::conv2ds(const Instruction& instruction) protocol.next_dotprod(); } } +} - protocol.exchange(); - +template +void Conv2dTuple::post(vector& S, typename T::Protocol& protocol) +{ for (int i_batch = 0; i_batch < batch_size; i_batch ++) { size_t base = r0 + i_batch * output_h * output_w; diff --git a/Processor/ThreadQueue.cpp b/Processor/ThreadQueue.cpp index 6358e4a4a..ced871b5d 100644 --- a/Processor/ThreadQueue.cpp +++ b/Processor/ThreadQueue.cpp @@ -6,6 +6,8 @@ #include "ThreadQueue.h" +thread_local ThreadQueue* ThreadQueue::thread_queue = 0; + void ThreadQueue::schedule(const ThreadJob& job) { lock.lock(); @@ -14,7 +16,11 @@ void ThreadQueue::schedule(const ThreadJob& job) cerr << this << ": " << left << " left" << endl; #endif lock.unlock(); + if (thread_queue) + thread_queue->wait_timer.start(); in.push(job); + if (thread_queue) + thread_queue->wait_timer.stop(); } ThreadJob ThreadQueue::next() @@ -42,7 +48,11 @@ void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats) ThreadJob ThreadQueue::result() { + if (thread_queue) + thread_queue->wait_timer.start(); auto res = out.pop(); + if (thread_queue) + thread_queue->wait_timer.stop(); lock.lock(); left--; #ifdef DEBUG_THREAD_QUEUE diff --git a/Processor/ThreadQueue.h b/Processor/ThreadQueue.h index f49722abb..c9640b7ae 100644 --- a/Processor/ThreadQueue.h +++ b/Processor/ThreadQueue.h @@ -16,6 +16,11 @@ class ThreadQueue NamedCommStats comm_stats; public: + static thread_local ThreadQueue* thread_queue; + + map timers; + Timer wait_timer; + ThreadQueue() : left(0) { diff --git a/Processor/ThreadQueues.cpp b/Processor/ThreadQueues.cpp index ecca7bbe3..5e3a3c409 100644 --- a/Processor/ThreadQueues.cpp +++ b/Processor/ThreadQueues.cpp @@ -85,3 +85,32 @@ void ThreadQueues::wrap_up(ThreadJob job) } available.clear(); } + +TimerWithComm ThreadQueues::sum(const string& phase) +{ + TimerWithComm res; + for (auto& x : *this) + res += x->timers[phase]; + return res; +} + +void ThreadQueues::print_breakdown() +{ + if (size() > 0) + { + if (size() == 1) + { + cerr << "Spent " << (*this)[0]->timers["online"].full() + << " on the online phase and " + << (*this)[0]->timers["prep"].full() + << " on the preprocessing/offline phase." << endl; + } + else + { + cerr << size() << " threads spent a total of " << sum("online").full() + << " on the online phase, " << sum("prep").full() + << " on the preprocessing/offline phase, and " + << sum("wait").full() << " idling." << endl; + } + } +} diff --git a/Processor/ThreadQueues.h b/Processor/ThreadQueues.h index c5da4d1db..3191e436b 100644 --- a/Processor/ThreadQueues.h +++ b/Processor/ThreadQueues.h @@ -24,6 +24,10 @@ class ThreadQueues : int distribute_no_setup(ThreadJob job, int n_items, int base = 0, int granularity = 1, const vector* supplies = 0); void wrap_up(ThreadJob job); + + TimerWithComm sum(const string& phase); + + void print_breakdown(); }; #endif /* PROCESSOR_THREADQUEUES_H_ */ diff --git a/Processor/instructions.h b/Processor/instructions.h index 5912d8676..cc833284a 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -387,6 +387,7 @@ X(GENSECSHUFFLE, throw not_implemented(),) \ X(APPLYSHUFFLE, throw not_implemented(),) \ X(DELSHUFFLE, throw not_implemented(),) \ + X(ACTIVE, throw not_implemented(),) \ #define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \ CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS diff --git a/Programs/Source/alex.mpc b/Programs/Source/alex.mpc new file mode 100644 index 000000000..5b55a5ad8 --- /dev/null +++ b/Programs/Source/alex.mpc @@ -0,0 +1,114 @@ +from Compiler.ml import keras +import Compiler.ml as tf + +try: + n_epochs = int(program.args[1]) +except (ValueError, IndexError): + n_epochs = 20 + +try: + batch_size = int(program.args[2]) +except (ValueError, IndexError): + batch_size = 128 + +try: + n_threads = int(program.args[3]) +except (ValueError, IndexError): + n_threads = 36 + +#Instantiation +AlexNet = [] + +padding = 1 +batchnorm = 'batchnorm' in program.args +bn1 = 'bn1' in program.args +bn2 = 'bn2' in program.args + +MultiArray.disable_index_checks() + +#1st Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=64, input_shape=(32,32,3), kernel_size=3, strides=1, padding=2)) +AlexNet.append(keras.layers.Activation('relu')) +if batchnorm: + AlexNet.append(keras.layers.BatchNormalization()) +AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding=0)) + +#2nd Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=96, kernel_size=3, strides=1, padding=2)) +AlexNet.append(keras.layers.Activation('relu')) +if batchnorm or bn2: + AlexNet.append(keras.layers.BatchNormalization()) +AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same')) + +#3rd Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=96, kernel_size=(3,3), strides=(1,1), padding=padding)) +AlexNet.append(keras.layers.Activation('relu')) +if batchnorm: + AlexNet.append(keras.layers.BatchNormalization()) + +#4th Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding=padding)) +AlexNet.append(keras.layers.Activation('relu')) +if batchnorm or bn1: + AlexNet.append(keras.layers.BatchNormalization()) + +#5th Convolutional Layer +AlexNet.append(keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding=padding)) +AlexNet.append(keras.layers.Activation('relu')) +if batchnorm or bn2: + AlexNet.append(keras.layers.BatchNormalization()) +AlexNet.append(keras.layers.MaxPooling2D(pool_size=(3,3), strides=(2,2), padding=0)) + +#Passing it to a Fully Connected layer +# 1st Fully Connected Layer +AlexNet.append(keras.layers.Dense(128)) +AlexNet.append(keras.layers.Activation('relu')) + +if 'dropout' in program.args: + AlexNet.append(keras.layers.Dropout(0.5)) + +#2nd Fully Connected Layer +AlexNet.append(keras.layers.Dense(256)) +AlexNet.append(keras.layers.Activation('relu')) + +if 'dropout' in program.args: + AlexNet.append(keras.layers.Dropout(0.5)) + +#Output Layer +AlexNet.append(keras.layers.Dense(10)) + +tf.set_n_threads(n_threads) +program.options_from_args() +sfix.set_precision_from_args(program, adapt_ring=True) + +training_samples = MultiArray([50000, 32, 32, 3], sfix) +training_labels = MultiArray([50000, 10], sint) + +test_samples = MultiArray([10000, 32, 32, 3], sfix) +test_labels = MultiArray([10000, 10], sint) + +training_labels.input_from(0) +training_samples.input_from(0, binary='binary_samples' in program.args) + +test_labels.input_from(0) +test_samples.input_from(0, binary='binary_samples' in program.args) + +model = tf.keras.models.Sequential(AlexNet) + +model.compile_by_args(program) + +model.build(training_samples.sizes) +model.summary() + +model.opt.output_diff = 'output_diff' in program.args +model.opt.output_grad = 'output_grad' in program.args +model.opt.output_stats = 100 if 'output_stats' in program.args else 0 +model.opt.shuffle = not 'noshuffle' in program.args + +opt = model.fit( + training_samples, + training_labels, + epochs=n_epochs, + batch_size=batch_size, + validation_data=(test_samples, test_labels) +) diff --git a/Programs/Source/bankers_bonus.mpc b/Programs/Source/bankers_bonus.mpc index 674efcdad..edaa2da82 100644 --- a/Programs/Source/bankers_bonus.mpc +++ b/Programs/Source/bankers_bonus.mpc @@ -25,6 +25,9 @@ n_threads = 2 if len(program.args) > 1: n_rounds = int(program.args[1]) +if len(program.args) > 2: + program.active = bool(int(program.args[2])) + def accept_client(): client_socket_id = accept_client_connection(PORTNUM) last = regint.read_from_socket(client_socket_id) diff --git a/Programs/Source/falcon_alex.mpc b/Programs/Source/falcon_alex.mpc index 3c535248f..26422b86e 100644 --- a/Programs/Source/falcon_alex.mpc +++ b/Programs/Source/falcon_alex.mpc @@ -4,7 +4,7 @@ import Compiler.ml as tf try: n_epochs = int(program.args[1]) except (ValueError, IndexError): - n_epochs = 10 + n_epochs = 20 try: batch_size = int(program.args[2]) diff --git a/Programs/Source/keras_mnist_lenet_avgpool.mpc b/Programs/Source/keras_mnist_lenet_avgpool.mpc new file mode 100644 index 000000000..a3f0b486b --- /dev/null +++ b/Programs/Source/keras_mnist_lenet_avgpool.mpc @@ -0,0 +1,72 @@ +# this trains LeNet on MNIST with a dropout layer +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + +program.options_from_args() + +if 'torch' in program.args: + import torchvision + data = [] + for train in True, False: + ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True) + # normalize to [0,1] before input + samples = sfix.input_tensor_via(0, ds.data / 255., binary=True) + labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True) + data += [(labels, samples)] + + (training_labels, training_samples), (test_labels, test_samples) = data +else: + training_samples = sfix.Tensor([60000, 28, 28]) + training_labels = sint.Tensor([60000, 10]) + + test_samples = sfix.Tensor([10000, 28, 28]) + test_labels = sint.Tensor([10000, 10]) + + training_labels.input_from(0) + training_samples.input_from(0) + + test_labels.input_from(0) + test_samples.input_from(0) + +from Compiler import ml +tf = ml + +layers = [ + tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), +] + +if 'batchnorm' in program.args: + layers += [tf.keras.layers.BatchNormalization()] + +layers += [ + tf.keras.layers.AveragePooling2D(2), + tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), +] + + +if 'batchnorm' in program.args: + layers += [tf.keras.layers.BatchNormalization()] + +layers += [ + tf.keras.layers.AveragePooling2D(2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(500, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') +] + +model = tf.keras.models.Sequential(layers) + +optim = tf.keras.optimizers.Adam(amsgrad=True) + +model.compile(optimizer=optim) + +opt = model.fit( + training_samples, + training_labels, + epochs=10, + batch_size=128, + validation_data=(test_samples, test_labels) +) + +for var in model.trainable_variables: + var.write_to_file() diff --git a/Programs/Source/torch_mnist_lenet_avgpool.mpc b/Programs/Source/torch_mnist_lenet_avgpool.mpc new file mode 100644 index 000000000..4f1ea985a --- /dev/null +++ b/Programs/Source/torch_mnist_lenet_avgpool.mpc @@ -0,0 +1,49 @@ +# this trains a dense neural network on MNIST + +program.options_from_args() + +import torchvision + +data = [] +for train in True, False: + ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True) + # normalize to [0,1] before input + samples = sfix.input_tensor_via(0, ds.data / 255., binary=True) + labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True) + data += [(labels, samples)] + +import torch +import torch.nn as nn + +net = nn.Sequential( + nn.Conv2d(1, 20, 5), + nn.ReLU(), + nn.AvgPool2d(2), + nn.Conv2d(20, 50, 5), + nn.ReLU(), + nn.AvgPool2d(2), + nn.Flatten(), + nn.ReLU(), + nn.Linear(800, 500), + nn.ReLU(), + nn.Linear(500, 10) +) + +# test network +ds = torchvision.datasets.MNIST( + root='/tmp', transform=torchvision.transforms.ToTensor()) +inputs = next(iter(torch.utils.data.DataLoader(ds)))[0] +print(inputs.shape) +outputs = net(inputs) + +from Compiler import ml + +ml.set_n_threads(int(program.args[2])) + +layers = ml.layers_from_torch(net, data[0][1].shape, 128) +layers[0].X = data[0][1] +layers[-1].Y = data[0][0] + +optimizer = ml.SGD(layers) +optimizer.run_by_args(program, int(program.args[1]), 128, + data[1][1], data[1][0]) diff --git a/Protocols/DealerMatrixPrep.hpp b/Protocols/DealerMatrixPrep.hpp index e91d4cd9c..29e4c1efd 100644 --- a/Protocols/DealerMatrixPrep.hpp +++ b/Protocols/DealerMatrixPrep.hpp @@ -11,6 +11,8 @@ DealerMatrixPrep::DealerMatrixPrep(int n_rows, int n_inner, int n_cols, super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols), prep(&prep) { + assert(prep.proc); + this->P = &prep.proc->P; } template diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 0aa61bcba..0d8d2f695 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -20,9 +20,6 @@ class Hemi : public T::BasicProtocol MatrixMC mc; - ShareMatrix matrix_multiply(const ShareMatrix& A, const ShareMatrix& B, - SubProcessor& processor); - public: Hemi(Player& P) : T::BasicProtocol(P) @@ -33,6 +30,9 @@ class Hemi : public T::BasicProtocol typename T::MatrixPrep& get_matrix_prep(const array& dimensions, SubProcessor& processor); + ShareMatrix matrix_multiply(const ShareMatrix& A, const ShareMatrix& B, + SubProcessor& processor); + void matmulsm(SubProcessor& processor, CheckVector& source, const Instruction& instruction, int a, int b); void conv2ds(SubProcessor& processor, const Instruction& instruction); diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index 9ba85290f..2b847530e 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -130,37 +130,23 @@ void Hemi::conv2ds(SubProcessor& processor, } auto& args = instruction.get_start(); - int output_h = args[0], output_w = args[1]; - int inputs_h = args[2], inputs_w = args[3]; - int weights_h = args[4], weights_w = args[5]; - int stride_h = args[6], stride_w = args[7]; - int n_channels_in = args[8]; - int padding_h = args[9]; - int padding_w = args[10]; - int batch_size = args[11]; - size_t r0 = instruction.get_r(0); - size_t r1 = instruction.get_r(1); - int r2 = instruction.get_r(2); - int filter_stride_h = 1; - int filter_stride_w = 1; - if (stride_h < 0) - { - filter_stride_h = -stride_h; - stride_h = 1; - } - if (stride_w < 0) - { - filter_stride_w = -stride_w; - stride_w = 1; - } + vector tuples; + for (size_t i = 0; i < args.size(); i += 15) + tuples.push_back(Conv2dTuple(args, i)); + for (auto& tuple : tuples) + tuple.run_matrix(processor); +} +template +void Conv2dTuple::run_matrix(SubProcessor& processor) +{ auto& S = processor.get_S(); array dim({{1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}}); ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); if (not T::real_shares(processor.P)) { - matrix_multiply(A, B, processor); + processor.protocol.matrix_multiply(A, B, processor); return; } @@ -208,7 +194,7 @@ void Hemi::conv2ds(SubProcessor& processor, } } - auto C = matrix_multiply(A, B, processor); + auto C = processor.protocol.matrix_multiply(A, B, processor); for (int i_batch = 0; i_batch < batch_size; i_batch ++) { diff --git a/Protocols/HemiMatrixPrep.h b/Protocols/HemiMatrixPrep.h index 8038e8efc..db8682193 100644 --- a/Protocols/HemiMatrixPrep.h +++ b/Protocols/HemiMatrixPrep.h @@ -37,6 +37,8 @@ class HemiMatrixPrep : public BufferPrep> if (swapped) std::swap(this->n_rows, this->n_cols); assert(this->n_cols >= this->n_rows); + assert(prep.proc); + this->P = &prep.proc->P; } void set_protocol(typename ShareMatrix::Protocol&) diff --git a/Protocols/MaliciousShamirMC.hpp b/Protocols/MaliciousShamirMC.hpp index 7f66215d5..912302a11 100644 --- a/Protocols/MaliciousShamirMC.hpp +++ b/Protocols/MaliciousShamirMC.hpp @@ -21,11 +21,7 @@ void MaliciousShamirMC::init_open(const Player& P, int n) reconstructions.resize(2 * threshold + 2); for (int i = threshold + 1; i <= 2 * threshold + 1; i++) { - reconstructions[i].resize(i); - for (int j = 0; j < i; j++) - reconstructions[i][j] = - Shamir::get_rec_factor(P.get_player(j), - P.num_players(), P.my_num(), i); + reconstructions[i] = ShamirMC::get_reconstruction(P, i); } } diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index e73d9cc2c..c3899e745 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -13,6 +13,7 @@ #include "Protocols/ShuffleSacrifice.h" #include "Protocols/MAC_Check_Base.h" #include "Protocols/ShuffleSacrifice.h" +#include "Tools/TimerWithComm.h" #include "edabit.h" #include @@ -33,6 +34,8 @@ class BufferPrep : public Preprocessing { template friend class Machine; + friend class InScope; + template void buffer_inverses(true_type); template @@ -49,9 +52,13 @@ class BufferPrep : public Preprocessing vector> dabits; + map, vector>> edabits; + map, edabitvec> my_edabits; + int n_bit_rounds; SubProcessor* proc; + Player* P; virtual void buffer_triples() { throw runtime_error("no triples"); } virtual void buffer_squares() { throw runtime_error("no squares"); } @@ -110,6 +117,9 @@ class BufferPrep : public Preprocessing virtual void get_dabit_no_count(T& a, typename T::bit_type& b); + edabitvec get_edabitvec(bool strict, int n_bits); + void get_edabit_no_count(bool strict, int n_bits, edabit& eb); + /// Get fresh random value virtual T get_random(); diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index e8b177914..cf3515490 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -24,24 +24,44 @@ class InScope { bool& variable; bool backup; + TimerWithComm& timer; + bool running; + Player* P; public: - InScope(bool& variable, bool value) : - variable(variable) + template + InScope(bool& variable, bool value, BufferPrep& prep) : + variable(variable), timer(prep.prep_timer), + P(prep.proc ? &prep.proc->P : (prep.P ? prep.P : 0)) { backup = variable; variable = value; + running = timer.is_running(); + if (not running) + { + if (P) + timer.start(P->total_comm()); + else + timer.start({}); + } } ~InScope() { variable = backup; + if (not running) + { + if (P) + timer.stop(P->total_comm()); + else + timer.stop({}); + } } }; template BufferPrep::BufferPrep(DataPositions& usage) : Preprocessing(usage), n_bit_rounds(0), - proc(0), + proc(0), P(0), buffer_size(OnlineOptions::singleton.batch_size) { } @@ -90,6 +110,17 @@ BufferPrep::~BufferPrep() this->print_left_edabits(x.second.size(), x.second[0].size(), x.first.first, x.first.second, this->usage.edabits[x.first]); } + +#ifdef VERBOSE + if (OnlineOptions::singleton.verbose and this->prep_timer.elapsed()) + { + cerr << type_string << " preprocessing time = " + << this->prep_timer.elapsed(); + if (this->prep_timer.mb_sent()) + cerr << " (" << this->prep_timer.mb_sent() << " MB)"; + cerr << endl; + } +#endif } template @@ -186,6 +217,7 @@ void BufferPrep::get_three_no_count(Dtype dtype, T& a, T& b, T& c) if (triples.empty()) { + InScope in_scope(this->do_count, false, *this); buffer_triples(); assert(not triples.empty()); } @@ -277,6 +309,7 @@ void BufferPrep::buffer_inverses(true_type) template void BufferPrep::get_two_no_count(Dtype dtype, T& a, T& b) { + InScope in_scope(this->do_count, false, *this); switch (dtype) { case DATA_SQUARE: @@ -464,7 +497,7 @@ void RingPrep::buffer_personal_edabits_without_check(int n_bits, #ifdef VERBOSE_EDA fprintf(stderr, "generate personal edaBits %d to %d\n", begin, end); #endif - InScope in_scope(this->do_count, false); + InScope in_scope(this->do_count, false, *this); assert(this->proc != 0); auto& P = proc.P; typename T::Input input(*this->proc, this->proc->MC); @@ -760,8 +793,10 @@ void RingPrep::buffer_dabits_without_check(vector>& dabits, typedef typename T::bit_type::part_type bit_type; vector> player_bits; auto& party = GC::ShareThread::s(); - SubProcessor bit_proc(party.MC->get_part_MC(), - bit_prep, proc->P); + if (not bit_part_proc) + bit_part_proc = new SubProcessor(party.MC->get_part_MC(), + bit_prep, proc->P); + auto& bit_proc = *bit_part_proc; buffer_bits_from_players(player_bits, G, bit_proc, this->base_player, buffer_size, 1); vector int_bits; @@ -828,7 +863,7 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector& sums, vector> player_ints(n_relevant, vector(buffer_size)); vector>> parts(n_relevant, vector>(n_bits, vector(buffer_size / dl))); - InScope in_scope(this->do_count, false); + InScope in_scope(this->do_count, false, *this); assert(this->proc != 0); auto& P = proc->P; typename T::Input input(*this->proc, this->proc->MC); @@ -1074,7 +1109,7 @@ void BufferPrep::get_one_no_count(Dtype dtype, T& a) while (bits.empty()) { - InScope in_scope(this->do_count, false); + InScope in_scope(this->do_count, false, *this); buffer_bits(); n_bit_rounds++; } @@ -1086,6 +1121,7 @@ void BufferPrep::get_one_no_count(Dtype dtype, T& a) template void BufferPrep::get_input_no_count(T& a, typename T::open_type& x, int i) { + InScope in_scope(this->do_count, false, *this); (void) a, (void) x, (void) i; if (inputs.size() <= (size_t)i) inputs.resize(i + 1); @@ -1101,7 +1137,7 @@ void BufferPrep::get_dabit_no_count(T& a, typename T::bit_type& b) { if (dabits.empty()) { - InScope in_scope(this->do_count, false); + InScope in_scope(this->do_count, false, *this); ThreadQueues* queues = 0; buffer_dabits(queues); assert(not dabits.empty()); @@ -1117,7 +1153,7 @@ void BufferPrep::get_personal_dabit(int player, T& a, typename T::bit_type& b auto& buffer = personal_dabits[player]; if (buffer.empty()) { - InScope in_scope(this->do_count, false); + InScope in_scope(this->do_count, false, *this); buffer_personal_dabits(player); } a = buffer.back().first; @@ -1133,28 +1169,39 @@ void Preprocessing::get_dabit(T& a, typename T::bit_type& b) } template -template -edabitvec Preprocessing::get_edabitvec(bool strict, int n_bits) +edabitvec BufferPrep::get_edabitvec(bool strict, int n_bits) { auto& buffer = this->edabits[{strict, n_bits}]; if (buffer.empty()) { - InScope in_scope(this->do_count, false); + InScope in_scope(this->do_count, false, *this); buffer_edabits_with_queues(strict, n_bits); } auto res = buffer.back(); buffer.pop_back(); + this->fill(res, strict, n_bits); return res; } template -template -void Preprocessing::get_edabit_no_count(bool strict, int n_bits, edabit& a) +void BufferPrep::get_edabit_no_count(bool strict, int n_bits, edabit& a) { auto& my_edabit = my_edabits[{strict, n_bits}]; if (my_edabit.empty()) { - my_edabit = this->template get_edabitvec<0>(strict, n_bits); + my_edabit = this->get_edabitvec(strict, n_bits); + } + a = my_edabit.next(); +} + +template +void Sub_Data_Files::get_edabit_no_count(bool strict, int n_bits, + edabit& a) +{ + auto& my_edabit = my_edabits[n_bits]; + if (my_edabit.empty()) + { + my_edabit = this->get_edabitvec(strict, n_bits); } a = my_edabit.next(); } @@ -1174,24 +1221,25 @@ void Preprocessing::get_edabits(bool strict, size_t size, T* a, vector& Sb, const vector& regs, false_type) { int n_bits = regs.size(); - auto& buffer = edabits[{strict, n_bits}]; edabit eb; size_t unit = T::bit_type::default_length; for (int k = 0; k < DIV_CEIL(size, unit); k++) { - if (not buffer.empty() and buffer.back().size() == unit and (k + 1) * unit <= size) + + if (unit == edabitvec::MAX_SIZE and (k + 1) * unit <= size) { + auto buffer = get_edabitvec(strict, n_bits); + assert(unit == buffer.size()); for (int j = 0; j < n_bits; j++) - Sb[regs[j] + k] = buffer.back().get_b(j); + Sb[regs[j] + k] = buffer.get_b(j); for (size_t j = 0; j < unit; j++) - a[k * unit + j] = buffer.back().get_a(j); - buffer.pop_back(); + a[k * unit + j] = buffer.get_a(j); } else { for (size_t i = k * unit; i < min(size, (k + 1) * unit); i++) { - this->template get_edabit_no_count<0>(strict, n_bits, eb); + get_edabit_no_count(strict, n_bits, eb); a[i] = eb.first; for (int j = 0; j < n_bits; j++) { diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index db056ae49..7c29a383a 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -49,8 +49,8 @@ class Shamir : public ProtocolBase Player& P; static U get_rec_factor(int i, int n); - static U get_rec_factor(int i, int n_total, int start, int threshold, - int target = -1); + static U get_rec_factor(int i, const vector& points, int target = -1); + static vector get_rec_factors(const vector& points, int target = -1); Shamir(Player& P, int threshold = 0); ~Shamir(); diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index 89fa6853e..e11b37513 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -15,21 +15,20 @@ template typename T::open_type::Scalar Shamir::get_rec_factor(int i, int n) { - return get_rec_factor(i, n, 0, n); + vector points(n); + for (int j = 0; j < n; j++) + points[j] = j; + return get_rec_factor(i, points); } template -typename T::open_type::Scalar Shamir::get_rec_factor(int i, int n_total, - int start, int n_points, int target) +typename T::open_type::Scalar Shamir::get_rec_factor(int i, + const vector& points, int target) { + assert(find(points.begin(), points.end(), i) != points.end()); U res = 1; - for (int j = 0; j < n_points; j++) + for (auto& other : points) { - int other; - if (n_total > 0) - other = positive_modulo(start + j, n_total); - else - other = start + j; if (i != other) { res *= (U(other + 1) - U(target + 1)) / (U(other + 1) - U(i + 1)); @@ -42,6 +41,16 @@ typename T::open_type::Scalar Shamir::get_rec_factor(int i, int n_total, return res; } +template +vector Shamir::get_rec_factors( + const vector& points, int target) +{ + vector res; + for (auto& point : points) + res.push_back(get_rec_factor(point, points, target)); + return res; +} + template Shamir::Shamir(Player& P, int t) : resharing(0), random_input(0), P(P) diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index 41c880121..f82c6f568 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -52,10 +52,12 @@ void ShamirInput::init() for (size_t i = 0; i < reconstruction.size(); i++) { auto& x = reconstruction[i]; - for (int j = 0; j <= threshold; j++) - x.push_back( - Shamir::get_rec_factor(j - 1, 0, -1, threshold + 1, - i + threshold)); + vector points(threshold + 1); + points[0] = -1; + for (int i = 0; i < threshold; i++) + points[1 + i] = this->P.get_player(1 + i); + x = Shamir::get_rec_factors(points, + this->P.get_player(1 + i + threshold)); } } @@ -68,23 +70,24 @@ void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) int t = threshold; randomness.resize(t); - for (int i = 0; i < t; i++) + for (int offset = 0; offset < t; offset++) { - randomness[i].randomize(this->send_prngs[i]); - if (i == P.my_num()) - this->shares.push_back(randomness[i]); + int i = P.get_player(1 + offset); + assert(i != P.my_num()); + randomness[offset].randomize(this->send_prngs[i]); } for (int i = threshold; i < n; i++) { + int player = P.get_player(1 + i); typename T::open_type x = input * reconstruction.at(i - threshold).at(0); for (int j = 0; j < t; j++) x += randomness[j] * reconstruction.at(i - threshold).at(j + 1); - if (i == P.my_num()) + if (player == P.my_num()) this->shares.push_back(x); else - x.pack(this->os[i]); + x.pack(this->os[player]); } this->senders[P.my_num()] = true; @@ -94,7 +97,7 @@ template void ShamirInput::finalize_other(int player, T& target, octetStream& o, int n_bits) { - if (this->P.my_num() < threshold) + if (this->P.get_offset(player) >= this->P.num_players() - threshold) target.randomize(this->recv_prngs.at(player)); else IndividualInput::finalize_other(player, target, o, n_bits); diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index bd0cc3176..8e2bafdbe 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -73,7 +73,7 @@ class ShamirMC : public IndirectShamirMC void Check(const Player& P) { (void)P; } - vector get_reconstruction(const Player& P); + vector get_reconstruction(const Player& P, int n = 0); open_type reconstruct(const vector& shares); }; diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 585a6896b..db50c327a 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -36,14 +36,15 @@ void ShamirMC::POpen_Begin(vector& values, template vector ShamirMC::get_reconstruction( - const Player& P) + const Player& P, int n_relevant_players) { - int n_relevant_players = threshold + 1; + if (n_relevant_players == 0) + n_relevant_players = threshold + 1; vector reconstruction(n_relevant_players); + vector points(n_relevant_players); for (int i = 0; i < n_relevant_players; i++) - reconstruction[i] = Shamir::get_rec_factor(P.get_player(i), - P.num_players(), P.my_num(), n_relevant_players); - return reconstruction; + points[i] = P.get_player(i); + return Shamir::get_rec_factors(points); } template diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index 318f050dd..efc7e45f8 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -50,7 +50,7 @@ class ShamirShare : public T, public ShareInterface const static bool dishonest_majority = false; const static bool variable_players = true; const static bool expensive = false; - const static bool malicious = true; + const static bool malicious = false; static string type_short() { diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index 815277614..4b76904b7 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -32,8 +32,8 @@ Spdz2kPrep::~Spdz2kPrep() { if (bit_prep != 0) { - delete bit_prep; delete bit_proc; + delete bit_prep; delete bit_MC; } } @@ -51,7 +51,6 @@ void Spdz2kPrep::set_protocol(typename T::Protocol& protocol) bit_prep->params.amplify = false; bit_proc = new SubProcessor(*bit_MC, *bit_prep, proc->P); bit_MC->set_prep(*bit_prep); - this->proc->MC.set_prep(*this); } template diff --git a/Protocols/edabit.h b/Protocols/edabit.h index d893277ae..d84a0e94c 100644 --- a/Protocols/edabit.h +++ b/Protocols/edabit.h @@ -91,7 +91,8 @@ class edabitvec { for (size_t i = 0; i < x.second.size(); i++) { - b[i] ^= typename T::bit_type::part_type(x.second[i]) << a.size(); + b[i] ^= (typename T::bit_type::part_type(x.second[i]) + ^ b[i].get_bit(a.size())) << a.size(); } a.push_back(x.first); } diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index 8bca59536..2d4b7f88f 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -49,7 +49,7 @@ class Files Files(int N, const typename T::mac_type& key, const string& prep_data_prefix, Dtype type, int thread_num = -1) : Files(N, key, - get_prep_sub_dir(prep_data_prefix, N) + get_prep_sub_dir(prep_data_prefix, N, true) + DataPositions::dtype_names[type] + "-" + T::type_short(), thread_num) { diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index d9b6122de..4cfdb93f0 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -295,9 +295,13 @@ void write_mac_key(const string& directory, int i, int nplayers, U key) #ifdef VERBOSE cout << "Writing to " << filename.str().c_str() << endl; #endif + if (not directory.empty()) + mkdir_p(directory.c_str()); outf.open(filename.str().c_str()); outf << nplayers << endl; key.output(outf,true); + if (outf.fail()) + throw IO_Error("cannot write to " + filename.str()); outf.close(); } @@ -469,16 +473,16 @@ void generate_mac_keys(typename T::mac_share_type::open_type& key, key.assign_zero(); int tmpN = 0; ifstream inpf; - prep_data_prefix = get_prep_sub_dir(prep_data_prefix, nplayers); + prep_data_prefix = get_prep_sub_dir(prep_data_prefix, nplayers, true); bool generate = false; vector key_shares(nplayers); for (int i = 0; i < nplayers; i++) { auto& pp = key_shares[i]; - stringstream filename; - filename << mac_filename(prep_data_prefix, i); - inpf.open(filename.str().c_str()); + string filename; + filename = mac_filename(prep_data_prefix, i); + inpf.open(filename); if (inpf.fail()) { inpf.close(); @@ -509,15 +513,15 @@ void generate_mac_keys(typename T::mac_share_type::open_type& key, for (int i = 0; i < nplayers; i++) { auto& pp = key_shares[i]; - stringstream filename; - filename - << mac_filename(prep_data_prefix, i); - ofstream outf(filename.str().c_str()); + string filename; + filename = mac_filename(prep_data_prefix, + i); + ofstream outf(filename); if (outf.fail()) - throw file_error(filename.str().c_str()); + throw file_error(filename); outf << nplayers << " " << pp << endl; outf.close(); - cout << "Written new MAC key share to " << filename.str() << endl; + cout << "Written new MAC key share to " << filename << endl; cout << " Key " << i << ": " << pp << endl; } } diff --git a/README.md b/README.md index 8f55f5581..9132ca069 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ parties and malicious security. On Linux, this requires a working toolchain and [all requirements](#requirements). On Ubuntu, the following might suffice: ``` -sudo apt-get install automake build-essential clang cmake git libboost-dev libboost-thread-dev libntl-dev libsodium-dev libssl-dev libtool m4 python3 texinfo yasm +sudo apt-get install automake build-essential clang cmake git libboost-dev libboost-thread-dev libgmp libntl-dev libsodium-dev libssl-dev libtool python3 ``` On MacOS, this requires [brew](https://brew.sh) to be installed, which will be used for all dependencies. @@ -285,7 +285,9 @@ compute the preprocessing time for a particular computation. libOTe also requires boost of version at least 1.75, which is not available by default on relatively recent systems such as Ubuntu 22.04. You can install it locally by running `make boost`. - - MPIR library, compiled with C++ support (use flag `--enable-cxx` when running configure). You can use `make -j8 mpir` to install it locally. + - GMP library, compiled with C++ support (use flag `--enable-cxx` + when running configure). Tested against 6.2.1 as supplied by + Ubuntu. - libsodium library, tested against 1.0.18 - OpenSSL, tested against 3.0.2 - Boost.Asio with SSL support (`libboost-dev` on Ubuntu), tested against 1.81 @@ -311,6 +313,8 @@ compute the preprocessing time for a particular computation. [GCC documentation](https://gcc.gnu.org/onlinedocs/gcc/x86-Options.html) for the possible options. + To run on CPUs without AVX2 (CPUs from before 2014), you should + also add `AVX_OT = 0` to `CONFIG.mine`. - For optimal results on Linux on ARM, add `ARCH = -march=armv8.2-a+crypto` to `CONFIG.mine`. This enables the hardware support for AES. See the [GCC documentation](https://gcc.gnu.org/onlinedocs/gcc/AArch64-Options.html#AArch64-Options) on available options. @@ -322,6 +326,8 @@ compute the preprocessing time for a particular computation. `SECURE = -DINSECURE` to `CONFIG.mine`. This is necessary with GCC 5 and 6 because these compilers don't support the C++ standard used by libOTe. + - On macOS, there have been issues with non-system compilers. Add + `CXX = /usr/bin/g++` to fix them. 2. Run `make` to compile all the software (use the flag `-j` for faster compilation using multiple threads). See below on how to compile specific diff --git a/Scripts/compile-emulate.py b/Scripts/compile-emulate.py index 4f346c325..5d5fbd2f7 100755 --- a/Scripts/compile-emulate.py +++ b/Scripts/compile-emulate.py @@ -2,7 +2,7 @@ import os, sys -sys.path.append('.') +sys.path.insert(0, os.path.dirname(sys.argv[0]) + '/..') from Compiler.compilerLib import Compiler diff --git a/Scripts/compile-run.py b/Scripts/compile-run.py index d7f2711b3..70aff1bf6 100755 --- a/Scripts/compile-run.py +++ b/Scripts/compile-run.py @@ -2,7 +2,7 @@ import os, sys -sys.path.append('.') +sys.path.insert(0, os.path.dirname(sys.argv[0]) + '/..') from Compiler.compilerLib import Compiler diff --git a/Scripts/emulate.sh b/Scripts/emulate.sh index 9585c85ca..98d84fde3 100755 --- a/Scripts/emulate.sh +++ b/Scripts/emulate.sh @@ -1,7 +1,9 @@ #!/bin/bash -. $(dirname $0)/run-common.sh +dir="$(dirname $0)" +. "$dir"/run-common.sh prog=${1%.sch} prog=${prog##*/} shift -$prefix ./emulate.x $prog $* 2>&1 | tee logs/emulate-$prog +mkdir logs 2> /dev/null +$prefix "$dir"/../emulate.x $prog $* 2>&1 | tee logs/emulate-$prog diff --git a/Scripts/list-field-protocols.sh b/Scripts/list-field-protocols.sh new file mode 100755 index 000000000..16052052c --- /dev/null +++ b/Scripts/list-field-protocols.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +echo rep-field shamir mal-rep-field ps-rep-field sy-rep-field \ + atlas mal-shamir sy-shamir semi hemi temi mascot cowgear chaigear diff --git a/Scripts/list-ring-protocols.sh b/Scripts/list-ring-protocols.sh new file mode 100755 index 000000000..9491d066d --- /dev/null +++ b/Scripts/list-ring-protocols.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +echo ring semi2k brain mal-rep-ring ps-rep-ring sy-rep-ring \ + spdz2k rep4-ring diff --git a/Scripts/memory-usage.py b/Scripts/memory-usage.py index 1977fc2c6..d5026eaa3 100755 --- a/Scripts/memory-usage.py +++ b/Scripts/memory-usage.py @@ -20,7 +20,7 @@ def process(tapename, res, regs): t = inst.type if issubclass(t, DirectMemoryInstruction): res[t.arg_format[0]] = max(inst.args[1].i + inst.size, - res[t.arg_format[0]]) + res[t.arg_format[0]]) + 1 for arg in inst.args: if isinstance(arg, RegisterArgFormat): regs[type(arg)] = max(regs[type(arg)], arg.i + inst.size) @@ -36,6 +36,27 @@ def process(tapename, res, regs): regout = lambda regs: dict((reverse_formats[t], n) for t, n in regs.items()) -print ('Memory:', dict(res)) -print ('Registers in main thread:', regout(regs)) -print ('Registers in other threads:', regout(thread_regs)) +def output(data): + for t, n in data.items(): + if n: + try: + print('%10d %s' % (n, ArgFormats[t.removesuffix('w')].name)) + except: + pass + +total = 0 +for x in res, regs, thread_regs: + total += sum(x.values()) + +print ('Memory:') +output(res) + +print ('Registers in main thread:') +output(regout(regs)) + +if thread_regs: + print ('Registers in other threads:') + output(regout(thread_regs)) + +print ('The program requires at the very least %f GB of RAM per party.' % \ + (total * 8e-9)) diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index fe3c54e71..8ceef7385 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -31,8 +31,8 @@ run_player() { prog=${prog##*/} prog=${prog%.sch} shift - if ! test -e $SPDZROOT/logs; then - mkdir $SPDZROOT/logs + if ! test -e logs; then + mkdir logs fi params="$prog $* -pn $port -h localhost" if $SPDZROOT/$bin 2>&1 | grep -q '^-N,'; then @@ -47,7 +47,7 @@ run_player() { set -o pipefail for i in $(seq 0 $[players-1]); do >&2 echo Running $prefix $SPDZROOT/$bin $i $params - log=$SPDZROOT/logs/$log_prefix$i + log=logs/$log_prefix$i $prefix $SPDZROOT/$bin $i $params 2>&1 | { if test "$BENCH"; then @@ -58,11 +58,27 @@ run_player() { } & codes[$i]=$! done + ctrlc() + { + pkill -P $$ + } + trap ctrlc SIGINT for i in $(seq 0 $[players-1]); do - wait ${codes[$i]} || return 1 + if ! wait ${codes[$i]}; then + for i in $(seq 1 $[players-1]); do + echo === Party $i + tail -n 3 logs/$log_prefix$i + done + return 1 + fi done } +getopts N: opt $(getopt N: $* 2>/dev/null) +if test "$opt" = N; then + PLAYERS=$OPTARG +fi + players=${PLAYERS:-2} SPDZROOT=${SPDZROOT:-.} diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index f3e67c82f..e64c8e461 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -103,6 +103,7 @@ void BufferBase::prune() string tmp_name = filename + ".new"; ofstream tmp(tmp_name.c_str()); size_t start = file->tellg(); + start -= element_length() * (BUFFER_SIZE - next); char buf[header_length]; file->seekg(0); file->read(buf, header_length); diff --git a/Tools/Buffer.h b/Tools/Buffer.h index ffd411233..d564c64a8 100644 --- a/Tools/Buffer.h +++ b/Tools/Buffer.h @@ -33,6 +33,8 @@ class BufferBase string filename; int header_length; + virtual int element_length() = 0; + public: bool eof; @@ -59,6 +61,8 @@ class Buffer : public BufferBase void read(char* read_buffer); + int element_length() { return T::size(); } + public: virtual ~Buffer(); virtual ifstream* open(); @@ -109,7 +113,7 @@ class BufferOwner : public Buffer BufferOwner(const BufferOwner& other) : file(0) { - assert(other.file == 0); + *this = other; } ~BufferOwner() @@ -117,9 +121,18 @@ class BufferOwner : public Buffer close(); } + BufferOwner& operator=(const BufferOwner& other) + { + assert(other.file == 0); + file = 0; + Buffer::operator=(other); + return *this; + } + ifstream* open() { file = new ifstream(this->filename, ios::in | ios::binary); + BufferBase::file = file; if (file->good()) { auto file_spec = check_file_signature(*file, this->filename); diff --git a/Tools/FlexBuffer.cpp b/Tools/FlexBuffer.cpp index 6c663cb63..bc8bf29a0 100644 --- a/Tools/FlexBuffer.cpp +++ b/Tools/FlexBuffer.cpp @@ -40,7 +40,7 @@ void ReceivedMsgStore::push(ReceivedMsg& msg) else { char filename[1000]; - sprintf(filename, "%s/%d.XXXXXX", BUFFER_DIR, getpid()); + snprintf(filename, 1000, "%s/%d.XXXXXX", BUFFER_DIR, getpid()); FILE* file = fdopen(mkstemp(filename), "w"); if (!file) throw runtime_error("can't open file, check space on " diff --git a/Tools/FlexBuffer.h b/Tools/FlexBuffer.h index 932e15e0a..484fbb0bb 100644 --- a/Tools/FlexBuffer.h +++ b/Tools/FlexBuffer.h @@ -231,7 +231,7 @@ inline void SendBuffer::serialize_no_allocate(const void* source, size_t size) inline void ReceivedMsg::check_buffer(size_t size) { (void)size; -#ifdef CHECK_BUFFER +#ifndef NO_CHECK_BUFFER if (ptr + size > buf + len) throw overflow_error("not enough data in buffer"); #endif diff --git a/Tools/TimerWithComm.cpp b/Tools/TimerWithComm.cpp index 2a5e8e12a..c6d751058 100644 --- a/Tools/TimerWithComm.cpp +++ b/Tools/TimerWithComm.cpp @@ -5,6 +5,15 @@ #include "TimerWithComm.h" +TimerWithComm::TimerWithComm() +{ +} + +TimerWithComm::TimerWithComm(const Timer& other) : + Timer(other) +{ +} + void TimerWithComm::start(const NamedCommStats& stats) { Timer::start(); @@ -17,7 +26,58 @@ void TimerWithComm::stop(const NamedCommStats& stats) total_stats += stats - last_stats; } -double TimerWithComm::mb_sent() +double TimerWithComm::mb_sent() const { return total_stats.sent * 1e-6; } + +size_t TimerWithComm::rounds() const +{ + size_t res = 0; + for (auto& x : total_stats) + res += x.second.rounds; + return res; +} + +TimerWithComm TimerWithComm::operator +(const TimerWithComm& other) +{ + TimerWithComm res = *this; + res += other; + return res; +} + +TimerWithComm TimerWithComm::operator -(const TimerWithComm& other) +{ + TimerWithComm res = *this; + res.Timer::operator-=(other); + res.total_stats = total_stats - other.total_stats; + return res; +} + +TimerWithComm& TimerWithComm::operator +=(const TimerWithComm& other) +{ + Timer::operator+=(other); + total_stats += other.total_stats; + return *this; +} + +TimerWithComm& TimerWithComm::operator -=(const TimerWithComm& other) +{ + *this = *this - other; + return *this; +} + +string TimerWithComm::full() +{ + stringstream tmp; + tmp << elapsed() << " seconds"; + if (mb_sent() > 0) + tmp << " (" << *this << ")"; + return tmp.str(); +} + +ostream& operator<<(ostream& os, const TimerWithComm& stats) +{ + os << stats.mb_sent() << " MB, " << stats.rounds() << " rounds"; + return os; +} diff --git a/Tools/TimerWithComm.h b/Tools/TimerWithComm.h index 2f3976a20..04ac72864 100644 --- a/Tools/TimerWithComm.h +++ b/Tools/TimerWithComm.h @@ -14,10 +14,23 @@ class TimerWithComm : public Timer NamedCommStats total_stats, last_stats; public: + TimerWithComm(); + TimerWithComm(const Timer& other); + void start(const NamedCommStats& stats = {}); void stop(const NamedCommStats& stats = {}); - double mb_sent(); + double mb_sent() const; + size_t rounds() const; + + TimerWithComm operator+(const TimerWithComm& other); + TimerWithComm operator-(const TimerWithComm& other); + TimerWithComm& operator+=(const TimerWithComm& other); + TimerWithComm& operator-=(const TimerWithComm& other); + + string full(); + + friend ostream& operator<<(ostream& os, const TimerWithComm& stats); }; #endif /* TOOLS_TIMERWITHCOMM_H_ */ diff --git a/Tools/random.h b/Tools/random.h index 373fcd64d..3b16cdb2f 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -7,7 +7,7 @@ #include "Tools/avx_memcpy.h" #include "Networking/data.h" -#include +#include #define USE_AES diff --git a/Tools/time-func.cpp b/Tools/time-func.cpp index 7b52f9ef9..e52024f13 100644 --- a/Tools/time-func.cpp +++ b/Tools/time-func.cpp @@ -64,6 +64,11 @@ double Timer::idle() return convert_ns_to_seconds(elapsed_since_last_start()); } +bool Timer::is_running() +{ + return running; +} + Timer& Timer::operator -=(const Timer& other) { assert(clock_id == other.clock_id); @@ -89,3 +94,10 @@ Timer& Timer::operator +=(const TimeScope& other) elapsed_time += other.timer.elapsed_since_last_start(); return *this; } + +Timer Timer::operator +(const Timer& other) const +{ + Timer res = *this; + res += other; + return res; +} diff --git a/Tools/time-func.h b/Tools/time-func.h index e37fc6233..17d65f44b 100644 --- a/Tools/time-func.h +++ b/Tools/time-func.h @@ -27,10 +27,14 @@ class Timer double elapsed_then_reset(); double idle(); + bool is_running(); + Timer& operator-=(const Timer& other); Timer& operator+=(const Timer& other); Timer& operator+=(const TimeScope& other); + Timer operator+(const Timer& other) const; + private: timespec startv; bool running; diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index b5d593c73..2a38f7f44 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -189,13 +189,15 @@ void make_inputs(const typename T::mac_type& key,int N,int ntrip,const string& s /* Generate Inputs */ for (int player=0; player(prep_data_prefix, N) << "Inputs-" - << T::type_short() << "-P" << i << "-" << player; - cout << "Opening " << filename.str() << endl; - outf[i].open(filename.str().c_str(),ios::out | ios::binary); + { + string filename = PrepBase::get_input_filename( + get_prep_sub_dir(prep_data_prefix, N), T::type_short(), player, + i); + cout << "Opening " << filename << endl; + outf[i].open(filename, ios::out | ios::binary); file_signature().output(outf[i]); - if (outf[i].fail()) { throw file_error(filename.str().c_str()); } + if (outf[i].fail()) + throw file_error(filename); } for (int i=0; i -void make_Sbox(const typename T::mac_type& key, int N, int ntrip, bool zero, T, gf2n_short) +void make_Sbox(const typename T::mac_type& key, int N, int ntrip, bool zero, T, true_type) { make_AES(key, N, ntrip, zero); make_DES(key, N, ntrip, zero); } -template -void make_Sbox(const typename T::mac_type& key, int N, int ntrip, bool zero, T, U) +template +void make_Sbox(const typename T::mac_type& key, int N, int ntrip, bool zero, T, false_type) { (void)key, (void)N, (void)ntrip, (void)zero; } @@ -345,7 +347,7 @@ void make_Sbox(const typename T::mac_type& key, int N, int ntrip, bool zero, T, template void make_Sbox(const typename T::mac_type& key, int N, int ntrip, bool zero) { - make_Sbox(key, N, ntrip, zero, T(), typename T::clear()); + make_Sbox(key, N, ntrip, zero, T(), T::clear::characteristic_two); } template @@ -777,6 +779,8 @@ int FakeParams::generate() generate_field(T::clear::prime_field); generate_field(true_type()); + if (gf2n::degree() != gf2n_short::degree()) + generate_field(true_type()); // default generate_ring<64>(); @@ -807,8 +811,11 @@ void FakeParams::generate_field(true_type) make_basic>({}, nplayers, default_num, zero); make_with_mac_key>>(nplayers, default_num, zero); } + else if (nplayers == 4) + make_basic>({}, nplayers, default_num, zero); make_basic>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); if (nplayers > 2) { diff --git a/azure-pipelines.yml b/azure-pipelines.yml index f026f3005..3f54d385e 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -12,9 +12,9 @@ pool: steps: - script: | - bash -c "sudo apt-get update && sudo apt-get install libsodium-dev libntl-dev yasm texinfo libboost-dev libboost-thread-dev python3-gmpy2 python3-networkx python3-sphinx clang" + bash -c "sudo apt-get update && sudo apt-get install libsodium-dev libntl-dev python3-gmpy2 python3-networkx" - script: | - make boost libote mpir + make boost libote - script: echo USE_NTL=1 >> CONFIG.mine - script: diff --git a/deps/mpir b/deps/mpir deleted file mode 160000 index 55fe6a9f5..000000000 --- a/deps/mpir +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 55fe6a9f52ca532a611a89f67186ed915bbf1123 diff --git a/doc/io.rst b/doc/io.rst index bd6b4db88..053e93bf9 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -48,7 +48,8 @@ By default, :py:func:`~Compiler.library.print_ln` and related functions only output to the terminal on party 0. This allows to run several parties in one terminal without spoiling the output. You can use interactive mode with option ``-I`` in order to output on all -parties. Note that this also to reading inputs from the command line +parties or ``-OF .`` to activate the output without interactive mode. +Note that the former also causes to inputs from the command line unless you specify ``-IF`` as well. You can also specify a file prefix with ``-OF``, so that outputs are written to ``-P-``. diff --git a/doc/machine-learning.rst b/doc/machine-learning.rst index 0a44cdc84..d59ed343e 100644 --- a/doc/machine-learning.rst +++ b/doc/machine-learning.rst @@ -310,6 +310,9 @@ and then trains it:: validation_data=(test_samples, test_labels) ) +See ``Programs/Source/keras_*.mpc`` for further examples using the +Keras interface. + Decision trees ============== @@ -436,7 +439,11 @@ and used in MP-SPDZ:: n_correct, loss = optimizer.reveal_correctness(test_samples, test_labels, 128, running=True) print_ln('Secure accuracy: %s/%s', n_correct, len(test_samples)) -This outputs the accuracy of the network. +This outputs the accuracy of the network. You can use +:py:func:`~Compiler.ml.Optimizer.eval` instead of +:py:func:`~Compiler.ml.Optimizer.reveal_correctness` to retrieve +probability distributions or top guessess (the latter with ``top=True``) +for any sample data. Storing and loading models diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index f2295bc9e..1c6ff2eec 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -6,8 +6,8 @@ Troubleshooting This section shows how to solve some common issues. -Crash without error message or ``bad_alloc`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Crash without error message, ``Killed``, or ``bad_alloc`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Some protocols require several gigabytes of memory, and the virtual machine will crash if there is not enough RAM. You can reduce the @@ -55,14 +55,26 @@ introduction and :py:func:`~Compiler.types.sfix.set_precision` for how to change the precision. +Variable results when using :py:class:`~Compiler.types.sfix` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This is caused the usage of probablistic rounding, which is used to +restore the representation after a multiplication. See `Catrina and Saxena +`_ for details. You can switch +to deterministic rounding by calling ``sfix.round_nearest = True``. + + Order of memory instructions not preserved ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ By default, the compiler runs optimizations that in some corner case can introduce errors with memory accesses such as accessing an -:py:class:`~Compiler.types.Array`. If you encounter such errors, you -can fix this either with ``-M`` when compiling or placing -`break_point()` around memory accesses. +:py:class:`~Compiler.types.Array`. The error message does not +necessarily mean there will be errors, but the compiler cannot +guarantee that there will not. If you encounter such errors, you +can fix this either with ``-M`` when compiling or enable memory +protection (:py:func:`~Compiler.program.Program.protect_memory`) +around specific memory accesses. Odd timings @@ -164,7 +176,8 @@ Required prime bit length is not the same as ``-F`` parameter during compilation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This is related to statistical masking that requires the prime to be a -fair bit larger than the actual "payload". The technique goes to back +fair bit larger than the actual "payload" (40 by default). +The technique goes to back to `Catrina and de Hoogh `_. See also the paragraph on unknown prime moduli in :ref:`nonlinear`.