diff --git a/.gitignore b/.gitignore index 134c8c088..d7d3a322c 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ tags .project .cproject .settings +.pydevproject # VS Code IDE # ############### diff --git a/.gitmodules b/.gitmodules index d21c2beb9..7c1438dd6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "SimpleOT"] path = SimpleOT - url = https://github.com/pascholl/SimpleOT + url = https://github.com/mkskeller/SimpleOT +[submodule "mpir"] + path = mpir + url = git://github.com/wbhart/mpir.git diff --git a/Auth/MAC_Check.cpp b/Auth/MAC_Check.cpp index 462422017..c12139ddb 100644 --- a/Auth/MAC_Check.cpp +++ b/Auth/MAC_Check.cpp @@ -10,6 +10,8 @@ #include "Math/gfp.h" #include "Math/gf2n.h" #include "Math/BitVec.h" +#include "Math/Rep3Share.h" +#include "Math/MaliciousRep3Share.h" #include @@ -73,12 +75,19 @@ void MAC_Check::POpen_End(vector& values,const vector >& S,const } template -void MAC_Check::POpen(vector& values,const vector >& S,const Player& P) +void MAC_Check_Base::POpen(vector& values,const vector& S,const Player& P) { POpen_Begin(values, S, P); POpen_End(values, S, P); } +template +typename T::clear MAC_Check_Base::POpen(const T& secret, const Player& P) +{ + vector opened; + POpen(opened, {secret}, P); + return opened[0]; +} template void MAC_Check::AddToMacs(const vector >& shares) @@ -446,5 +455,7 @@ template class Parallel_MAC_Check; template class Passing_MAC_Check; #endif -template class MAC_Check_Base; -template class MAC_Check_Base; +template class MAC_Check_Base>; +template class MAC_Check_Base>; +template class MAC_Check_Base>; +template class MAC_Check_Base>; diff --git a/Auth/MAC_Check.h b/Auth/MAC_Check.h index ca60a76f2..84de474ed 100644 --- a/Auth/MAC_Check.h +++ b/Auth/MAC_Check.h @@ -60,7 +60,7 @@ class MAC_Check_Base { protected: /* MAC Share */ - T alphai; + typename T::clear alphai; public: int values_opened; @@ -72,12 +72,17 @@ class MAC_Check_Base int number() const { return values_opened; } - const T& get_alphai() const { return alphai; } + const typename T::clear& get_alphai() const { return alphai; } + + virtual void POpen_Begin(vector& values,const vector& S,const Player& P) = 0; + virtual void POpen_End(vector& values,const vector& S,const Player& P) = 0; + void POpen(vector& values,const vector& S,const Player& P); + typename T::clear POpen(const T& secret, const Player& P); }; template -class MAC_Check : public TreeSum, public MAC_Check_Base +class MAC_Check : public TreeSum, public MAC_Check_Base> { protected: @@ -107,7 +112,6 @@ class MAC_Check : public TreeSum, public MAC_Check_Base */ virtual void POpen_Begin(vector& values,const vector >& S,const Player& P); virtual void POpen_End(vector& values,const vector >& S,const Player& P); - void POpen(vector& values,const vector >& S,const Player& P); void AddToCheck(const T& mac, const T& value, const Player& P); virtual void Check(const Player& P); diff --git a/Auth/MaliciousRepMC.h b/Auth/MaliciousRepMC.h new file mode 100644 index 000000000..bfd019c10 --- /dev/null +++ b/Auth/MaliciousRepMC.h @@ -0,0 +1,66 @@ +/* + * MaliciousRepMC.h + * + */ + +#ifndef AUTH_MALICIOUSREPMC_H_ +#define AUTH_MALICIOUSREPMC_H_ + +#include "ReplicatedMC.h" +#include "GC/MaliciousRepSecret.h" +#include "GC/Machine.h" + +template +class MaliciousRepMC : public ReplicatedMC +{ +protected: + typedef ReplicatedMC super; + +public: + virtual void POpen_Begin(vector& values, + const vector& S, const Player& P); + virtual void POpen_End(vector& values, + const vector& S, const Player& P); + + virtual void Check(const Player& P); +}; + +template +class HashMaliciousRepMC : public MaliciousRepMC +{ + crypto_generichash_state* hash_state; + + octetStream os; + +public: + // emulate MAC_Check + HashMaliciousRepMC(const typename T::value_type& _, int __ = 0, int ___ = 0) : HashMaliciousRepMC() + { (void)_; (void)__; (void)___; } + + // emulate Direct_MAC_Check + HashMaliciousRepMC(const typename T::value_type& _, Names& ____, int __ = 0, int ___ = 0) : HashMaliciousRepMC() + { (void)_; (void)__; (void)___; (void)____; } + + HashMaliciousRepMC(); + ~HashMaliciousRepMC(); + + void POpen_End(vector& values,const vector& S,const Player& P); + + void Check(const Player& P); +}; + +template +class CommMaliciousRepMC : public MaliciousRepMC +{ + vector os; + +public: + void POpen_Begin(vector& values, const vector& S, + const Player& P); + void POpen_End(vector& values, const vector& S, + const Player& P); + + void Check(const Player& P); +}; + +#endif /* AUTH_MALICIOUSREPMC_H_ */ diff --git a/Auth/MaliciousRepMC.hpp b/Auth/MaliciousRepMC.hpp new file mode 100644 index 000000000..81709ca68 --- /dev/null +++ b/Auth/MaliciousRepMC.hpp @@ -0,0 +1,107 @@ +/* + * MaliciousRepMC.cpp + * + */ + +#include "MaliciousRepMC.h" +#include "GC/Machine.h" + +#include "ReplicatedMC.hpp" + +#include + +template +void MaliciousRepMC::POpen_Begin(vector& values, + const vector& S, const Player& P) +{ + super::POpen_Begin(values, S, P); +} + +template +void MaliciousRepMC::POpen_End(vector& values, + const vector& S, const Player& P) +{ + (void)values, (void)S, (void)P; + throw runtime_error("use subclass"); +} + +template +void MaliciousRepMC::Check(const Player& P) +{ + (void)P; + throw runtime_error("use subclass"); +} + +template +HashMaliciousRepMC::HashMaliciousRepMC() +{ + // deal with alignment issues + int error = posix_memalign((void**)&hash_state, 64, sizeof(crypto_generichash_state)); + if (error) + throw runtime_error(string("failed to allocate hash state: ") + strerror(error)); + crypto_generichash_init(hash_state, 0, 0, crypto_generichash_BYTES); +} + +template +HashMaliciousRepMC::~HashMaliciousRepMC() +{ + free(hash_state); +} + +template +void HashMaliciousRepMC::POpen_End(vector& values, + const vector& S, const Player& P) +{ + ReplicatedMC::POpen_End(values, S, P); + os.reset_write_head(); + for (auto& value : values) + value.pack(os); + crypto_generichash_update(hash_state, os.get_data(), os.get_length()); +} + +template +void HashMaliciousRepMC::Check(const Player& P) +{ + unsigned char hash[crypto_generichash_BYTES]; + crypto_generichash_final(hash_state, hash, sizeof hash); + crypto_generichash_init(hash_state, 0, 0, crypto_generichash_BYTES); + vector os(P.num_players()); + os[P.my_num()].serialize(hash); + P.Broadcast_Receive(os); + for (int i = 0; i < P.num_players(); i++) + if (os[i] != os[P.my_num()]) + throw mac_fail(); +} + +template +void CommMaliciousRepMC::POpen_Begin(vector& values, + const vector& S, const Player& P) +{ + assert(T::length == 2); + (void)values; + os.resize(2); + for (auto& o : os) + o.reset_write_head(); + for (auto& x : S) + for (int i = 0; i < 2; i++) + x[i].pack(os[1 - i]); + P.send_relative(os); +} + +template +void CommMaliciousRepMC::POpen_End(vector& values, + const vector& S, const Player& P) +{ + P.receive_relative(os); + if (os[0] != os[1]) + throw mac_fail(); + values.clear(); + for (auto& x : S) + values.push_back(os[0].template get() + x.sum()); +} + +template +void CommMaliciousRepMC::Check(const Player& P) +{ + (void)P; +} diff --git a/Auth/ReplicatedMC.h b/Auth/ReplicatedMC.h index 3b006f3b1..7a693dccd 100644 --- a/Auth/ReplicatedMC.h +++ b/Auth/ReplicatedMC.h @@ -9,17 +9,17 @@ #include "MAC_Check.h" template -class ReplicatedMC : public MAC_Check_Base +class ReplicatedMC : public MAC_Check_Base { + octetStream o; + public: // emulate MAC_Check - ReplicatedMC(const gfp& _ = {}, int __ = 0, int ___ = 0) : - MAC_Check_Base({}) + ReplicatedMC(const typename T::value_type& _ = {}, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; } // emulate Direct_MAC_Check - ReplicatedMC(const gfp& _, Names& ____, int __ = 0, int ___ = 0) : - MAC_Check_Base({}) + ReplicatedMC(const typename T::value_type& _, Names& ____, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; (void)____; } void POpen_Begin(vector& values,const vector& S,const Player& P); diff --git a/Auth/ReplicatedMC.cpp b/Auth/ReplicatedMC.hpp similarity index 70% rename from Auth/ReplicatedMC.cpp rename to Auth/ReplicatedMC.hpp index f686dd612..261fc716b 100644 --- a/Auth/ReplicatedMC.cpp +++ b/Auth/ReplicatedMC.hpp @@ -4,8 +4,6 @@ */ #include "ReplicatedMC.h" -#include "GC/ReplicatedSecret.h" -#include "Math/Rep3Share.h" template void ReplicatedMC::POpen_Begin(vector& values, @@ -13,18 +11,17 @@ void ReplicatedMC::POpen_Begin(vector& values, { assert(T::length == 2); (void)values; - octetStream o; + o.reset_write_head(); for (auto& x : S) x[0].pack(o); - P.send_relative(-1, o); + P.pass_around(o, -1); } template void ReplicatedMC::POpen_End(vector& values, const vector& S, const Player& P) { - octetStream o; - P.receive_relative(1, o); + (void)P; values.resize(S.size()); for (size_t i = 0; i < S.size(); i++) { @@ -33,6 +30,3 @@ void ReplicatedMC::POpen_End(vector& values, values[i] = S[i].sum() + tmp; } } - -template class ReplicatedMC; -template class ReplicatedMC; diff --git a/Auth/Subroutines.h b/Auth/Subroutines.h index e1e4a6253..2830d3426 100644 --- a/Auth/Subroutines.h +++ b/Auth/Subroutines.h @@ -63,6 +63,14 @@ int Open_Challenge(vector& e,vector& Open_e, template void Create_Random(T& ans,const Player& P); +template +T Create_Random(const Player& P) +{ + T res; + Create_Random(res, P); + return res; +} + /* Produce a random seed of length len */ void Create_Random_Seed(octet* seed,const Player& P,int len); diff --git a/Auth/fake-stuff.h b/Auth/fake-stuff.h index 115099976..53bcc6b35 100644 --- a/Auth/fake-stuff.h +++ b/Auth/fake-stuff.h @@ -6,22 +6,19 @@ #include "Math/gfp.h" #include "Math/Share.h" #include "Math/Rep3Share.h" +#include "GC/MaliciousRepSecret.h" #include using namespace std; template void make_share(vector >& Sa,const T& a,int N,const T& key,PRNG& G); -void make_share(vector& Sa, const Integer& a, int N, - const Integer& key, PRNG& G); template void check_share(vector >& Sa,T& value,T& mac,int N,const T& key); -void check_share(vector& Sa, Integer& value, Integer& mac, int N, - const Integer& key); - -void expand_byte(gf2n_short& a,int b); -void collapse_byte(int& b,const gf2n_short& a); +template +void check_share(vector& Sa, typename T::clear& value, + typename T::value_type& mac, int N, const typename T::value_type& key); // Generate MAC key shares void generate_keys(const string& directory, int nplayers); @@ -38,9 +35,9 @@ class Files public: ofstream* outf; int N; - T key; + typename T::value_type key; PRNG G; - Files(int N, const T& key, const string& prefix) : N(N), key(key) + Files(int N, const typename T::value_type& key, const string& prefix) : N(N), key(key) { outf = new ofstream[N]; for (int i=0; i > Sa(N); + vector Sa(N); make_share(Sa,a,N,key,G); for (int j=0; j >& Sa,const T& a,int N,const T& key,PRNG& G) Sa[N-1]=S; } -void make_share(vector& Sa, - const Integer& a, int N, const Integer& key, +template +void make_share(FixedVec* Sa, const T& a, int N, PRNG& G); + +template +inline void make_share(vector& Sa, + const typename T::clear& a, int N, const typename T::value_type& key, PRNG& G) { (void)key; + Sa.resize(N); + make_share(Sa.data(), a, N, G); +} + +template +void make_share(FixedVec* Sa, const T& a, int N, PRNG& G) +{ assert(N == 3); insecure("share generation", false); - Sa.resize(N); - FixedVec add_shares; + FixedVec add_shares; // hack add_shares.randomize_to_sum(a, G); for (int i=0; i share; + FixedVec share; share[0] = add_shares[(i + 1) % 3]; share[1] = add_shares[i]; Sa[i] = share; @@ -73,9 +83,9 @@ void check_share(vector >& Sa,T& value,T& mac,int N,const T& key) } } -void check_share(vector& Sa, - Integer& value, Integer& mac, int N, - const Integer& key) +template +void check_share(vector& Sa, typename T::clear& value, + typename T::value_type& mac, int N, const typename T::value_type& key) { assert(N == 3); value = 0; @@ -86,69 +96,20 @@ void check_share(vector& Sa, { auto share = Sa[i]; value += share[0]; - if (share[1] != Sa[positive_modulo(i - 1, N)][0]) + auto a = share[1]; + auto b = Sa[positive_modulo(i - 1, N)][0]; + if (a != b) + { + cout << a << " != " << b << endl; + cout << hex << a.debug() << " != " << b.debug() << endl; + for (int i = 0; i < N; i++) + cout << Sa[i] << endl; throw bad_value("invalid replicated secret sharing"); + } } } -template void make_share(vector >& Sa,const gf2n& a,int N,const gf2n& key,PRNG& G); -template void make_share(vector >& Sa,const gfp& a,int N,const gfp& key,PRNG& G); - -template void check_share(vector >& Sa,gf2n& value,gf2n& mac,int N,const gf2n& key); -template void check_share(vector >& Sa,gfp& value,gfp& mac,int N,const gfp& key); - -#ifdef USE_GF2N_LONG -template void make_share(vector >& Sa,const gf2n_short& a,int N,const gf2n_short& key,PRNG& G); -template void check_share(vector >& Sa,gf2n_short& value,gf2n_short& mac,int N,const gf2n_short& key); -#endif - -// Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40) -void expand_byte(gf2n_short& a,int b) -{ - gf2n_short x,xp; - x.assign(32+1); - xp.assign_one(); - a.assign_zero(); - - while (b!=0) - { if ((b&1)==1) - { a.add(a,xp); } - xp.mul(x); - b>>=1; - } -} - - -// Have previously worked out the linear equations we need to solve -void collapse_byte(int& b,const gf2n_short& aa) -{ - word w=aa.get(); - int e35=(w>>35)&1; - int e30=(w>>30)&1; - int e25=(w>>25)&1; - int e20=(w>>20)&1; - int e15=(w>>15)&1; - int e10=(w>>10)&1; - int e5=(w>>5)&1; - int e0=w&1; - int a[8]; - a[7]=e35; - a[6]=e30^a[7]; - a[5]=e25^a[7]; - a[4]=e20^a[5]^a[6]^a[7]; - a[3]=e15^a[7]; - a[2]=e10^a[3]^a[6]^a[7]; - a[1]=e5^a[3]^a[5]^a[7]; - a[0]=e0^a[1]^a[2]^a[3]^a[4]^a[5]^a[6]^a[7]; - - b=0; - for (int i=7; i>=0; i--) - { b=b<<1; - b+=a[i]; - } -} - -void generate_keys(const string& directory, int nplayers) +inline void generate_keys(const string& directory, int nplayers) { PRNG G; G.ReSeed(); @@ -166,14 +127,19 @@ void generate_keys(const string& directory, int nplayers) } } +inline string mac_filename(string directory, int playerno) +{ + if (directory.empty()) + directory = "."; + return directory + "/Player-MAC-Keys-P" + to_string(playerno); +} + template void write_mac_keys(const string& directory, int i, int nplayers, gfp macp, T mac2) { ofstream outf; stringstream filename; - if (directory.size()) - filename << directory << "/"; - filename << "Player-MAC-Keys-P" << i; + filename << mac_filename(directory, i); cout << "Writing to " << filename.str().c_str() << endl; outf.open(filename.str().c_str()); outf << nplayers << endl; @@ -184,7 +150,7 @@ void write_mac_keys(const string& directory, int i, int nplayers, gfp macp, T ma outf.close(); } -void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers) +inline void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers) { gfp sharep; gf2n share2; @@ -217,6 +183,3 @@ void read_keys(const string& directory, gfp& keyp, gf2n& key2, int nplayers) } std::cout << "Final MAC keys :\t p: " << keyp << "\n\t\t 2: " << key2 << std::endl; } - -template void write_mac_keys(const string& directory, int i, int nplayers, gfp macp, gf2n_short mac2); -template void write_mac_keys(const string& directory, int i, int nplayers, gfp macp, gf2n_long mac2); diff --git a/GC/AuthValue.cpp b/BMR/AuthValue.cpp similarity index 95% rename from GC/AuthValue.cpp rename to BMR/AuthValue.cpp index d58b1a003..7c74a94ed 100644 --- a/GC/AuthValue.cpp +++ b/BMR/AuthValue.cpp @@ -3,7 +3,7 @@ * */ -#include "Secret.h" +#include "GC/Secret.h" namespace GC { diff --git a/BMR/CommonParty.h b/BMR/CommonParty.h index 7439c5da8..771048fad 100644 --- a/BMR/CommonParty.h +++ b/BMR/CommonParty.h @@ -45,6 +45,11 @@ class PersistentFront int get_i() { return i; } }; +namespace GC +{ +template class Machine; +} + class CommonParty : public NodeUpdatable { protected: diff --git a/BMR/Party.cpp b/BMR/Party.cpp index 5c6fb9055..072e72896 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -11,7 +11,7 @@ #include #include -#include +#include "Tools/callgrind.h" #include "proto_utils.h" #include "msg_types.h" diff --git a/BMR/Register.h b/BMR/Register.h index 1c0208d0f..bcc2ff2cc 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -229,7 +229,6 @@ class Phase static T get_input(int from, GC::Processor& processor, int n_bits) { return T::input(from, processor.get_input(n_bits), n_bits); } -// static void check_input(long long in, int n_bits) { (void)in; (void)n_bits; } void input(party_id_t from, char value = -1) { (void)from; (void)value; } void public_input(bool value) { (void)value; } void random() {} diff --git a/BMR/TrustedParty.cpp b/BMR/TrustedParty.cpp index b1842bff2..7ecefd27f 100644 --- a/BMR/TrustedParty.cpp +++ b/BMR/TrustedParty.cpp @@ -17,6 +17,8 @@ #include "SpdzWire.h" #include "Auth/fake-stuff.h" +#include "Auth/fake-stuff.hpp" + TrustedProgramParty* TrustedProgramParty::singleton = 0; diff --git a/Check-Offline.cpp b/Check-Offline.cpp index 8cc56c777..1f0ec91d6 100644 --- a/Check-Offline.cpp +++ b/Check-Offline.cpp @@ -13,6 +13,8 @@ #include "Math/Setup.h" #include "Processor/Data_Files.h" +#include "Auth/fake-stuff.hpp" + #include #include #include @@ -21,10 +23,10 @@ using namespace std; string PREP_DATA_PREFIX; template -void check_mult_triples(const T& key,int N,vector>*>& dataF) +void check_mult_triples(const typename T::value_type& key,int N,vector*>& dataF) { - T a,b,c,mac,res; - vector > Sa(N),Sb(N),Sc(N); + typename T::clear a,b,c,mac,res; + vector Sa(N),Sb(N),Sc(N); int n = 0; try { @@ -89,10 +91,10 @@ void check_tuple(const T& a, const T& b, int n, Dtype type) } template -void check_tuples(const T& key,int N,vector>*>& dataF, Dtype type) +void check_tuples(const typename T::value_type& key,int N,vector*>& dataF, Dtype type) { - T a,b,c,mac,res; - vector > Sa(N),Sb(N),Sc(N); + typename T::clear a,b,c,mac,res; + vector Sa(N),Sb(N),Sc(N); int n = 0; try { @@ -147,10 +149,10 @@ void check_bits(const typename T::value_type& key,int N,vector } template -void check_inputs(const T& key,int N,vector>*>& dataF) +void check_inputs(const typename T::value_type& key,int N,vector*>& dataF) { - T a, mac, x; - vector< Share > Sa(N); + typename T::clear a, mac, x; + vector Sa(N); for (int player = 0; player < N; player++) { @@ -176,27 +178,27 @@ void check_inputs(const T& key,int N,vector>*>& dataF) } template -vector*> setup(int N, DataPositions& usage) +vector*> setup(int N, DataPositions& usage, int thread_num = -1) { vector*> dataF(N); for (int i = 0; i < N; i++) - dataF[i] = new Sub_Data_Files(i, N, PREP_DATA_PREFIX, usage); + dataF[i] = new Sub_Data_Files(i, N, PREP_DATA_PREFIX, usage, thread_num); return dataF; } template -void check(T key, int N, bool only_bits = false) +void check(typename T::value_type key, int N, bool only_bits = false) { DataPositions usage(N); - auto dataF = setup>(N, usage); + auto dataF = setup(N, usage); check_bits(key, N, dataF); if (not only_bits) { check_mult_triples(key, N, dataF); - check_inputs(key, N, dataF); - check_tuples(key, N, dataF, DATA_SQUARE); - check_tuples(key, N, dataF, DATA_INVERSE); + check_inputs(key, N, dataF); + check_tuples(key, N, dataF, DATA_SQUARE); + check_tuples(key, N, dataF, DATA_INVERSE); } } @@ -304,13 +306,19 @@ int main(int argc, const char** argv) cout << "--------------\n"; cout << "Final Keys :\t p: " << keyp << "\n\t\t 2: " << key2 << endl; - check(keyp, N); - check(key2, N); + check(keyp, N); + check>(key2, N); if (N == 3) { DataPositions pos(N); - auto dataF = setup(N, pos); + auto dataF = setup>(N, pos); check_bits({}, N, dataF); + + check>({}, N); + + auto dataF2 = setup(N, pos, 0); + check_mult_triples({}, N, dataF2); + check_bits({}, N, dataF2); } } diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index 3cbf3b910..8a3cf6aaa 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -89,6 +89,10 @@ class shrci(base.Instruction): code = base.opcodes['SHRCI'] arg_format = ['cbw','cb','int'] +class shlci(base.Instruction): + code = base.opcodes['SHLCI'] + arg_format = ['cbw','cb','int'] + class ldbits(base.Instruction): code = opcodes['LDBITS'] arg_format = ['sbw','i','i'] @@ -186,3 +190,16 @@ class print_reg_plain(base.IOInstruction): class print_reg_signed(base.IOInstruction): code = opcodes['PRINTREGSIGNED'] arg_format = ['int','cb'] + +class print_float_plain(base.IOInstruction): + __slots__ = [] + code = base.opcodes['PRINTFLOATPLAIN'] + arg_format = ['cb', 'cb', 'cb', 'cb'] + +class cond_print_str(base.IOInstruction): + r""" Print a 4 character string. """ + code = base.opcodes['CONDPRINTSTR'] + arg_format = ['cb', 'int'] + + def __init__(self, cond, val): + super(cond_print_str, self).__init__(cond, self.str_to_int(val)) diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 285268bc1..8a5ad329e 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -155,11 +155,17 @@ def __rshift__(self, other): res = cbits(n=self.n-other) inst.shrci(res, self, other) return res + def __lshift__(self, other): + res = cbits(n=self.n+other) + inst.shlci(res, self, other) + return res def print_reg(self, desc=''): inst.print_reg(self, desc) def print_reg_plain(self): inst.print_reg_signed(self.n, self) output = print_reg_plain + def print_if(self, string): + inst.cond_print_str(self, string) def reveal(self): return self @@ -557,9 +563,18 @@ def less_than(self, other, *args, **kwargs): assert(len(self.v) == len(other.v)) return self.from_vec(sbitint.bit_less_than(self.v, other.v)) +class cbitfix(object): + def __init__(self, value): + self.v = value + def output(self): + bits = self.v.bit_decompose(self.k) + sign = bits[-1] + v = self.v + (sign << (self.k)) * -1 + inst.print_float_plain(v, cbits(-self.f, n=32), cbits(0), cbits(0)) + class sbitfix(_fix): float_type = type(None) - clear_type = staticmethod(lambda x: x) + clear_type = cbitfix @classmethod def set_precision(cls, f, k=None): super(cls, sbitfix).set_precision(f, k) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index e53b190dd..a2773a64a 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -26,6 +26,7 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \ VARS['program'] = prog if options.binary: VARS['sint'] = GC.types.sbitint.get_type(int(options.binary)) + VARS['sfix'] = GC.types.sbitfix comparison.set_variant(options) print 'Compiling file', prog.infile diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py index 9c02b07d5..684179947 100644 --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -443,7 +443,7 @@ def __setitem__(self, index, value): def __len__(self): return len(self.arrays[0]) -class IntVectorArray(Array, Vector): +class IntVectorArray(Vector, Array): def __init__(self, length): Array.__init__(self, length, 's') diff --git a/Compiler/graph.py b/Compiler/graph.py index 7fd67d738..9ce87be72 100644 --- a/Compiler/graph.py +++ b/Compiler/graph.py @@ -123,7 +123,7 @@ def get_children(node): yield i if nbunch is None: - nbunch = range(len(G)) + nbunch = reversed(range(len(G))) for v in nbunch: # process all vertices in G if v in explored: continue diff --git a/Compiler/instructions.py b/Compiler/instructions.py index a6c6e822a..04bd560cc 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -856,15 +856,16 @@ def has_var_args(self): @base.gf2n @base.vectorize -class asm_input(base.IOInstruction): +class asm_input(base.VarArgsInstruction): r""" Receive input from player $p$ and put in register $s_i$. """ __slots__ = [] code = base.opcodes['INPUT'] - arg_format = ['sw', 'p'] + arg_format = tools.cycle(['sw', 'p']) field_type = 'modp' def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', self.args[1]), \ + for player in self.args[1::2]: + req_node.increment((self.field_type, 'input', player), \ self.get_size()) def execute(self): self.args[0].value = _python_input("Enter player %d's input:" % self.args[1]) % program.P @@ -967,6 +968,14 @@ class print_char4(base.IOInstruction): def __init__(self, val): super(print_char4, self).__init__(self.str_to_int(val)) +class cond_print_str(base.IOInstruction): + r""" Print a 4 character string. """ + code = base.opcodes['CONDPRINTSTR'] + arg_format = ['c', 'int'] + + def __init__(self, cond, val): + super(cond_print_str, self).__init__(cond, self.str_to_int(val)) + @base.vectorize class print_char_regint(base.IOInstruction): r""" Print register $ci_i$ as a single character to stdout. """ diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 4ee7e1526..382fbd5cd 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -162,6 +162,7 @@ PRINTFLOATPLAIN = 0xBC, WRITEFILESHARE = 0xBD, READFILESHARE = 0xBE, + CONDPRINTSTR = 0xBF, GBITDEC = 0x184, GBITCOM = 0x185, # Secure socket diff --git a/Compiler/library.py b/Compiler/library.py index ba4df0fff..d56c9a730 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1,7 +1,7 @@ from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat from Compiler.instructions import * from Compiler.util import tuplify,untuplify -from Compiler import instructions,instructions_base,comparison,program +from Compiler import instructions,instructions_base,comparison,program,util import inspect,math import random import collections @@ -93,6 +93,16 @@ def print_ln(s='', *args): print_str(s, *args) print_char('\n') +def print_ln_if(cond, s): + if util.is_constant(cond): + if cond: + print_ln(s) + else: + s += '\n' + while s: + cond.print_if(s[:4]) + s = s[4:] + def runtime_error(msg='', *args): """ Print an error message and abort the runtime. """ print_str('User exception: ') diff --git a/Compiler/program.py b/Compiler/program.py index 6f83d2c3a..c657e4db9 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -67,7 +67,9 @@ def __init__(self, args, options, param=-1, assemblymode=False): self.to_merge = [Compiler.instructions.asm_open_class, \ Compiler.instructions.gasm_open_class, \ Compiler.instructions.muls_class, \ - Compiler.instructions.gmuls_class] + Compiler.instructions.gmuls_class, \ + Compiler.instructions.asm_input_class, \ + Compiler.instructions.gasm_input_class] import Compiler.GC.instructions as gc self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \ gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb] @@ -563,6 +565,8 @@ def optimize(self, options): print 'Block requires', \ ', '.join('%d %s' % (y, x.__name__) \ for x, y in merger.counter.items()) + # free memory + merger = None if options.dead_code_elimination: block.instructions = filter(lambda x: x is not None, block.instructions) if not (options.merge_opens and self.merge_opens): diff --git a/Compiler/types.py b/Compiler/types.py index c556dfd7e..e8311d3ee 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -192,6 +192,8 @@ def conv(cls, val): return type(val)(cls.conv(v) for v in val) except TypeError: pass + except CompilerError: + pass return cls(val) @vectorized_classmethod @@ -517,6 +519,9 @@ def digest(self, num_bytes): digestc(res, self, num_bytes) return res + def print_if(self, string): + cond_print_str(self, string) + @@ -1941,7 +1946,12 @@ def __div__(self, other): raise TypeError('Incompatible fixed point types in division') def print_plain(self): - sign = cint(self.v < 0) + if self.k > 64: + raise CompilerError('Printing of fixed-point numbers not ' + + 'implemented for more than 64-bit precision') + tmp = regint() + convmodp(tmp, self.v, bitlength=self.k) + sign = cint(tmp < 0) abs_v = sign.if_else(-self.v, self.v) print_float_plain(cint(abs_v), cint(-self.f), \ cint(0), cint(sign)) @@ -1984,11 +1994,22 @@ def load_mem(cls, address, mem_type=None): return cls(*res) @classmethod - def load_sint(cls, v): + def from_sint(cls, other): res = cls() - res.load_int(v) + res.load_int(cls.int_type.conv(other)) return res + @classmethod + def conv(cls, other): + if isinstance(other, cls): + return other + else: + try: + return cls.from_sint(other) + except (TypeError, CompilerError): + pass + return cls(other) + @vectorize_init def __init__(self, _v=None, size=None): self.size = get_global_vector_size() @@ -2017,9 +2038,6 @@ def __init__(self, _v=None, size=None): def load_int(self, v): self.v = self.int_type(v) << self.f - def conv(self): - return self - def store_in_mem(self, address): self.v.store_in_mem(address) @@ -2028,7 +2046,7 @@ def sizeof(self): @vectorize def add(self, other): - other = parse_type(other) + other = self.conv(other) if isinstance(other, (_fix, cfix)): return type(self)(self.v + other.v) elif isinstance(other, cfix.scalars): @@ -2039,7 +2057,7 @@ def add(self, other): @vectorize def mul(self, other): - other = parse_type(other) + other = self.conv(other) if isinstance(other, _fix): val = self.v.TruncMul(other.v, self.k * 2, self.f, self.kappa) return type(self)(val) @@ -2054,7 +2072,7 @@ def mul(self, other): @vectorize def __sub__(self, other): - other = parse_type(other) + other = self.conv(other) return self + (-other) @vectorize @@ -2066,7 +2084,7 @@ def __rsub__(self, other): @vectorize def __eq__(self, other): - other = parse_type(other) + other = self.conv(other) if isinstance(other, (cfix, _fix)): return self.v.equal(other.v, self.k, self.kappa) else: @@ -2074,7 +2092,7 @@ def __eq__(self, other): @vectorize def __le__(self, other): - other = parse_type(other) + other = self.conv(other) if isinstance(other, (cfix, _fix)): return self.v.less_equal(other.v, self.k, self.kappa) else: @@ -2082,7 +2100,7 @@ def __le__(self, other): @vectorize def __lt__(self, other): - other = parse_type(other) + other = self.conv(other) if isinstance(other, (cfix, _fix)): return self.v.less_than(other.v, self.k, self.kappa) else: @@ -2090,7 +2108,7 @@ def __lt__(self, other): @vectorize def __ge__(self, other): - other = parse_type(other) + other = self.conv(other) if isinstance(other, (cfix, _fix)): return self.v.greater_equal(other.v, self.k, self.kappa) else: @@ -2098,7 +2116,7 @@ def __ge__(self, other): @vectorize def __gt__(self, other): - other = parse_type(other) + other = self.conv(other) if isinstance(other, (cfix, _fix)): return self.v.greater_than(other.v, self.k, self.kappa) else: @@ -2106,7 +2124,7 @@ def __gt__(self, other): @vectorize def __ne__(self, other): - other = parse_type(other) + other = self.conv(other) if isinstance(other, (cfix, _fix)): return self.v.not_equal(other.v, self.k, self.kappa) else: @@ -2114,7 +2132,7 @@ def __ne__(self, other): @vectorize def __div__(self, other): - other = parse_type(other) + other = self.conv(other) if isinstance(other, _fix): return type(self)(library.FPDiv(self.v, other.v, self.k, self.f, self.kappa)) elif isinstance(other, cfix): @@ -2122,18 +2140,28 @@ def __div__(self, other): else: raise TypeError('Incompatible fixed point types in division') + def __rdiv__(self, other): + return self.conv(other) / self + @vectorize def compute_reciprocal(self): return type(self)(library.FPDiv(cint(2) ** self.f, self.v, self.k, self.f, self.kappa, True)) def reveal(self): val = self.v.reveal() - return self.clear_type(val) + res = self.clear_type(val) + res.f = self.f + res.k = self.k + return res class sfix(_fix): int_type = sint clear_type = cfix + @classmethod + def conv(cls, other): + return parse_type(other) + # this is for 20 bit decimal precision # with 40 bitlength of entire number # these constants have been chosen for multiplications to fit in 128 bit prime field @@ -2467,7 +2495,7 @@ def create_from(cls, l): res.assign(tmp) return res - def __init__(self, length, value_type, address=None): + def __init__(self, length, value_type, address=None, debug=None): if value_type in _types: value_type = _types[value_type] self.address = address @@ -2476,6 +2504,7 @@ def __init__(self, length, value_type, address=None): if address is None: self.address = self._malloc() self.address_cache = {} + self.debug = debug def _malloc(self): return program.malloc(self.length, self.value_type) @@ -2492,6 +2521,9 @@ def get_address(self, index): (str(index), str(self.length))) if (program.curr_block, index) not in self.address_cache: self.address_cache[program.curr_block, index] = self.address + index + if self.debug: + library.print_ln_if(index >= self.length, 'OF:' + self.debug) + library.print_ln_if(self.address_cache[program.curr_block, index] >= program.allocated_mem[self.value_type.reg_type], 'AOF:' + self.debug) return self.address_cache[program.curr_block, index] def get_slice(self, index): @@ -2589,25 +2621,37 @@ def reveal(self): class SubMultiArray(object): - def __init__(self, sizes, value_type, address, index): + def __init__(self, sizes, value_type, address, index, debug=None): self.sizes = sizes self.value_type = value_type self.address = address + index * reduce(operator.mul, self.sizes) self.sub_cache = {} + self.debug = debug + if debug: + library.print_ln_if(self.address + reduce(operator.mul, self.sizes) > program.allocated_mem[self.value_type.reg_type], 'AOF%d:' % len(self.sizes) + self.debug) def __getitem__(self, index): + if util.is_constant(index) and index >= self.sizes[0]: + raise StopIteration key = program.curr_block, index if key not in self.sub_cache: + if self.debug: + library.print_ln_if(index >= self.sizes[0], \ + 'OF%d:' % len(self.sizes) + self.debug) if len(self.sizes) == 2: self.sub_cache[key] = \ Array(self.sizes[1], self.value_type, \ - self.address + index * self.sizes[1]) + self.address + index * self.sizes[1], \ + debug=self.debug) else: self.sub_cache[key] = \ SubMultiArray(self.sizes[1:], self.value_type, \ - self.address, index) + self.address, index, debug=self.debug) return self.sub_cache[key] + def __len__(self): + return self.sizes[0] + def assign_all(self, value): @library.for_range(self.sizes[0]) def f(i): @@ -2615,16 +2659,17 @@ def f(i): return self class MultiArray(SubMultiArray): - def __init__(self, sizes, value_type): + def __init__(self, sizes, value_type, debug=None): self.array = Array(reduce(operator.mul, sizes), \ value_type) - SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0) + SubMultiArray.__init__(self, sizes, value_type, self.array.address, 0, \ + debug=debug) if len(sizes) < 2: raise CompilerError('Use Array') class Matrix(MultiArray): - def __init__(self, rows, columns, value_type): - MultiArray.__init__(self, [rows, columns], value_type) + def __init__(self, rows, columns, value_type, debug=None): + MultiArray.__init__(self, [rows, columns], value_type, debug=debug) def __mul__(self, other): assert isinstance(other, Array) diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp index 07cfb6ba6..ebf31c36d 100644 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -7,6 +7,8 @@ #include "Tools/benchmarking.h" #include "Auth/fake-stuff.h" +#include "Auth/fake-stuff.hpp" + PairwiseMachine::PairwiseMachine(int argc, const char** argv) : MachineBase(argc, argv), P(N, 0xffff << 16), other_pks(N.num_players(), {setup_p.params, 0}), diff --git a/Fake-Offline.cpp b/Fake-Offline.cpp index a7e63cbd3..7c80a44db 100644 --- a/Fake-Offline.cpp +++ b/Fake-Offline.cpp @@ -5,6 +5,7 @@ #include "Math/Setup.h" #include "Auth/fake-stuff.h" #include "Exceptions/Exceptions.h" +#include "GC/MaliciousRepSecret.h" #include "Math/Setup.h" #include "Processor/Data_Files.h" @@ -12,6 +13,8 @@ #include "Tools/ezOptionParser.h" #include "Tools/benchmarking.h" +#include "Auth/fake-stuff.hpp" + #include #include using namespace std; @@ -24,18 +27,20 @@ string prep_data_prefix; * str = "2" or "p" */ template -void make_mult_triples(const T& key,int N,int ntrip,const string& str,bool zero) +void make_mult_triples(const typename T::value_type& key, int N, int ntrip, + bool zero, int thread_num = -1) { PRNG G; G.ReSeed(); ofstream* outf=new ofstream[N]; - T a,b,c; - vector > Sa(N),Sb(N),Sc(N); + typename T::value_type a,b,c; + vector Sa(N),Sb(N),Sc(N); /* Generate Triples */ for (int i=0; i::get_suffix(thread_num); cout << "Opening " << filename.str() << endl; outf[i].open(filename.str().c_str(),ios::out | ios::binary); if (outf[i].fail()) { throw file_error(filename.str().c_str()); } @@ -108,14 +113,14 @@ void make_bit_triples(const gf2n& key,int N,int ntrip,Dtype dtype,bool zero) * str = "2" or "p" */ template -void make_square_tuples(const T& key,int N,int ntrip,const string& str,bool zero) +void make_square_tuples(const typename T::value_type& key,int N,int ntrip,const string& str,bool zero) { PRNG G; G.ReSeed(); ofstream* outf=new ofstream[N]; - T a,c; - vector > Sa(N),Sc(N); + typename T::clear a,c; + vector Sa(N),Sc(N); /* Generate Squares */ for (int i=0; i -void make_bits(const typename T::value_type& key,int N,int ntrip,bool zero) +void make_bits(const typename T::value_type& key, int N, int ntrip, bool zero, + int thread_num = -1) { PRNG G; G.ReSeed(); @@ -156,7 +162,8 @@ void make_bits(const typename T::value_type& key,int N,int ntrip,bool zero) /* Generate Bits */ for (int i=0; i::get_suffix(thread_num); cout << "Opening " << filename.str() << endl; outf[i].open(filename.str().c_str(),ios::out | ios::binary); if (outf[i].fail()) { throw file_error(filename.str().c_str()); } @@ -180,14 +187,14 @@ void make_bits(const typename T::value_type& key,int N,int ntrip,bool zero) * */ template -void make_inputs(const T& key,int N,int ntrip,const string& str,bool zero) +void make_inputs(const typename T::value_type& key,int N,int ntrip,const string& str,bool zero) { PRNG G; G.ReSeed(); ofstream* outf=new ofstream[N]; - T a; - vector > Sa(N); + typename T::clear a; + vector Sa(N); /* Generate Inputs */ for (int player=0; player -void make_inverse(const T& key,int N,int ntrip,bool zero) +void make_inverse(const typename T::value_type& key,int N,int ntrip,bool zero) { PRNG G; G.ReSeed(); ofstream* outf=new ofstream[N]; - T a,b; - vector > Sa(N),Sb(N); + typename T::clear a,b; + vector Sa(N),Sb(N); /* Generate Triples */ for (int i=0; i -void make_PreMulC(const T& key, int N, int ntrip, bool zero) +void make_PreMulC(const typename T::value_type& key, int N, int ntrip, bool zero) { stringstream ss; - ss << prep_data_prefix << "PreMulC-" << T::type_char(); + ss << prep_data_prefix << "PreMulC-" << T::type_short(); Files files(N, key, ss.str()); PRNG G; G.ReSeed(); - T a, b, c; + typename T::clear a, b, c; c = 1; for (int i=0; i +void make_basic(const typename T::value_type& key, int nplayers, int nitems, bool zero) +{ + make_mult_triples(key, nplayers, nitems, zero); + make_bits(key, nplayers, nitems, zero); + make_square_tuples(key, nplayers, nitems, T::type_short(), zero); + make_inputs(key, nplayers, nitems, T::type_short(), zero); + make_inverse(key, nplayers, nitems, zero); + make_PreMulC(key, nplayers, nitems, zero); +} + int main(int argc, const char** argv) { insecure("preprocessing"); @@ -535,22 +553,31 @@ int main(int argc, const char** argv) cout << "--------------\n"; cout << "Final Keys :\t p: " << keyp << "\n\t\t 2: " << key2 << endl; - make_mult_triples(key2,nplayers,ntrip2,"2",zero); - make_mult_triples(keyp,nplayers,ntripp,"p",zero); + typedef Share sgf2n; + + make_mult_triples(key2,nplayers,ntrip2,zero); + make_mult_triples(keyp,nplayers,ntripp,zero); make_bits>(key2,nplayers,nbits2,zero); make_bits>(keyp,nplayers,nbitsp,zero); - make_square_tuples(key2,nplayers,nsqr2,"2",zero); - make_square_tuples(keyp,nplayers,nsqrp,"p",zero); - make_inputs(key2,nplayers,ninp2,"2",zero); - make_inputs(keyp,nplayers,ninpp,"p",zero); - make_inverse(key2,nplayers,ninv,zero); - make_inverse(keyp,nplayers,ninv,zero); + make_square_tuples(key2,nplayers,nsqr2,"2",zero); + make_square_tuples(keyp,nplayers,nsqrp,"p",zero); + make_inputs(key2,nplayers,ninp2,"2",zero); + make_inputs(keyp,nplayers,ninpp,"p",zero); + make_inverse(key2,nplayers,ninv,zero); + make_inverse(keyp,nplayers,ninv,zero); make_bit_triples(key2,nplayers,nbittrip,DATA_BITTRIPLE,zero); make_bit_triples(key2,nplayers,nbitgf2ntrip,DATA_BITGF2NTRIPLE,zero); - make_PreMulC(key2,nplayers,ninv,zero); - make_PreMulC(keyp,nplayers,ninv,zero); + make_PreMulC(key2,nplayers,ninv,zero); + make_PreMulC(keyp,nplayers,ninv,zero); // replicated secret sharing only for three parties if (nplayers == 3) - make_bits({}, nplayers, nbitsp, zero); + { + make_bits>({}, nplayers, nbitsp, zero); + make_basic>({}, nplayers, default_num, zero); + make_basic>({}, nplayers, default_num, zero); + + make_mult_triples({}, nplayers, ntrip2, zero); + make_bits({}, nplayers, nbits2, zero); + } } diff --git a/GC/ArgTuples.h b/GC/ArgTuples.h index 7bf676726..10e3ef1d9 100644 --- a/GC/ArgTuples.h +++ b/GC/ArgTuples.h @@ -78,6 +78,21 @@ class InputArgs } }; +class InputArgList : public ArgList +{ +public: + InputArgList(const vector& args) : + ArgList(args) + { + } + int n_inputs_from(int from) + { + int res = 0; + for (auto x : *this) + res += x.from == from; + return res; + } +}; #endif /* GC_ARGTUPLES_H_ */ diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index fd721cccd..05952a723 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -10,7 +10,8 @@ #include "GC/Memory.h" #include "GC/Access.h" -#include "Auth/MAC_Check.h" +#include "Math/gf2nlong.h" + #include "Processor/DummyProtocol.h" #include @@ -30,7 +31,7 @@ class FakeSecret typedef FakeSecret DynamicType; // dummy - typedef MAC_Check_Base MC; + typedef DummyMC MC; typedef DummyProtocol Protocol; static string type_string() { return "fake secret"; } diff --git a/GC/Instruction.h b/GC/Instruction.h index 928b2fb9f..ffd9cf3be 100644 --- a/GC/Instruction.h +++ b/GC/Instruction.h @@ -10,8 +10,6 @@ #include using namespace std; -#include "GC/Processor.h" - #include "Processor/Instruction.h" namespace GC @@ -27,6 +25,7 @@ enum RegType { NONE }; +template class Processor; template class Instruction : public ::BaseInstruction @@ -49,10 +48,10 @@ class Instruction : public ::BaseInstruction int get_reg_type() const; // Returns the maximal register used - int get_max_reg(int reg_type) const; + unsigned get_max_reg(int reg_type) const; // Returns the memory size used if applicable and known - int get_mem(RegType reg_type) const; + unsigned get_mem(RegType reg_type) const; // Execute this instruction bool exe(Processor& processor) const { return code(*this, processor); } diff --git a/GC/Instruction.hpp b/GC/Instruction.hpp index a68f5a6a7..0f867327d 100644 --- a/GC/Instruction.hpp +++ b/GC/Instruction.hpp @@ -6,6 +6,8 @@ #include #include "GC/Instruction.h" +#include "GC/Processor.h" + #ifdef MAX_INLINE #include "GC/Secret_inline.h" #endif @@ -73,7 +75,7 @@ int Instruction::get_reg_type() const } template -int GC::Instruction::get_max_reg(int reg_type) const +unsigned GC::Instruction::get_max_reg(int reg_type) const { int skip; int offset = 0; @@ -101,17 +103,17 @@ int GC::Instruction::get_max_reg(int reg_type) const return BaseInstruction::get_max_reg(reg_type); } - int m = 0; + unsigned m = 0; if (reg_type == SBIT) for (size_t i = offset; i < start.size(); i += skip) - m = max(m, start[i] + 1); + m = max(m, (unsigned)start[i] + 1); return m; } template -int Instruction::get_mem(RegType reg_type) const +unsigned Instruction::get_mem(RegType reg_type) const { - int m = n + 1; + unsigned m = n + 1; switch (opcode) { case LDMSD: @@ -119,7 +121,7 @@ int Instruction::get_mem(RegType reg_type) const { m = 0; for (size_t i = 0; i < start.size() / 3; i++) - m = max(m, start[3*i+1] + 1); + m = max(m, (unsigned)start[3*i+1] + 1); return m; } break; @@ -128,7 +130,7 @@ int Instruction::get_mem(RegType reg_type) const { m = 0; for (size_t i = 0; i < start.size() / 2; i++) - m = max(m, start[2*i+1] + 1); + m = max(m, (unsigned)start[2*i+1] + 1); return m; } break; @@ -232,6 +234,8 @@ void Instruction::parse(istream& s, int pos) default: ostringstream os; os << "Code not defined for instruction " << showbase << hex << opcode << dec; + os << "This virtual machine executes binary circuits only." << endl; + os << "Try compiling with '-B' or use only sbit* types." << endl; throw Invalid_Instruction(os.str()); break; } diff --git a/GC/Machine.cpp b/GC/Machine.cpp index fee6d9920..2c5e9f50b 100644 --- a/GC/Machine.cpp +++ b/GC/Machine.cpp @@ -3,24 +3,34 @@ * */ +#include "MaliciousRepSecret.h" +#include "Auth/ReplicatedMC.h" +#include "Auth/MaliciousRepMC.h" + +#include "Instruction.hpp" #include "Machine.hpp" #include "Processor.hpp" +#include "Program.hpp" #include "Thread.hpp" #include "ThreadMaster.hpp" +#include "Auth/MaliciousRepMC.hpp" namespace GC { -template class Machine; -template class Machine; - -template class Processor; -template class Processor; +extern template class ReplicatedSecret; +extern template class ReplicatedSecret; -template class Thread; -template class Thread; +#define GC_MACHINE(T) \ + template class Instruction; \ + template class Machine; \ + template class Processor; \ + template class Program; \ + template class Thread; \ + template class ThreadMaster; \ -template class ThreadMaster; -template class ThreadMaster; +GC_MACHINE(FakeSecret); +GC_MACHINE(SemiHonestRepSecret); +GC_MACHINE(MaliciousRepSecret) } /* namespace GC */ diff --git a/GC/Machine.h b/GC/Machine.h index 411f38d2d..4244128fc 100644 --- a/GC/Machine.h +++ b/GC/Machine.h @@ -10,7 +10,7 @@ #include "GC/Clear.h" #include "GC/Memory.h" -#include "Processor/Machine.h" +#include "Processor/BaseMachine.h" #include using namespace std; @@ -32,6 +32,7 @@ class Machine : public ::BaseMachine vector > progs; bool use_encryption; + bool more_comm_less_comp; Machine(Memory& MD); ~Machine(); diff --git a/GC/Machine.hpp b/GC/Machine.hpp index f1b6ca59c..4ed2959c3 100644 --- a/GC/Machine.hpp +++ b/GC/Machine.hpp @@ -17,6 +17,7 @@ template Machine::Machine(Memory& dynamic_memory) : MD(dynamic_memory) { use_encryption = false; + more_comm_less_comp = false; start_timer(); } diff --git a/GC/MaliciousRepSecret.h b/GC/MaliciousRepSecret.h new file mode 100644 index 000000000..0fcb5811b --- /dev/null +++ b/GC/MaliciousRepSecret.h @@ -0,0 +1,34 @@ +/* + * MaliciousRepSecret.h + * + */ + +#ifndef GC_MALICIOUSREPSECRET_H_ +#define GC_MALICIOUSREPSECRET_H_ + +#include "ReplicatedSecret.h" + +template class MaliciousRepMC; + +namespace GC +{ + +class MaliciousRepThread; + +class MaliciousRepSecret : public ReplicatedSecret +{ + typedef ReplicatedSecret super; + +public: + typedef MaliciousRepSecret DynamicType; + + typedef MaliciousRepMC MC; + + MaliciousRepSecret() {} + template + MaliciousRepSecret(const T& other) : super(other) {} +}; + +} + +#endif /* GC_MALICIOUSREPSECRET_H_ */ diff --git a/GC/MaliciousRepThread.cpp b/GC/MaliciousRepThread.cpp new file mode 100644 index 000000000..966d332b9 --- /dev/null +++ b/GC/MaliciousRepThread.cpp @@ -0,0 +1,94 @@ +/* + * MalicousRepParty.cpp + * + */ + +#include "Auth/MaliciousRepMC.h" +#include "MaliciousRepThread.h" +#include "Math/Setup.h" + +#include "Auth/MaliciousRepMC.hpp" + +namespace GC +{ + +thread_local MaliciousRepThread* MaliciousRepThread::singleton = 0; + +MaliciousRepThread::MaliciousRepThread(int i, + ThreadMaster& master) : + Thread(i, master), DataF(N.my_num(), + N.num_players(), + get_prep_dir(N.num_players(), 128, gf2n::default_degree()), + usage, i) +{ +} + +MaliciousRepMC* MaliciousRepThread::new_mc() +{ + if (machine.more_comm_less_comp) + return new CommMaliciousRepMC; + else + return new HashMaliciousRepMC; +} + +void MaliciousRepThread::pre_run() +{ + if (singleton) + throw runtime_error("there can only be one"); + singleton = this; +} + +void MaliciousRepThread::post_run() +{ +#ifndef INSECURE + cerr << "Removing used pre-processed data" << endl; + DataF.prune(); +#endif +} + +void MaliciousRepThread::and_(Processor& processor, + const vector& args, bool repeat) +{ + assert(P->num_players() == 3); + os.resize(2); + for (auto& o : os) + o.reset_write_head(); + processor.check_args(args, 4); + shares.clear(); + triples.clear(); + for (size_t i = 0; i < args.size(); i += 4) + { + int n_bits = args[i]; + int left = args[i + 2]; + int right = args[i + 3]; + triples.push_back({0}); + DataF.get(DATA_TRIPLE, triples.back().data()); + shares.push_back((processor.S[left] - triples.back()[0]).mask(n_bits)); + MaliciousRepSecret y_ext; + if (repeat) + y_ext = processor.S[right].extend_bit(); + else + y_ext = processor.S[right]; + shares.push_back((y_ext - triples.back()[1]).mask(n_bits)); + } + + MC->POpen_Begin(opened, shares, *P); + MC->POpen_End(opened, shares, *P); + auto it = opened.begin(); + + for (size_t i = 0; i < args.size(); i += 4) + { + int n_bits = args[i]; + int out = args[i + 1]; + MaliciousRepSecret tmp = triples[i / 4][2]; + BitVec masked[2]; + for (int k = 0; k < 2; k++) + { + masked[k] = *it++; + tmp += triples[i / 4][1 - k] & masked[k]; + } + processor.S[out] = (tmp + (masked[0] & masked[1])).mask(n_bits); + } +} + +} /* namespace GC */ diff --git a/GC/MaliciousRepThread.h b/GC/MaliciousRepThread.h new file mode 100644 index 000000000..1e46cebc9 --- /dev/null +++ b/GC/MaliciousRepThread.h @@ -0,0 +1,53 @@ +/* + * MalicousRepParty.h + * + */ + +#ifndef GC_MALICIOUSREPTHREAD_H_ +#define GC_MALICIOUSREPTHREAD_H_ + +#include "Thread.h" +#include "MaliciousRepSecret.h" +#include "Processor/Data_Files.h" + +#include + +namespace GC +{ + +class MaliciousRepThread : public Thread +{ + static thread_local MaliciousRepThread* singleton; + + vector shares; + vector opened; + vector> triples; + +public: + static MaliciousRepThread& s(); + + DataPositions usage; + Sub_Data_Files DataF; + + MaliciousRepThread(int i, ThreadMaster& master); + virtual ~MaliciousRepThread() {} + + MaliciousRepSecret::MC* new_mc(); + + void pre_run(); + void post_run(); + + void and_(Processor& processor, const vector& args, bool repeat); +}; + +inline MaliciousRepThread& MaliciousRepThread::s() +{ + if (singleton) + return *singleton; + else + throw runtime_error("no singleton"); +} + +} /* namespace GC */ + +#endif /* GC_MALICIOUSREPTHREAD_H_ */ diff --git a/GC/Memory.h b/GC/Memory.h index 82486a1ea..33c3f51c4 100644 --- a/GC/Memory.h +++ b/GC/Memory.h @@ -14,6 +14,7 @@ using namespace std; #include "Exceptions/Exceptions.h" #include "Clear.h" +#include "config.h" namespace GC { @@ -41,7 +42,7 @@ inline void Memory::check_index(Integer index) const if (i >= vector::size()) { stringstream ss; - ss << "Memory overflow: " << i << "/" << vector::size(); + ss << T::type_string() << " memory overflow: " << i << "/" << vector::size(); throw Processor_Error(ss.str()); } #endif @@ -69,7 +70,7 @@ template inline void Memory::resize(size_t size, const char* name) { if (size > 1000) - cout << "Resizing " << T::type_string() << " " << name << " to " << size << endl; + cerr << "Resizing " << T::type_string() << " " << name << " to " << size << endl; vector::resize(size); } diff --git a/GC/Processor.h b/GC/Processor.h index 059ac1819..a05f8682f 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -13,19 +13,27 @@ using namespace std; #include "GC/Machine.h" #include "Math/Integer.h" -#include "Processor/Processor.h" +#include "Processor/ProcessorBase.h" namespace GC { template class Program; +class ExecutionStats : public map +{ +public: + ExecutionStats& operator+=(const ExecutionStats& other) + { + for (auto it : other) + (*this)[it.first] += it.second; + return *this; + } +}; + template class Processor : public ::ProcessorBase { - ifstream input_file; - string input_filename; - public: static int check_args(const vector& args, int n); @@ -43,14 +51,15 @@ class Processor : public ::ProcessorBase Memory C; Memory I; + ExecutionStats stats; + Processor(Machine& machine); ~Processor(); void reset(const Program& program, int arg); void reset(const Program& program); - void open_input_file(const string& name); - long long get_input(int n_bits); + long long get_input(int n_bits, bool interactive = false); void bitcoms(T& x, const vector& regs) { x.bitcom(S, regs); } void bitdecs(const vector& regs, const T& x) { x.bitdec(S, regs); } @@ -77,6 +86,7 @@ class Processor : public ::ProcessorBase void print_reg_signed(unsigned n_bits, Clear& value); void print_chr(int n); void print_str(int n); + void print_float(const vector& args); }; template diff --git a/GC/Processor.hpp b/GC/Processor.hpp index d66a6c56a..60ba2cb9d 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -47,22 +47,9 @@ void Processor::reset(const Program& program) } template -void GC::Processor::open_input_file(const string& name) +inline long long GC::Processor::get_input(int n_bits, bool interactive) { - cout << "opening " << name << endl; - input_file.open(name); - input_filename = name; -} - -template -inline long long GC::Processor::get_input(int n_bits) -{ - long long res; - input_file >> res; - if (input_file.eof()) - throw IO_Error("not enough inputs in " + input_filename); - if (input_file.fail()) - throw IO_Error("cannot read from " + input_filename); + long long res = ProcessorBase::get_input(interactive); check_input(res, n_bits); return res; } @@ -231,4 +218,10 @@ void Processor::print_str(int n) T::out << string((char*)&n,sizeof(n)) << flush; } +template +void Processor::print_float(const vector& args) +{ + T::out << bigint::get_float(C[args[0]], C[args[1]], C[args[2]], C[args[3]]) << flush; +} + } /* namespace GC */ diff --git a/GC/Program.cpp b/GC/Program.cpp deleted file mode 100644 index cf3185b52..000000000 --- a/GC/Program.cpp +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Program.cpp - * - */ - -#include "Instruction.hpp" -#include "Program.hpp" - -namespace GC -{ - -template class Instruction; -template class Instruction; - -template class Program; -template class Program; - -} /* namespace GC */ diff --git a/GC/Program.h b/GC/Program.h index 420c6e2c1..ed13a1455 100644 --- a/GC/Program.h +++ b/GC/Program.h @@ -29,10 +29,10 @@ class Program int offline_data_used; // Maximal register used - int max_reg[MAX_REG_TYPE]; + unsigned max_reg[MAX_REG_TYPE]; // Memory size used directly - int max_mem[MAX_REG_TYPE]; + unsigned max_mem[MAX_REG_TYPE]; // True if program contains variable-sized loop bool unknown_usage; @@ -53,10 +53,10 @@ class Program bool usage_unknown() const { return unknown_usage; } - int num_reg(RegType reg_type) const + unsigned num_reg(RegType reg_type) const { return max_reg[reg_type]; } - int direct_mem(RegType reg_type) const + unsigned direct_mem(RegType reg_type) const { return max_mem[reg_type]; } // Execute this program, updateing the processor and memory diff --git a/GC/Program.hpp b/GC/Program.hpp index 70c2e23b2..2fa19dfa4 100644 --- a/GC/Program.hpp +++ b/GC/Program.hpp @@ -7,8 +7,9 @@ #include "Secret.h" #include "ReplicatedSecret.h" +#include "config.h" -#include +#include "Tools/callgrind.h" #ifdef MAX_INLINE #include "Instruction_inline.h" @@ -118,6 +119,9 @@ BreakType Program::execute(Processor& Proc, int PC) const Proc.time = time; return DONE_BREAK; } +#ifdef COUNT_INSTRUCTIONS + Proc.stats[p[Proc.PC].get_opcode()]++; +#endif p[Proc.PC++].execute(Proc); time++; #ifdef DEBUG_COMPLEXITY diff --git a/GC/ReplicatedParty.cpp b/GC/ReplicatedParty.cpp index eab0a9967..5bfe06874 100644 --- a/GC/ReplicatedParty.cpp +++ b/GC/ReplicatedParty.cpp @@ -5,6 +5,7 @@ #include "ReplicatedParty.h" #include "Thread.h" +#include "MaliciousRepThread.h" #include "Networking/Server.h" #include "Tools/ezOptionParser.h" #include "Tools/benchmarking.h" @@ -12,9 +13,10 @@ namespace GC { -ReplicatedParty::ReplicatedParty(int argc, const char** argv) +template +ReplicatedParty::ReplicatedParty(int argc, const char** argv) : + ThreadMaster(online_opts), online_opts(opt, argc, argv) { - ez::ezOptionParser opt; opt.add( "", // Default. 1, // Required? @@ -51,11 +53,20 @@ ReplicatedParty::ReplicatedParty(int argc, const char** argv) "-u", // Flag token. "--unencrypted" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Check opening by communication instead of hashing.", // Help description. + "-c", // Flag token. + "--communication" // Flag token. + ); opt.parse(argc, argv); opt.syntax = "./replicated-bin-party.x [OPTIONS] "; if (opt.lastArgs.size() == 1) { - progname = *opt.lastArgs[0]; + this->progname = *opt.lastArgs[0]; } else { @@ -71,20 +82,52 @@ ReplicatedParty::ReplicatedParty(int argc, const char** argv) opt.get("-p")->getInt(my_num); opt.get("-pn")->getInt(pnb); opt.get("-h")->getString(hostname); - machine.use_encryption = not opt.get("-u")->isSet; + this->machine.use_encryption = not opt.get("-u")->isSet; + this->machine.more_comm_less_comp = opt.get("-c")->isSet; - if (my_num != 0) - ReplicatedSecret::out.activate(false); + T::out.activate(my_num == 0 or online_opts.interactive); - if (not machine.use_encryption) + if (not this->machine.use_encryption) insecure("unencrypted communication"); - Server* server = Server::start_networking(N, my_num, 3, hostname, pnb); + Server* server = Server::start_networking(this->N, my_num, 3, hostname, pnb); - run(); + this->run(); if (server) delete server; } +template<> +Thread* ReplicatedParty::new_thread(int i) +{ + return ThreadMaster::new_thread(i); +} + +template<> +Thread* ReplicatedParty::new_thread(int i) +{ + return new MaliciousRepThread(i, *this); +} + +template<> +void ReplicatedParty::post_run() +{ +} + +template<> +void ReplicatedParty::post_run() +{ + DataPositions usage; + for (auto thread : threads) + usage.increase(((MaliciousRepThread*)thread)->usage); + usage.print_cost(); +} + +extern template class ReplicatedSecret; +extern template class ReplicatedSecret; + +template class ReplicatedParty; +template class ReplicatedParty; + } diff --git a/GC/ReplicatedParty.h b/GC/ReplicatedParty.h index cddb2b870..35d424768 100644 --- a/GC/ReplicatedParty.h +++ b/GC/ReplicatedParty.h @@ -7,6 +7,7 @@ #define GC_REPLICATEDPARTY_H_ #include "Auth/ReplicatedMC.h" +#include "Auth/MaliciousRepMC.h" #include "ReplicatedSecret.h" #include "Processor.h" #include "Program.h" @@ -16,17 +17,26 @@ namespace GC { -class ReplicatedParty : public ThreadMaster +template +class ReplicatedParty : public ThreadMaster { + ez::ezOptionParser opt; + OnlineOptions online_opts; + public: - static Thread& s(); + static Thread& s(); ReplicatedParty(int argc, const char** argv); + + Thread* new_thread(int i); + + void post_run(); }; -inline Thread& ReplicatedParty::s() +template +inline Thread& ReplicatedParty::s() { - return Thread::s(); + return Thread::s(); } } diff --git a/GC/ReplicatedSecret.cpp b/GC/ReplicatedSecret.cpp index 2e4a8c454..3bc6331af 100644 --- a/GC/ReplicatedSecret.cpp +++ b/GC/ReplicatedSecret.cpp @@ -5,67 +5,85 @@ #include "ReplicatedSecret.h" #include "ReplicatedParty.h" +#include "MaliciousRepSecret.h" +#include "Auth/MaliciousRepMC.h" +#include "MaliciousRepThread.h" #include "Thread.h" #include "square64.h" #include "Math/Share.h" +#include "Auth/ReplicatedMC.hpp" + namespace GC { -int ReplicatedSecret::default_length = 8 * sizeof(ReplicatedSecret::value_type); +template +int ReplicatedSecret::default_length = 8 * sizeof(ReplicatedSecret::value_type); -SwitchableOutput ReplicatedSecret::out; +template +SwitchableOutput ReplicatedSecret::out; -void ReplicatedSecret::load(int n, const Integer& x) +template +void ReplicatedSecret::load(int n, const Integer& x) { if ((size_t)n < 8 * sizeof(x) and abs(x.get()) >= (1LL << n)) throw out_of_range("public value too long"); *this = x; } -void ReplicatedSecret::bitcom(Memory& S, const vector& regs) +template +void ReplicatedSecret::bitcom(Memory& S, const vector& regs) { *this = 0; for (unsigned int i = 0; i < regs.size(); i++) *this ^= (S[regs[i]] << i); } -void ReplicatedSecret::bitdec(Memory& S, const vector& regs) const +template +void ReplicatedSecret::bitdec(Memory& S, const vector& regs) const { for (unsigned int i = 0; i < regs.size(); i++) S[regs[i]] = (*this >> i) & 1; } -void ReplicatedSecret::load(vector >& accesses, - const Memory& mem) +template +void ReplicatedSecret::load(vector >& accesses, + const Memory& mem) { for (auto access : accesses) access.dest = mem[access.address]; } -void ReplicatedSecret::store(Memory& mem, - vector >& accesses) +template +void ReplicatedSecret::store(Memory& mem, + vector >& accesses) { for (auto access : accesses) mem[access.address] = access.source; } -void ReplicatedSecret::store_clear_in_dynamic(Memory& mem, +template +void ReplicatedSecret::store_clear_in_dynamic(Memory& mem, const vector& accesses) { for (auto access : accesses) mem[access.address] = access.value; } -void ReplicatedSecret::inputb(Processor& processor, +template +void ReplicatedSecret::inputb(Processor& processor, const vector& args) { - auto& party = ReplicatedParty::s(); + auto& party = ReplicatedParty::s(); party.os.resize(2); for (auto& o : party.os) o.reset_write_head(); processor.check_args(args, 3); + + InputArgList a(args); + bool interactive = party.n_interactive_inputs_from_me(a) > 0; + for (size_t i = 0; i < args.size(); i += 3) { int from = args[i]; @@ -73,12 +91,15 @@ void ReplicatedSecret::inputb(Processor& processor, if (from == party.P->my_num()) { auto& res = processor.S[args[i + 2]]; - res.prepare_input(party.os, processor.get_input(n_bits), n_bits, party.secure_prng); + res.prepare_input(party.os, processor.get_input(n_bits, interactive), n_bits, party.secure_prng); } } - party.P->send_relative(party.os); - party.P->receive_relative(party.os); + if (interactive) + cout << "Thank you" << endl; + + for (int i = 0; i < 2; i++) + party.P->pass_around(party.os[i], i + 1); for (size_t i = 0; i < args.size(); i += 3) { @@ -87,17 +108,18 @@ void ReplicatedSecret::inputb(Processor& processor, if (from != party.P->my_num()) { auto& res = processor.S[args[i + 2]]; - res.finalize_input(party, party.os[party.P->get_offset(from) == 2], from, n_bits); + res.finalize_input(party, party.os[party.P->get_offset(from) == 1], from, n_bits); } } } -ReplicatedSecret ReplicatedSecret::input(int from, Processor& processor, int n_bits) +template +U ReplicatedSecret::input(int from, Processor& processor, int n_bits) { // BMR stuff counts from 1 from--; - auto& party = ReplicatedParty::s(); - ReplicatedSecret res; + auto& party = ReplicatedParty::s(); + U res; party.os.resize(2); for (auto& o : party.os) o.reset_write_head(); @@ -114,7 +136,8 @@ ReplicatedSecret ReplicatedSecret::input(int from, Processor& return res; } -void ReplicatedSecret::prepare_input(vector& os, long input, int n_bits, PRNG& secure_prng) +template +void ReplicatedSecret::prepare_input(vector& os, long input, int n_bits, PRNG& secure_prng) { randomize_to_sum(input, secure_prng); *this &= get_mask(n_bits); @@ -122,36 +145,18 @@ void ReplicatedSecret::prepare_input(vector& os, long input, int n_ BitVec(get_mask(n_bits) & (*this)[i]).pack(os[i], n_bits); } -void ReplicatedSecret::finalize_input(Thread& party, octetStream& o, int from, int n_bits) +template +void ReplicatedSecret::finalize_input(Thread& party, octetStream& o, int from, int n_bits) { int j = party.P->get_offset(from) == 2; (*this)[j] = BitVec::unpack_new(o, n_bits); (*this)[1 - j] = 0; } -void ReplicatedSecret::and_(Processor& processor, - const vector& args, bool repeat) -{ - auto& party = ReplicatedParty::s(); - assert(party.P->num_players() == 3); - vector& os = party.os; - os.resize(2); - for (auto& o : os) - o.reset_write_head(); - processor.check_args(args, 4); - for (size_t i = 0; i < args.size(); i += 4) - processor.S[args[i + 1]].prepare_and(os, args[i], - processor.S[args[i + 2]], processor.S[args[i + 3]], - party, repeat); - party.P->send_relative(os); - party.P->receive_relative(os); - for (size_t i = 0; i < args.size(); i += 4) - processor.S[args[i + 1]].finalize_andrs(os, args[i]); -} - -inline void ReplicatedSecret::prepare_and(vector& os, int n, - const ReplicatedSecret& x, const ReplicatedSecret& y, - Thread& party, bool repeat) +template<> +inline void ReplicatedSecret::prepare_and(vector& os, int n, + const ReplicatedSecret& x, const ReplicatedSecret& y, + Thread& party, bool repeat) { ReplicatedSecret y_ext; if (repeat) @@ -169,8 +174,34 @@ inline void ReplicatedSecret::prepare_and(vector& os, int n, BitVec(mask & (*this)[0]).pack(os[0], n); } -void ReplicatedSecret::and_(int n, const ReplicatedSecret& x, - const ReplicatedSecret& y, bool repeat) +template<> +inline void ReplicatedSecret::finalize_andrs( + vector& os, int n) +{ + (*this)[1].unpack(os[1], n); +} + +template<> +void ReplicatedSecret::andrs(int n, + const ReplicatedSecret& x, + const ReplicatedSecret& y) +{ + auto& party = Thread::s(); + assert(party.P->num_players() == 3); + vector& os = party.os; + os.resize(2); + for (auto& o : os) + o.reset_write_head(); + prepare_and(os, n, x, y, party, true); + party.P->send_relative(os); + party.P->receive_relative(os); + finalize_andrs(os, n); +} + +template<> +void ReplicatedSecret::and_(int n, + const ReplicatedSecret& x, + const ReplicatedSecret& y, bool repeat) { if (repeat) andrs(n, x, y); @@ -178,27 +209,46 @@ void ReplicatedSecret::and_(int n, const ReplicatedSecret& x, throw runtime_error("call static ReplicatedSecret::ands()"); } -void ReplicatedSecret::andrs(int n, const ReplicatedSecret& x, - const ReplicatedSecret& y) +template<> +void ReplicatedSecret::and_(int n, + const ReplicatedSecret& x, + const ReplicatedSecret& y, bool repeat) +{ + (void)n, (void)x, (void)y, (void)repeat; + throw runtime_error("use static method"); +} + +template<> +void ReplicatedSecret::and_(Processor& processor, + const vector& args, bool repeat) { - auto& party = ReplicatedParty::s(); + auto& party = Thread::s(); assert(party.P->num_players() == 3); vector& os = party.os; os.resize(2); for (auto& o : os) o.reset_write_head(); - prepare_and(os, n, x, y, party, true); + processor.check_args(args, 4); + for (size_t i = 0; i < args.size(); i += 4) + processor.S[args[i + 1]].prepare_and(os, args[i], + processor.S[args[i + 2]], processor.S[args[i + 3]], + party, repeat); party.P->send_relative(os); party.P->receive_relative(os); - finalize_andrs(os, n); + for (size_t i = 0; i < args.size(); i += 4) + processor.S[args[i + 1]].finalize_andrs(os, args[i]); } -inline void ReplicatedSecret::finalize_andrs(vector& os, int n) +template<> +void ReplicatedSecret::and_( + Processor& processor, const vector& args, + bool repeat) { - (*this)[1].unpack(os[1], n); + MaliciousRepThread::s().and_(processor, args, repeat); } -void ReplicatedSecret::trans(Processor& processor, +template +void ReplicatedSecret::trans(Processor& processor, int n_outputs, const vector& args) { assert(length == 2); @@ -213,19 +263,21 @@ void ReplicatedSecret::trans(Processor& processor, } } -void ReplicatedSecret::reveal(Clear& x) +template +void ReplicatedSecret::reveal(Clear& x) { ReplicatedSecret share = *this; vector opened; - auto& party = ReplicatedParty::s(); - party.MC.POpen_Begin(opened, {share}, *party.P); - party.MC.POpen_End(opened, {share}, *party.P); + auto& party = ReplicatedParty::s(); + party.MC->POpen_Begin(opened, {share}, *party.P); + party.MC->POpen_End(opened, {share}, *party.P); x = IntBase(opened[0]); } -void ReplicatedSecret::random_bit() +template<> +void ReplicatedSecret::random_bit() { - auto& party = ReplicatedParty::s(); + auto& party = ReplicatedParty::s(); *this = party.secure_prng.get_bit(); octetStream o; (*this)[0].pack(o, 1); @@ -233,4 +285,15 @@ void ReplicatedSecret::random_bit() (*this)[1].unpack(o, 1); } +template<> +void ReplicatedSecret::random_bit() +{ + MaliciousRepSecret res; + MaliciousRepThread::s().DataF.get_one(DATA_BIT, res); + *this = res; +} + +template class ReplicatedSecret; +template class ReplicatedSecret; + } diff --git a/GC/ReplicatedSecret.h b/GC/ReplicatedSecret.h index c086886dd..96326321a 100644 --- a/GC/ReplicatedSecret.h +++ b/GC/ReplicatedSecret.h @@ -26,16 +26,14 @@ class Processor; template class Thread; +template class ReplicatedSecret : public FixedVec { typedef FixedVec super; public: - typedef ReplicatedSecret DynamicType; - typedef BitVec clear; - typedef ReplicatedMC MC; typedef void Inp; typedef void PO; typedef ReplicatedBase Protocol; @@ -46,27 +44,27 @@ class ReplicatedSecret : public FixedVec static int default_length; static SwitchableOutput out; - static void store_clear_in_dynamic(Memory& mem, + static void store_clear_in_dynamic(Memory& mem, const vector& accesses); - static void load(vector< ReadAccess >& accesses, const Memory& mem); - static void store(Memory& mem, vector< WriteAccess >& accesses); + static void load(vector< ReadAccess >& accesses, const Memory& mem); + static void store(Memory& mem, vector< WriteAccess >& accesses); - static void andrs(Processor& processor, const vector& args) + static void andrs(Processor& processor, const vector& args) { and_(processor, args, true); } - static void ands(Processor& processor, const vector& args) + static void ands(Processor& processor, const vector& args) { and_(processor, args, false); } - static void and_(Processor& processor, const vector& args, bool repeat); - static void inputb(Processor& processor, const vector& args); + static void and_(Processor& processor, const vector& args, bool repeat); + static void inputb(Processor& processor, const vector& args); - static void trans(Processor& processor, int n_outputs, + static void trans(Processor& processor, int n_outputs, const vector& args); static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); } - static ReplicatedSecret input(int from, Processor& processor, int n_bits); + static U input(int from, Processor& processor, int n_bits); void prepare_input(vector& os, long input, int n_bits, PRNG& secure_prng); - void finalize_input(Thread& party, octetStream& o, int from, int n_bits); + void finalize_input(Thread& party, octetStream& o, int from, int n_bits); ReplicatedSecret() {} template @@ -74,8 +72,8 @@ class ReplicatedSecret : public FixedVec void load(int n, const Integer& x); - void bitcom(Memory& S, const vector& regs); - void bitdec(Memory& S, const vector& regs) const; + void bitcom(Memory& S, const vector& regs); + void bitdec(Memory& S, const vector& regs) const; void xor_(int n, const ReplicatedSecret& x, const ReplicatedSecret& y) { *this = x ^ y; (void)n; } @@ -83,7 +81,7 @@ class ReplicatedSecret : public FixedVec void andrs(int n, const ReplicatedSecret& x, const ReplicatedSecret& y); void prepare_and(vector& os, int n, const ReplicatedSecret& x, const ReplicatedSecret& y, - Thread& party, bool repeat); + Thread& party, bool repeat); void finalize_andrs(vector& os, int n); void reveal(Clear& x); @@ -91,6 +89,21 @@ class ReplicatedSecret : public FixedVec void random_bit(); }; + +class SemiHonestRepSecret : public ReplicatedSecret +{ + typedef ReplicatedSecret super; + +public: + typedef SemiHonestRepSecret DynamicType; + + typedef ReplicatedMC MC; + + SemiHonestRepSecret() {} + template + SemiHonestRepSecret(const T& other) : super(other) {} +}; + } #endif /* GC_REPLICATEDSECRET_H_ */ diff --git a/GC/Secret.h b/GC/Secret.h index 55dface2f..2d111d1eb 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -13,7 +13,6 @@ #include "GC/Clear.h" #include "GC/Memory.h" #include "GC/Access.h" -#include "GC/Processor.h" #include "Math/Share.h" @@ -64,6 +63,8 @@ class SpdzShare : public Share { Share::assign(value, first_player ? 0 : 1, mac_key); } }; +template class Processor; + template class Secret { @@ -79,7 +80,7 @@ class Secret #endif // dummy - typedef MAC_Check_Base > MC; + typedef DummyMC MC; typedef DummyProtocol Protocol; static string type_string() { return "evaluation secret"; } diff --git a/GC/Thread.h b/GC/Thread.h index b29627533..a10267d50 100644 --- a/GC/Thread.h +++ b/GC/Thread.h @@ -9,6 +9,7 @@ #include "Networking/Player.h" #include "Tools/random.h" #include "Processor.h" +#include "ArgTuples.h" namespace GC { @@ -20,6 +21,8 @@ struct ScheduleItem ScheduleItem(int tape = 0, int arg = 0) : tape(tape), arg(arg) {} }; +template class ThreadMaster; + template class Thread { @@ -28,9 +31,10 @@ class Thread static void* run_thread(void* thread); public: + ThreadMaster& master; Machine& machine; Processor processor; - typename T::MC MC; + typename T::MC* MC; typename T::Protocol* protocol; Names& N; Player* P; @@ -44,9 +48,11 @@ class Thread static Thread& s(); - Thread(int thread_num, Machine& machine, Names& N); + Thread(int thread_num, ThreadMaster& master); virtual ~Thread(); + virtual typename T::MC* new_mc() { return new typename T::MC; } + void run(); virtual void pre_run() {} virtual void run(Program& program); @@ -54,6 +60,8 @@ class Thread void join_tape(); void finish(); + + int n_interactive_inputs_from_me(InputArgList& args); }; template diff --git a/GC/Thread.hpp b/GC/Thread.hpp index 3832c8f41..bd2c7f54d 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -25,8 +25,9 @@ void* Thread::run_thread(void* thread) } template -Thread::Thread(int thread_num, Machine& machine, Names& N) : - machine(machine), processor(machine), protocol(0), N(N), P(0), +Thread::Thread(int thread_num, ThreadMaster& master) : + master(master), machine(master.machine), processor(machine), + protocol(0), N(master.N), P(0), thread_num(thread_num) { pthread_create(&thread, 0, run_thread, this); @@ -35,6 +36,8 @@ Thread::Thread(int thread_num, Machine& machine, Names& N) : template Thread::~Thread() { + if (MC) + delete MC; if (P) delete P; if (protocol) @@ -53,8 +56,8 @@ void Thread::run() else P = new PlainPlayer(N, thread_num << 16); protocol = new typename T::Protocol(*P); - string input_file = "Player-Data/Input-P" + to_string(N.my_num()) + "-" + to_string(thread_num); - processor.open_input_file(input_file); + MC = this->new_mc(); + processor.open_input_file(N.my_num(), thread_num); done.push(0); pre_run(); @@ -67,7 +70,7 @@ void Thread::run() } post_run(); - MC.Check(*P); + MC->Check(*P); } template @@ -91,4 +94,16 @@ void Thread::finish() pthread_join(thread, 0); } + +template +int GC::Thread::n_interactive_inputs_from_me(InputArgList& args) +{ + int res = 0; + if (thread_num == 0 and master.opts.interactive) + res = args.n_inputs_from(P->my_num()); + if (res > 0) + cout << "Please enter " << res << " numbers:" << endl; + return res; +} + } /* namespace GC */ diff --git a/GC/ThreadMaster.h b/GC/ThreadMaster.h index 10b71c84a..285c12a49 100644 --- a/GC/ThreadMaster.h +++ b/GC/ThreadMaster.h @@ -9,6 +9,8 @@ #include "Thread.h" #include "Program.h" +#include "Processor/OnlineOptions.h" + namespace GC { @@ -45,9 +47,11 @@ class ThreadMaster : public ThreadMasterBase Machine machine; Memory memory; + OnlineOptions& opts; + static ThreadMaster& s(); - ThreadMaster(); + ThreadMaster(OnlineOptions& opts); virtual ~ThreadMaster() {} void run_tape(int thread_number, int tape_number, int arg); @@ -56,6 +60,8 @@ class ThreadMaster : public ThreadMasterBase virtual Thread* new_thread(int i); void run(); + + virtual void post_run() {} }; } /* namespace GC */ diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index e2a9eba3e..330630fd9 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -9,6 +9,8 @@ #include "ReplicatedSecret.h" #include "Secret.h" +#include "instructions.h" + namespace GC { @@ -25,7 +27,8 @@ ThreadMaster& ThreadMaster::s() } template -ThreadMaster::ThreadMaster() : P(0), machine(memory) +ThreadMaster::ThreadMaster(OnlineOptions& opts) : + P(0), machine(memory), opts(opts) { if (singleton) throw runtime_error("there can only be one"); @@ -47,13 +50,13 @@ void ThreadMaster::join_tape(int thread_number) template Thread* ThreadMaster::new_thread(int i) { - return new Thread(i, machine, N); + return new Thread(i, *this); } template void ThreadMaster::run() { - P = new PlainPlayer(N, 1 << 24); + P = new PlainPlayer(N, 0xff << 24); machine.load_schedule(progname); for (int i = 0; i < machine.nthreads; i++) @@ -73,11 +76,30 @@ void ThreadMaster::run() vector os(P->num_players()); P->Broadcast_Receive(os); + post_run(); + + NamedCommStats stats = P->comm_stats; + ExecutionStats exe_stats; for (auto thread : threads) + { + stats += thread->P->comm_stats; + exe_stats += thread->processor.stats; delete thread; + } delete P; + for (auto it : exe_stats) + switch (it.first) + { +#define X(NAME, CODE) case NAME: cerr << it.second << " " #NAME << endl; break; + INSTRUCTIONS + } + + for (auto it = stats.begin(); it != stats.end(); it++) + if (it->second.data > 0) + cerr << it->first << " " << 1e-6 * it->second.data << " MB" << endl; + cerr << "Time = " << timer.elapsed() << endl; } diff --git a/GC/config.h b/GC/config.h index c6d309fe9..7c80f42ac 100644 --- a/GC/config.h +++ b/GC/config.h @@ -8,4 +8,8 @@ //#define CHECK_SIZE +//#define COUNT_INSTRUCTIONS + +//#define CHECK_SIZE + #endif /* GC_CONFIG_H_ */ diff --git a/GC/instructions.h b/GC/instructions.h index d206ebf25..b8551e0f9 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -6,19 +6,19 @@ #ifndef GC_INSTRUCTIONS_H_ #define GC_INSTRUCTIONS_H_ -#include +#include "Tools/callgrind.h" -#define P processor +#define PROC processor #define INST instruction -#define M processor.machine +#define MACH processor.machine #define R0 instruction.get_r(0) #define R1 instruction.get_r(1) #define R2 instruction.get_r(2) #define S0 processor.S[instruction.get_r(0)] -#define S1 processor.S[instruction.get_r(1)] -#define S2 processor.S[instruction.get_r(2)] +#define PS1 processor.S[instruction.get_r(1)] +#define PS2 processor.S[instruction.get_r(2)] #define C0 processor.C[instruction.get_r(0)] #define C1 processor.C[instruction.get_r(1)] @@ -28,88 +28,89 @@ #define I1 processor.I[instruction.get_r(1)] #define I2 processor.I[instruction.get_r(2)] -#define N instruction.get_n() +#define IMM instruction.get_n() #define EXTRA instruction.get_start() -#define MSD M.MS[N] -#define MMC M.MC[N] -#define MID M.MI[N] +#define MSD MACH.MS[IMM] +#define MMC MACH.MC[IMM] +#define MID MACH.MI[IMM] -#define MSI M.MS[I1.get()] -#define MII M.MI[I1.get()] - -#define MD M.MD +#define MSI MACH.MS[I1.get()] +#define MII MACH.MI[I1.get()] #define INSTRUCTIONS \ - X(XORS, P.xors(EXTRA)) \ + X(XORS, PROC.xors(EXTRA)) \ X(XORC, C0.xor_(C1, C2)) \ - X(XORCI, C0.xor_(C1, N)) \ - X(ANDRS, T::andrs(P, EXTRA)) \ - X(ANDS, T::ands(P, EXTRA)) \ - X(INPUTB, T::inputb(P, EXTRA)) \ + X(XORCI, C0.xor_(C1, IMM)) \ + X(ANDRS, T::andrs(PROC, EXTRA)) \ + X(ANDS, T::ands(PROC, EXTRA)) \ + X(INPUTB, T::inputb(PROC, EXTRA)) \ X(ADDC, C0 = C1 + C2) \ - X(ADDCI, C0 = C1 + N) \ - X(MULCI, C0 = C1 * N) \ - X(BITDECS, P.bitdecs(EXTRA, S0)) \ - X(BITCOMS, P.bitcoms(S0, EXTRA)) \ - X(BITDECC, P.bitdecc(EXTRA, C0)) \ - X(BITDECINT, P.bitdecint(EXTRA, I0)) \ - X(SHRCI, C0 = C1 >> N) \ - X(LDBITS, S0.load(R1, N)) \ + X(ADDCI, C0 = C1 + IMM) \ + X(MULCI, C0 = C1 * IMM) \ + X(BITDECS, PROC.bitdecs(EXTRA, S0)) \ + X(BITCOMS, PROC.bitcoms(S0, EXTRA)) \ + X(BITDECC, PROC.bitdecc(EXTRA, C0)) \ + X(BITDECINT, PROC.bitdecint(EXTRA, I0)) \ + X(SHRCI, C0 = C1 >> IMM) \ + X(SHLCI, C0 = C1 << IMM) \ + X(LDBITS, S0.load(R1, IMM)) \ X(LDMS, S0 = MSD) \ X(STMS, MSD = S0) \ X(LDMSI, S0 = MSI) \ X(STMSI, MSI = S0) \ X(LDMC, C0 = MMC) \ X(STMC, MMC = C0) \ - X(LDMSD, P.load_dynamic_direct(EXTRA)) \ - X(STMSD, P.store_dynamic_direct(EXTRA)) \ - X(LDMSDI, P.load_dynamic_indirect(EXTRA)) \ - X(STMSDI, P.store_dynamic_indirect(EXTRA)) \ - X(STMSDCI, P.store_clear_in_dynamic(EXTRA)) \ - X(CONVSINT, S0.load(N, I1)) \ + X(LDMSD, PROC.load_dynamic_direct(EXTRA)) \ + X(STMSD, PROC.store_dynamic_direct(EXTRA)) \ + X(LDMSDI, PROC.load_dynamic_indirect(EXTRA)) \ + X(STMSDI, PROC.store_dynamic_indirect(EXTRA)) \ + X(STMSDCI, PROC.store_clear_in_dynamic(EXTRA)) \ + X(CONVSINT, S0.load(IMM, I1)) \ X(CONVCINT, C0 = I1) \ - X(MOVS, S0 = S1) \ - X(TRANS, T::trans(P, N, EXTRA)) \ - X(BIT, P.random_bit(S0)) \ - X(REVEAL, S1.reveal(C0)) \ - X(PRINTREG, P.print_reg(R0, N)) \ - X(PRINTREGPLAIN, P.print_reg_plain(C0)) \ - X(PRINTREGSIGNED, P.print_reg_signed(N, C0)) \ - X(PRINTCHR, P.print_chr(N)) \ - X(PRINTSTR, P.print_str(N)) \ - X(LDINT, I0 = int(N)) \ + X(MOVS, S0 = PS1) \ + X(TRANS, T::trans(PROC, IMM, EXTRA)) \ + X(BIT, PROC.random_bit(S0)) \ + X(REVEAL, PS1.reveal(C0)) \ + X(PRINTREG, PROC.print_reg(R0, IMM)) \ + X(PRINTREGPLAIN, PROC.print_reg_plain(C0)) \ + X(PRINTREGSIGNED, PROC.print_reg_signed(IMM, C0)) \ + X(PRINTCHR, PROC.print_chr(IMM)) \ + X(PRINTSTR, PROC.print_str(IMM)) \ + X(PRINTFLOATPLAIN, PROC.print_float(EXTRA)) \ + X(CONDPRINTSTR, if(C0.get()) PROC.print_str(IMM)) \ + X(LDINT, I0 = int(IMM)) \ X(ADDINT, I0 = I1 + I2) \ X(SUBINT, I0 = I1 - I2) \ X(MULINT, I0 = I1 * I2) \ X(DIVINT, I0 = I1 / I2) \ - X(JMP, P.PC += N) \ - X(JMPNZ, if (I0 != 0) P.PC += N) \ - X(JMPEQZ, if (I0 == 0) P.PC += N) \ + X(JMP, PROC.PC += IMM) \ + X(JMPNZ, if (I0 != 0) PROC.PC += IMM) \ + X(JMPEQZ, if (I0 == 0) PROC.PC += IMM) \ X(EQZC, I0 = I1 == 0) \ X(LTZC, I0 = I1 < 0) \ X(LTC, I0 = I1 < I2) \ X(GTC, I0 = I1 > I2) \ X(EQC, I0 = I1 == I2) \ - X(JMPI, P.PC += I0) \ + X(JMPI, PROC.PC += I0) \ X(LDMINT, I0 = MID) \ X(STMINT, MID = I0) \ X(LDMINTI, I0 = MII) \ X(STMINTI, MII = I0) \ - X(PUSHINT, P.pushi(I0.get())) \ - X(POPINT, long x; P.popi(x); I0 = x) \ + X(PUSHINT, PROC.pushi(I0.get())) \ + X(POPINT, long x; PROC.popi(x); I0 = x) \ X(MOVINT, I0 = I1) \ - X(LDARG, I0 = P.get_arg()) \ - X(STARG, P.set_arg(I0.get())) \ - X(TIME, M.time()) \ - X(START, M.start(N)) \ - X(STOP, M.stop(N)) \ + X(LDARG, I0 = PROC.get_arg()) \ + X(STARG, PROC.set_arg(I0.get())) \ + X(TIME, MACH.time()) \ + X(START, MACH.start(IMM)) \ + X(STOP, MACH.stop(IMM)) \ X(GLDMS, ) \ X(GLDMC, ) \ X(PRINTINT, S0.out << I0) \ X(STARTGRIND, CALLGRIND_START_INSTRUMENTATION) \ X(STOPGRIND, CALLGRIND_STOP_INSTRUMENTATION) \ - X(RUN_TAPE, M.run_tape(R0, N, R1)) \ - X(JOIN_TAPE, M.join_tape(R0)) \ + X(RUN_TAPE, MACH.run_tape(R0, IMM, R1)) \ + X(JOIN_TAPE, MACH.join_tape(R0)) \ #endif /* GC_INSTRUCTIONS_H_ */ diff --git a/Makefile b/Makefile index dc2a10768..f1e3e80db 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ endif COMMON = $(MATH) $(TOOLS) $(NETWORK) $(AUTH) COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT) YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) $(GC) -BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) $(PROCESSOR) $(GC) +BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) $(COMMON) $(PROCESSOR) LIB = libSPDZ.a @@ -39,7 +39,7 @@ OBJS = $(BMR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) DEPS := $(OBJS:.o=.d) -all: gen_input online offline externalIO yao replicated-bin-party.x replicated-ring-party.x +all: gen_input online offline externalIO yao replicated ifeq ($(USE_GF2N_LONG),1) all: bmr @@ -70,10 +70,20 @@ she-offline: Check-Offline.x spdz2-offline.x overdrive: simple-offline.x pairwise-offline.x cnc-offline.x +rep-field: malicious-rep-bin-party.x replicated-field-party.x Setup.x + +rep-ring: replicated-ring-party.x Fake-Offline.x + +rep-bin: replicated-bin-party.x malicious-rep-bin-party.x Fake-Offline.x + +replicated: rep-field rep-ring rep-bin + +tldr: malicious-rep-field-party.x Setup.x + Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(PROCESSOR) - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + $(CXX) $(CFLAGS) -o $@ Fake-Offline.cpp $(COMMON) $(PROCESSOR) $(LDLIBS) -Check-Offline.x: Check-Offline.cpp $(COMMON) $(PROCESSOR) +Check-Offline.x: Check-Offline.cpp $(COMMON) $(PROCESSOR) Auth/fake-stuff.hpp $(CXX) $(CFLAGS) Check-Offline.cpp -o Check-Offline.x $(COMMON) $(PROCESSOR) $(LDLIBS) Server.x: Server.cpp $(COMMON) @@ -82,6 +92,9 @@ Server.x: Server.cpp $(COMMON) Player-Online.x: Player-Online.cpp $(COMMON) $(PROCESSOR) $(CXX) $(CFLAGS) Player-Online.cpp -o Player-Online.x $(COMMON) $(PROCESSOR) $(LDLIBS) +Setup.x: Setup.cpp $(COMMON) + $(CXX) $(CFLAGS) Setup.cpp -o Setup.x $(COMMON) $(LDLIBS) + ifeq ($(USE_GF2N_LONG),1) ot.x: $(OT) $(COMMON) OT/OText_main.cpp $(LIBSIMPLEOT) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) @@ -106,7 +119,7 @@ gen_input_fp.x: Scripts/gen_input_fp.cpp $(COMMON) $(CXX) $(CFLAGS) Scripts/gen_input_fp.cpp -o gen_input_fp.x $(COMMON) $(LDLIBS) gc-emulate.x: $(GC) $(COMMON) $(PROCESSOR) gc-emulate.cpp $(GC) - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(BOOST) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) ifeq ($(USE_GF2N_LONG),1) bmr-program-party.x: $(BMR) bmr-program-party.cpp @@ -152,11 +165,20 @@ galois-degree.x: $(COMMON) galois-degree.cpp $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) replicated-bin-party.x: $(COMMON) $(GC) replicated-bin-party.cpp - $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) $(BOOST) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + +malicious-rep-bin-party.x: $(COMMON) $(GC) malicious-rep-bin-party.cpp + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) replicated-ring-party.x: replicated-ring-party.cpp $(PROCESSOR) $(COMMON) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) +replicated-field-party.x: replicated-field-party.cpp $(PROCESSOR) $(COMMON) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + +malicious-rep-field-party.x: malicious-rep-field-party.cpp $(PROCESSOR) $(COMMON) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + $(LIBSIMPLEOT): SimpleOT/Makefile $(MAKE) -C SimpleOT @@ -166,5 +188,18 @@ OT/BaseOT.o: SimpleOT/Makefile SimpleOT/Makefile: git submodule update --init SimpleOT +.PHONY: mpir +mpir: + git submodule update --init mpir + cd mpir; \ + libtoolize --force; \ + aclocal; \ + autoheader; \ + automake --force-missing --add-missing; \ + autoconf; \ + ./configure --enable-cxx; + $(MAKE) -C mpir + sudo $(MAKE) -C mpir install + clean: -rm */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o diff --git a/Math/BitVec.h b/Math/BitVec.h index 784b7d7ab..fe3786a2c 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -7,12 +7,16 @@ #define MATH_BITVEC_H_ #include "Integer.h" +#include "field_types.h" class BitVec : public IntBase { public: static const int n_bits = sizeof(a) * 8; + static char type_char() { return 'B'; } + static DataFieldType field_type() { return DATA_GF2; } + BitVec() {} BitVec(long a) : IntBase(a) {} BitVec(const IntBase& a) : IntBase(a) {} @@ -24,6 +28,9 @@ class BitVec : public IntBase BitVec& operator+=(const BitVec& other) { *this ^= other; return *this; } BitVec extend_bit() const { return -(a & 1); } + BitVec mask(int n) const { return n < n_bits ? *this & ((1L << n) - 1) : *this; } + + void mul(const BitVec& a, const BitVec& b) { *this = a * b; } void pack(octetStream& os, int n = n_bits) const { os.store_int(a, DIV_CEIL(n, 8)); } void unpack(octetStream& os, int n = n_bits) { a = os.get_int(DIV_CEIL(n, 8)); } diff --git a/Math/FixedVec.h b/Math/FixedVec.h index 52a8f75a1..f4f981472 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -35,13 +35,13 @@ class FixedVec { return T::type_string() + "^" + to_string(L); } - static char type_char() + static string type_short() { - return T::type_char(); + return string(1, T::type_char()); } static DataFieldType field_type() { - return DATA_MODP; + return T::field_type(); } FixedVec(const T& other = 0) @@ -71,8 +71,8 @@ class FixedVec } void assign(const char* buffer) { - for (auto& x : v) - x.assign(buffer); + for (int i = 0; i < L; i++) + v[i].assign(buffer + i * T::size()); } void assign_zero() @@ -144,6 +144,21 @@ class FixedVec return res; } + FixedVecoperator*(const FixedVec& other) const + { + FixedVec res; + res.mul(*this, other); + return res; + } + + FixedVecoperator/(const FixedVec& other) const + { + FixedVec res; + for (int i = 0; i < L; i++) + res[i] = v[i] / other[i]; + return res; + } + FixedVecoperator^(const FixedVec& other) const { FixedVec res; @@ -166,6 +181,12 @@ class FixedVec return *this; } + FixedVec& operator/=(const FixedVec& other) + { + *this = *this / other; + return *this; + } + FixedVec& operator^=(const FixedVec& other) { for (int i = 0; i < L; i++) @@ -196,6 +217,12 @@ class FixedVec return res; } + FixedVec& operator>>=(int i) + { + *this = *this >> i; + return *this; + } + T sum() const { T res = 0; @@ -212,6 +239,14 @@ class FixedVec return res; } + FixedVec mask(int n_bits) const + { + FixedVec res; + for (int i = 0; i < L; i++) + res[i] = v[i].mask(n_bits); + return res; + } + void randomize(PRNG& G) { for (auto& x : v) @@ -251,6 +286,12 @@ class FixedVec } }; +template +FixedVec operator*(const T& a, const FixedVec& b) +{ + return b * a; +} + template ostream& operator<<(ostream& os, const FixedVec& v) { diff --git a/Math/Integer.cpp b/Math/Integer.cpp index 707ec0772..8cf9a6d0e 100644 --- a/Math/Integer.cpp +++ b/Math/Integer.cpp @@ -31,3 +31,18 @@ void to_signed_bigint(bigint& res, const Integer& x, int n) if (x < 0) res.negate(); } + +void Integer::reqbl(int n) +{ + if ((int)n < 0 && size() * 8 != -(int)n) + { + throw Processor_Error( + "Program compiled for rings of length " + to_string(-(int)n) + + " but VM supports only " + + to_string(size() * 8)); + } + else if ((int)n > 0) + { + throw Processor_Error("Program compiled for fields not rings"); + } +} diff --git a/Math/Integer.h b/Math/Integer.h index d8884544f..75055a1f9 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -12,6 +12,7 @@ using namespace std; #include "Tools/octetStream.h" #include "Tools/random.h" #include "bigint.h" +#include "field_types.h" // Functionality shared between integers and bit vectors @@ -30,9 +31,16 @@ class IntBase long get() const { return a; } bool get_bit(int i) const { return (a >> i) & 1; } + unsigned long debug() const { return a; } + void assign(long x) { *this = x; } void assign(const char* buffer) { avx_memcpy(&a, buffer, sizeof(a)); } void assign_zero() { a = 0; } + void assign_one() { a = 1; } + + bool is_zero() const { return a == 0; } + bool is_one() const { return a == 1; } + bool is_bit() const { return is_zero() or is_one(); } long operator>>(const IntBase& other) const { return a >> other.a; } long operator<<(const IntBase& other) const { return a << other.a; } @@ -44,6 +52,8 @@ class IntBase bool operator==(const IntBase& other) const { return a == other.a; } bool operator!=(const IntBase& other) const { return a != other.a; } + bool equal(const IntBase& other) const { return *this == other; } + long operator^=(const IntBase& other) { return a ^= other.a; } long operator&=(const IntBase& other) { return a &= other.a; } @@ -67,6 +77,9 @@ class Integer : public IntBase typedef Integer clear; static char type_char() { return 'R'; } + static DataFieldType field_type() { return DATA_INT64; } + + static void reqbl(int n); Integer() { a = 0; } Integer(long a) : IntBase(a) {} @@ -74,12 +87,6 @@ class Integer : public IntBase void convert_destroy(bigint& other) { *this = other.get_si(); } - void assign_one() { a = 1; } - - bool is_zero() const { return a == 0; } - bool is_one() const { return a == 1; } - bool is_bit() const { return is_zero() or is_one(); } - long operator+(const Integer& other) const { return a + other.a; } long operator-(const Integer& other) const { return a - other.a; } long operator*(const Integer& other) const { return a * other.a; } @@ -126,6 +133,11 @@ inline void to_bigint(bigint& res, const Integer& x) res = (unsigned long)x.get(); } +inline void to_signed_bigint(bigint& res, const Integer& x) +{ + res = x.get(); +} + void to_signed_bigint(bigint& res, const Integer& x, int n); // slight misnomer diff --git a/Math/MaliciousRep3Share.h b/Math/MaliciousRep3Share.h new file mode 100644 index 000000000..815ef2c83 --- /dev/null +++ b/Math/MaliciousRep3Share.h @@ -0,0 +1,41 @@ +/* + * MaliciousRep3Share.h + * + */ + +#ifndef MATH_MALICIOUSREP3SHARE_H_ +#define MATH_MALICIOUSREP3SHARE_H_ + +#include "Rep3Share.h" +#include "gfp.h" + +template class HashMaliciousRepMC; +template class Beaver; + +template +class MaliciousRep3Share : public Rep3Share +{ + typedef Rep3Share super; + +public: + typedef Beaver> Protocol; + typedef HashMaliciousRepMC> MAC_Check; + typedef MAC_Check Direct_MC; + typedef ReplicatedInput> Input; + typedef ReplicatedPrivateOutput> PrivateOutput; + + static string type_short() + { + return "M" + string(1, gfp::type_char()); + } + + MaliciousRep3Share() + { + } + template + MaliciousRep3Share(const U& other) : super(other) + { + } +}; + +#endif /* MATH_MALICIOUSREP3SHARE_H_ */ diff --git a/Math/Rep3Share.h b/Math/Rep3Share.h index ea414e64b..298b84e2e 100644 --- a/Math/Rep3Share.h +++ b/Math/Rep3Share.h @@ -10,10 +10,11 @@ #include "Math/Integer.h" #include "Processor/Replicated.h" -class Rep3Share: public FixedVec +template +class Rep3Share : public FixedVec { public: - typedef Integer clear; + typedef T clear; typedef Replicated Protocol; typedef ReplicatedMC MAC_Check; @@ -21,33 +22,37 @@ class Rep3Share: public FixedVec typedef ReplicatedInput Input; typedef ReplicatedPrivateOutput PrivateOutput; - static char type_char() + static string type_short() { - return clear::type_char(); + return "R" + string(1, clear::type_char()); + } + static string type_string() + { + return "replicated " + T::type_string(); } Rep3Share() { } - Rep3Share(const FixedVec& other) + Rep3Share(const FixedVec& other) { - FixedVec::operator=(other); + FixedVec::operator=(other); } - Rep3Share(Integer value, int my_num) + Rep3Share(T value, int my_num) { Replicated::assign(*this, value, my_num); } // Share compatibility - void assign(clear other, int my_num, const Integer& alphai) + void assign(clear other, int my_num, const T& alphai) { (void)alphai; *this = Rep3Share(other, my_num); } void assign(const char* buffer) { - FixedVec::assign(buffer); + FixedVec::assign(buffer); } void add(const Rep3Share& x, const Rep3Share& y) @@ -60,33 +65,39 @@ class Rep3Share: public FixedVec } void add(const Rep3Share& S, const clear aa, int my_num, - const Integer& alphai) + const T& alphai) { (void)alphai; *this = S + Rep3Share(aa, my_num); } void sub(const Rep3Share& S, const clear& aa, int my_num, - const Integer& alphai) + const T& alphai) { (void)alphai; *this = S - Rep3Share(aa, my_num); } void sub(const clear& aa, const Rep3Share& S, int my_num, - const Integer& alphai) + const T& alphai) { (void)alphai; *this = Rep3Share(aa, my_num) - S; } + void mul_by_bit(const Rep3Share& x, const T& y) + { + (void) x, (void) y; + throw not_implemented(); + } + void pack(octetStream& os, bool full = true) const { (void)full; - FixedVec::pack(os); + FixedVec::pack(os); } void unpack(octetStream& os, bool full = true) { (void)full; - FixedVec::unpack(os); + FixedVec::unpack(os); } }; diff --git a/Math/Share.h b/Math/Share.h index cf2193d53..6fa439371 100644 --- a/Math/Share.h +++ b/Math/Share.h @@ -45,8 +45,8 @@ class Share static string type_string() { return T::type_string(); } - static char type_char() - { return T::type_char(); } + static string type_short() + { return string(1, T::type_char()); } static DataFieldType field_type() { return T::field_type(); } diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index d7320e336..9e1aa71f1 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -6,6 +6,7 @@ void Zp_Data::init(const bigint& p,bool mont) { pr=p; mask=(1<<((mpz_sizeinbase(pr.get_mpz_t(),2)-1)%(8*sizeof(mp_limb_t))))-1; + pr_byte_length = numBytes(pr); montgomery=mont; t=mpz_size(pr.get_mpz_t()); diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index bcff3a151..deb534b52 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -45,6 +45,7 @@ class Zp_Data bigint pr; mp_limb_t mask; + size_t pr_byte_length; void assign(const Zp_Data& Zp); void init(const bigint& p,bool mont=true); @@ -55,7 +56,7 @@ class Zp_Data void unpack(octetStream& o); // This one does nothing, needed so as to make vectors of Zp_Data - Zp_Data() : montgomery(0), pi(0), mask(0) { t=MAX_MOD_SZ; } + Zp_Data() : montgomery(0), pi(0), mask(0), pr_byte_length(0) { t=MAX_MOD_SZ; } // The main init funciton Zp_Data(const bigint& p,bool mont=true) diff --git a/Math/bigint.cpp b/Math/bigint.cpp index c290cfb08..ab8b65de9 100644 --- a/Math/bigint.cpp +++ b/Math/bigint.cpp @@ -1,9 +1,22 @@ #include "bigint.h" #include "gfp.h" +#include "Integer.h" +#include "GC/Clear.h" #include "Exceptions/Exceptions.h" +class gmp_random +{ +public: + gmp_randclass Gen; + gmp_random() : Gen(gmp_randinit_default) + { + Gen.seed(0); + } +}; + thread_local bigint bigint::tmp; +thread_local gmp_random bigint::random; bigint sqrRootMod(const bigint& a,const bigint& p) { @@ -18,13 +31,11 @@ bigint sqrRootMod(const bigint& a,const bigint& p) } else { // Shanks algorithm - gmp_randclass Gen(gmp_randinit_default); - Gen.seed(0); bigint x,y,n,q,t,b,temp; // Find n such that (n/p)=-1 int leg=1; while (leg!=-1) - { n=Gen.get_z_range(p); + { n=bigint::random.Gen.get_z_range(p); leg=mpz_legendre(n.get_mpz_t(),p.get_mpz_t()); } // Split p-1 = 2^e q @@ -134,6 +145,27 @@ int limb_size() return 0; } +template +mpf_class bigint::get_float(T v, Integer exp, T z, T s) +{ + bigint tmp; + to_signed_bigint(tmp, v); + mpf_class res = tmp; + if (exp > 0) + mpf_mul_2exp(res.get_mpf_t(), res.get_mpf_t(), exp.get()); + else + mpf_div_2exp(res.get_mpf_t(), res.get_mpf_t(), -exp.get()); + if (z.is_one()) + res = 0; + if (s.is_one()) + { + res *= -1; + } + if (not z.is_bit() or not s.is_bit()) + throw Processor_Error("invalid floating point number"); + return res; +} + #ifdef REALLOC_POLICE void bigint::lottery() { @@ -142,3 +174,7 @@ void bigint::lottery() throw runtime_error("much deallocation"); } #endif + +template mpf_class bigint::get_float(gfp, Integer, gfp, gfp); +template mpf_class bigint::get_float(Integer, Integer, Integer, Integer); +template mpf_class bigint::get_float(GC::Clear, Integer, GC::Clear, GC::Clear); diff --git a/Math/bigint.h b/Math/bigint.h index 9b86498d2..a5ce439fc 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -22,11 +22,17 @@ enum ReportType }; class gfp; +class gmp_random; +class Integer; class bigint : public mpz_class { public: static thread_local bigint tmp; + static thread_local gmp_random random; + + template + static mpf_class get_float(T v, Integer exp, T z, T s); bigint() : mpz_class() {} template diff --git a/Math/field_types.h b/Math/field_types.h index 96693d8a2..fe5c0bbc8 100644 --- a/Math/field_types.h +++ b/Math/field_types.h @@ -7,7 +7,7 @@ #define MATH_FIELD_TYPES_H_ -enum DataFieldType { DATA_MODP, DATA_GF2N, N_DATA_FIELD_TYPE }; +enum DataFieldType { DATA_MODP, DATA_GF2N, DATA_GF2, DATA_INT64, N_DATA_FIELD_TYPE }; #endif /* MATH_FIELD_TYPES_H_ */ diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 2b30da28f..1a3240cbd 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -349,3 +349,49 @@ void gf2n_short::input(istream& s,bool human) a &= mask; } + +// Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40) +void expand_byte(gf2n_short& a,int b) +{ + gf2n_short x,xp; + x.assign(32+1); + xp.assign_one(); + a.assign_zero(); + + while (b!=0) + { if ((b&1)==1) + { a.add(a,xp); } + xp.mul(x); + b>>=1; + } +} + + +// Have previously worked out the linear equations we need to solve +void collapse_byte(int& b,const gf2n_short& aa) +{ + word w=aa.get(); + int e35=(w>>35)&1; + int e30=(w>>30)&1; + int e25=(w>>25)&1; + int e20=(w>>20)&1; + int e15=(w>>15)&1; + int e10=(w>>10)&1; + int e5=(w>>5)&1; + int e0=w&1; + int a[8]; + a[7]=e35; + a[6]=e30^a[7]; + a[5]=e25^a[7]; + a[4]=e20^a[5]^a[6]^a[7]; + a[3]=e15^a[7]; + a[2]=e10^a[3]^a[6]^a[7]; + a[1]=e5^a[3]^a[5]^a[7]; + a[0]=e0^a[1]^a[2]^a[3]^a[4]^a[5]^a[6]^a[7]; + + b=0; + for (int i=7; i>=0; i--) + { b=b<<1; + b+=a[i]; + } +} diff --git a/Math/gf2n.h b/Math/gf2n.h index ad1923b11..dd689a9d4 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -12,6 +12,11 @@ using namespace std; #include "Math/gf2nlong.h" #include "Math/field_types.h" +class gf2n_short; + +void expand_byte(gf2n_short& a,int b); +void collapse_byte(int& b,const gf2n_short& a); + /* This interface compatible with the gfp interface * which then allows us to template the Share * data type. @@ -138,12 +143,13 @@ class gf2n_short // x * y when one of x,y is a bit void mul_by_bit(const gf2n_short& x, const gf2n_short& y) { a = x.a * y.a; } - gf2n_short operator+(const gf2n_short& x) { gf2n_short res; res.add(*this, x); return res; } - gf2n_short operator*(const gf2n_short& x) { gf2n_short res; res.mul(*this, x); return res; } + gf2n_short operator+(const gf2n_short& x) const { gf2n_short res; res.add(*this, x); return res; } + gf2n_short operator*(const gf2n_short& x) const { gf2n_short res; res.mul(*this, x); return res; } gf2n_short& operator+=(const gf2n_short& x) { add(x); return *this; } gf2n_short& operator*=(const gf2n_short& x) { mul(x); return *this; } - gf2n_short operator-(const gf2n_short& x) { gf2n_short res; res.add(*this, x); return res; } + gf2n_short operator-(const gf2n_short& x) const { gf2n_short res; res.add(*this, x); return res; } gf2n_short& operator-=(const gf2n_short& x) { sub(x); return *this; } + gf2n_short operator/(const gf2n_short& x) const { gf2n_short tmp; tmp.invert(x); return *this * tmp; } void square(); void square(const gf2n_short& aa); @@ -161,12 +167,12 @@ class gf2n_short void SHL(const gf2n_short& x,int n) { a=(x.a<>n; } - gf2n_short operator&(const gf2n_short& x) { gf2n_short res; res.AND(*this, x); return res; } - gf2n_short operator^(const gf2n_short& x) { gf2n_short res; res.XOR(*this, x); return res; } - gf2n_short operator|(const gf2n_short& x) { gf2n_short res; res.OR(*this, x); return res; } - gf2n_short operator!() { gf2n_short res; res.NOT(*this); return res; } - gf2n_short operator<<(int i) { gf2n_short res; res.SHL(*this, i); return res; } - gf2n_short operator>>(int i) { gf2n_short res; res.SHR(*this, i); return res; } + gf2n_short operator&(const gf2n_short& x) const { gf2n_short res; res.AND(*this, x); return res; } + gf2n_short operator^(const gf2n_short& x) const { gf2n_short res; res.XOR(*this, x); return res; } + gf2n_short operator|(const gf2n_short& x) const { gf2n_short res; res.OR(*this, x); return res; } + gf2n_short operator!() const { gf2n_short res; res.NOT(*this); return res; } + gf2n_short operator<<(int i) const { gf2n_short res; res.SHL(*this, i); return res; } + gf2n_short operator>>(int i) const { gf2n_short res; res.SHR(*this, i); return res; } /* Crap RNG */ void randomize(PRNG& G); diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index 6a814b066..f4ff6bd48 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -187,12 +187,13 @@ class gf2n_long // x * y when one of x,y is a bit void mul_by_bit(const gf2n_long& x, const gf2n_long& y) { a = x.a.a * y.a.a; } - gf2n_long operator+(const gf2n_long& x) { gf2n_long res; res.add(*this, x); return res; } - gf2n_long operator*(const gf2n_long& x) { gf2n_long res; res.mul(*this, x); return res; } + gf2n_long operator+(const gf2n_long& x) const { gf2n_long res; res.add(*this, x); return res; } + gf2n_long operator*(const gf2n_long& x) const { gf2n_long res; res.mul(*this, x); return res; } gf2n_long& operator+=(const gf2n_long& x) { add(x); return *this; } gf2n_long& operator*=(const gf2n_long& x) { mul(x); return *this; } - gf2n_long operator-(const gf2n_long& x) { gf2n_long res; res.add(*this, x); return res; } + gf2n_long operator-(const gf2n_long& x) const { gf2n_long res; res.add(*this, x); return res; } gf2n_long& operator-=(const gf2n_long& x) { sub(x); return *this; } + gf2n_long operator/(const gf2n_long& x) const { gf2n_long tmp; tmp.invert(x); return *this * tmp; } void square(); void square(const gf2n_long& aa); @@ -210,12 +211,12 @@ class gf2n_long void SHL(const gf2n_long& x,int n) { a=(x.a<>n; } - gf2n_long operator&(const gf2n_long& x) { gf2n_long res; res.AND(*this, x); return res; } - gf2n_long operator^(const gf2n_long& x) { gf2n_long res; res.XOR(*this, x); return res; } - gf2n_long operator|(const gf2n_long& x) { gf2n_long res; res.OR(*this, x); return res; } - gf2n_long operator!() { gf2n_long res; res.NOT(*this); return res; } - gf2n_long operator<<(int i) { gf2n_long res; res.SHL(*this, i); return res; } - gf2n_long operator>>(int i) { gf2n_long res; res.SHR(*this, i); return res; } + gf2n_long operator&(const gf2n_long& x) const { gf2n_long res; res.AND(*this, x); return res; } + gf2n_long operator^(const gf2n_long& x) const { gf2n_long res; res.XOR(*this, x); return res; } + gf2n_long operator|(const gf2n_long& x) const { gf2n_long res; res.OR(*this, x); return res; } + gf2n_long operator!() const { gf2n_long res; res.NOT(*this); return res; } + gf2n_long operator<<(int i) const { gf2n_long res; res.SHL(*this, i); return res; } + gf2n_long operator>>(int i) const { gf2n_long res; res.SHR(*this, i); return res; } /* Crap RNG */ void randomize(PRNG& G); diff --git a/Math/gfp.cpp b/Math/gfp.cpp index 39b926920..5a6342bd1 100644 --- a/Math/gfp.cpp +++ b/Math/gfp.cpp @@ -131,3 +131,25 @@ gfp gfp::sqrRoot() to_gfp(temp, ti); return temp; } + +void gfp::reqbl(int n) +{ + if ((int)n > 0 && gfp::pr() < bigint(1) << (n-1)) + { + cout << "Tape requires prime of bit length " << n << endl; + throw invalid_params(); + } + else if ((int)n < 0) + { + throw Processor_Error("Program compiled for rings not fields"); + } +} + +void to_signed_bigint(bigint& ans, const gfp& x) +{ + to_bigint(ans, x); + // get sign and abs(x) + bigint& p_half = bigint::tmp = (gfp::pr()-1)/2; + if (mpz_cmp(ans.get_mpz_t(), p_half.get_mpz_t()) > 0) + ans = gfp::pr() - ans; +} diff --git a/Math/gfp.h b/Math/gfp.h index 6c260725c..b4b96bd62 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -51,16 +51,26 @@ class gfp static int size() { return t() * sizeof(mp_limb_t); } + static void reqbl(int n); + void assign(const gfp& g) { a=g.a; } void assign_zero() { assignZero(a,ZpD); } void assign_one() { assignOne(a,ZpD); } void assign(word aa) { bigint::tmp=aa; to_gfp(*this,bigint::tmp); } - void assign(long aa) { bigint::tmp=aa; to_gfp(*this,bigint::tmp); } - void assign(int aa) { bigint::tmp=aa; to_gfp(*this,bigint::tmp); } + void assign(long aa) + { + if (aa == 0) + assignZero(a, ZpD); + else + to_gfp(*this, bigint::tmp = aa); + } + void assign(int aa) { assign(long(aa)); } void assign(const char* buffer) { a.assign(buffer, ZpD.get_t()); } modp get() const { return a; } + unsigned long debug() const { return a.get_limb(0); } + // Assumes prD behind x is equal to ZpD void assign(modp& x) { a=x; } @@ -134,6 +144,7 @@ class gfp gfp operator+(const gfp& x) const { gfp res; res.add(*this, x); return res; } gfp operator-(const gfp& x) const { gfp res; res.sub(*this, x); return res; } gfp operator*(const gfp& x) const { gfp res; res.mul(*this, x); return res; } + gfp operator/(const gfp& x) const { gfp tmp; tmp.invert(x); return *this * tmp; } gfp& operator+=(const gfp& x) { add(x); return *this; } gfp& operator-=(const gfp& x) { sub(x); return *this; } gfp& operator*=(const gfp& x) { mul(x); return *this; } @@ -211,5 +222,6 @@ class gfp { to_modp(ans.a,x,ans.ZpD); } }; +void to_signed_bigint(bigint& ans,const gfp& x); #endif diff --git a/Math/modp.cpp b/Math/modp.cpp index 5e6758bc1..7f6a46e51 100644 --- a/Math/modp.cpp +++ b/Math/modp.cpp @@ -11,8 +11,7 @@ bool modp::rewind = false; void modp::randomize(PRNG& G, const Zp_Data& ZpD) { - bigint x=G.randomBnd(ZpD.pr); - memcpy(this->x, x.get_mpz_t()->_mp_d, ZpD.get_t() * sizeof(mp_limb_t)); + G.randomBnd(x, ZpD.get_prA(), ZpD.pr_byte_length); } void modp::pack(octetStream& o,const Zp_Data& ZpD) const diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 315165e52..8c7cbb3f4 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -199,7 +199,8 @@ Player::~Player() cerr << it->first << " " << 1e-6 * it->second.data << " MB in " << it->second.rounds << " rounds, taking " << it->second.timer.elapsed() << " seconds" << endl; - cerr << "Receiving took " << timer.elapsed() << " seconds" << endl; + if (timer.elapsed() > 0) + cerr << "Receiving took " << timer.elapsed() << " seconds" << endl; } @@ -213,7 +214,7 @@ void MultiPlayer::setup_sockets(const vector& names,const vector::setup_sockets(const vector& names,const vector void MultiPlayer::pass_around(octetStream& o, int offset) const { TimeScope ts(comm_stats["Passing around"].add(o)); - o.exchange(sockets.at((my_num() + offset) % num_players()), - sockets.at((my_num() + num_players() - offset) % num_players())); + o.exchange(sockets.at(get_player(offset)), sockets.at(get_player(-offset))); sent += o.get_length(); } @@ -563,6 +563,7 @@ static pair sts_responder(int socket, CommsecKeysPackage *keys, void TwoPartyPlayer::setup_sockets(int other_player, const Names &nms, int portNum, int id) { + id += 0xF << 28; const char *hostname = nms.names[other_player].c_str(); ServerSocket *server = nms.server; if (is_server) { @@ -637,5 +638,18 @@ void TwoPartyPlayer::exchange(octetStream& o) const o.exchange(socket, socket); } +CommStats& CommStats::operator +=(const CommStats& other) +{ + data += other.data; + return *this; +} + +NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other) +{ + for (auto it = other.begin(); it != other.end(); it++) + (*this)[it->first] += it->second; + return *this; +} + template class MultiPlayer; template class MultiPlayer; diff --git a/Networking/Player.h b/Networking/Player.h index 3bd365db6..7098bc1a0 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -125,6 +125,13 @@ struct CommStats Timer timer; CommStats() : data(0), rounds(0) {} Timer& add(const octetStream& os) { data += os.get_length(); rounds++; return timer; } + CommStats& operator+=(const CommStats& other); +}; + +class NamedCommStats : public map +{ +public: + NamedCommStats& operator+=(const NamedCommStats& other); }; class Player : public PlayerBase @@ -134,9 +141,9 @@ class Player : public PlayerBase mutable blk_SHA_CTX ctx; - mutable map comm_stats; - public: + mutable NamedCommStats comm_stats; + Player(const Names& Nms); virtual ~Player(); @@ -144,6 +151,7 @@ class Player : public PlayerBase int my_num() const { return player_no; } int get_offset(int other_player) const { return positive_modulo(other_player - my_num(), num_players()); } + int get_player(int offset) const { return positive_modulo(offset + my_num(), num_players()); } virtual bool is_encrypted() { return false; } diff --git a/Networking/data.h b/Networking/data.h index 441d34cd4..974164e9b 100644 --- a/Networking/data.h +++ b/Networking/data.h @@ -6,6 +6,12 @@ #include "Exceptions/Exceptions.h" #include "Tools/avx_memcpy.h" +#ifdef __APPLE__ +# include +#define htole64(x) OSSwapHostToLittleInt64(x) +#define le64toh(x) OSSwapLittleToHostInt64(x) +#endif + typedef unsigned char octet; diff --git a/Player-Online.cpp b/Player-Online.cpp index 0d435dbf2..b98d3ad81 100644 --- a/Player-Online.cpp +++ b/Player-Online.cpp @@ -1,9 +1,12 @@ #include "Processor/Machine.h" +#include "Processor/OnlineOptions.h" #include "Math/Setup.h" #include "Tools/ezOptionParser.h" #include "Tools/Config.h" #include "Networking/Server.h" +#include "Processor/Online-Thread.hpp" + #include #include #include @@ -13,6 +16,7 @@ using namespace std; int main(int argc, const char** argv) { ez::ezOptionParser opt; + OnlineOptions online_opts(opt, argc, argv); opt.syntax = "./Player-Online.x [OPTIONS] \n"; opt.example = "./Player-Online.x -lgp 64 -lg2 128 -m new 0 sample-prog\n./Player-Online.x -pn 13000 -h localhost 1 sample-prog\n"; @@ -27,11 +31,11 @@ int main(int argc, const char** argv) "--lgp" // Flag token. ); opt.add( - "40", // Default. + to_string(gf2n::default_degree()).c_str(), // Default. 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Bit length of GF(2^n) field (default: 40)", // Help description. + ("Bit length of GF(2^n) field (default: " + to_string(gf2n::default_degree()) + ")").c_str(), // Help description. "-lg2", // Flag token. "--lg2" // Flag token. ); @@ -149,6 +153,7 @@ int main(int argc, const char** argv) "--nparties" // Flag token. ); + opt.resetArgs(); opt.parse(argc, argv); vector allArgs(opt.firstArgs); @@ -253,9 +258,10 @@ int main(int argc, const char** argv) try #endif { - Machine(playerno, playerNames, progname, memtype, lgp, lg2, + Machine>(playerno, playerNames, progname, memtype, lgp, lg2, opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet, - opt.get("--threads")->isSet, max_broadcast, false).run(); + opt.get("--threads")->isSet, max_broadcast, false, false, + online_opts).run(); if (server) delete server; @@ -268,8 +274,9 @@ int main(int argc, const char** argv) #ifndef INSECURE catch(...) { - thread_info::purge_preprocessing(playerNames, - get_prep_dir(playerNames.num_players(), lgp, lg2)); + Machine> machine(playerNames); + machine.live_prep = false; + thread_info>::purge_preprocessing(machine); throw; } #endif diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h new file mode 100644 index 000000000..9c3008ee1 --- /dev/null +++ b/Processor/BaseMachine.h @@ -0,0 +1,47 @@ +/* + * BaseMachine.h + * + */ + +#ifndef PROCESSOR_BASEMACHINE_H_ +#define PROCESSOR_BASEMACHINE_H_ + +#include "Tools/time-func.h" + +#include +#include +using namespace std; + +class BaseMachine +{ +protected: + static BaseMachine* singleton; + + std::map timer; + + ifstream inpf; + + void print_timers(); + + virtual void load_program(string threadname, string filename); + +public: + string progname; + int nthreads; + + static BaseMachine& s(); + + BaseMachine(); + virtual ~BaseMachine() {} + + void load_schedule(string progname); + void print_compiler(); + + void time(); + void start(int n); + void stop(int n); + + virtual void reqbl(int n) { (void)n; throw runtime_error("not defined"); } +}; + +#endif /* PROCESSOR_BASEMACHINE_H_ */ diff --git a/Processor/Beaver.h b/Processor/Beaver.h new file mode 100644 index 000000000..99abbec7f --- /dev/null +++ b/Processor/Beaver.h @@ -0,0 +1,28 @@ +/* + * Beaver.h + * + */ + +#ifndef PROCESSOR_BEAVER_H_ +#define PROCESSOR_BEAVER_H_ + +#include +using namespace std; + +template class SubProcessor; +template class MAC_Check_Base; +class Player; + +template +class Beaver +{ +public: + Player& P; + + static void muls(const vector& reg, SubProcessor& proc, + MAC_Check_Base& MC, int size); + + Beaver(Player& P) : P(P) {} +}; + +#endif /* PROCESSOR_BEAVER_H_ */ diff --git a/Processor/Beaver.hpp b/Processor/Beaver.hpp new file mode 100644 index 000000000..75934bbbb --- /dev/null +++ b/Processor/Beaver.hpp @@ -0,0 +1,50 @@ +/* + * Beaver.cpp + * + */ + +#include "Beaver.h" + +#include + +template +void Beaver::muls(const vector& reg, SubProcessor& proc, MAC_Check_Base& MC, + int size) +{ + assert(reg.size() % 3 == 0); + int n = reg.size() / 3; + vector& shares = proc.Sh_PO; + vector& opened = proc.PO; + shares.clear(); + vector> triples(n * size); + auto triple = triples.begin(); + + for (int i = 0; i < n; i++) + for (int j = 0; j < size; j++) + { + proc.DataF.get(DATA_TRIPLE, triple->data()); + for (int k = 0; k < 2; k++) + shares.push_back(proc.S[reg[i * 3 + k + 1] + j] - (*triple)[k]); + triple++; + } + + MC.POpen_Begin(opened, shares, proc.P); + MC.POpen_End(opened, shares, proc.P); + auto it = opened.begin(); + triple = triples.begin(); + + for (int i = 0; i < n; i++) + for (int j = 0; j < size; j++) + { + typename T::clear masked[2]; + T& tmp = (*triple)[2]; + for (int k = 0; k < 2; k++) + { + masked[k] = *it++; + tmp += (masked[k] * (*triple)[1 - k]); + } + tmp.add(tmp, masked[0] * masked[1], proc.P.my_num(), MC.get_alphai()); + proc.S[reg[i * 3] + j] = tmp; + triple++; + } +} diff --git a/Processor/Binary_File_IO.cpp b/Processor/Binary_File_IO.hpp similarity index 79% rename from Processor/Binary_File_IO.cpp rename to Processor/Binary_File_IO.hpp index 92ccfdd44..cf2040faf 100644 --- a/Processor/Binary_File_IO.cpp +++ b/Processor/Binary_File_IO.hpp @@ -66,9 +66,3 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, for (unsigned int i = 0; i < buffer.size(); i++) buffer[i].assign(&read_buffer[i*T::size()]); } - -template void Binary_File_IO::write_to_file(const string filename, const vector< Share >& buffer); -template void Binary_File_IO::read_from_file(const string filename, vector< Share >& buffer, const int start_posn, int &end_posn); - -template void Binary_File_IO::write_to_file(const string filename, const vector< Rep3Share >& buffer); -template void Binary_File_IO::read_from_file(const string filename, vector< Rep3Share >& buffer, const int start_posn, int &end_posn); diff --git a/Processor/Buffer.cpp b/Processor/Buffer.cpp index 7eedd1f8d..338928722 100644 --- a/Processor/Buffer.cpp +++ b/Processor/Buffer.cpp @@ -4,8 +4,6 @@ */ #include "Buffer.h" -#include "Processor/InputTuple.h" -#include "Processor/Data_Files.h" bool BufferBase::rewind = false; @@ -36,7 +34,7 @@ void BufferBase::try_rewind() { #ifndef INSECURE string type; - if (field_type and data_type) + if (field_type.size() and data_type.size()) type = (string)" of " + field_type + " " + data_type; throw not_enough_to_buffer(type); #endif @@ -75,81 +73,3 @@ void BufferBase::purge() file = 0; } } - -template -Buffer::~Buffer() -{ - if (timer.elapsed() && data_type) - cerr << T::type_string() << " " << data_type << " reading: " - << timer.elapsed() << endl; -} - -template -void Buffer::fill_buffer() -{ - if (T::size() == sizeof(T)) - { - // read directly - read((char*)buffer); - } - else - { - char read_buffer[sizeof(buffer)]; - read(read_buffer); - //memset(buffer, 0, sizeof(buffer)); - for (int i = 0; i < BUFFER_SIZE; i++) - buffer[i].assign(&read_buffer[i*T::size()]); - } -} - -template -void Buffer::read(char* read_buffer) -{ - int size_in_bytes = T::size() * BUFFER_SIZE; - int n_read = 0; - timer.start(); - do - { - file->read(read_buffer + n_read, size_in_bytes - n_read); - n_read += file->gcount(); - if (file->eof()) - { - try_rewind(); - } - if (file->fail()) - { - stringstream ss; - ss << "IO problem when buffering " << T::type_string(); - if (data_type) - ss << " " << data_type; - ss << " from " << filename; - throw file_error(ss.str()); - } - } - while (n_read < size_in_bytes); - timer.stop(); -} - -template -void Buffer::input(U& a) -{ - if (next == BUFFER_SIZE) - { - fill_buffer(); - next = 0; - } - - a = buffer[next]; - next++; -} - -template class Buffer< Share, Share >; -template class Buffer< Share, Share >; -template class Buffer< Rep3Share, Rep3Share>; -template class Buffer< InputTuple, RefInputTuple >; -template class Buffer< InputTuple, RefInputTuple >; -template class Buffer< InputTuple, RefInputTuple >; -template class Buffer< gfp, gfp >; -template class Buffer< gf2n, gf2n >; -template class Buffer< FixedVec, FixedVec >; -template class Buffer< Integer, Integer >; diff --git a/Processor/Buffer.h b/Processor/Buffer.h index 9cf050ff6..60c23a687 100644 --- a/Processor/Buffer.h +++ b/Processor/Buffer.h @@ -18,7 +18,6 @@ using namespace std; #define BUFFER_SIZE 101 #endif - class BufferBase { protected: @@ -26,8 +25,8 @@ class BufferBase ifstream* file; int next; - const char* data_type; - const char* field_type; + string data_type; + string field_type; Timer timer; int tuple_length; string filename; @@ -35,10 +34,10 @@ class BufferBase public: bool eof; - BufferBase() : file(0), next(BUFFER_SIZE), data_type(0), field_type(0), + BufferBase() : file(0), next(BUFFER_SIZE), tuple_length(-1), eof(false) {} - void setup(ifstream* f, int length, string filename, const char* type = 0, - const char* field = 0); + void setup(ifstream* f, int length, string filename, const char* type = "", + const char* field = ""); void seekg(int pos); bool is_up() { return file != 0; } void try_rewind(); @@ -71,7 +70,7 @@ class BufferOwner : public Buffer { } - void setup(string filename, int tuple_length, const char* data_type = 0) + void setup(string filename, int tuple_length, const char* data_type = "") { file = new ifstream(filename, ios::in | ios::binary); Buffer::setup(file, tuple_length, filename, data_type, U::type_string().c_str()); @@ -85,4 +84,73 @@ class BufferOwner : public Buffer } }; +template +inline Buffer::~Buffer() +{ + if (timer.elapsed() && data_type.size()) + cerr << T::type_string() << " " << data_type << " reading: " + << timer.elapsed() << endl; +} + +template +inline void Buffer::fill_buffer() +{ + if (T::size() == sizeof(T)) + { + // read directly + read((char*)buffer); + } + else + { + char read_buffer[sizeof(buffer)]; + read(read_buffer); + //memset(buffer, 0, sizeof(buffer)); + for (int i = 0; i < BUFFER_SIZE; i++) + buffer[i].assign(&read_buffer[i*T::size()]); + } +} + +template +inline void Buffer::read(char* read_buffer) +{ + int size_in_bytes = T::size() * BUFFER_SIZE; + int n_read = 0; + timer.start(); + if (not file) + throw IO_Error(T::type_string() + " buffer not set up"); + do + { + file->read(read_buffer + n_read, size_in_bytes - n_read); + n_read += file->gcount(); + if (file->eof()) + { + try_rewind(); + } + if (file->fail()) + { + stringstream ss; + ss << "IO problem when buffering " << T::type_string(); + if (data_type.size()) + ss << " " << data_type; + ss << " from " << filename; + throw file_error(ss.str()); + } + } + while (n_read < size_in_bytes); + timer.stop(); +} + +template +inline void Buffer::input(U& a) +{ + if (next == BUFFER_SIZE) + { + fill_buffer(); + next = 0; + } + + a = buffer[next]; + next++; +} + #endif /* PROCESSOR_BUFFER_H_ */ diff --git a/Processor/Data_Files.cpp b/Processor/Data_Files.cpp index 5d84313c7..ecf68f383 100644 --- a/Processor/Data_Files.cpp +++ b/Processor/Data_Files.cpp @@ -1,10 +1,17 @@ #include "Processor/Data_Files.h" #include "Processor/Processor.h" +#include "Processor/ReplicatedPrep.h" +#include "Processor/MaliciousRepPrep.h" +#include "GC/MaliciousRepSecret.h" +#include "Math/MaliciousRep3Share.h" + +#include "Processor/MaliciousRepPrep.hpp" #include +#include -const char* DataPositions::field_names[] = { "sint", "sgf2n" }; +const char* DataPositions::field_names[] = { "gfp", "gf2n", "bit", "int64" }; template<> const bool Sub_Data_Files::implemented[N_DTYPE] = @@ -12,15 +19,40 @@ const bool Sub_Data_Files::implemented[N_DTYPE] = ; template<> -const bool Sub_Data_Files::implemented[N_DTYPE] = +const bool Sub_Data_Files>::implemented[N_DTYPE] = { true, true, true, true, true, true } ; template<> -const bool Sub_Data_Files::implemented[N_DTYPE] = +const bool Sub_Data_Files>::implemented[N_DTYPE] = { false, false, true, false, false, false } ; +template<> +const bool Sub_Data_Files>::implemented[N_DTYPE] = + { true, true, true, true, false, false } +; + +template<> +const bool Sub_Data_Files>::implemented[N_DTYPE] = + { true, true, true, true, false, false } +; + +template<> +const bool Sub_Data_Files>::implemented[N_DTYPE] = + { true, true, true, true, false, false } +; + +template<> +const bool Sub_Data_Files>::implemented[N_DTYPE] = + { true, true, true, true, false, false } +; + +template<> +const bool Sub_Data_Files::implemented[N_DTYPE] = + { true, false, true, false, false, false } +; + const int DataPositions::tuple_size[N_DTYPE] = { 3, 2, 1, 2, 3, 3 }; template @@ -28,6 +60,58 @@ Lock Sub_Data_Files::tuple_lengths_lock; template map Sub_Data_Files::tuple_lengths; +template<> +template<> +Preprocessing>* Preprocessing>::get_new( + Machine, Rep3Share>& machine, DataPositions& usage) +{ + if (machine.live_prep) + return new ReplicatedPrep>; + else + return new Sub_Data_Files>(machine.get_N(), machine.prep_dir_prefix, usage); +} + +template<> +template<> +Preprocessing>* Preprocessing>::get_new( + Machine, Rep3Share>& machine, DataPositions& usage) +{ + if (machine.live_prep) + return new ReplicatedPrep>; + else + return new Sub_Data_Files>(machine.get_N(), machine.prep_dir_prefix, usage); +} + +template<> +template<> +Preprocessing>* Preprocessing>::get_new( + Machine, MaliciousRep3Share>& machine, DataPositions& usage) +{ + if (machine.live_prep) + return new MaliciousRepPrep; + else + return new Sub_Data_Files>(machine.get_N(), machine.prep_dir_prefix, usage); +} + +template<> +template<> +Preprocessing>* Preprocessing>::get_new( + Machine, MaliciousRep3Share>& machine, DataPositions& usage) +{ + if (machine.live_prep) + return new MaliciousRepPrep; + else + return new Sub_Data_Files>(machine.get_N(), machine.prep_dir_prefix, usage); +} + + +template +template +Preprocessing* Preprocessing::get_new(Machine& machine, + DataPositions& usage) +{ + return new Sub_Data_Files(machine.get_N(), machine.prep_dir_prefix, usage); +} void DataPositions::set_num_players(int num_players) { @@ -59,13 +143,19 @@ void DataPositions::print_cost() const double total_cost = 0; for (int i = 0; i < N_DATA_FIELD_TYPE; i++) { - cerr << " Type " << field_names[i] << endl; + if (accumulate(files[i].begin(), files[i].end(), 0) > 0) + cerr << " Type " << field_names[i] << endl; + bool reading_field = true; for (int j = 0; j < N_DTYPE; j++) { double cost_per_item = 0; - file >> cost_per_item; + if (reading_field) + file >> cost_per_item; if (cost_per_item < 0) - break; + { + reading_field = false; + cost_per_item = 0; + } long long items_used = files[i][j]; double cost = items_used * cost_per_item; total_cost += cost; @@ -83,7 +173,8 @@ void DataPositions::print_cost() const } } - cerr << "Total cost: " << total_cost << endl; + if (total_cost > 0) + cerr << "Total cost: " << total_cost << endl; } @@ -93,20 +184,35 @@ int Sub_Data_Files::tuple_length(int dtype) return DataPositions::tuple_size[dtype] * T::size(); } +template +string Sub_Data_Files::get_suffix(int thread_num) +{ +#ifdef INSECURE + (void) thread_num; + return ""; +#else + if (thread_num >= 0) + return "-T" + to_string(thread_num); + else + return ""; +#endif +} + template Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, - const string& prep_data_dir, DataPositions& usage) : + const string& prep_data_dir, DataPositions& usage, int thread_num) : my_num(my_num), num_players(num_players), prep_data_dir(prep_data_dir), usage(usage) { cerr << "Setting up Data_Files in: " << prep_data_dir << endl; char filename[1024]; + string suffix = get_suffix(thread_num); for (int dtype = 0; dtype < N_DTYPE; dtype++) { if (implemented[dtype]) { - sprintf(filename,(prep_data_dir + "%s-%s-P%d").c_str(),DataPositions::dtype_names[dtype], - string(1, T::type_char()).c_str(),my_num); + sprintf(filename,(prep_data_dir + "%s-%s-P%d%s").c_str(),DataPositions::dtype_names[dtype], + (T::type_short()).c_str(),my_num,suffix.c_str()); buffers[dtype].setup(filename, tuple_length(dtype), DataPositions::dtype_names[dtype]); } @@ -115,8 +221,8 @@ Sub_Data_Files::Sub_Data_Files(int my_num, int num_players, input_buffers.resize(num_players); for (int i=0; i::Sub_Data_Files(int my_num, int num_players, cerr << "done\n"; } -template -Data_Files::Data_Files(int myn, int n, const string& prep_data_dir) : - usage(n), DataFp(myn, n, prep_data_dir, usage), - DataF2(myn, n, prep_data_dir, usage), prep_data_dir(prep_data_dir) +template +Data_Files::Data_Files(Machine& machine) : + usage(machine.get_N().num_players()), + DataFp(*Preprocessing::get_new(machine, usage)), + DataF2(*Preprocessing::get_new(machine, usage)) { } @@ -168,16 +275,16 @@ void Sub_Data_Files::seekg(DataPositions& pos) } } -template -void Data_Files::seekg(DataPositions& pos) +template +void Data_Files::seekg(DataPositions& pos) { DataFp.seekg(pos); DataF2.seekg(pos); usage = pos; } -template -void Data_Files::skip(const DataPositions& pos) +template +void Data_Files::skip(const DataPositions& pos) { DataPositions new_pos = usage; new_pos.increase(pos); @@ -196,8 +303,8 @@ void Sub_Data_Files::prune() it.second.prune(); } -template -void Data_Files::prune() +template +void Data_Files::prune() { DataFp.prune(); DataF2.prune(); @@ -215,8 +322,8 @@ void Sub_Data_Files::purge() it.second.purge(); } -template -void Data_Files::purge() +template +void Data_Files::purge() { DataFp.purge(); DataF2.purge(); @@ -247,24 +354,31 @@ void Sub_Data_Files::setup_extended(const DataTag& tag, int tuple_size) if (!buffer.is_up()) { stringstream ss; - ss << prep_data_dir << tag.get_string() << "-" << T::type_char() << "-P" << my_num; + ss << prep_data_dir << tag.get_string() << "-" << T::type_short() << "-P" << my_num; extended[tag].setup(ss.str(), tuple_length); } } template -void Sub_Data_Files::get(SubProcessor& proc, DataTag tag, const vector& regs, int vector_size) +void Sub_Data_Files::get(vector& S, DataTag tag, const vector& regs, int vector_size) { usage.extended[T::field_type()][tag] += vector_size; setup_extended(tag, regs.size()); for (int j = 0; j < vector_size; j++) for (unsigned int i = 0; i < regs.size(); i++) - extended[tag].input(proc.get_S_ref(regs[i] + j)); + extended[tag].input(S[regs[i] + j]); } -template class Sub_Data_Files; +template class Sub_Data_Files>; template class Sub_Data_Files; -template class Sub_Data_Files; - -template class Data_Files; -template class Data_Files; +template class Sub_Data_Files>; +template class Sub_Data_Files>; +template class Sub_Data_Files>; +template class Sub_Data_Files; +template class Sub_Data_Files>; +template class Sub_Data_Files>; + +template class Data_Files>; +template class Data_Files, Rep3Share>; +template class Data_Files, Rep3Share>; +template class Data_Files, MaliciousRep3Share>; diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 1441d99ba..3ed16f825 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -60,12 +60,39 @@ struct DataPositions void print_cost() const; }; -template class Processor; -template class Data_Files; +template class Processor; +template class Data_Files; +template class Machine; template -class Sub_Data_Files +class Preprocessing { +public: + template + static Preprocessing* get_new(Machine& machine, DataPositions& usage); + + virtual ~Preprocessing() {} + + virtual void set_protocol(typename T::Protocol& protocol) = 0; + + virtual void seekg(DataPositions& pos) { (void) pos; } + virtual void prune() {} + virtual void purge() {} + + virtual void get(Dtype dtype, T* a); + virtual void get_three(Dtype dtype, T& a, T& b, T& c) = 0; + virtual void get_two(Dtype dtype, T& a, T& b) = 0; + virtual void get_one(Dtype dtype, T& a) = 0; + virtual void get_input(T& a, typename T::clear& x, int i) = 0; + virtual void get(vector& S, DataTag tag, const vector& regs, + int vector_size) = 0; +}; + +template +class Sub_Data_Files : public Preprocessing +{ + template friend class Sub_Data_Files; + static const bool implemented[N_DTYPE]; static map tuple_lengths; @@ -85,10 +112,19 @@ class Sub_Data_Files DataPositions& usage; public: + static string get_suffix(int thread_num); + Sub_Data_Files(int my_num, int num_players, const string& prep_data_dir, - DataPositions& usage); + DataPositions& usage, int thread_num = -1); + Sub_Data_Files(const Names& N, const string& prep_data_dir, + DataPositions& usage, int thread_num = -1) : + Sub_Data_Files(N.my_num(), N.num_players(), prep_data_dir, usage, thread_num) + { + } ~Sub_Data_Files(); + void set_protocol(typename T::Protocol& protocol) { (void) protocol; } + void seekg(DataPositions& pos); void prune(); void purge(); @@ -130,24 +166,20 @@ class Sub_Data_Files } void setup_extended(const DataTag& tag, int tuple_size = 0); - void get(SubProcessor& proc, DataTag tag, const vector& regs, int vector_size); + void get(vector& S, DataTag tag, const vector& regs, int vector_size); }; -template +template class Data_Files { DataPositions usage; public: - Sub_Data_Files DataFp; - Sub_Data_Files DataF2; - - const string& prep_data_dir; + Preprocessing& DataFp; + Preprocessing& DataF2; - Data_Files(int my_num,int n,const string& prep_data_dir); - Data_Files(Names& N, const string& prep_data_dir) : - Data_Files(N.my_num(), N.num_players(), prep_data_dir) {} + Data_Files(Machine& machine); DataPositions tellg(); void seekg(DataPositions& pos); @@ -155,64 +187,10 @@ class Data_Files void prune(); void purge(); - template - bool eof(Dtype dtype) - { - return get_sub().eof(dtype); - } - template - bool input_eof(int player) - { - return get_sub().input_eof(player); - } - - void setup_extended(DataFieldType field_type, const DataTag& tag, int tuple_size = 0); - template - void get(SubProcessor& proc, DataTag tag, const vector& regs, int vector_size) - { - get_sub().get(proc, tag, regs, vector_size); - } - DataPositions get_usage() { return usage; } - - template - Sub_Data_Files& get_sub(); - - template - void get(Dtype dtype, T* a) - { - get_sub().get(dtype, a); - } - - template - void get_three(DataFieldType field_type, Dtype dtype, T& a, T& b, T& c) - { - (void)field_type; - get_sub().get_three(dtype, a, b, c); - } - - template - void get_two(DataFieldType field_type, Dtype dtype, T& a, T& b) - { - (void)field_type; - get_sub().get_two(dtype, a, b); - } - - template - void get_one(DataFieldType field_type, Dtype dtype, T& a) - { - (void)field_type; - get_sub().get_one(dtype, a); - } - - template - void get_input(T& a,typename T::clear& x,int i) - { - get_sub().get_input(a, x, i); - } }; template inline @@ -236,32 +214,24 @@ inline void Sub_Data_Files::get(Dtype dtype, T* a) buffers[dtype].input(a[i]); } -template<> -template<> -inline Sub_Data_Files& Data_Files::get_sub() -{ - return DataFp; -} - -template<> -template<> -inline Sub_Data_Files& Data_Files::get_sub() -{ - return DataF2; -} - -template<> -template<> -inline Sub_Data_Files& Data_Files::get_sub() -{ - return DataFp; -} - -template<> -template<> -inline Sub_Data_Files& Data_Files::get_sub() +template +inline void Preprocessing::get(Dtype dtype, T* a) { - return DataF2; + switch (dtype) + { + case DATA_TRIPLE: + get_three(dtype, a[0], a[1], a[2]); + break; + case DATA_SQUARE: + case DATA_INVERSE: + get_two(dtype, a[0], a[1]); + break; + case DATA_BIT: + get_one(dtype, a[0]); + break; + default: + throw not_implemented(); + } } #endif diff --git a/Processor/DummyProtocol.h b/Processor/DummyProtocol.h index 524c6f5f5..2a574ba88 100644 --- a/Processor/DummyProtocol.h +++ b/Processor/DummyProtocol.h @@ -8,6 +8,15 @@ class Player; +class DummyMC +{ +public: + void Check(Player& P) + { + (void) P; + } +}; + class DummyProtocol { public: diff --git a/Processor/Input.cpp b/Processor/Input.cpp deleted file mode 100644 index eec0870c4..000000000 --- a/Processor/Input.cpp +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Input.cpp - * - */ - -#include "Input.h" -#include "Processor.h" - -template -InputBase::InputBase(ArithmeticProcessor& proc) : - values_input(0) -{ - buffer.setup(&proc.private_input, -1, proc.private_input_filename); -} - -template -Input::Input(SubProcessor>& proc, MAC_Check& mc) : - InputBase(proc.Proc), proc(proc), MC(mc), shares(proc.P.num_players()) -{ -} - -template -InputBase::~InputBase() -{ - if (timer.elapsed() > 0) - cerr << T::type_string() << " inputs: " << timer.elapsed() << endl; -} - -template -void Input::adjust_mac(Share& share, T& value) -{ - T tmp; - tmp.mul(MC.get_alphai(), value); - tmp.add(share.get_mac(),tmp); - share.set_mac(tmp); -} - -template -void Input::start(int player, int n_inputs) -{ - shares[player].resize(n_inputs); - vector rr(n_inputs); - - if (player == proc.P.my_num()) - { - octetStream o; - - for (int i = 0; i < n_inputs; i++) - { - T rr, t; - Share& share = shares[player][i]; - proc.DataF.get_input(share, rr, player); - T xi; - try - { - this->buffer.input(t); - } - catch (not_enough_to_buffer& e) - { - throw runtime_error("Insufficient input data to buffer"); - } - t.sub(t, rr); - t.pack(o); - xi.add(t, share.get_share()); - share.set_share(xi); - adjust_mac(share, t); - } - - proc.P.send_all(o, true); - this->values_input += n_inputs; - } - else - { - T t; - for (int i = 0; i < n_inputs; i++) - proc.DataF.get_input(shares[player][i], t, player); - } -} - -template -void Input::stop(int player, vector targets) -{ - for (unsigned int i = 0; i < targets.size(); i++) - proc.get_S_ref(targets[i]) = shares[player][i]; - - if (proc.P.my_num() != player) - { - T t; - octetStream o; - this->timer.start(); - proc.P.receive_player(player, o, true); - this->timer.stop(); - for (unsigned int i = 0; i < targets.size(); i++) - { - Share& share = proc.get_S_ref(targets[i]); - t.unpack(o); - adjust_mac(share, t); - } - } -} - -template class InputBase; -template class Input; -template class Input; diff --git a/Processor/Input.h b/Processor/Input.h index 8a2bb20a1..8722daeef 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -10,7 +10,6 @@ using namespace std; #include "Math/Share.h" -#include "Auth/MAC_Check.h" #include "Processor/Buffer.h" #include "Tools/time-func.h" @@ -20,31 +19,38 @@ template class InputBase { protected: - Buffer buffer; + Buffer buffer; Timer timer; public: int values_input; + static void input(SubProcessor& Proc, const vector& args); + InputBase(ArithmeticProcessor& proc); ~InputBase(); }; template -class Input : public InputBase +class Input : public InputBase> { SubProcessor>& proc; MAC_Check& MC; vector< vector< Share > > shares; + octetStream o; void adjust_mac(Share& share, T& value); public: Input(SubProcessor>& proc, MAC_Check& mc); - void start(int player, int n_inputs); - void stop(int player, vector targets); + void reset(int player); + void add_mine(const T& input); + void add_other(int player); + void send_mine(); + void start(int player, int n_inputs); + void stop(int player, const vector& targets); }; #endif /* PROCESSOR_INPUT_H_ */ diff --git a/Processor/Input.hpp b/Processor/Input.hpp new file mode 100644 index 000000000..c6a2fbcaa --- /dev/null +++ b/Processor/Input.hpp @@ -0,0 +1,173 @@ +/* + * Input.cpp + * + */ + +#include "Input.h" +#include "Processor.h" +#include "Auth/MAC_Check.h" + +template +InputBase::InputBase(ArithmeticProcessor& proc) : + values_input(0) +{ + buffer.setup(&proc.private_input, -1, proc.private_input_filename); +} + +template +Input::Input(SubProcessor>& proc, MAC_Check& mc) : + InputBase>(proc.Proc), proc(proc), MC(mc), shares(proc.P.num_players()) +{ +} + +template +InputBase::~InputBase() +{ + if (timer.elapsed() > 0) + cerr << T::type_string() << " inputs: " << timer.elapsed() << endl; +} + +template +void Input::adjust_mac(Share& share, T& value) +{ + T tmp; + tmp.mul(MC.get_alphai(), value); + tmp.add(share.get_mac(),tmp); + share.set_mac(tmp); +} + +template +void Input::reset(int player) +{ + shares[player].clear(); + if (player == proc.P.my_num()) + o.reset_write_head(); +} + +template +void Input::add_mine(const T& input) +{ + int player = proc.P.my_num(); + T rr, t = input; + shares[player].push_back({}); + Share& share = shares[player].back(); + proc.DataF.get_input(share, rr, player); + T xi; + t.sub(t, rr); + t.pack(o); + xi.add(t, share.get_share()); + share.set_share(xi); + adjust_mac(share, t); + this->values_input++; +} + +template +void Input::add_other(int player) +{ + T t; + shares[player].push_back({}); + proc.DataF.get_input(shares[player].back(), t, player); +} + +template +void Input::send_mine() +{ + proc.P.send_all(o, true); +} + +template +void Input::start(int player, int n_inputs) +{ + reset(player); + if (player == proc.P.my_num()) + { + for (int i = 0; i < n_inputs; i++) + { + T t; + try + { + this->buffer.input(t); + } + catch (not_enough_to_buffer& e) + { + throw runtime_error("Insufficient input data to buffer"); + } + add_mine(t); + } + send_mine(); + } + else + { + for (int i = 0; i < n_inputs; i++) + add_other(player); + } +} + +template +void Input::stop(int player, const vector& targets) +{ + for (unsigned int i = 0; i < targets.size(); i++) + proc.get_S_ref(targets[i]) = shares[player][i]; + + if (proc.P.my_num() != player) + { + T t; + octetStream o; + this->timer.start(); + proc.P.receive_player(player, o, true); + this->timer.stop(); + for (unsigned int i = 0; i < targets.size(); i++) + { + Share& share = proc.get_S_ref(targets[i]); + t.unpack(o); + adjust_mac(share, t); + } + } +} + +template +void InputBase::input(SubProcessor& Proc, + const vector& args) +{ + auto& input = Proc.input; + for (int i = 0; i < Proc.P.num_players(); i++) + input.reset(i); + assert(args.size() % 2 == 0); + + int n_from_me = 0; + + if (Proc.Proc.opts.interactive and Proc.Proc.thread_num == 0) + { + for (size_t i = 1; i < args.size(); i += 2) + n_from_me += (args[i] == Proc.P.my_num()); + if (n_from_me > 0) + cout << "Please input " << n_from_me << " numbers:" << endl; + } + + for (size_t i = 0; i < args.size(); i += 2) + { + int n = args[i + 1]; + if (n == Proc.P.my_num()) + { + long x = Proc.Proc.get_input(n_from_me > 0); + input.add_mine(x); + } + else + { + input.add_other(n); + } + } + + if (n_from_me > 0) + cout << "Thank you" << endl; + + input.send_mine(); + + vector> regs(Proc.P.num_players()); + for (size_t i = 0; i < args.size(); i += 2) + { + regs[args[i + 1]].push_back(args[i]); + } + for (int i = 0; i < Proc.P.num_players(); i++) + input.stop(i, regs[i]); +} diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index f7fc62320..aa1f4469d 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -6,12 +6,20 @@ #include "Tools/time-func.h" #include "Tools/parse.h" #include "Auth/ReplicatedMC.h" +#include "Math/MaliciousRep3Share.h" + +#include "Processor/Processor.hpp" +#include "Processor/Binary_File_IO.hpp" +#include "Processor/Input.hpp" +#include "Processor/Beaver.hpp" +#include "Auth/MaliciousRepMC.hpp" #include #include #include #include -#include + +#include "Tools/callgrind.h" // broken #undef DEBUG @@ -223,7 +231,6 @@ void BaseInstruction::parse_operands(istream& s, int pos) case STMS: case LDMINT: case STMINT: - case INPUT: case JMPNZ: case JMPEQZ: case GLDI: @@ -232,7 +239,6 @@ void BaseInstruction::parse_operands(istream& s, int pos) case GLDMS: case GSTMC: case GSTMS: - case GINPUT: case PRINTREG: case GPRINTREG: case LDINT: @@ -244,6 +250,7 @@ void BaseInstruction::parse_operands(istream& s, int pos) case GINPUTMASK: case ACCEPTCLIENTCONNECTION: case INV2M: + case CONDPRINTSTR: r[0]=get_int(s); n = get_int(s); break; @@ -272,6 +279,8 @@ void BaseInstruction::parse_operands(istream& s, int pos) case GOPEN: case MULS: case GMULS: + case INPUT: + case GINPUT: num_var_args = get_int(s); get_vector(num_var_args, start, s); break; @@ -360,7 +369,9 @@ void BaseInstruction::parse_operands(istream& s, int pos) break; default: ostringstream os; - os << "Invalid instruction " << hex << showbase << opcode << " at " << dec << pos; + os << "Invalid instruction " << hex << showbase << opcode << " at " << dec << pos << endl; + os << "This virtual machine executes arithmetic circuits only." << endl; + os << "Try compiling without '-B' and don't use sbit* types." << endl; throw Invalid_Instruction(os.str()); } } @@ -431,17 +442,29 @@ int BaseInstruction::get_reg_type() const } } -int BaseInstruction::get_max_reg(int reg_type) const +unsigned BaseInstruction::get_max_reg(int reg_type) const { if (get_reg_type() != reg_type) { return 0; } + const int *begin, *end; if (start.size()) - return *max_element(start.begin(), start.end()) + size; + { + begin = start.data(); + end = start.data() + start.size(); + } else - return *max_element(r, r + 3) + size; + { + begin = r; + end = r + 3; + } + + unsigned res = 0; + for (auto it = begin; it != end; it++) + res = max(res, (unsigned)*it); + return res + size; } -int Instruction::get_mem(RegType reg_type, SecrecyType sec_type) const +unsigned Instruction::get_mem(RegType reg_type, SecrecyType sec_type) const { if (get_reg_type() == reg_type and is_direct_memory_access(sec_type)) return n + size; @@ -498,13 +521,15 @@ ostream& operator<<(ostream& s,const Instruction& instr) } -template +template #ifndef __clang__ __attribute__((always_inline)) #endif -inline void Instruction::execute(Processor& Proc) const +inline void Instruction::execute(Processor& Proc) const { Proc.PC+=1; + auto& Procp = Proc.Procp; + auto& Proc2 = Proc.Proc2; #ifndef DEBUG // optimize some instructions @@ -560,12 +585,12 @@ inline void Instruction::execute(Processor& Proc) const return; case TRIPLE: for (int i = 0; i < size; i++) - Proc.DataF.get_three(DATA_MODP, DATA_TRIPLE, Proc.get_Sp_ref(r[0] + i), + Procp.DataF.get_three(DATA_TRIPLE, Proc.get_Sp_ref(r[0] + i), Proc.get_Sp_ref(r[1] + i), Proc.get_Sp_ref(r[2] + i)); return; case BIT: for (int i = 0; i < size; i++) - Proc.DataF.get_one(DATA_MODP, DATA_BIT, Proc.get_Sp_ref(r[0] + i)); + Procp.DataF.get_one(DATA_BIT, Proc.get_Sp_ref(r[0] + i)); return; } #endif @@ -586,15 +611,7 @@ inline void Instruction::execute(Processor& Proc) const Proc.get_Sp_ref(r[0]).assign(n, Proc.P.my_num(), Proc.MCp.get_alphai()); break; case GLDSI: - { Proc.temp.ans2.assign(n); - if (Proc.P.my_num()==0) - Proc.get_S2_ref(r[0]).set_share(Proc.temp.ans2); - else - Proc.get_S2_ref(r[0]).assign_zero(); - gf2n& tmp=Proc.temp.tmp2; - tmp.mul(Proc.MC2.get_alphai(),Proc.temp.ans2); - Proc.get_S2_ref(r[0]).set_mac(tmp); - } + Proc.get_S2_ref(r[0]).assign(n, Proc.P.my_num(), Proc.MC2.get_alphai()); break; case LDMC: Proc.write_Cp(r[0],Proc.machine.Mp.read_C(n)); @@ -1085,86 +1102,62 @@ inline void Instruction::execute(Processor& Proc) const #endif break; case TRIPLE: - Proc.DataF.get_three(DATA_MODP, DATA_TRIPLE, Proc.get_Sp_ref(r[0]),Proc.get_Sp_ref(r[1]),Proc.get_Sp_ref(r[2])); + Procp.DataF.get_three(DATA_TRIPLE, Proc.get_Sp_ref(r[0]),Proc.get_Sp_ref(r[1]),Proc.get_Sp_ref(r[2])); break; case GTRIPLE: - Proc.DataF.get_three(DATA_GF2N, DATA_TRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); + Proc2.DataF.get_three(DATA_TRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); break; case GBITTRIPLE: - Proc.DataF.get_three(DATA_GF2N, DATA_BITTRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); + Proc2.DataF.get_three(DATA_BITTRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); break; case GBITGF2NTRIPLE: - Proc.DataF.get_three(DATA_GF2N, DATA_BITGF2NTRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); + Proc2.DataF.get_three(DATA_BITGF2NTRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); break; case SQUARE: - Proc.DataF.get_two(DATA_MODP, DATA_SQUARE, Proc.get_Sp_ref(r[0]),Proc.get_Sp_ref(r[1])); + Procp.DataF.get_two(DATA_SQUARE, Proc.get_Sp_ref(r[0]),Proc.get_Sp_ref(r[1])); break; case GSQUARE: - Proc.DataF.get_two(DATA_GF2N, DATA_SQUARE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1])); + Proc2.DataF.get_two(DATA_SQUARE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1])); break; case BIT: - Proc.DataF.get_one(DATA_MODP, DATA_BIT, Proc.get_Sp_ref(r[0])); + Procp.DataF.get_one(DATA_BIT, Proc.get_Sp_ref(r[0])); break; case GBIT: - Proc.DataF.get_one(DATA_GF2N, DATA_BIT, Proc.get_S2_ref(r[0])); + Proc2.DataF.get_one(DATA_BIT, Proc.get_S2_ref(r[0])); break; case INV: - Proc.DataF.get_two(DATA_MODP, DATA_INVERSE, Proc.get_Sp_ref(r[0]),Proc.get_Sp_ref(r[1])); + Procp.DataF.get_two(DATA_INVERSE, Proc.get_Sp_ref(r[0]),Proc.get_Sp_ref(r[1])); break; case GINV: - Proc.DataF.get_two(DATA_GF2N, DATA_INVERSE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1])); + Proc2.DataF.get_two(DATA_INVERSE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1])); break; case INPUTMASK: - Proc.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n); + Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n); if (n == Proc.P.my_num()) Proc.temp.rrp.output(Proc.private_output, false); break; case GINPUTMASK: - Proc.DataF.get_input(Proc.get_S2_ref(r[0]), Proc.temp.ans2, n); + Proc2.DataF.get_input(Proc.get_S2_ref(r[0]), Proc.temp.ans2, n); if (n == Proc.P.my_num()) Proc.temp.ans2.output(Proc.private_output, false); break; case INPUT: - sint::Protocol::input(Proc.Procp, n, r); + sint::Input::input(Proc.Procp, start); break; case GINPUT: - { gf2n& rr=Proc.temp.rr2; gf2n& t=Proc.temp.t2; gf2n& tmp=Proc.temp.tmp2; - Proc.DataF.get_input(Proc.get_S2_ref(r[0]),rr,n); - octetStream o; - if (n==Proc.P.my_num()) - { gf2n& xi=Proc.temp.xi2; - #ifdef DEBUG - printf("Enter your input : \n"); - #endif - word x; - cin >> x; - t.assign(x); - t.sub(t,rr); - t.pack(o); - Proc.P.send_all(o); - xi.add(t,Proc.get_S2_ref(r[0]).get_share()); - Proc.get_S2_ref(r[0]).set_share(xi); - } - else - { Proc.P.receive_player(n,o); - t.unpack(o); - } - tmp.mul(Proc.MC2.get_alphai(),t); - tmp.add(Proc.get_S2_ref(r[0]).get_mac(),tmp); - Proc.get_S2_ref(r[0]).set_mac(tmp); - } + sgf2n::Input::input(Proc.Proc2, start); break; case STARTINPUT: - Proc.inputp.start(r[0],n); + Proc.Procp.input.start(r[0],n); break; case GSTARTINPUT: - Proc.input2.start(r[0],n); + Proc.Proc2.input.start(r[0],n); break; case STOPINPUT: - Proc.inputp.stop(n,start); + Proc.Procp.input.stop(n,start); break; case GSTOPINPUT: - Proc.input2.stop(n,start); + Proc.Proc2.input.stop(n,start); break; case ANDC: #ifdef DEBUG @@ -1382,7 +1375,7 @@ inline void Instruction::execute(Processor& Proc) const Proc.Procp.protocol.muls(start, Proc.Procp, Proc.MCp, size); return; case GMULS: - SPDZ::muls(start, Proc.Proc2, Proc.MC2, size); + Proc.Proc2.protocol.muls(start, Proc.Proc2, Proc.MC2, size); return; case JMP: Proc.PC += (signed int) n; @@ -1466,92 +1459,72 @@ inline void Instruction::execute(Processor& Proc) const Proc.write_Ci(r[0], Proc.read_C2(r[1]).get_word()); break; case PRINTMEM: - if (Proc.P.my_num() == 0) - { cout << "Mem[" << r[0] << "] = " << Proc.machine.Mp.read_C(r[0]) << endl; } + { Proc.out << "Mem[" << r[0] << "] = " << Proc.machine.Mp.read_C(r[0]) << endl; } break; case GPRINTMEM: - if (Proc.P.my_num() == 0) - { cout << "Mem[" << r[0] << "] = " << Proc.machine.M2.read_C(r[0]) << endl; } + { Proc.out << "Mem[" << r[0] << "] = " << Proc.machine.M2.read_C(r[0]) << endl; } break; case PRINTREG: - if (Proc.P.my_num() == 0) { - cout << "Reg[" << r[0] << "] = " << Proc.read_Cp(r[0]) + Proc.out << "Reg[" << r[0] << "] = " << Proc.read_Cp(r[0]) << " # " << string((char*)&n,sizeof(n)) << endl; } break; case GPRINTREG: - if (Proc.P.my_num() == 0) { - cout << "Reg[" << r[0] << "] = " << Proc.read_C2(r[0]) + Proc.out << "Reg[" << r[0] << "] = " << Proc.read_C2(r[0]) << " # " << string((char*)&n,sizeof(n)) << endl; } break; case PRINTREGPLAIN: - if (Proc.P.my_num() == 0) { - cout << Proc.read_Cp(r[0]) << flush; + Proc.out << Proc.read_Cp(r[0]) << flush; } break; case GPRINTREGPLAIN: - if (Proc.P.my_num() == 0) { - cout << Proc.read_C2(r[0]) << flush; + Proc.out << Proc.read_C2(r[0]) << flush; } break; case PRINTINT: - if (Proc.P.my_num() == 0) { - cout << Proc.read_Ci(r[0]) << flush; + Proc.out << Proc.read_Ci(r[0]) << flush; } break; case PRINTFLOATPLAIN: - if (Proc.P.my_num() == 0) { typename sint::clear v = Proc.read_Cp(start[0]); typename sint::clear p = Proc.read_Cp(start[1]); typename sint::clear z = Proc.read_Cp(start[2]); typename sint::clear s = Proc.read_Cp(start[3]); - to_bigint(Proc.temp.aa, v); // MPIR can't handle more precision in exponent to_signed_bigint(Proc.temp.aa2, p, 31); long exp = Proc.temp.aa2.get_si(); - mpf_class res = Proc.temp.aa; - if (exp > 0) - mpf_mul_2exp(res.get_mpf_t(), res.get_mpf_t(), exp); - else - mpf_div_2exp(res.get_mpf_t(), res.get_mpf_t(), -exp); - if (z.is_one()) - res = 0; - if (!s.is_zero()) - res *= -1; - if (not z.is_bit() or not s.is_bit()) - throw Processor_Error("invalid floating point number"); - cout << res << flush; + Proc.out << bigint::get_float(v, exp, z, s) << flush; } break; case PRINTSTR: - if (Proc.P.my_num() == 0) { - cout << string((char*)&n,sizeof(n)) << flush; + Proc.out << string((char*)&n,sizeof(n)) << flush; } break; + case CONDPRINTSTR: + if (not Proc.read_Cp(r[0]).is_zero()) + Proc.out << string((char*)&n,sizeof(n)) << flush; + break; case PRINTCHR: - if (Proc.P.my_num() == 0) { - cout << string((char*)&n,1) << flush; + Proc.out << string((char*)&n,1) << flush; } break; case PRINTCHRINT: - if (Proc.P.my_num() == 0) { - cout << string((char*)&(Proc.read_Ci(r[0])),1) << flush; + Proc.out << string((char*)&(Proc.read_Ci(r[0])),1) << flush; } break; case PRINTSTRINT: - if (Proc.P.my_num() == 0) { - cout << string((char*)&(Proc.read_Ci(r[0])),sizeof(int)) << flush; + Proc.out << string((char*)&(Proc.read_Ci(r[0])),sizeof(int)) << flush; } break; case RAND: @@ -1688,10 +1661,10 @@ inline void Instruction::execute(Processor& Proc) const Proc.privateOutput2.stop(n,r[0]); break; case PREP: - Proc.DataF.get(Proc.Procp, r, start, size); + Procp.DataF.get(Proc.Procp.get_S(), r, start, size); return; case GPREP: - Proc.DataF.get(Proc.Proc2, r, start, size); + Proc2.DataF.get(Proc.Proc2.get_S(), r, start, size); return; default: printf("Case of opcode=%d not implemented yet\n",opcode); @@ -1705,8 +1678,8 @@ inline void Instruction::execute(Processor& Proc) const } } -template -void Program::execute(Processor& Proc) const +template +void Program::execute(Processor& Proc) const { unsigned int size = p.size(); Proc.PC=0; @@ -1717,5 +1690,7 @@ void Program::execute(Processor& Proc) const { p[Proc.PC].execute(Proc); } } -template void Program::execute(Processor& Proc) const; -template void Program::execute(Processor& Proc) const; +template void Program::execute(Processor>& Proc) const; +template void Program::execute(Processor, Rep3Share>& Proc) const; +template void Program::execute(Processor, Rep3Share>& Proc) const; +template void Program::execute(Processor, MaliciousRep3Share>& Proc) const; diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 9dcceef70..96c68c65b 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -9,13 +9,12 @@ #include using namespace std; -#include "Processor/Data_Files.h" #include "Networking/Player.h" #include "Math/Integer.h" -#include "Auth/MAC_Check.h" +#include "Math/Share.h" -template class Machine; -template class Processor; +template class Machine; +template class Processor; /* * Opcode constants @@ -166,6 +165,7 @@ enum PRINTFLOATPLAIN = 0xBC, WRITEFILESHARE = 0xBD, READFILESHARE = 0xBE, + CONDPRINTSTR = 0xBF, // GF(2^n) versions @@ -272,7 +272,7 @@ enum SecrecyType { MAX_SECRECY_TYPE }; -template +template struct TempVars { gf2n ans2; Share Sans2; typename sint::clear ansp; @@ -312,9 +312,10 @@ class BaseInstruction bool is_direct_memory_access(SecrecyType sec_type) const; // Returns the maximal register used - int get_max_reg(int reg_type) const; + unsigned get_max_reg(int reg_type) const; }; +struct DataPositions; class Instruction : public BaseInstruction { @@ -326,14 +327,14 @@ class Instruction : public BaseInstruction bool get_offline_data_usage(DataPositions& usage); // Returns the memory size used if applicable and known - int get_mem(RegType reg_type, SecrecyType sec_type) const; + unsigned get_mem(RegType reg_type, SecrecyType sec_type) const; friend ostream& operator<<(ostream& s,const Instruction& instr); // Execute this instruction, updateing the processor and memory // and streams pointing to the triples etc - template - void execute(Processor& Proc) const; + template + void execute(Processor& Proc) const; }; diff --git a/Processor/Machine.cpp b/Processor/Machine.cpp index f6c210795..a8a85e508 100644 --- a/Processor/Machine.cpp +++ b/Processor/Machine.cpp @@ -1,10 +1,14 @@ #include "Machine.h" +#include "Memory.hpp" +#include "Online-Thread.hpp" + #include "Exceptions/Exceptions.h" #include #include "Math/Setup.h" +#include "Math/MaliciousRep3Share.h" #include #include @@ -31,15 +35,15 @@ BaseMachine::BaseMachine() : nthreads(0) singleton = this; } -template -Machine::Machine(int my_number, Names& playerNames, +template +Machine::Machine(int my_number, Names& playerNames, string progname_str, string memtype, int lgp, int lg2, bool direct, int opening_sum, bool parallel, bool receive_threads, int max_broadcast, - bool use_encryption) + bool use_encryption, bool live_prep, OnlineOptions opts) : my_number(my_number), N(playerNames), tn(0), numt(0), usage_unknown(false), direct(direct), opening_sum(opening_sum), parallel(parallel), receive_threads(receive_threads), max_broadcast(max_broadcast), - use_encryption(use_encryption) + use_encryption(use_encryption), live_prep(live_prep), opts(opts) { if (opening_sum < 2) this->opening_sum = N.num_players(); @@ -86,8 +90,7 @@ Machine::Machine(int my_number, Names& playerNames, } else if (memtype.compare("old")==0) { - sprintf(filename, PREP_DIR "Memory-P%d", my_number); - inpf.open(filename,ios::in | ios::binary); + inpf.open(memory_filename(), ios::in | ios::binary); if (inpf.fail()) { throw file_error(); } inpf >> M2 >> Mp >> Mi; inpf.close(); @@ -126,7 +129,7 @@ Machine::Machine(int my_number, Names& playerNames, tinfo[i].machine=this; // lock for synchronization pthread_mutex_lock(&t_mutex[i]); - pthread_create(&threads[i],NULL,thread_info::Main_Func,&tinfo[i]); + pthread_create(&threads[i],NULL,thread_info::Main_Func,&tinfo[i]); } // synchronize with clients before starting timer @@ -177,8 +180,8 @@ void BaseMachine::print_compiler() inpf.close(); } -template -void Machine::load_program(string threadname, string filename) +template +void Machine::load_program(string threadname, string filename) { ifstream pinp(filename); if (pinp.fail()) { throw file_error(filename); } @@ -191,8 +194,8 @@ void Machine::load_program(string threadname, string filename) Mi.minimum_size(INT, progs[i], threadname); } -template -DataPositions Machine::run_tape(int thread_number, int tape_number, int arg, int line_number) +template +DataPositions Machine::run_tape(int thread_number, int tape_number, int arg, int line_number) { if (thread_number >= (int)tinfo.size()) throw Processor_Error("invalid thread number: " + to_string(thread_number) + "/" + to_string(tinfo.size())); @@ -231,8 +234,8 @@ DataPositions Machine::run_tape(int thread_number, int tape_number, int ar } } -template -void Machine::join_tape(int i) +template +void Machine::join_tape(int i) { join_timer[i].start(); pthread_mutex_lock(&t_mutex[i]); @@ -243,8 +246,8 @@ void Machine::join_tape(int i) join_timer[i].stop(); } -template -void Machine::run() +template +void Machine::run() { Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID); proc_timer.start(); @@ -332,16 +335,14 @@ void Machine::run() cerr << "Full broadcast" << endl; // Reduce memory size to speed up - int max_size = 1 << 20; + unsigned max_size = 1 << 20; if (M2.size_s() > max_size) M2.resize_s(max_size); if (Mp.size_s() > max_size) Mp.resize_s(max_size); // Write out the memory to use next time - char filename[1024]; - sprintf(filename,PREP_DIR "Memory-P%d",my_number); - ofstream outf(filename,ios::out | ios::binary); + ofstream outf(memory_filename(), ios::out | ios::binary); outf << M2 << Mp << Mi; outf.close(); @@ -370,7 +371,7 @@ void Machine::run() pos.print_cost(); #ifndef INSECURE - Data_Files df(N.my_num(), N.num_players(), prep_dir_prefix); + Data_Files df(*this); df.seekg(pos); df.prune(); #endif @@ -378,6 +379,12 @@ void Machine::run() cerr << "End of prog" << endl; } +template +string Machine::memory_filename() +{ + return PREP_DIR "Memory-" + sint::type_short() + "-P" + to_string(my_number); +} + void BaseMachine::load_program(string threadname, string filename) { (void)threadname; @@ -411,11 +418,13 @@ void BaseMachine::print_timers() cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl; } -template -void Machine::reqbl(int n) +template +void Machine::reqbl(int n) { - sint::Protocol::reqbl(n); + sint::clear::reqbl(n); } -template class Machine; -template class Machine; +template class Machine>; +template class Machine, Rep3Share>; +template class Machine, Rep3Share>; +template class Machine, MaliciousRep3Share>; diff --git a/Processor/Machine.h b/Processor/Machine.h index eab3f44dd..738f2c816 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -6,11 +6,12 @@ #ifndef MACHINE_H_ #define MACHINE_H_ +#include "Processor/BaseMachine.h" #include "Processor/Memory.h" #include "Processor/Program.h" +#include "Processor/OnlineOptions.h" #include "Processor/Online-Thread.h" -#include "Processor/Data_Files.h" #include "Math/gfp.h" #include "Tools/time-func.h" @@ -19,39 +20,7 @@ #include using namespace std; -class BaseMachine -{ -protected: - static BaseMachine* singleton; - - std::map timer; - - ifstream inpf; - - void print_timers(); - - virtual void load_program(string threadname, string filename); - -public: - string progname; - int nthreads; - - static BaseMachine& s(); - - BaseMachine(); - virtual ~BaseMachine() {} - - void load_schedule(string progname); - void print_compiler(); - - void time(); - void start(int n); - void stop(int n); - - virtual void reqbl(int n) { (void)n; } -}; - -template +template class Machine : public BaseMachine { /* The mutex's lock the C-threads and then only release @@ -60,12 +29,12 @@ class Machine : public BaseMachine * MPC thread releases the mutex */ - vector> tinfo; + vector> tinfo; vector threads; int my_number; Names& N; - gfp alphapi; + typename sint::value_type alphapi; gf2n alpha2i; // Keep record of used offline data @@ -98,17 +67,25 @@ class Machine : public BaseMachine bool receive_threads; int max_broadcast; bool use_encryption; + bool live_prep; + + OnlineOptions opts; Machine(int my_number, Names& playerNames, string progname, string memtype, int lgp, int lg2, bool direct, int opening_sum, bool parallel, - bool receive_threads, int max_broadcast, bool use_encryption); + bool receive_threads, int max_broadcast, bool use_encryption, bool live_prep, + OnlineOptions opts); + + const Names& get_N() { return N; } DataPositions run_tape(int thread_number, int tape_number, int arg, int line_number); void join_tape(int thread_number); void run(); + string memory_filename(); + // Only for Player-Demo.cpp - Machine(): N(*(new Names())) {} + Machine(Names& N = *(new Names())): N(N) {} void reqbl(int n); }; diff --git a/Processor/MaliciousRepPrep.h b/Processor/MaliciousRepPrep.h new file mode 100644 index 000000000..8fb40734c --- /dev/null +++ b/Processor/MaliciousRepPrep.h @@ -0,0 +1,47 @@ +/* + * MaliciousRepPrep.h + * + */ + +#ifndef PROCESSOR_MALICIOUSREPPREP_H_ +#define PROCESSOR_MALICIOUSREPPREP_H_ + +#include "Data_Files.h" +#include "ReplicatedPrep.h" +#include "Math/MaliciousRep3Share.h" +#include "Auth/MaliciousRepMC.h" + +#include + +template +class MaliciousRepPrep : public BufferPrep> +{ + typedef MaliciousRep3Share T; + typedef BufferPrep> super; + + ReplicatedPrep> honest_prep; + Replicated>* replicated; + HashMaliciousRepMC MC; + + vector masked; + vector checks; + vector opened; + + vector> check_triples; + vector> check_squares; + + void clear_tmp(); + + void buffer_triples(); + void buffer_squares(); + void buffer_inverses(); + void buffer_bits(); + +public: + MaliciousRepPrep(); + ~MaliciousRepPrep(); + + void set_protocol(Beaver& protocol); +}; + +#endif /* PROCESSOR_MALICIOUSREPPREP_H_ */ diff --git a/Processor/MaliciousRepPrep.hpp b/Processor/MaliciousRepPrep.hpp new file mode 100644 index 000000000..2f43983be --- /dev/null +++ b/Processor/MaliciousRepPrep.hpp @@ -0,0 +1,165 @@ +/* + * MaliciousRepPrep.cpp + * + */ + +#include "MaliciousRepPrep.h" +#include "Auth/Subroutines.h" +#include "Auth/MaliciousRepMC.hpp" + +template +MaliciousRepPrep::MaliciousRepPrep() : replicated(0) +{ +} + +template +MaliciousRepPrep::~MaliciousRepPrep() +{ + if (replicated) + delete replicated; +} + +template +void MaliciousRepPrep::set_protocol(Beaver& protocol) +{ + replicated = new Replicated>(protocol.P); + honest_prep.set_protocol(*replicated); +} + +template +void MaliciousRepPrep::clear_tmp() +{ + masked.clear(); + checks.clear(); + check_triples.clear(); + check_squares.clear(); +} + +template +void MaliciousRepPrep::buffer_triples() +{ + auto& triples = this->triples; + auto& buffer_size = this->buffer_size; + clear_tmp(); + Player& P = honest_prep.protocol->P; + triples.clear(); + for (int i = 0; i < buffer_size; i++) + { + T a, b, c; + T f, g, h; + honest_prep.get_three(DATA_TRIPLE, a, b, c); + honest_prep.get_three(DATA_TRIPLE, f, g, h); + triples.push_back({a, b, c}); + check_triples.push_back({f, g, h}); + } + auto t = Create_Random(P); + for (int i = 0; i < buffer_size; i++) + { + T& a = triples[i][0]; + T& b = triples[i][1]; + T& f = check_triples[i][0]; + T& g = check_triples[i][1]; + masked.push_back(a * t - f); + masked.push_back(b - g); + } + MC.POpen(opened, masked, P); + for (int i = 0; i < buffer_size; i++) + { + T& b = triples[i][1]; + T& c = triples[i][2]; + T& f = check_triples[i][0]; + T& h = check_triples[i][2]; + typename T::clear& rho = opened[2 * i]; + typename T::clear& sigma = opened[2 * i + 1]; + checks.push_back(t * c - h - rho * b - sigma * f); + } + MC.POpen(opened, checks, P); + for (auto& check : opened) + if (check != 0) + throw Offline_Check_Error("triple"); + MC.Check(P); +} + +template +void MaliciousRepPrep::buffer_squares() +{ + auto& squares = this->squares; + auto& buffer_size = this->buffer_size; + clear_tmp(); + Player& P = honest_prep.protocol->P; + squares.clear(); + for (int i = 0; i < buffer_size; i++) + { + T a, b; + T f, h; + honest_prep.get_two(DATA_SQUARE, a, b); + honest_prep.get_two(DATA_SQUARE, f, h); + squares.push_back({a, b}); + check_squares.push_back({f, h}); + } + auto t = Create_Random(P); + for (int i = 0; i < buffer_size; i++) + { + T& a = squares[i][0]; + T& f = check_squares[i][0]; + masked.push_back(a * t - f); + } + MC.POpen(opened, masked, P); + for (int i = 0; i < buffer_size; i++) + { + T& a = squares[i][0]; + T& b = squares[i][1]; + T& f = check_squares[i][0]; + T& h = check_squares[i][1]; + auto& rho = opened[i]; + checks.push_back(t * t * b - h - rho * (t * a + f)); + } + MC.POpen(opened, checks, P); + for (auto& check : opened) + if (check != 0) + throw Offline_Check_Error("square"); +} + +template +void MaliciousRepPrep::buffer_inverses() +{ + BufferPrep::buffer_inverses(MC, honest_prep.protocol->P); +} + +template +void MaliciousRepPrep::buffer_bits() +{ + auto& bits = this->bits; + auto& buffer_size = this->buffer_size; + clear_tmp(); + Player& P = honest_prep.protocol->P; + bits.clear(); + for (int i = 0; i < buffer_size; i++) + { + T a, f, h; + honest_prep.get_one(DATA_BIT, a); + honest_prep.get_two(DATA_SQUARE, f, h); + bits.push_back(a); + check_squares.push_back({f, h}); + } + auto t = Create_Random(P); + for (int i = 0; i < buffer_size; i++) + { + T& a = bits[i]; + T& f = check_squares[i][0]; + masked.push_back(t * a - f); + } + MC.POpen(opened, masked, P); + for (int i = 0; i < buffer_size; i++) + { + T& a = bits[i]; + T& f = check_squares[i][0]; + T& h = check_squares[i][1]; + auto& rho = opened[i]; + masked.push_back(t * t * a - h - rho * (t * a + f)); + } + MC.POpen(opened, checks, P); + for (auto& check : opened) + if (check != 0) + throw Offline_Check_Error("bit"); +} diff --git a/Processor/Memory.h b/Processor/Memory.h index fd77050f1..99a9c5206 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -31,9 +31,9 @@ class Memory void resize_c(int sz) { MC.resize(sz); } - int size_s() + unsigned size_s() { return MS.size(); } - int size_c() + unsigned size_c() { return MC.size(); } const typename T::clear& read_C(int i) const diff --git a/Processor/Memory.cpp b/Processor/Memory.hpp similarity index 81% rename from Processor/Memory.cpp rename to Processor/Memory.hpp index 65814fe4b..2d0b9ce1e 100644 --- a/Processor/Memory.cpp +++ b/Processor/Memory.hpp @@ -9,7 +9,7 @@ template void Memory::minimum_size(RegType reg_type, const Program& program, string threadname) { - const int* sizes = program.direct_mem(reg_type); + const unsigned* sizes = program.direct_mem(reg_type); if (sizes[SECRET] > size_s()) { cerr << threadname << " needs more secret " << T::type_string() << " memory, resizing to " @@ -141,18 +141,3 @@ void Memory::Load_Memory(ifstream& inpf) S.input(inpf,true); } } - -template class Memory; -template class Memory; -template class Memory; -template class Memory; - -template istream& operator>>(istream& s,Memory& M); -template istream& operator>>(istream& s,Memory& M); -template istream& operator>>(istream& s,Memory& M); -template istream& operator>>(istream& s,Memory& M); - -template ostream& operator<<(ostream& s,const Memory& M); -template ostream& operator<<(ostream& s,const Memory& M); -template ostream& operator<<(ostream& s,const Memory& M); -template ostream& operator<<(ostream& s,const Memory& M); diff --git a/Processor/Online-Thread.h b/Processor/Online-Thread.h index c4f837029..240e7c3e4 100644 --- a/Processor/Online-Thread.h +++ b/Processor/Online-Thread.h @@ -10,9 +10,9 @@ #include using namespace std; -template class Machine; +template class Machine; -template +template class thread_info { public: @@ -21,7 +21,7 @@ class thread_info int covert; Names* Nms; gf2n *alpha2i; - gfp *alphapi; + typename sint::value_type *alphapi; int prognum; bool finished; bool ready; @@ -31,11 +31,11 @@ class thread_info // Integer arg (optional) int arg; - Machine* machine; + Machine* machine; static void* Main_Func(void *ptr); - static void purge_preprocessing(Names& N, string prep_dir); + static void purge_preprocessing(Machine& machine); }; #endif diff --git a/Processor/Online-Thread.cpp b/Processor/Online-Thread.hpp similarity index 82% rename from Processor/Online-Thread.cpp rename to Processor/Online-Thread.hpp index e3119bc04..4c1d31ef6 100644 --- a/Processor/Online-Thread.cpp +++ b/Processor/Online-Thread.hpp @@ -5,19 +5,24 @@ #include "Processor/Data_Files.h" #include "Processor/Machine.h" #include "Processor/Processor.h" +#include "Auth/ReplicatedMC.h" #include "Networking/CryptoPlayer.h" +#include "Processor/Processor.hpp" +#include "Processor/Input.hpp" +#include "Auth/MaliciousRepMC.hpp" + #include #include #include using namespace std; -template +template void* Sub_Main_Func(void* ptr) { - thread_info *tinfo=(thread_info *) ptr; - Machine& machine=*(tinfo->machine); + thread_info *tinfo=(thread_info *) ptr; + Machine& machine=*(tinfo->machine); vector& t_mutex = machine.t_mutex; vector& client_ready = machine.client_ready; vector& server_ready = machine.server_ready; @@ -44,34 +49,34 @@ void* Sub_Main_Func(void* ptr) Player& P = *player; fprintf(stderr, "\tSet up player in thread %d\n",num); - Data_Files DataF(P.my_num(),P.num_players(),machine.prep_dir_prefix); + Data_Files DataF(machine); - MAC_Check* MC2; + typename sgf2n::MAC_Check* MC2; typename sint::MAC_Check* MCp; // Use MAC_Check instead for more than 10000 openings at once if (machine.direct) { cerr << "Using direct communication. If computation stalls, use -m when compiling." << endl; - MC2 = new Direct_MAC_Check(*(tinfo->alpha2i),*(tinfo->Nms), num); + MC2 = new typename sgf2n::Direct_MC(*(tinfo->alpha2i),*(tinfo->Nms), num); MCp = new typename sint::Direct_MC(*(tinfo->alphapi),*(tinfo->Nms), num); } else if (machine.parallel) { cerr << "Using indirect communication with background threads." << endl; - MC2 = new Parallel_MAC_Check(*(tinfo->alpha2i),*(tinfo->Nms), num, machine.opening_sum, machine.max_broadcast); + //MC2 = new Parallel_MAC_Check(*(tinfo->alpha2i),*(tinfo->Nms), num, machine.opening_sum, machine.max_broadcast); //MCp = new Parallel_MAC_Check(*(tinfo->alphapi),*(tinfo->Nms), num, machine.opening_sum, machine.max_broadcast); throw not_implemented(); } else { cerr << "Using indirect communication." << endl; - MC2 = new MAC_Check(*(tinfo->alpha2i), machine.opening_sum, machine.max_broadcast); + MC2 = new typename sgf2n::MAC_Check(*(tinfo->alpha2i), machine.opening_sum, machine.max_broadcast); MCp = new typename sint::MAC_Check(*(tinfo->alphapi), machine.opening_sum, machine.max_broadcast); } // Allocate memory for first program before starting the clock - Processor Proc(tinfo->thread_num,DataF,P,*MC2,*MCp,machine,progs[0]); + Processor Proc(tinfo->thread_num,DataF,P,*MC2,*MCp,machine,progs[0]); Share a,b,c; bool flag=true; @@ -177,21 +182,20 @@ void* Sub_Main_Func(void* ptr) } -template -void* thread_info::Main_Func(void* ptr) +template +void* thread_info::Main_Func(void* ptr) { #ifndef INSECURE try #endif { - Sub_Main_Func(ptr); + Sub_Main_Func(ptr); } #ifndef INSECURE catch (...) { - thread_info* ti = (thread_info*)ptr; - purge_preprocessing(*ti->Nms, - ti->machine->prep_dir_prefix); + thread_info* ti = (thread_info*)ptr; + ti->purge_preprocessing(*ti->machine); throw; } #endif @@ -199,13 +203,13 @@ void* thread_info::Main_Func(void* ptr) } -template -void thread_info::purge_preprocessing(Names& N, string prep_dir) +template +void thread_info::purge_preprocessing(Machine& machine) { cerr << "Purging preprocessed data because something is wrong" << endl; try { - Data_Files df(N, prep_dir); + Data_Files df(machine); df.purge(); } catch(...) @@ -214,7 +218,3 @@ void thread_info::purge_preprocessing(Names& N, string prep_dir) << "SECURITY FAILURE; YOU ARE ON YOUR OWN NOW!" << endl; } } - - -template class thread_info; -template class thread_info; diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp new file mode 100644 index 000000000..5f140cef7 --- /dev/null +++ b/Processor/OnlineOptions.cpp @@ -0,0 +1,31 @@ +/* + * OnlineOptions.cpp + * + */ + +#include "OnlineOptions.h" + +OnlineOptions::OnlineOptions() +{ + interactive = false; +} + +OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, + const char** argv) +{ + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Interactive mode in the main thread (default: disabled)", // Help description. + "-I", // Flag token. + "--interactive" // Flag token. + ); + + opt.parse(argc, argv); + + interactive = opt.isSet("-I"); + + opt.resetArgs(); +} diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h new file mode 100644 index 000000000..583f7a6d4 --- /dev/null +++ b/Processor/OnlineOptions.h @@ -0,0 +1,20 @@ +/* + * OnlineOptions.h + * + */ + +#ifndef PROCESSOR_ONLINEOPTIONS_H_ +#define PROCESSOR_ONLINEOPTIONS_H_ + +#include "Tools/ezOptionParser.h" + +class OnlineOptions +{ +public: + bool interactive; + + OnlineOptions(); + OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv); +}; + +#endif /* PROCESSOR_ONLINEOPTIONS_H_ */ diff --git a/Processor/Processor.h b/Processor/Processor.h index 137447043..87df4b958 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -12,8 +12,6 @@ #include "Math/Integer.h" #include "Exceptions/Exceptions.h" #include "Networking/Player.h" -#include "Auth/MAC_Check.h" -#include "Auth/ReplicatedMC.h" #include "Data_Files.h" #include "Input.h" #include "ReplicatedInput.h" @@ -25,32 +23,8 @@ #include "Instruction.h" #include "SPDZ.h" #include "Replicated.h" - -#include - -class ProcessorBase -{ - // Stack - stack stacki; - -protected: - // Optional argument to tape - int arg; - -public: - void pushi(long x) { stacki.push(x); } - void popi(long& x) { x = stacki.top(); stacki.pop(); } - - int get_arg() const - { - return arg; - } - - void set_arg(int new_arg) - { - arg=new_arg; - } -}; +#include "ProcessorBase.h" +#include "Tools/SwitchableOutput.h" template class SubProcessor @@ -65,20 +39,22 @@ class SubProcessor void resize(int size) { C.resize(size); S.resize(size); } - template friend class Processor; + template friend class Processor; template friend class SPDZ; template friend class Replicated; + template friend class Beaver; public: ArithmeticProcessor& Proc; typename T::MAC_Check& MC; Player& P; - Sub_Data_Files& DataF; + Preprocessing& DataF; typename T::Protocol protocol; + typename T::Input input; SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& MC, - Sub_Data_Files& DataF, Player& P); + Preprocessing& DataF, Player& P); // Access to PO (via calls to POpen start/stop) void POpen_Start(const vector& reg,const Player& P,int size); @@ -87,6 +63,11 @@ class SubProcessor void muls(const vector& reg,const Player& P,int size); + vector& get_S() + { + return S; + } + T& get_S_ref(int i) { return S[i]; @@ -101,6 +82,8 @@ class SubProcessor class ArithmeticProcessor : public ProcessorBase { public: + int thread_num; + PRNG secure_prng; string private_input_filename; @@ -112,16 +95,18 @@ class ArithmeticProcessor : public ProcessorBase int sent, rounds; - ArithmeticProcessor() : sent(0), rounds(0) {} + OnlineOptions opts; + + ArithmeticProcessor(OnlineOptions opts, int thread_num) : thread_num(thread_num), + sent(0), rounds(0), opts(opts) {} }; -template +template class Processor : public ArithmeticProcessor { vector Ci; int reg_max2,reg_maxp,reg_maxi; - int thread_num; // Data structure used for reading/writing data to/from a socket (i.e. an external party to SPDZ) octetStream socket_stream; @@ -143,23 +128,20 @@ class Processor : public ArithmeticProcessor vector& get_PO(); public: - Data_Files& DataF; + Data_Files& DataF; Player& P; - MAC_Check& MC2; + typename sgf2n::MAC_Check& MC2; typename sint::MAC_Check& MCp; - Machine& machine; + Machine& machine; SubProcessor Proc2; SubProcessor Procp; - Input input2; - typename sint::Input inputp; - - PrivateOutput privateOutput2; + typename sgf2n::PrivateOutput privateOutput2; typename sint::PrivateOutput privateOutputp; unsigned int PC; - TempVars temp; + TempVars temp; PRNG shared_prng; @@ -169,13 +151,16 @@ class Processor : public ArithmeticProcessor // avoid re-computation of expensive division map inverses2m; + SwitchableOutput out; + static const int reg_bytes = 4; void reset(const Program& program,int arg); // Reset the state of the processor string get_filename(const char* basename, bool use_number); - Processor(int thread_num,Data_Files& DataF,Player& P, - MAC_Check& MC2,typename sint::MAC_Check& MCp,Machine& machine, + Processor(int thread_num,Data_Files& DataF,Player& P, + typename sgf2n::MAC_Check& MC2,typename sint::MAC_Check& MCp, + Machine& machine, const Program& program); ~Processor(); @@ -255,15 +240,15 @@ class Processor : public ArithmeticProcessor #else const gf2n& read_C2(int i) const { return Proc2.C[i]; } - const Share & read_S2(int i) const + const sgf2n& read_S2(int i) const { return Proc2.S[i]; } gf2n& get_C2_ref(int i) { return Proc2.C[i]; } - Share & get_S2_ref(int i) + sgf2n& get_S2_ref(int i) { return Proc2.S[i]; } void write_C2(int i,const gf2n& x) { Proc2.C[i]=x; } - void write_S2(int i,const Share & x) + void write_S2(int i,const sgf2n& x) { Proc2.S[i]=x; } const typename sint::clear& read_Cp(int i) const @@ -307,8 +292,8 @@ class Processor : public ArithmeticProcessor void write_shares_to_file(const vector& data_registers); // Print the processor state - template - friend ostream& operator<<(ostream& s,const Processor& P); + template + friend ostream& operator<<(ostream& s,const Processor& P); private: void maybe_decrypt_sequence(int client_id); diff --git a/Processor/Processor.cpp b/Processor/Processor.hpp similarity index 83% rename from Processor/Processor.cpp rename to Processor/Processor.hpp index fb622d0d9..89985876a 100644 --- a/Processor/Processor.cpp +++ b/Processor/Processor.hpp @@ -2,26 +2,34 @@ #include "Processor/Processor.h" #include "Networking/STS.h" #include "Auth/MAC_Check.h" - +#include "Auth/ReplicatedMC.h" #include "Auth/fake-stuff.h" + +#include "Processor/ReplicatedInput.hpp" +#include "Processor/ReplicatedPrivateOutput.hpp" + #include #include template SubProcessor::SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& MC, - Sub_Data_Files& DataF, Player& P) : - Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P) + Preprocessing& DataF, Player& P) : + Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P), input(*this, MC) { + DataF.set_protocol(protocol); } -template -Processor::Processor(int thread_num,Data_Files& DataF,Player& P, - MAC_Check& MC2,typename sint::MAC_Check& MCp,Machine& machine, +template +Processor::Processor(int thread_num,Data_Files& DataF,Player& P, + typename sgf2n::MAC_Check& MC2,typename sint::MAC_Check& MCp, + Machine& machine, const Program& program) -: thread_num(thread_num),DataF(DataF),P(P),MC2(MC2),MCp(MCp),machine(machine), +: ArithmeticProcessor(machine.opts, thread_num),DataF(DataF),P(P), + MC2(MC2),MCp(MCp),machine(machine), Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P), - input2(Proc2,MC2),inputp(Procp,MCp),privateOutput2(Proc2),privateOutputp(Procp), - external_clients(ExternalClients(P.my_num(), DataF.prep_data_dir)),binary_file_io(Binary_File_IO()) + privateOutput2(Proc2),privateOutputp(Procp), + external_clients(ExternalClients(P.my_num(), machine.prep_dir_prefix)), + binary_file_io(Binary_File_IO()) { reset(program,0); @@ -31,18 +39,22 @@ Processor::Processor(int thread_num,Data_Files& DataF,Player& P, public_output.open(get_filename(PREP_DIR "Public-Output-",true).c_str(), ios_base::out); private_output.open(get_filename(PREP_DIR "Private-Output-",true).c_str(), ios_base::out); + open_input_file(P.my_num(), thread_num); + secure_prng.ReSeed(); + + out.activate(P.my_num() == 0 or machine.opts.interactive); } -template -Processor::~Processor() +template +Processor::~Processor() { cerr << "Sent " << sent << " elements in " << rounds << " rounds" << endl; } -template -string Processor::get_filename(const char* prefix, bool use_number) +template +string Processor::get_filename(const char* prefix, bool use_number) { stringstream filename; filename << prefix; @@ -57,8 +69,8 @@ string Processor::get_filename(const char* prefix, bool use_number) } -template -void Processor::reset(const Program& program,int arg) +template +void Processor::reset(const Program& program,int arg) { reg_max2 = program.num_reg(GF2N); reg_maxp = program.num_reg(MODP); @@ -86,8 +98,8 @@ void Processor::reset(const Program& program,int arg) // If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to // determine the data structure being sent in a message. // Encryption is enabled if key material (for DH Auth Encryption and/or STS protocol) has been already setup. -template -void Processor::write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs, +template +void Processor::write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs, int socket_id, int message_type, const vector& registers) { if (socket_id >= (int)external_clients.external_client_sockets.size()) @@ -144,8 +156,8 @@ void Processor::write_socket(const RegType reg_type, const SecrecyType sec // Receive vector of 32-bit clear ints -template -void Processor::read_socket_ints(int client_id, const vector& registers) +template +void Processor::read_socket_ints(int client_id, const vector& registers) { if (client_id >= (int)external_clients.external_client_sockets.size()) { @@ -166,8 +178,8 @@ void Processor::read_socket_ints(int client_id, const vector& registe } // Receive vector of public field elements -template -void Processor::read_socket_vector(int client_id, const vector& registers) +template +void Processor::read_socket_vector(int client_id, const vector& registers) { if (client_id >= (int)external_clients.external_client_sockets.size()) { @@ -186,8 +198,8 @@ void Processor::read_socket_vector(int client_id, const vector& regis } // Receive vector of field element shares over private channel -template -void Processor::read_socket_private(int client_id, const vector& registers, bool read_macs) +template +void Processor::read_socket_private(int client_id, const vector& registers, bool read_macs) { if (client_id >= (int)external_clients.external_client_sockets.size()) { @@ -211,8 +223,8 @@ void Processor::read_socket_private(int client_id, const vector& regi } // Read socket for client public key as 8 ints, calculate session key for client. -template -void Processor::read_client_public_key(int client_id, const vector& registers) { +template +void Processor::read_client_public_key(int client_id, const vector& registers) { read_socket_ints(client_id, registers); @@ -225,8 +237,8 @@ void Processor::read_client_public_key(int client_id, const vector& r external_clients.generate_session_key_for_client(client_id, client_public_key); } -template -void Processor::init_secure_socket_internal(int client_id, const vector& registers) { +template +void Processor::init_secure_socket_internal(int client_id, const vector& registers) { external_clients.symmetric_client_commsec_send_keys.erase(client_id); external_clients.symmetric_client_commsec_recv_keys.erase(client_id); unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES]; @@ -276,8 +288,8 @@ void Processor::init_secure_socket_internal(int client_id, const vector -void Processor::init_secure_socket(int client_id, const vector& registers) { +template +void Processor::init_secure_socket(int client_id, const vector& registers) { try { init_secure_socket_internal(client_id, registers); @@ -287,8 +299,8 @@ void Processor::init_secure_socket(int client_id, const vector& regis } } -template -void Processor::resp_secure_socket(int client_id, const vector& registers) { +template +void Processor::resp_secure_socket(int client_id, const vector& registers) { try { resp_secure_socket_internal(client_id, registers); } catch (char const *e) { @@ -297,8 +309,8 @@ void Processor::resp_secure_socket(int client_id, const vector& regis } } -template -void Processor::resp_secure_socket_internal(int client_id, const vector& registers) { +template +void Processor::resp_secure_socket_internal(int client_id, const vector& registers) { external_clients.symmetric_client_commsec_send_keys.erase(client_id); external_clients.symmetric_client_commsec_recv_keys.erase(client_id); unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES]; @@ -351,8 +363,8 @@ void Processor::resp_secure_socket_internal(int client_id, const vector -void Processor::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector& data_registers) { +template +void Processor::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector& data_registers) { string filename; filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; @@ -379,8 +391,8 @@ void Processor::read_shares_from_file(int start_file_posn, int end_file_po } // Append share data in data_registers to end of file. Expects Persistence directory to exist. -template -void Processor::write_shares_to_file(const vector& data_registers) { +template +void Processor::write_shares_to_file(const vector& data_registers) { string filename; filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; @@ -451,7 +463,7 @@ void SubProcessor::POpen_Stop(const vector& reg,const Player& P,int size Proc.rounds++; } -void unzip_open(vector& dest, vector& source, const vector& reg) +inline void unzip_open(vector& dest, vector& source, const vector& reg) { int n = reg.size() / 2; source.resize(n); @@ -473,8 +485,8 @@ void SubProcessor::POpen(const vector& reg, const Player& P, POpen_Stop(dest, P, size); } -template -ostream& operator<<(ostream& s,const Processor& P) +template +ostream& operator<<(ostream& s,const Processor& P) { s << "Processor State" << endl; s << "Char 2 Registers" << endl; @@ -499,8 +511,8 @@ ostream& operator<<(ostream& s,const Processor& P) return s; } -template -void Processor::maybe_decrypt_sequence(int client_id) +template +void Processor::maybe_decrypt_sequence(int client_id) { map,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_recv_keys.find(client_id); if (it_cs != external_clients.symmetric_client_commsec_recv_keys.end()) @@ -510,8 +522,8 @@ void Processor::maybe_decrypt_sequence(int client_id) } } -template -void Processor::maybe_encrypt_sequence(int client_id) +template +void Processor::maybe_encrypt_sequence(int client_id) { map,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_send_keys.find(client_id); if (it_cs != external_clients.symmetric_client_commsec_send_keys.end()) @@ -520,10 +532,3 @@ void Processor::maybe_encrypt_sequence(int client_id) it_cs->second.second++; } } - -template class SubProcessor; -template class SubProcessor; -template class SubProcessor; - -template class Processor; -template class Processor; diff --git a/Processor/ProcessorBase.cpp b/Processor/ProcessorBase.cpp new file mode 100644 index 000000000..2b6d1f1c1 --- /dev/null +++ b/Processor/ProcessorBase.cpp @@ -0,0 +1,41 @@ +/* + * ProcessorBase.cpp + * + */ + +#include "ProcessorBase.h" +#include "Exceptions/Exceptions.h" + +#include + +void ProcessorBase::open_input_file(const string& name) +{ + cerr << "opening " << name << endl; + input_file.open(name); + input_filename = name; +} + +void ProcessorBase::open_input_file(int my_num, int thread_num) +{ + string input_file = "Player-Data/Input-P" + to_string(my_num) + "-" + to_string(thread_num); + open_input_file(input_file); +} + +long long ProcessorBase::get_input(bool interactive) +{ + if (interactive) + return get_input(cin, "standard input"); + else + return get_input(input_file, input_filename); +} + +long long ProcessorBase::get_input(istream& input_file, const string& input_filename) +{ + long long res; + input_file >> res; + if (input_file.eof()) + throw IO_Error("not enough inputs in " + input_filename); + if (input_file.fail()) + throw IO_Error("cannot read from " + input_filename); + return res; +} diff --git a/Processor/ProcessorBase.h b/Processor/ProcessorBase.h new file mode 100644 index 000000000..b1d8e3eb8 --- /dev/null +++ b/Processor/ProcessorBase.h @@ -0,0 +1,47 @@ +/* + * ProcessorBase.h + * + */ + +#ifndef PROCESSOR_PROCESSORBASE_H_ +#define PROCESSOR_PROCESSORBASE_H_ + +#include +#include +#include +using namespace std; + +class ProcessorBase +{ + // Stack + stack stacki; + + ifstream input_file; + string input_filename; + +protected: + // Optional argument to tape + int arg; + +public: + void pushi(long x) { stacki.push(x); } + void popi(long& x) { x = stacki.top(); stacki.pop(); } + + int get_arg() const + { + return arg; + } + + void set_arg(int new_arg) + { + arg=new_arg; + } + + void open_input_file(const string& name); + void open_input_file(int my_num, int thread_num); + + long long get_input(bool interactive); + long long get_input(istream& is, const string& input_filename); +}; + +#endif /* PROCESSOR_PROCESSORBASE_H_ */ diff --git a/Processor/Program.h b/Processor/Program.h index 8b8cb6d23..dc5d98d14 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -4,7 +4,7 @@ #include "Processor/Instruction.h" #include "Processor/Data_Files.h" -template class Machine; +template class Machine; /* A program is a vector of instructions */ @@ -18,10 +18,10 @@ class Program DataPositions offline_data_used; // Maximal register used - int max_reg[MAX_REG_TYPE]; + unsigned max_reg[MAX_REG_TYPE]; // Memory size used directly - int max_mem[MAX_REG_TYPE][MAX_SECRECY_TYPE]; + unsigned max_mem[MAX_REG_TYPE][MAX_SECRECY_TYPE]; // True if program contains variable-sized loop bool unknown_usage; @@ -45,15 +45,15 @@ class Program int num_reg(RegType reg_type) const { return max_reg[reg_type]; } - const int* direct_mem(RegType reg_type) const + const unsigned* direct_mem(RegType reg_type) const { return max_mem[reg_type]; } friend ostream& operator<<(ostream& s,const Program& P); // Execute this program, updateing the processor and memory // and streams pointing to the triples etc - template - void execute(Processor& Proc) const; + template + void execute(Processor& Proc) const; }; diff --git a/Processor/Replicated.cpp b/Processor/Replicated.cpp index 88ce5bcf7..87314c101 100644 --- a/Processor/Replicated.cpp +++ b/Processor/Replicated.cpp @@ -7,16 +7,19 @@ #include "Processor.h" #include "Math/FixedVec.h" #include "Math/Integer.h" +#include "Math/MaliciousRep3Share.h" #include "Tools/benchmarking.h" #include "GC/ReplicatedSecret.h" template Replicated::Replicated(Player& P) : ReplicatedBase(P), counter(0) { + assert(T::length == 2); } -ReplicatedBase::ReplicatedBase(Player& P) +ReplicatedBase::ReplicatedBase(Player& P) : P(P) { + assert(P.num_players() == 3); if (not P.is_encrypted()) insecure("unencrypted communication"); @@ -31,7 +34,7 @@ ReplicatedBase::ReplicatedBase(Player& P) template inline Replicated::~Replicated() { - cout << "Number of multiplications: " << counter << endl; + cerr << "Number of multiplications: " << counter << endl; } template @@ -43,60 +46,75 @@ void Replicated::muls(const vector& reg, assert(reg.size() % 3 == 0); int n = reg.size() / 3; - os.resize(2); - for (auto& o : os) - o.reset_write_head(); - results.resize(n * size); + init_mul(); for (int i = 0; i < n; i++) for (int j = 0; j < size; j++) { auto& x = proc.S[reg[3 * i + 1] + j]; auto& y = proc.S[reg[3 * i + 2] + j]; - typename T::value_type add_share = x[0] * y.sum() + x[1] * y[0]; - typename T::value_type tmp[2]; - for (int i = 0; i < 2; i++) - tmp[i].randomize(shared_prngs[i]); - add_share += tmp[0] - tmp[1]; - add_share.pack(os[0]); - auto& result = results[i * size + j]; - result[0] = add_share; + prepare_mul(x, y); } - proc.P.send_relative(1, os[0]); - proc.P.receive_relative(- 1, os[0]); + exchange(); for (int i = 0; i < n; i++) for (int j = 0; j < size; j++) { - auto& result = results[i * size + j]; - result[1].unpack(os[0]); - proc.S[reg[3 * i] + j] = result; + proc.S[reg[3 * i] + j] = finalize_mul(); } counter += n * size; } -template<> -void Replicated::reqbl(int n) +template +void Replicated::init_mul() +{ + os.resize(2); + for (auto& o : os) + o.reset_write_head(); + add_shares.clear(); +} + +template +typename T::clear Replicated::prepare_mul(const T& x, + const T& y) +{ + typename T::value_type add_share = x[0] * y.sum() + x[1] * y[0]; + typename T::value_type tmp[2]; + for (int i = 0; i < 2; i++) + tmp[i].randomize(shared_prngs[i]); + add_share += tmp[0] - tmp[1]; + add_share.pack(os[0]); + add_shares.push_back(add_share); + return add_share; +} + +template +void Replicated::exchange() +{ + P.send_relative(1, os[0]); + P.receive_relative(- 1, os[0]); +} + +template +T Replicated::finalize_mul() { - if ((int)n < 0 && Integer::size() * 8 != -(int)n) - { - throw Processor_Error( - "Program compiled for rings of length " + to_string(-(int)n) - + " but VM supports only " - + to_string(Integer::size() * 8)); - } - else if ((int)n > 0) - { - throw Processor_Error("Program compiled for fields not rings"); - } + T result; + result[0] = add_shares.front(); + add_shares.pop_front(); + result[1].unpack(os[0]); + return result; } template -inline void Replicated::input(SubProcessor& Proc, int n, int* r) +T Replicated::get_random() { - (void)Proc; - (void)n; - (void)r; - throw not_implemented(); + T res; + for (int i = 0; i < 2; i++) + res[i].randomize(shared_prngs[i]); + return res; } -template class Replicated; +template class Replicated>; +template class Replicated>; +template class Replicated>; +template class Replicated>; +template class Replicated>; diff --git a/Processor/Replicated.h b/Processor/Replicated.h index 81932f376..33c27a07e 100644 --- a/Processor/Replicated.h +++ b/Processor/Replicated.h @@ -19,21 +19,23 @@ template class ReplicatedMC; template class ReplicatedInput; template class ReplicatedPrivateOutput; template class Share; -template class Processor; +template class Rep3Share; class ReplicatedBase { public: PRNG shared_prngs[2]; + Player& P; + ReplicatedBase(Player& P); }; template -class Replicated : ReplicatedBase +class Replicated : public ReplicatedBase { vector os; - vector results; + deque add_shares; int counter; public: @@ -55,9 +57,12 @@ class Replicated : ReplicatedBase void muls(const vector& reg, SubProcessor& proc, ReplicatedMC& MC, int size); - static void reqbl(int n); + void init_mul(); + typename T::clear prepare_mul(const T& x, const T& y); + void exchange(); + T finalize_mul(); - static void input(SubProcessor& Proc, int n, int* r); + T get_random(); }; #endif /* PROCESSOR_REPLICATED_H_ */ diff --git a/Processor/ReplicatedInput.h b/Processor/ReplicatedInput.h index bda52ea7f..e434b700f 100644 --- a/Processor/ReplicatedInput.h +++ b/Processor/ReplicatedInput.h @@ -6,22 +6,27 @@ #ifndef PROCESSOR_REPLICATEDINPUT_H_ #define PROCESSOR_REPLICATEDINPUT_H_ -#include "Auth/ReplicatedMC.h" #include "Input.h" template -class ReplicatedInput : public InputBase +class ReplicatedInput : public InputBase { SubProcessor& proc; vector shares; + vector os; public: ReplicatedInput(SubProcessor& proc, ReplicatedMC& MC) : - InputBase(proc.Proc), proc(proc) + InputBase(proc.Proc), proc(proc) { (void) MC; } + void reset(int player); + void add_mine(const typename T::clear& input); + void add_other(int player); + void send_mine(); + void start(int player, int n_inputs); void stop(int player, vector targets); }; diff --git a/Processor/ReplicatedInput.cpp b/Processor/ReplicatedInput.hpp similarity index 57% rename from Processor/ReplicatedInput.cpp rename to Processor/ReplicatedInput.hpp index d107f3a59..af076b02a 100644 --- a/Processor/ReplicatedInput.cpp +++ b/Processor/ReplicatedInput.hpp @@ -6,32 +6,62 @@ #include "ReplicatedInput.h" #include "Processor.h" + +template +void ReplicatedInput::reset(int player) +{ + if (player == proc.P.my_num()) + { + shares.clear(); + os.resize(2); + for (auto& o : os) + o.reset_write_head(); + } +} + +template +void ReplicatedInput::add_mine(const typename T::clear& input) +{ + shares.push_back({}); + T& my_share = shares.back(); + my_share[0].randomize(proc.Proc.secure_prng); + my_share[1] = input - my_share[0]; + for (int j = 0; j < 2; j++) + { + my_share[j].pack(os[j]); + } + this->values_input++; +} + +template +void ReplicatedInput::add_other(int player) +{ + (void) player; +} + +template +void ReplicatedInput::send_mine() +{ + proc.P.send_relative(os); +} + template void ReplicatedInput::start(int player, int n_inputs) { assert(T::length == 2); - shares.resize(n_inputs); + reset(player); if (player == proc.P.my_num()) { - vector os(2); - for (int i = 0; i < n_inputs; i++) { typename T::value_type t; this->buffer.input(t); - T& my_share = shares[i]; - my_share[0].randomize(proc.Proc.secure_prng); - my_share[1] = t - my_share[0]; - for (int j = 0; j < 2; j++) - { - my_share[j].pack(os[j]); - } + add_mine(t); } - proc.P.send_relative(os); - this->values_input += n_inputs; + send_mine(); } } @@ -61,5 +91,3 @@ void ReplicatedInput::stop(int player, vector targets) } } } - -template class ReplicatedInput; diff --git a/Processor/ReplicatedMachine.h b/Processor/ReplicatedMachine.h new file mode 100644 index 000000000..a373112d0 --- /dev/null +++ b/Processor/ReplicatedMachine.h @@ -0,0 +1,19 @@ +/* + * ReplicatedMachine.h + * + */ + +#ifndef PROCESSOR_REPLICATEDMACHINE_H_ +#define PROCESSOR_REPLICATEDMACHINE_H_ + +#include +using namespace std; + +template +class ReplicatedMachine +{ +public: + ReplicatedMachine(int argc, const char** argv, string name); +}; + +#endif /* PROCESSOR_REPLICATEDMACHINE_H_ */ diff --git a/Processor/ReplicatedMachine.hpp b/Processor/ReplicatedMachine.hpp new file mode 100644 index 000000000..1da8ac3d8 --- /dev/null +++ b/Processor/ReplicatedMachine.hpp @@ -0,0 +1,100 @@ +/* + * ReplicatedMachine.cpp + * + */ + +#include "Tools/ezOptionParser.h" +#include "Tools/benchmarking.h" +#include "Networking/Server.h" +#include "Math/Rep3Share.h" +#include "Processor/Machine.h" +#include "ReplicatedMachine.h" + +template +ReplicatedMachine::ReplicatedMachine(int argc, const char** argv, + string name) +{ + ez::ezOptionParser opt; + OnlineOptions online_opts(opt, argc, argv); + opt.add( + "localhost", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Host where party 0 is running (default: localhost)", // Help description. + "-h", // Flag token. + "--hostname" // Flag token. + ); + opt.add( + "5000", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Base port number (default: 5000).", // Help description. + "-pn", // Flag token. + "--portnum" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Unencrypted communication.", // Help description. + "-u", // Flag token. + "--unencrypted" // Flag token. + ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Preprocessing from files (default for rings).", // Help description. + "-F", // Flag token. + "--file-preprocessing" // Flag token. + ); + opt.syntax = "./" + name + "-party.x [OPTIONS] "; + opt.resetArgs(); + opt.parse(argc, argv); + vector allArgs(opt.firstArgs); + allArgs.insert(allArgs.end(), opt.lastArgs.begin(), opt.lastArgs.end()); + + int playerno; + string progname; + + if (allArgs.size() != 3) + { + cerr << "ERROR: incorrect number of arguments to " << argv[0] << endl; + cerr << "Arguments given were:\n"; + for (unsigned int j = 1; j < allArgs.size(); j++) + cout << "'" << *allArgs[j] << "'" << endl; + string usage; + opt.getUsage(usage); + cout << usage; + exit(1); + } + else + { + playerno = atoi(allArgs[1]->c_str()); + progname = *allArgs[2]; + + } + + int pnb; + string hostname; + opt.get("-pn")->getInt(pnb); + opt.get("-h")->getString(hostname); + bool use_encryption = not opt.get("-u")->isSet; + bool live_prep = not opt.get("-F")->isSet; + + if (not use_encryption) + insecure("unencrypted communication"); + Names N; + Server* server = Server::start_networking(N, playerno, 3, hostname, pnb); + + Machine(playerno, N, progname, "empty", 128, + gf2n::default_degree(), 0, 0, 0, 0, 0, use_encryption, + live_prep, online_opts).run(); + + if (server) + delete server; +} diff --git a/Processor/ReplicatedPrep.cpp b/Processor/ReplicatedPrep.cpp new file mode 100644 index 000000000..ee8f8f348 --- /dev/null +++ b/Processor/ReplicatedPrep.cpp @@ -0,0 +1,230 @@ +/* + * ReplicatedPrep.cpp + * + */ + +#include "ReplicatedPrep.h" +#include "Math/gfp.h" +#include "Math/MaliciousRep3Share.h" +#include "Auth/ReplicatedMC.h" + +template +ReplicatedPrep::ReplicatedPrep() : protocol(0) +{ +} + +template +void ReplicatedPrep::buffer_triples() +{ + assert(protocol != 0); + auto& triples = this->triples; + triples.resize(this->buffer_size); + protocol->init_mul(); + for (size_t i = 0; i < triples.size(); i++) + { + auto& triple = triples[i]; + triple[0] = protocol->get_random(); + triple[1] = protocol->get_random(); + protocol->prepare_mul(triple[0], triple[1]); + } + protocol->exchange(); + for (size_t i = 0; i < triples.size(); i++) + triples[i][2] = protocol->finalize_mul(); +} + +template +void BufferPrep::get_three(Dtype dtype, T& a, T& b, T& c) +{ + if (dtype != DATA_TRIPLE) + throw not_implemented(); + + if (triples.empty()) + buffer_triples(); + + a = triples.back()[0]; + b = triples.back()[1]; + c = triples.back()[2]; + triples.pop_back(); +} + +template +void ReplicatedPrep::buffer_squares() +{ + assert(protocol != 0); + auto& squares = this->squares; + squares.resize(this->buffer_size); + protocol->init_mul(); + for (size_t i = 0; i < squares.size(); i++) + { + auto& square = squares[i]; + square[0] = protocol->get_random(); + protocol->prepare_mul(square[0], square[0]); + } + protocol->exchange(); + for (size_t i = 0; i < squares.size(); i++) + squares[i][1] = protocol->finalize_mul(); +} + +template +void ReplicatedPrep::buffer_inverses() +{ + assert(protocol != 0); + ReplicatedMC MC; + BufferPrep::buffer_inverses(MC, protocol->P); +} + +template +void BufferPrep::buffer_inverses(MAC_Check_Base& MC, Player& P) +{ + vector> triples(buffer_size); + vector c; + for (int i = 0; i < buffer_size; i++) + { + get_three(DATA_TRIPLE, triples[i][0], triples[i][1], triples[i][2]); + c.push_back(triples[i][2]); + } + vector c_open; + MC.POpen(c_open, c, P); + for (size_t i = 0; i < c.size(); i++) + if (c_open[i] != 0) + inverses.push_back({triples[i][0], triples[i][1] / c_open[i]}); + triples.clear(); + if (inverses.empty()) + throw runtime_error("products were all zero"); + MC.Check(P); +} + +template +void BufferPrep::get_two(Dtype dtype, T& a, T& b) +{ + switch (dtype) + { + case DATA_SQUARE: + { + if (squares.empty()) + buffer_squares(); + + a = squares.back()[0]; + b = squares.back()[1]; + squares.pop_back(); + return; + } + case DATA_INVERSE: + { + while (inverses.empty()) + buffer_inverses(); + + a = inverses.back()[0]; + b = inverses.back()[1]; + inverses.pop_back(); + return; + } + default: + throw not_implemented(); + } +} + +template<> +void ReplicatedPrep>::buffer_bits() +{ + assert(protocol != 0); +#ifdef BIT_BY_SQUARE + vector, 2>> squares(buffer_size); + vector> s; + for (int i = 0; i < buffer_size; i++) + { + get_two(DATA_SQUARE, squares[i][0], squares[i][1]); + s.push_back(squares[i][1]); + } + vector open; + ReplicatedMC>().POpen(open, s, protocol->P); + auto one = Rep3Share(1, protocol->P.my_num()); + for (size_t i = 0; i < s.size(); i++) + if (open[i] != 0) + bits.push_back((squares[i][0] / open[i].sqrRoot() + one) / 2); + squares.clear(); + if (bits.empty()) + throw runtime_error("squares were all zero"); +#else + vector>> player_bits(3, vector>(buffer_size)); + vector os(2); + SeededPRNG G; + for (auto& share : player_bits[protocol->P.my_num()]) + { + share.randomize_to_sum(G.get_bit(), G); + for (int i = 0; i < 2; i++) + share[i].pack(os[i]); + } + auto& prot = *protocol; + prot.P.send_relative(os); + prot.P.receive_relative(os); + for (int i = 0; i < 2; i++) + for (auto& share : player_bits[prot.P.get_player(i + 1)]) + share[i].unpack(os[i]); + prot.init_mul(); + for (int i = 0; i < buffer_size; i++) + prot.prepare_mul(player_bits[0][i], player_bits[1][i]); + prot.exchange(); + vector> first_xor(buffer_size); + gfp two(2); + for (int i = 0; i < buffer_size; i++) + first_xor[i] = player_bits[0][i] + player_bits[1][i] - prot.finalize_mul() * two; + prot.init_mul(); + for (int i = 0; i < buffer_size; i++) + prot.prepare_mul(player_bits[2][i], first_xor[i]); + prot.exchange(); + bits.resize(buffer_size); + for (int i = 0; i < buffer_size; i++) + bits[i] = player_bits[2][i] + first_xor[i] - prot.finalize_mul() * two; +#endif +} + +template<> +void ReplicatedPrep>::buffer_bits() +{ + assert(protocol != 0); + for (int i = 0; i < DIV_CEIL(buffer_size, gf2n::degree()); i++) + { + Rep3Share share = protocol->get_random(); + for (int j = 0; j < gf2n::degree(); j++) + { + bits.push_back(share & 1); + share >>= 1; + } + } +} + +template +void BufferPrep::get_one(Dtype dtype, T& a) +{ + if (dtype != DATA_BIT) + throw not_implemented(); + + while (bits.empty()) + buffer_bits(); + + a = bits.back(); + bits.pop_back(); +} + +template +void BufferPrep::get_input(T& a, typename T::clear& x, int i) +{ + (void) a, (void) x, (void) i; + throw not_implemented(); +} + +template +void BufferPrep::get(vector& S, DataTag tag, + const vector& regs, int vector_size) +{ + (void) S, (void) tag, (void) regs, (void) vector_size; + throw not_implemented(); +} + +template class BufferPrep>; +template class BufferPrep>; +template class BufferPrep>; +template class BufferPrep>; +template class ReplicatedPrep>; +template class ReplicatedPrep>; diff --git a/Processor/ReplicatedPrep.h b/Processor/ReplicatedPrep.h new file mode 100644 index 000000000..9e4238f4c --- /dev/null +++ b/Processor/ReplicatedPrep.h @@ -0,0 +1,61 @@ +/* + * ReplicatedPrep.h + * + */ + +#ifndef PROCESSOR_REPLICATEDPREP_H_ +#define PROCESSOR_REPLICATEDPREP_H_ + +#include "Networking/Player.h" +#include "Data_Files.h" + +#include + +template +class BufferPrep : public Preprocessing +{ +protected: + static const int buffer_size = 1000; + + vector> triples; + vector> squares; + vector> inverses; + vector bits; + + virtual void buffer_triples() = 0; + virtual void buffer_squares() = 0; + virtual void buffer_inverses() = 0; + virtual void buffer_bits() = 0; + + virtual void buffer_inverses(MAC_Check_Base& MC, Player& P); + +public: + virtual ~BufferPrep() {} + + void get_three(Dtype dtype, T& a, T& b, T& c); + void get_two(Dtype dtype, T& a, T& b); + void get_one(Dtype dtype, T& a); + void get_input(T& a, typename T::clear& x, int i); + void get(vector& S, DataTag tag, const vector& regs, + int vector_size); +}; + +template +class ReplicatedPrep : public BufferPrep +{ + template friend class MaliciousRepPrep; + + Replicated* protocol; + + void buffer_triples(); + void buffer_squares(); + void buffer_inverses(); + void buffer_bits(); + +public: + ReplicatedPrep(); + + void set_protocol(Replicated& protocol) { this->protocol = &protocol; } +}; + +#endif /* PROCESSOR_REPLICATEDPREP_H_ */ diff --git a/Processor/ReplicatedPrivateOutput.cpp b/Processor/ReplicatedPrivateOutput.hpp similarity index 91% rename from Processor/ReplicatedPrivateOutput.cpp rename to Processor/ReplicatedPrivateOutput.hpp index 0d41a4570..616e5f026 100644 --- a/Processor/ReplicatedPrivateOutput.cpp +++ b/Processor/ReplicatedPrivateOutput.hpp @@ -28,5 +28,3 @@ void ReplicatedPrivateOutput::stop(int player, int source) { (void)player, (void)source; } - -template class ReplicatedPrivateOutput; diff --git a/Processor/SPDZ.cpp b/Processor/SPDZ.cpp index 3d2f697af..8e5720be3 100644 --- a/Processor/SPDZ.cpp +++ b/Processor/SPDZ.cpp @@ -6,93 +6,9 @@ #include "SPDZ.h" #include "Processor.h" #include "Math/Share.h" +#include "Auth/MAC_Check.h" -template -void SPDZ::muls(const vector& reg, SubProcessor >& proc, MAC_Check& MC, - int size) -{ - assert(reg.size() % 3 == 0); - int n = reg.size() / 3; - vector >& shares = proc.Sh_PO; - vector& opened = proc.PO; - shares.clear(); - vector, 3>> triples(n * size); - auto triple = triples.begin(); - - for (int i = 0; i < n; i++) - for (int j = 0; j < size; j++) - { - proc.DataF.get(DATA_TRIPLE, triple->data()); - for (int k = 0; k < 2; k++) - shares.push_back(proc.S[reg[i * 3 + k + 1] + j] - (*triple)[k]); - triple++; - } - - MC.POpen_Begin(opened, shares, proc.P); - MC.POpen_End(opened, shares, proc.P); - auto it = opened.begin(); - triple = triples.begin(); - - for (int i = 0; i < n; i++) - for (int j = 0; j < size; j++) - { - T masked[2]; - Share& tmp = (*triple)[2]; - for (int k = 0; k < 2; k++) - { - masked[k] = *it++; - tmp.add(masked[k] * (*triple)[1 - k]); - } - tmp.add(tmp, masked[0] * masked[1], proc.P.my_num(), MC.get_alphai()); - proc.S[reg[i * 3] + j] = tmp; - triple++; - } -} - -template<> -void SPDZ::reqbl(int n) -{ - if ((int)n > 0 && gfp::pr() < bigint(1) << (n-1)) - { - cout << "Tape requires prime of bit length " << n << endl; - throw invalid_params(); - } - else if ((int)n < 0) - { - throw Processor_Error("Program compiled for rings not fields"); - } -} - -template -inline void SPDZ::input(SubProcessor>& Proc, int n, int* r) -{ - T rr, t, tmp; - Proc.DataF.get_input(Proc.get_S_ref(r[0]),rr,n); - octetStream o; - if (n==Proc.P.my_num()) - { - T xi; -#ifdef DEBUG - printf("Enter your input : \n"); -#endif - long x; - cin >> x; - t.assign(x); - t.sub(t,rr); - t.pack(o); - Proc.P.send_all(o); - xi.add(t,Proc.get_S_ref(r[0]).get_share()); - Proc.get_S_ref(r[0]).set_share(xi); - } - else - { - Proc.P.receive_player(n,o); - t.unpack(o); - } - tmp.mul(t, Proc.MC.get_alphai()); - tmp.add(Proc.get_S_ref(r[0]).get_mac(),tmp); - Proc.get_S_ref(r[0]).set_mac(tmp); -} +#include "Input.hpp" template class SPDZ; template class SPDZ; diff --git a/Processor/SPDZ.h b/Processor/SPDZ.h index cd030d826..38f5c6a77 100644 --- a/Processor/SPDZ.h +++ b/Processor/SPDZ.h @@ -6,6 +6,8 @@ #ifndef PROCESSOR_SPDZ_H_ #define PROCESSOR_SPDZ_H_ +#include "Beaver.h" + #include using namespace std; @@ -13,15 +15,13 @@ template class SubProcessor; template class MAC_Check; template class Share; class Player; -template class Processor; template -class SPDZ +class SPDZ : public Beaver> { public: - SPDZ(Player& P) + SPDZ(Player& P) : Beaver>(P) { - (void) P; } static void assign(T& share, const T& clear, int my_num) @@ -31,13 +31,6 @@ class SPDZ else share = 0; } - - static void muls(const vector& reg, SubProcessor >& proc, MAC_Check& MC, - int size); - - static void reqbl(int n); - - static void input(SubProcessor>& Proc, int n, int* r); }; #endif /* PROCESSOR_SPDZ_H_ */ diff --git a/Processor/config.h b/Processor/config.h index b50592533..266234d98 100644 --- a/Processor/config.h +++ b/Processor/config.h @@ -13,7 +13,6 @@ #error REPLICATED flag is obsolete #endif -typedef Share sgf2n; typedef Share sgfp; #endif /* PROCESSOR_CONFIG_H_ */ diff --git a/Programs/Source/blink.mpc b/Programs/Source/blink.mpc index 07adec609..37bb17761 100644 --- a/Programs/Source/blink.mpc +++ b/Programs/Source/blink.mpc @@ -12,19 +12,32 @@ if len(program.args) > 1: else: n_batches = 78 +if len(program.args) > 2: + m_batches = int(program.args[2]) +else: + m_batches = n_batches + +if len(program.args) > 3: + n_threads = int(program.args[3]) +else: + n_threads = n_batches + batch_size = 64 n = n_batches * batch_size +m = m_batches * batch_size l = 16 a = Matrix(n, l, full_t) -b = Matrix(n, l, full_t) +b = Matrix(m, l, full_t) t = sbitint.get_type(int(math.ceil(math.log(batch_size * l, 2))) + 1) -matches = Matrix(n, n, t.bit_type) -mismatches = Matrix(n, n, t) +matches = Matrix(n, m, t.bit_type) +mismatches = Matrix(n, m, t) threshold = MemValue(t(10)) for i in range(n): for j in range(l): a[i][j] = full_t.get_input_from(0) +for i in range(m): + for j in range(l): b[i][j] = full_t.get_input_from(1) # test, create match between a[0] and b[1] but no match for a[1] @@ -34,10 +47,10 @@ a[0][0] = -1 b[1][0] = -1 a[1][1] = -1 -@for_range_multithread(n_batches, 1, n) +@for_range_multithread(n_threads, 1, n) def _(i): print_ln('%s', i) - @for_range_parallel(100, n_batches) + @for_range_parallel(m_batches, m_batches) def _(j): j = j * batch_size av = sbitintvec.from_matrix((a[i][kk] for _ in range(batch_size)) \ @@ -47,35 +60,35 @@ def _(i): res = xor_op(av, bv).popcnt() mismatches[i].set_range(j, (t(x) for x in res.elements())) -@for_range_multithread(n_batches, 8, n) +@for_range_multithread(n_threads, 8, n) def _(i): print_ln('%s', i) - @for_range_parallel(100, n_batches) + @for_range_parallel(m_batches, m_batches) def _(j): j = j * batch_size v = sbitintvec(mismatches[i].get_range(j, batch_size)) vv = sbitintvec([threshold.read()] * batch_size) matches[i].set_range(j, v.less_than(vv, 10).elements()) -mg = MultiArray([n_batches, n, t.n], full_t) -ag = Matrix(n_batches, n, full_t) +mg = MultiArray([n_batches, m, t.n], full_t) +ag = Matrix(n_batches, m, full_t) -@for_range_multithread(n_batches, 1, n_batches) +@for_range_multithread(n_threads, 1, n_batches) def _(i): - m = mg[i] + mgi = mg[i] a = ag[i] i = i * batch_size print_ln('best %s', i) - @for_range(n) + @for_range(m) def _(j): - m[j].assign(sbitintvec(mismatches[i + k][j] + mgi[j].assign(sbitintvec(mismatches[i + k][j] for k in range(batch_size)).v) - m = [sbitintvec.from_vec(m[j]) for j in range(n)] + mgi = [sbitintvec.from_vec(mgi[j]) for j in range(m)] def reducer(a, b): c = a[0].less_than(b[0]) return util.if_else(c, (a[0], a[1] + [0] * len(b[1])), (b[0], [0] * len(a[1]) + b[1])) - mm = util.tree_reduce(reducer, ((x, [2**batch_size - 1]) for x in m)) + mm = util.tree_reduce(reducer, ((x, [2**batch_size - 1]) for x in mgi)) a.assign(mm[1]) @for_range_parallel(100, len(a)) def _(j): diff --git a/Programs/Source/fixed_point_tutorial.mpc b/Programs/Source/fixed_point_tutorial.mpc deleted file mode 100644 index bfe218837..000000000 --- a/Programs/Source/fixed_point_tutorial.mpc +++ /dev/null @@ -1,56 +0,0 @@ -program.bit_length = 80 -print "program.bit_length: ", program.bit_length -program.security = 40 - -sfix.set_precision(16, 32) -cfix.set_precision(16, 32) - -n = 10 -m = 5 - -# array of fixed points -A = Array(n, sfix) - -for i in range(n): - A[i] = sfix(i) - -print_ln('array of fixed points') -for i in range(n): - print_ln('%s', A[i].reveal()) - -# matrix of fixed points -M = Matrix(n, m, sfix) - -for i in range(n): - for j in range(m): - M[i][j] = sfix(i*j) - -print_ln('matrix of fixed points') -for i in range(n): - for j in range(m): - print_str('%s ', M[i][j].reveal()) - print_ln(' ') - - -# assign scalar to sfix -A[5] = sfix(1.12345) -print_ln('%s', A[5].reveal()) - -AC = Array(n, cfix) - -for i in range(n): - AC[i] = cfix(1.5 * i) - -for i in range(n): - print_ln('%s', AC[i]) - -# assign sint to sfix -s = sint(10) -sa = sfix(); sa.load_int(s) -print_ln('successfully assigned sint to sfix %s', sa.reveal()) - -# division between fixed points -sb = sfix(2.5) -print_ln('division between %s %s = %s', sa.reveal(), sb.reveal(), (sa/sb).reveal()) - - diff --git a/Programs/Source/gc_fixed_point_tutorial.mpc b/Programs/Source/gc_fixed_point_tutorial.mpc deleted file mode 100644 index fd7f392ff..000000000 --- a/Programs/Source/gc_fixed_point_tutorial.mpc +++ /dev/null @@ -1,44 +0,0 @@ -sfix = sbitfix -sint = sbitint.get_type(20) - -sfix.set_precision(16, 32) - -n = 10 -m = 5 - -# array of fixed points -A = Array(n, sfix) - -for i in range(n): - A[i] = sfix(i) - -print_ln('mrray of fixed points') -for i in range(n): - print_ln('%s', A[i].reveal()) - -# matrix of fixed points -M = Matrix(n, m, sfix) - -for i in range(n): - for j in range(m): - M[i][j] = sfix(i*j) - -print_ln('matrix of fixed points') -for i in range(n): - for j in range(m): - print_str('%s ', M[i][j].reveal()) - print_ln(' ') - - -# assign scalar to sfix -A[5] = sfix(1.12345) -print_ln('%s', A[5].reveal()) - -# assign sint to sfix -s = sint(10) -sa = sfix(); sa.load_int(s) -print_ln('successfully assigned sint to sfix %s', sa.reveal()) - -# division between fixed points -sb = sfix(2.5) -print_ln('division between %s %s = %s', sa.reveal(), sb.reveal(), (sa/sb).reveal()) diff --git a/Programs/Source/gc_tutorial.mpc b/Programs/Source/gc_tutorial.mpc deleted file mode 100644 index 48591c33a..000000000 --- a/Programs/Source/gc_tutorial.mpc +++ /dev/null @@ -1,50 +0,0 @@ -# sbitint: factory for signed integer types - -sint = sbitint.get_type(32) - -def test(a, b, value_type=None): - try: - a = a.reveal() - except AttributeError: - pass - import inspect - print_ln('line %s: diff %s, got %s, expected %s', - inspect.currentframe().f_back.f_lineno, \ - (a ^ cbits(b, n=a.n)).reveal(), a, hex(b)) - -a = sint(1) -b = sint(2) - -test(a + b, 3) -test(a + a, 2) -test(a * b, 2) -test(a * a, 1) -test(a - b, -1) -test(a < b, 1) -test(a <= b, 1) -test(a >= b, 0) -test(a > b, 0) -test(a == b, 0) -test(a != b, 1) - -clear_a = a.reveal() - -# arrays and loops - -a = Array(100, sint) - -@for_range(100) -def f(i): - a[i] = sint(i)**2 - -test(a[99], 99**2) - -# conditional - -if_then(regint(0)) -a[0] = 123 -else_then() -a[0] = 789 -end_if() - -test(a[0], 789) diff --git a/Programs/Source/test_sbitfix.mpc b/Programs/Source/test_sbitfix.mpc index ac9a197fa..b8dd6f4fe 100644 --- a/Programs/Source/test_sbitfix.mpc +++ b/Programs/Source/test_sbitfix.mpc @@ -1,10 +1,13 @@ from Compiler.GC.types import sbitfix, cbits +import math -#sbitfix.set_precision(3, 7) +sbitfix.set_precision(16, 32) def test(a, b, value_type=None): try: b = int(round((b * (1 << a.f)))) + if b < 0: + b += 2 ** sbitfix.k a = a.v.reveal() except AttributeError: pass @@ -16,8 +19,8 @@ def test(a, b, value_type=None): print_ln('%s: %s %s %s', inspect.currentframe().f_back.f_lineno, \ (a ^ cbits(b)).reveal(), a, (b)) -aa = 5321.0 -bb = 142.0 +aa = 53.21 +bb = 142 for a_sign, b_sign in (1, -1), (-1, -1): a = a_sign * aa @@ -33,6 +36,16 @@ for a_sign, b_sign in (1, -1), (-1, -1): test(-sa, -a) + test(sa + b, a+b) + test(sa - b, a-b) + test(sa * b, a*b) + test(sa / b, a/b) + + test(a + sb, a+b) + test(a - sb, a-b) + test(a * sb, a*b) + test(a / sb, a/b) + a = 126 b = 125 sa = sbitfix(a) @@ -51,3 +64,31 @@ test(sa >= sb, int(a>=b)) test(sa == sb, int(a==b)) test(sa != sb, int(a!=b)) test(sa != sa, int(a!=a)) + +test(sa < b, int(a b, int(a>b)) +test(sa <= b, int(a<=b)) +test(sa >= b, int(a>=b)) +test(sa == b, int(a==b)) +test(sa != b, int(a!=b)) +test(sa != a, int(a!=a)) + +test(a < sb, int(a sb, int(a>b)) +test(a <= sb, int(a<=b)) +test(a >= sb, int(a>=b)) +test(a == sb, int(a==b)) +test(a != sb, int(a!=b)) +test(a != sa, int(a!=a)) diff --git a/Programs/Source/tutorial.mpc b/Programs/Source/tutorial.mpc index 9616cd1a1..7567e6054 100644 --- a/Programs/Source/tutorial.mpc +++ b/Programs/Source/tutorial.mpc @@ -1,19 +1,28 @@ -def test(actual, expected): - if isinstance(actual, (sint, sgf2n)): - actual = actual.reveal() - print_ln('expected %s, got %s', expected, actual) +# sint: secret integers -# cint: clear integers modulo p -# sint: secret integers modulo p +# you can assign public numbers to sint a = sint(1) -b = cint(2) +b = sint(2) + +def test(actual, expected): + + # you can reveal a number in order to print it + + actual = actual.reveal() + print_ln('expected %s, got %s', expected, actual) + +# some arithmetic works as expected test(a + b, 3) -test(a + a, 2) test(a * b, 2) -test(a * a, 1) test(a - b, -1) + +# but division doesn't, don't do the following +# test(b / a, 2) + +# comparisons produce 1 for true and 0 for false + test(a < b, 1) test(a <= b, 1) test(a >= b, 0) @@ -21,36 +30,86 @@ test(a > b, 0) test(a == b, 0) test(a != b, 1) -clear_a = a.reveal() +# if_else() can be used instead of branching +# let's find out the larger number +test((a < b).if_else(b, a), 2) -# sgfn2/cgf2n: secret/clear elements of GF(2^n) +# arrays and loops work as follows -a = sgf2n(1) -b = cgf2n(2) +a = Array(100, sint) -test(a + b, 3) -test(a + a, 0) -test(a * b, 2) -test(a * a, 1) +@for_range(100) +def f(i): + a[i] = sint(i) * sint(i - 1) + +test(a[99], 99 * 98) + +# if you use loops, use Array to store results +# don't do this +# @for_range(100) +# def f(i): +# a = sint(i) +# test(a, 99) + +# sfix: fixed-point numbers + +# set the precision after the dot and in total + +sfix.set_precision(16, 32) + +# you can do all basic arithmetic with sfix, including division + +a = sfix(2) +b = sfix(-0.1) + +test(a + b, 1.9) +test(a - b, 2.1) +test(a * b, -0.2) +test(a / b, -20) +test(a < b, 0) +test(a <= b, 0) +test(a >= b, 1) +test(a > b, 1) test(a == b, 0) test(a != b, 1) -# arrays and loops +test((a < b).if_else(a, b), -0.1) -a = Array(100, sint) +# now let's do a computation with private inputs +# party 0 supplies three number and party 1 supplies three percentages +# we want to compute the weighted mean -@for_range(100) -def f(i): - a[i] = sint(i)**2 +print_ln('Party 0: please input three numbers not adding up to zero') +print_ln('Party 1: please input any three numbers') + +data = Matrix(3, 2, sfix) + +# use Python loops for compile-time optimization + +for i in range(3): + for j in range(2): + data[i][j] = sfix.from_sint(sint.get_input_from(j)) + +# compute weighted average + +weight_total = sum(point[0] for point in data) +result = sum(point[0] * point[1] for point in data) / weight_total + +# the following only works with arithmetic circuits + +# @if_e((sum(point[0] for point in data) != 0).reveal()) +# def _(): +# print_ln('weighted average: %s', result.reveal()) +# @else_ +# def _(): +# print_ln('your inputs made no sense') -test(a[99], 99**2) +# so we output even an invalid result (the weights adding up to zero) -# conditional +print_ln('weighted average: %s', result.reveal()) -if_then(cint(0)) -a[0] = 123 -else_then() -a[0] = 789 -end_if() +# but we warn the user +# note that the we don't reveal the weight sum, only the comparison -test(a[0], 789) +print_ln_if((sum(point[0] for point in data) == 0).reveal(), \ + 'but the inputs were invalid (weights add up to zero)') diff --git a/README.md b/README.md index 7d828b462..053d0e488 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,31 @@ Software to benchmark various secure multi-party computation (MPC) protocols such as SPDZ, MASCOT, Overdrive, BMR garbled circuits (evaluation only), Yao's garbled circuits, and computation based on -semi-honest 3-party replicated secret sharing. +semi-honest three-party replicated secret sharing (with an honest majority). + +#### TL;DR + +This requires `sudo` rights as well as a working toolchain installed +for the first step, refer to [the requirements](#requirements) +otherwise. It will execute [the +tutorial](Programs/Source/tutorial.mpc) with three +parties, an honest majority, and malicious security. + +``` +make -j 8 mpir +make -j 8 tldr +./compile.py tutorial +Scripts/setup-replicated.sh +echo 1 2 3 > Player-Data/Input-P0-0 +echo 1 2 3 > Player-Data/Input-P1-0 +Scripts/mal-rep-field.sh tutorial +``` #### Preface The primary aim of this software is to benchmark the same computation in various protocols in order to compare the performance. In order to -do, it uses functionality that is not secure. Many MPC protocols +do, it sometimes uses functionality that is not secure. Many MPC protocols involve several phases that have to be executed in a secure manner for the whole protocol to be sure. However, for benchmarking it does not make a difference whether a previous phase was executed securely or @@ -17,13 +35,9 @@ whether its output were generated insecurely. The focus on this software is to benchmark each phases individually rather than running the whole sequence of phases at once. -Furthermore, the replicated secret sharing implementation currently -uses unencrypted communication which reveals all information to an -adversary wiretapping all connections. - In order to make it clear where insecure functionality is used, it is disabled by default but can be activated as explained in the section -on compilation. Many parts of the software will not work without doing so. +on compilation. Some parts of the software will not work without doing so. #### History @@ -72,7 +86,7 @@ Overdrive are the names for two alternative preprocessing phases to go with the SPDZ online phase. In the section on computation we will explain how to run the SPDZ -online phase and semi-honest 3-party replicated secret sharing as well +online phase and the various honest-majority three-party comptuation as well as BMR and Yao's garbled circuits. The section on offline phases will then explain how to benchmark the @@ -85,8 +99,8 @@ compute the preprocessing time for a particulor computation. - MPIR library, compiled with C++ support (use flag --enable-cxx when running configure) - libsodium library, tested against 1.0.16 - OpenSSL, tested against 1.1.0 - - Boost.Asio with SSL support, tested against 1.65 - - Boost.Thread for BMR, tested against 1.65 + - Boost.Asio with SSL support (`libboost-dev` on Ubuntu), tested against 1.65 + - Boost.Thread for BMR (`libboost-thread-dev` on Ubuntu), tested against 1.65 - CPU supporting AES-NI, PCLMUL, AVX2 - Python 2.x - NTL library for the SPDZ-2 and Overdrive offline phases (optional; tested with NTL 10.5) @@ -96,7 +110,7 @@ compute the preprocessing time for a particulor computation. 1) Edit `CONFIG` or `CONFIG.mine` to your needs: - - To benchmark anything other than replicated secret sharing for binary circuits, Yao's garbled circuits, or covertly secure SPDZ, add the following line at the top: `MY_CFLAGS = -DINSECURE` + - To benchmark malicious SPDZ, some honest-majority three-party computation (semi-honest modulo 2^64 or malicious binary), or BMR, add the following line at the top: `MY_CFLAGS = -DINSECURE` - `PREP_DIR` should point to should be a local, unversioned directory to store preprocessing data (default is `Player-Data` in the current directory). - For the SPDZ-2 and Overdrive offline phases, set `USE_NTL = 1` and `MOD = -DMAX_MOD_SZ=6`. - To use GF(2^40), in particular for the SPDZ-2 offline phase, set `USE_GF2N_LONG = 0`. This will deactive anything that requires GF(2^128) such as MASCOT. @@ -109,25 +123,16 @@ or `CONFIG.mine`. # Benchmarking computation See `Programs/Source/` for some example MPC programs, in particular -`tutorial.mpc` and `fixed_point_tutorial.mpc` for arithmetic circuits -and `gc_tutorial.mpc` and `gc_fixed_point_tutorial.mpc` for binary -circuits. - -Because the focus is on benchmarking, the facilities for private -inputs to communication are rather rudimentary. For arithmetic -circuits, `sint.get_raw_input_from()` reads internal representations -from `Player-Data/Private-Input-`, and for binary circuits -`sbits.get_input_from()` reads numbers in ASCII from -`Player-Data/Input-P-`. +`tutorial.mpc`. ## Arithmetic circuits +### SPDZ + All programs required in this section can be compiled with the target `online`: `make -j 8 online` -### SPDZ - #### To setup for benchmarking the online phase This requires the INSECURE flag to be set before compilation as explained above. For a secure offline phase, see the section on SPDZ-2 below. @@ -194,11 +199,50 @@ Player-Data Programs $ ../spdz/Scripts/run-online.sh test ``` -### Semi-honest 3-party replicated secret sharing modulo 2^64 +### Three-party honest-majority computation modulo a prime -Compile the virtual machine: +Compile the virtual machines: + +`make -j 8 rep-field` + +Run setup to generate a 128-bit prime. This will also generate SSL keys and certificates. See the section replicated secret sharing for binary circuits below for details. + +`Scripts/setup-replicated.sh` + +In order to compile a program, use `./compile.py`, for example: + +`./compile.py tutorial` + +Running the computation is similar to SPDZ but you will need to start +three parties: + +`./malicious-rep-field-party.x -I 0 tutorial` -`make -j 8 replicated-ring-party.x` +`./malicious-rep-field-party.x -I 1 tutorial` (in a separate terminal) + +`./malicious-rep-field-party.x -I 2 tutorial` (in a separate terminal) + +The `-I` enable interactive inputs, and in the tutorial party 0 and 1 +will be asked to provide three numbers. Using +`./replicated-field-party.x` will provide semi-honest security instead +of malicious. + +You can run all parties at once with + +`Scripts/mal-rep-field.sh tutorial` + +for malicious security or + +`Scripts/rep-field.sh tutorial` + +for semi-honest security. In this case, the inputs are read from +`Player-Data/Input-P-0`. + +### Semi-honest honest-majority computation modulo 2^64 + +Compile the necessary programs: + +`make -j 8 rep-ring` Run setup to create necessary files and random bits (needed for comparisons etc.): @@ -210,31 +254,41 @@ In order to compile a program, use `./compile.py -R 64`, for example: `./compile.py -R 64 tutorial` -Running the computation is similar to SPDZ but you will need to start -three parties: +Then, run the three parties as follows: -`./replicated-ring-party.x 0 tutorial` +`./replicated-ring-party.x -I 0 tutorial` -`./replicated-ring-party.x 1 tutorial` (in a separate terminal) +`./replicated-ring-party.x -I 1 tutorial` (in a separate terminal) -`./replicated-ring-party.x 2 tutorial` (in a separate terminal) +`./replicated-ring-party.x -I 2 tutorial` (in a separate terminal) or `Scripts/ring.sh tutorial` +Again, `-I` activates interactive input, otherwise inputs are read +from `Player-Data/Input-P-0`. + ## Binary circuits -Compilation is the same as for SPDZ (no need to use the `-R` -argument), but you will need to use different types instead of `sint` -and `sfix`. See `gc_tutorial.mpc` and `gc_fixed_point_tutorial.mpc` in -`Programs/Source`. +For binary circuits, you can compile your programs giving the desired +integer length, for example: -### Semi-honest 3-party replicated secret sharing +`./compile.py -B 32 tutorial` -Compile the virtual machine: +for using 32-bit integers with `sint` and 16/16-bit fixed-point +numbers for `sfix`. The latter is independent of the `-B` option and +can be changed with `sfix.set_precision`. See [the +tutorial](Programs/Source/tutorial.mpc). + +Alternatively, you can directly use `sbitint.get_type(n)` and +`sbitfix` instead of `sint`and `sfix`, respectively. + +### Honest-majority three-party computation -`make -j 8 replicated-bin-party.x` +Compile the virtual machines: + +`make -j 8 rep-bin` Set up SSL certificate and keys: @@ -244,9 +298,20 @@ The programs expect the keys and certificates to be in `Player-Data/P.key` an After compilating the mpc file, run as follows: -`replicated-bin-party.x -h -p <0/1/2> gc_tutorial` +`replicated-bin-party.x [-I] -h -p <0/1/2> tutorial` + +When running locally, you can omit the host argument. As above, `-I` +activates interactive input, otherwise inputs are read from +`Player-Data/Input-P-0`. + +The program above runs a semi-honest computation. For malicious +security you have to generate some preprocessing data (requires +compilation with the INSECURE flag): + +`Scripts/setup-online.sh 3` -When running locally, you can omit the host argument. +and then use `malicious-rep-bin-party.x` instead of +`replicated-bin-party.x`. ### Yao's garbled circuits @@ -257,10 +322,12 @@ Compile the virtual machine: `make -j 8 yao` After compilating the mpc file, run as follows: - - Garbler: ```./yao-player.x -p 0 ``` - - Evaluator: ```./yao-player.x -p 1 -h ``` + - Garbler: ```./yao-player.x [-I] -p 0 ``` + - Evaluator: ```./yao-player.x [-I] -p 1 -h ``` -When running locally, you can omit the host argument. +When running locally, you can omit the host argument. As above, `-I` +activates interactive input, otherwise inputs are read from +`Player-Data/Input-P-0`. By default, the circuit is garbled at once and stored on the evaluator side before evaluating. You can activate a more continuous operation diff --git a/Scripts/mal-rep-bin.sh b/Scripts/mal-rep-bin.sh new file mode 100755 index 000000000..5b79d3592 --- /dev/null +++ b/Scripts/mal-rep-bin.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +port=$[RANDOM+1024] + +for i in 0 1 2; do + IFS="" + log="mal-rep-bin-$*-$i" + IFS=" " + $prefix ./malicious-rep-bin-party.x -p $i -pn $port $* 2>&1 | + { + if test $i = 0; then + tee -a logs/$log + else + cat >> logs/$log + fi + } & true +done + +wait || exit 1 diff --git a/Scripts/mal-rep-field.sh b/Scripts/mal-rep-field.sh new file mode 100755 index 000000000..21593fb03 --- /dev/null +++ b/Scripts/mal-rep-field.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=3 + +. $HERE/run-common.sh + +run_player malicious-rep-field-party.x ${1:-test_all} || exit 1 diff --git a/Scripts/rep-field.sh b/Scripts/rep-field.sh new file mode 100755 index 000000000..54591d72b --- /dev/null +++ b/Scripts/rep-field.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=3 + +. $HERE/run-common.sh + +run_player replicated-field-party.x ${1:-test_all} || exit 1 diff --git a/Scripts/replicated.sh b/Scripts/replicated.sh index ccfcd2f30..e35f3522c 100755 --- a/Scripts/replicated.sh +++ b/Scripts/replicated.sh @@ -1,10 +1,19 @@ #!/bin/bash +port=$[RANDOM+1024] + for i in 0 1 2; do IFS="" log="replicated-$*-$i" IFS=" " - $prefix ./replicated-bin-party.x -p $i $* | tee -a logs/$log & true + $prefix ./replicated-bin-party.x -p $i -pn $port $* 2>&1 | + { + if test $i = 0; then + tee -a logs/$log + else + cat >> logs/$log + fi + } & true done wait || exit 1 diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index 3ee29ced5..bdbde4fc6 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -16,7 +16,7 @@ run_player() { if ! test -e $SPDZROOT/logs; then mkdir $SPDZROOT/logs fi - if test $bin = Player-Online.x -o $bin = replicated-ring-party.x; then + if [[ $bin = Player-Online.x || $bin =~ 'party.x' ]]; then params="$* -pn $port -h localhost" else params="$port localhost $*" @@ -24,8 +24,10 @@ run_player() { if test $bin = Player-KeyGen.x -a ! -e Player-Data/Params-Data; then ./Setup.x $players $size 40 fi - >&2 echo Running $SPDZROOT/Server.x $players $port - $SPDZROOT/Server.x $players $port & + if [[ $bin =~ Player- ]]; then + >&2 echo Running $SPDZROOT/Server.x $players $port + $SPDZROOT/Server.x $players $port & + fi rem=$(($players - 2)) for i in $(seq 0 $rem); do echo "trying with player $i" diff --git a/Scripts/setup-replicated.sh b/Scripts/setup-replicated.sh new file mode 100755 index 000000000..786c92b32 --- /dev/null +++ b/Scripts/setup-replicated.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +$HERE/setup-ssl.sh 3 + +$SPDZROOT/Setup.x 3 128 0 online diff --git a/Scripts/setup-ssl.sh b/Scripts/setup-ssl.sh index ad1089a59..af98e2719 100755 --- a/Scripts/setup-ssl.sh +++ b/Scripts/setup-ssl.sh @@ -2,12 +2,12 @@ n=${1:-3} -echo Setting up SSL for $n parties +test -e Player-Data || mkdir Player-Data -mkdir Player-Data +echo Setting up SSL for $n parties for i in `seq 0 $[n-1]`; do - openssl req -new -nodes -x509 -out Player-Data/P$i.pem -keyout Player-Data/P$i.key -subj "/CN=P$i" + openssl req -newkey rsa -nodes -x509 -out Player-Data/P$i.pem -keyout Player-Data/P$i.key -subj "/CN=P$i" done c_rehash Player-Data diff --git a/Setup.cpp b/Setup.cpp new file mode 100644 index 000000000..374509efb --- /dev/null +++ b/Setup.cpp @@ -0,0 +1,38 @@ +#include "Math/Setup.h" +#include "Auth/fake-stuff.hpp" +#include +#include +using namespace std; + +int main(int argc, char** argv) +{ + if (argc < 4) + { cout << "Call using\n\t"; + cout << "Setup.x n lgp lg2 \n"; + cout << "\t\t n = Number of players" << endl; + cout << "\t\t lgp = Bit size of char p message space" << endl; + cout << "\t\t lg2 = Bit size of char 2 message space" << endl; + exit(1); + } + + int n=atoi(argv[1]); + int lgp=atoi(argv[2]); + int lg2=atoi(argv[3]); + + string dir = get_prep_dir(n, lgp, lg2); + ofstream outf; + bigint p; + generate_online_setup(outf, dir, p, lgp, lg2); + + bool need_mac = false; + for (int i = 0; i < n; i++) + { + string filename = mac_filename(dir, i); + ifstream in(filename); + need_mac |= not in.good(); + } + if (need_mac) + generate_keys(dir, n); +} + + diff --git a/Tools/callgrind.h b/Tools/callgrind.h new file mode 100644 index 000000000..3c0041075 --- /dev/null +++ b/Tools/callgrind.h @@ -0,0 +1,17 @@ +/* + * callgrind.h + * + */ + +#ifndef TOOLS_CALLGRIND_H_ +#define TOOLS_CALLGRIND_H_ + +#ifdef USE_CALLGRIND +#include +#else +#define CALLGRIND_START_INSTRUMENTATION +#define CALLGRIND_STOP_INSTRUMENTATION +#define CALLGRIND_DUMP_STATS +#endif + +#endif /* TOOLS_CALLGRIND_H_ */ diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 36435f277..98d4ec930 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -87,6 +87,7 @@ class octetStream bool equals(const octetStream& a) const; bool operator==(const octetStream& a) const { return equals(a); } + bool operator!=(const octetStream& a) const { return not equals(a); } /* Append NUM random bytes from dev/random */ void append_random(size_t num); @@ -118,6 +119,9 @@ class octetStream void store(const bigint& x); void get(bigint& ans); + template + T get(); + // works for all statically allocated types template void serialize(const T& x) { append((octet*)&x, sizeof(x)); } @@ -268,5 +272,13 @@ inline void octetStream::ReceiveExpected(int socket_num, size_t expected) receive(socket_num,data,len); } +template +T octetStream::get() +{ + T res; + res.unpack(*this); + return res; +} + #endif diff --git a/Tools/random.cpp b/Tools/random.cpp index e6dc14375..660ba1b23 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -146,6 +146,21 @@ void PRNG::get_octetStream(octetStream& ans,int len) } +void PRNG::randomBnd(mp_limb_t* res, const mp_limb_t* B, size_t n_bytes) +{ + if (n_bytes == 16) + do + get_octets<16>((octet*) res); + while (mpn_cmp(res, B, 2) >= 0); + else + { + size_t n_limbs = (n_bytes + sizeof(mp_limb_t) - 1) / sizeof(mp_limb_t); + do + get_octets((octet*) res, n_bytes); + while (mpn_cmp(res, B, n_limbs) >= 0); + } +} + bigint PRNG::randomBnd(const bigint& B, bool positive) { bigint x; diff --git a/Tools/random.h b/Tools/random.h index f7c5cf1d9..a2bb0c3c6 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -5,6 +5,9 @@ #include "Tools/sha1.h" #include "Tools/aes.h" #include "Tools/avx_memcpy.h" +#include "Networking/data.h" + +#include #define USE_AES @@ -72,6 +75,8 @@ class PRNG void get(int& res, int n_bits, bool positive = true); void randomBnd(bigint& res, const bigint& B, bool positive=true); bigint randomBnd(const bigint& B, bool positive=true); + // only efficient if byte length of B is exactly n_bytes + void randomBnd(mp_limb_t* res, const mp_limb_t* B, size_t n_bytes); word get_word() { word a; @@ -88,6 +93,15 @@ class PRNG { return seed; } }; +class SeededPRNG : public PRNG +{ +public: + SeededPRNG() + { + ReSeed(); + } +}; + inline unsigned char PRNG::get_uchar() { diff --git a/Yao/YaoEvalMaster.cpp b/Yao/YaoEvalMaster.cpp index ccc56ef5d..2c50426f1 100644 --- a/Yao/YaoEvalMaster.cpp +++ b/Yao/YaoEvalMaster.cpp @@ -6,7 +6,8 @@ #include "YaoEvalMaster.h" #include "YaoEvaluator.h" -YaoEvalMaster::YaoEvalMaster(bool continuous) : continuous(continuous) +YaoEvalMaster::YaoEvalMaster(bool continuous, OnlineOptions& opts) : + ThreadMaster>(opts), continuous(continuous) { } diff --git a/Yao/YaoEvalMaster.h b/Yao/YaoEvalMaster.h index cfbd87e07..69d8fa195 100644 --- a/Yao/YaoEvalMaster.h +++ b/Yao/YaoEvalMaster.h @@ -17,7 +17,7 @@ class YaoEvalMaster : public GC::ThreadMaster> public: bool continuous; - YaoEvalMaster(bool continuous); + YaoEvalMaster(bool continuous, OnlineOptions& opts); Thread>* new_thread(int i); }; diff --git a/Yao/YaoEvalWire.cpp b/Yao/YaoEvalWire.cpp index e0f2949d3..17820426e 100644 --- a/Yao/YaoEvalWire.cpp +++ b/Yao/YaoEvalWire.cpp @@ -96,9 +96,12 @@ void YaoEvalWire::and_(GC::Processor >& processor, void YaoEvalWire::inputb(GC::Processor >& processor, const vector& args) { - ArgList a(args); + InputArgList a(args); BitVector inputs; inputs.resize(0); + auto& evaluator = YaoEvaluator::s(); + bool interactive = evaluator.n_interactive_inputs_from_me(a) > 0; + for (auto x : a) { auto& dest = processor.S[x.dest]; @@ -112,7 +115,7 @@ void YaoEvalWire::inputb(GC::Processor >& processor, } else { - long long input = processor.get_input(x.n_bits); + long long input = processor.get_input(x.n_bits, interactive); size_t start = inputs.size(); inputs.resize(start + x.n_bits); for (int i = 0; i < x.n_bits; i++) @@ -120,7 +123,9 @@ void YaoEvalWire::inputb(GC::Processor >& processor, } } - auto& evaluator = YaoEvaluator::s(); + if (interactive) + cout << "Thank you" << endl; + evaluator.ot_ext.extend_correlated(inputs.size(), inputs); octetStream os; evaluator.player.receive(os); diff --git a/Yao/YaoEvalWire.h b/Yao/YaoEvalWire.h index d751543f1..388323a72 100644 --- a/Yao/YaoEvalWire.h +++ b/Yao/YaoEvalWire.h @@ -9,13 +9,12 @@ #include "BMR/Key.h" #include "BMR/Gate.h" #include "BMR/Register.h" -#include "GC/Processor.h" -#include "Auth/MAC_Check.h" +#include "Processor/DummyProtocol.h" class YaoEvalWire : public Phase { public: - typedef MAC_Check_Base MC; + typedef DummyMC MC; static string name() { return "YaoEvalWire"; } diff --git a/Yao/YaoEvaluator.cpp b/Yao/YaoEvaluator.cpp index 5c5f107ce..32490d9b7 100644 --- a/Yao/YaoEvaluator.cpp +++ b/Yao/YaoEvaluator.cpp @@ -8,7 +8,7 @@ thread_local YaoEvaluator* YaoEvaluator::singleton = 0; YaoEvaluator::YaoEvaluator(int thread_num, YaoEvalMaster& master) : - Thread>(thread_num, master.machine, master.N), + Thread>(thread_num, master), master(master), player(N, 0, thread_num << 24), ot_ext(OTExtensionWithMatrix::setup(player, {}, RECEIVER, true)) diff --git a/Yao/YaoEvaluator.h b/Yao/YaoEvaluator.h index 062f9cf0b..2335c5558 100644 --- a/Yao/YaoEvaluator.h +++ b/Yao/YaoEvaluator.h @@ -7,11 +7,9 @@ #define YAO_YAOEVALUATOR_H_ #include "YaoGate.h" -#include "YaoPlayer.h" #include "YaoEvalMaster.h" #include "YaoCommon.h" #include "GC/Secret.h" -#include "GC/Program.h" #include "GC/Thread.h" #include "Tools/MMO.h" #include "OT/OTExtensionWithMatrix.h" diff --git a/Yao/YaoGarbleMaster.cpp b/Yao/YaoGarbleMaster.cpp index 31dc3aa42..b2e0da618 100644 --- a/Yao/YaoGarbleMaster.cpp +++ b/Yao/YaoGarbleMaster.cpp @@ -6,8 +6,8 @@ #include "YaoGarbleMaster.h" #include "YaoGarbler.h" -YaoGarbleMaster::YaoGarbleMaster(bool continuous, int threshold) : - continuous(continuous), threshold(threshold) +YaoGarbleMaster::YaoGarbleMaster(bool continuous, OnlineOptions& opts, int threshold) : + super(opts), continuous(continuous), threshold(threshold) { PRNG G; G.ReSeed(); diff --git a/Yao/YaoGarbleMaster.h b/Yao/YaoGarbleMaster.h index c8a3ecc6b..5df7ad18c 100644 --- a/Yao/YaoGarbleMaster.h +++ b/Yao/YaoGarbleMaster.h @@ -9,17 +9,20 @@ #include "GC/ThreadMaster.h" #include "GC/Secret.h" #include "YaoGarbleWire.h" +#include "Processor/OnlineOptions.h" using namespace GC; class YaoGarbleMaster : public GC::ThreadMaster> { + typedef GC::ThreadMaster> super; + public: bool continuous; int threshold; Key delta; - YaoGarbleMaster(bool continuous, int threshold = 1024); + YaoGarbleMaster(bool continuous, OnlineOptions& opts, int threshold = 1024); Thread>* new_thread(int i); }; diff --git a/Yao/YaoGarbleWire.cpp b/Yao/YaoGarbleWire.cpp index 817b4dc8e..03ed88299 100644 --- a/Yao/YaoGarbleWire.cpp +++ b/Yao/YaoGarbleWire.cpp @@ -190,15 +190,17 @@ void YaoGarbleWire::and_(GC::Memory >& S, void YaoGarbleWire::inputb(GC::Processor>& processor, const vector& args) { - ArgList a(args); + InputArgList a(args); int n_evaluator_bits = 0; + auto& garbler = YaoGarbler::s(); + bool interactive = garbler.n_interactive_inputs_from_me(a) > 0; for (auto x : a) { auto& dest = processor.S[x.dest]; dest.resize_regs(x.n_bits); if (x.from == 0) { - long long input = processor.get_input(x.n_bits); + long long input = processor.get_input(x.n_bits, interactive); for (auto& reg : dest.get_regs()) { reg.public_input(input & 1); @@ -211,7 +213,9 @@ void YaoGarbleWire::inputb(GC::Processor>& processor, } } - auto& garbler = YaoGarbler::s(); + if (interactive) + cout << "Thank you"; + garbler.receiver_input_keys.push_back({}); for (auto x : a) diff --git a/Yao/YaoGarbleWire.h b/Yao/YaoGarbleWire.h index 4f7a15ca5..d582c4aaf 100644 --- a/Yao/YaoGarbleWire.h +++ b/Yao/YaoGarbleWire.h @@ -8,7 +8,8 @@ #include "BMR/Key.h" #include "BMR/Register.h" -#include "GC/Processor.h" + +#include class YaoGate; class YaoGarbler; diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp index 36d72b83d..0fc399e92 100644 --- a/Yao/YaoGarbler.cpp +++ b/Yao/YaoGarbler.cpp @@ -9,7 +9,7 @@ thread_local YaoGarbler* YaoGarbler::singleton = 0; YaoGarbler::YaoGarbler(int thread_num, YaoGarbleMaster& master) : - Thread>(thread_num, master.machine, master.N), + Thread>(thread_num, master), master(master), and_proc_timer(CLOCK_PROCESS_CPUTIME_ID), and_main_thread_timer(CLOCK_THREAD_CPUTIME_ID), diff --git a/Yao/YaoGarbler.h b/Yao/YaoGarbler.h index 181948492..3271d4129 100644 --- a/Yao/YaoGarbler.h +++ b/Yao/YaoGarbler.h @@ -13,7 +13,6 @@ #include "Tools/random.h" #include "Tools/MMO.h" #include "GC/Secret.h" -#include "GC/Program.h" #include "Networking/Player.h" #include "OT/OTExtensionWithMatrix.h" #include "sys/sysinfo.h" diff --git a/Yao/YaoGate.h b/Yao/YaoGate.h index 8ff8160e0..08dd45d08 100644 --- a/Yao/YaoGate.h +++ b/Yao/YaoGate.h @@ -10,7 +10,6 @@ #include "BMR/Key.h" #include "YaoGarbleWire.h" #include "YaoEvalWire.h" -#include "YaoGarbler.h" class YaoGate { diff --git a/Yao/YaoPlayer.cpp b/Yao/YaoPlayer.cpp index 426f4c51d..30869d9ed 100644 --- a/Yao/YaoPlayer.cpp +++ b/Yao/YaoPlayer.cpp @@ -56,6 +56,7 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) "-t", // Flag token. "--threshold" // Flag token. ); + OnlineOptions online_opts(opt, argc, argv); opt.parse(argc, argv); opt.syntax = "./yao-player.x [OPTIONS] "; if (opt.lastArgs.size() == 1) @@ -82,9 +83,9 @@ YaoPlayer::YaoPlayer(int argc, const char** argv) ThreadMasterBase* master; if (my_num == 0) - master = new YaoGarbleMaster(continuous, threshold); + master = new YaoGarbleMaster(continuous, online_opts, threshold); else - master = new YaoEvalMaster(continuous); + master = new YaoEvalMaster(continuous, online_opts); server = Server::start_networking(master->N, my_num, 2, hostname, pnb); master->run(progname); diff --git a/malicious-rep-bin-party.cpp b/malicious-rep-bin-party.cpp new file mode 100644 index 000000000..502112e03 --- /dev/null +++ b/malicious-rep-bin-party.cpp @@ -0,0 +1,12 @@ +/* + * malicious-rep-bin-party.cpp + * + */ + +#include "GC/ReplicatedParty.h" +#include "GC/MaliciousRepSecret.h" + +int main(int argc, const char** argv) +{ + GC::ReplicatedParty(argc, argv); +} diff --git a/malicious-rep-field-party.cpp b/malicious-rep-field-party.cpp new file mode 100644 index 000000000..e2300519e --- /dev/null +++ b/malicious-rep-field-party.cpp @@ -0,0 +1,12 @@ +/* + * malicious-rep-field-party.cpp + * + */ + +#include "Math/MaliciousRep3Share.h" +#include "Processor/ReplicatedMachine.hpp" + +int main(int argc, const char** argv) +{ + ReplicatedMachine, MaliciousRep3Share>(argc, argv, "malicious-rep-field"); +} diff --git a/pairwise-offline.cpp b/pairwise-offline.cpp index 50f96a662..18d059e03 100644 --- a/pairwise-offline.cpp +++ b/pairwise-offline.cpp @@ -1,5 +1,5 @@ #include "FHEOffline/PairwiseMachine.h" -#include +#include "Tools/callgrind.h" int main(int argc, const char** argv) { diff --git a/replicated-bin-party.cpp b/replicated-bin-party.cpp index 53d6eaff2..d528e4db3 100644 --- a/replicated-bin-party.cpp +++ b/replicated-bin-party.cpp @@ -7,5 +7,5 @@ int main(int argc, const char** argv) { - GC::ReplicatedParty(argc, argv); + GC::ReplicatedParty(argc, argv); } diff --git a/replicated-field-party.cpp b/replicated-field-party.cpp new file mode 100644 index 000000000..9e27f5365 --- /dev/null +++ b/replicated-field-party.cpp @@ -0,0 +1,12 @@ +/* + * replicated-field-party.cpp + * + */ + +#include "Processor/ReplicatedMachine.hpp" +#include "Math/gfp.h" + +int main(int argc, const char** argv) +{ + ReplicatedMachine, Rep3Share>(argc, argv, "replicated-field"); +} diff --git a/replicated-ring-party.cpp b/replicated-ring-party.cpp index 66864a658..9c6fe7334 100644 --- a/replicated-ring-party.cpp +++ b/replicated-ring-party.cpp @@ -3,83 +3,10 @@ * */ -#include "Tools/ezOptionParser.h" -#include "Tools/benchmarking.h" -#include "Networking/Server.h" -#include "Math/Rep3Share.h" -#include "Processor/Machine.h" +#include "Processor/ReplicatedMachine.hpp" +#include "Math/Integer.h" int main(int argc, const char** argv) { - ez::ezOptionParser opt; - opt.add( - "localhost", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Host where party 0 is running (default: localhost)", // Help description. - "-h", // Flag token. - "--hostname" // Flag token. - ); - opt.add( - "5000", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Base port number (default: 5000).", // Help description. - "-pn", // Flag token. - "--portnum" // Flag token. - ); - opt.add( - "", // Default. - 0, // Required? - 0, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "Unencrypted communication.", // Help description. - "-u", // Flag token. - "--unencrypted" // Flag token. - ); - opt.syntax = "./replicated-ring-party.x [OPTIONS] "; - opt.parse(argc, argv); - vector allArgs(opt.firstArgs); - allArgs.insert(allArgs.end(), opt.lastArgs.begin(), opt.lastArgs.end()); - - int playerno; - string progname; - - if (allArgs.size() != 3) - { - cerr << "ERROR: incorrect number of arguments to " << argv[0] << endl; - cerr << "Arguments given were:\n"; - for (unsigned int j = 1; j < allArgs.size(); j++) - cout << "'" << *allArgs[j] << "'" << endl; - string usage; - opt.getUsage(usage); - cout << usage; - return 1; - } - else - { - playerno = atoi(allArgs[1]->c_str()); - progname = *allArgs[2]; - - } - - int pnb; - string hostname; - opt.get("-pn")->getInt(pnb); - opt.get("-h")->getString(hostname); - bool use_encryption = not opt.get("-u")->isSet; - - if (not use_encryption) - insecure("unencrypted communication"); - Names N; - Server* server = Server::start_networking(N, playerno, 3, hostname, pnb); - - Machine(playerno, N, progname, "empty", 128, - gf2n::default_degree(), 0, 0, 0, 0, 0, use_encryption).run(); - - if (server) - delete server; - + ReplicatedMachine, Rep3Share>(argc, argv, "replicated-ring"); } diff --git a/simple-offline.cpp b/simple-offline.cpp index 988ef8910..3ceb70952 100644 --- a/simple-offline.cpp +++ b/simple-offline.cpp @@ -4,7 +4,7 @@ */ #include -#include +#include "Tools/callgrind.h" int main(int argc, const char** argv) {