From b5d8123ae08e6921cbdfa2b5decbe7608da3414e Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 30 Apr 2019 17:24:11 +1000 Subject: [PATCH] BMR, Use your Brain, Semi/Semi2k. --- Auth/MAC_Check.hpp | 10 - Auth/MaliciousRepMC.h | 2 - Auth/MaliciousRepMC.hpp | 1 + Auth/SemiMC.h | 29 ++ Auth/SemiMC.hpp | 24 ++ Auth/ShamirMC.hpp | 22 +- Auth/Subroutines.cpp | 2 + Auth/fake-stuff.h | 7 +- Auth/fake-stuff.hpp | 41 +++ BMR/CommonParty.cpp | 111 ++----- BMR/CommonParty.h | 41 ++- BMR/CommonParty.hpp | 77 +++++ BMR/GarbledGate.cpp | 25 +- BMR/GarbledGate.h | 17 +- BMR/Key.h | 15 +- BMR/Machine.cpp | 40 +-- BMR/Party.cpp | 159 ++++------ BMR/Party.h | 148 ++++++--- BMR/Program.cpp | 27 -- BMR/ProgramParty.hpp | 106 +++++++ BMR/RealGarbleWire.h | 56 ++++ BMR/RealGarbleWire.hpp | 215 +++++++++++++ BMR/RealProgramParty.h | 72 +++++ BMR/RealProgramParty.hpp | 236 ++++++++++++++ BMR/Register.cpp | 290 +++--------------- BMR/Register.h | 39 ++- BMR/Register.hpp | 238 ++++++++++++++ BMR/SpdzWire.cpp | 24 -- BMR/SpdzWire.h | 27 +- BMR/TrustedParty.cpp | 19 +- BMR/TrustedParty.h | 3 +- BMR/aes.cpp | 38 ++- BMR/config.h | 6 +- BMR/network/Node.cpp | 4 +- BMR/network/Node.h | 2 +- BMR/network/Server.cpp | 5 + BMR/network/Server.h | 5 +- CHANGELOG.md | 8 + CONFIG | 9 +- Check-Offline.cpp | 1 + Compiler/GC/instructions.py | 5 + Compiler/GC/types.py | 4 + Compiler/dijkstra.py | 3 + Compiler/instructions.py | 6 +- Compiler/oram.py | 2 +- Compiler/path_oram.py | 3 + Compiler/program.py | 5 +- Compiler/types.py | 35 ++- Exceptions/Exceptions.h | 2 +- Fake-Offline.cpp | 15 +- GC/FakeSecret.h | 5 +- GC/Instruction.h | 5 +- GC/Instruction.hpp | 21 +- GC/Instruction_inline.h | 4 +- GC/Machine.cpp | 1 + GC/Machine.h | 10 +- GC/Machine.hpp | 26 +- GC/MaliciousRepSecret.h | 2 +- GC/Memory.h | 13 + GC/Processor.h | 21 +- GC/Processor.hpp | 38 ++- GC/Program.h | 6 +- GC/Program.hpp | 14 +- GC/ReplicatedParty.cpp | 2 + GC/ReplicatedSecret.cpp | 3 +- GC/ReplicatedSecret.h | 8 +- GC/RuntimeBranching.h | 36 +++ GC/Secret.h | 18 +- GC/Secret.hpp | 10 +- GC/Thread.hpp | 5 +- GC/ThreadMaster.h | 2 +- GC/ThreadMaster.hpp | 6 +- GC/instructions.h | 14 +- GC/square64.cpp | 57 ++-- Machines/Rep.cpp | 33 +- Machines/SPDZ.cpp | 31 +- Machines/Semi.cpp | 25 ++ Machines/ShamirMachine.cpp | 37 +-- Makefile | 28 +- Math/BrainShare.h | 49 +++ Math/FixedVec.h | 21 +- Math/Integer.cpp | 11 - Math/Integer.h | 6 +- Math/MaliciousRep3Share.h | 8 +- Math/MaliciousShamirShare.h | 6 +- Math/Rep3Share.h | 11 +- Math/Semi2kShare.h | 46 +++ Math/SemiShare.h | 113 +++++++ Math/ShamirShare.h | 22 +- Math/Share.h | 10 +- Math/Spdz2kShare.h | 10 +- Math/Z2k.cpp | 43 +++ Math/Z2k.h | 118 ++++++- Math/Zp_Data.cpp | 17 +- Math/Zp_Data.h | 50 ++- Math/bigint.cpp | 40 ++- Math/bigint.h | 55 +++- Math/gf2n.cpp | 5 +- Math/gf2n.h | 8 +- Math/gf2nlong.cpp | 10 +- Math/gf2nlong.h | 35 ++- Math/gfp.cpp | 7 +- Math/gfp.h | 25 +- Math/modp.cpp | 18 +- Math/modp.h | 10 +- Math/mpn_fixed.h | 84 ++++- Networking/Player.cpp | 12 +- Networking/Player.h | 6 +- Networking/sockets.h | 7 +- OT/BaseOT.cpp | 3 + OT/BitMatrix.cpp | 21 +- OT/NPartyTripleGenerator.cpp | 248 ++++++++++----- OT/NPartyTripleGenerator.h | 79 +++-- OT/OTExtension.cpp | 6 - OT/OTExtensionWithMatrix.cpp | 24 +- OT/OTMultiplier.cpp | 23 +- OT/OTMultiplier.h | 37 ++- OT/TripleMachine.cpp | 5 + OT/TripleMachine.h | 2 + Player-Online.cpp | 2 +- Player-Online.hpp | 8 +- Processor/BaseMachine.cpp | 5 + Processor/BaseMachine.h | 2 + Processor/Beaver.h | 2 + Processor/Binary_File_IO.hpp | 2 - Processor/BrainPrep.h | 21 ++ Processor/BrainPrep.hpp | 149 +++++++++ Processor/Data_Files.h | 1 + Processor/Data_Files.hpp | 9 +- Processor/Input.h | 27 +- Processor/Input.hpp | 72 ++++- Processor/Instruction.hpp | 9 +- Processor/Machine.hpp | 31 +- Processor/MaliciousRepPrep.h | 1 - Processor/MascotPrep.h | 21 +- Processor/MascotPrep.hpp | 17 +- Processor/Memory.hpp | 3 - Processor/NoLivePrep.h | 26 ++ Processor/Online-Thread.hpp | 3 - Processor/PrivateOutput.h | 10 +- .../{PrivateOutput.cpp => PrivateOutput.hpp} | 9 +- Processor/Processor.h | 1 + Processor/Processor.hpp | 2 +- Processor/Replicated.h | 2 + Processor/Replicated.hpp | 5 - Processor/ReplicatedInput.h | 8 +- Processor/ReplicatedInput.hpp | 28 +- Processor/ReplicatedMachine.hpp | 26 +- Processor/ReplicatedPrep.h | 5 + Processor/ReplicatedPrep.hpp | 21 +- Processor/RingOptions.cpp | 26 ++ Processor/RingOptions.h | 19 ++ Processor/SemiInput.h | 33 ++ Processor/SemiInput.hpp | 27 ++ Processor/SemiPrep.h | 21 ++ Processor/SemiPrep.hpp | 31 ++ Processor/Shamir.h | 1 + Processor/Shamir.hpp | 28 +- Processor/ShamirInput.h | 38 ++- Processor/ShamirInput.hpp | 11 +- Programs/Source/test_gc.mpc | 11 +- Programs/Source/tutorial.mpc | 27 +- README.md | 103 +++++-- Scripts/brain.sh | 10 + Scripts/build.sh | 8 +- Scripts/fake-spdz-real-bmr.sh | 8 + Scripts/mal-rep-bmr.sh | 10 + Scripts/mal-shamir-bmr.sh | 10 + Scripts/real-bmr.sh | 8 + Scripts/rep-bmr.sh | 10 + Scripts/ring.sh | 2 +- Scripts/run-common.sh | 3 +- Scripts/semi.sh | 12 + Scripts/semi2k.sh | 8 + Scripts/shamir-bmr.sh | 10 + Scripts/test_tutorial.sh | 28 ++ Scripts/tldr.sh | 2 - Scripts/yao.sh | 2 +- Tools/FlexBuffer.h | 1 + Tools/MMO.cpp | 2 + Tools/NetworkOptions.cpp | 33 ++ Tools/NetworkOptions.h | 22 ++ Tools/aes-ni.cpp | 86 +++--- Tools/aes.h | 60 ++-- Tools/avx_memcpy.h | 8 +- Tools/cpu_support.h | 71 +++++ Tools/random.cpp | 4 +- Yao/Machine.cpp | 36 --- Yao/Program.cpp | 23 -- Yao/YaoCommon.h | 5 +- Yao/YaoEvalMaster.cpp | 11 +- Yao/YaoEvalMaster.h | 4 +- Yao/YaoEvalWire.cpp | 13 + Yao/YaoEvalWire.h | 2 + Yao/YaoEvaluator.cpp | 22 +- Yao/YaoEvaluator.h | 4 +- Yao/YaoGarbleMaster.cpp | 11 +- Yao/YaoGarbleMaster.h | 4 +- Yao/YaoGarbleWire.cpp | 13 + Yao/YaoGarbleWire.h | 2 + Yao/YaoGarbler.cpp | 31 +- Yao/YaoGarbler.h | 5 +- Yao/YaoPlayer.cpp | 15 +- bmr-program-party.cpp | 4 +- brain-party.cpp | 33 ++ mal-rep-bmr-party.cpp | 13 + mal-shamir-bmr-party.cpp | 14 + real-bmr-party.cpp | 13 + rep-bmr-party.cpp | 13 + replicated-ring-party.cpp | 17 +- semi-party.cpp | 15 + semi2k-party.cpp | 27 ++ shamir-bmr-party.cpp | 14 + spdz2k-party.cpp | 4 +- 214 files changed, 4223 insertions(+), 1540 deletions(-) create mode 100644 Auth/SemiMC.h create mode 100644 Auth/SemiMC.hpp create mode 100644 BMR/CommonParty.hpp delete mode 100644 BMR/Program.cpp create mode 100644 BMR/ProgramParty.hpp create mode 100644 BMR/RealGarbleWire.h create mode 100644 BMR/RealGarbleWire.hpp create mode 100644 BMR/RealProgramParty.h create mode 100644 BMR/RealProgramParty.hpp create mode 100644 BMR/Register.hpp delete mode 100644 BMR/SpdzWire.cpp create mode 100644 GC/RuntimeBranching.h create mode 100644 Machines/Semi.cpp create mode 100644 Math/BrainShare.h create mode 100644 Math/Semi2kShare.h create mode 100644 Math/SemiShare.h create mode 100644 Processor/BrainPrep.h create mode 100644 Processor/BrainPrep.hpp create mode 100644 Processor/NoLivePrep.h rename Processor/{PrivateOutput.cpp => PrivateOutput.hpp} (78%) create mode 100644 Processor/RingOptions.cpp create mode 100644 Processor/RingOptions.h create mode 100644 Processor/SemiInput.h create mode 100644 Processor/SemiInput.hpp create mode 100644 Processor/SemiPrep.h create mode 100644 Processor/SemiPrep.hpp create mode 100755 Scripts/brain.sh create mode 100755 Scripts/fake-spdz-real-bmr.sh create mode 100755 Scripts/mal-rep-bmr.sh create mode 100755 Scripts/mal-shamir-bmr.sh create mode 100755 Scripts/real-bmr.sh create mode 100755 Scripts/rep-bmr.sh create mode 100755 Scripts/semi.sh create mode 100755 Scripts/semi2k.sh create mode 100755 Scripts/shamir-bmr.sh create mode 100755 Scripts/test_tutorial.sh create mode 100644 Tools/NetworkOptions.cpp create mode 100644 Tools/NetworkOptions.h create mode 100644 Tools/cpu_support.h delete mode 100644 Yao/Machine.cpp delete mode 100644 Yao/Program.cpp create mode 100644 brain-party.cpp create mode 100644 mal-rep-bmr-party.cpp create mode 100644 mal-shamir-bmr-party.cpp create mode 100644 real-bmr-party.cpp create mode 100644 rep-bmr-party.cpp create mode 100644 semi-party.cpp create mode 100644 semi2k-party.cpp create mode 100644 shamir-bmr-party.cpp diff --git a/Auth/MAC_Check.hpp b/Auth/MAC_Check.hpp index d01b2c8e9..de1922c79 100644 --- a/Auth/MAC_Check.hpp +++ b/Auth/MAC_Check.hpp @@ -8,16 +8,6 @@ #include "Tools/int.h" #include "Tools/benchmarking.h" -#include "Math/gfp.h" -#include "Math/gf2n.h" -#include "Math/BitVec.h" -#include "Math/Rep3Share.h" -#include "Math/MaliciousRep3Share.h" -#include "Math/ShamirShare.h" -#include "Math/MaliciousShamirShare.h" -#include "Math/Z2k.h" -#include "Math/Spdz2kShare.h" - #include template diff --git a/Auth/MaliciousRepMC.h b/Auth/MaliciousRepMC.h index bfd019c10..5313b90d3 100644 --- a/Auth/MaliciousRepMC.h +++ b/Auth/MaliciousRepMC.h @@ -7,8 +7,6 @@ #define AUTH_MALICIOUSREPMC_H_ #include "ReplicatedMC.h" -#include "GC/MaliciousRepSecret.h" -#include "GC/Machine.h" template class MaliciousRepMC : public ReplicatedMC diff --git a/Auth/MaliciousRepMC.hpp b/Auth/MaliciousRepMC.hpp index 6b0cad192..1422d3f5d 100644 --- a/Auth/MaliciousRepMC.hpp +++ b/Auth/MaliciousRepMC.hpp @@ -5,6 +5,7 @@ #include "MaliciousRepMC.h" #include "GC/Machine.h" +#include "Math/BitVec.h" #include "ReplicatedMC.hpp" diff --git a/Auth/SemiMC.h b/Auth/SemiMC.h new file mode 100644 index 000000000..525e207c4 --- /dev/null +++ b/Auth/SemiMC.h @@ -0,0 +1,29 @@ +/* + * SemiMC.h + * + */ + +#ifndef AUTH_SEMIMC_H_ +#define AUTH_SEMIMC_H_ + +#include "MAC_Check.h" + +template +class SemiMC : public TreeSum, public MAC_Check_Base +{ +public: + // emulate MAC_Check + SemiMC(const typename T::mac_key_type& _ = {}, int __ = 0, int ___ = 0) + { (void)_; (void)__; (void)___; } + + // emulate Direct_MAC_Check + SemiMC(const typename T::mac_key_type& _, Names& ____, int __ = 0, int ___ = 0) + { (void)_; (void)__; (void)___; (void)____; } + + void POpen_Begin(vector& values,const vector& S,const Player& P); + void POpen_End(vector& values,const vector& S,const Player& P); + + void Check(const Player& P) { (void)P; } +}; + +#endif /* AUTH_SEMIMC_H_ */ diff --git a/Auth/SemiMC.hpp b/Auth/SemiMC.hpp new file mode 100644 index 000000000..928da99e3 --- /dev/null +++ b/Auth/SemiMC.hpp @@ -0,0 +1,24 @@ +/* + * SemiMC.cpp + * + */ + +#include "SemiMC.h" + +template +void SemiMC::POpen_Begin(vector& values, + const vector& S, const Player& P) +{ + values.clear(); + for (auto& x : S) + values.push_back(x); + this->start(values, P); +} + +template +void SemiMC::POpen_End(vector& values, + const vector& S, const Player& P) +{ + (void) S; + this->finish(values, P); +} diff --git a/Auth/ShamirMC.hpp b/Auth/ShamirMC.hpp index 41840ebfb..f2cd6f52d 100644 --- a/Auth/ShamirMC.hpp +++ b/Auth/ShamirMC.hpp @@ -12,17 +12,25 @@ void ShamirMC::POpen_Begin(vector& values, (void) values; os.clear(); os.resize(P.num_players()); - if (P.my_num() <= threshold) + bool send = P.my_num() <= threshold; + if (send) { for (auto& share : S) share.pack(os[P.my_num()]); - for (int i = 0; i < P.num_players(); i++) - if (i != P.my_num()) - P.send_to(i, os[P.my_num()], true); } - for (int i = 0; i <= threshold; i++) - if (i != P.my_num()) - P.receive_player(i, os[i], true); + for (int offset = 1; offset < P.num_players(); offset++) + { + int send_to = P.get_player(offset); + int receive_from = P.get_player(-offset); + bool receive = receive_from <= threshold; + if (send) + if (receive) + P.pass_around(os[P.my_num()], os[receive_from], offset); + else + P.send_to(send_to, os[P.my_num()], true); + else if (receive) + P.receive_player(receive_from, os[receive_from], true); + } } template diff --git a/Auth/Subroutines.cpp b/Auth/Subroutines.cpp index 0f3a03674..2bba3f248 100644 --- a/Auth/Subroutines.cpp +++ b/Auth/Subroutines.cpp @@ -236,3 +236,5 @@ template void Create_Random(gf2n_short& ans,const Player& P); #endif template void Create_Random(gfp& ans,const Player& P); +template void Create_Random(gfp1& ans,const Player& P); +template void Create_Random(gfp2& ans,const Player& P); diff --git a/Auth/fake-stuff.h b/Auth/fake-stuff.h index e4af6c82c..a99db93c2 100644 --- a/Auth/fake-stuff.h +++ b/Auth/fake-stuff.h @@ -4,10 +4,6 @@ #include "Math/gf2n.h" #include "Math/gfp.h" -#include "Math/Z2k.h" -#include "Math/Share.h" -#include "Math/Rep3Share.h" -#include "GC/MaliciousRepSecret.h" #include using namespace std; @@ -29,6 +25,9 @@ void generate_keys(const string& directory, int nplayers); template void write_mac_keys(const string& directory, int player_num, int nplayers, U keyp, T key2); +template +void read_mac_keys(const string& directory, int player_num, int nplayers, U& keyp, T& key2); + // Read MAC key shares and compute keys void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers); diff --git a/Auth/fake-stuff.hpp b/Auth/fake-stuff.hpp index 44820ceb1..d9507036a 100644 --- a/Auth/fake-stuff.hpp +++ b/Auth/fake-stuff.hpp @@ -3,6 +3,7 @@ #include "Math/gfp.h" #include "Math/Z2k.h" #include "Math/Share.h" +#include "Math/SemiShare.h" #include "Auth/fake-stuff.h" #include "Tools/benchmarking.h" #include "Processor/config.h" @@ -29,6 +30,21 @@ void make_share(Share* Sa,const U& a,int N,const V& key,PRNG& G) Sa[N-1]=S; } +template +void make_share(SemiShare* Sa,const T& a,int N,const T& key,PRNG& G) +{ + (void) key; + insecure("share generation", false); + T x, S = a; + for (int i=0; i void make_share(FixedVec* Sa, const T& a, int N, const T& key, PRNG& G); @@ -154,6 +170,31 @@ void write_mac_keys(const string& directory, int i, int nplayers, U macp, T mac2 outf.close(); } +template +void read_mac_keys(const string& directory, int player_num, int nplayers, U& keyp, T& key2) +{ + int nn; + + string filename = directory + "Player-MAC-Keys-P" + to_string(player_num); + ifstream inpf; + inpf.open(filename); + if (inpf.fail()) + { + cerr << "Could not open MAC key file. Perhaps it needs to be generated?\n"; + throw file_error(filename); + } + inpf >> nn; + if (nn!=nplayers) + { cerr << "KeyGen was last run with " << nn << " players." << endl; + cerr << " - You are running Online with " << nplayers << " players." << endl; + exit(1); + } + + keyp.input(inpf,true); + key2.input(inpf,true); + inpf.close(); +} + inline void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers) { gfp sharep; diff --git a/BMR/CommonParty.cpp b/BMR/CommonParty.cpp index 84c4cee74..8add42bc6 100644 --- a/BMR/CommonParty.cpp +++ b/BMR/CommonParty.cpp @@ -9,11 +9,16 @@ CommonParty* CommonParty::singleton = 0; -CommonParty::CommonParty() : - _node(0), gate_counter(0), gate_counter2(0), garbled_tbl_size(0), - cpu_timer(CLOCK_PROCESS_CPUTIME_ID), buffers(TYPE_MAX) +CommonFakeParty::CommonFakeParty() : + _node(0), buffers(TYPE_MAX) { insecure("MPC emulation"); +} + +CommonParty::CommonParty() : + gate_counter(0), gate_counter2(0), garbled_tbl_size(0), + cpu_timer(CLOCK_PROCESS_CPUTIME_ID) +{ if (singleton != 0) throw runtime_error("there can only be one"); singleton = this; @@ -29,20 +34,27 @@ CommonParty::CommonParty() : mac_key.randomize(prng); } +CommonFakeParty::~CommonFakeParty() +{ + if (_node) + delete _node; +} + CommonParty::~CommonParty() { - if (_node) - delete _node; - cout << "Wire storage: " << 1e-9 * wires.capacity() << " GB" << endl; - cout << "CPU time: " << cpu_timer.elapsed() << endl; - cout << "Total time: " << timer.elapsed() << endl; - cout << "First phase time: " << timers[0].elapsed() << endl; - cout << "Second phase time: " << timers[1].elapsed() << endl; - cout << "Number of gates: " << gate_counter << endl; + cerr << "Total time: " << timer.elapsed() << endl; +#ifdef VERBOSE + cerr << "Wire storage: " << 1e-9 * wires.capacity() << " GB" << endl; + cerr << "CPU time: " << cpu_timer.elapsed() << endl; + cerr << "First phase time: " << timers[0].elapsed() << endl; + cerr << "Second phase time: " << timers[1].elapsed() << endl; + cerr << "Number of gates: " << gate_counter << endl; +#endif } -void CommonParty::init(const char* netmap_file, int id, int n_parties) +void CommonParty::check(int n_parties) { + (void) n_parties; #ifdef N_PARTIES if (n_parties != N_PARTIES) throw runtime_error("wrong number of parties"); @@ -53,6 +65,11 @@ void CommonParty::init(const char* netmap_file, int id, int n_parties) #endif _N = n_parties; #endif // N_PARTIES +} + +void CommonFakeParty::init(const char* netmap_file, int id, int n_parties) +{ + check(n_parties); printf("netmap_file: %s\n", netmap_file); if (0 == strcmp(netmap_file, LOOPBACK_STR)) { _node = new Node( NULL, id, this, _N + 1); @@ -61,7 +78,7 @@ void CommonParty::init(const char* netmap_file, int id, int n_parties) } } -int CommonParty::init(const char* netmap_file, int id) +int CommonFakeParty::init(const char* netmap_file, int id) { int n_parties; if (string(netmap_file) != string(LOOPBACK_STR)) @@ -93,7 +110,7 @@ void CommonParty::next_gate(GarbledGate& gate) gate.init_inputs(gate_counter2, _N); } -SendBuffer& CommonParty::get_buffer(MSG_TYPE type) +SendBuffer& CommonFakeParty::get_buffer(MSG_TYPE type) { SendBuffer& buffer = buffers[type]; buffer.clear(); @@ -122,52 +139,6 @@ void CommonCircuitParty::print_outputs(const vector& indices) } -template -GC::BreakType CommonParty::first_phase(GC::Program& program, - GC::Processor& processor, GC::Machine& machine) -{ - (void)machine; - timers[0].start(); - reset(); - wires.clear(); - GC::BreakType next = (reinterpret_cast*>(&program))->execute(processor); -#ifdef DEBUG_ROUNDS - cout << "finished first phase at pc " << processor.PC - << " reason " << next << endl; -#endif - timers[0].stop(); - cout << "First round time: " << timers[0].elapsed() << " / " - << timer.elapsed() << endl; -#ifdef DEBUG_WIRES - cout << "Storing wires with " << 1e-9 * wires.size() << " GB on disk" << endl; -#endif - wire_storage.push(wires); - return next; -} - -template -GC::BreakType CommonParty::second_phase(GC::Program& program, - GC::Processor& processor, GC::Machine& machine) -{ - (void)machine; - wire_storage.pop(wires); - wires.reset_head(); - timers[1].start(); - GC::BreakType next = GC::TIME_BREAK; - next = program.execute(processor); -#ifdef DEBUG_ROUNDS - cout << "finished second phase at " << processor.PC - << " reason " << next << endl; -#endif - timers[1].stop(); -// cout << "Second round time: " << timers[1].elapsed() << ", "; -// cout << "total time: " << timer.elapsed() << endl; - if (false) - return GC::CAP_BREAK; - else - return next; -} - void CommonCircuitParty::prepare_input_regs(party_id_t from) { party_t sender = _circuit->_parties[from]; @@ -188,23 +159,3 @@ void CommonCircuitParty::prepare_output_regs() for (size_t i = 0; i < _OW; i++) output_regs.push_back(_circuit->OutWiresStart()+i); } - -template GC::BreakType CommonParty::first_phase( - GC::Program >& program, - GC::Processor >& processor, - GC::Machine >& machine); - -template GC::BreakType CommonParty::first_phase( - GC::Program >& program, - GC::Processor >& processor, - GC::Machine >& machine); - -template GC::BreakType CommonParty::second_phase( - GC::Program >& program, - GC::Processor >& processor, - GC::Machine >& machine); - -template GC::BreakType CommonParty::second_phase( - GC::Program >& program, - GC::Processor >& processor, - GC::Machine >& machine); diff --git a/BMR/CommonParty.h b/BMR/CommonParty.h index 771048fad..b6cb1d614 100644 --- a/BMR/CommonParty.h +++ b/BMR/CommonParty.h @@ -50,7 +50,7 @@ namespace GC template class Machine; } -class CommonParty : public NodeUpdatable +class CommonParty { protected: friend class Register; @@ -60,8 +60,6 @@ class CommonParty : public NodeUpdatable #else party_id_t _N; #endif - Node* _node; - int gate_counter, gate_counter2; int garbled_tbl_size; @@ -71,24 +69,20 @@ class CommonParty : public NodeUpdatable gf2n mac_key; - mutex global_lock; - LocalBuffer wires; ReceivedMsgStore wire_storage; template GC::BreakType first_phase(GC::Program& program, GC::Processor& processor, GC::Machine& machine); - template + template GC::BreakType second_phase(GC::Program& program, GC::Processor& processor, - GC::Machine& machine); + GC::Machine& machine, U& dynamic_memory); public: static CommonParty* singleton; static CommonParty& s(); - vector buffers; - PRNG prng; CommonParty(); @@ -100,25 +94,40 @@ class CommonParty : public NodeUpdatable static int get_n_parties() { return s()._N; } #endif - void init(const char* netmap_file, int id, int n_parties); - int init(const char* netmap_file, int id); - virtual void reset(); + void check(int n_parties); - virtual party_id_t get_id() { return -1; } + virtual void reset(); gate_id_t new_gate(); void next_gate(GarbledGate& gate); gate_id_t next_gate(int skip) { return gate_counter2 += skip; } size_t get_garbled_tbl_size() { return garbled_tbl_size; } - SendBuffer& get_buffer(MSG_TYPE type); - gf2n get_mac_key() { return mac_key; } }; +class CommonFakeParty : virtual public CommonParty, public NodeUpdatable +{ +protected: + Node* _node; + + mutex global_lock; + +public: + CommonFakeParty(); + virtual ~CommonFakeParty(); + + vector buffers; + + void init(const char* netmap_file, int id, int n_parties); + int init(const char* netmap_file, int id); + + SendBuffer& get_buffer(MSG_TYPE type); +}; + class BooleanCircuit; -class CommonCircuitParty : virtual public CommonParty +class CommonCircuitParty : virtual public CommonFakeParty { protected: BooleanCircuit* _circuit; diff --git a/BMR/CommonParty.hpp b/BMR/CommonParty.hpp new file mode 100644 index 000000000..1051a2751 --- /dev/null +++ b/BMR/CommonParty.hpp @@ -0,0 +1,77 @@ +/* + * CommonParty.hpp + * + */ + +#ifndef BMR_COMMONPARTY_HPP_ +#define BMR_COMMONPARTY_HPP_ + +#include "CommonParty.h" + +template +GC::BreakType CommonParty::first_phase(GC::Program& program, + GC::Processor& processor, GC::Machine& machine) +{ + (void)machine; + timers[0].start(); + reset(); + wires.clear(); + NoMemory dynamic_memory; + GC::BreakType next; + try + { + next = (reinterpret_cast*>(&program))->execute(processor, dynamic_memory); + } + catch (needs_cleaning& e) + { + next = GC::CLEANING_BREAK; + processor.PC--; + } +#ifdef DEBUG_ROUNDS + cout << "finished first phase at pc " << processor.PC + << " reason " << next << endl; +#endif + timers[0].stop(); +#ifdef VERBOSE + cerr << "First round time: " << timers[0].elapsed() << " / " + << timer.elapsed() << endl; +#endif +#ifdef DEBUG_WIRES + cout << "Storing wires with " << 1e-9 * wires.size() << " GB on disk" << endl; +#endif + wire_storage.push(wires); + return next; +} + +template +GC::BreakType CommonParty::second_phase(GC::Program& program, + GC::Processor& processor, GC::Machine& machine, + U& dynamic_memory) +{ + (void)machine; + wire_storage.pop(wires); + wires.reset_head(); + timers[1].start(); + GC::BreakType next = GC::TIME_BREAK; + try + { + next = program.execute(processor, dynamic_memory); + } + catch (needs_cleaning& e) + { + next = GC::CLEANING_BREAK; + } +#ifdef DEBUG_ROUNDS + cout << "finished second phase at " << processor.PC + << " reason " << next << endl; +#endif + timers[1].stop(); +// cout << "Second round time: " << timers[1].elapsed() << ", "; +// cout << "total time: " << timer.elapsed() << endl; + if (false) + return GC::CAP_BREAK; + else + return next; +} + +#endif /* BMR_COMMONPARTY_HPP_ */ diff --git a/BMR/GarbledGate.cpp b/BMR/GarbledGate.cpp index 6d601e02a..68ef7562d 100644 --- a/BMR/GarbledGate.cpp +++ b/BMR/GarbledGate.cpp @@ -36,11 +36,10 @@ void GarbledGate::init_inputs(gate_id_t g, int n_parties) } void GarbledGate::compute_prfs_outputs(const Register** in_wires, int my_id, - SendBuffer& buffer, gate_id_t g) + PRFOutputs& prf_output, gate_id_t g) { int n_parties = CommonParty::get_n_parties(); init_inputs(g, n_parties); - PRFOutputs prf_output(n_parties); for(int w=0; w<=1; w++) { for (int b=0; b<=1; b++) { const Key& key = in_wires[w]->key(my_id, b); @@ -51,7 +50,7 @@ void GarbledGate::compute_prfs_outputs(const Register** in_wires, int my_id, #endif for (int e=0; e<=1; e++) { for (int j=1; j<= n_parties; j++) { - prf_output[my_id-1][j-1].outputs[w][b][e][0] = + prf_output[j-1].outputs[w][b][e][0] = aes_128_encrypt(*(__m128i*)input(e, j), (octet*)aes_key.rd_key); #ifdef __PRIME_FIELD__ ((Key*)prf_outputs_index)->adjust(); @@ -60,14 +59,28 @@ void GarbledGate::compute_prfs_outputs(const Register** in_wires, int my_id, } } } - for (int i = 0; i < n_parties; i++) - buffer.serialize(prf_output[my_id - 1][i]); +} + +void GarbledGate::compute_prfs_outputs(const Register** in_wires, int my_id, + SendBuffer& buffer, gate_id_t g) +{ + int n_parties = CommonParty::get_n_parties(); + PRFOutputs prf_output(n_parties); + compute_prfs_outputs(in_wires, my_id, prf_output, g); + prf_output.serialize(buffer, my_id, n_parties); #ifdef DEBUG wire_id_t wire_ids[] = { (wire_id_t)in_wires[0]->get_id(), (wire_id_t)in_wires[1]->get_id() }; prf_output.print_prfs(g, wire_ids, my_id, n_parties); #endif } +void PRFOutputs::serialize(SendBuffer& buffer, int my_id, int n_parties) +{ + (void) my_id; + for (int i = 0; i < n_parties; i++) + buffer.serialize(tuples[i]); +} + void PRFOutputs::print_prfs(gate_id_t g, wire_id_t* in_wires, party_id_t my_id, int n_parties) { for(int w=0; w<=1; w++) { @@ -75,7 +88,7 @@ void PRFOutputs::print_prfs(gate_id_t g, wire_id_t* in_wires, party_id_t my_id, for (int e=0; e<=1; e++) { for(party_id_t j=1; j<=(size_t)n_parties; j++) { printf("F_k^%d_{%lu,%u}(%d,%lu,%u) = ", my_id, in_wires[w], b, e, g, j); - Key k = *((Key*)(*this)[my_id-1][j-1].outputs[w][b][e]); + Key k = *((Key*)(*this)[j-1].outputs[w][b][e]); std::cout << k << std::endl; } } diff --git a/BMR/GarbledGate.h b/BMR/GarbledGate.h index 840f90e34..62438f26d 100644 --- a/BMR/GarbledGate.h +++ b/BMR/GarbledGate.h @@ -11,6 +11,13 @@ struct PRFTuple { Key outputs[2][2][2][1]; + // i = 0..3 + Key for_garbling(int i) + { + int a = i / 2; + int b = i % 2; + return outputs[0][a][b][0] ^ outputs[1][b][a][0]; + } }; /* @@ -27,16 +34,15 @@ struct PRFTuple { */ struct PRFOutputs { #ifdef MAX_N_PARTIES - PRFTuple tuples[MAX_N_PARTIES][MAX_N_PARTIES]; + PRFTuple tuples[MAX_N_PARTIES]; PRFOutputs(int n_parties) { (void)n_parties; } - PRFTuple* operator[](int i) { return tuples[i]; } #else - int n_parties; vector tuples; - PRFOutputs(int n_parties) : n_parties(n_parties), tuples(n_parties * n_parties) {} - PRFTuple* operator[](int i) { return &tuples[i*n_parties]; } + PRFOutputs(int n_parties) : tuples(n_parties) {} #endif + PRFTuple& operator[](int i) { return tuples[i]; } + void serialize(SendBuffer& buffer, int my_id, int n_parties); void print_prfs(gate_id_t g, wire_id_t* in_wires, party_id_t my_id, int n_parties); }; @@ -82,6 +88,7 @@ class GarbledGate : public KeyTuple<4> { char* input(int e, party_id_t j) { return (char*)&prf_inputs[e][j-1]; } void compute_prfs_outputs(const Register** in_wires, int my_id, SendBuffer& buffer, gate_id_t g); + void compute_prfs_outputs(const Register** in_wires, int my_id, PRFOutputs& outputs, gate_id_t g); void print(); }; diff --git a/BMR/Key.h b/BMR/Key.h index b902f1d93..eb12b396f 100644 --- a/BMR/Key.h +++ b/BMR/Key.h @@ -89,13 +89,16 @@ inline void Key::set_signal(bool signal) inline Key Key::doubling(int i) const { #ifdef __AVX2__ - return _mm_sllv_epi64(r, _mm_set_epi64x(i, i)); -#else - uint64_t halfs[2]; - halfs[1] = _mm_cvtsi128_si64(_mm_unpackhi_epi64(r, r)) << i; - halfs[0] = _mm_cvtsi128_si64(r) << i; - return _mm_loadu_si128((__m128i*)halfs); + if (cpu_has_avx2()) + return _mm_sllv_epi64(r, _mm_set_epi64x(i, i)); + else #endif + { + uint64_t halfs[2]; + halfs[1] = _mm_cvtsi128_si64(_mm_unpackhi_epi64(r, r)) << i; + halfs[0] = _mm_cvtsi128_si64(r) << i; + return _mm_loadu_si128((__m128i*)halfs); + } } diff --git a/BMR/Machine.cpp b/BMR/Machine.cpp index 935fc6d06..9c917412d 100644 --- a/BMR/Machine.cpp +++ b/BMR/Machine.cpp @@ -6,11 +6,15 @@ #include "BMR/CommonParty.h" #include "BMR/Register_inline.h" +#include "BMR/Register.hpp" #include "GC/Machine.hpp" #include "GC/Processor.hpp" #include "GC/Secret.hpp" #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" +#include "GC/Program.hpp" +#include "GC/Instruction.hpp" +#include "Processor/Instruction.hpp" namespace GC { @@ -58,8 +62,8 @@ void Secret::store(Memory& mem, size_t address) mac_mask.random(mac_length, mask_share.mac); word masked; int128 masked_mac; - (*this + mask).reveal(masked); - (mac + mac_mask).reveal(masked_mac); + (*this + mask).reveal(length, masked); + (mac + mac_mask).reveal(mac_length, masked_mac); #ifdef DEBUG_DYNAMIC word a,b; int128 c,d; @@ -84,7 +88,7 @@ void Secret::load(int n, const Memory& mem, size_t address) mac_key = reconstruct(CommonParty::s().get_mac_key().get(), default_length); check_mac = carryless_mult(*this, mac_key); int128 result; - (mac + check_mac).reveal(result); + (mac + check_mac).reveal(2 * default_length, result); #ifdef DEBUG_DYNAMIC cout << "loading " << hex << x.share << " " << x.mac << endl; int128 a; @@ -96,34 +100,4 @@ void Secret::load(int n, const Memory& mem, size_t address) T::check(result, x.share, x.mac); } -template class Secret; -template class Secret; -template class Secret; -template class Secret; - -template void Secret::reveal(Clear& x); -template void Secret::reveal(Clear& x); -template void Secret::reveal(Clear& x); -template void Secret::reveal(Clear& x); - -template class Machine< Secret >; -template class Machine< Secret >; -template class Machine< Secret >; -template class Machine< Secret >; - -template class Processor< Secret >; -template class Processor< Secret >; -template class Processor< Secret >; -template class Processor< Secret >; - -template class Thread< Secret >; -template class Thread< Secret >; -template class Thread< Secret >; -template class Thread< Secret >; - -template class ThreadMaster< Secret >; -template class ThreadMaster< Secret >; -template class ThreadMaster< Secret >; -template class ThreadMaster< Secret >; - } diff --git a/BMR/Party.cpp b/BMR/Party.cpp index 50d0d9806..3b76e5070 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -19,7 +19,20 @@ #include "BooleanCircuit.h" #include "Math/Setup.h" +#include "Register_inline.h" + +#include "CommonParty.hpp" +#include "ProgramParty.hpp" #include "Auth/MAC_Check.hpp" +#include "BMR/Register.hpp" +#include "GC/Machine.hpp" +#include "GC/Processor.hpp" +#include "GC/Secret.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" +#include "GC/Program.hpp" +#include "GC/Instruction.hpp" +#include "Processor/Instruction.hpp" #ifdef __PURE_SHE__ #include "mpirxx.h" @@ -28,7 +41,7 @@ ProgramParty* ProgramParty::singleton = 0; -BaseParty::BaseParty(party_id_t id) : _id(id) +BaseParty::BaseParty() { #ifdef DEBUG_PRNG_PARTY octet seed[SEED_SIZE]; @@ -48,11 +61,12 @@ Party::Party(const char* netmap_file, // required to init Node const std::string input, int numthreads, int numtries - ) :BaseParty(id), + ) :BaseParty(), _all_input(input), _NUMTHREADS(numthreads), _NUMTRIES(numtries) { + _id = id; _circuit = new BooleanCircuit( circuit_file ); _circuit->party = this; _G = _circuit->NumGates(); @@ -715,13 +729,22 @@ void BaseParty::done() { _node->Stop(); } -ProgramParty::ProgramParty(int argc, char** argv) : - BaseParty(-1), keys_for_prf(0), +ProgramParty::ProgramParty() : spdz_storage(0), garbled_storage(0), spdz_counters(SPDZ_OP_N), - machine(dynamic_memory), - processor(machine), prf_machine(dynamic_memory), + processor(machine), prf_processor(prf_machine), - MC(0) + P(0) +{ + if (singleton) + throw runtime_error("there can only be one"); + singleton = this; + threshold = 128; + eval_threads = new Worker[N_EVAL_THREADS]; + and_jobs.resize(N_EVAL_THREADS); +} + +FakeProgramParty::FakeProgramParty(int argc, const char** argv) : + keys_for_prf(0) { if (argc < 3) { @@ -729,16 +752,10 @@ ProgramParty::ProgramParty(int argc, char** argv) : exit(1); } + load(argv[2]); _id = atoi(argv[1]); - program.parse(string(argv[2]) + "-0"); - machine.reset(program); - processor.reset(program); processor.open_input_file("user_inputs/user_" + to_string(_id - 1) + "_input.txt"); - prf_machine.reset(*reinterpret_cast >* >(&program)); - prf_processor.reset(*reinterpret_cast >* >(&program)); - if (singleton) - throw runtime_error("there can only be one"); - singleton = this; + if (argc > 3) { int n_parties = init(argv[3], _id); @@ -759,9 +776,6 @@ ProgramParty::ProgramParty(int argc, char** argv) : int n_parties = init("LOOPBACK", _id); N.init(_id - 1, 5000, vector(n_parties, "localhost")); } - prf_output = (char*)new __m128i[PAD_TO_8(get_n_parties())]; - mac_key = prng.get_word() & ((1ULL << GC::Secret::default_length) - 1); - cout << "MAC key: " << hex << mac_key << endl; ifstream schfile((string("Programs/Schedules/") + argv[2] + ".sch").c_str()); string curr, prev; while (schfile.good()) @@ -773,30 +787,39 @@ ProgramParty::ProgramParty(int argc, char** argv) : P = new PlainPlayer(N, 0); if (argc > 4) threshold = atoi(argv[4]); - else - threshold = 128; cout << "Threshold for multi-threaded evaluation: " << threshold << endl; - eval_threads = new Worker[N_EVAL_THREADS]; - and_jobs.resize(N_EVAL_THREADS); } ProgramParty::~ProgramParty() { reset(); - delete[] prf_output; - delete P; - if (MC) - delete MC; + if (P) + { + cerr << "Data sent: " << 1e-6 * P->comm_stats.total_data() << " MB" << endl; + delete P; + } delete[] eval_threads; - cout << "SPDZ loading: " << spdz_counters[SPDZ_LOAD] << endl; - cout << "SPDZ storing: " << spdz_counters[SPDZ_STORE] << endl; - cout << "SPDZ wire storage: " << 1e-9 * spdz_storage << " GB" << endl; - cout << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity() * - sizeof(GC::Secret::DynamicType) << " GB" << endl; - cout << "Maximum circuit storage: " << 1e-9 * garbled_storage << " GB" << endl; +#ifdef VERBOSE + if (spdz_counters[SPDZ_LOAD]) + cerr << "SPDZ loading: " << spdz_counters[SPDZ_LOAD] << endl; + if (spdz_counters[SPDZ_STORE]) + cerr << "SPDZ storing: " << spdz_counters[SPDZ_STORE] << endl; + if (spdz_storage) + cerr << "SPDZ wire storage: " << 1e-9 * spdz_storage << " GB" << endl; + cerr << "Maximum circuit storage: " << 1e-9 * garbled_storage << " GB" << endl; +#endif +} + +FakeProgramParty::~FakeProgramParty() +{ +#ifdef VERBOSE + if (dynamic_memory.capacity_in_bytes()) + cerr << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity_in_bytes() + << " GB" << endl; +#endif } -void ProgramParty::_compute_prfs_outputs(Key* keys) +void FakeProgramParty::_compute_prfs_outputs(Key* keys) { keys_for_prf = keys; first_phase(program, prf_processor, prf_machine); @@ -835,45 +858,7 @@ void ProgramParty::start_online_round() _check_evaluate(); } -void ProgramParty::_check_evaluate() -{ -#ifdef DEBUG_REGS - print_round_regs(); -#endif - cout << "Online time at evaluation start: " << online_timer.elapsed() - << endl; - GC::BreakType next = GC::TIME_BREAK; - while (next == GC::TIME_BREAK) - { - load_garbled_circuit(); - next = second_phase(program, processor, machine); - } - cout << "Online time at evaluation stop: " << online_timer.elapsed() - << endl; - if (next == GC::TIME_BREAK) - { -#ifdef DEBUG_STEPS - cout << "another round of garbling" << endl; -#endif - } - if (next != GC::DONE_BREAK) - { -#ifdef DEBUG_STEPS - cout << "another round of evaluation" << endl; -#endif - start_online_round(); - } - else - { - Timer timer; - timer.start(); - MC->Check(*P); - cout << "Final check took " << timer.elapsed() << endl; - done(); - } -} - -void ProgramParty::receive_keys(Register& reg) +void FakeProgramParty::receive_keys(Register& reg) { reg.init(_N); for (int i = 0; i < 2; i++) @@ -885,14 +870,21 @@ void ProgramParty::receive_keys(Register& reg) #endif } -void ProgramParty::receive_all_keys(Register& reg, bool external) +void FakeProgramParty::receive_all_keys(Register& reg, bool external) { reg.init(get_n_parties()); for (int i = 0; i < get_n_parties(); i++) reg.keys[external][i] = *(keys_for_prf++); } -void ProgramParty::receive_spdz_wires(ReceivedMsg& msg) +void FakeProgramParty::process_prf_output(PRFOutputs& prf_output, + PRFRegister* wire, const PRFRegister* left, const PRFRegister* right) +{ + (void) wire, (void) left, (void) right; + prf_output.serialize(buffers[TYPE_PRF_OUTPUTS], _id, get_n_parties()); +} + +void FakeProgramParty::receive_spdz_wires(ReceivedMsg& msg) { int op; msg.unserialize(op); @@ -917,25 +909,6 @@ void ProgramParty::receive_spdz_wires(ReceivedMsg& msg) } } -void ProgramParty::get_spdz_wire(SpdzOp op, SpdzWire& spdz_wire) -{ - while (true) - { - if (spdz_wires[op].empty()) - throw runtime_error("no SPDZ wires available"); - if (spdz_wires[op].front().done()) - spdz_wires[op].pop_front(); - else - break; - } - spdz_wire.unpack(spdz_wires[op].front(), get_n_parties()); - spdz_counters[op]++; -#ifdef DEBUG_SPDZ_WIRE - cout << "get SPDZ wire of type " << op << ", " << spdz_wires[op].front().left() << " bytes left" << endl; - cout << "mask share for " << get_id() << ": " << spdz_wire.mask << endl; -#endif -} - void ProgramParty::store_wire(const Register& reg) { wires.serialize(reg.key(get_id(), 0)); diff --git a/BMR/Party.h b/BMR/Party.h index a0a701260..312e5daa5 100644 --- a/BMR/Party.h +++ b/BMR/Party.h @@ -21,6 +21,7 @@ #include "GC/Program.h" #include "GC/Processor.h" #include "GC/Secret.h" +#include "GC/RuntimeBranching.h" #include "Tools/Worker.h" class BooleanCircuit; @@ -40,9 +41,27 @@ typedef struct { unsigned long long acc=0; } exec_props_t; -class BaseParty : virtual public CommonParty { +class PartyProperties +{ +protected: + party_id_t _id; + + Timer online_timer; + + Key delta; + +public: + PartyProperties() : _id(-1) {} + + party_id_t get_id() { return _id; } + Key get_delta() { return delta; } + +}; + +class BaseParty : virtual public CommonFakeParty, virtual public PartyProperties +{ public: - BaseParty(party_id_t id); + BaseParty(); virtual ~BaseParty(); /* From NodeUpdatable class */ @@ -52,19 +71,10 @@ class BaseParty : virtual public CommonParty { void Start(); - party_id_t get_id() { return _id; } - Key get_delta() { return delta; } - protected: - party_id_t _id; - // int _num_evaluation_threads; struct timeval _start_online_net, _end_online_net; - Timer online_timer; - - Key delta; - virtual void _compute_prfs_outputs(Key* keys) = 0; void _send_prfs(); @@ -165,14 +175,14 @@ class Party : public BaseParty, public CommonCircuitParty { int get_n_inputs(); }; -class ProgramParty : public BaseParty +class ProgramParty : virtual public CommonParty, virtual public PartyProperties, public GC::RuntimeBranching { +protected: friend class PRFRegister; friend class EvalRegister; friend class Register; - char* prf_output; - Key* keys_for_prf; + vector prf_output; deque spdz_wires[SPDZ_OP_N]; size_t spdz_storage; @@ -185,7 +195,6 @@ class ProgramParty : public BaseParty ReceivedMsgStore output_masks_store; ReceivedMsgStore input_masks_store; - GC::Memory< GC::Secret::DynamicType > dynamic_memory; GC::Machine< GC::Secret > machine; GC::Processor > processor; GC::Program > program; @@ -193,24 +202,16 @@ class ProgramParty : public BaseParty GC::Machine< GC::Secret > prf_machine; GC::Processor > prf_processor; - void _compute_prfs_outputs(Key* keys); - - void _process_external_received(char* externals, - party_id_t from) { (void)externals; (void)from; } - void _process_all_external_received(char* externals) { (void)externals; } - void _process_input_keys(Key* keys, party_id_t from) - { (void)keys; (void)from; } - void _process_all_input_keys(char* keys) { (void)keys; } - void store_garbled_circuit(ReceivedMsg& msg); void load_garbled_circuit(); - void _check_evaluate(); - - void receive_keys(Register& reg); - void receive_all_keys(Register& reg, bool external); + virtual void _check_evaluate() = 0; + virtual void done() = 0; - void receive_spdz_wires(ReceivedMsg& msg); + virtual void receive_keys(Register& reg) = 0; + virtual void receive_all_keys(Register& reg, bool external) = 0; + virtual void process_prf_output(PRFOutputs& prf_output, + PRFRegister* out, const PRFRegister* left, const PRFRegister* right) = 0; void start_online_round(); @@ -220,31 +221,95 @@ class ProgramParty : public BaseParty public: static ProgramParty* singleton; - ReceivedMsg garbled_circuit; + LocalBuffer garbled_circuit; ReceivedMsgStore garbled_circuits; - ReceivedMsg output_masks; - ReceivedMsg input_masks; + LocalBuffer output_masks; + LocalBuffer input_masks; - MAC_Check* MC; Player* P; Names N; int threshold; + Integer convcbit; + static ProgramParty& s(); - ProgramParty(int argc, char** argv); - ~ProgramParty(); + ProgramParty(); + virtual ~ProgramParty(); void reset(); - void get_spdz_wire(SpdzOp op, SpdzWire& spdz_wire); - void store_wire(const Register& reg); void load_wire(Register& reg); }; +template +class ProgramPartySpec : public ProgramParty +{ + static ProgramPartySpec* singleton; + +protected: + GC::Memory dynamic_memory; + + void _check_evaluate(); + +public: + typename T::MAC_Check* MC; + + static ProgramPartySpec& s(); + + ProgramPartySpec(); + ~ProgramPartySpec(); + + void load(string progname); + + void get_spdz_wire(SpdzOp op, DualWire& spdz_wire); +}; + +#ifdef SPDZ_AUTH +typedef ProgramPartySpec> FakeProgramPartySuper; +#else +typedef ProgramPartySpec> FakeProgramPartySuper; +#endif + +class FakeProgramParty : virtual public BaseParty, virtual public FakeProgramPartySuper +{ + Key* keys_for_prf; + + void _compute_prfs_outputs(Key* keys); + + void _process_external_received(char* externals, + party_id_t from) { (void)externals; (void)from; } + void _process_all_external_received(char* externals) { (void)externals; } + void _process_input_keys(Key* keys, party_id_t from) + { (void)keys; (void)from; } + void _process_all_input_keys(char* keys) { (void)keys; } + + void store_garbled_circuit(ReceivedMsg& msg) { ProgramParty::store_garbled_circuit(msg); } + + void _check_evaluate() { FakeProgramPartySuper::_check_evaluate(); } + + void receive_keys(Register& reg); + void receive_all_keys(Register& reg, bool external); + void process_prf_output(PRFOutputs& prf_output, PRFRegister* out, + const PRFRegister* left, const PRFRegister* right); + + void receive_spdz_wires(ReceivedMsg& msg); + + void start_online_round() { FakeProgramPartySuper::start_online_round(); } + + void mask_output(ReceivedMsg& msg) { ProgramParty::mask_output(msg); } + void mask_input(ReceivedMsg& msg) { ProgramParty::mask_input(msg); } + + void done() { BaseParty::done(); } + +public: + FakeProgramParty(int argc, const char** argv); + ~FakeProgramParty(); +}; + inline ProgramParty& ProgramParty::s() { if (singleton) @@ -253,4 +318,13 @@ inline ProgramParty& ProgramParty::s() throw runtime_error("no singleton"); } +template +inline ProgramPartySpec& ProgramPartySpec::s() +{ + if (singleton) + return *singleton; + else + throw runtime_error("no singleton"); +} + #endif /* PROTOCOL_PARTY_H_ */ diff --git a/BMR/Program.cpp b/BMR/Program.cpp deleted file mode 100644 index 25c1b7368..000000000 --- a/BMR/Program.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Program.cpp - * - */ - -#include "Register.h" -#include "GC/Secret.h" - -#include "GC/Instruction.hpp" -#include "GC/Program.hpp" - -#include "Processor/Instruction.hpp" - -namespace GC -{ - -template class Instruction< Secret >; -template class Instruction< Secret >; -template class Instruction< Secret >; -template class Instruction< Secret >; - -template class Program< Secret >; -template class Program< Secret >; -template class Program< Secret >; -template class Program< Secret >; - -} diff --git a/BMR/ProgramParty.hpp b/BMR/ProgramParty.hpp new file mode 100644 index 000000000..7a136118f --- /dev/null +++ b/BMR/ProgramParty.hpp @@ -0,0 +1,106 @@ +/* + * ProgramParty.hpp + * + */ + +#ifndef BMR_PROGRAMPARTY_HPP_ +#define BMR_PROGRAMPARTY_HPP_ + +#include "Party.h" + +template +ProgramPartySpec* ProgramPartySpec::singleton = 0; + +template +ProgramPartySpec::ProgramPartySpec() : MC(0) +{ + assert(singleton == 0); + singleton = this; +} + +template +ProgramPartySpec::~ProgramPartySpec() +{ + if (MC) + delete MC; +} + +template +void ProgramPartySpec::load(string progname) +{ + program.parse(progname + "-0"); + machine.reset(program, dynamic_memory); + processor.reset(program); + prf_machine.reset(*reinterpret_cast >* >(&program)); + prf_processor.reset(*reinterpret_cast >* >(&program)); +} + +template +void ProgramPartySpec::_check_evaluate() +{ +#ifdef DEBUG_REGS + print_round_regs(); +#endif +#ifdef VERBOSE + cerr << "Online time at evaluation start: " << online_timer.elapsed() + << endl; +#endif + GC::BreakType next = GC::TIME_BREAK; + while (next == GC::TIME_BREAK) + { + load_garbled_circuit(); + next = second_phase(program, processor, machine, dynamic_memory); + } +#ifdef VERBOSE + cerr << "Online time at evaluation stop: " << online_timer.elapsed() + << endl; +#endif + if (next == GC::TIME_BREAK) + { +#ifdef DEBUG_STEPS + cout << "another round of garbling" << endl; +#endif + } + if (next == GC::CLEANING_BREAK) + return; + if (next != GC::DONE_BREAK) + { +#ifdef DEBUG_STEPS + cout << "another round of evaluation" << endl; +#endif + start_online_round(); + } + else + { + Timer timer; + timer.start(); + MC->Check(*P); +#ifdef VERBOSE + cerr << "Final check took " << timer.elapsed() << endl; +#endif + done(); + machine.write_memory(N.my_num()); + } +} + +template +void ProgramPartySpec::get_spdz_wire(SpdzOp op, DualWire& spdz_wire) +{ + while (true) + { + if (spdz_wires[op].empty()) + throw runtime_error("no SPDZ wires available"); + if (spdz_wires[op].front().done()) + spdz_wires[op].pop_front(); + else + break; + } + spdz_wire.unpack(spdz_wires[op].front(), get_n_parties()); + spdz_counters[op]++; +#ifdef DEBUG_SPDZ_WIRE + cout << "get SPDZ wire of type " << op << ", " << spdz_wires[op].front().left() << " bytes left" << endl; + cout << "mask share for " << get_id() << ": " << spdz_wire.mask << endl; +#endif +} + +#endif /* BMR_PROGRAMPARTY_HPP_ */ diff --git a/BMR/RealGarbleWire.h b/BMR/RealGarbleWire.h new file mode 100644 index 000000000..754e26b26 --- /dev/null +++ b/BMR/RealGarbleWire.h @@ -0,0 +1,56 @@ +/* + * RealGarbleWire.h + * + */ + +#ifndef BMR_REALGARBLEWIRE_H_ +#define BMR_REALGARBLEWIRE_H_ + +#include "Register.h" + +template class RealProgramParty; + +template +class RealGarbleWire : public PRFRegister +{ + friend class RealProgramParty; + + T mask; + +public: + static void store(NoMemory& dest, + const vector>>& accesses); + static void load(vector>>& accesses, + const NoMemory& source); + + static void convcbit(Integer& dest, const GC::Clear& source); + + RealGarbleWire(const Register& reg) : PRFRegister(reg) {} + + void garble(PRFOutputs& prf_output, const RealGarbleWire& left, + const RealGarbleWire& right); + + void XOR(const RealGarbleWire& left, const RealGarbleWire& right); + + void input(party_id_t from, char input = -1); + void public_input(bool value); + void random(); + void output(); +}; + +template +class GarbleJob +{ + typedef typename T::Protocol Protocol; + typedef typename T::Input Inputter; + + T lambda_u, lambda_v, lambda_uv, lambda_w; + +public: + GarbleJob(T lambda_u, T lambda_v, T lambda_w); + void middle_round(RealProgramParty& party, Protocol& second_protocol); + void last_round(RealProgramParty& party, Inputter& inputter, + Protocol& second_protocol, vector& wires); +}; + +#endif /* BMR_REALGARBLEWIRE_H_ */ diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp new file mode 100644 index 000000000..26cb70890 --- /dev/null +++ b/BMR/RealGarbleWire.hpp @@ -0,0 +1,215 @@ +/* + * RealGarbleWire.cpp + * + */ + +#include "RealGarbleWire.h" +#include "RealProgramParty.h" +#include "Processor/MascotPrep.h" + +template +void RealGarbleWire::garble(PRFOutputs& prf_output, + const RealGarbleWire& left, const RealGarbleWire& right) +{ + auto& party = RealProgramParty::s(); + assert(party.prep != 0); + party.prep->get_one(DATA_BIT, mask); + auto& inputter = *party.garble_inputter; + int n = party.N.num_players(); + int me = party.N.my_num(); + inputter.add_from_all(int128(keys[0][me].r)); + for (int k = 0; k < 4; k++) + for (int j = 0; j < n; j++) + inputter.add_from_all(int128(prf_output[j].for_garbling(k).r)); + + assert(party.shared_proc != 0); + assert(party.garble_protocol != 0); + auto& protocol = *party.garble_protocol; + protocol.prepare_mul(left.mask, right.mask); + GarbleJob job(left.mask, right.mask, mask); + party.garble_jobs.push_back(job); +} + +template +GarbleJob::GarbleJob(T lambda_u, T lambda_v, T lambda_w) : + lambda_u(lambda_u), lambda_v(lambda_v), lambda_w(lambda_w) +{ +} + +template +void GarbleJob::middle_round(RealProgramParty& party, Protocol& second_protocol) +{ + int n = party.N.num_players(); + int me = party.N.my_num(); + assert(party.garble_protocol != 0); + auto& protocol = *party.garble_protocol; + lambda_uv = protocol.finalize_mul(); + +#ifdef DEBUG_MASK + cout << "lambda_u " << party.MC->POpen(lambda_u, *party.P) << endl; + cout << "lambda_v " << party.MC->POpen(lambda_v, *party.P) << endl; + cout << "lambda_w " << party.MC->POpen(lambda_w, *party.P) << endl; + cout << "lambda_uv " << party.MC->POpen(lambda_uv, *party.P) << endl; +#endif + + for (int alpha = 0; alpha < 2; alpha++) + for (int beta = 0; beta < 2; beta++) + for (int j = 0; j < n; j++) + { + second_protocol.prepare_mul(party.shared_delta(j), + lambda_uv + lambda_v * alpha + lambda_u * beta + + T(alpha * beta, me, party.MC->get_alphai()) + + lambda_w); + } +} + +template +void GarbleJob::last_round(RealProgramParty& party, Inputter& inputter, + Protocol& second_protocol, vector& wires) +{ + int n = party.N.num_players(); + auto& protocol = second_protocol; + + vector base_keys; + for (int i = 0; i < n; i++) + base_keys.push_back(inputter.finalize(i)); + + for (int k = 0; k < 4; k++) + for (int j = 0; j < n; j++) + { + wires.push_back({}); + auto& wire = wires.back(); + for (int i = 0; i < n; i++) + wire += inputter.finalize(i); + wire += base_keys[j]; + wire += protocol.finalize_mul(); + } +} + +template +void RealGarbleWire::XOR(const RealGarbleWire& left, const RealGarbleWire& right) +{ + PRFRegister::XOR(left, right); + mask = left.mask + right.mask; +} + +template +void RealGarbleWire::input(party_id_t from, char input) +{ + PRFRegister::input(from, input); + auto& party = RealProgramParty::s(); + assert(party.shared_proc != 0); + auto& inputter = party.shared_proc->input; + inputter.reset(from - 1); + if (from == party.get_id()) + { + char my_mask; + my_mask = party.prng.get_bit(); + party.input_masks.serialize(my_mask); + inputter.add_mine(my_mask); + inputter.send_mine(); + mask = inputter.finalize_mine(); +#ifdef DEBUG_MASK + cout << "my mask: " << (int)my_mask << endl; +#endif + } + else + { + inputter.add_other(from - 1); + octetStream os; + party.P->receive_player(from - 1, os, true); + inputter.finalize_other(from - 1, mask, os); + } + // important to make sure that mask is a bit + try + { + mask.force_to_bit(); + } + catch (not_implemented& e) + { + assert(party.P != 0); + assert(party.MC != 0); + auto& protocol = party.shared_proc->protocol; + protocol.init_mul(party.shared_proc); + protocol.prepare_mul(mask, T(1, party.P->my_num(), party.mac_key) - mask); + protocol.exchange(); + if (party.MC->POpen(protocol.finalize_mul(), *party.P) != 0) + throw runtime_error("input mask not a bit"); + } +#ifdef DEBUG_MASK + cout << "shared mask: " << party.MC->POpen(mask, *party.P) << endl; +#endif +} + +template +void RealGarbleWire::public_input(bool value) +{ + PRFRegister::public_input(value); + mask = {}; +} + +template +void RealGarbleWire::random() +{ + // no need to randomize keys + PRFRegister::public_input(0); + auto& party = RealProgramParty::s(); + assert(party.prep != 0); + party.prep->get_one(DATA_BIT, mask); + // this is necessary to match the fake BMR evaluation phase + party.store_wire(*this); + keys[0].serialize(party.wires); +} + +template +void RealGarbleWire::output() +{ + PRFRegister::output(); + auto& party = RealProgramParty::s(); + assert(party.MC != 0); + assert(party.P != 0); + auto m = party.MC->POpen(mask, *party.P); + party.output_masks.push_back(m.get_bit(0)); + party.taint(); +#ifdef DEBUG_MASK + cout << "output mask: " << m << endl; +#endif +} + +template +void RealGarbleWire::store(NoMemory& dest, + const vector > >& accesses) +{ + (void) dest; + auto& party = RealProgramParty::s(); + for (auto access : accesses) + for (auto& reg : access.source.get_regs()) + { + party.push_spdz_wire(SPDZ_STORE, reg); + } +} + +template +void RealGarbleWire::load( + vector > >& accesses, + const NoMemory& source) +{ + PRFRegister::load(accesses, source); + auto& party = RealProgramParty::s(); + assert(party.prep != 0); + for (auto access : accesses) + for (auto& reg : access.dest.get_regs()) + { + party.prep->get_one(DATA_BIT, reg.mask); + party.push_spdz_wire(SPDZ_LOAD, reg); + } +} + +template +void RealGarbleWire::convcbit(Integer& dest, const GC::Clear& source) +{ + (void) source; + auto& party = RealProgramParty::s(); + party.untaint(); + dest = party.convcbit; +} diff --git a/BMR/RealProgramParty.h b/BMR/RealProgramParty.h new file mode 100644 index 000000000..d3ef1d0b4 --- /dev/null +++ b/BMR/RealProgramParty.h @@ -0,0 +1,72 @@ +/* + * RealProgramParty.h + * + */ + +#ifndef BMR_REALPROGRAMPARTY_H_ +#define BMR_REALPROGRAMPARTY_H_ + +#include "Party.h" +#include "RealGarbleWire.h" + +#include "GC/Machine.h" +#include "GC/RuntimeBranching.h" +#include "Processor/Processor.h" + +template +class RealProgramParty : public ProgramPartySpec +{ + typedef typename T::Input Inputter; + + friend class RealGarbleWire; + friend class GarbleJob; + + static RealProgramParty* singleton; + + GC::Machine>> garble_machine; + GC::Processor>> garble_processor; + + DataPositions usage; + Preprocessing* prep; + SubProcessor* shared_proc; + + ArithmeticProcessor dummy_proc; + + vector deltas; + + Inputter* garble_inputter; + typename T::Protocol* garble_protocol; + vector> garble_jobs; + + GC::BreakType next; + +public: + static RealProgramParty& s(); + + RealProgramParty(int argc, const char** argv); + ~RealProgramParty(); + + void garble(); + + void receive_keys(Register& reg); + void receive_all_keys(Register& reg, bool external); + void process_prf_output(PRFOutputs& prf_output, PRFRegister* out, + const PRFRegister* left, const PRFRegister* right); + + void push_spdz_wire(SpdzOp op, const RealGarbleWire& wire); + + void done() {} + + T shared_delta(int i) { return deltas[i]; } +}; + +template +inline RealProgramParty& RealProgramParty::s() +{ + if (singleton) + return *singleton; + else + throw runtime_error("no singleton"); +} + +#endif /* BMR_REALPROGRAMPARTY_H_ */ diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp new file mode 100644 index 000000000..7f493868c --- /dev/null +++ b/BMR/RealProgramParty.hpp @@ -0,0 +1,236 @@ +/* + * RealProgramParty.cpp + * + */ + +#include "RealProgramParty.h" + +#include "Register_inline.h" + +#include "Tools/NetworkOptions.h" +#include "Math/Setup.h" + +#include "RealGarbleWire.hpp" +#include "CommonParty.hpp" +#include "Register.hpp" +#include "ProgramParty.hpp" +#include "GC/Machine.hpp" +#include "GC/Processor.hpp" +#include "GC/Program.hpp" +#include "GC/Instruction.hpp" +#include "GC/Secret.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" + +template +RealProgramParty* RealProgramParty::singleton = 0; + +template +RealProgramParty::RealProgramParty(int argc, const char** argv) : + garble_processor(garble_machine), dummy_proc({{}, 0}) +{ + assert(singleton == 0); + singleton = this; + + ez::ezOptionParser opt; + opt.add( + T::needs_ot ? "2" : "3", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of players", // Help description. + "-N", // Flag token. + "--nparties" // Flag token. + ); + opt.parse(argc, argv); + int nparties; + opt.get("-N")->getInt(nparties); + this->check(nparties); + + NetworkOptions network_opts(opt, argc, argv); + OnlineOptions online_opts(opt, argc, argv); + assert(not online_opts.interactive); + + online_opts.finalize(opt, argc, argv); + this->load(online_opts.progname); + + auto& N = this->N; + auto& P = this->P; + auto& delta = this->delta; + auto& mac_key = this->mac_key; + auto& garble_processor = this->garble_processor; + auto& prng = this->prng; + auto& program = this->program; + auto& MC = this->MC; + + this->_id = online_opts.playerno + 1; + Server* server = Server::start_networking(N, online_opts.playerno, nparties, + network_opts.hostname, network_opts.portnum_base); + if (T::needs_ot) + P = new PlainPlayer(N, 0); + else + P = new CryptoPlayer(N, 0); + + delta = prng.get_doubleword(); +#ifdef KEY_SIGNAL + delta.set_signal(1); +#endif +#ifdef VERBOSE + cerr << "delta: " << delta << endl; +#endif + + string prep_dir = get_prep_dir(nparties, 128, 128); + usage = DataPositions(N.num_players()); + if (online_opts.live_prep) + { + mac_key.randomize(prng); + if (T::needs_ot) + BaseMachine::s().ot_setups.push_back({{{*P, true}}}); + prep = Preprocessing::get_live_prep(0, usage); + } + else + { + Z2<64> _; + read_mac_keys(prep_dir, online_opts.playerno, nparties, _, mac_key); + prep = new Sub_Data_Files(N, prep_dir, usage); + } + + MC = new typename T::MAC_Check(mac_key); + + garble_processor.reset(program); + this->processor.open_input_file(N.my_num(), 0); + + shared_proc = new SubProcessor(dummy_proc, *MC, *prep, *P); + + auto& inputter = shared_proc->input; + inputter.reset_all(*P); + for (int i = 0; i < N.num_players(); i++) + if (i == N.my_num()) + inputter.add_mine(int128(delta.r)); + else + inputter.add_other(i); + inputter.exchange(); + for (int i = 0; i < N.num_players(); i++) + deltas.push_back(inputter.finalize(i)); + + garble_inputter = new Inputter(shared_proc, *P); + garble_protocol = new typename T::Protocol(*P); + for (int i = 0; i < SPDZ_OP_N; i++) + this->spdz_wires[i].push_back({}); + + do + { + next = GC::TIME_BREAK; + garble(); + try + { + this->online_timer.start(); + this->start_online_round(); + this->online_timer.stop(); + } + catch (needs_cleaning& e) + { + } + } + while (next != GC::DONE_BREAK); + + MC->Check(*P); + + if (server) + delete server; +} + +template +void RealProgramParty::garble() +{ + auto& P = this->P; + auto& garble_processor = this->garble_processor; + auto& program = this->program; + auto& MC = this->MC; + + while (next == GC::TIME_BREAK) + { + garble_jobs.clear(); + garble_inputter->reset_all(*P); + auto& protocol = *garble_protocol; + protocol.init_mul(shared_proc); + + next = this->first_phase(program, garble_processor, this->garble_machine); + + garble_inputter->exchange(); + protocol.exchange(); + + typename T::Protocol second_protocol(*P); + second_protocol.init_mul(shared_proc); + for (auto& job : garble_jobs) + job.middle_round(*this, second_protocol); + + second_protocol.exchange(); + + vector wires; + for (auto& job : garble_jobs) + job.last_round(*this, *garble_inputter, second_protocol, wires); + + vector opened; + MC->POpen(opened, wires, *P); + + for (auto& x : opened) + this->garbled_circuit.serialize(x); + + this->garbled_circuits.push_and_clear(this->garbled_circuit); + this->input_masks_store.push_and_clear(this->input_masks); + this->output_masks_store.push_and_clear(this->output_masks); + } +} + +template +RealProgramParty::~RealProgramParty() +{ + delete shared_proc; + delete prep; + delete garble_inputter; + delete garble_protocol; +} + +template +void RealProgramParty::receive_keys(Register& reg) +{ +#ifndef FREE_XOR +#error not implemented +#endif + auto& _id = this->_id; + auto& _N = this->_N; + reg.init(_N); + reg.keys[0][_id - 1] = this->prng.get_doubleword(); +#ifdef KEY_SIGNAL + reg.keys[0][_id - 1].set_signal(0); +#endif + reg.keys[1][_id - 1] = reg.keys[0][_id - 1] ^ this->get_delta(); +} + +template +void RealProgramParty::receive_all_keys(Register& reg, bool external) +{ + (void) reg, (void) external; + throw not_implemented(); +} + +template +void RealProgramParty::process_prf_output(PRFOutputs& prf_output, PRFRegister* out, const PRFRegister* left, const PRFRegister* right) +{ + assert(out != 0 and left != 0 and right != 0); + auto l = reinterpret_cast*>(left); + auto r = reinterpret_cast*>(right); + reinterpret_cast*>(out)->garble(prf_output, *l, *r); +} + +template +void RealProgramParty::push_spdz_wire(SpdzOp op, const RealGarbleWire& wire) +{ + DualWire spdz_wire; + spdz_wire.mask = wire.mask; + for (int i = 0; i < 2; i++) + spdz_wire.my_keys[i] = wire.keys[i][this->N.my_num()]; + spdz_wire.pack(this->spdz_wires[op].back()); + this->spdz_storage += sizeof(SpdzWire); +} diff --git a/BMR/Register.cpp b/BMR/Register.cpp index 5fd637851..19060b037 100644 --- a/BMR/Register.cpp +++ b/BMR/Register.cpp @@ -11,13 +11,17 @@ #include "TrustedParty.h" #include "CommonParty.h" #include "Register_inline.h" +#include "RealGarbleWire.h" #include "prf.h" #include "GC/Secret.h" +#include "GC/Secret_inline.h" #include "GC/Processor.h" #include "Tools/FlexBuffer.h" +#include "GC/Processor.hpp" + #include ostream& EvalRegister::out = cout; @@ -101,6 +105,9 @@ void Register::check_mask() const void Register::set_mask(char mask) { +#ifdef DEBUG_MASK + cout << "setting mask: " << (int)mask << endl; +#endif this->mask = mask; check_mask(); } @@ -206,7 +213,8 @@ void EvalRegister::op(const ProgramRegister& left, const ProgramRegister& right, GarbledGate gate(party.get_n_parties()); party.next_gate(gate); gate.unserialize(party.garbled_circuit, party.get_n_parties()); - Register::eval(left, right, gate, party._id, party.prf_output, + party.prf_output.resize(PAD_TO_8(party.get_n_parties()) * sizeof(__m128i)); + Register::eval(left, right, gate, party._id, party.prf_output.data(), get_id(), left.get_id(), right.get_id()); } @@ -221,6 +229,10 @@ void Register::eval(const Register& left, const Register& right, GarbledGate& ga int sig_r = right.get_external(); int entry = 2 * sig_l + sig_r; +#ifdef DEBUG_MASK + cout << "input signals: " << sig_l << " " << sig_r << endl; +#endif + #ifdef DEBUG gate.print(); cout << "picking " << entry << endl; @@ -300,6 +312,10 @@ void Register::eval(const Register& left, const Register& right, GarbledGate& ga } #endif +#ifdef DEBUG_MASK + cout << "output signal: " << (int)external << endl; +#endif + #ifdef DEBUG std::cout << "k^"< >& processor, access.received_labels(oss); } +void EvalRegister::convcbit(Integer& dest, const GC::Clear& source) +{ + auto& party = ProgramParty::s(); + dest = source; + party.convcbit = source; + party.untaint(); +} + void EvalRegister::input_helper(char value, octetStream& os) { set_mask(ProgramParty::s().input_masks.pop_front()); @@ -761,6 +783,7 @@ void EvalRegister::output() #endif check_signal_key(party.get_id(), garbled_entry); #endif + party.taint(); } #ifdef FREE_XOR @@ -855,24 +878,8 @@ void EvalRegister::unmask(GC::AuthValue& dest, word mask_share, int128 mac_mask_ #endif } -template -void EvalRegister::store_clear_in_dynamic(GC::Memory& mem, - const vector& accesses) -{ - for (auto access : accesses) - { - T& dest = mem[access.address]; - GC::Clear value = access.value; - ProgramParty& party = ProgramParty::s(); - dest.assign(value.get(), party.get_mac_key().get(), party.get_id() == 1); -#ifdef DEBUG_DYNAMIC - cout << "store clear " << dest.share << " " << dest.mac << " " << value << endl; -#endif - } -} - template <> -void RandomRegister::store(GC::Memory& mem, +void RandomRegister::store(NoMemory& mem, const vector< GC::WriteAccess< GC::Secret > >& accesses) { (void)mem; @@ -883,121 +890,9 @@ void RandomRegister::store(GC::Memory& mem, } } -template -void check_for_doubles(const vector& accesses, const char* name) -{ - (void)accesses; - (void)name; -#ifdef OUTPUT_DOUBLES - set seen; - int doubles = 0; - for (auto access : accesses) - { - if (seen.find(access.address) != seen.end()) - doubles++; - seen.insert(access.address); - } - cout << doubles << "/" << accesses.size() << " doubles in " << name << endl; -#endif -} - -template<> -void EvalRegister::store(GC::Memory& mem, - const vector< GC::WriteAccess< GC::Secret > >& accesses) -{ - check_for_doubles(accesses, "storing"); - ProgramParty& party = ProgramParty::s(); - vector< Share > S, S2, S3, S4, S5, SS; - vector exts; - int n_registers = 0; - for (auto access : accesses) - n_registers += access.source.get_regs().size(); - for (auto access : accesses) - { - GC::SpdzShare& dest = mem[access.address]; - dest.assign_zero(); - const vector& sources = access.source.get_regs(); - for (unsigned int i = 0; i < sources.size(); i++) - { - SpdzWire spdz_wire; - party.get_spdz_wire(SPDZ_STORE, spdz_wire); - const EvalRegister& reg = sources[i]; - Share tmp; - gf2n ext = (int)reg.get_external(); - //cout << "ext:" << ext << "/" << (int)reg.get_external() << " " << endl; - tmp.add(spdz_wire.mask, ext, (int)party.get_id() - 1, party.get_mac_key()); - S.push_back(tmp); - tmp *= gf2n(1) << i; - dest += tmp; - const Key& key = reg.external_key(party.get_id()); - Key& expected_key = spdz_wire.my_keys[(int)reg.get_external()]; - if (expected_key != key) - { - cout << "wire label: " << key << ", expected: " - << expected_key << endl; - cout << "opposite: " << spdz_wire.my_keys[1-reg.get_external()] << endl; - sources[i].keys.print(sources[i].get_id()); - throw runtime_error("key check failed"); - } -#ifdef DEBUG_SPDZ - S3.push_back(spdz_wire.mask); - S4.push_back(dest); - S5.push_back(tmp); - exts.push_back(ext); -#endif - } -#ifdef DEBUG_SPDZ - SS.push_back(dest); -#endif - } - -#ifdef DEBUG_SPDZ - party.MC->Check(*party.P); - vector v, v3, vv; - party.MC->POpen_Begin(vv, SS, *party.P); - party.MC->POpen_End(vv, SS, *party.P); - cout << "stored " << vv.back() << " from bits:"; - vv.pop_back(); - party.MC->Check(*party.P); - party.MC->POpen_Begin(v, S, *party.P); - party.MC->POpen_End(v, S, *party.P); - for (auto val : v) - cout << val.get_bit(0); - party.MC->Check(*party.P); - cout << " / exts:"; - for (auto ext : exts) - cout << ext.get_bit(0); - cout << " / masks:"; - party.MC->POpen_Begin(v3, S3, *party.P); - party.MC->POpen_End(v3, S3, *party.P); - for (auto val : v3) - cout << val.get_word(); - cout << endl; - party.MC->Check(*party.P); - cout << "share: " << SS.back() << endl; - party.MC->Check(*party.P); - - party.MC->POpen_Begin(v, S4, *party.P); - party.MC->POpen_End(v, S4, *party.P); - for (auto x : v) - cout << x << " "; - cout << endl; - - party.MC->POpen_Begin(v, S5, *party.P); - party.MC->POpen_End(v, S5, *party.P); - for (auto x : v) - cout << x << " "; - cout << endl; - - party.MC->POpen_Begin(v, S2, *party.P); - party.MC->POpen_End(v, S2, *party.P); - party.MC->Check(*party.P); -#endif -} - template <> void RandomRegister::load(vector > >& accesses, - const GC::Memory& source) + const NoMemory& source) { (void)source; for (auto access : accesses) @@ -1011,7 +906,7 @@ void RandomRegister::load(vector > >& template <> void GarbleRegister::load(vector > >& accesses, - const GC::Memory& source) + const NoMemory& source) { (void)source; for (auto access : accesses) @@ -1019,104 +914,6 @@ void GarbleRegister::load(vector > >& TrustedProgramParty::s().load_wire(reg); } -template <> -void PRFRegister::load(vector > >& accesses, - const GC::Memory& source) -{ - (void)source; - for (auto access : accesses) - for (auto& reg : access.dest.get_regs()) - { - ProgramParty::s().receive_keys(reg); - ProgramParty::s().store_wire(reg); - } -} - -template <> -void EvalRegister::load(vector > >& accesses, - const GC::Memory& mem) -{ - check_for_doubles(accesses, "loading"); - vector< Share > shares; - shares.reserve(accesses.size()); - ProgramParty& party = ProgramParty::s(); - deque spdz_wires; - vector< Share > S; - for (auto access : accesses) - { - const GC::SpdzShare& source = mem[access.address]; - Share mask; - vector& dests = access.dest.get_regs(); - for (unsigned int i = 0; i < dests.size(); i++) - { - spdz_wires.push_back({}); - ProgramParty::s().get_spdz_wire(SPDZ_LOAD, spdz_wires.back()); - mask += spdz_wires.back().mask << i; - } - shares.push_back(source + mask); -#ifdef DEBUG_SPDZ - S.push_back(source); -#endif - } - -#ifdef DEBUG_SPDZ - party.MC->Check(*party.P); - vector v; - party.MC->POpen_Begin(v, S, *party.P); - party.MC->POpen_End(v, S, *party.P); - for (size_t j = 0; j < accesses.size(); j++) - { - cout << "loaded " << v[j] << " / "; - vector& dests = accesses[j].dest.get_regs(); - for (unsigned int i = 0; i < dests.size(); i++) - cout << (int)dests[i].get_external(); - cout << " from " << S[j] << endl; - } - party.MC->Check(*party.P); -#endif - - vector masked; - party.MC->POpen_Begin(masked, shares, *party.P); - party.MC->POpen_End(masked, shares, *party.P); - vector keys(party.get_n_parties()); - - for (size_t j = 0; j < accesses.size(); j++) - { - vector& dests = accesses[j].dest.get_regs(); - for (unsigned int i = 0; i < dests.size(); i++) - { - bool ext = masked[j].get_bit(i); - party.load_wire(dests[i]); - dests[i].set_external(ext); - keys[party.get_id() - 1].serialize(spdz_wires.front().my_keys[ext]); - spdz_wires.pop_front(); - } - } - - party.P->Broadcast_Receive(keys, true); - - int base = 0; - for (auto access : accesses) - { - vector& dests = access.dest.get_regs(); - for (unsigned int i = 0; i < dests.size(); i++) - for (int j = 0; j < party.get_n_parties(); j++) - { - Key key; - keys[j].unserialize(key); - dests[i].set_external_key(j + 1, key); - } - base += dests.size() * party.get_n_parties(); - } - -#ifdef DEBUG_SPDZ - cout << "masked: "; - for (auto& m : masked) - cout << m << " "; - cout << endl; -#endif -} - void KeyVector::operator=(const KeyVector& other) { resize(other.size()); @@ -1224,8 +1021,3 @@ void KeyTuple::print(int wire_id, party_id_t pid) template class KeyTuple<2>; template class KeyTuple<4>; - -template void EvalRegister::store_clear_in_dynamic( - GC::Memory& mem, const vector& accesses); -template void EvalRegister::store_clear_in_dynamic( - GC::Memory& mem, const vector& accesses); diff --git a/BMR/Register.h b/BMR/Register.h index bcc2ff2cc..fe8711cc1 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -53,7 +53,12 @@ class BaseKeyVector #endif }; #else -typedef vector BaseKeyVector; +class BaseKeyVector : public vector +{ +public: + BaseKeyVector(int size = 0) : vector(size, Key(0)) {} + void resize(int size) { vector::resize(size, Key(0)); } +}; #endif class KeyVector : public BaseKeyVector @@ -195,6 +200,8 @@ inline BlackHole& flush(BlackHole& b) { return b; } class Phase { public: + typedef NoMemory DynamicMemory; + typedef BlackHole out_type; static const BlackHole out; @@ -210,12 +217,12 @@ class Phase { (void)dest; (void)mask_share; (void)mac_mask_share; (void)masked; (void)masked_mac; } template - static void store(GC::Memory& dest, + static void store(NoMemory& dest, const vector >& accesses) { (void)dest; (void)accesses; throw runtime_error("dynamic memory not implemented"); } template static void load(vector >& accesses, - const GC::Memory& source) + const NoMemory& source) { (void)accesses; (void)source; throw runtime_error("dynamic memory not implemented"); } template @@ -225,10 +232,12 @@ class Phase template static void inputb(T& processor, const vector& args) { processor.input(args); } template - static T get_input(int from, GC::Processor& processor, int n_bits) { return T::input(from, processor.get_input(n_bits), n_bits); } + static void convcbit(Integer& dest, const GC::Clear& source) + { (void) dest, (void) source; throw not_implemented(); } + void input(party_id_t from, char value = -1) { (void)from; (void)value; } void public_input(bool value) { (void)value; } void random() {} @@ -243,7 +252,7 @@ class ProgramRegister : public Phase, public Register static Register and_reg() { return new_reg(); } template - static void store(GC::Memory& dest, + static void store(NoMemory& dest, const vector >& accesses) { (void)dest; (void)accesses; } template static void load(vector >& accesses, @@ -266,11 +275,11 @@ class PRFRegister : public ProgramRegister template static void load(vector >& accesses, - const GC::Memory& source); + const NoMemory& source); PRFRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const ProgramRegister& left, const ProgramRegister& right, Function func); + void op(const PRFRegister& left, const PRFRegister& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char input = -1); void public_input(bool value); @@ -291,12 +300,12 @@ class EvalRegister : public ProgramRegister static void unmask(GC::AuthValue& dest, word mask_share, int128 mac_mask_share, word masked, int128 masked_mac); - template - static void store(GC::Memory& dest, + template + static void store(GC::Memory& dest, const vector >& accesses); - template + template static void load(vector >& accesses, - const GC::Memory& source); + const GC::Memory& source); template static void andrs(T& processor, const vector& args); @@ -310,6 +319,8 @@ class EvalRegister : public ProgramRegister throw runtime_error("use EvalRegister::inputb()"); } + static void convcbit(Integer& dest, const GC::Clear& source); + EvalRegister(const Register& reg) : ProgramRegister(reg) {} void op(const ProgramRegister& left, const ProgramRegister& right, Function func); @@ -335,7 +346,7 @@ class GarbleRegister : public ProgramRegister template static void load(vector >& accesses, - const GC::Memory& source); + const NoMemory& source); GarbleRegister(const Register& reg) : ProgramRegister(reg) {} @@ -353,11 +364,11 @@ class RandomRegister : public ProgramRegister static string name() { return "Randomization"; } template - static void store(GC::Memory& dest, + static void store(NoMemory& dest, const vector >& accesses); template static void load(vector >& accesses, - const GC::Memory& source); + const NoMemory& source); RandomRegister(const Register& reg) : ProgramRegister(reg) {} diff --git a/BMR/Register.hpp b/BMR/Register.hpp new file mode 100644 index 000000000..3537a4dde --- /dev/null +++ b/BMR/Register.hpp @@ -0,0 +1,238 @@ +/* + * Register.hpp + * + */ + +#ifndef BMR_REGISTER_HPP_ +#define BMR_REGISTER_HPP_ + +#include "Register.h" +#include "Party.h" + +template +void PRFRegister::load(vector >& accesses, + const NoMemory& source) +{ + (void)source; + for (auto access : accesses) + for (auto& reg : access.dest.get_regs()) + { + ProgramParty::s().receive_keys(reg); + ProgramParty::s().store_wire(reg); + } +} + +template +void EvalRegister::store_clear_in_dynamic(GC::Memory& mem, + const vector& accesses) +{ + for (auto access : accesses) + { + T& dest = mem[access.address]; + GC::Clear value = access.value; + ProgramParty& party = ProgramParty::s(); + dest.assign(value.get(), party.get_id() - 1, party.get_mac_key().get()); +#ifdef DEBUG_DYNAMIC + cout << "store clear " << dest.share << " " << dest.mac << " " << value << endl; +#endif + } +} + +template +void check_for_doubles(const vector& accesses, const char* name) +{ + (void)accesses; + (void)name; +#ifdef OUTPUT_DOUBLES + set seen; + int doubles = 0; + for (auto access : accesses) + { + if (seen.find(access.address) != seen.end()) + doubles++; + seen.insert(access.address); + } + cout << doubles << "/" << accesses.size() << " doubles in " << name << endl; +#endif +} + +template +void EvalRegister::store(GC::Memory& mem, + const vector< GC::WriteAccess >& accesses) +{ + check_for_doubles(accesses, "storing"); + auto& party = ProgramPartySpec::s(); + vector S, S2, S3, S4, S5, SS; + vector exts; + int n_registers = 0; + for (auto access : accesses) + n_registers += access.source.get_regs().size(); + for (auto access : accesses) + { + U& dest = mem[access.address]; + dest.assign_zero(); + const vector& sources = access.source.get_regs(); + for (unsigned int i = 0; i < sources.size(); i++) + { + DualWire spdz_wire; + party.get_spdz_wire(SPDZ_STORE, spdz_wire); + const EvalRegister& reg = sources[i]; + U tmp; + gf2n ext = (int)reg.get_external(); + //cout << "ext:" << ext << "/" << (int)reg.get_external() << " " << endl; + tmp.add(spdz_wire.mask, ext, (int)party.get_id() - 1, party.get_mac_key()); + S.push_back(tmp); + tmp *= gf2n(1) << i; + dest += tmp; + const Key& key = reg.external_key(party.get_id()); + Key& expected_key = spdz_wire.my_keys[(int)reg.get_external()]; + if (expected_key != key) + { + cout << "wire label: " << key << ", expected: " + << expected_key << endl; + cout << "opposite: " << spdz_wire.my_keys[1-reg.get_external()] << endl; + sources[i].keys.print(sources[i].get_id()); + throw runtime_error("key check failed"); + } +#ifdef DEBUG_SPDZ + S3.push_back(spdz_wire.mask); + S4.push_back(dest); + S5.push_back(tmp); + exts.push_back(ext); +#endif + } +#ifdef DEBUG_SPDZ + SS.push_back(dest); +#endif + } + +#ifdef DEBUG_SPDZ + party.MC->Check(*party.P); + vector v, v3, vv; + party.MC->POpen_Begin(vv, SS, *party.P); + party.MC->POpen_End(vv, SS, *party.P); + cout << "stored " << vv.back() << " from bits:"; + vv.pop_back(); + party.MC->Check(*party.P); + party.MC->POpen_Begin(v, S, *party.P); + party.MC->POpen_End(v, S, *party.P); + for (auto val : v) + cout << val.get_bit(0); + party.MC->Check(*party.P); + cout << " / exts:"; + for (auto ext : exts) + cout << ext.get_bit(0); + cout << " / masks:"; + party.MC->POpen_Begin(v3, S3, *party.P); + party.MC->POpen_End(v3, S3, *party.P); + for (auto val : v3) + cout << val.get_word(); + cout << endl; + party.MC->Check(*party.P); + cout << "share: " << SS.back() << endl; + party.MC->Check(*party.P); + + party.MC->POpen_Begin(v, S4, *party.P); + party.MC->POpen_End(v, S4, *party.P); + for (auto x : v) + cout << x << " "; + cout << endl; + + party.MC->POpen_Begin(v, S5, *party.P); + party.MC->POpen_End(v, S5, *party.P); + for (auto x : v) + cout << x << " "; + cout << endl; + + party.MC->POpen_Begin(v, S2, *party.P); + party.MC->POpen_End(v, S2, *party.P); + party.MC->Check(*party.P); +#endif +} + +template +void EvalRegister::load(vector >& accesses, + const GC::Memory& mem) +{ + check_for_doubles(accesses, "loading"); + vector shares; + shares.reserve(accesses.size()); + auto& party = ProgramPartySpec::s(); + deque> spdz_wires; + vector S; + for (auto access : accesses) + { + const U& source = mem[access.address]; + U mask; + vector& dests = access.dest.get_regs(); + for (unsigned int i = 0; i < dests.size(); i++) + { + spdz_wires.push_back({}); + party.get_spdz_wire(SPDZ_LOAD, spdz_wires.back()); + mask += spdz_wires.back().mask << i; + } + shares.push_back(source + mask); +#ifdef DEBUG_SPDZ + S.push_back(source); +#endif + } + +#ifdef DEBUG_SPDZ + party.MC->Check(*party.P); + vector v; + party.MC->POpen_Begin(v, S, *party.P); + party.MC->POpen_End(v, S, *party.P); + for (size_t j = 0; j < accesses.size(); j++) + { + cout << "loaded " << v[j] << " / "; + vector& dests = accesses[j].dest.get_regs(); + for (unsigned int i = 0; i < dests.size(); i++) + cout << (int)dests[i].get_external(); + cout << " from " << S[j] << endl; + } + party.MC->Check(*party.P); +#endif + + vector masked; + party.MC->POpen_Begin(masked, shares, *party.P); + party.MC->POpen_End(masked, shares, *party.P); + vector keys(party.get_n_parties()); + + for (size_t j = 0; j < accesses.size(); j++) + { + vector& dests = accesses[j].dest.get_regs(); + for (unsigned int i = 0; i < dests.size(); i++) + { + bool ext = masked[j].get_bit(i); + party.load_wire(dests[i]); + dests[i].set_external(ext); + keys[party.get_id() - 1].serialize(spdz_wires.front().my_keys[ext]); + spdz_wires.pop_front(); + } + } + + party.P->Broadcast_Receive(keys, true); + + int base = 0; + for (auto access : accesses) + { + vector& dests = access.dest.get_regs(); + for (unsigned int i = 0; i < dests.size(); i++) + for (int j = 0; j < party.get_n_parties(); j++) + { + Key key; + keys[j].unserialize(key); + dests[i].set_external_key(j + 1, key); + } + base += dests.size() * party.get_n_parties(); + } + +#ifdef DEBUG_SPDZ + cout << "masked: "; + for (auto& m : masked) + cout << m << " "; + cout << endl; +#endif +} + +#endif /* BMR_REGISTER_HPP_ */ diff --git a/BMR/SpdzWire.cpp b/BMR/SpdzWire.cpp deleted file mode 100644 index 752168b36..000000000 --- a/BMR/SpdzWire.cpp +++ /dev/null @@ -1,24 +0,0 @@ -/* - * SpdzWire.cpp - * - */ - -#include "SpdzWire.h" - -SpdzWire::SpdzWire() -{ - -} - -void SpdzWire::pack(octetStream& os) const -{ - mask.pack(os); - os.serialize(my_keys); -} - -void SpdzWire::unpack(octetStream& os, size_t wanted_size) -{ - (void)wanted_size; - mask.unpack(os); - os.unserialize(my_keys); -} diff --git a/BMR/SpdzWire.h b/BMR/SpdzWire.h index c9a95bfeb..b54fd063b 100644 --- a/BMR/SpdzWire.h +++ b/BMR/SpdzWire.h @@ -9,15 +9,32 @@ #include "Math/Share.h" #include "Key.h" -class SpdzWire +template +class DualWire { public: - Share mask; + T mask; Key my_keys[2]; - SpdzWire(); - void pack(octetStream& os) const; - void unpack(octetStream& os, size_t wanted_size); + DualWire() + { + my_keys[0] = 0; + my_keys[1] = 0; + } + + void pack(octetStream& os) const + { + mask.pack(os); + os.serialize(my_keys); + } + void unpack(octetStream& os, size_t wanted_size) + { + (void)wanted_size; + mask.unpack(os); + os.unserialize(my_keys); + } }; +typedef DualWire> SpdzWire; + #endif /* BMR_SPDZWIRE_H_ */ diff --git a/BMR/TrustedParty.cpp b/BMR/TrustedParty.cpp index 431cae27b..609d37dac 100644 --- a/BMR/TrustedParty.cpp +++ b/BMR/TrustedParty.cpp @@ -17,7 +17,19 @@ #include "SpdzWire.h" #include "Auth/fake-stuff.h" +#include "Register_inline.h" + +#include "CommonParty.hpp" #include "Auth/fake-stuff.hpp" +#include "BMR/Register.hpp" +#include "GC/Machine.hpp" +#include "GC/Processor.hpp" +#include "GC/Secret.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" +#include "GC/Program.hpp" +#include "GC/Instruction.hpp" +#include "Processor/Instruction.hpp" TrustedProgramParty* TrustedProgramParty::singleton = 0; @@ -53,8 +65,8 @@ TrustedParty::TrustedParty(const char* netmap_file, // required to init Node } TrustedProgramParty::TrustedProgramParty(int argc, char** argv) : - machine(dynamic_memory), processor(machine), - random_machine(dynamic_memory), random_processor(random_machine) + processor(machine), + random_processor(random_machine) { if (argc < 2) { @@ -406,7 +418,8 @@ bool TrustedProgramParty::_fill_keys() void TrustedProgramParty::garble() { - second_phase(program, processor, machine); + NoMemory dynamic_memory; + second_phase(program, processor, machine, dynamic_memory); vector< Share > tmp; make_share(tmp, 1, get_n_parties(), mac_key, prng); diff --git a/BMR/TrustedParty.h b/BMR/TrustedParty.h index 2cfa464f4..2db406fa4 100644 --- a/BMR/TrustedParty.h +++ b/BMR/TrustedParty.h @@ -14,7 +14,7 @@ #include "Register.h" #include "CommonParty.h" -class BaseTrustedParty : virtual public CommonParty { +class BaseTrustedParty : virtual public CommonFakeParty { public: vector prf_outputs; vector msg_input_masks; @@ -109,7 +109,6 @@ class TrustedProgramParty : public BaseTrustedParty { static TrustedProgramParty* singleton; static TrustedProgramParty& s(); - GC::Memory< GC::Secret::DynamicType > dynamic_memory; GC::Machine< GC::Secret > machine; GC::Processor< GC::Secret > processor; GC::Program< GC::Secret > program; diff --git a/BMR/aes.cpp b/BMR/aes.cpp index 19bb3531c..484d83a28 100644 --- a/BMR/aes.cpp +++ b/BMR/aes.cpp @@ -1,4 +1,6 @@ #include "aes.h" +#include "Tools/aes.h" +#include "Tools/cpu_support.h" #include #ifdef _WIN32 @@ -7,23 +9,25 @@ void AES_128_Key_Expansion(const unsigned char *userkey, AES_KEY *aesKey) { - block x0,x1,x2; - //block *kp = (block *)&aesKey; - aesKey->rd_key[0] = x0 = _mm_loadu_si128((block*)userkey); - x2 = _mm_setzero_si128(); #ifdef __AES__ - EXPAND_ASSIST(x0, x1, x2, x0, 255, 1); aesKey->rd_key[1] = x0; - EXPAND_ASSIST(x0, x1, x2, x0, 255, 2); aesKey->rd_key[2] = x0; - EXPAND_ASSIST(x0, x1, x2, x0, 255, 4); aesKey->rd_key[3] = x0; - EXPAND_ASSIST(x0, x1, x2, x0, 255, 8); aesKey->rd_key[4] = x0; - EXPAND_ASSIST(x0, x1, x2, x0, 255, 16); aesKey->rd_key[5] = x0; - EXPAND_ASSIST(x0, x1, x2, x0, 255, 32); aesKey->rd_key[6] = x0; - EXPAND_ASSIST(x0, x1, x2, x0, 255, 64); aesKey->rd_key[7] = x0; - EXPAND_ASSIST(x0, x1, x2, x0, 255, 128); aesKey->rd_key[8] = x0; - EXPAND_ASSIST(x0, x1, x2, x0, 255, 27); aesKey->rd_key[9] = x0; - EXPAND_ASSIST(x0, x1, x2, x0, 255, 54); aesKey->rd_key[10] = x0; -#else - (void) x1, (void) x2; - throw std::runtime_error("need to compile with AES-NI support"); + if (cpu_has_aes()) + { + block x0,x1,x2; + //block *kp = (block *)&aesKey; + aesKey->rd_key[0] = x0 = _mm_loadu_si128((block*)userkey); + x2 = _mm_setzero_si128(); + EXPAND_ASSIST(x0, x1, x2, x0, 255, 1); aesKey->rd_key[1] = x0; + EXPAND_ASSIST(x0, x1, x2, x0, 255, 2); aesKey->rd_key[2] = x0; + EXPAND_ASSIST(x0, x1, x2, x0, 255, 4); aesKey->rd_key[3] = x0; + EXPAND_ASSIST(x0, x1, x2, x0, 255, 8); aesKey->rd_key[4] = x0; + EXPAND_ASSIST(x0, x1, x2, x0, 255, 16); aesKey->rd_key[5] = x0; + EXPAND_ASSIST(x0, x1, x2, x0, 255, 32); aesKey->rd_key[6] = x0; + EXPAND_ASSIST(x0, x1, x2, x0, 255, 64); aesKey->rd_key[7] = x0; + EXPAND_ASSIST(x0, x1, x2, x0, 255, 128); aesKey->rd_key[8] = x0; + EXPAND_ASSIST(x0, x1, x2, x0, 255, 27); aesKey->rd_key[9] = x0; + EXPAND_ASSIST(x0, x1, x2, x0, 255, 54); aesKey->rd_key[10] = x0; + } + else #endif + aes_128_schedule((uint*) aesKey->rd_key, userkey); } diff --git a/BMR/config.h b/BMR/config.h index 473674d6b..2d76c14b4 100644 --- a/BMR/config.h +++ b/BMR/config.h @@ -7,7 +7,8 @@ #define BMR_CONFIG_H_ // change number of parties here or omit to allow any number -#define N_PARTIES 2 +//#define N_PARTIES 2 +#define MAX_N_PARTIES 3 #define FREE_XOR #define KEY_SIGNAL @@ -15,4 +16,7 @@ #define NO_INPUT #define MAX_INLINE +//#define SIGNAL_CHECK +//#define DEBUG_MASK + #endif /* BMR_CONFIG_H_ */ diff --git a/BMR/network/Node.cpp b/BMR/network/Node.cpp index 20f633029..ef583d76e 100644 --- a/BMR/network/Node.cpp +++ b/BMR/network/Node.cpp @@ -34,7 +34,7 @@ Node::Node(const char* netmap_file, int my_id, NodeUpdatable* updatable, int num throw_bad_id(_id); _ready_nodes = new bool[_numparties](); //initialized to false _clients_connected = new bool[_numparties](); - _server = new Server(_port, _numparties-1, this, max_message_size); + _server = new BIU::Server(_port, _numparties-1, this, max_message_size); _client = new Client(_endpoints, _numparties-1, this, max_message_size); } @@ -169,7 +169,7 @@ void Node::Broadcast2(SendBuffer& msg) { void Node::_identify() { char* msg = id_msg; strncpy(msg, ID_HDR, strlen(ID_HDR)); - strncpy(msg+strlen(ID_HDR), (const char *)&_id, sizeof(_id)); + memcpy(msg+strlen(ID_HDR), (const char *)&_id, sizeof(_id)); //printf("Node:: identifying myself:\n"); SendBuffer buffer; buffer.serialize(msg, strlen(ID_HDR)+4); diff --git a/BMR/network/Node.h b/BMR/network/Node.h index 5eb9f1912..cd4839396 100644 --- a/BMR/network/Node.h +++ b/BMR/network/Node.h @@ -70,7 +70,7 @@ class Node : public ServerUpdatable, public ClientUpdatable { endpoint_t* _endpoints; Client* _client; - Server* _server; + BIU::Server* _server; bool* _ready_nodes; volatile bool _connected_to_servers; std::atomic_int _num_parties_identified; diff --git a/BMR/network/Server.cpp b/BMR/network/Server.cpp index 3ced867a9..c5263df89 100644 --- a/BMR/network/Server.cpp +++ b/BMR/network/Server.cpp @@ -13,6 +13,9 @@ #include "Server.h" +namespace BIU +{ + /* Opens server socket for listening - not yet accepting */ Server::Server(int port, int expected_clients, ServerUpdatable* updatable, unsigned int max_message_size) :starter(0), @@ -142,3 +145,5 @@ bool Server::_handle_recv_len(int id, size_t actual_len, size_t expected_len) { } return true; } + +} diff --git a/BMR/network/Server.h b/BMR/network/Server.h index 7928e6d73..a454686a1 100644 --- a/BMR/network/Server.h +++ b/BMR/network/Server.h @@ -20,6 +20,9 @@ class ServerUpdatable { virtual void NodeAborted(struct sockaddr_in* from) =0; }; +namespace BIU +{ + class Server { public: Server(int port, int expected_clients, ServerUpdatable* updatable, unsigned int max_message_size); @@ -48,6 +51,6 @@ class Server { bool _handle_recv_len(int id, size_t actual_len, size_t expected_len); }; - +} #endif /* NETWORK_INC_SERVER_H_ */ diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ed00dd14..904165cc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enchancements are committed between releases and not documented here. +## 0.0.9 (Apr 30, 2019) + +- Complete BMR for all GF(2^n) protocols +- [Use your Brain!](https://eprint.iacr.org/2019/164) +- Semi/Semi2k for semi-honest OT-based computation +- Branching on revealed values in garbled circuits +- Fixed security bug: Potentially revealing too much information when opening linear combinations of private inputs in MASCOT and SPDZ2k with more than two parties + ## 0.0.8 (Mar 28, 2019) - SPDZ2k diff --git a/CONFIG b/CONFIG index b7b3f8bab..e386db27a 100644 --- a/CONFIG +++ b/CONFIG @@ -22,7 +22,8 @@ USE_GF2N_LONG = 1 # AVX2 support (Haswell or later) is used to optimize OT # AVX/AVX2 is required for replicated binary secret sharing # BMI2 is used to optimize multiplication modulo a prime -ARCH = -mtune=native -msse4.1 -maes -mpclmul -mavx -mavx2 -mbmi2 +# ADX is used to optimize big integer additions +ARCH = -mtune=native -msse4.1 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx # allow to set compiler in CONFIG.mine CXX = g++ @@ -50,7 +51,11 @@ ifeq ($(OS), Linux) LDLIBS += -lrt endif -BOOST = -lboost_system -lboost_thread $(MY_BOOST) +ifeq ($(OS), Darwin) +BOOST = -lboost_thread-mt $(MY_BOOST) +else +BOOST = -lboost_thread $(MY_BOOST) +endif CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -std=c++11 -Werror CPPFLAGS = $(CFLAGS) diff --git a/Check-Offline.cpp b/Check-Offline.cpp index f2b82fb57..77b9aa264 100644 --- a/Check-Offline.cpp +++ b/Check-Offline.cpp @@ -10,6 +10,7 @@ #include "Auth/MAC_Check.h" #include "Tools/ezOptionParser.h" #include "Exceptions/Exceptions.h" +#include "GC/MaliciousRepSecret.h" #include "Math/Setup.h" #include "Processor/Data_Files.h" diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index 8a3cf6aaa..8526cc62a 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -35,6 +35,7 @@ class ClearBitsAF(base.RegisterArgFormat): STMSDCI = 0x215, INPUTB = 0x216, PRINTREGSIGNED = 0x220, + CONVCBIT = 0x230, ) class xors(base.Instruction): @@ -153,6 +154,10 @@ class convcint(base.Instruction): code = opcodes['CONVCINT'] arg_format = ['cbw','ci'] +class convcbit(base.Instruction): + code = opcodes['CONVCBIT'] + arg_format = ['ciw','cb'] + class movs(base.Instruction): code = base.opcodes['MOVS'] arg_format = ['sbw','sb'] diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index d864426b0..731186616 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -176,6 +176,10 @@ def print_if(self, string): inst.cond_print_str(self, string) def reveal(self): return self + def to_regint(self, dest): + if self.n > 64: + raise CompilerError('too many bits') + inst.convcbit(dest, self) class sbits(bits): max_length = 128 diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py index 684179947..1011cd589 100644 --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -4,6 +4,9 @@ ORAM = OptimalORAM +prog = program.Program.prog +prog.set_bit_length(min(64, prog.bit_length)) + class HeapEntry(object): fields = ['empty', 'prio', 'value'] def __init__(self, int_type, *args): diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 2c65dd0ca..ceb37d111 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -1318,7 +1318,11 @@ class convmodp(base.Instruction): code = base.opcodes['CONVMODP'] arg_format = ['ciw', 'c', 'int'] def __init__(self, *args, **kwargs): - bitlength = kwargs.get('bitlength', program.bit_length) + bitlength = kwargs.get('bitlength') + bitlength = program.bit_length if bitlength is None else bitlength + if bitlength > 64: + raise CompilerError('%d-bit conversion requested ' \ + 'but integer registers only have 64 bits') super(convmodp_class, self).__init__(*(args + (bitlength,))) @base.vectorize diff --git a/Compiler/oram.py b/Compiler/oram.py index 52e33aa8f..12786e4f2 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -375,7 +375,7 @@ def init_mem(self, empty_entry): print 'init ram' for a,value in zip(self.l, empty_entry.defaults.values()): # don't use threads if n_threads explicitly set to 1 - a.assign_all(value, n_threads != 1) + a.assign_all(value, n_threads != 1, conv=False) def get_empty_bits(self): return self.l[0] def get_indices(self): diff --git a/Compiler/path_oram.py b/Compiler/path_oram.py index f8ebd1e3f..e1265a0ce 100644 --- a/Compiler/path_oram.py +++ b/Compiler/path_oram.py @@ -7,6 +7,9 @@ #import pdb +prog = program.Program.prog +prog.set_bit_length(min(64, prog.bit_length)) + class Counter(object): def __init__(self, val=0, max_val=None, size=None, value_type=sgf2n): if value_type is sgf2n: diff --git a/Compiler/program.py b/Compiler/program.py index 3a7fc6176..6fdb23af0 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -139,7 +139,10 @@ def init_names(self, args, assemblymode): self.name is input file name (minus extension) + any optional arguments. Used to generate output filenames """ - self.name = progname + if self.options.outfile: + self.name = self.options.outfile + '-' + progname + else: + self.name = progname if len(args) > 1: self.name += '-' + '-'.join(args[1:]) self.progname = progname diff --git a/Compiler/types.py b/Compiler/types.py index 2960ba394..4c8db3d81 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -502,6 +502,11 @@ def load_int(self, val): elif chunk: sum += sign * chunk + def to_regint(self, n_bits=None, dest=None): + dest = regint() if dest is None else dest + convmodp(dest, self, bitlength=n_bits) + return dest + def __mod__(self, other): return self.clear_op(other, modc, modci) @@ -761,15 +766,13 @@ def load_int(self, val): @read_mem_value def load_other(self, val): - if isinstance(val, cint): - convmodp(self, val) - elif isinstance(val, cgf2n): + if isinstance(val, cgf2n): gconvgf2n(self, val) elif isinstance(val, regint): addint(self, val, regint(0)) else: try: - val.to_regint(self) + val.to_regint(dest=self) except AttributeError: raise CompilerError("Cannot convert '%s' to integer" % \ type(val)) @@ -872,7 +875,7 @@ def mod2m(self, *args, **kwargs): @vectorize def bit_decompose(self, bit_length=None): - bit_length = bit_length or program.bit_length + bit_length = bit_length or min(64, program.bit_length) if bit_length > 64: raise CompilerError('too many bits demanded') res = [regint() for i in range(bit_length)] @@ -2450,8 +2453,6 @@ class squant(_single): """ Quantization as in ArXiv:1712.05877v1 """ __slots__ = ['params'] int_type = sint - # cheaper probabilistic truncation - max_length = 63 clamp = True @classmethod @@ -2581,6 +2582,12 @@ def __init__(self, S, Z=0, k=8): self.Z = Z self.k = k self._store = {} + if program.options.ring: + # cheaper probabilistic truncation + self.max_length = int(program.options.ring) - 1 + else: + # safe choice for secret shift + self.max_length = 71 def __iter__(self): yield self.S @@ -2594,7 +2601,7 @@ def get(self, input_params, n_summands): p = input_params M = p[0].S * p[1].S / self.S logM = util.log2(M) - n_shift = squant.max_length - p[0].k - p[1].k - util.log2(n_summands) + n_shift = self.max_length - p[0].k - p[1].k - util.log2(n_summands) if util.is_constant_float(M): n_shift -= logM int_mult = int(round(M * 2 ** (n_shift))) @@ -2621,10 +2628,10 @@ def reduce(self, unreduced): n_shift = util.expand(n_shift, size) shifted_Z = util.expand(shifted_Z, size) tmp = unreduced.v * int_mult + shifted_Z - shifted = tmp.round(squant.max_length, n_shift, + shifted = tmp.round(self.max_length, n_shift, squant.kappa, squant.round_nearest) if squant.clamp: - length = max(self.k, squant.max_length - n_shift) + 1 + length = max(self.k, self.max_length - n_shift) + 1 top = (1 << self.k) - 1 over = shifted.greater_than(top, length, squant.kappa) under = shifted.less_than(0, length, squant.kappa) @@ -3122,8 +3129,10 @@ def loop(i): self[i] = j return self - def assign_all(self, value, use_threads=True): - mem_value = MemValue(self.value_type.conv(value)) + def assign_all(self, value, use_threads=True, conv=True): + if conv: + value = self.value_type.conv(value) + mem_value = MemValue(value) n_threads = 8 if use_threads and len(self) > 2**20 else 1 @library.for_range_multithread(n_threads, 1024, len(self)) def f(i): @@ -3265,7 +3274,7 @@ def _(i): @library.for_range(other.sizes[1]) def _(j): res_matrix[i][j] = 0 - @library.for_range(self.sizes[0]) + @library.for_range(self.sizes[1]) def _(k): res_matrix[i][j] += self[i][k] * other[k][j] return res_matrix diff --git a/Exceptions/Exceptions.h b/Exceptions/Exceptions.h index b0430f95a..58e22c4cc 100644 --- a/Exceptions/Exceptions.h +++ b/Exceptions/Exceptions.h @@ -202,6 +202,6 @@ class not_enough_to_buffer : public runtime_error { } }; - +class needs_cleaning : public exception {}; #endif diff --git a/Fake-Offline.cpp b/Fake-Offline.cpp index 68635352d..ed83e07a4 100644 --- a/Fake-Offline.cpp +++ b/Fake-Offline.cpp @@ -4,6 +4,9 @@ #include "Math/Share.h" #include "Math/Setup.h" #include "Math/Spdz2kShare.h" +#include "Math/BrainShare.h" +#include "Math/MaliciousRep3Share.h" +#include "Math/SemiShare.h" #include "Auth/fake-stuff.h" #include "Exceptions/Exceptions.h" #include "GC/MaliciousRepSecret.h" @@ -307,8 +310,11 @@ void make_basic(const typename T::mac_type& key, int nplayers, int nitems, bool make_bits(key, nplayers, nitems, zero); make_square_tuples(key, nplayers, nitems, T::type_short(), zero); make_inputs(key, nplayers, nitems, T::type_short(), zero); - make_inverse(key, nplayers, nitems, zero); - make_PreMulC(key, nplayers, nitems, zero); + if (T::clear::invertible) + { + make_inverse(key, nplayers, nitems, zero); + make_PreMulC(key, nplayers, nitems, zero); + } } template @@ -631,10 +637,15 @@ int generate(ez::ezOptionParser& opt) make_bits>({}, nplayers, nbitsp, zero); make_basic>({}, nplayers, default_num, zero); make_basic>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); make_mult_triples({}, nplayers, ntrip2, zero); make_bits({}, nplayers, nbits2, zero); } + make_basic>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); + return 0; } diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 05952a723..cd85ef3da 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -29,6 +29,7 @@ class FakeSecret public: typedef FakeSecret DynamicType; + typedef Memory DynamicMemory; // dummy typedef DummyMC MC; @@ -59,6 +60,8 @@ class FakeSecret static void trans(Processor& processor, int n_inputs, const vector& args); + static void convcbit(Integer& dest, const Clear& source) { dest = source; } + static FakeSecret input(int from, GC::Processor& processor, int n_bits); static FakeSecret input(int from, const int128& input, int n_bits); @@ -88,7 +91,7 @@ class FakeSecret void random_bit() { a = random() % 2; } - void reveal(Clear& x) { x = a; } + void reveal(int n_bits, Clear& x) { (void) n_bits; x = a; } int size() { return -1; } }; diff --git a/GC/Instruction.h b/GC/Instruction.h index ffd9cf3be..b8e3d7958 100644 --- a/GC/Instruction.h +++ b/GC/Instruction.h @@ -55,7 +55,8 @@ class Instruction : public ::BaseInstruction // Execute this instruction bool exe(Processor& processor) const { return code(*this, processor); } - bool execute(Processor& processor) const; + template + bool execute(Processor& processor, U& dynamic_memory) const; }; enum @@ -86,6 +87,8 @@ enum INPUTB = 0x216, // don't write PRINTREGSIGNED = 0x220, + // write to regint + CONVCBIT = 0x230, }; } /* namespace GC */ diff --git a/GC/Instruction.hpp b/GC/Instruction.hpp index 0f867327d..9e8f297d9 100644 --- a/GC/Instruction.hpp +++ b/GC/Instruction.hpp @@ -8,25 +8,15 @@ #include "GC/Instruction.h" #include "GC/Processor.h" -#ifdef MAX_INLINE -#include "GC/Secret_inline.h" -#endif #include "Processor/Instruction.h" -#include "Secret.h" #include "Tools/parse.h" #include "GC/Instruction_inline.h" -#include "GC/ReplicatedSecret.h" namespace GC { -#define X(NAME, CODE) template bool NAME##_code(const Instruction& instruction, \ - Processor& processor) { (void)instruction; (void)processor; CODE; return true; } - INSTRUCTIONS -#undef X - template Instruction::Instruction() : BaseInstruction() @@ -66,6 +56,13 @@ int Instruction::get_reg_type() const switch (opcode) { case LDMC: + case STMC: + case XORC: + case ADDC: + case ADDCI: + case MULCI: + case SHRCI: + case SHLCI: return CBIT; } return SBIT; @@ -99,6 +96,8 @@ unsigned GC::Instruction::get_max_reg(int reg_type) const skip = 3; offset = 2; break; + case CONVCBIT: + return BaseInstruction::get_max_reg(INT); default: return BaseInstruction::get_max_reg(reg_type); } @@ -189,6 +188,7 @@ void Instruction::parse(istream& s, int pos) get_vector(m, start, s); break; case CONVCINT: + case CONVCBIT: get_ints(r, s, 2); break; case REVEAL: @@ -227,7 +227,6 @@ void Instruction::parse(istream& s, int pos) switch(opcode) { #define X(NAME, CODE) case NAME: \ - code = NAME##_code; \ break; INSTRUCTIONS #undef X diff --git a/GC/Instruction_inline.h b/GC/Instruction_inline.h index 9f122960b..1ab2735b5 100644 --- a/GC/Instruction_inline.h +++ b/GC/Instruction_inline.h @@ -28,7 +28,9 @@ inline bool fallback_code(const Instruction& instruction, Processor& proce } template -MAYBE_INLINE bool Instruction::execute(Processor& processor) const +template +MAYBE_INLINE bool Instruction::execute(Processor& processor, + U& dynamic_memory) const { #ifdef DEBUG_OPS cout << typeid(T).name() << " "; diff --git a/GC/Machine.cpp b/GC/Machine.cpp index 384670db9..48d360fcd 100644 --- a/GC/Machine.cpp +++ b/GC/Machine.cpp @@ -16,6 +16,7 @@ #include "Processor/Machine.hpp" #include "Processor/Instruction.hpp" +#include "Auth/MaliciousRepMC.hpp" namespace GC { diff --git a/GC/Machine.h b/GC/Machine.h index 4244128fc..88c62c54b 100644 --- a/GC/Machine.h +++ b/GC/Machine.h @@ -27,20 +27,22 @@ class Machine : public ::BaseMachine Memory MS; Memory MC; Memory MI; - Memory& MD; vector > progs; bool use_encryption; bool more_comm_less_comp; - Machine(Memory& MD); + Machine(); ~Machine(); void load_schedule(string progname); void load_program(string threadname, string filename); - void reset(const Program& program); + template + void reset(const U& program); + template + void reset(const U& program, V& dynamic_memory); void start_timer() { timer[0].start(); } void stop_timer() { timer[0].stop(); } @@ -48,6 +50,8 @@ class Machine : public ::BaseMachine void run_tape(int thread_number, int tape_number, int arg); void join_tape(int thread_numer); + + void write_memory(int my_num); }; } /* namespace GC */ diff --git a/GC/Machine.hpp b/GC/Machine.hpp index 78ce078ee..2c5b18752 100644 --- a/GC/Machine.hpp +++ b/GC/Machine.hpp @@ -6,15 +6,13 @@ #include #include "GC/Program.h" -#include "Secret.h" -#include "ReplicatedSecret.h" #include "ThreadMaster.h" namespace GC { template -Machine::Machine(Memory& dynamic_memory) : MD(dynamic_memory) +Machine::Machine() { use_encryption = false; more_comm_less_comp = false; @@ -55,12 +53,23 @@ void Machine::load_schedule(string progname) } template -void Machine::reset(const Program& program) +template +void Machine::reset(const U& program) { MS.resize_min(program.direct_mem(SBIT), "memory"); MC.resize_min(program.direct_mem(CBIT), "memory"); MI.resize_min(program.direct_mem(INT), "memory"); +} + +template +template +void Machine::reset(const U& program, V& MD) +{ + reset(program); MD.resize_min(program.direct_mem(DYN_SBIT), "dynamic memory"); +#ifdef DEBUG_MEMORY + cerr << "reset dynamic mem to " << program.direct_mem(DYN_SBIT) << endl; +#endif } template @@ -75,4 +84,13 @@ void Machine::join_tape(int thread_number) ThreadMaster::s().join_tape(thread_number); } +template +void GC::Machine::write_memory(int my_num) +{ + ofstream outf(memory_filename("B", my_num)); + outf << 0 << endl; + outf << MC.size() << endl << MC; + outf << 0 << endl << 0 << endl << 0 << endl << 0 << endl; +} + } /* namespace GC */ diff --git a/GC/MaliciousRepSecret.h b/GC/MaliciousRepSecret.h index 0fcb5811b..0ecca4d9b 100644 --- a/GC/MaliciousRepSecret.h +++ b/GC/MaliciousRepSecret.h @@ -20,7 +20,7 @@ class MaliciousRepSecret : public ReplicatedSecret typedef ReplicatedSecret super; public: - typedef MaliciousRepSecret DynamicType; + typedef Memory DynamicMemory; typedef MaliciousRepMC MC; diff --git a/GC/Memory.h b/GC/Memory.h index da0f4dac0..89f01d4b9 100644 --- a/GC/Memory.h +++ b/GC/Memory.h @@ -16,6 +16,10 @@ using namespace std; #include "Clear.h" #include "config.h" +class NoMemory +{ +}; + namespace GC { @@ -28,6 +32,7 @@ class Memory : public vector void check_index(Integer index) const; T& operator[] (Integer i); const T& operator[] (Integer i) const; + size_t capacity_in_bytes() const { return this->capacity() * sizeof(T); } template Memory& cast() { return *reinterpret_cast< Memory* >(this); } @@ -84,6 +89,14 @@ inline void Memory::resize_min(size_t size, const char* name) resize(size, name); } +template +inline ostream& operator<<(ostream& s, const Memory& memory) +{ + for (auto& x : memory) + x.output(s, false); + return s; +} + } /* namespace GC */ #endif /* GC_MEMORY_H_ */ diff --git a/GC/Processor.h b/GC/Processor.h index 76d5b2345..f5b744336 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -56,8 +56,10 @@ class Processor : public ::ProcessorBase Processor(Machine& machine); ~Processor(); - void reset(const Program& program, int arg); - void reset(const Program& program); + template + void reset(const U& program, int arg); + template + void reset(const U& program); long long get_input(int n_bits, bool interactive = false); @@ -68,11 +70,16 @@ class Processor : public ::ProcessorBase void random_bit(T &x) { x.random_bit(); } - void load_dynamic_direct(const vector& args); - void store_dynamic_direct(const vector& args); - void load_dynamic_indirect(const vector& args); - void store_dynamic_indirect(const vector& args); - void store_clear_in_dynamic(const vector& args); + template + void load_dynamic_direct(const vector& args, U& dynamic_memory); + template + void store_dynamic_direct(const vector& args, U& dynamic_memory); + template + void load_dynamic_indirect(const vector& args, U& dynamic_memory); + template + void store_dynamic_indirect(const vector& args, U& dynamic_memory); + template + void store_clear_in_dynamic(const vector& args, U& dynamic_memory); void xors(const vector& args); void and_(const vector& args, bool repeat); diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 4a7e270e3..871942df1 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -10,9 +10,7 @@ using namespace std; #include "GC/Program.h" -#include "Secret.h" #include "Access.h" -#include "ReplicatedSecret.h" namespace GC { @@ -33,7 +31,8 @@ Processor::~Processor() } template -void Processor::reset(const Program& program, int arg) +template +void Processor::reset(const U& program, int arg) { S.resize(program.num_reg(SBIT), "registers"); C.resize(program.num_reg(CBIT), "registers"); @@ -43,7 +42,8 @@ void Processor::reset(const Program& program, int arg) } template -void Processor::reset(const Program& program) +template +void Processor::reset(const U& program) { reset(program, 0); machine.reset(program); @@ -89,59 +89,69 @@ void Processor::bitdecint(const vector& regs, const Integer& x) } template -void GC::Processor::load_dynamic_direct(const vector& args) +template +void Processor::load_dynamic_direct(const vector& args, + U& dynamic_memory) { vector< ReadAccess > accesses; if (args.size() % 3 != 0) throw runtime_error("invalid number of arguments"); for (size_t i = 0; i < args.size(); i += 3) accesses.push_back({S[args[i]], args[i+1], args[i+2], complexity}); - T::load(accesses, machine.MD); + T::load(accesses, dynamic_memory); } template -void GC::Processor::load_dynamic_indirect(const vector& args) +template +void GC::Processor::load_dynamic_indirect(const vector& args, + U& dynamic_memory) { vector< ReadAccess > accesses; if (args.size() % 3 != 0) throw runtime_error("invalid number of arguments"); for (size_t i = 0; i < args.size(); i += 3) accesses.push_back({S[args[i]], C[args[i+1]], args[i+2], complexity}); - T::load(accesses, machine.MD); + T::load(accesses, dynamic_memory); } template -void GC::Processor::store_dynamic_direct(const vector& args) +template +void GC::Processor::store_dynamic_direct(const vector& args, + U& dynamic_memory) { vector< WriteAccess > accesses; if (args.size() % 2 != 0) throw runtime_error("invalid number of arguments"); for (size_t i = 0; i < args.size(); i += 2) accesses.push_back({args[i+1], S[args[i]]}); - T::store(machine.MD, accesses); + T::store(dynamic_memory, accesses); complexity += accesses.size() / 2 * T::default_length; } template -void GC::Processor::store_dynamic_indirect(const vector& args) +template +void GC::Processor::store_dynamic_indirect(const vector& args, + U& dynamic_memory) { vector< WriteAccess > accesses; if (args.size() % 2 != 0) throw runtime_error("invalid number of arguments"); for (size_t i = 0; i < args.size(); i += 2) accesses.push_back({C[args[i+1]], S[args[i]]}); - T::store(machine.MD, accesses); + T::store(dynamic_memory, accesses); complexity += accesses.size() / 2 * T::default_length; } template -void GC::Processor::store_clear_in_dynamic(const vector& args) +template +void GC::Processor::store_clear_in_dynamic(const vector& args, + U& dynamic_memory) { vector accesses; check_args(args, 2); for (size_t i = 0; i < args.size(); i += 2) accesses.push_back({C[args[i+1]], C[args[i]]}); - T::store_clear_in_dynamic(machine.MD, accesses); + T::store_clear_in_dynamic(dynamic_memory, accesses); } template diff --git a/GC/Program.h b/GC/Program.h index ed13a1455..53b0e3a08 100644 --- a/GC/Program.h +++ b/GC/Program.h @@ -18,6 +18,7 @@ enum BreakType { TIME_BREAK, DONE_BREAK, CAP_BREAK, + CLEANING_BREAK, }; template class Processor; @@ -59,9 +60,8 @@ class Program unsigned direct_mem(RegType reg_type) const { return max_mem[reg_type]; } - // Execute this program, updateing the processor and memory - // and streams pointing to the triples etc - BreakType execute(Processor& Proc, int PC = -1) const; + template + BreakType execute(Processor& Proc, U& dynamic_memory, int PC = -1) const; bool done(Processor& Proc) const { return Proc.PC >= p.size(); } diff --git a/GC/Program.hpp b/GC/Program.hpp index 2fa19dfa4..24ceee76f 100644 --- a/GC/Program.hpp +++ b/GC/Program.hpp @@ -5,16 +5,10 @@ #include -#include "Secret.h" -#include "ReplicatedSecret.h" #include "config.h" #include "Tools/callgrind.h" -#ifdef MAX_INLINE -#include "Instruction_inline.h" -#endif - namespace GC { @@ -98,8 +92,10 @@ void Program::print_offline_cost() const } template +template __attribute__((flatten)) -BreakType Program::execute(Processor& Proc, int PC) const +BreakType Program::execute(Processor& Proc, U& dynamic_memory, + int PC) const { if (PC != -1) Proc.PC = PC; @@ -122,13 +118,13 @@ BreakType Program::execute(Processor& Proc, int PC) const #ifdef COUNT_INSTRUCTIONS Proc.stats[p[Proc.PC].get_opcode()]++; #endif - p[Proc.PC++].execute(Proc); + p[Proc.PC++].execute(Proc, dynamic_memory); time++; #ifdef DEBUG_COMPLEXITY cout << "complexity at " << time << ": " << Proc.complexity << endl; #endif } - while (Proc.complexity < (1 << 20)); + while (Proc.complexity < (1 << 19)); Proc.time = time; #ifdef DEBUG_ROUNDS cout << "breaking at time " << Proc.time << endl; diff --git a/GC/ReplicatedParty.cpp b/GC/ReplicatedParty.cpp index 20b6353b6..72b209c0f 100644 --- a/GC/ReplicatedParty.cpp +++ b/GC/ReplicatedParty.cpp @@ -72,6 +72,8 @@ ReplicatedParty::ReplicatedParty(int argc, const char** argv) : this->run(); + this->machine.write_memory(this->N.my_num()); + if (server) delete server; } diff --git a/GC/ReplicatedSecret.cpp b/GC/ReplicatedSecret.cpp index 3bc6331af..b5232c2e8 100644 --- a/GC/ReplicatedSecret.cpp +++ b/GC/ReplicatedSecret.cpp @@ -264,8 +264,9 @@ void ReplicatedSecret::trans(Processor& processor, } template -void ReplicatedSecret::reveal(Clear& x) +void ReplicatedSecret::reveal(size_t n_bits, Clear& x) { + (void) n_bits; ReplicatedSecret share = *this; vector opened; auto& party = ReplicatedParty::s(); diff --git a/GC/ReplicatedSecret.h b/GC/ReplicatedSecret.h index 938fa06ec..6ae046e43 100644 --- a/GC/ReplicatedSecret.h +++ b/GC/ReplicatedSecret.h @@ -37,8 +37,6 @@ class ReplicatedSecret : public FixedVec typedef BitVec mac_type; typedef BitVec mac_key_type; - typedef void Inp; - typedef void PO; typedef ReplicatedBase Protocol; static string type_string() { return "replicated secret"; } @@ -63,6 +61,8 @@ class ReplicatedSecret : public FixedVec static void trans(Processor& processor, int n_outputs, const vector& args); + static void convcbit(Integer& dest, const Clear& source) { dest = source; } + static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); } static U input(int from, Processor& processor, int n_bits); @@ -87,7 +87,7 @@ class ReplicatedSecret : public FixedVec Thread& party, bool repeat); void finalize_andrs(vector& os, int n); - void reveal(Clear& x); + void reveal(size_t n_bits, Clear& x); void random_bit(); }; @@ -98,7 +98,7 @@ class SemiHonestRepSecret : public ReplicatedSecret typedef ReplicatedSecret super; public: - typedef SemiHonestRepSecret DynamicType; + typedef Memory DynamicMemory; typedef ReplicatedMC MC; diff --git a/GC/RuntimeBranching.h b/GC/RuntimeBranching.h new file mode 100644 index 000000000..d0e0743ea --- /dev/null +++ b/GC/RuntimeBranching.h @@ -0,0 +1,36 @@ +/* + * RuntimeBranching.h + * + */ + +#ifndef GC_RUNTIMEBRANCHING_H_ +#define GC_RUNTIMEBRANCHING_H_ + +namespace GC +{ + +class RuntimeBranching +{ + bool tainted; + +public: + RuntimeBranching() : tainted(false) + { + } + + void untaint() + { + bool was_tainted = tainted; + tainted = false; + if (was_tainted) + throw needs_cleaning(); + } + void taint() + { + tainted = true; + } +}; + +} /* namespace GC */ + +#endif /* GC_RUNTIMEBRANCHING_H_ */ diff --git a/GC/Secret.h b/GC/Secret.h index 2d111d1eb..002a47321 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -59,8 +59,6 @@ class Mask class SpdzShare : public Share { public: - void assign(const gf2n& value, const gf2n& mac_key, bool first_player) - { Share::assign(value, first_player ? 0 : 1, mac_key); } }; template class Processor; @@ -73,11 +71,7 @@ class Secret T& get_new_reg(); public: -#ifdef SPDZ_AUTH - typedef SpdzShare DynamicType; -#else - typedef AuthValue DynamicType; -#endif + typedef typename T::DynamicMemory DynamicMemory; // dummy typedef DummyMC MC; @@ -105,8 +99,10 @@ class Secret static Secret carryless_mult(const Secret& x, const Secret& y); static void output(T& reg); - static void load(vector< ReadAccess< Secret > >& accesses, const Memory& mem); - static void store(Memory& mem, vector< WriteAccess< Secret > >& accesses); + template + static void load(vector< ReadAccess< Secret > >& accesses, const U& mem); + template + static void store(U& mem, vector< WriteAccess< Secret > >& accesses); static void andrs(Processor< Secret >& processor, const vector& args) { T::andrs(processor, args); } @@ -117,6 +113,8 @@ class Secret static void trans(Processor >& processor, int n_inputs, const vector& args); + static void convcbit(Integer& dest, const Clear& source) { T::convcbit(dest, source); } + Secret(); Secret(const Integer& x) { *this = x; } @@ -143,7 +141,7 @@ class Secret void andrs(int n, const Secret& x, const Secret& y) { and_(n, x, y, true); } template - void reveal(U& x); + void reveal(size_t n_bits, U& x); int size() const { return registers.size(); } CheckVector& get_regs() { return registers; } diff --git a/GC/Secret.hpp b/GC/Secret.hpp index e468ed150..1c83b5fdf 100644 --- a/GC/Secret.hpp +++ b/GC/Secret.hpp @@ -102,7 +102,8 @@ void Secret::random_bit() } template -void Secret::store(Memory& mem, +template +void Secret::store(U& mem, vector > >& accesses) { T::store(mem, accesses); @@ -218,7 +219,8 @@ void Secret::load(int n, const Integer& x) } template -void Secret::load(vector > >& accesses, const Memory& mem) +template +void Secret::load(vector > >& accesses, const U& mem) { for (auto&& access : accesses) { @@ -293,12 +295,14 @@ void Secret::trans(Processor >& processor, int n_outputs, template template -void Secret::reveal(U& x) +void Secret::reveal(size_t n_bits, U& x) { #ifdef DEBUG_OUTPUT cout << "revealing " << this << " with min(" << 8 * sizeof(U) << "," << registers.size() << ") bits" << endl; #endif + if (n_bits > registers.size()) + throw out_of_range("not enough wires for revealing"); x = 0; for (unsigned int i = 0; i < min(8 * sizeof(U), registers.size()); i++) { diff --git a/GC/Thread.hpp b/GC/Thread.hpp index bd2c7f54d..7c1d0ef58 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -6,9 +6,6 @@ #include "Thread.h" #include "Program.h" -#include "ReplicatedSecret.h" -#include "Secret.h" - #include "Networking/CryptoPlayer.h" namespace GC @@ -76,7 +73,7 @@ void Thread::run() template void Thread::run(Program& program) { - while (program.execute(processor) != DONE_BREAK) + while (program.execute(processor, master.memory) != DONE_BREAK) ; } diff --git a/GC/ThreadMaster.h b/GC/ThreadMaster.h index 285c12a49..e198c0f5d 100644 --- a/GC/ThreadMaster.h +++ b/GC/ThreadMaster.h @@ -45,7 +45,7 @@ class ThreadMaster : public ThreadMasterBase Player* P; Machine machine; - Memory memory; + typename T::DynamicMemory memory; OnlineOptions& opts; diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index 330630fd9..8ce35fc99 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -6,9 +6,6 @@ #include "ThreadMaster.h" #include "Program.h" -#include "ReplicatedSecret.h" -#include "Secret.h" - #include "instructions.h" namespace GC @@ -28,7 +25,7 @@ ThreadMaster& ThreadMaster::s() template ThreadMaster::ThreadMaster(OnlineOptions& opts) : - P(0), machine(memory), opts(opts) + P(0), opts(opts) { if (singleton) throw runtime_error("there can only be one"); @@ -94,6 +91,7 @@ void ThreadMaster::run() { #define X(NAME, CODE) case NAME: cerr << it.second << " " #NAME << endl; break; INSTRUCTIONS +#undef X } for (auto it = stats.begin(); it != stats.end(); it++) diff --git a/GC/instructions.h b/GC/instructions.h index 36284deee..28c1b1f6e 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -11,6 +11,7 @@ #define PROC processor #define INST instruction #define MACH processor.machine +#define MD dynamic_memory #define R0 instruction.get_r(0) #define R1 instruction.get_r(1) @@ -61,17 +62,18 @@ X(STMSI, MSI = S0) \ X(LDMC, C0 = MMC) \ X(STMC, MMC = C0) \ - X(LDMSD, PROC.load_dynamic_direct(EXTRA)) \ - X(STMSD, PROC.store_dynamic_direct(EXTRA)) \ - X(LDMSDI, PROC.load_dynamic_indirect(EXTRA)) \ - X(STMSDI, PROC.store_dynamic_indirect(EXTRA)) \ - X(STMSDCI, PROC.store_clear_in_dynamic(EXTRA)) \ + X(LDMSD, PROC.load_dynamic_direct(EXTRA, MD)) \ + X(STMSD, PROC.store_dynamic_direct(EXTRA, MD)) \ + X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \ + X(STMSDI, PROC.store_dynamic_indirect(EXTRA, MD)) \ + X(STMSDCI, PROC.store_clear_in_dynamic(EXTRA, MD)) \ X(CONVSINT, S0.load(IMM, I1)) \ X(CONVCINT, C0 = I1) \ + X(CONVCBIT, T::convcbit(I0, C1)) \ X(MOVS, S0 = PS1) \ X(TRANS, T::trans(PROC, IMM, EXTRA)) \ X(BIT, PROC.random_bit(S0)) \ - X(REVEAL, PS1.reveal(C0)) \ + X(REVEAL, PS1.reveal(IMM, C0)) \ X(PRINTREG, PROC.print_reg(R0, IMM)) \ X(PRINTREGPLAIN, PROC.print_reg_plain(C0)) \ X(PRINTREGSIGNED, PROC.print_reg_signed(IMM, C0)) \ diff --git a/GC/square64.cpp b/GC/square64.cpp index a10351dc6..9c4774ed6 100644 --- a/GC/square64.cpp +++ b/GC/square64.cpp @@ -4,6 +4,7 @@ */ #include "square64.h" +#include "Tools/cpu_support.h" #include #include using namespace std; @@ -24,18 +25,23 @@ union matrix32x8 void transpose(square64& output, int x, int y) { #ifdef __AVX2__ - for (int j = 0; j < 8; j++) + if (cpu_has_avx2()) { - int row = _mm256_movemask_epi8(whole); - whole = _mm256_slli_epi64(whole, 1); + for (int j = 0; j < 8; j++) + { + int row = _mm256_movemask_epi8(whole); + whole = _mm256_slli_epi64(whole, 1); - // _mm_movemask_epi8 uses most significant bit, hence +7-j - output.halfrows[8*x+7-j][y] = row; + // _mm_movemask_epi8 uses most significant bit, hence +7-j + output.halfrows[8*x+7-j][y] = row; + } } -#else - (void) output, (void) x, (void) y; - throw runtime_error("need to compile with AVX2 support"); + else #endif + { + (void) output, (void) x, (void) y; + throw runtime_error("need AVX2 support"); + } } }; @@ -60,24 +66,29 @@ void zip(int chunk_size, __m256i& lows, __m256i& highs, const __m256i& a, const __m256i& b) { #ifdef __AVX2__ - switch (chunk_size) + if (cpu_has_avx2()) { - ZIP_CASE(8, lows, highs, a, b); - ZIP_CASE(16, lows, highs, a, b); - ZIP_CASE(32, lows, highs, a, b); - ZIP_CASE(64, lows, highs, a, b); - case 128: - lows = a; - highs = b; - swap(((__m128i*)&lows)[1], ((__m128i*)&highs)[0]); - break; - default: - throw invalid_argument("not supported"); + switch (chunk_size) + { + ZIP_CASE(8, lows, highs, a, b); + ZIP_CASE(16, lows, highs, a, b); + ZIP_CASE(32, lows, highs, a, b); + ZIP_CASE(64, lows, highs, a, b); + case 128: + lows = a; + highs = b; + swap(((__m128i*)&lows)[1], ((__m128i*)&highs)[0]); + break; + default: + throw invalid_argument("not supported"); + } } -#else - (void) chunk_size, (void) lows, (void) highs, (void) a, (void) b; - throw runtime_error("need to compile with AVX2 support"); + else #endif + { + (void) chunk_size, (void) lows, (void) highs, (void) a, (void) b; + throw runtime_error("need AVX2 support"); + } } void square64::transpose(int n_rows, int n_cols) diff --git a/Machines/Rep.cpp b/Machines/Rep.cpp index 5fc33ae53..f8bd44f31 100644 --- a/Machines/Rep.cpp +++ b/Machines/Rep.cpp @@ -3,10 +3,17 @@ * */ +#include "Math/MaliciousRep3Share.h" +#include "Math/BrainShare.h" +#include "Processor/BrainPrep.h" + #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" #include "Processor/Machine.hpp" +#include "Processor/BrainPrep.hpp" #include "Auth/MAC_Check.hpp" +#include "Auth/fake-stuff.hpp" +#include "Auth/MaliciousRepMC.hpp" template<> Preprocessing>* Preprocessing>::get_live_prep( @@ -23,28 +30,22 @@ Preprocessing>* Preprocessing>::get_live_prep( } template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) -{ - return new ReplicatedRingPrep>(proc, usage); -} - -template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) +Preprocessing>>* Preprocessing>>::get_live_prep( + SubProcessor>>* proc, DataPositions& usage) { - (void) proc; - return new MaliciousRepPrep>(proc, usage); + return new ReplicatedRingPrep>>(proc, usage); } template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) +Preprocessing>>* Preprocessing>>::get_live_prep( + SubProcessor>>* proc, DataPositions& usage) { - (void) proc; - return new MaliciousRepPrep>(proc, usage); + return new ReplicatedRingPrep>>(proc, usage); } -template class Machine, Rep3Share>; +template class Machine>, Rep3Share>; +template class Machine>, Rep3Share>; template class Machine, Rep3Share>; template class Machine, MaliciousRep3Share>; +template class Machine, MaliciousRep3Share>; +template class Machine, MaliciousRep3Share>; diff --git a/Machines/SPDZ.cpp b/Machines/SPDZ.cpp index 0c4110860..fa6f94c03 100644 --- a/Machines/SPDZ.cpp +++ b/Machines/SPDZ.cpp @@ -4,40 +4,11 @@ #include "Processor/Instruction.hpp" #include "Processor/Machine.hpp" #include "Auth/MAC_Check.hpp" +#include "Auth/fake-stuff.hpp" #include "Processor/MascotPrep.hpp" #include "Processor/Spdz2kPrep.hpp" -#ifdef USE_GF2N_LONG -template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) -{ - return new MascotFieldPrep>(proc, usage); -} - -template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) -{ - return new MascotFieldPrep>(proc, usage); -} - -template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) -{ - return new Spdz2kPrep>(proc, usage); -} - -template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) -{ - return new Spdz2kPrep>(proc, usage); -} -#endif - template class Machine>; template class Machine, Share>; diff --git a/Machines/Semi.cpp b/Machines/Semi.cpp new file mode 100644 index 000000000..7b31093b2 --- /dev/null +++ b/Machines/Semi.cpp @@ -0,0 +1,25 @@ +/* + * Semi.cpp + * + */ + +#include "Math/SemiShare.h" +#include "Math/Semi2kShare.h" +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "Auth/SemiMC.h" +#include "Processor/SemiPrep.h" + +#include "Processor/Data_Files.hpp" +#include "Processor/Instruction.hpp" +#include "Processor/Machine.hpp" +#include "Processor/MascotPrep.hpp" +#include "Processor/SemiPrep.hpp" +#include "Processor/SemiInput.hpp" +#include "Auth/MAC_Check.hpp" +#include "Auth/fake-stuff.hpp" +#include "Auth/SemiMC.hpp" + +template class Machine, SemiShare>; +template class Machine, SemiShare>; +template class Machine, SemiShare>; diff --git a/Machines/ShamirMachine.cpp b/Machines/ShamirMachine.cpp index 5738f30df..8b286e82a 100644 --- a/Machines/ShamirMachine.cpp +++ b/Machines/ShamirMachine.cpp @@ -14,7 +14,12 @@ #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" #include "Processor/Machine.hpp" +#include "Processor/ShamirInput.hpp" +#include "Processor/Shamir.hpp" +#include "Auth/ShamirMC.hpp" +#include "Auth/MaliciousShamirMC.hpp" #include "Auth/MAC_Check.hpp" +#include "Auth/fake-stuff.hpp" ShamirMachine* ShamirMachine::singleton = 0; @@ -57,7 +62,9 @@ ShamirMachine::ShamirMachine(int argc, const char** argv) opt.get("-T")->getInt(threshold); else threshold = (nparties - 1) / 2; +#ifdef VERBOSE cerr << "Using threshold " << threshold << " out of " << nparties << endl; +#endif if (2 * threshold >= nparties) throw runtime_error("threshold too high"); if (threshold < 1) @@ -74,36 +81,6 @@ ShamirMachineSpec::ShamirMachineSpec(int argc, const char** argv) : ReplicatedMachine, T>(argc, argv, "shamir", opt, nparties); } -template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) -{ - return new ReplicatedPrep>(proc, usage); -} - -template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) -{ - return new ReplicatedPrep>(proc, usage); -} - -template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) -{ - (void) proc; - return new MaliciousRepPrep>(proc, usage); -} - -template<> -Preprocessing>* Preprocessing>::get_live_prep( - SubProcessor>* proc, DataPositions& usage) -{ - (void) proc; - return new MaliciousRepPrep>(proc, usage); -} - template class ShamirMachineSpec; template class ShamirMachineSpec; diff --git a/Makefile b/Makefile index 27fdb5f7f..73a855f13 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ endif COMMON = $(MATH) $(TOOLS) $(NETWORK) $(AUTH) COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT) YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) $(GC) BMR/Key.o -BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) Processor/BaseMachine.o Processor/ProcessorBase.o +BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) $(PROCESSOR) $(OT) LIB = libSPDZ.a @@ -40,7 +40,7 @@ OBJS = $(BMR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) $(patsubst %.cpp DEPS := $(OBJS:.o=.d) -all: gen_input online offline externalIO yao replicated shamir spdz2k +all: gen_input online offline externalIO yao replicated shamir spdz2k real-bmr brain-party.x semi-party.x semi2k-party.x ifeq ($(USE_GF2N_LONG),1) ifneq ($(OS), Darwin) @@ -67,6 +67,8 @@ externalIO: client-setup.x bankers-bonus-client.x bankers-bonus-commsec-client.x bmr: bmr-program-party.x bmr-program-tparty.x +real-bmr: $(patsubst %.cpp,%.x,$(wildcard *-bmr-party.cpp)) + yao: yao-player.x she-offline: Check-Offline.x spdz2-offline.x @@ -79,7 +81,7 @@ rep-ring: replicated-ring-party.x Fake-Offline.x rep-bin: replicated-bin-party.x malicious-rep-bin-party.x Fake-Offline.x -replicated: rep-field rep-ring rep-bin +replicated: rep-field rep-ring rep-bin brain-party.x spdz2k: spdz2k-party.x ot-offline.x Check-Offline-Z2k.x galois-degree.x Fake-Offline.x @@ -95,7 +97,7 @@ endif shamir: shamir-party.x malicious-shamir-party.x galois-degree.x -$(LIBRELEASE): Machines/Rep.o Machines/ShamirMachine.o Machines/SPDZ.o $(YAO) $(PROCESSOR) $(COMMON) +$(LIBRELEASE): Machines/Rep.o Machines/Semi.o Machines/ShamirMachine.o Machines/SPDZ.o $(YAO) $(PROCESSOR) $(COMMON) $(BMR) $(AR) -csr $@ $^ static/%.x: %.cpp $(LIBRELEASE) $(LIBSIMPLEOT) @@ -104,7 +106,7 @@ static/%.x: %.cpp $(LIBRELEASE) $(LIBSIMPLEOT) static-dir: @ mkdir static 2> /dev/null; true -static-release: static-dir $(patsubst %.cpp, static/%.x, $(wildcard *ring*.cpp *field*.cpp *shamir*.cpp yao*.cpp spdz2k*.cpp Player-Online.cpp )) +static-release: static-dir $(patsubst %.cpp, static/%.x, $(wildcard *ring*.cpp *field*.cpp *shamir*.cpp yao*.cpp spdz2k*.cpp Player-Online.cpp *-bmr-*.cpp brain-party.cpp semi*.cpp replicated-bin-party.cpp )) Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) @@ -151,11 +153,14 @@ gc-emulate.x: $(GC) $(COMMON) $(PROCESSOR) gc-emulate.cpp $(GC) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) ifeq ($(USE_GF2N_LONG),1) -bmr-program-party.x: $(BMR) bmr-program-party.cpp +bmr-program-party.x: $(BMR) bmr-program-party.cpp $(LIBSIMPLEOT) $(CXX) $(CFLAGS) -o $@ $^ $(BOOST) $(LDLIBS) -bmr-program-tparty.x: $(BMR) bmr-program-tparty.cpp +bmr-program-tparty.x: $(BMR) bmr-program-tparty.cpp $(LIBSIMPLEOT) $(CXX) $(CFLAGS) -o $@ $^ $(BOOST) $(LDLIBS) + +%-bmr-party.x: %-bmr-party.cpp $(wildcard BMR/*) $(BMR) $(LIBSIMPLEOT) + $(CXX) $(CFLAGS) -o $@ $< $(BMR) $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) endif bmr-clean: @@ -208,6 +213,9 @@ replicated-field-party.x: replicated-field-party.cpp Machines/Rep.o $(PROCESSOR) malicious-rep-field-party.x: malicious-rep-field-party.cpp Machines/Rep.o $(PROCESSOR) $(COMMON) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) +brain-party.x: brain-party.cpp Machines/Rep.o $(PROCESSOR) $(COMMON) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + shamir-party.x: shamir-party.cpp Machines/ShamirMachine.o $(PROCESSOR) $(COMMON) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) @@ -217,6 +225,12 @@ malicious-shamir-party.x: malicious-shamir-party.cpp Machines/ShamirMachine.o $( spdz2k-party.x: spdz2k-party.cpp Machines/SPDZ.o $(PROCESSOR) $(COMMON) $(OT) $(LIBSIMPLEOT) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) +semi-party.x: semi-party.cpp Machines/Semi.o $(PROCESSOR) $(COMMON) $(OT) $(LIBSIMPLEOT) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + +semi2k-party.x: semi2k-party.cpp Machines/Semi.o $(PROCESSOR) $(COMMON) $(OT) $(LIBSIMPLEOT) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + $(LIBSIMPLEOT): SimpleOT/Makefile $(MAKE) -C SimpleOT diff --git a/Math/BrainShare.h b/Math/BrainShare.h new file mode 100644 index 000000000..6f37de889 --- /dev/null +++ b/Math/BrainShare.h @@ -0,0 +1,49 @@ +/* + * BrainShare.h + * + */ + +#ifndef MATH_BRAINSHARE_H_ +#define MATH_BRAINSHARE_H_ + +#include "Rep3Share.h" + +template class HashMaliciousRepMC; +template class Beaver; +template class BrainPrep; + +template +class BrainShare : public Rep3Share> +{ + typedef SignedZ2 T; + typedef Rep3Share super; + +public: + typedef T clear; + + typedef Beaver Protocol; + typedef HashMaliciousRepMC MAC_Check; + typedef MAC_Check Direct_MC; + typedef ReplicatedInput Input; + typedef ReplicatedPrivateOutput PrivateOutput; + typedef BrainPrep LivePrep; + + const static int N_MASK_BITS = clear::N_BITS + S; + const static int Z_BITS = 2 * (N_MASK_BITS) + 5 + S; + + BrainShare() + { + } + template + BrainShare(const FixedVec& other) + { + FixedVec::operator=(other); + } + template + BrainShare(const U& other, int my_num = 0, T alphai = {}) : super(other) + { + (void) my_num, (void) alphai; + } +}; + +#endif /* MATH_BRAINSHARE_H_ */ diff --git a/Math/FixedVec.h b/Math/FixedVec.h index f4f981472..5214d58c7 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -55,6 +55,13 @@ class FixedVec { } + template + FixedVec(const FixedVec& other) + { + for (int i = 0; i < L; i++) + v[i] = other[i]; + } + T& operator[](int i) { return v[i]; @@ -99,7 +106,7 @@ class FixedVec void mul(const FixedVec& x, const FixedVec& y) { for (int i = 0; i < L; i++) - v[i] = x.v[i] * y.v[i]; + v[i].mul(x.v[i], y.v[i]); } void add(const FixedVec& x) @@ -181,6 +188,12 @@ class FixedVec return *this; } + FixedVec& operator*=(const FixedVec& other) + { + *this = *this * other; + return *this; + } + FixedVec& operator/=(const FixedVec& other) { *this = *this / other; @@ -263,6 +276,12 @@ class FixedVec v[0] = sum - s; } + void force_to_bit() + { + for (auto& x : v) + x.force_to_bit(); + } + void output(ostream& s, bool human) const { for (auto& x : v) diff --git a/Math/Integer.cpp b/Math/Integer.cpp index 8cf9a6d0e..21ba9a302 100644 --- a/Math/Integer.cpp +++ b/Math/Integer.cpp @@ -21,17 +21,6 @@ void IntBase::input(istream& s,bool human) s.read((char*)&a, sizeof(a)); } -void to_signed_bigint(bigint& res, const Integer& x, int n) -{ - res = abs(x.get()); - bigint& tmp = bigint::tmp = 1; - tmp <<= n; - tmp -= 1; - res &= tmp; - if (x < 0) - res.negate(); -} - void Integer::reqbl(int n) { if ((int)n < 0 && size() * 8 != -(int)n) diff --git a/Math/Integer.h b/Math/Integer.h index 84689669a..ed97889f7 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -24,6 +24,8 @@ class IntBase : public ValueInterface long a; public: + static const int N_BITS = 8 * sizeof(a); + static int size() { return sizeof(a); } static string type_string() { return "integer"; } @@ -92,7 +94,7 @@ class Integer : public IntBase Integer() { a = 0; } Integer(long a) : IntBase(a) {} - Integer(const bigint& x) { *this = x.get_si(); } + Integer(const bigint& x) { *this = (x > 0) ? x.get_ui() : -x.get_ui(); } template Integer(const Z2& x) : Integer(x.get_limb(0)) {} @@ -146,8 +148,6 @@ inline void to_signed_bigint(bigint& res, const Integer& x) res = x.get(); } -void to_signed_bigint(bigint& res, const Integer& x, int n); - // slight misnomer inline void to_gfp(Integer& res, const bigint& x) { diff --git a/Math/MaliciousRep3Share.h b/Math/MaliciousRep3Share.h index 4ade5c6f7..eb8e6ddff 100644 --- a/Math/MaliciousRep3Share.h +++ b/Math/MaliciousRep3Share.h @@ -11,6 +11,7 @@ template class HashMaliciousRepMC; template class Beaver; +template class MaliciousRepPrep; template class MaliciousRep3Share : public Rep3Share @@ -24,15 +25,20 @@ class MaliciousRep3Share : public Rep3Share typedef ReplicatedInput> Input; typedef ReplicatedPrivateOutput> PrivateOutput; typedef Rep3Share Honest; + typedef MaliciousRepPrep LivePrep; static string type_short() { - return "M" + string(1, gfp::type_char()); + return "M" + string(1, T::type_char()); } MaliciousRep3Share() { } + MaliciousRep3Share(const T& other, int my_num, T alphai = {}) : + super(other, my_num, alphai) + { + } template MaliciousRep3Share(const U& other) : super(other) { diff --git a/Math/MaliciousShamirShare.h b/Math/MaliciousShamirShare.h index 85995ba02..501371d84 100644 --- a/Math/MaliciousShamirShare.h +++ b/Math/MaliciousShamirShare.h @@ -10,6 +10,8 @@ #include "Processor/Beaver.h" #include "Auth/MaliciousShamirMC.h" +template class MaliciousRepPrep; + template class MaliciousShamirShare : public ShamirShare { @@ -22,6 +24,7 @@ class MaliciousShamirShare : public ShamirShare typedef ShamirInput Input; typedef ReplicatedPrivateOutput PrivateOutput; typedef ShamirShare Honest; + typedef MaliciousRepPrep LivePrep; static string type_short() { @@ -32,8 +35,9 @@ class MaliciousShamirShare : public ShamirShare { } template - MaliciousShamirShare(const U& other) : super(other) + MaliciousShamirShare(const U& other, int my_num = 0, T alphai = {}) : super(other) { + (void) my_num, (void) alphai; } }; diff --git a/Math/Rep3Share.h b/Math/Rep3Share.h index 4249c7ced..15163b0d8 100644 --- a/Math/Rep3Share.h +++ b/Math/Rep3Share.h @@ -39,13 +39,15 @@ class Rep3Share : public FixedVec Rep3Share() { } - Rep3Share(const FixedVec& other) + template + Rep3Share(const FixedVec& other) { FixedVec::operator=(other); } - Rep3Share(T value, int my_num) + Rep3Share(T value, int my_num, const T& alphai = {}) { + (void) alphai; Replicated::assign(*this, value, my_num); } @@ -90,7 +92,10 @@ class Rep3Share : public FixedVec clear local_mul(const Rep3Share& other) const { - return (*this)[0] * other.sum() + (*this)[1] * other[0]; + T a, b; + a.mul((*this)[0], other.sum()); + b.mul((*this)[1], other[0]); + return a + b; } void mul_by_bit(const Rep3Share& x, const T& y) diff --git a/Math/Semi2kShare.h b/Math/Semi2kShare.h new file mode 100644 index 000000000..9d584764c --- /dev/null +++ b/Math/Semi2kShare.h @@ -0,0 +1,46 @@ +/* + * Semi2kShare.h + * + */ + +#ifndef MATH_SEMI2KSHARE_H_ +#define MATH_SEMI2KSHARE_H_ + +#include "SemiShare.h" +#include "OT/Rectangle.h" + +template +class Semi2kShare : public SemiShare> +{ + typedef SignedZ2 T; + +public: + typedef Z2<64> mac_key_type; + + typedef SemiMC MAC_Check; + typedef MAC_Check Direct_MC; + typedef SemiInput Input; + typedef ::PrivateOutput PrivateOutput; + typedef SPDZ Protocol; + typedef SemiPrep LivePrep; + + typedef Semi2kShare prep_type; + typedef SemiMultiplier Multiplier; + typedef OTTripleGenerator TripleGenerator; + typedef Z2kSquare Rectangle; + + Semi2kShare() + { + } + template + Semi2kShare(const U& other) : SemiShare>(other) + { + } + Semi2kShare(const T& other, int my_num, const T& alphai = {}) + { + (void) alphai; + assign(other, my_num); + } +}; + +#endif /* MATH_SEMI2KSHARE_H_ */ diff --git a/Math/SemiShare.h b/Math/SemiShare.h new file mode 100644 index 000000000..e46ac4480 --- /dev/null +++ b/Math/SemiShare.h @@ -0,0 +1,113 @@ +/* + * SemiShare.h + * + */ + +#ifndef MATH_SEMISHARE_H_ +#define MATH_SEMISHARE_H_ + +#include "ValueInterface.h" +#include "Processor/Beaver.h" +#include "Processor/DummyProtocol.h" +#include "Processor/NoLivePrep.h" + +#include +using namespace std; + +template class Input; +template class SemiMC; +template class SPDZ; +template class SemiPrep; +template class SemiInput; +template class SemiMultiplier; +template class OTTripleGenerator; + +template +class SemiShare : public T +{ + typedef T super; + +public: + typedef T mac_key_type; + typedef T mac_type; + typedef T open_type; + typedef T clear; + + typedef SemiMC MAC_Check; + typedef MAC_Check Direct_MC; + typedef SemiInput Input; + typedef ::PrivateOutput PrivateOutput; + typedef SPDZ Protocol; + typedef SemiPrep LivePrep; + + typedef SemiShare prep_type; + typedef SemiMultiplier Multiplier; + typedef OTTripleGenerator TripleGenerator; + typedef T sacri_type; + typedef square128 Rectangle; + + const static bool needs_ot = true; + + static string type_short() { return "D" + string(1, T::type_char()); } + + SemiShare() + { + } + template + SemiShare(const U& other) : T(other) + { + } + SemiShare(const clear& other, int my_num, const T& alphai = {}) + { + (void) alphai; + assign(other, my_num); + } + + void assign(const clear& other, int my_num, const T& alphai = {}) + { + (void) alphai; + Protocol::assign(*this, other, my_num); + } + void assign(const char* buffer) + { + super::assign(buffer); + } + + void add(const SemiShare& x, const SemiShare& y) + { + *this = x + y; + } + void sub(const SemiShare& x, const SemiShare& y) + { + *this = x - y; + } + + void add(const SemiShare& S, const clear aa, int my_num, const T& alphai) + { + (void) alphai; + *this = S + SemiShare(aa, my_num); + } + void sub(const SemiShare& S, const clear& aa, int my_num, const T& alphai) + { + (void) alphai; + *this = S - SemiShare(aa, my_num); + } + void sub(const clear& aa, const SemiShare& S, int my_num, const T& alphai) + { + (void) alphai; + *this = SemiShare(aa, my_num) - S; + } + + void pack(octetStream& os, bool full = true) const + { + (void)full; + super::pack(os); + } + void unpack(octetStream& os, bool full = true) + { + (void)full; + super::unpack(os); + } +}; + +#endif /* MATH_SEMISHARE_H_ */ diff --git a/Math/ShamirShare.h b/Math/ShamirShare.h index c191ac41c..34e64dcce 100644 --- a/Math/ShamirShare.h +++ b/Math/ShamirShare.h @@ -11,6 +11,8 @@ #include "Processor/Shamir.h" #include "Processor/ShamirInput.h" +template class ReplicatedPrep; + template class ShamirShare : public T { @@ -24,6 +26,7 @@ class ShamirShare : public T typedef MAC_Check Direct_MC; typedef ShamirInput Input; typedef ReplicatedPrivateOutput PrivateOutput; + typedef ReplicatedPrep LivePrep; const static bool needs_ot = false; @@ -45,9 +48,9 @@ class ShamirShare : public T T::operator=(other); } template - ShamirShare(const U& other, int my_num) : ShamirShare(other) + ShamirShare(const U& other, int my_num, T alphai = {}) : ShamirShare(other) { - (void) my_num; + (void) my_num, (void) alphai; } // Share compatibility @@ -89,6 +92,21 @@ class ShamirShare : public T *this = aa - S; } + ShamirShare operator<<(int i) + { + return *this * (T(1) << i); + } + ShamirShare& operator<<=(int i) + { + *this = *this << i; + return *this; + } + + void force_to_bit() + { + throw not_implemented(); + } + void pack(octetStream& os, bool full = true) const { (void)full; diff --git a/Math/Share.h b/Math/Share.h index 85266e868..aa02949d3 100644 --- a/Math/Share.h +++ b/Math/Share.h @@ -23,6 +23,8 @@ template bool check_macs(const vector< Share >& S,const T& key); template class MAC_Check_; template class Direct_MAC_Check; template class MascotMultiplier; +template class MascotFieldPrep; +template class NPartyTripleGenerator; union square128; @@ -37,18 +39,20 @@ class Share typedef T mac_key_type; typedef T mac_type; typedef T open_type; - typedef typename T::value_type clear; + typedef T clear; typedef Share prep_type; typedef MascotMultiplier Multiplier; + typedef NPartyTripleGenerator TripleGenerator; typedef T sacri_type; typedef square128 Rectangle; typedef MAC_Check_ MAC_Check; typedef Direct_MAC_Check Direct_MC; typedef ::Input Input; - typedef typename T::PO PrivateOutput; + typedef ::PrivateOutput PrivateOutput; typedef SPDZ Protocol; + typedef MascotFieldPrep LivePrep; const static bool needs_ot = true; @@ -127,6 +131,8 @@ class Share Share operator<<(int i) { return this->operator*(T(1) << i); } Share& operator<<=(int i) { return *this = *this << i; } + void force_to_bit() { a.force_to_bit(); } + // Input and output from a stream // - Can do in human or machine only format (later should be faster) void output(ostream& s,bool human) const diff --git a/Math/Spdz2kShare.h b/Math/Spdz2kShare.h index f459aaca2..c214ace07 100644 --- a/Math/Spdz2kShare.h +++ b/Math/Spdz2kShare.h @@ -21,8 +21,8 @@ class Spdz2kShare : public Share> typedef Z2 tmp_type; typedef Share super; - typedef Integer clear; -// typedef Z2 clear; +// typedef Integer clear; + typedef SignedZ2 clear; typedef Z2 mac_key_type; typedef Z2 mac_type; @@ -30,6 +30,7 @@ class Spdz2kShare : public Share> typedef Spdz2kShare prep_type; typedef Spdz2kMultiplier Multiplier; + typedef NPartyTripleGenerator TripleGenerator; typedef Z2 sacri_type; typedef Z2kRectangle Rectangle; @@ -38,6 +39,7 @@ class Spdz2kShare : public Share> typedef ::Input Input; typedef NotImplementedOutput PrivateOutput; typedef SPDZ Protocol; + typedef Spdz2kPrep LivePrep; const static int k = K; const static int s = S; @@ -48,6 +50,10 @@ class Spdz2kShare : public Share> Spdz2kShare() {} template Spdz2kShare(const Share& x) : super(x) {} + Spdz2kShare(const clear& x, int my_num, const mac_key_type& alphai) : + super(x, my_num, alphai) + { + } }; diff --git a/Math/Z2k.cpp b/Math/Z2k.cpp index ea425a6cb..a7f480d20 100644 --- a/Math/Z2k.cpp +++ b/Math/Z2k.cpp @@ -11,6 +11,28 @@ const int Z2::N_BITS; template const int Z2::N_BYTES; +template +void Z2::reqbl(int n) +{ + if (n < 0 && N_BITS != -(int)n) + { + throw Processor_Error( + "Program compiled for rings of length " + to_string(-n) + + " but VM supports only " + + to_string(N_BITS)); + } + else if (n > 0) + { + throw Processor_Error("Program compiled for fields not rings"); + } +} + +template +bool Z2::allows(Dtype dtype) +{ + return Integer::allows(dtype); +} + template Z2::Z2(const bigint& x) : Z2() { @@ -73,6 +95,24 @@ Z2 Z2::sqrRoot() return res; } +template +void Z2::AND(const Z2& x, const Z2& y) +{ + mpn_and_n(a, x.a, y.a, N_WORDS); +} + +template +void Z2::OR(const Z2& x, const Z2& y) +{ + mpn_ior_n(a, x.a, y.a, N_WORDS); +} + +template +void Z2::XOR(const Z2& x, const Z2& y) +{ + mpn_xor_n(a, x.a, y.a, N_WORDS); +} + template void Z2::input(istream& s, bool human) { @@ -113,3 +153,6 @@ X(48) X(112) X(208) X(114) X(130) X(162) X(194) X(324) X(388) X(66) X(210) X(258) +X(72) +X(106) +X(104) X(144) X(253) X(255) X(269) X(271) diff --git a/Math/Z2k.h b/Math/Z2k.h index 536523f63..6f8e4842c 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -19,6 +19,7 @@ using namespace std; template class Z2 : public ValueInterface { +protected: template friend class Z2; friend class bigint; @@ -26,8 +27,7 @@ class Z2 : public ValueInterface static const int N_WORDS = ((K + 7) / 8 + sizeof(mp_limb_t) - 1) / sizeof(mp_limb_t); static const int N_LIMB_BITS = 8 * sizeof(mp_limb_t); - static const uint64_t UPPER_MASK = - ((K % N_LIMB_BITS) == 0) ? -1 : -1 + (1LL << (K % N_LIMB_BITS)); + static const uint64_t UPPER_MASK = uint64_t(-1LL) >> (N_LIMB_BITS - 1 - (K - 1) % N_LIMB_BITS); mp_limb_t a[N_WORDS]; @@ -51,9 +51,10 @@ class Z2 : public ValueInterface static const int N_BYTES = (K + 7) / 8; static int size() { return N_BYTES; } + static int size_in_limbs() { return N_WORDS; } static int t() { return 0; } - static char type_char() { return 'Z'; } + static char type_char() { return 'R'; } static string type_string() { return "Z2^" + to_string(int(N_BITS)); } static DataFieldType field_type() { return DATA_INT; } @@ -63,7 +64,9 @@ class Z2 : public ValueInterface template static Z2 Mul(const Z2& x, const Z2& y); - typedef Z2 value_type; + static void reqbl(int n); + static bool allows(Dtype dtype); + typedef Z2 next; Z2() { assign_zero(); } @@ -89,6 +92,9 @@ class Z2 : public ValueInterface bool get_bit(int i) const; const void* get_ptr() const { return a; } + const mp_limb_t* get() const { return a; } + + void convert_destroy(bigint& a) { *this = a; } void negate() { throw not_implemented(); @@ -122,18 +128,32 @@ class Z2 : public ValueInterface template void mul(const Integer& a, const Z2& b) { *this = Z2::Mul(Z2<64>(a), b); } + void mul(const Z2& a) { *this = Z2::Mul(*this, a); } + template void add(octetStream& os) { add(os.consume(size())); } Z2& invert(); + void invert(const Z2& a) { *this = a; invert(); } Z2 sqrRoot(); - bool is_zero() { return *this == Z2(); } + bool is_zero() const { return *this == Z2(); } + bool is_one() const { return *this == 1; } + bool is_bit() const { return is_zero() or is_one(); } + + void SHL(const Z2& a, const bigint& i) { *this = a << i.get_ui(); } + void SHR(const Z2& a, const bigint& i) { *this = a >> i.get_ui(); } + + void AND(const Z2& a, const Z2& b); + void OR(const Z2& a, const Z2& b); + void XOR(const Z2& a, const Z2& b); void randomize(PRNG& G); void almost_randomize(PRNG& G) { randomize(G); } + void force_to_bit() { throw runtime_error("impossible"); } + void pack(octetStream& o) const; void unpack(octetStream& o); @@ -144,6 +164,53 @@ class Z2 : public ValueInterface friend ostream& operator<<(ostream& o, const Z2& x); }; +template +class SignedZ2 : public Z2 +{ +public: + SignedZ2() + { + } + + template + SignedZ2(const SignedZ2& other) : Z2(other) + { + if (K < L and other.negative()) + { + this->a[Z2::N_WORDS - 1] |= ~Z2::UPPER_MASK; + for (int i = Z2::N_WORDS; i < this->N_WORDS; i++) + this->a[i] = -1; + } + } + + SignedZ2(const Integer& other) : SignedZ2(SignedZ2<64>(other)) + { + } + + template + SignedZ2(const T& other) : + Z2(other) + { + } + + bool negative() const + { + return this->a[this->N_WORDS - 1] & 1ll << ((K - 1) % (8 * sizeof(mp_limb_t))); + } + + SignedZ2 operator-() const + { + return SignedZ2() - *this; + } + + SignedZ2 operator-(const SignedZ2& other) const + { + return Z2::operator-(other); + } + + void output(ostream& s, bool human = true) const; +}; + template inline Z2 Z2::operator+(const Z2& other) const { @@ -239,4 +306,45 @@ void Z2::unpack(octetStream& o) o.consume((octet*)a, N_BYTES); } +template +void to_gfp(Z2& res, const bigint& a) +{ + res = a; +} + +template +SignedZ2 abs(const SignedZ2& x) +{ + if (x.negative()) + return -x; + else + return x; +} + +template +void SignedZ2::output(ostream& s, bool human) const +{ + if (human) + { + bigint::tmp = *this; + s << bigint::tmp; + } + else + Z2::output(s, false); +} + +template +ostream& operator<<(ostream& o, const SignedZ2& x) +{ + x.output(o, true); + return o; +} + +template +inline void to_signed_bigint(bigint& res, const SignedZ2& x, int n) +{ + bigint tmp = x; + to_signed_bigint(res, tmp, n); +} + #endif /* MATH_Z2K_H_ */ diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index 38bd195a9..7ec39df8e 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -7,6 +7,7 @@ void Zp_Data::init(const bigint& p,bool mont) { pr=p; mask=(1<<((mpz_sizeinbase(pr.get_mpz_t(),2)-1)%(8*sizeof(mp_limb_t))))-1; pr_byte_length = numBytes(pr); + pr_bit_length = numBits(pr); montgomery=mont; t=mpz_size(pr.get_mpz_t()); @@ -41,22 +42,6 @@ void Zp_Data::init(const bigint& p,bool mont) } -void Zp_Data::assign(const Zp_Data& Zp) -{ pr=Zp.pr; - mask=Zp.mask; - pr_byte_length = Zp.pr_byte_length; - - montgomery=Zp.montgomery; - t=Zp.t; - mpn_copyi(R,Zp.R,t); - mpn_copyi(R2,Zp.R2,t); - mpn_copyi(R3,Zp.R3,t); - pi=Zp.pi; - - mpn_copyi(prA,Zp.prA,t+1); -} - - __m128i Zp_Data::get_random128(PRNG& G) { while (true) diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 01c693dc6..9b047e72e 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -46,6 +46,7 @@ class Zp_Data bigint pr; mp_limb_t mask; size_t pr_byte_length; + size_t pr_bit_length; void assign(const Zp_Data& Zp); void init(const bigint& p,bool mont=true); @@ -56,19 +57,16 @@ class Zp_Data void unpack(octetStream& o); // This one does nothing, needed so as to make vectors of Zp_Data - Zp_Data() : montgomery(0), pi(0), mask(0), pr_byte_length(0) { t=MAX_MOD_SZ; } + Zp_Data() : + montgomery(0), pi(0), mask(0), pr_byte_length(0), pr_bit_length(0) + { + t = MAX_MOD_SZ; + } // The main init funciton Zp_Data(const bigint& p,bool mont=true) { init(p,mont); } - Zp_Data(const Zp_Data& Zp) { assign(Zp); } - Zp_Data& operator=(const Zp_Data& Zp) - { if (this!=&Zp) { assign(Zp); } - return *this; - } - ~Zp_Data() { ; } - template void Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const; void Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const; @@ -110,9 +108,9 @@ class Zp_Data template<> inline void Zp_Data::Add<0>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const { - mp_limb_t carry = mpn_add_n(ans,x,y,t); + mp_limb_t carry = mpn_add_n_with_carry(ans,x,y,t); if (carry!=0 || mpn_cmp(ans,prA,t)>=0) - { mpn_sub_n(ans,ans,prA,t); } + { mpn_sub_n_borrow(ans,ans,prA,t); } } template<> @@ -148,10 +146,20 @@ inline void Zp_Data::Add<2>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y #endif } +template +inline void Zp_Data::Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const +{ + mp_limb_t carry = mpn_add_fixed_n_with_carry(ans,x,y); + if (carry!=0 || mpn_cmp(ans,prA,T)>=0) + { mpn_sub_n_borrow(ans,ans,prA,T); } +} + inline void Zp_Data::Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const { switch (t) { + case 4: + return Add<4>(ans, x, y); case 2: return Add<2>(ans, x, y); case 1: @@ -175,9 +183,9 @@ inline void Zp_Data::Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) c template <> inline void Zp_Data::Sub<0>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const { - mp_limb_t borrow = mpn_sub_n(ans,x,y,t); + mp_limb_t borrow = mpn_sub_n_borrow(ans,x,y,t); if (borrow!=0) - mpn_add_n(ans,ans,prA,t); + mpn_add_n_with_carry(ans,ans,prA,t); } inline void Zp_Data::Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const @@ -224,15 +232,23 @@ inline void Zp_Data::Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* inline void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const { + if (not cpu_has_bmi2()) + return Mont_Mult_variable(z, x, y); switch (t) { #ifdef __BMI2__ - case 2: - Mont_Mult_<2>(z, x, y); - break; - case 1: - Mont_Mult_<1>(z, x, y); +#define CASE(N) \ + case N: \ + Mont_Mult_(z, x, y); \ break; + CASE(1) + CASE(2) +#if MAX_MOD_SZ >= 5 + CASE(3) + CASE(4) + CASE(5) +#endif +#undef CASE #endif default: Mont_Mult_variable(z, x, y); diff --git a/Math/bigint.cpp b/Math/bigint.cpp index 1c607149a..d5b7860b1 100644 --- a/Math/bigint.cpp +++ b/Math/bigint.cpp @@ -2,6 +2,7 @@ #include "bigint.h" #include "gfp.h" #include "Integer.h" +#include "Z2k.h" #include "GC/Clear.h" #include "Exceptions/Exceptions.h" @@ -106,18 +107,6 @@ int powerMod(int x,int e,int p) } -bigint::bigint(const gfp& x) -{ - *this = x; -} - -bigint& bigint::operator=(const gfp& x) -{ - to_bigint(*this, x); - return *this; -} - - size_t bigint::report_size(ReportType type) const { size_t res = 0; @@ -145,11 +134,19 @@ int limb_size() return 0; } +bigint::bigint(const Integer& x) : bigint(SignedZ2<64>(x)) +{ +} + + +bigint::bigint(const GC::Clear& x) : bigint(SignedZ2<64>(x)) +{ +} + template mpf_class bigint::get_float(T v, Integer exp, T z, T s) { - bigint tmp; - to_signed_bigint(tmp, v); + bigint tmp = v; mpf_class res = tmp; if (exp > 0) mpf_mul_2exp(res.get_mpf_t(), res.get_mpf_t(), exp.get()); @@ -166,6 +163,17 @@ mpf_class bigint::get_float(T v, Integer exp, T z, T s) return res; } +void to_signed_bigint(bigint& res, const bigint& x, int n) +{ + res = abs(x); + bigint& tmp = bigint::tmp = 1; + tmp <<= n; + tmp -= 1; + res &= tmp; + if (x < 0) + res.negate(); +} + #ifdef REALLOC_POLICE void bigint::lottery() { @@ -177,4 +185,8 @@ void bigint::lottery() template mpf_class bigint::get_float(gfp, Integer, gfp, gfp); template mpf_class bigint::get_float(Integer, Integer, Integer, Integer); +template mpf_class bigint::get_float(Z2<64>, Integer, Z2<64>, Z2<64>); +template mpf_class bigint::get_float(Z2<72>, Integer, Z2<72>, Z2<72>); +template mpf_class bigint::get_float(SignedZ2<64>, Integer, SignedZ2<64>, SignedZ2<64>); +template mpf_class bigint::get_float(SignedZ2<72>, Integer, SignedZ2<72>, SignedZ2<72>); template mpf_class bigint::get_float(GC::Clear, Integer, GC::Clear, GC::Clear); diff --git a/Math/bigint.h b/Math/bigint.h index 8779a334b..19b740398 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -27,6 +27,12 @@ typedef gfp_<0> gfp; class gmp_random; class Integer; template class Z2; +template class SignedZ2; + +namespace GC +{ + class Clear; +} class bigint : public mpz_class { @@ -43,20 +49,28 @@ class bigint : public mpz_class bigint() : mpz_class() {} template bigint(const T& x) : mpz_class(x) {} - bigint(const gfp& x); + template + bigint(const gfp_& x); template bigint(const Z2& x); + template + bigint(const SignedZ2& x); + bigint(const Integer& x); + bigint(const GC::Clear& x); bigint& operator=(int n); bigint& operator=(long n); bigint& operator=(word n); - bigint& operator=(const gfp& other); + template + bigint& operator=(const gfp_& other); void allocate_slots(const bigint& x) { *this = x; } int get_min_alloc() { return get_mpz_t()->_mp_alloc; } void negate() { mpz_neg(get_mpz_t(), get_mpz_t()); } + void mul(const bigint& x, const bigint& y) { *this = x * y; } + #ifdef REALLOC_POLICE ~bigint() { lottery(); } void lottery(); @@ -92,6 +106,8 @@ class bigint : public mpz_class }; +void to_signed_bigint(bigint& res, const bigint& x, int n); + void inline_mpn_zero(mp_limb_t* x, mp_size_t size); void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t size); @@ -122,6 +138,31 @@ bigint::bigint(const Z2& x) mpz_import(get_mpz_t(), Z2::N_WORDS, -1, sizeof(mp_limb_t), 0, 0, x.get_ptr()); } +template +bigint::bigint(const SignedZ2& x) +{ + mpz_import(get_mpz_t(), Z2::N_WORDS, -1, sizeof(mp_limb_t), 0, 0, x.get_ptr()); + if (x.negative()) + { + bigint::tmp = 1; + bigint::tmp <<= K; + *this -= bigint::tmp; + } +} + +template +bigint::bigint(const gfp_& x) +{ + *this = x; +} + +template +bigint& bigint::operator=(const gfp_& x) +{ + to_bigint(*this, x); + return *this; +} + /********************************** * Utility Functions * @@ -226,16 +267,6 @@ inline int Hwt(int N) return result; } -inline void inline_mpn_zero(mp_limb_t* x, mp_size_t size) -{ - avx_memzero(x, size * sizeof(mp_limb_t)); -} - -inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t size) -{ - avx_memcpy(dest, src, size * sizeof(mp_limb_t)); -} - template int limb_size(); diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index d3f1fd32c..b2f915dd9 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -19,7 +19,6 @@ int gf2n_short::l3; int gf2n_short::nterms; word gf2n_short::mask; bool gf2n_short::useC; -bool gf2n_short::rewind = false; word gf2n_short_table[256][256]; @@ -59,7 +58,9 @@ void gf2n_short::init_field(int nn) if (nn == 0) { nn = default_length(); +#ifdef VERBOSE cerr << "Using GF(2^" << nn << ")" << endl; +#endif } gf2n_short::init_tables(); @@ -92,7 +93,7 @@ void gf2n_short::init_field(int nn) mask=(1ULL< PO; - typedef SPDZ> Protocol; - static void init_field(int nn); static int degree() { return n; } static int default_degree() { return 40; } @@ -83,6 +79,8 @@ class gf2n_short static bool allows(Dtype type) { (void) type; return true; } + static const bool invertible = true; + word get() const { return a; } word get_word() const { return a; } @@ -184,6 +182,8 @@ class gf2n_short // compatibility with gfp void almost_randomize(PRNG& G) { randomize(G); } + void force_to_bit() { a &= 1; } + void output(ostream& s,bool human) const; void input(istream& s,bool human); diff --git a/Math/gf2nlong.cpp b/Math/gf2nlong.cpp index c5b1e6e6b..ca94ee3a7 100644 --- a/Math/gf2nlong.cpp +++ b/Math/gf2nlong.cpp @@ -51,7 +51,6 @@ int gf2n_long::nterms; int128 gf2n_long::mask; int128 gf2n_long::lowermask; int128 gf2n_long::uppermask; -bool gf2n_long::rewind = false; #define num_2_fields 1 @@ -68,7 +67,9 @@ void gf2n_long::init_field(int nn) if (nn == 0) { nn = default_length(); +#ifdef VERBOSE cerr << "Using GF(2^" << nn << ")" << endl; +#endif } if (nn!=128) { @@ -257,12 +258,7 @@ void gf2n_long::input(istream& s,bool human) { cout << "IO problem. Empty file?" << endl; throw file_error(); } - //throw end_of_file(); - s.clear(); // unset EOF flag - s.seekg(0); - if (!rewind) - cout << "REWINDING - ONLY FOR BENCHMARKING" << endl; - rewind = true; + throw end_of_file("gf2n_long"); } if (human) diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index 13f11bd96..e7ea492bd 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -86,7 +86,6 @@ class gf2n_long static int n,t1,t2,t3,nterms; static int l0,l1,l2,l3; static int128 mask,lowermask,uppermask; - static bool rewind; /* Assign x[0..2*nwords] to a and reduce it... */ void reduce_trinomial(int128 xh,int128 xl); @@ -97,9 +96,6 @@ class gf2n_long typedef gf2n_long value_type; typedef int128 internal_type; - typedef PrivateOutput PO; - typedef SPDZ> Protocol; - typedef gf2n_long next; void reduce(int128 xh,int128 xl) @@ -132,6 +128,8 @@ class gf2n_long static bool allows(Dtype type) { (void) type; return true; } + static const bool invertible = true; + int128 get() const { return a; } __m128i to_m128i() const { return a.a; } word get_word() const { return _mm_cvtsi128_si64(a.a); } @@ -235,6 +233,8 @@ class gf2n_long // compatibility with gfp void almost_randomize(PRNG& G) { randomize(G); } + void force_to_bit() { a &= 1; } + template T convert() const { return *this; } @@ -289,26 +289,35 @@ inline int128 int128::operator>>(const int& other) const void mul64(word x, word y, word& lo, word& hi); -inline __m128i clmul(__m128i a, __m128i b, int choice) +inline __m128i software_clmul(__m128i a, __m128i b, int choice) { word lo, hi; - mul64(int128(a).get_half(choice & 1), int128(b).get_half((choice & 0x10) >> 4), lo, hi); + mul64(int128(a).get_half(choice & 1), + int128(b).get_half((choice & 0x10) >> 4), lo, hi); return int128(hi, lo).a; } -#ifndef __PCLMUL__ -#undef _mm_clmulepi64_si128 -#define _mm_clmulepi64_si128 clmul +template +inline __m128i clmul(__m128i a, __m128i b) +{ +#ifdef __PCLMUL__ + if (cpu_has_pclmul()) + { + return _mm_clmulepi64_si128(a, b, choice); + } + else #endif + return software_clmul(a, b, choice); +} inline void mul128(__m128i a, __m128i b, __m128i *res1, __m128i *res2) { __m128i tmp3, tmp4, tmp5, tmp6; - tmp3 = _mm_clmulepi64_si128(a, b, 0x00); - tmp4 = _mm_clmulepi64_si128(a, b, 0x10); - tmp5 = _mm_clmulepi64_si128(a, b, 0x01); - tmp6 = _mm_clmulepi64_si128(a, b, 0x11); + tmp3 = clmul<0x00>(a, b); + tmp4 = clmul<0x10>(a, b); + tmp5 = clmul<0x01>(a, b); + tmp6 = clmul<0x11>(a, b); tmp4 = _mm_xor_si128(tmp4, tmp5); tmp5 = _mm_slli_si128(tmp4, 8); diff --git a/Math/gfp.cpp b/Math/gfp.cpp index 5d45fdb47..c1438b493 100644 --- a/Math/gfp.cpp +++ b/Math/gfp.cpp @@ -5,9 +5,9 @@ #include "Exceptions/Exceptions.h" template -void gfp_::init_default(int lgp) +void gfp_::init_default(int lgp, bool mont) { - init_field(SPDZ_Data_Setup_Primes(lgp)); + init_field(SPDZ_Data_Setup_Primes(lgp), mont); } template @@ -182,8 +182,9 @@ void to_signed_bigint(bigint& ans, const gfp& x) // get sign and abs(x) bigint& p_half = bigint::tmp = (gfp::pr()-1)/2; if (mpz_cmp(ans.get_mpz_t(), p_half.get_mpz_t()) > 0) - ans = gfp::pr() - ans; + ans -= gfp::pr(); } template class gfp_<0>; template class gfp_<1>; +template class gfp_<2>; diff --git a/Math/gfp.h b/Math/gfp.h index 5dea31b6e..1b22de721 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -32,13 +32,12 @@ class gfp_ public: typedef gfp_ value_type; - typedef PrivateOutput PO; typedef gfp_ next; static void init_field(const bigint& p,bool mont=true) { ZpD.init(p,mont); } - static void init_default(int lgp); + static void init_default(int lgp, bool mont = true); static bigint pr() { return ZpD.pr; } @@ -92,13 +91,8 @@ class gfp_ gfp_(const void* buffer) { assign((char*)buffer); } template gfp_(const gfp_& x); - - ~gfp_() { ; } - - gfp_& operator=(const gfp_& g) - { if (&g!=this) { a=g.a; } - return *this; - } + template + gfp_(const SignedZ2& other); gfp_& operator=(const __m128i other) { @@ -218,6 +212,8 @@ class gfp_ gfp_ operator<<(int i) { gfp_ res; res.SHL(*this, i); return res; } gfp_ operator>>(int i) { gfp_ res; res.SHR(*this, i); return res; } + void force_to_bit() { throw runtime_error("impossible"); } + // Pack and unpack in native format // i.e. Dont care about conversion to human readable form void pack(octetStream& o) const @@ -236,6 +232,7 @@ class gfp_ typedef gfp_<0> gfp; typedef gfp_<1> gfp1; +typedef gfp_<2> gfp2; void to_signed_bigint(bigint& ans,const gfp& x); @@ -250,4 +247,14 @@ gfp_::gfp_(const gfp_& x) *this = bigint::tmp; } +template +template +gfp_::gfp_(const SignedZ2& other) +{ + if (K >= ZpD.pr_bit_length) + *this = bigint::tmp = other; + else + a.convert(abs(other).get(), other.size_in_limbs(), ZpD, other.negative()); +} + #endif diff --git a/Math/modp.cpp b/Math/modp.cpp index 84f38f3ce..8bb8ae58f 100644 --- a/Math/modp.cpp +++ b/Math/modp.cpp @@ -3,8 +3,6 @@ #include "Exceptions/Exceptions.h" -bool modp::rewind = false; - /*********************************************************************** * The following functions remain the same in Real and Montgomery rep * ***********************************************************************/ @@ -136,10 +134,20 @@ void modp::convert_destroy(bigint& xx, const Zp_Data& ZpD) { xx %= ZpD.pr; - if (xx<0) { xx+=ZpD.pr; } //mpz_mod(xx.get_mpz_t(),x.get_mpz_t(),ZpD.pr.get_mpz_t()); - inline_mpn_zero(x, ZpD.t); - inline_mpn_copyi(x, xx.get_mpz_t()->_mp_d, xx.get_mpz_t()->_mp_size); + convert(xx.get_mpz_t()->_mp_d, abs(xx.get_mpz_t()->_mp_size), ZpD, xx < 0); +} + +void modp::convert(const mp_limb_t* source, mp_size_t size, const Zp_Data& ZpD, bool negative) +{ + assert(size <= ZpD.t); + if (negative) + mpn_sub(x, ZpD.prA, ZpD.t, source, size); + else + { + inline_mpn_zero(x + size, ZpD.t - size); + inline_mpn_copyi(x, source, size); + } if (ZpD.montgomery) ZpD.Mont_Mult(x, x, ZpD.R2); } diff --git a/Math/modp.h b/Math/modp.h index 9b1d76b8b..b1ac65f22 100644 --- a/Math/modp.h +++ b/Math/modp.h @@ -24,8 +24,6 @@ void to_bigint(bigint& ans,const modp& x,const Zp_Data& ZpD,bool reduce=true); class modp { - static bool rewind; - mp_limb_t x[MAX_MOD_SZ]; public: @@ -36,15 +34,11 @@ class modp // use mem* functions instead of mpn_*, so the compiler can optimize modp() { avx_memzero(x, sizeof(x)); } - modp(const modp& y) - { memcpy(x, y.x, sizeof(x)); } - modp& operator=(const modp& y) - { if (this!=&y) { memcpy(x, y.x, sizeof(x)); } - return *this; - } void assign(const char* buffer, int t) { memcpy(x, buffer, t * sizeof(mp_limb_t)); } + void convert(const mp_limb_t* source, mp_size_t size, const Zp_Data& ZpD, + bool negative = false); void convert_destroy(bigint& source, const Zp_Data& ZpD); void convert_destroy(int source, const Zp_Data& ZpD) { to_modp(*this, source, ZpD); } diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h index 3ce7f83f0..1202f8eb5 100644 --- a/Math/mpn_fixed.h +++ b/Math/mpn_fixed.h @@ -9,8 +9,20 @@ #include #include #include +#include -#include "bigint.h" +#include "Tools/avx_memcpy.h" +#include "Tools/cpu_support.h" + +inline void inline_mpn_zero(mp_limb_t* x, mp_size_t size) +{ + avx_memzero(x, size * sizeof(mp_limb_t)); +} + +inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t size) +{ + avx_memcpy(dest, src, size * sizeof(mp_limb_t)); +} inline void debug_print(const char* name, const mp_limb_t* x, int n) { @@ -23,10 +35,13 @@ inline void debug_print(const char* name, const mp_limb_t* x, int n) #endif } +template +mp_limb_t mpn_add_fixed_n_with_carry(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y); + template inline void mpn_add_fixed_n(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) { - mpn_add(res, x, N, y, N); + mpn_add_fixed_n_with_carry(res, x, y); } template <> @@ -85,16 +100,51 @@ inline void mpn_add_fixed_n<4>(mp_limb_t* res, const mp_limb_t* x, const mp_limb ); } +inline mp_limb_t mpn_add_n_with_carry(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y, int n) +{ +#ifdef __ADX__ + if (cpu_has_adx()) + { + char carry = 0; + for (int i = 0; i < n; i++) + carry = _addcarryx_u64(carry, x[i], y[i], (unsigned long long*)&res[i]); + return carry; + } + else +#endif + return mpn_add_n(res, x, y, n); +} + +template +mp_limb_t mpn_add_fixed_n_with_carry(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + return mpn_add_n_with_carry(res, x, y, N); +} + +inline mp_limb_t mpn_sub_n_borrow(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y, int n) +{ +#ifndef __clang__ +#if __GNUC__ < 7 + // GCC 6 can't handle the code below + return mpn_sub_n(res, x, y, n); +#endif +#endif + char borrow = 0; + for (int i = 0; i < n; i++) + borrow = _subborrow_u64(borrow, x[i], y[i], (unsigned long long*)&res[i]); + return borrow; +} + template inline void mpn_sub_fixed_n(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) { - mpn_sub(res, x, N, y, N); + mpn_sub_n_borrow(res, x, y, N); } template inline mp_limb_t mpn_sub_fixed_n_borrow(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) { - return mpn_sub(res, x, N, y, N); + return mpn_sub_n_borrow(res, x, y, N); } template <> @@ -188,8 +238,9 @@ inline void mpn_add_n_use_fixed(mp_limb_t* res, const mp_limb_t* x, const mp_lim CASE(2); CASE(3); CASE(4); +#undef CASE default: - mpn_add_n(res, x, y, n); + mpn_add_n_with_carry(res, x, y, n); break; } } @@ -235,6 +286,15 @@ inline void mpn_addmul_1_fixed_(mp_limb_t* res, const mp_limb_t* y, mp_limb_t x) memcpy(tmp, y, M * sizeof(mp_limb_t)); mpn_addmul_1(res, tmp, L, x); } + +template +inline void mpn_mul_1_fixed(mp_limb_t* res, const mp_limb_t* y, mp_limb_t x) +{ + mp_limb_t tmp[L]; + memset(tmp, 0, sizeof(tmp)); + memcpy(tmp, y, M * sizeof(mp_limb_t)); + mpn_mul_1(res, tmp, L, x); +} #endif template @@ -254,6 +314,20 @@ inline void mpn_mul_fixed_(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* inline_mpn_copyi(res, tmp, L); } +template <> +inline void mpn_mul_fixed_<1,1,1>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + *res = *x * *y; +} + +template <> +inline void mpn_mul_fixed_<2,2,2>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) +{ + mp_limb_t* tmp = res; + mpn_mul_1_fixed<2,2>(tmp, y, x[0]); + mpn_addmul_1_fixed_<1,1>(tmp + 1, y, x[1]); +} + template <> inline void mpn_mul_fixed_<3,3,3>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) { diff --git a/Networking/Player.cpp b/Networking/Player.cpp index c48b998a6..5bdf67ac1 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -394,10 +394,10 @@ void Player::exchange_relative(int offset, octetStream& o) const template -void MultiPlayer::pass_around(octetStream& o, int offset) const +void MultiPlayer::pass_around(octetStream& o, octetStream& to_receive, int offset) const { TimeScope ts(comm_stats["Passing around"].add(o)); - o.exchange(sockets.at(get_player(offset)), sockets.at(get_player(-offset))); + o.exchange(sockets.at(get_player(offset)), sockets.at(get_player(-offset)), to_receive); sent += o.get_length(); } @@ -742,5 +742,13 @@ NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other) return *this; } +size_t NamedCommStats::total_data() +{ + size_t res = 0; + for (auto& x : *this) + res += x.second.data; + return res; +} + template class MultiPlayer; template class MultiPlayer; diff --git a/Networking/Player.h b/Networking/Player.h index b207d80d1..eb502b908 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -111,6 +111,7 @@ class NamedCommStats : public map { public: NamedCommStats& operator+=(const NamedCommStats& other); + size_t total_data(); }; class PlayerBase @@ -174,7 +175,8 @@ class Player : public PlayerBase virtual void exchange_no_stats(int other, const octetStream& to_send, octetStream& ot_receive) const = 0; void exchange(int other, octetStream& o) const; void exchange_relative(int offset, octetStream& o) const; - virtual void pass_around(octetStream& o, int offset = 1) const = 0; + void pass_around(octetStream& o, int offset = 1) const { pass_around(o, o, offset); } + virtual void pass_around(octetStream& to_send, octetStream& to_receive, int offset) const = 0; /* Broadcast and Receive data to/from all players * - Assumes o[player_no] contains the thing broadcast by me @@ -233,7 +235,7 @@ class MultiPlayer : public Player void exchange_no_stats(int other, const octetStream& to_send, octetStream& ot_receive) const; // send to next and receive from previous player - void pass_around(octetStream& o, int offset = 1) const; + void pass_around(octetStream& to_send, octetStream& to_receive, int offset) const; // Receive one from player i diff --git a/Networking/sockets.h b/Networking/sockets.h index 07e4a8ceb..9ad7efa2e 100644 --- a/Networking/sockets.h +++ b/Networking/sockets.h @@ -98,7 +98,10 @@ inline void receive(int socket,octet *msg,size_t len) int fail = 0; while (len-i>0) { int j=recv(socket,msg+i,len-i,0); - if (j<0) + // success first + if (j > 0) + i = i + j; + else if (j < 0) { if (errno == EAGAIN or errno == EINTR) { @@ -114,7 +117,7 @@ inline void receive(int socket,octet *msg,size_t len) { error("Receiving error - 1"); } } else - i=i+j; + throw runtime_error("connection closed down"); } } diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index c90ac192b..49b42366b 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -70,6 +70,9 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector& os, OT_ROLE rol void BaseOT::exec_base(bool new_receiver_inputs) { + if (not cpu_has_avx()) + throw runtime_error("SimpleOT needs AVX support"); + int i, j, k; size_t len; PRNG G; diff --git a/OT/BitMatrix.cpp b/OT/BitMatrix.cpp index 42826713f..38ae9181f 100644 --- a/OT/BitMatrix.cpp +++ b/OT/BitMatrix.cpp @@ -791,22 +791,29 @@ void Slice::unpack(octetStream& os) } #define M(N,L) Matrix, Z2 > > -#undef X -#define X(N,L) \ + +#undef XXX +#define XXX(T,N,L) \ template class Matrix, Z2 > >; \ -template M(N,L)& M(N,L)::operator=(const Matrix& other); \ template class Slice, Z2 > > >; \ -template void Slice, Z2 > > >::randomize >(int row, PRNG& G); \ template Slice, Z2 > > >& Slice< \ - Matrix, Z2 > > >::rsub >( \ + Matrix, Z2 > > >::rsub( \ Slice, Z2 > > >& other); \ template Slice, Z2 > > >& Slice< \ - Matrix, Z2 > > >::sub >(BitVector& other, int repeat); \ + Matrix, Z2 > > >::sub(BitVector& other, int repeat); \ template void Slice, Z2 > > >::conditional_add< \ - Z2 >(BitVector& conditions, \ + T>(BitVector& conditions, \ Matrix, Z2 > >& other, bool useOffset); \ +#undef X +#define X(N,L) \ +template M(N,L)& M(N,L)::operator=(const Matrix& other); \ +template void Slice, Z2 > > >::randomize >(int row, PRNG& G); \ +XXX(Z2, N, L) + //X(96, 160) +XXX(SignedZ2<64>, 64, 64) +XXX(SignedZ2<72>, 72, 72) Y(64, 64) Y(64, 48) diff --git a/OT/NPartyTripleGenerator.cpp b/OT/NPartyTripleGenerator.cpp index 4bb5c8e5e..80f906620 100644 --- a/OT/NPartyTripleGenerator.cpp +++ b/OT/NPartyTripleGenerator.cpp @@ -4,6 +4,8 @@ #include "OT/OTMultiplier.h" #include "Math/gfp.h" #include "Math/Share.h" +#include "Math/SemiShare.h" +#include "Math/Semi2kShare.h" #include "Math/operators.h" #include "Auth/Subroutines.h" #include "Auth/MAC_Check.h" @@ -12,7 +14,11 @@ #include "OT/Triple.hpp" #include "OT/Rectangle.hpp" #include "Auth/MAC_Check.hpp" +#include "Auth/SemiMC.h" #include "Processor/MascotPrep.hpp" +#include "Processor/ReplicatedInput.hpp" +#include "Processor/SemiInput.hpp" +#include "Processor/Input.hpp" #include #include @@ -31,6 +37,15 @@ void* run_ot_thread(void* ptr) */ template NPartyTripleGenerator::NPartyTripleGenerator(OTTripleSetup& setup, + const Names& names, int thread_num, int _nTriples, int nloops, + MascotParams& machine, Player* parentPlayer) : + OTTripleGenerator(setup, names, thread_num, _nTriples, nloops, + machine, parentPlayer) +{ +} + +template +OTTripleGenerator::OTTripleGenerator(OTTripleSetup& setup, const Names& names, int thread_num, int _nTriples, int nloops, MascotParams& machine, Player* parentPlayer) : globalPlayer(parentPlayer ? *parentPlayer : *new PlainPlayer(names, @@ -91,7 +106,7 @@ NPartyTripleGenerator::NPartyTripleGenerator(OTTripleSetup& setup, } template -NPartyTripleGenerator::~NPartyTripleGenerator() +OTTripleGenerator::~OTTripleGenerator() { // wait for threads to finish for (int i = 0; i < nparties-1; i++) @@ -117,7 +132,7 @@ NPartyTripleGenerator::~NPartyTripleGenerator() } template -typename T::Multiplier* NPartyTripleGenerator::new_multiplier(int i) +typename T::Multiplier* OTTripleGenerator::new_multiplier(int i) { return new typename T::Multiplier(*this, i); } @@ -127,6 +142,13 @@ void NPartyTripleGenerator::generate() { bigint::init_thread(); + auto& timers = this->timers; + auto& machine = this->machine; + auto& my_num = this->my_num; + auto& thread_num = this->thread_num; + auto& nTriples = this->nTriples; + auto& outputFile = this->outputFile; + timers["Generator thread"].start(); // add up the shares from each thread and write to file @@ -161,9 +183,18 @@ void NPartyTripleGenerator::generateInputs(int player) { typedef open_type T; + auto& machine = this->machine; + auto& nTriplesPerLoop = this->nTriplesPerLoop; + auto& valueBits = this->valueBits; + auto& share_prg = this->share_prg; + auto& field_size = this->field_size; + auto& ot_multipliers = this->ot_multipliers; + auto& nparties = this->nparties; + auto& globalPlayer = this->globalPlayer; + // extra value for sacrifice int toCheck = nTriplesPerLoop + 1; - signal_multipliers({player, toCheck}); + this->signal_multipliers({player, toCheck}); bool mine = player == globalPlayer.my_num(); valueBits.resize(1); @@ -171,23 +202,32 @@ void NPartyTripleGenerator::generateInputs(int player) { valueBits[0].resize(toCheck * field_size); valueBits[0].template randomize_blocks(share_prg); - signal_multipliers({}); + this->signal_multipliers({}); } - wait_for_multipliers(); + this->wait_for_multipliers(); GlobalPRNG G(globalPlayer); Share check_sum; inputs.resize(toCheck); - auto mac_key = machine.get_mac_key(); + auto mac_key = machine.template get_mac_key(); + SemiInput> input(0, globalPlayer); + input.reset_all(globalPlayer); + vector secrets(toCheck); + if (mine) + for (int j = 0; j < toCheck; j++) + { + secrets[j] = valueBits[0].template get_portion(j); + input.add_mine(secrets[j]); + } + input.exchange(); for (int j = 0; j < toCheck; j++) { T share, mac_sum; - if (mine) - share = valueBits[0].template get_portion(j); + share = input.finalize(player); if (mine) { - mac_sum = share * mac_key; + mac_sum = secrets[j] * mac_key; for (int i = 0; i < nparties-1; i++) mac_sum += (ot_multipliers[i])->input_macs[j]; } @@ -196,7 +236,7 @@ void NPartyTripleGenerator::generateInputs(int player) int i_thread = player - (player > globalPlayer.my_num() ? 1 : 0); mac_sum = (ot_multipliers[i_thread])->input_macs[j]; } - inputs[j] = {{share, mac_sum}, share}; + inputs[j] = {{share, mac_sum}, secrets[j]}; check_sum += inputs[j].share * G.get(); } inputs.resize(nTriplesPerLoop); @@ -276,8 +316,6 @@ void NPartyTripleGenerator>::generateBits() template void NPartyTripleGenerator::generateBits() { - (void)ot_multipliers; - (void)outputFile; throw not_implemented(); } @@ -285,8 +323,18 @@ template template void NPartyTripleGenerator::generateTriplesZ2k() { - (void) outputFile; - signal_multipliers(DATA_TRIPLE); + auto& timers = this->timers; + auto& machine = this->machine; + auto& nTriplesPerLoop = this->nTriplesPerLoop; + auto& valueBits = this->valueBits; + auto& share_prg = this->share_prg; + auto& ot_multipliers = this->ot_multipliers; + auto& nparties = this->nparties; + auto& globalPlayer = this->globalPlayer; + auto& nloops = this->nloops; + auto& b_padded_bits = this->b_padded_bits; + + this->signal_multipliers(DATA_TRIPLE); const int TAU = Spdz2kMultiplier::TAU; const int TAU_ROUNDED = (TAU + 7) / 8 * 8; @@ -297,13 +345,14 @@ void NPartyTripleGenerator::generateTriplesZ2k() b_padded_bits.resize(8 * Z2::N_BYTES * (nTriplesPerLoop + 1)); vector< PlainTriple_, Z2, 2> > amplifiedTriples(nTriplesPerLoop); uncheckedTriples.resize(nTriplesPerLoop); - MAC_Check_Z2k, Z2, Z2, Share> > MC(machine.get_mac_key >()); + MAC_Check_Z2k, Z2, Z2, Share> > MC( + machine.template get_mac_key >()); - start_progress(); + this->start_progress(); for (int k = 0; k < nloops; k++) { - print_progress(k); + this->print_progress(k); for (int j = 0; j < 2; j++) valueBits[j].template randomize_blocks(share_prg); @@ -315,8 +364,8 @@ void NPartyTripleGenerator::generateTriplesZ2k() } timers["OTs"].start(); - signal_multipliers({}); - wait_for_multipliers(); + this->signal_multipliers({}); + this->wait_for_multipliers(); timers["OTs"].stop(); octet seed[SEED_SIZE]; @@ -358,8 +407,8 @@ void NPartyTripleGenerator::generateTriplesZ2k() amplifiedTriples[j].to(valueBits, j); } - signal_multipliers({}); - wait_for_multipliers(); + this->signal_multipliers({}); + this->wait_for_multipliers(); for (int j = 0; j < nTriplesPerLoop; j++) { @@ -412,11 +461,91 @@ void NPartyTripleGenerator>::generateTriples() this->generateTriplesZ2k<66, 48>(); } +template +void OTTripleGenerator::generatePlainTriples() +{ + machine.set_passive(); + machine.output = false; + signal_multipliers(DATA_TRIPLE); + + valueBits.resize(3); + for (int i = 0; i < 3; i++) + valueBits[i].resize(field_size * nPreampTriplesPerLoop); + + start_progress(); + for (int i = 0; i < nloops; i++) + plainTripleRound(i); +} + +template +void OTTripleGenerator::plainTripleRound(int k) +{ + typedef typename U::open_type T; + + if (not (machine.amplify or machine.output)) + plainTriples.resize(nPreampTriplesPerLoop); + + print_progress(k); + + for (int j = 0; j < 2; j++) + valueBits[j].template randomize_blocks(share_prg); + + timers["OTs"].start(); + for (int i = 0; i < nparties-1; i++) + ot_multipliers[i]->inbox.push({}); + this->wait_for_multipliers(); + timers["OTs"].stop(); + + for (int j = 0; j < nPreampTriplesPerLoop; j++) + { + T a((char*)valueBits[0].get_ptr() + j * T::size()); + T b((char*)valueBits[1].get_ptr() + j / nAmplify * T::size()); + T c = a * b; + timers["Triple computation"].start(); + for (int i = 0; i < nparties-1; i++) + { + c += dynamic_cast(ot_multipliers[i])->c_output[j]; + } + timers["Triple computation"].stop(); + if (machine.amplify) + { + preampTriples[j/nAmplify].a[j%nAmplify] = a; + preampTriples[j/nAmplify].b = b; + preampTriples[j/nAmplify].c[j%nAmplify] = c; + } + else if (machine.output) + { + timers["Writing"].start(); + a.output(outputFile, false); + b.output(outputFile, false); + c.output(outputFile, false); + timers["Writing"].stop(); + } + else + { + plainTriples[j] = {{a, b, c}}; + } + } +} + template void NPartyTripleGenerator::generateTriples() { typedef typename U::open_type T; + auto& timers = this->timers; + auto& machine = this->machine; + auto& nTriplesPerLoop = this->nTriplesPerLoop; + auto& valueBits = this->valueBits; + auto& ot_multipliers = this->ot_multipliers; + auto& nparties = this->nparties; + auto& globalPlayer = this->globalPlayer; + auto& nloops = this->nloops; + auto& preampTriples = this->preampTriples; + auto& outputFile = this->outputFile; + auto& field_size = this->field_size; + auto& nPreampTriplesPerLoop = this->nPreampTriplesPerLoop; + for (int i = 0; i < nparties-1; i++) ot_multipliers[i]->inbox.push(DATA_TRIPLE); @@ -424,9 +553,8 @@ void NPartyTripleGenerator::generateTriples() for (int i = 0; i < 2; i++) valueBits[2*i].resize(field_size * nPreampTriplesPerLoop); valueBits[1].resize(field_size * nTriplesPerLoop); - vector< PlainTriple > preampTriples; vector< PlainTriple > amplifiedTriples; - MAC_Check MC(machine.get_mac_key()); + MAC_Check MC(machine.template get_mac_key()); if (machine.amplify) preampTriples.resize(nTriplesPerLoop); @@ -436,47 +564,11 @@ void NPartyTripleGenerator::generateTriples() uncheckedTriples.resize(nTriplesPerLoop); } - start_progress(); + this->start_progress(); for (int k = 0; k < nloops; k++) { - print_progress(k); - - for (int j = 0; j < 2; j++) - valueBits[j].template randomize_blocks(share_prg); - - timers["OTs"].start(); - for (int i = 0; i < nparties-1; i++) - ot_multipliers[i]->inbox.push({}); - wait_for_multipliers(); - timers["OTs"].stop(); - - for (int j = 0; j < nPreampTriplesPerLoop; j++) - { - T a((char*)valueBits[0].get_ptr() + j * T::size()); - T b((char*)valueBits[1].get_ptr() + j / nAmplify * T::size()); - T c = a * b; - timers["Triple computation"].start(); - for (int i = 0; i < nparties-1; i++) - { - c += ((MascotMultiplier*)ot_multipliers[i])->c_output[j]; - } - timers["Triple computation"].stop(); - if (machine.amplify) - { - preampTriples[j/nAmplify].a[j%nAmplify] = a; - preampTriples[j/nAmplify].b = b; - preampTriples[j/nAmplify].c[j%nAmplify] = c; - } - else if (machine.output) - { - timers["Writing"].start(); - a.output(outputFile, false); - b.output(outputFile, false); - c.output(outputFile, false); - timers["Writing"].stop(); - } - } + this->plainTripleRound(); if (machine.amplify) { @@ -507,7 +599,7 @@ void NPartyTripleGenerator::generateTriples() for (int i = 0; i < nparties-1; i++) ot_multipliers[i]->inbox.push({}); timers["Authentication OTs"].start(); - wait_for_multipliers(); + this->wait_for_multipliers(); timers["Authentication OTs"].stop(); for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++) @@ -535,6 +627,11 @@ template void NPartyTripleGenerator::sacrifice( vector >& uncheckedTriples, typename T::MAC_Check& MC, PRNG& G) { + auto& machine = this->machine; + auto& nTriplesPerLoop = this->nTriplesPerLoop; + auto& globalPlayer = this->globalPlayer; + auto& outputFile = this->outputFile; + vector maskedAs(nTriplesPerLoop); vector > maskedTriples(nTriplesPerLoop); for (int j = 0; j < nTriplesPerLoop; j++) @@ -570,6 +667,11 @@ void NPartyTripleGenerator::sacrificeZ2k( typedef sacri_type T; typedef open_type V; + auto& machine = this->machine; + auto& nTriplesPerLoop = this->nTriplesPerLoop; + auto& globalPlayer = this->globalPlayer; + auto& outputFile = this->outputFile; + vector< Share > maskedAs(nTriplesPerLoop); vector > maskedTriples(nTriplesPerLoop); for (int j = 0; j < nTriplesPerLoop; j++) @@ -653,7 +755,7 @@ void NPartyTripleGenerator::generateBitsFromTriples( } template -void NPartyTripleGenerator::start_progress() +void OTTripleGenerator::start_progress() { wait_for_multipliers(); lock(); @@ -663,7 +765,7 @@ void NPartyTripleGenerator::start_progress() } template -void NPartyTripleGenerator::print_progress(int k) +void OTTripleGenerator::print_progress(int k) { if (thread_num == 0 && my_num == 0) { @@ -704,32 +806,28 @@ void MascotGenerator::wait() } template -void NPartyTripleGenerator::signal_multipliers(MultJob job) +void OTTripleGenerator::signal_multipliers(MultJob job) { for (int i = 0; i < nparties-1; i++) ot_multipliers[i]->inbox.push(job); } template -void NPartyTripleGenerator::wait_for_multipliers() +void OTTripleGenerator::wait_for_multipliers() { for (int i = 0; i < nparties-1; i++) ot_multipliers[i]->outbox.pop(); } -template -size_t NPartyTripleGenerator::data_sent() -{ - size_t res = globalPlayer.sent; - for (auto& player : players) - res += player->sent; - return res; -} - template class NPartyTripleGenerator>; template class NPartyTripleGenerator>; +template class OTTripleGenerator>; +template class OTTripleGenerator>; +template class OTTripleGenerator>; +template class OTTripleGenerator>; + template class NPartyTripleGenerator>; template class NPartyTripleGenerator>; template class NPartyTripleGenerator>; diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index 0900788be..9633e1da4 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -20,6 +20,8 @@ template class ShareTriple_; +template +class PlainTriple; template using ShareTriple = ShareTriple_; @@ -48,12 +50,13 @@ class MascotGenerator }; template -class NPartyTripleGenerator : public MascotGenerator +class OTTripleGenerator : public MascotGenerator { typedef typename T::open_type open_type; typedef typename T::mac_key_type mac_key_type; typedef typename T::sacri_type sacri_type; +protected: //OTTripleSetup* setup; Player& globalPlayer; Player* parentPlayer; @@ -67,21 +70,6 @@ class NPartyTripleGenerator : public MascotGenerator SeededPRNG share_prg; - template - void generateTriplesZ2k(); - - void generateTriples(); - void generateBits(); - template - void generateBitsFromTriples(vector >& triples, - W& MC, ofstream& outputFile); - - void sacrifice(vector >& uncheckedTriples, - typename T::MAC_Check& MC, PRNG& G); - template - void sacrificeZ2k(vector >& uncheckedTriples, - U& MC, PRNG& G); - void start_progress(); void print_progress(int k); @@ -99,9 +87,6 @@ class NPartyTripleGenerator : public MascotGenerator vector< vector< vector > > baseSenderInputs; vector< vector > baseReceiverOutputs; vector valueBits; - vector< ShareTriple_ > uncheckedTriples; - vector bits; - vector>> inputs; BitVector b_padded_bits; int my_num; @@ -115,14 +100,64 @@ class NPartyTripleGenerator : public MascotGenerator MascotParams& machine; + vector> preampTriples; + vector> plainTriples; + + OTTripleGenerator(OTTripleSetup& setup, const Names& names, + int thread_num, int nTriples, int nloops, MascotParams& machine, + Player* parentPlayer = 0); + ~OTTripleGenerator(); + + void generate() { throw not_implemented(); } + + void generatePlainTriples(); + void plainTripleRound(int k = 0); + + size_t data_sent(); +}; + +template +class NPartyTripleGenerator : public OTTripleGenerator +{ + typedef typename T::open_type open_type; + typedef typename T::mac_key_type mac_key_type; + typedef typename T::sacri_type sacri_type; + + template + void generateTriplesZ2k(); + + void generateTriples(); + void generateBits(); + template + void generateBitsFromTriples(vector >& triples, + W& MC, ofstream& outputFile); + + void sacrifice(vector >& uncheckedTriples, + typename T::MAC_Check& MC, PRNG& G); + template + void sacrificeZ2k(vector >& uncheckedTriples, + U& MC, PRNG& G); + +public: + vector< ShareTriple_ > uncheckedTriples; + vector bits; + vector>> inputs; + NPartyTripleGenerator(OTTripleSetup& setup, const Names& names, int thread_num, int nTriples, int nloops, MascotParams& machine, Player* parentPlayer = 0); - ~NPartyTripleGenerator(); + void generate(); void generateInputs(int player); - - size_t data_sent(); }; +template +size_t OTTripleGenerator::data_sent() +{ + size_t res = globalPlayer.sent; + for (auto& player : players) + res += player->sent; + return res; +} + #endif diff --git a/OT/OTExtension.cpp b/OT/OTExtension.cpp index 2e0c5a5dd..e87d919c3 100644 --- a/OT/OTExtension.cpp +++ b/OT/OTExtension.cpp @@ -690,12 +690,6 @@ void OTExtension::check_correlation(int nOTs, delete[] seed; vector os(2); - if (!Check_CPU_support_AES()) - { - cerr << "Not implemented GF(2^128) multiplication in C\n"; - throw not_implemented(); - } - __m128i Delta, x128i; Delta = _mm_load_si128((__m128i*)&(baseReceiverInput.get_ptr()[0])); diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index b5391d060..260783cc7 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -539,23 +539,27 @@ ZZZZ(gfp1) ZZZZ(gf2n_long) ZZZ(Z2<160>, MM) -#undef X -#define X(N,L) \ +#undef XX +#define XX(T,U,N,L) \ template class OTCorrelator, Z2 > > >; \ -template void OTCorrelator, Z2 > > >::correlate >(int start, int slice, \ +template void OTCorrelator, Z2 > > >::correlate(int start, int slice, \ BitVector& newReceiverInput, bool useConstantBase, int repeat); \ -template void OTCorrelator, Z2 > > >::expand >(int start, int slice); \ template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ - vector >& output); \ -template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ - vector >& output); \ -template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ - vector >& output); \ -template void OTExtensionWithMatrix::hash_outputs, Matrix, Z2 > > >(int, \ + vector& output); \ +template void OTExtensionWithMatrix::hash_outputs, Z2 > > >(int, \ std::vector, Z2 > >, std::allocator, Z2 > > > >&, \ Matrix, Z2 > >&); +#undef X +#define X(N,L) \ +template void OTCorrelator, Z2 > > >::expand >(int start, int slice); \ +template void OTCorrelator, Z2 > > >::reduce_squares(unsigned int nTriples, \ + vector >& output); \ +XX(Z2,Z2,N,L) + //X(96, 160) +XX(SignedZ2<64>, SignedZ2<64>, 64, 64) +XX(SignedZ2<72>, SignedZ2<72>, 72, 72) Y(64, 64) Y(64, 48) diff --git a/OT/OTMultiplier.cpp b/OT/OTMultiplier.cpp index 44ea79f37..a9b743b6f 100644 --- a/OT/OTMultiplier.cpp +++ b/OT/OTMultiplier.cpp @@ -9,6 +9,8 @@ #include "OT/NPartyTripleGenerator.h" #include "OT/Rectangle.h" #include "Math/Z2k.h" +#include "Math/SemiShare.h" +#include "Math/Semi2kShare.h" #include "OT/OTVole.hpp" #include "OT/Row.hpp" @@ -19,7 +21,7 @@ //#define OTCORR_TIMER template -OTMultiplier::OTMultiplier(NPartyTripleGenerator& generator, +OTMultiplier::OTMultiplier(OTTripleGenerator& generator, int thread_num) : generator(generator), thread_num(thread_num), rot_ext(128, 128, 0, 1, @@ -32,7 +34,7 @@ OTMultiplier::OTMultiplier(NPartyTripleGenerator& generator, } template -MascotMultiplier::MascotMultiplier(NPartyTripleGenerator>& generator, +MascotMultiplier::MascotMultiplier(OTTripleGenerator>& generator, int thread_num) : OTMultiplier>(generator, thread_num), auth_ot_ext(128, 128, 0, 1, generator.players[thread_num], {}, {}, {}, BOTH, true) @@ -41,7 +43,7 @@ MascotMultiplier::MascotMultiplier(NPartyTripleGenerator>& generator } template -Spdz2kMultiplier::Spdz2kMultiplier(NPartyTripleGenerator>& generator, int thread_num) : +Spdz2kMultiplier::Spdz2kMultiplier(OTTripleGenerator>& generator, int thread_num) : OTMultiplier> (generator, thread_num) { @@ -173,6 +175,15 @@ void Spdz2kMultiplier::init_authenticator(const BitVector& keyBits, input_mac_vole->init(keyBits, senderOutput, receiverOutput); } +template +void SemiMultiplier::after_correlation() +{ + this->otCorrelator.reduce_squares(this->generator.nPreampTriplesPerLoop, + this->c_output); + + this->outbox.push({}); +} + template void MascotMultiplier::after_correlation() { @@ -329,6 +340,12 @@ void OTMultiplier::multiplyForBits() template class OTMultiplier>; template class OTMultiplier>; +template class OTMultiplier>; +template class OTMultiplier>; +template class SemiMultiplier>; +template class SemiMultiplier>; +template class SemiMultiplier>; +template class SemiMultiplier>; template class MascotMultiplier; template class MascotMultiplier; diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index f4796cd73..04628e6ac 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -17,6 +17,8 @@ using namespace std; template class NPartyTripleGenerator; +template +class OTTripleGenerator; class MultJob { @@ -66,13 +68,13 @@ class OTMultiplier : public OTMultiplierMac& baseReceiverOutput) = 0; public: - NPartyTripleGenerator& generator; + OTTripleGenerator& generator; int thread_num; OTExtensionWithMatrix rot_ext; OTCorrelator > otCorrelator; - OTMultiplier(NPartyTripleGenerator& generator, int thread_num); + OTMultiplier(OTTripleGenerator& generator, int thread_num); virtual ~OTMultiplier(); void multiply(); }; @@ -89,7 +91,7 @@ class MascotMultiplier : public OTMultiplier> public: vector c_output; - MascotMultiplier(NPartyTripleGenerator>& generator, int thread_num); + MascotMultiplier(OTTripleGenerator>& generator, int thread_num); void multiplyForInputs(MultJob job); }; @@ -115,8 +117,35 @@ class Spdz2kMultiplier: public OTMultiplier> OTVoleBase, Z2>* mac_vole; OTVoleBase, Z2>* input_mac_vole; - Spdz2kMultiplier(NPartyTripleGenerator& generator, int thread_num); + Spdz2kMultiplier(OTTripleGenerator& generator, int thread_num); ~Spdz2kMultiplier(); }; +template +class SemiMultiplier : public OTMultiplier +{ + void multiplyForInputs(MultJob job) + { + (void) job; + throw not_implemented(); + } + + void after_correlation(); + + void init_authenticator(const BitVector& baseReceiverInput, + const vector< vector >& baseSenderInput, + const vector& baseReceiverOutput) + { + (void) baseReceiverInput, (void) baseReceiverOutput, (void) baseSenderInput; + } + +public: + vector c_output; + + SemiMultiplier(OTTripleGenerator& generator, int i) : + OTMultiplier(generator, i) + { + } +}; + #endif /* OT_OTMULTIPLIER_H_ */ diff --git a/OT/TripleMachine.cpp b/OT/TripleMachine.cpp index af8468ca4..9c8eede88 100644 --- a/OT/TripleMachine.cpp +++ b/OT/TripleMachine.cpp @@ -34,6 +34,11 @@ MascotParams::MascotParams() timerclear(&start); } +void MascotParams::set_passive() +{ + generateMACs = amplify = check = false; +} + TripleMachine::TripleMachine(int argc, const char** argv) : nConnections(1), bonding(0) { diff --git a/OT/TripleMachine.h b/OT/TripleMachine.h index dbd6a4574..c3612a3d1 100644 --- a/OT/TripleMachine.h +++ b/OT/TripleMachine.h @@ -31,6 +31,8 @@ class MascotParams : virtual public OfflineParams MascotParams(); + void set_passive(); + template T get_mac_key(); template diff --git a/Player-Online.cpp b/Player-Online.cpp index 8f2919d49..207837484 100644 --- a/Player-Online.cpp +++ b/Player-Online.cpp @@ -10,5 +10,5 @@ int main(int argc, const char** argv) { ez::ezOptionParser opt; - return spdz_main(argc, argv, opt); + return spdz_main>(argc, argv, opt); } diff --git a/Player-Online.hpp b/Player-Online.hpp index bcf00b027..8bf4e0919 100644 --- a/Player-Online.hpp +++ b/Player-Online.hpp @@ -11,7 +11,7 @@ #include using namespace std; -template +template int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt) { OnlineOptions online_opts(opt, argc, argv); @@ -212,7 +212,7 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt) try #endif { - Machine>(playerno, playerNames, online_opts.progname, memtype, lg2, + Machine(playerno, playerNames, online_opts.progname, memtype, lg2, opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet, opt.get("--threads")->isSet, max_broadcast, opt.get("--encrypted")->isSet, online_opts.live_prep, @@ -231,9 +231,9 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt) #ifndef INSECURE catch(...) { - Machine> machine(playerNames); + Machine machine(playerNames); machine.live_prep = false; - thread_info>::purge_preprocessing(machine); + thread_info::purge_preprocessing(machine); throw; } #endif diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 311b5e19e..a42f784eb 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -103,3 +103,8 @@ void BaseMachine::print_timers() for (map::iterator it = timer.begin(); it != timer.end(); it++) cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl; } + +string BaseMachine::memory_filename(string type_short, int my_number) +{ + return PREP_DIR "Memory-" + type_short + "-P" + to_string(my_number); +} diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 0250ad22f..a5e227d6f 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -34,6 +34,8 @@ class BaseMachine static BaseMachine& s(); + static string memory_filename(string type_short, int my_number); + BaseMachine(); virtual ~BaseMachine() {} diff --git a/Processor/Beaver.h b/Processor/Beaver.h index 834070851..bc88cf5ea 100644 --- a/Processor/Beaver.h +++ b/Processor/Beaver.h @@ -35,6 +35,8 @@ class Beaver : public ProtocolBase typename T::clear prepare_mul(const T& x, const T& y); void exchange(); T finalize_mul(); + + int get_n_relevant_players() { return P.num_players(); } }; #endif /* PROCESSOR_BEAVER_H_ */ diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index cf2040faf..087494b7a 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -1,6 +1,4 @@ #include "Processor/Binary_File_IO.h" -#include "Math/Rep3Share.h" -#include "Math/gfp.h" /* * Provides generalised file read and write methods for arrays of shares. diff --git a/Processor/BrainPrep.h b/Processor/BrainPrep.h new file mode 100644 index 000000000..deb63c8be --- /dev/null +++ b/Processor/BrainPrep.h @@ -0,0 +1,21 @@ +/* + * BrainPrep.h + * + */ + +#ifndef PROCESSOR_BRAINPREP_H_ +#define PROCESSOR_BRAINPREP_H_ + +#include "ReplicatedPrep.h" +#include "Math/BrainShare.h" + +template +class BrainPrep : public RingPrep +{ +public: + BrainPrep(SubProcessor* proc, DataPositions& usage) : + RingPrep(proc, usage) {} + void buffer_triples(); +}; + +#endif /* PROCESSOR_BRAINPREP_H_ */ diff --git a/Processor/BrainPrep.hpp b/Processor/BrainPrep.hpp new file mode 100644 index 000000000..8e0f26ccc --- /dev/null +++ b/Processor/BrainPrep.hpp @@ -0,0 +1,149 @@ +/* + * BrainPrep.cpp + * + */ + +#include "BrainPrep.h" +#include "Processor.h" +#include "Auth/MaliciousRepMC.h" + +template class ZProtocol; + +template +class Zint : public SignedZ2 +{ + typedef SignedZ2 super; + +public: + static string type_string() + { + return "Zint" + to_string(L); + } + + Zint() + { + } + + template + Zint(const T& other) : super(other) + { + } + + void randomize(PRNG& G) + { + *this = G.get>(); + } +}; + +template +class ZShare : public Rep3Share> +{ +public: + typedef ZProtocol Protocol; + typedef ReplicatedMC MAC_Check; + + ZShare() + { + } + + template + ZShare(const FixedVec& other) + { + FixedVec::operator=(other); + } +}; + +template +class ZProtocol : public Replicated>> +{ + typedef Rep3Share> T; + vector random; + SeededPRNG G; + +public: + ZProtocol(Player& P) : Replicated(P) + { + } + + T get_random() + { + if (random.empty()) + { + int buffer_size = 10000; + ReplicatedInput>> input(0, this->P); + input.reset_all(this->P); + for (int i = 0; i < buffer_size; i++) + { + typename U::clear tmp; + tmp.randomize(G); + input.add_mine(tmp); + } + input.exchange(); + for (int i = 0; i < buffer_size; i++) + { + random.push_back({}); + for (int j = 0; j < 3; j++) + random.back() += input.finalize(j); + } + } + + auto res = random.back(); + random.pop_back(); + return res; + } +}; + +template +void BrainPrep::buffer_triples() +{ + if(gfp2::get_ZpD().pr_bit_length + <= ZProtocol::share_type::clear::N_BITS) + throw runtime_error( + to_string(gfp2::get_ZpD().pr_bit_length) + + "-bit prime too short for " + + to_string(ZProtocol::share_type::clear::N_BITS) + + "-bit integer computation"); + typedef Rep3Share pShare; + auto buffer_size = this->buffer_size; + Player& P = this->protocol->P; + vector, 3>> triples; + vector, 3>> check_triples; + DataPositions usage; + HashMaliciousRepMC MC; + vector> masked, checks; + vector opened; + ZProtocol Z_protocol(P); + Replicated p_protocol(P); + generate_triples(triples, buffer_size, &Z_protocol); + generate_triples(check_triples, buffer_size, &p_protocol); + auto t = Create_Random(P); + vector> converted_bs; + for (int i = 0; i < buffer_size; i++) + { + pShare a = triples[i][0]; + converted_bs.push_back(triples[i][1]); + auto& b = converted_bs[i]; + auto& f = check_triples[i][0]; + auto& g = check_triples[i][1]; + masked.push_back(a * t - f); + masked.push_back(b - g); + } + MC.POpen(opened, masked, P); + for (int i = 0; i < buffer_size; i++) + { + auto& b = converted_bs[i]; + pShare c = triples[i][2]; + auto& f = check_triples[i][0]; + auto& h = check_triples[i][2]; + auto& rho = opened[2 * i]; + auto& sigma = opened[2 * i + 1]; + checks.push_back(t * c - h - rho * b - sigma * f); + } + MC.POpen(opened, checks, P); + for (auto& check : opened) + if (check != 0) + throw Offline_Check_Error("triple"); + MC.Check(P); + for (auto& x : triples) + this->triples.push_back({{x[0], x[1], x[2]}}); +} diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index b8a5dd28e..8162e7238 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -89,6 +89,7 @@ class Preprocessing virtual ~Preprocessing() {} virtual void set_protocol(typename T::Protocol& protocol) = 0; + virtual void set_proc(SubProcessor* proc) { (void) proc; } virtual void seekg(DataPositions& pos) { (void) pos; } virtual void prune() {} diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 749d32b69..3c98b3b0e 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -1,12 +1,6 @@ #include "Processor/Data_Files.h" #include "Processor/Processor.h" -#include "Processor/ReplicatedPrep.h" -#include "Processor/MaliciousRepPrep.h" -#include "GC/MaliciousRepSecret.h" -#include "Math/MaliciousRep3Share.h" -#include "Math/ShamirShare.h" -#include "Math/MaliciousShamirShare.h" #include "Processor/MaliciousRepPrep.hpp" //#include "Processor/Replicated.hpp" @@ -25,8 +19,7 @@ template Preprocessing* Preprocessing::get_live_prep(SubProcessor* proc, DataPositions& usage) { - (void) proc, (void) usage; - throw not_implemented(); + return new typename T::LivePrep(proc, usage); } template diff --git a/Processor/Input.h b/Processor/Input.h index a03fd7165..4e433a2d5 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -20,17 +20,37 @@ class ArithmeticProcessor; template class InputBase { + typedef typename T::clear clear; + + Player* P; + protected: Buffer buffer; Timer timer; + vector os; + public: int values_input; static void input(SubProcessor& Proc, const vector& args); InputBase(ArithmeticProcessor* proc); - ~InputBase(); + virtual ~InputBase(); + + virtual void reset(int player) = 0; + void reset_all(Player& P); + + virtual void add_mine(const clear& input) = 0; + virtual void add_other(int player) = 0; + void add_from_all(const clear& input); + + virtual void send_mine() = 0; + void exchange(); + + virtual T finalize_mine() = 0; + virtual void finalize_other(int player, T& target, octetStream& o) = 0; + T finalize(int player); }; template @@ -43,18 +63,17 @@ class Input : public InputBase SubProcessor& proc; MAC_Check& MC; vector< PointerVector > shares; - octetStream o; open_type rr, t, xi; - void adjust_mac(T& share, const open_type& value); - public: Input(SubProcessor& proc, MAC_Check& mc); Input(SubProcessor* proc, Player& P); void reset(int player); + void add_mine(const clear& input); void add_other(int player); + void send_mine(); T finalize_mine(); diff --git a/Processor/Input.hpp b/Processor/Input.hpp index ec9c07b92..4c93f0e84 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -9,7 +9,7 @@ template InputBase::InputBase(ArithmeticProcessor* proc) : - values_input(0) + P(0), values_input(0) { if (proc) buffer.setup(&proc->private_input, -1, proc->private_input_filename); @@ -17,7 +17,8 @@ InputBase::InputBase(ArithmeticProcessor* proc) : template Input::Input(SubProcessor& proc, MAC_Check& mc) : - InputBase(&proc.Proc), proc(proc), MC(mc), shares(proc.P.num_players()) + InputBase(&proc.Proc), proc(proc), MC(mc), + shares(proc.P.num_players()) { } @@ -39,20 +40,26 @@ InputBase::~InputBase() } template -void Input::adjust_mac(T& share, const open_type& value) +void Input::reset(int player) { - typename T::mac_type tmp; - tmp.mul(MC.get_alphai(), value); - tmp.add(share.get_mac(),tmp); - share.set_mac(tmp); + InputBase::reset(player); + shares[player].clear(); } template -void Input::reset(int player) +void InputBase::reset(int player) { - shares[player].clear(); - if (player == proc.P.my_num()) - o.reset_write_head(); + os.resize(max(os.size(), player + 1UL)); + os[player].reset_write_head(); +} + +template +void InputBase::reset_all(Player& P) +{ + this->P = &P; + os.resize(P.num_players()); + for (int i = 0; i < P.num_players(); i++) + reset(i); } template @@ -63,10 +70,8 @@ void Input::add_mine(const clear& input) T& share = shares[player].back(); proc.DataF.get_input(share, rr, player); t.sub(input, rr); - t.pack(o); - xi.add(t, share.get_share()); - share.set_share(xi); - adjust_mac(share, t); + t.pack(this->os[player]); + share += T(t, 0, MC.get_alphai()); this->values_input++; } @@ -78,10 +83,30 @@ void Input::add_other(int player) proc.DataF.get_input(shares[player].back(), t, player); } +template +void InputBase::add_from_all(const clear& input) +{ + for (int i = 0; i < P->num_players(); i++) + if (i == P->my_num()) + add_mine(input); + else + add_other(i); +} + template void Input::send_mine() { - proc.P.send_all(o, true); + proc.P.send_all(this->os[proc.P.my_num()], true); +} + +template +void InputBase::exchange() +{ + for (int i = 0; i < P->num_players(); i++) + if (i == P->my_num()) + send_mine(); + else + P->receive_player(i, os[i], true); } template @@ -143,7 +168,20 @@ void Input::finalize_other(int player, T& target, { target = shares[player].next(); t.unpack(o); - adjust_mac(target, t); + target += T(t, 1, MC.get_alphai()); +} + +template +T InputBase::finalize(int player) +{ + if (player == P->my_num()) + return finalize_mine(); + else + { + T res; + finalize_other(player, res, os[player]); + return res; + } } template diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index d6f310d9d..ae5108bb2 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -5,14 +5,10 @@ #include "Exceptions/Exceptions.h" #include "Tools/time-func.h" #include "Tools/parse.h" -#include "Auth/ReplicatedMC.h" -#include "Math/MaliciousRep3Share.h" -#include "Math/ShamirShare.h" -#include "Auth/ShamirMC.h" -#include "Math/MaliciousShamirShare.h" //#include "Processor/Processor.hpp" #include "Processor/Binary_File_IO.hpp" +#include "Processor/PrivateOutput.hpp" //#include "Processor/Input.hpp" //#include "Processor/Beaver.hpp" //#include "Processor/Shamir.hpp" @@ -1217,6 +1213,9 @@ inline void Instruction::execute(Processor& Proc) const } else { + if (n > 64) + throw Processor_Error(to_string(n) + "-bit conversion impossible; " + "integer registers only have 64 bits"); to_signed_bigint(Proc.temp.aa,Proc.read_Cp(r[1]),n); Proc.write_Ci(r[0], Proc.temp.aa.get_si()); } diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index fb24b6f46..00cb2b54d 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -2,21 +2,14 @@ #include "Memory.hpp" #include "Online-Thread.hpp" -#include "ShamirInput.hpp" -#include "Shamir.hpp" #include "Replicated.hpp" #include "Beaver.hpp" -#include "Auth/ShamirMC.hpp" -#include "Auth/MaliciousShamirMC.hpp" #include "Exceptions/Exceptions.h" #include #include "Math/Setup.h" -#include "Math/MaliciousRep3Share.h" -#include "Math/ShamirShare.h" -#include "Math/MaliciousShamirShare.h" #include "Tools/mkpath.h" #include @@ -50,27 +43,7 @@ Machine::Machine(int my_number, Names& playerNames, try { read_setup(prep_dir_prefix); - - int nn; - - sprintf(filename, (prep_dir_prefix + "Player-MAC-Keys-P%d").c_str(), my_number); - ifstream inpf; - inpf.open(filename); - if (inpf.fail()) - { - cerr << "Could not open MAC key file. Perhaps it needs to be generated?\n"; - throw file_error(filename); - } - inpf >> nn; - if (nn!=N.num_players()) - { cerr << "KeyGen was last run with " << nn << " players." << endl; - cerr << " - You are running Online with " << N.num_players() << " players." << endl; - exit(1); - } - - alphapi.input(inpf,true); - alpha2i.input(inpf,true); - inpf.close(); + ::read_mac_keys(prep_dir_prefix, my_number, N.num_players(), alphapi, alpha2i); read_mac_keys = true; } catch (file_error& e) @@ -410,7 +383,7 @@ void Machine::run() template string Machine::memory_filename() { - return PREP_DIR "Memory-" + sint::type_short() + "-P" + to_string(my_number); + return BaseMachine::memory_filename(sint::type_short(), my_number); } template diff --git a/Processor/MaliciousRepPrep.h b/Processor/MaliciousRepPrep.h index c1a0d8759..8be5cdcce 100644 --- a/Processor/MaliciousRepPrep.h +++ b/Processor/MaliciousRepPrep.h @@ -8,7 +8,6 @@ #include "Data_Files.h" #include "ReplicatedPrep.h" -#include "Math/MaliciousRep3Share.h" #include "Auth/MaliciousRepMC.h" #include diff --git a/Processor/MascotPrep.h b/Processor/MascotPrep.h index 8a0a140cd..5956fc56e 100644 --- a/Processor/MascotPrep.h +++ b/Processor/MascotPrep.h @@ -10,25 +10,34 @@ #include "OT/NPartyTripleGenerator.h" template -class MascotPrep : public RingPrep +class OTPrep : public RingPrep { protected: - NPartyTripleGenerator* triple_generator; + typename T::TripleGenerator* triple_generator; public: MascotParams params; - MascotPrep(SubProcessor* proc, DataPositions& usage); - ~MascotPrep(); + OTPrep(SubProcessor* proc, DataPositions& usage); + ~OTPrep(); void set_protocol(typename T::Protocol& protocol); + size_t data_sent(); +}; + +template +class MascotPrep : public OTPrep +{ +public: + MascotPrep(SubProcessor* proc, DataPositions& usage) : OTPrep(proc, usage) + { + } + void buffer_triples(); void buffer_inputs(int player); T get_random(); - - size_t data_sent(); }; template diff --git a/Processor/MascotPrep.hpp b/Processor/MascotPrep.hpp index 224960473..54b52aeda 100644 --- a/Processor/MascotPrep.hpp +++ b/Processor/MascotPrep.hpp @@ -9,31 +9,31 @@ #include "OT/Triple.hpp" template -MascotPrep::MascotPrep(SubProcessor* proc, DataPositions& usage) : +OTPrep::OTPrep(SubProcessor* proc, DataPositions& usage) : RingPrep(proc, usage), triple_generator(0) { this->buffer_size = 1000; } template -MascotPrep::~MascotPrep() +OTPrep::~OTPrep() { if (triple_generator) delete triple_generator; } template -void MascotPrep::set_protocol(typename T::Protocol& protocol) +void OTPrep::set_protocol(typename T::Protocol& protocol) { RingPrep::set_protocol(protocol); SubProcessor* proc = this->proc; assert(proc != 0); - auto& ot_setups = BaseMachine::s().ot_setups[proc->Proc.thread_num]; + auto& ot_setups = BaseMachine::s().ot_setups.at(proc->Proc.thread_num); assert(not ot_setups.empty()); OTTripleSetup setup = ot_setups.back(); ot_setups.pop_back(); params.set_mac_key(typename T::mac_key_type::next(proc->MC.get_alphai())); - triple_generator = new NPartyTripleGenerator(setup, + triple_generator = new typename T::TripleGenerator(setup, proc->P.N, proc->Proc.thread_num, this->buffer_size, 1, params, &proc->P); triple_generator->multi_threaded = false; @@ -42,13 +42,15 @@ void MascotPrep::set_protocol(typename T::Protocol& protocol) template void MascotPrep::buffer_triples() { + auto& params = this->params; + auto& triple_generator = this->triple_generator; params.generateBits = false; triple_generator->generate(); triple_generator->unlock(); assert(triple_generator->uncheckedTriples.size() != 0); for (auto& triple : triple_generator->uncheckedTriples) this->triples.push_back( - { triple.a[0], triple.b, triple.c[0] }); + {{ triple.a[0], triple.b, triple.c[0] }}); } template @@ -73,6 +75,7 @@ void MascotFieldPrep::buffer_bits() template void MascotPrep::buffer_inputs(int player) { + auto& triple_generator = this->triple_generator; assert(triple_generator); triple_generator->generateInputs(player); if (this->inputs.size() <= (size_t)player) @@ -97,7 +100,7 @@ T MascotPrep::get_random() } template -size_t MascotPrep::data_sent() +size_t OTPrep::data_sent() { if (triple_generator) return triple_generator->data_sent(); diff --git a/Processor/Memory.hpp b/Processor/Memory.hpp index 5aae26071..67ea21a8c 100644 --- a/Processor/Memory.hpp +++ b/Processor/Memory.hpp @@ -1,8 +1,5 @@ #include "Processor/Memory.h" #include "Processor/Instruction.h" -#include "Math/gf2n.h" -#include "Math/gfp.h" -#include "Math/Integer.h" #include diff --git a/Processor/NoLivePrep.h b/Processor/NoLivePrep.h new file mode 100644 index 000000000..77302ba55 --- /dev/null +++ b/Processor/NoLivePrep.h @@ -0,0 +1,26 @@ +/* + * NoLivePrep.h + * + */ + +#ifndef PROCESSOR_NOLIVEPREP_H_ +#define PROCESSOR_NOLIVEPREP_H_ + +#include "Exceptions/Exceptions.h" +#include "Data_Files.h" + +template class SubProcessor; +class DataPositions; + +template +class NoLivePrep : public Sub_Data_Files +{ +public: + NoLivePrep(SubProcessor* proc, DataPositions& usage) : Sub_Data_Files(0, 0, "", usage, 0) + { + (void) proc; + throw not_implemented(); + } +}; + +#endif /* PROCESSOR_NOLIVEPREP_H_ */ diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 65294c88d..c8039a412 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -5,13 +5,10 @@ #include "Processor/Data_Files.h" #include "Processor/Machine.h" #include "Processor/Processor.h" -#include "Auth/ReplicatedMC.h" -#include "Auth/ShamirMC.h" #include "Networking/CryptoPlayer.h" #include "Processor/Processor.hpp" #include "Processor/Input.hpp" -#include "Auth/MaliciousRepMC.hpp" #include #include diff --git a/Processor/PrivateOutput.h b/Processor/PrivateOutput.h index 40952744d..4cd47d0db 100644 --- a/Processor/PrivateOutput.h +++ b/Processor/PrivateOutput.h @@ -9,16 +9,18 @@ #include using namespace std; -#include "Math/Share.h" +template class SubProcessor; template class PrivateOutput { - SubProcessor>& proc; - deque masks; + typedef typename T::open_type open_type; + + SubProcessor& proc; + deque masks; public: - PrivateOutput(SubProcessor>& proc) : proc(proc) { }; + PrivateOutput(SubProcessor& proc) : proc(proc) { }; void start(int player, int target, int source); void stop(int player, int source); diff --git a/Processor/PrivateOutput.cpp b/Processor/PrivateOutput.hpp similarity index 78% rename from Processor/PrivateOutput.cpp rename to Processor/PrivateOutput.hpp index 8eec49901..da8ce1e10 100644 --- a/Processor/PrivateOutput.cpp +++ b/Processor/PrivateOutput.hpp @@ -9,9 +9,9 @@ template void PrivateOutput::start(int player, int target, int source) { - T mask; + open_type mask; proc.DataF.get_input(proc.get_S_ref(target), mask, player); - proc.get_S_ref(target).add(proc.get_S_ref(source)); + proc.get_S_ref(target) += proc.get_S_ref(source); if (player == proc.P.my_num()) masks.push_back(mask); @@ -22,12 +22,9 @@ void PrivateOutput::stop(int player, int source) { if (player == proc.P.my_num()) { - T value; + open_type value; value.sub(proc.get_C_ref(source), masks.front()); value.output(proc.Proc.private_output, false); masks.pop_front(); } } - -template class PrivateOutput; -template class PrivateOutput; diff --git a/Processor/Processor.h b/Processor/Processor.h index 0c848dbda..88d992e23 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -15,6 +15,7 @@ #include "Data_Files.h" #include "Input.h" #include "ReplicatedInput.h" +#include "SemiInput.h" #include "PrivateOutput.h" #include "ReplicatedPrivateOutput.h" #include "Machine.h" diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index e7aba2e45..c7fefdf98 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -2,7 +2,6 @@ #include "Processor/Processor.h" #include "Networking/STS.h" #include "Auth/MAC_Check.h" -#include "Auth/ReplicatedMC.h" #include "Auth/fake-stuff.h" #include "Processor/ReplicatedInput.hpp" @@ -16,6 +15,7 @@ SubProcessor::SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& Preprocessing& DataF, Player& P) : Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P), input(*this, MC) { + DataF.set_proc(this); DataF.set_protocol(protocol); } diff --git a/Processor/Replicated.h b/Processor/Replicated.h index 283277de1..abaafb286 100644 --- a/Processor/Replicated.h +++ b/Processor/Replicated.h @@ -38,6 +38,8 @@ template class ProtocolBase { public: + typedef T share_type; + int counter; ProtocolBase(); diff --git a/Processor/Replicated.hpp b/Processor/Replicated.hpp index 3e1e188cc..56f7bac4d 100644 --- a/Processor/Replicated.hpp +++ b/Processor/Replicated.hpp @@ -5,12 +5,7 @@ #include "Replicated.h" #include "Processor.h" -#include "Math/FixedVec.h" -#include "Math/Integer.h" -#include "Math/MaliciousRep3Share.h" -#include "Math/ShamirShare.h" #include "Tools/benchmarking.h" -#include "GC/ReplicatedSecret.h" template ProtocolBase::ProtocolBase() : counter(0) diff --git a/Processor/ReplicatedInput.h b/Processor/ReplicatedInput.h index 929db4126..6c89687f5 100644 --- a/Processor/ReplicatedInput.h +++ b/Processor/ReplicatedInput.h @@ -7,6 +7,7 @@ #define PROCESSOR_REPLICATEDINPUT_H_ #include "Input.h" +#include "Replicated.h" template class PrepLessInput : public InputBase @@ -40,12 +41,12 @@ class ReplicatedInput : public PrepLessInput Player& P; vector os; SeededPRNG secure_prng; + ReplicatedBase protocol; public: ReplicatedInput(SubProcessor& proc) : - PrepLessInput(&proc), proc(&proc), P(proc.P) + ReplicatedInput(&proc, proc.P) { - assert(T::length == 2); } ReplicatedInput(SubProcessor& proc, ReplicatedMC& MC) : ReplicatedInput(proc) @@ -53,8 +54,9 @@ class ReplicatedInput : public PrepLessInput (void) MC; } ReplicatedInput(SubProcessor* proc, Player& P) : - PrepLessInput(proc), proc(proc), P(P) + PrepLessInput(proc), proc(proc), P(P), protocol(P) { + assert(T::length == 2); } void reset(int player); diff --git a/Processor/ReplicatedInput.hpp b/Processor/ReplicatedInput.hpp index 615e5ee61..0c0b56a5a 100644 --- a/Processor/ReplicatedInput.hpp +++ b/Processor/ReplicatedInput.hpp @@ -26,12 +26,9 @@ inline void ReplicatedInput::add_mine(const typename T::clear& input) auto& shares = this->shares; shares.push_back({}); T& my_share = shares.back(); - my_share[0].randomize(secure_prng); + my_share[0].randomize(protocol.shared_prngs[0]); my_share[1] = input - my_share[0]; - for (int j = 0; j < 2; j++) - { - my_share[j].pack(os[j]); - } + my_share[1].pack(os[1]); this->values_input++; } @@ -58,7 +55,7 @@ void PrepLessInput::start(int player, int n_inputs) { for (int i = 0; i < n_inputs; i++) { - typename T::value_type t; + typename T::open_type t; this->buffer.input(t); add_mine(t); } @@ -92,13 +89,18 @@ template inline void ReplicatedInput::finalize_other(int player, T& target, octetStream& o) { - typename T::value_type t; - t.unpack(o); - int j = P.get_offset(player) == 2; - T share; - share[j] = t; - share[1 - j] = 0; - target = share; + if (P.get_offset(player) == 1) + { + typename T::value_type t; + t.unpack(o); + target[0] = t; + target[1] = 0; + } + else + { + target[0] = 0; + target[1].randomize(protocol.shared_prngs[1]); + } } template diff --git a/Processor/ReplicatedMachine.hpp b/Processor/ReplicatedMachine.hpp index 52f27f57a..4d81074eb 100644 --- a/Processor/ReplicatedMachine.hpp +++ b/Processor/ReplicatedMachine.hpp @@ -5,6 +5,7 @@ #include "Tools/ezOptionParser.h" #include "Tools/benchmarking.h" +#include "Tools/NetworkOptions.h" #include "Networking/Server.h" #include "Math/Rep3Share.h" #include "Processor/Machine.h" @@ -17,24 +18,7 @@ ReplicatedMachine::ReplicatedMachine(int argc, const char** argv, (void) name; OnlineOptions online_opts(opt, argc, argv); - opt.add( - "localhost", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Host where party 0 is running (default: localhost)", // Help description. - "-h", // Flag token. - "--hostname" // Flag token. - ); - opt.add( - "5000", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Base port number (default: 5000).", // Help description. - "-pn", // Flag token. - "--portnum" // Flag token. - ); + NetworkOptions network_opts(opt, argc, argv); opt.add( "", // Default. 0, // Required? @@ -49,10 +33,8 @@ ReplicatedMachine::ReplicatedMachine(int argc, const char** argv, int playerno = online_opts.playerno; string progname = online_opts.progname; - int pnb; - string hostname; - opt.get("-pn")->getInt(pnb); - opt.get("-h")->getString(hostname); + int pnb = network_opts.portnum_base; + string hostname = network_opts.hostname; bool use_encryption = not opt.get("-u")->isSet; if (not use_encryption) diff --git a/Processor/ReplicatedPrep.h b/Processor/ReplicatedPrep.h index adab93816..41bf9c9a9 100644 --- a/Processor/ReplicatedPrep.h +++ b/Processor/ReplicatedPrep.h @@ -61,11 +61,16 @@ class RingPrep : public BufferPrep RingPrep(SubProcessor* proc, DataPositions& usage); virtual ~RingPrep() {} + void set_proc(SubProcessor* proc) { this->proc = proc; } void set_protocol(typename T::Protocol& protocol); virtual void buffer_bits(); }; +template +void generate_triples(vector>& triples, int n_triples, + typename T::Protocol* protocol); + template class ReplicatedRingPrep : public RingPrep { diff --git a/Processor/ReplicatedPrep.hpp b/Processor/ReplicatedPrep.hpp index 39ec66d61..667cf6243 100644 --- a/Processor/ReplicatedPrep.hpp +++ b/Processor/ReplicatedPrep.hpp @@ -5,9 +5,6 @@ #include "ReplicatedPrep.h" #include "Math/gfp.h" -#include "Math/MaliciousRep3Share.h" -#include "Auth/ReplicatedMC.h" -#include "Auth/ShamirMC.h" template RingPrep::RingPrep(SubProcessor* proc, DataPositions& usage) : @@ -26,12 +23,16 @@ void RingPrep::set_protocol(typename T::Protocol& protocol) template void ReplicatedRingPrep::buffer_triples() { - auto protocol = this->protocol; - auto proc = this->proc; - assert(protocol != 0); - auto& triples = this->triples; - triples.resize(this->buffer_size); - protocol->init_mul(proc); + assert(this->protocol != 0); + generate_triples(this->triples, this->buffer_size, this->protocol); +} + +template +void generate_triples(vector>& triples, int n_triples, + typename T::Protocol* protocol) +{ + triples.resize(n_triples); + protocol->init_mul(); for (size_t i = 0; i < triples.size(); i++) { auto& triple = triples[i]; @@ -75,7 +76,7 @@ void RingPrep::buffer_squares() vector opened(buffer_size); proc->MC.POpen(opened, a_plus_b, proc->P); for (int i = 0; i < buffer_size; i++) - this->squares.push_back({as[i], as[i] * opened[i] - cs[i]}); + this->squares.push_back({{as[i], as[i] * opened[i] - cs[i]}}); } template diff --git a/Processor/RingOptions.cpp b/Processor/RingOptions.cpp new file mode 100644 index 000000000..1da103095 --- /dev/null +++ b/Processor/RingOptions.cpp @@ -0,0 +1,26 @@ +/* + * RingOptions.cpp + * + */ + +#include "RingOptions.h" + +#include +using namespace std; + +RingOptions::RingOptions(ez::ezOptionParser& opt, int argc, const char** argv) +{ + opt.add( + "64", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of integer bits (default: 64)", // Help description. + "-R", // Flag token. + "--ring" // Flag token. + ); + opt.parse(argc, argv); + opt.get("-R")->getInt(R); + opt.resetArgs(); + cerr << "Trying to run " << R << "-bit computation" << endl; +} diff --git a/Processor/RingOptions.h b/Processor/RingOptions.h new file mode 100644 index 000000000..4a34e88a1 --- /dev/null +++ b/Processor/RingOptions.h @@ -0,0 +1,19 @@ +/* + * RingOptions.h + * + */ + +#ifndef PROCESSOR_RINGOPTIONS_H_ +#define PROCESSOR_RINGOPTIONS_H_ + +#include "Tools/ezOptionParser.h" + +class RingOptions +{ +public: + int R; + + RingOptions(ez::ezOptionParser& opt, int argc, const char** argv); +}; + +#endif /* PROCESSOR_RINGOPTIONS_H_ */ diff --git a/Processor/SemiInput.h b/Processor/SemiInput.h new file mode 100644 index 000000000..59ab506de --- /dev/null +++ b/Processor/SemiInput.h @@ -0,0 +1,33 @@ +/* + * SemiInput.h + * + */ + +#ifndef PROCESSOR_SEMIINPUT_H_ +#define PROCESSOR_SEMIINPUT_H_ + +#include "ShamirInput.h" + +template class SemiMC; + +template +class SemiInput : public IndividualInput +{ + SeededPRNG secure_prng; + +public: + SemiInput(SubProcessor& proc, SemiMC& MC) : + IndividualInput(proc) + { + (void) MC; + } + + SemiInput(SubProcessor* proc, Player& P) : + IndividualInput(proc, P) + { + } + + void add_mine(const typename T::clear& input); +}; + +#endif /* PROCESSOR_SEMIINPUT_H_ */ diff --git a/Processor/SemiInput.hpp b/Processor/SemiInput.hpp new file mode 100644 index 000000000..11d9c1adb --- /dev/null +++ b/Processor/SemiInput.hpp @@ -0,0 +1,27 @@ +/* + * SemiInput.cpp + * + */ + +#include "SemiInput.h" + +#include "ShamirInput.hpp" + +template +void SemiInput::add_mine(const typename T::clear& input) +{ + auto& P = this->P; + typename T::open_type sum, share; + for (int i = 0; i < P.num_players(); i++) + { + if (i < P.num_players() - 1) + share.randomize(secure_prng); + else + share = input - sum; + sum += share; + if (i == P.my_num()) + this->shares.push_back(share); + else + share.pack(this->os[i]); + } +} diff --git a/Processor/SemiPrep.h b/Processor/SemiPrep.h new file mode 100644 index 000000000..ec5d1cbf4 --- /dev/null +++ b/Processor/SemiPrep.h @@ -0,0 +1,21 @@ +/* + * SemiPrep.h + * + */ + +#ifndef PROCESSOR_SEMIPREP_H_ +#define PROCESSOR_SEMIPREP_H_ + +#include "MascotPrep.h" + +template +class SemiPrep : public OTPrep +{ +public: + SemiPrep(SubProcessor* proc, DataPositions& usage); + + void buffer_triples(); + void buffer_inverses(); +}; + +#endif /* PROCESSOR_SEMIPREP_H_ */ diff --git a/Processor/SemiPrep.hpp b/Processor/SemiPrep.hpp new file mode 100644 index 000000000..de74b5695 --- /dev/null +++ b/Processor/SemiPrep.hpp @@ -0,0 +1,31 @@ +/* + * SemiPrep.cpp + * + */ + +#include "SemiPrep.h" + +template +SemiPrep::SemiPrep(SubProcessor* proc, DataPositions& usage) : OTPrep(proc, usage) +{ + this->params.set_passive(); +} + +template +void SemiPrep::buffer_triples() +{ + assert(this->triple_generator); + this->triple_generator->generatePlainTriples(); + for (auto& x : this->triple_generator->plainTriples) + { + this->triples.push_back({{x[0], x[1], x[2]}}); + } + this->triple_generator->unlock(); +} + +template +void SemiPrep::buffer_inverses() +{ + assert(this->proc != 0); + BufferPrep::buffer_inverses(this->proc->MC, this->proc->P); +} diff --git a/Processor/Shamir.h b/Processor/Shamir.h index 206843d3a..1a3a853d1 100644 --- a/Processor/Shamir.h +++ b/Processor/Shamir.h @@ -49,6 +49,7 @@ class Shamir : public ProtocolBase> void reset(); + void init_mul(); void init_mul(SubProcessor* proc); U prepare_mul(const T& x, const T& y); void exchange(); diff --git a/Processor/Shamir.hpp b/Processor/Shamir.hpp index 4ee365127..f5e5b2d9d 100644 --- a/Processor/Shamir.hpp +++ b/Processor/Shamir.hpp @@ -20,6 +20,8 @@ U Shamir::get_rec_factor(int i, int n) template Shamir::Shamir(Player& P) : resharing(0), P(P) { + if (not P.is_encrypted()) + insecure("unencrypted communication"); threshold = ShamirMachine::s().threshold; n_mul_players = 2 * threshold + 1; } @@ -56,6 +58,12 @@ template void Shamir::init_mul(SubProcessor* proc) { (void) proc; + init_mul(); +} + +template +void Shamir::init_mul() +{ reset(); if (rec_factor == 0 and P.my_num() < n_mul_players) rec_factor = get_rec_factor(P.my_num(), n_mul_players); @@ -73,11 +81,21 @@ U Shamir::prepare_mul(const T& x, const T& y) template void Shamir::exchange() { - if (P.my_num() < n_mul_players) - resharing->send_mine(); - for (int i = 0; i < n_mul_players; i++) - if (i != P.my_num()) - P.receive_player(i, os[i], true); + for (int offset = 1; offset < P.num_players(); offset++) + { + int receive_from = P.get_player(-offset); + int send_to = P.get_player(offset); + bool receive = receive_from < n_mul_players; + if (P.my_num() < n_mul_players) + { + if (receive) + P.pass_around(resharing->os[send_to], os[receive_from], offset); + else + P.send_to(send_to, resharing->os[send_to], true); + } + else if (receive) + P.receive_player(receive_from, os[receive_from], true); + } } template diff --git a/Processor/ShamirInput.h b/Processor/ShamirInput.h index ce2596e26..d3ae8ce77 100644 --- a/Processor/ShamirInput.h +++ b/Processor/ShamirInput.h @@ -11,37 +11,51 @@ #include "ReplicatedInput.h" template -class ShamirInput : public PrepLessInput +class IndividualInput : public PrepLessInput { +protected: Player& P; vector os; - vector> vandermonde; - SeededPRNG secure_prng; - - vector randomness; public: - ShamirInput(SubProcessor& proc) : + IndividualInput(SubProcessor* proc, Player& P) : + PrepLessInput(proc), P(P) + { + } + IndividualInput(SubProcessor& proc) : PrepLessInput(&proc), P(proc.P) { } + void reset(int player); + void add_other(int player); + void send_mine(); + void finalize_other(int player, T& target, octetStream& o); +}; + +template +class ShamirInput : public IndividualInput +{ + friend class Shamir; + + vector> vandermonde; + SeededPRNG secure_prng; + + vector randomness; + +public: ShamirInput(SubProcessor& proc, ShamirMC& MC) : - ShamirInput(proc) + IndividualInput(proc) { (void) MC; } ShamirInput(SubProcessor* proc, Player& P) : - PrepLessInput(proc), P(P) + IndividualInput(proc, P) { } - void reset(int player); void add_mine(const typename T::clear& input); - void add_other(int player); - void send_mine(); - void finalize_other(int player, T& target, octetStream& o); }; #endif /* PROCESSOR_SHAMIRINPUT_H_ */ diff --git a/Processor/ShamirInput.hpp b/Processor/ShamirInput.hpp index 067456422..40d88826f 100644 --- a/Processor/ShamirInput.hpp +++ b/Processor/ShamirInput.hpp @@ -7,7 +7,7 @@ #include "Machines/ShamirMachine.h" template -void ShamirInput::reset(int player) +void IndividualInput::reset(int player) { if (player == P.my_num()) { @@ -21,6 +21,7 @@ void ShamirInput::reset(int player) template void ShamirInput::add_mine(const typename T::clear& input) { + auto& P = this->P; int n = P.num_players(); int t = ShamirMachine::s().threshold; if (vandermonde.empty()) @@ -49,18 +50,18 @@ void ShamirInput::add_mine(const typename T::clear& input) if (i == P.my_num()) this->shares.push_back(x); else - x.pack(os[i]); + x.pack(this->os[i]); } } template -void ShamirInput::add_other(int player) +void IndividualInput::add_other(int player) { (void) player; } template -void ShamirInput::send_mine() +void IndividualInput::send_mine() { for (int i = 0; i < P.num_players(); i++) if (i != P.my_num()) @@ -68,7 +69,7 @@ void ShamirInput::send_mine() } template -void ShamirInput::finalize_other(int player, T& target, octetStream& o) +void IndividualInput::finalize_other(int player, T& target, octetStream& o) { (void) player; target.unpack(o); diff --git a/Programs/Source/test_gc.mpc b/Programs/Source/test_gc.mpc index 24719eb4c..bdf695d8b 100644 --- a/Programs/Source/test_gc.mpc +++ b/Programs/Source/test_gc.mpc @@ -1,13 +1,4 @@ -def test(a, b, value_type=None): - try: - a = a.reveal() - except AttributeError: - pass - import inspect - print_ln('%s: %s %s %s', inspect.currentframe().f_back.f_lineno, \ - (a ^ cbits(b)).reveal(), a, hex(b)) - test(sbits(3) + sbits(5), 3 ^ 5) test(cbits(3) + cbits(5), 3 + 5) test(cbits(3) + (5), 3 + 5) @@ -86,7 +77,7 @@ test(a[1], 1) test(a[2], 64) a = sbits(-1, n=64) -test(a & a, -1) +test(a & a, 2**64 - 1) sbits.n = 64 a = sbitvec(64 * [sbits(2**64 - 1, n=64)]).popcnt().elements() diff --git a/Programs/Source/tutorial.mpc b/Programs/Source/tutorial.mpc index e948a2b71..bfe240155 100644 --- a/Programs/Source/tutorial.mpc +++ b/Programs/Source/tutorial.mpc @@ -103,24 +103,15 @@ for i in range(3): weight_total = sum(point[0] for point in data) result = sum(point[0] * point[1] for point in data) / weight_total -# the following only works with arithmetic circuits - -# @if_e((sum(point[0] for point in data) != 0).reveal()) -# def _(): -# print_ln('weighted average: %s', result.reveal()) -# @else_ -# def _(): -# print_ln('your inputs made no sense') - -# so we output even an invalid result (the weights adding up to zero) - -print_ln('weighted average: %s', result.reveal()) - -# but we warn the user -# note that the we don't reveal the weight sum, only the comparison - -print_ln_if((sum(point[0] for point in data) == 0).reveal(), \ - 'but the inputs were invalid (weights add up to zero)') +# branching is supported also depending on revealed secret data +# with garbled circuits this triggers a interruption of the garbling + +@if_e((sum(point[0] for point in data) != 0).reveal()) +def _(): + print_ln('weighted average: %s', result.reveal()) +@else_ +def _(): + print_ln('your inputs made no sense') # permutation matrix diff --git a/README.md b/README.md index 93222610d..1c2dbcd38 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,10 @@ # Multi-Protocol SPDZ Software to benchmark various secure multi-party computation (MPC) -protocols such as SPDZ, MASCOT, Overdrive, BMR garbled circuits -(evaluation only), Yao's garbled circuits, and computation based on -semi-honest three-party replicated secret sharing (with an honest majority). +protocols such as SPDZ, SPDZ2k, MASCOT, Overdrive, BMR garbled circuits, +Yao's garbled circuits, and computation based on +three-party replicated secret sharing as well as Shamir's secret +sharing (with an honest majority). #### TL;DR (Binary Distribution on Linux or Source Distribution on macOS) @@ -64,10 +65,10 @@ stands for three-party replicated secret sharing. | Security model | Mod prime / GF(2^n) | Mod 2^k | Binary | | --- | --- | --- | --- | -| Malicious, dishonest majority | [MASCOT](#arithmetic-circuits-mascot--spdz2k) | [SPDZ2k](#arithmetic-circuits-mascot--spdz2k) | N/A | -| Semi-honest, dishonest majority | N/A | N/A | [Yao's GC](#yaos-garbled-circuits) | -| [Malicious, honest majority](#honest-majority) | Shamir / Rep3 | N/A | N/A | -| [Semi-honest, honest majority](#honest-majority) | Shamir / Rep3 | Rep3 | Rep3 | +| Malicious, dishonest majority | [MASCOT](#arithmetic-circuits) | [SPDZ2k](#arithmetic-circuits) | [BMR](#bmr) | +| Semi-honest, dishonest majority | [Semi](#arithmetic-circuits) | [Semi2k](#arithmetic-circuits) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | +| Malicious, honest majority | [Shamir / Rep3](#honest-majority) | [Brain](#honest-majority) | [BMR](#bmr) | +| Semi-honest, honest majority | [Shamir / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3](#honest-majority) / [BMR](#bmr) | #### History @@ -121,7 +122,7 @@ phase outputs the amount of offline material required, which allows to compute the preprocessing time for a particulor computation. #### Requirements - - GCC 5 or later (tested with 7.3) or LLVM (tested with 6.0) + - GCC 5 or later (tested with 8.2) or LLVM/clang 5 or later (tested with 7) - MPIR library, compiled with C++ support (use flag --enable-cxx when running configure) - libsodium library, tested against 1.0.16 - OpenSSL, tested against and 1.0.2 and 1.1.0 @@ -222,20 +223,31 @@ All current full implementations requires oblivious transfer, which is implemented as OT extension based on https://github.com/mkskeller/SimpleOT. -### Arithmetic circuits (MASCOT / SPDZ2k) +### Arithmetic circuits -The two protocols are implemented in `Player-Online.x` and -`spdz2k-party.x`, -respectively. [MASCOT](https://eprint.iacr.org/2016/505) works modulo -a prime, while a [SPDZ2k](https://eprint.iacr.org/2018/482) works -modulo 2^k. We will use MASCOT to demonstrate the use, -but SPDZ2k works similarly. +The following table shows all programs for arithmetic dishonest-majority computation: + +| Program | Protocol | Domain | Malicious | Script | +| --- | --- | --- | --- | --- | +| `Player-Online.x` | [MASCOT](https://eprint.iacr.org/2016/505) | Mod prime | Y | `mascot.sh` | +| `spdz2k-party.x` | [SPDZ2k](https://eprint.iacr.org/2018/482) | Mod 2^k | Y | `spdk2k.sh` | +| `semi-party.x` | OT-based | Mod prime | N | `semi.sh` | +| `semi2k-party.x` | OT-based | Mod 2^k | N | `semi2k.sh` | + +Semi and Semi2k denote the result of stripping MASCOT/SPDZ2k of all +steps required for malicious security, namely amplifying, sacrificing, +MAC generation, and OT correlation checks. What remains is the +generation of additively shared Beaver triples using OT. + +We will use MASCOT to demonstrate the use, but the other protocols +work similarly. First compile the virtual machine: `make -j8 Player-Online.x` -and a high-level program, for example the tutorial (use `-R 64` for SPDZ2k): +and a high-level program, for example the tutorial (use `-R 64` for +SPDZ2k and Semi2k): `./compile.py -F 64 tutorial` @@ -251,7 +263,7 @@ party. Omitting `-I` leads to inputs being read from `Player-Data/Input-P-0` in text format. Or, you can use a script to do run two parties in non-interactive mode -automatically (the script for SPDZ2k is `Scripts/spdz2k.sh`): +automatically: `Scripts/mascot.sh tutorial` @@ -286,9 +298,9 @@ When running locally, you can omit the host argument. As above, `-I` activates interactive input, otherwise inputs are read from `Player-Data/Input-P-0`. -By default, the circuit is garbled at once and stored on the evaluator -side before evaluating. You can activate a more continuous operation -by adding `-C` to the command line on both sides. +By default, the circuit is garbled in chunks that are evaluated +whenever received.You can activate garbling all at once by adding +`-O` to the command line on both sides. ## Honest majority @@ -297,6 +309,7 @@ The following table shows all programs for honest-majority computation: | Program | Sharing | Domain | Malicious | \# parties | Script | | --- | --- | --- | --- | --- | --- | | `replicated-ring-party.x` | Replicated | Mod 2^k | N | 3 | `ring.sh` | +| `brain-party.x` | Replicated | Mod 2^k | Y | 3 | `brain.sh` | | `replicated-bin-party.x` | Replicated | Binary | N | 3 | `replicated.sh` | | `replicated-field-party.x` | Replicated | Mod prime | N | 3 | `rep-field.sh` | | `malicious-rep-field-party.x` | Replicated | Mod prime | Y | 3 | `mal-rep-field.sh` | @@ -306,7 +319,10 @@ The following table shows all programs for honest-majority computation: We use the "generate random triple optimistically/sacrifice/Beaver" methodology described by [Lindell and Nof](https://eprint.iacr.org/2017/816) to achieve malicious -security. Otherwise, we use resharing by [Cramer et +security. The implementation in `brain-party.x` is inspired by +[Eerikson et al.](https://eprint.iacr.org/2019/164) but does not use fast Fourier transform for batch +verification. +Otherwise, we use resharing by [Cramer et al.](https://eprint.iacr.org/2000/037) for Shamir's secret sharing and the optimized approach by [Araki et al.](https://eprint.iacr.org/2016/768) for replicated secret sharing. @@ -365,6 +381,51 @@ the number of parties with `-N` and the maximum number of corrupted parties with `-T`. The latter can be at most half the number of parties. +### BMR + +BMR (Bellare-Micali-Rogaway) is a method of generating a garbled circuit +using another secure computation protocol. We have implemented BMR +based all available implementations using GF(2^128) because the nature +of this field particularly suits the Free-XOR optimization for garbled +circuits. Our implementation is based on [SPDZ-BMR-ORAM +construction](https://eprint.iacr.org/2017/981). The following table +lists the available schemes. + +| Program | Protocol | Dishonest Maj. | Malicious | \# parties | Script | +| --- | --- | --- | --- | --- | --- | +| `real-bmr-party.x` | MASCOT | Y | Y | 2 or more | `real-bmr.sh` | +| `shamir-bmr-party.x` | Shamir | N | N | 3 or more | `shamir-bmr.sh` | +| `mal-shamir-bmr-party.x` | Shamir | N | Y | 3 or more | `mal-shamir-bmr.sh` | +| `rep-bmr-party.x` | Replicated | N | N | 3 | `rep-bmr.sh` | +| `mal-rep-bmr-party.x` | Replicated | N | Y | 3 | `mal-rep-bmr.sh` | + +In the following, we will walk through running the tutorial with BMR +based on MASCOT and two parties. The other programs work similarly. + +First, compile the virtual machine. In order to run with more than +three parties, change the definition of `MAX_N_PARTIES` in +`BMR/config.h` accordingly. + +`make -j 8 real-bmr-party.x` + +In order to compile a high-level program, use `./compile.py -B`: + +`./compile.py -B 32 tutorial` + +Finally, run the two parties as follows: + +`./real-bmr-party.x -I 0 tutorial` + +`./real-bmr-party.x -I 1 tutorial` (in a separate terminal) + +or + +`Scripts/real-bmr.sh tutorial` + +The `-I` enable interactive inputs, and in the tutorial party 0 and 1 +will be asked to provide three numbers. Otherwise, and when using the +script, the inputs are read from `Player-Data/Input-P-0`. + ## Online-only benchmarking In this section we show how to benchmark purely the data-dependent diff --git a/Scripts/brain.sh b/Scripts/brain.sh new file mode 100755 index 000000000..50d19c32d --- /dev/null +++ b/Scripts/brain.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=3 + +. $HERE/run-common.sh + +run_player brain-party.x $* || exit 1 diff --git a/Scripts/build.sh b/Scripts/build.sh index 661c12155..88909f464 100755 --- a/Scripts/build.sh +++ b/Scripts/build.sh @@ -4,10 +4,11 @@ function build { echo ARCH = $1 >> CONFIG.mine echo GDEBUG = >> CONFIG.mine + echo MOD = -DMAX_MOD_SZ=4 >> CONFIG.mine make clean rm -R static mkdir static - make -j 12 static-release + make -j 4 static-release mkdir bin dest=bin/`uname`-$2 rm -R $dest @@ -15,6 +16,5 @@ function build strip $dest/* } -build '' amd64 -build '-msse4.1 -maes -mpclmul' aes -build '-msse4.1 -maes -mpclmul -mavx -mavx2 -mbmi2' avx2 +build '-maes -mpclmul -DCHECK_AES -DCHECK_PCLMUL -DCHECK_AVX' amd64 +build '-msse4.1 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx -DCHECK_ADX' avx2 diff --git a/Scripts/fake-spdz-real-bmr.sh b/Scripts/fake-spdz-real-bmr.sh new file mode 100755 index 000000000..bb3a21fed --- /dev/null +++ b/Scripts/fake-spdz-real-bmr.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player real-bmr-party.x $* -F || exit 1 diff --git a/Scripts/mal-rep-bmr.sh b/Scripts/mal-rep-bmr.sh new file mode 100755 index 000000000..e36f22776 --- /dev/null +++ b/Scripts/mal-rep-bmr.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=${PLAYERS:-3} + +. $HERE/run-common.sh + +run_player mal-rep-bmr-party.x $* || exit 1 diff --git a/Scripts/mal-shamir-bmr.sh b/Scripts/mal-shamir-bmr.sh new file mode 100755 index 000000000..4f684a7b3 --- /dev/null +++ b/Scripts/mal-shamir-bmr.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=${PLAYERS:-3} + +. $HERE/run-common.sh + +run_player mal-shamir-bmr-party.x $* || exit 1 diff --git a/Scripts/real-bmr.sh b/Scripts/real-bmr.sh new file mode 100755 index 000000000..85790e379 --- /dev/null +++ b/Scripts/real-bmr.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player real-bmr-party.x $* || exit 1 diff --git a/Scripts/rep-bmr.sh b/Scripts/rep-bmr.sh new file mode 100755 index 000000000..f0c4a0ff2 --- /dev/null +++ b/Scripts/rep-bmr.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=${PLAYERS:-3} + +. $HERE/run-common.sh + +run_player rep-bmr-party.x $* || exit 1 diff --git a/Scripts/ring.sh b/Scripts/ring.sh index 61c520675..0548f0395 100755 --- a/Scripts/ring.sh +++ b/Scripts/ring.sh @@ -7,4 +7,4 @@ export PLAYERS=3 . $HERE/run-common.sh -run_player replicated-ring-party.x ${1:-test_all} || exit 1 +run_player replicated-ring-party.x $* || exit 1 diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index 4b6b50a9f..ccc26d78d 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -18,7 +18,7 @@ run_player() { fi if [[ $bin = Player-Online.x || $bin =~ 'party.x' ]]; then params="$* -pn $port -h localhost" - if [[ ! $bin =~ 'rep' ]]; then + if [[ ! ($bin =~ 'rep' || $bin =~ 'brain') ]]; then params="$params -N $players" fi else @@ -33,7 +33,6 @@ run_player() { fi rem=$(($players - 2)) for i in $(seq 0 $rem); do - echo "trying with player $i" >&2 echo Running $prefix $SPDZROOT/$bin $i $params log=$SPDZROOT/logs/$i $prefix $SPDZROOT/$bin $i $params 2>&1 | diff --git a/Scripts/semi.sh b/Scripts/semi.sh new file mode 100755 index 000000000..dd71bbf3c --- /dev/null +++ b/Scripts/semi.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +bits=${2:-128} +g=${3:-0} +mem=${4:-empty} + +. $HERE/run-common.sh + +run_player semi-party.x ${1:-test_all} -lgp ${bits} -lg2 ${g} -m ${mem} || exit 1 diff --git a/Scripts/semi2k.sh b/Scripts/semi2k.sh new file mode 100755 index 000000000..3fdd4f4a8 --- /dev/null +++ b/Scripts/semi2k.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player semi2k-party.x $* || exit 1 diff --git a/Scripts/shamir-bmr.sh b/Scripts/shamir-bmr.sh new file mode 100755 index 000000000..c92d026c4 --- /dev/null +++ b/Scripts/shamir-bmr.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=${PLAYERS:-3} + +. $HERE/run-common.sh + +run_player shamir-bmr-party.x $* || exit 1 diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh new file mode 100755 index 000000000..72a895e8e --- /dev/null +++ b/Scripts/test_tutorial.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +for i in 0 1; do + seq 3 > Player-Data/Input-P$i-0 +done + +function test +{ + Scripts/$1.sh tutorial | grep 'expected -0.2, got -0.2' || exit 1 +} + +./compile.py tutorial + +for i in rep-field mal-rep-field shamir mal-shamir semi mascot; do + test $i +done + +./compile.py -R 64 tutorial + +for i in ring brain semi2k spdz2k; do + test $i +done + +./compile.py -B 16 tutorial + +for i in replicated yao rep-bmr mal-rep-bmr shamir-bmr mal-shamir-bmr; do + test $i +done diff --git a/Scripts/tldr.sh b/Scripts/tldr.sh index 9db68ad52..9dc2b27f9 100755 --- a/Scripts/tldr.sh +++ b/Scripts/tldr.sh @@ -23,8 +23,6 @@ fi if test "$flags"; then if $flags | grep -q avx2; then cpu=avx2 - elif $flags | grep -q aes; then - cpu=aes else cpu=amd64 fi diff --git a/Scripts/yao.sh b/Scripts/yao.sh index 84e8ce0ca..e77fafe7f 100755 --- a/Scripts/yao.sh +++ b/Scripts/yao.sh @@ -6,7 +6,7 @@ for i in 0 1; do IFS="" log="yao-$*-$i" IFS=" " - $prefix ./yao-player.x -p $i $* | tee -a logs/$log & true + $prefix ./yao-player.x -p $i $* 2>&1 | tee -a logs/$log & true done wait || exit 1 diff --git a/Tools/FlexBuffer.h b/Tools/FlexBuffer.h index 81e34db82..932e15e0a 100644 --- a/Tools/FlexBuffer.h +++ b/Tools/FlexBuffer.h @@ -93,6 +93,7 @@ class ReceivedMsgStore ReceivedMsgStore() : start(0), mem_size(0), total_size(0) {} ~ReceivedMsgStore(); void push(ReceivedMsg& msg); + void push_and_clear(LocalBuffer& msg) { push(msg); msg.clear(); } bool pop(ReceivedMsg& msg); bool empty() { return mem_size == 0 and files.empty(); } }; diff --git a/Tools/MMO.cpp b/Tools/MMO.cpp index e8ca280f5..86737d13c 100644 --- a/Tools/MMO.cpp +++ b/Tools/MMO.cpp @@ -146,3 +146,5 @@ void MMO::hashBlockWise(octet* output, octet* input) template void MMO::hashBlocks(void*, const void*); #define Z(F) ZZ(F,1) ZZ(F,2) ZZ(F,8) Z(gf2n_long) Z(Z2<64>) Z(Z2<112>) Z(Z2<128>) Z(Z2<160>) Z(Z2<114>) Z(Z2<130>) +Z(Z2<72>) +Z(SignedZ2<64>) Z(SignedZ2<72>) diff --git a/Tools/NetworkOptions.cpp b/Tools/NetworkOptions.cpp new file mode 100644 index 000000000..4b3c507ca --- /dev/null +++ b/Tools/NetworkOptions.cpp @@ -0,0 +1,33 @@ +/* + * NetworkOptions.cpp + * + */ + +#include "NetworkOptions.h" + +NetworkOptions::NetworkOptions(ez::ezOptionParser& opt, int argc, + const char** argv) +{ + opt.add( + "localhost", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Host where party 0 is running (default: localhost)", // Help description. + "-h", // Flag token. + "--hostname" // Flag token. + ); + opt.add( + "5000", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Base port number (default: 5000).", // Help description. + "-pn", // Flag token. + "--portnum" // Flag token. + ); + opt.parse(argc, argv); + opt.get("-pn")->getInt(portnum_base); + opt.get("-h")->getString(hostname); + opt.resetArgs(); +} diff --git a/Tools/NetworkOptions.h b/Tools/NetworkOptions.h new file mode 100644 index 000000000..d3c11a5a8 --- /dev/null +++ b/Tools/NetworkOptions.h @@ -0,0 +1,22 @@ +/* + * NetworkOptions.h + * + */ + +#ifndef TOOLS_NETWORKOPTIONS_H_ +#define TOOLS_NETWORKOPTIONS_H_ + +#include "ezOptionParser.h" + +#include + +class NetworkOptions +{ +public: + int portnum_base; + std::string hostname; + + NetworkOptions(ez::ezOptionParser& opt, int argc, const char** argv); +}; + +#endif /* TOOLS_NETWORKOPTIONS_H_ */ diff --git a/Tools/aes-ni.cpp b/Tools/aes-ni.cpp index ec2d4ece4..49d4c13a1 100644 --- a/Tools/aes-ni.cpp +++ b/Tools/aes-ni.cpp @@ -6,17 +6,6 @@ * M-Code Version * **********************/ -#define cpuid(func,ax,bx,cx,dx)\ - __asm__ __volatile__ ("cpuid":\ - "=a" (ax), "=b" (bx), "=c" (cx), "=d" (dx) : "a" (func)); - - -int Check_CPU_support_AES() -{ unsigned int a,b,c,d; - cpuid(1, a,b,c,d); - return (c & 0x2000000); -} - inline __m128i AES_128_ASSIST (__m128i temp1, __m128i temp2) { __m128i temp3; temp2 = _mm_shuffle_epi32 (temp2 ,0xff); temp3 = _mm_slli_si128 (temp1, 0x4); @@ -33,43 +22,46 @@ inline __m128i AES_128_ASSIST (__m128i temp1, __m128i temp2) void aes_128_schedule( octet* key, const octet* userkey ) { #ifdef __AES__ - __m128i temp1, temp2; - __m128i *Key_Schedule = (__m128i*)key; - temp1 = _mm_loadu_si128((__m128i*)userkey); - Key_Schedule[0] = temp1; - temp2 = _mm_aeskeygenassist_si128 (temp1 ,0x1); - temp1 = AES_128_ASSIST(temp1, temp2); - Key_Schedule[1] = temp1; - temp2 = _mm_aeskeygenassist_si128 (temp1,0x2); - temp1 = AES_128_ASSIST(temp1, temp2); - Key_Schedule[2] = temp1; - temp2 = _mm_aeskeygenassist_si128 (temp1,0x4); - temp1 = AES_128_ASSIST(temp1, temp2); - Key_Schedule[3] = temp1; - temp2 = _mm_aeskeygenassist_si128 (temp1,0x8); - temp1 = AES_128_ASSIST(temp1, temp2); - Key_Schedule[4] = temp1; - temp2 = _mm_aeskeygenassist_si128 (temp1,0x10); - temp1 = AES_128_ASSIST(temp1, temp2); - Key_Schedule[5] = temp1; - temp2 = _mm_aeskeygenassist_si128 (temp1,0x20); - temp1 = AES_128_ASSIST(temp1, temp2); - Key_Schedule[6] = temp1; - temp2 = _mm_aeskeygenassist_si128 (temp1,0x40); - temp1 = AES_128_ASSIST(temp1, temp2); - Key_Schedule[7] = temp1; - temp2 = _mm_aeskeygenassist_si128 (temp1,0x80); - temp1 = AES_128_ASSIST(temp1, temp2); - Key_Schedule[8] = temp1; - temp2 = _mm_aeskeygenassist_si128 (temp1,0x1b); - temp1 = AES_128_ASSIST(temp1, temp2); - Key_Schedule[9] = temp1; - temp2 = _mm_aeskeygenassist_si128 (temp1,0x36); - temp1 = AES_128_ASSIST(temp1, temp2); - Key_Schedule[10] = temp1; -#else - aes_128_schedule((uint*) key, userkey); + if (cpu_has_aes()) + { + __m128i temp1, temp2; + __m128i *Key_Schedule = (__m128i*)key; + temp1 = _mm_loadu_si128((__m128i*)userkey); + Key_Schedule[0] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1 ,0x1); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[1] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x2); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[2] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x4); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[3] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x8); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[4] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x10); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[5] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x20); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[6] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x40); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[7] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x80); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[8] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x1b); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[9] = temp1; + temp2 = _mm_aeskeygenassist_si128 (temp1,0x36); + temp1 = AES_128_ASSIST(temp1, temp2); + Key_Schedule[10] = temp1; + } + else #endif + aes_128_schedule((uint*) key, userkey); } #ifdef __AES__ diff --git a/Tools/aes.h b/Tools/aes.h index 1e2049c46..fec9bc949 100644 --- a/Tools/aes.h +++ b/Tools/aes.h @@ -4,6 +4,7 @@ #include #include "Networking/data.h" +#include "cpu_support.h" typedef unsigned int uint; @@ -33,7 +34,7 @@ inline void aes_encrypt( octet* C, octet* M, uint* RK ) /*********** M-Code Version ***********/ // Check can support this -int Check_CPU_support_AES(); +inline int Check_CPU_support_AES() { return cpu_has_aes(); } // Key Schedule void aes_128_schedule( octet* key, const octet* userkey ); void aes_192_schedule( octet* key, const octet* userkey ); @@ -52,17 +53,32 @@ void aes_256_encrypt( octet* C, const octet* M,const octet* RK ); __attribute__((optimize("unroll-loops"))) #endif inline __m128i aes_128_encrypt(__m128i in, const octet* key) -{ __m128i& tmp = in; - tmp = _mm_xor_si128 (tmp,((__m128i*)key)[0]); +{ #ifdef __AES__ - int j; - for(j=1; j <10; j++) - { tmp = _mm_aesenc_si128 (tmp,((__m128i*)key)[j]); } - tmp = _mm_aesenclast_si128 (tmp,((__m128i*)key)[j]); -#else - throw runtime_error("need to compile with AES-NI support"); + if (cpu_has_aes()) + { + __m128i& tmp = in; + tmp = _mm_xor_si128 (tmp,((__m128i*)key)[0]); + int j; + for(j=1; j <10; j++) + tmp = _mm_aesenc_si128 (tmp,((__m128i*)key)[j]); + tmp = _mm_aesenclast_si128 (tmp,((__m128i*)key)[j]); + return tmp; + } + else #endif - return tmp; + { + __m128i tmp; + aes_128_encrypt((octet*) &tmp, (octet*) &in, (uint*) key); + return tmp; + } +} + +template +inline void software_ecb_aes_128_encrypt(__m128i* out, __m128i* in, uint* key) +{ + for (int i = 0; i < N; i++) + aes_128_encrypt((octet*)&out[i], (octet*)&in[i], key); } template @@ -72,19 +88,21 @@ __attribute__((optimize("unroll-loops"))) inline void ecb_aes_128_encrypt(__m128i* out, __m128i* in, const octet* key) { #ifdef __AES__ - __m128i tmp[N]; - for (int i = 0; i < N; i++) - tmp[i] = _mm_xor_si128 (in[i],((__m128i*)key)[0]); - int j; - for(j=1; j <10; j++) + if (cpu_has_aes()) + { + __m128i tmp[N]; for (int i = 0; i < N; i++) - tmp[i] = _mm_aesenc_si128 (tmp[i],((__m128i*)key)[j]); - for (int i = 0; i < N; i++) - out[i] = _mm_aesenclast_si128 (tmp[i],((__m128i*)key)[j]); -#else - for (int i = 0; i < N; i++) - aes_128_encrypt((octet*)&out[i], (octet*)&in[i], (uint*)key); + tmp[i] = _mm_xor_si128 (in[i],((__m128i*)key)[0]); + int j; + for(j=1; j <10; j++) + for (int i = 0; i < N; i++) + tmp[i] = _mm_aesenc_si128 (tmp[i],((__m128i*)key)[j]); + for (int i = 0; i < N; i++) + out[i] = _mm_aesenclast_si128 (tmp[i],((__m128i*)key)[j]); + } + else #endif + software_ecb_aes_128_encrypt(out, in, (uint*) key); } template diff --git a/Tools/avx_memcpy.h b/Tools/avx_memcpy.h index 5e3281d30..fa8cd6d6a 100644 --- a/Tools/avx_memcpy.h +++ b/Tools/avx_memcpy.h @@ -60,8 +60,14 @@ inline void avx_memzero(void* dest, size_t length) length -= 32; } #endif - if (length) + switch (length) + { + case 8: + *(int64_t*)d = 0; + return; + default: memset((void*)d, 0, length); + } } #endif /* TOOLS_AVX_MEMCPY_H_ */ diff --git a/Tools/cpu_support.h b/Tools/cpu_support.h new file mode 100644 index 000000000..405bb1f25 --- /dev/null +++ b/Tools/cpu_support.h @@ -0,0 +1,71 @@ +/* + * cpu_support.h + * + */ + +#ifndef TOOLS_CPU_SUPPORT_H_ +#define TOOLS_CPU_SUPPORT_H_ + +inline bool check_cpu(int func, bool ecx, int feature) +{ + int ax = func, bx, cx = 0, dx; + __asm__ __volatile__ ("cpuid": + "+a" (ax), "=b" (bx), "+c" (cx), "=d" (dx)); + return ((ecx ? cx : bx) >> feature) & 1; +} + +inline bool cpu_has_adx() +{ +#ifdef CHECK_ADX + return check_cpu(7, false, 19); +#else + return true; +#endif +} + +inline bool cpu_has_bmi2() +{ +#ifdef CHECK_BMI2 + return check_cpu(7, false, 8); +#else + return true; +#endif +} + +inline bool cpu_has_avx2() +{ +#ifdef CHECK_AVX2 + return check_cpu(7, false, 5); +#else + return true; +#endif +} + +inline bool cpu_has_avx() +{ +#ifdef CHECK_AVX + return check_cpu(1, true, 28); +#else + return true; +#endif +} + +inline bool cpu_has_pclmul() +{ +#ifdef CHECK_PCLMUL + return check_cpu(1, true, 1); +#else + return true; +#endif +} + +inline bool cpu_has_aes() +{ +#ifdef CHECK_AES + return check_cpu(1, true, 25); +#else + return true; +#endif +} + +#endif /* TOOLS_CPU_SUPPORT_H_ */ diff --git a/Tools/random.cpp b/Tools/random.cpp index 48bdd2d25..68e3046ae 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -81,7 +81,7 @@ void PRNG::print_state() const cout << hex << (int) random[i]; } cout << "\t"; - for (i=0; i((__m128i*)random,(__m128i*)state,KeyScheduleC); } else { ecb_aes_128_encrypt((__m128i*)random,(__m128i*)state,KeySchedule); } #endif diff --git a/Yao/Machine.cpp b/Yao/Machine.cpp deleted file mode 100644 index 880521d17..000000000 --- a/Yao/Machine.cpp +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Secret.cpp - * - */ - -#include "YaoGarbleWire.h" -#include "YaoEvalWire.h" - -#include "GC/Machine.hpp" -#include "GC/Processor.hpp" -#include "GC/Secret.hpp" -#include "GC/Thread.hpp" -#include "GC/ThreadMaster.hpp" - -namespace GC -{ - -template class Secret; -template class Secret; - -template void Secret::reveal(Clear& x); -template void Secret::reveal(Clear& x); - -template class Machine< Secret >; -template class Machine< Secret >; - -template class Processor< Secret >; -template class Processor< Secret >; - -template class Thread< Secret >; -template class Thread< Secret >; - -template class ThreadMaster< Secret >; -template class ThreadMaster< Secret >; - -} diff --git a/Yao/Program.cpp b/Yao/Program.cpp deleted file mode 100644 index 9302ac311..000000000 --- a/Yao/Program.cpp +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Program.cpp - * - */ - -#include "YaoEvalWire.h" -#include "YaoGarbleWire.h" - -#include "GC/Instruction.hpp" -#include "GC/Program.hpp" - -#include "Processor/Instruction.hpp" - -namespace GC -{ - -template class Instruction< Secret >; -template class Instruction< Secret >; - -template class Program< Secret >; -template class Program< Secret >; - -} diff --git a/Yao/YaoCommon.h b/Yao/YaoCommon.h index 95fbcb28f..714e193e1 100644 --- a/Yao/YaoCommon.h +++ b/Yao/YaoCommon.h @@ -8,7 +8,10 @@ #include -class YaoCommon +#include "Exceptions/Exceptions.h" +#include "GC/RuntimeBranching.h" + +class YaoCommon : public GC::RuntimeBranching { int log_n_threads; diff --git a/Yao/YaoEvalMaster.cpp b/Yao/YaoEvalMaster.cpp index 2c50426f1..575b38001 100644 --- a/Yao/YaoEvalMaster.cpp +++ b/Yao/YaoEvalMaster.cpp @@ -6,12 +6,21 @@ #include "YaoEvalMaster.h" #include "YaoEvaluator.h" +#include "GC/Instruction.hpp" +#include "GC/Machine.hpp" +#include "GC/Program.hpp" +#include "GC/Processor.hpp" +#include "GC/Secret.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" +#include "Processor/Instruction.hpp" + YaoEvalMaster::YaoEvalMaster(bool continuous, OnlineOptions& opts) : ThreadMaster>(opts), continuous(continuous) { } -Thread>* YaoEvalMaster::new_thread(int i) +GC::Thread>* YaoEvalMaster::new_thread(int i) { return new YaoEvaluator(i, *this); } diff --git a/Yao/YaoEvalMaster.h b/Yao/YaoEvalMaster.h index 69d8fa195..69128c132 100644 --- a/Yao/YaoEvalMaster.h +++ b/Yao/YaoEvalMaster.h @@ -10,8 +10,6 @@ #include "GC/Secret.h" #include "YaoEvalWire.h" -using namespace GC; - class YaoEvalMaster : public GC::ThreadMaster> { public: @@ -19,7 +17,7 @@ class YaoEvalMaster : public GC::ThreadMaster> YaoEvalMaster(bool continuous, OnlineOptions& opts); - Thread>* new_thread(int i); + GC::Thread>* new_thread(int i); }; #endif /* YAO_YAOEVALMASTER_H_ */ diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index 17820426e..b134a0db4 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -11,6 +11,10 @@ #include "BMR/common.h" #include "GC/ArgTuples.h" +#include "GC/Processor.hpp" +#include "GC/Secret.hpp" +#include "GC/Thread.hpp" + ostream& YaoEvalWire::out = cout; void YaoEvalWire::random() @@ -165,6 +169,7 @@ void YaoEvalWire::XOR(const YaoEvalWire& left, const YaoEvalWire& right) bool YaoEvalWire::get_output() { + YaoEvaluator::s().taint(); bool res = external ^ YaoEvaluator::s().output_masks.pop_front(); #ifdef DEBUG cout << "output " << res << " mask " << (external ^ res) << " external " @@ -185,6 +190,14 @@ void YaoEvalWire::set(Key key, bool external) set(key); } +void YaoEvalWire::convcbit(Integer& dest, const GC::Clear& source) +{ + auto& evaluator = YaoEvaluator::s(); + dest = source; + evaluator.P->send_long(0, source.get()); + evaluator.untaint(); +} + template void YaoEvalWire::and_( GC::Processor >& processor, const vector& args); diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h index 388323a72..3f2c4dd61 100644 --- a/Yao/YaoEvalWire.h +++ b/Yao/YaoEvalWire.h @@ -43,6 +43,8 @@ class YaoEvalWire : public Phase static void inputb(GC::Processor>& processor, const vector& args); + static void convcbit(Integer& dest, const GC::Clear& source); + void set(const Key& key); void set(Key key, bool external); diff --git a/Yao/YaoEvaluator.cpp b/Yao/YaoEvaluator.cpp index 32490d9b7..5f5a82c0e 100644 --- a/Yao/YaoEvaluator.cpp +++ b/Yao/YaoEvaluator.cpp @@ -5,6 +5,14 @@ #include "YaoEvaluator.h" +#include "GC/Instruction.hpp" +#include "GC/Machine.hpp" +#include "GC/Program.hpp" +#include "GC/Processor.hpp" +#include "GC/Secret.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" + thread_local YaoEvaluator* YaoEvaluator::singleton = 0; YaoEvaluator::YaoEvaluator(int thread_num, YaoEvalMaster& master) : @@ -36,9 +44,19 @@ void YaoEvaluator::run(GC::Program>& program) void YaoEvaluator::run(GC::Program>& program, Player& P) { + auto next = GC::TIME_BREAK; do + { receive(P); - while(GC::DONE_BREAK != program.execute(processor, -1)); + try + { + next = program.execute(processor, master.memory, -1); + } + catch (needs_cleaning& e) + { + } + } + while(GC::DONE_BREAK != next); } void YaoEvaluator::run_from_store(GC::Program>& program) @@ -49,7 +67,7 @@ void YaoEvaluator::run_from_store(GC::Program>& program) gates_store.pop(gates); output_masks_store.pop(output_masks); } - while(GC::DONE_BREAK != program.execute(processor, -1)); + while(GC::DONE_BREAK != program.execute(processor, master.memory, -1)); } bool YaoEvaluator::receive(Player& P) diff --git a/Yao/YaoEvaluator.h b/Yao/YaoEvaluator.h index 8c19be9bd..b542304cb 100644 --- a/Yao/YaoEvaluator.h +++ b/Yao/YaoEvaluator.h @@ -14,7 +14,7 @@ #include "Tools/MMO.h" #include "OT/OTExtensionWithMatrix.h" -class YaoEvaluator : public GC::Thread>, public YaoCommon +class YaoEvaluator : public GC::Thread>, public YaoCommon { protected: static thread_local YaoEvaluator* singleton; @@ -37,7 +37,7 @@ class YaoEvaluator : public GC::Thread>, public YaoCommon YaoEvaluator(int thread_num, YaoEvalMaster& master); - bool continuous() { return master.continuous and thread_num == 0; } + bool continuous() { return master.continuous and master.machine.nthreads == 1; } void pre_run(); void run(GC::Program>& program); diff --git a/Yao/YaoGarbleMaster.cpp b/Yao/YaoGarbleMaster.cpp index b2e0da618..593d6c466 100644 --- a/Yao/YaoGarbleMaster.cpp +++ b/Yao/YaoGarbleMaster.cpp @@ -6,6 +6,15 @@ #include "YaoGarbleMaster.h" #include "YaoGarbler.h" +#include "GC/Instruction.hpp" +#include "GC/Machine.hpp" +#include "GC/Program.hpp" +#include "GC/Processor.hpp" +#include "GC/Secret.hpp" +#include "GC/Thread.hpp" +#include "GC/ThreadMaster.hpp" +#include "Processor/Instruction.hpp" + YaoGarbleMaster::YaoGarbleMaster(bool continuous, OnlineOptions& opts, int threshold) : super(opts), continuous(continuous), threshold(threshold) { @@ -15,7 +24,7 @@ YaoGarbleMaster::YaoGarbleMaster(bool continuous, OnlineOptions& opts, int thres delta.set_signal(1); } -Thread>* YaoGarbleMaster::new_thread(int i) +GC::Thread>* YaoGarbleMaster::new_thread(int i) { return new YaoGarbler(i, *this); } diff --git a/Yao/YaoGarbleMaster.h b/Yao/YaoGarbleMaster.h index 5df7ad18c..40914013b 100644 --- a/Yao/YaoGarbleMaster.h +++ b/Yao/YaoGarbleMaster.h @@ -11,8 +11,6 @@ #include "YaoGarbleWire.h" #include "Processor/OnlineOptions.h" -using namespace GC; - class YaoGarbleMaster : public GC::ThreadMaster> { typedef GC::ThreadMaster> super; @@ -24,7 +22,7 @@ class YaoGarbleMaster : public GC::ThreadMaster> YaoGarbleMaster(bool continuous, OnlineOptions& opts, int threshold = 1024); - Thread>* new_thread(int i); + GC::Thread>* new_thread(int i); }; #endif /* YAO_YAOGARBLEMASTER_H_ */ diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 03ed88299..35257b3de 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -8,6 +8,10 @@ #include "YaoGarbler.h" #include "GC/ArgTuples.h" +#include "GC/Processor.hpp" +#include "GC/Secret.hpp" +#include "GC/Thread.hpp" + void YaoGarbleWire::randomize(PRNG& prng) { key = prng.get_doubleword(); @@ -254,6 +258,15 @@ void YaoGarbleWire::XOR(const YaoGarbleWire& left, const YaoGarbleWire& right) char YaoGarbleWire::get_output() { + YaoGarbler::s().taint(); YaoGarbler::s().output_masks.push_back(mask); return -1; } + +void YaoGarbleWire::convcbit(Integer& dest, const GC::Clear& source) +{ + (void) source; + auto& garbler = YaoGarbler::s(); + garbler.untaint(); + dest = garbler.P->receive_long(1); +} diff --git a/Yao/YaoGarbleWire.h b/Yao/YaoGarbleWire.h index d582c4aaf..b576b115f 100644 --- a/Yao/YaoGarbleWire.h +++ b/Yao/YaoGarbleWire.h @@ -51,6 +51,8 @@ class YaoGarbleWire : public Phase static void inputb(GC::Processor>& processor, const vector& args); + static void convcbit(Integer& dest, const GC::Clear& source); + void randomize(PRNG& prng); void set(Key key, bool mask); diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp index 090bf04f5..aa9344b04 100644 --- a/Yao/YaoGarbler.cpp +++ b/Yao/YaoGarbler.cpp @@ -6,10 +6,18 @@ #include "YaoGarbler.h" #include "YaoGate.h" +#include "GC/ThreadMaster.hpp" +#include "GC/Instruction.hpp" +#include "GC/Processor.hpp" +#include "GC/Program.hpp" +#include "GC/Machine.hpp" +#include "GC/Secret.hpp" +#include "GC/Thread.hpp" + thread_local YaoGarbler* YaoGarbler::singleton = 0; YaoGarbler::YaoGarbler(int thread_num, YaoGarbleMaster& master) : - Thread>(thread_num, master), + GC::Thread>(thread_num, master), master(master), and_proc_timer(CLOCK_PROCESS_CPUTIME_ID), and_main_thread_timer(CLOCK_THREAD_CPUTIME_ID), @@ -46,28 +54,31 @@ YaoGarbler::~YaoGarbler() void YaoGarbler::run(GC::Program>& program) { singleton = this; - bool continuous = master.continuous; - if (continuous and thread_num > 0) - { - cerr << "continuous running not available for more than one thread" << endl; - continuous = false; - } GC::BreakType b = GC::TIME_BREAK; while(GC::DONE_BREAK != b) { - b = program.execute(processor, -1); + try + { + b = program.execute(processor, master.memory, -1); + } + catch (needs_cleaning& e) + { + if (not continuous()) + throw runtime_error("run-time branching impossible with garbling at once"); + processor.PC--; + } send(*P); gates.clear(); output_masks.clear(); - if (continuous) + if (continuous()) process_receiver_inputs(); } } void YaoGarbler::post_run() { - if (not (master.continuous and thread_num == 0)) + if (not continuous()) { P->send_long(1, YaoCommon::DONE); process_receiver_inputs(); diff --git a/Yao/YaoGarbler.h b/Yao/YaoGarbler.h index 6c047515b..7a8c4b314 100644 --- a/Yao/YaoGarbler.h +++ b/Yao/YaoGarbler.h @@ -18,8 +18,6 @@ #include -using namespace GC; - class YaoGate; class YaoGarbler : public GC::Thread>, public YaoCommon @@ -57,6 +55,9 @@ class YaoGarbler : public GC::Thread>, public YaoCommo YaoGarbler(int thread_num, YaoGarbleMaster& master); ~YaoGarbler(); + + bool continuous() { return master.continuous and master.machine.nthreads == 1; } + void run(GC::Program>& program); void run(Player& P, bool continuous); void post_run(); diff --git a/Yao/YaoPlayer.cpp b/Yao/YaoPlayer.cpp index 30869d9ed..a2d32af45 100644 --- a/Yao/YaoPlayer.cpp +++ b/Yao/YaoPlayer.cpp @@ -8,6 +8,8 @@ #include "YaoEvaluator.h" #include "Tools/ezOptionParser.h" +#include "GC/Machine.hpp" + YaoPlayer::YaoPlayer(int argc, const char** argv) { ez::ezOptionParser opt; @@ -43,9 +45,9 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) 0, // Required? 0, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Evaluate while garbling (default: false).", // Help description. - "-C", // Flag token. - "--continuous" // Flag token. + "Evaluate only after garbling (default only with multi-threading).", // Help description. + "-O", // Flag token. + "--oneshot" // Flag token. ); opt.add( "1024", // Default. @@ -78,10 +80,10 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) opt.get("-p")->getInt(my_num); opt.get("-pn")->getInt(pnb); opt.get("-h")->getString(hostname); - bool continuous = opt.get("-C")->isSet; + bool continuous = not opt.get("-O")->isSet; opt.get("-t")->getInt(threshold); - ThreadMasterBase* master; + GC::ThreadMasterBase* master; if (my_num == 0) master = new YaoGarbleMaster(continuous, online_opts, threshold); else @@ -89,6 +91,9 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) server = Server::start_networking(master->N, my_num, 2, hostname, pnb); master->run(progname); + + if (my_num == 1) + ((YaoEvalMaster*)master)->machine.write_memory(0); } YaoPlayer::~YaoPlayer() diff --git a/bmr-program-party.cpp b/bmr-program-party.cpp index 2da1a531e..2ae6edc0c 100644 --- a/bmr-program-party.cpp +++ b/bmr-program-party.cpp @@ -5,8 +5,8 @@ #include "BMR/Party.h" -int main(int argc, char** argv) +int main(int argc, const char** argv) { - ProgramParty party(argc, argv); + FakeProgramParty party(argc, argv); party.Start(); } diff --git a/brain-party.cpp b/brain-party.cpp new file mode 100644 index 000000000..36f285c16 --- /dev/null +++ b/brain-party.cpp @@ -0,0 +1,33 @@ +/* + * brain-party.cpp + * + */ + +#include "Math/BrainShare.h" +#include "Math/MaliciousRep3Share.h" +#include "Processor/RingOptions.h" + +#include "Processor/ReplicatedMachine.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + RingOptions opts(opt, argc, argv); + switch (opts.R) + { + case 64: + // multiple of eight for quicker randomness generation + gfp2::init_default(DIV_CEIL(BrainShare<64, 40>::Z_BITS + 3, 8) * 8); + ReplicatedMachine, MaliciousRep3Share>(argc, + argv, "", opt); + break; + case 72: + // multiple of eight for quicker randomness generation + gfp2::init_default(DIV_CEIL(BrainShare<72, 40>::Z_BITS + 3, 8) * 8); + ReplicatedMachine, MaliciousRep3Share>(argc, + argv, "", opt); + break; + default: + throw runtime_error(to_string(opts.R) + "-bit computation not implemented"); + } +} diff --git a/mal-rep-bmr-party.cpp b/mal-rep-bmr-party.cpp new file mode 100644 index 000000000..21112ada4 --- /dev/null +++ b/mal-rep-bmr-party.cpp @@ -0,0 +1,13 @@ +/* + * mal-rep-shamir-party.cpp + * + */ + +#include "Machines/Rep.cpp" + +#include "BMR/RealProgramParty.hpp" + +int main(int argc, const char** argv) +{ + RealProgramParty>(argc, argv); +} diff --git a/mal-shamir-bmr-party.cpp b/mal-shamir-bmr-party.cpp new file mode 100644 index 000000000..b6e4a0efe --- /dev/null +++ b/mal-shamir-bmr-party.cpp @@ -0,0 +1,14 @@ +/* + * mal-shamir-bmr-party.cpp + * + */ + +#include "Machines/ShamirMachine.cpp" + +#include "BMR/RealProgramParty.hpp" + +int main(int argc, const char** argv) +{ + ShamirMachine machine(argc, argv); + RealProgramParty>(argc, argv); +} diff --git a/real-bmr-party.cpp b/real-bmr-party.cpp new file mode 100644 index 000000000..6872ef39c --- /dev/null +++ b/real-bmr-party.cpp @@ -0,0 +1,13 @@ +/* + * real-bmr-party.cpp + * + */ + +#include "Machines/SPDZ.cpp" + +#include "BMR/RealProgramParty.hpp" + +int main(int argc, const char** argv) +{ + RealProgramParty>(argc, argv); +} diff --git a/rep-bmr-party.cpp b/rep-bmr-party.cpp new file mode 100644 index 000000000..5192d786b --- /dev/null +++ b/rep-bmr-party.cpp @@ -0,0 +1,13 @@ +/* + * rep-bmr-party.cpp + * + */ + +#include "Machines/Rep.cpp" + +#include "BMR/RealProgramParty.hpp" + +int main(int argc, const char** argv) +{ + RealProgramParty>(argc, argv); +} diff --git a/replicated-ring-party.cpp b/replicated-ring-party.cpp index 8977a2c4e..278440be3 100644 --- a/replicated-ring-party.cpp +++ b/replicated-ring-party.cpp @@ -4,11 +4,24 @@ */ #include "Processor/ReplicatedMachine.hpp" +#include "Processor/RingOptions.h" #include "Math/Integer.h" int main(int argc, const char** argv) { ez::ezOptionParser opt; - ReplicatedMachine, Rep3Share>(argc, argv, - "replicated-ring", opt); + RingOptions opts(opt, argc, argv); + switch (opts.R) + { + case 64: + ReplicatedMachine>, Rep3Share>(argc, argv, + "replicated-ring", opt); + break; + case 72: + ReplicatedMachine>, Rep3Share>(argc, argv, + "replicated-ring", opt); + break; + default: + throw runtime_error(to_string(opts.R) + "-bit computation not implemented"); + } } diff --git a/semi-party.cpp b/semi-party.cpp new file mode 100644 index 000000000..749d1a3e2 --- /dev/null +++ b/semi-party.cpp @@ -0,0 +1,15 @@ +/* + * semi-party.cpp + * + */ + +#include "Math/gfp.h" +#include "Math/SemiShare.h" + +#include "Player-Online.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + spdz_main, SemiShare>(argc, argv, opt); +} diff --git a/semi2k-party.cpp b/semi2k-party.cpp new file mode 100644 index 000000000..0edbac358 --- /dev/null +++ b/semi2k-party.cpp @@ -0,0 +1,27 @@ +/* + * semi2k-party.cpp + * + */ + +#include "Math/Semi2kShare.h" +#include "Math/gf2n.h" +#include "Processor/RingOptions.h" + +#include "Player-Online.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + RingOptions opts(opt, argc, argv); + switch (opts.R) + { + case 64: + spdz_main, SemiShare>(argc, argv, opt); + break; + case 72: + spdz_main, SemiShare>(argc, argv, opt); + break; + default: + throw runtime_error(to_string(opts.R) + "-bit computation not implemented"); + } +} diff --git a/shamir-bmr-party.cpp b/shamir-bmr-party.cpp new file mode 100644 index 000000000..e7d38d672 --- /dev/null +++ b/shamir-bmr-party.cpp @@ -0,0 +1,14 @@ +/* + * shamir-bmr-party.cpp + * + */ + +#include "Machines/ShamirMachine.cpp" + +#include "BMR/RealProgramParty.hpp" + +int main(int argc, const char** argv) +{ + ShamirMachine machine(argc, argv); + RealProgramParty>(argc, argv); +} diff --git a/spdz2k-party.cpp b/spdz2k-party.cpp index 3655b4a73..f50063563 100644 --- a/spdz2k-party.cpp +++ b/spdz2k-party.cpp @@ -30,9 +30,9 @@ int main(int argc, const char** argv) cerr << "Using SPDZ2k with security parameter " << s << endl; #endif if (s == 64) - return spdz_main>(argc, argv, opt); + return spdz_main, Share>(argc, argv, opt); else if (s == 48) - return spdz_main>(argc, argv, opt); + return spdz_main, Share>(argc, argv, opt); else throw runtime_error("not compiled for s=" + to_string(s)); }