diff --git a/.gitignore b/.gitignore index e5724e9..2fb746c 100644 --- a/.gitignore +++ b/.gitignore @@ -72,6 +72,7 @@ gens *.der *.srl *.seq +*.cnf #latex artifacts *.aux diff --git a/Dockerfile b/Dockerfile index 16f3754..5101da3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,8 +27,14 @@ cp bin/release/pjc-client exec && \ cp bin/release/pjc-server exec && \ cp bin/release/datagen exec && \ cp bin/release/private-id-multi-key-server exec && \ -cp bin/release/private-id-multi-key-client exec - +cp bin/release/private-id-multi-key-client exec && \ +cp bin/release/dpmc-company-server exec && \ +cp bin/release/dpmc-helper exec && \ +cp bin/release/dpmc-partner-server exec && \ +cp bin/release/dspmc-company-server exec && \ +cp bin/release/dspmc-helper-server exec && \ +cp bin/release/dspmc-partner-server exec && \ +cp bin/release/dspmc-shuffler exec # thin container with binaries # base image is taken from here https://hub.docker.com/_/debian/ diff --git a/README.md b/README.md index d54165c..4a10a0d 100644 --- a/README.md +++ b/README.md @@ -1,164 +1,383 @@ # Private-ID -Private-ID is a collection of algorithms to match records between two parties, while preserving the privacy of these records. We present two algorithms to do this---one of which does an outer join between parties and another does a inner join and then generates additive shares that can then be input to a Multi Party Compute system like [CrypTen](https://github.com/facebookresearch/CrypTen). Please refer to our [paper](https://eprint.iacr.org/2020/599.pdf) for more details. +Private-ID is a collection of algorithms to match records between two or parties, while preserving the privacy of these records. We present multiple algorithms to do this---one of which does an outer join between parties, and others do inner or left join and then generate additive shares that can then be input to a Multi Party Compute system like [CrypTen](https://github.com/facebookresearch/CrypTen). Please refer to our [paper](https://eprint.iacr.org/2020/599.pdf) for more details. The MultiKey Private-ID [paper](https://eprint.iacr.org/2021/770.pdf) and the Delegated Private-ID [paper](https://eprint.iacr.org/2023/012.pdf) extend Private-ID. -### Build +## Build -Private-ID is implemented in Rust to take advantage of the languages security features and to leverage the encryption libraries that we depend on. It should compile with the nightly Rust toolchain. +Private-ID is implemented in Rust to take advantage of the language's security features and to leverage the encryption libraries that we depend on. It should compile with the nightly Rust toolchain. The following should build and run the unit tests for the building blocks used by the protocols -- `cargo build`, `cargo test` +```bash +cargo build --release +cargo test --release +``` + +Each protocol involves two (or more) parties and they have to be run in their own shell environment. We call one party Company and another party Partner. Some protocols also involve additional parties such as the Helper and the Shuffler. -Each protocol involves two parties and they have to be run in its own shell environment. We call one party Company and another party Partner. +Run the script at etc/example/generate_cert.sh to generate dummy_certs directory if you want to test protocol with TLS on local. -Run the script at etc/example/generate_cert.sh to generate dummy_certs directroy if you want to test protocol with tls on local. +### Build & Run With Docker Compose +The following, run each party in a different container: +* Private-ID: `docker compose --profile private-id up` +* Delegated Private Matching for Compute (DPMC): `docker compose --profile dpmc up` +* Delegated Private Matching for Compute with Secure Shuffling (DSPMC): `docker compose --profile dspmc up` -### Private-ID +By default, this will create datasets of 10 items each. To run with bigger datasets set the `ENV_VARIABLE_FOR_SIZE` environment variable. For example: `ENV_VARIABLE_FOR_SIZE=100 docker compose --profile dpmc up` will run DPMC with datasets of 100 items each. + +## Private-ID This protocol maps the email addresses from both parties to a single ID spine, so that same e-mail addresses map to the same key. -To run Company +To run Company: +```bash +env RUST_LOG=info cargo run --release --bin private-id-server -- \ + --host 0.0.0.0:10009 \ + --input etc/example/email_company.csv \ + --stdout \ + --no-tls +``` +To run Partner: ```bash -env RUST_LOG=info cargo run --bin private-id-server -- \ ---host 0.0.0.0:10009 \ ---input etc/example/email_company.csv \ ---stdout \ ---tls-dir etc/example/dummy_certs +env RUST_LOG=info cargo run --release --bin private-id-client -- \ + --company localhost:10009 \ + --input etc/example/email_partner.csv \ + --stdout \ + --no-tls ``` -To run Partner +## Private-ID MultiKey + +We extend the Private-ID protocol to match multiple identifiers. Please refer to our [paper](https://eprint.iacr.org/2021/770) for more details. +To run Company: ```bash -env RUST_LOG=info cargo run --bin private-id-client -- \ ---company localhost:10009 \ ---input etc/example/email_partner.csv \ ---stdout \ ---tls-dir etc/example/dummy_certs +env RUST_LOG=info cargo run --release --bin private-id-multi-key-server -- \ + --host 0.0.0.0:10009 \ + --input etc/example/private_id_multi_key/Ex1_company.csv \ + --stdout \ + --no-tls ``` -### Private-ID MultiKey +To run Partner: +```bash +env RUST_LOG=info cargo run --release --bin private-id-multi-key-client -- \ + --company localhost:10009 \ + --input etc/example/private_id_multi_key/Ex1_partner.csv \ + --stdout \ + --no-tls +``` -We extend the Private-ID protocol to match multiple identifiers. Please refer to our [paper](https://eprint.iacr.org/2021/770) for more details. +## PS3I -To run Company +This protocol does an inner join based on email addresses as keys and then generates additive share of a feature associated with that email address. The shares are generated in the designated output files as 64-bit numbers + +To run Company: +```bash +env RUST_LOG=info cargo run --release --bin cross-psi-server -- \ + --host 0.0.0.0:10010 \ + --input etc/example/input_company.csv \ + --output etc/example/output_company.csv \ + --no-tls +``` +To run Partner: ```bash -env RUST_LOG=info cargo run --bin private-id-multi-key-server -- \ - --host 0.0.0.0:10009 \ - --input etc/example/private_id_multi_key/Ex1_company.csv \ - --stdout \ - --tls-dir etc/example/dummy_certs +env RUST_LOG=info cargo run --release --bin cross-psi-client -- \ + --company localhost:10010 \ + --input etc/example/input_partner.csv \ + --output etc/example/output_partner.csv \ + --no-tls ``` -To run Partner +## PS3I XOR + +This protocol does an inner join based on email addresses as keys and then generates XOR share of a feature associated with that email address. The shares are generated in the designated output files as 64-bit numbers + +To run Company: +```bash +env RUST_LOG=info cargo run --release --bin cross-psi-xor-server -- \ + --host 0.0.0.0:10010 \ + --input etc/example/cross_psi_xor/input_company.csv \ + --output etc/example/cross_psi_xor/output_company \ + --no-tls +``` +To run Partner: ```bash -env RUST_LOG=info cargo run --bin private-id-multi-key-client -- \ - --company localhost:10009 \ - --input etc/example/private_id_multi_key/Ex1_partner.csv \ - --stdout \ - --tls-dir etc/example/dummy_certs +env RUST_LOG=info cargo run --release --bin cross-psi-xor-client -- \ + --company localhost:10010 \ + --input etc/example/cross_psi_xor/input_partner.csv \ + --output etc/example/cross_psi_xor/output_partner \ + --no-tls ``` -### PS3I +The `--output` option provides prefix for the output files that contain the shares. In this case, Company generates two files; `output_company_company_feature.csv` and `output_company_partner_feature.csv`. They contain Company's share of company and parter features respectively. Similarly Partner generates two files; `output_partner_company_feature.csv` and `output_partner_partner_feature.csv`. They contain Partner's share of company and partner features respectively. -This protocol does an inner join based on email addresses as keys and then generates additive share of a feature associated with that email address. The shares are generated in the designated output files as 64 bit numbers +Thus `output_company_company_feature.csv` and `output_partner_company_feature.csv` are XOR shares of Company's features. Similarly, `output_partner_company_feature.csv` and `output_partner_partner_feature.csv` are XOR shares of Partner's features. -To run Company +### Private Join and Compute +This is an implementation of Google's [Private Join and Compute](https://github.com/google/private-join-and-compute) protocol, that does a inner join based on email addresses and computes a sum of the corresponding feature for the Partner. +To run Company: ```bash -env RUST_LOG=info cargo run --bin cross-psi-server -- \ ---host 0.0.0.0:10010 \ ---input etc/example/input_company.csv \ ---output etc/example/output_company.csv \ ---no-tls +env RUST_LOG=info cargo run --release --bin pjc-server -- \ + --host 0.0.0.0:10011 \ + --input etc/example/pjc_company.csv \ + --stdout \ + --no-tls ``` -To run Partner +To run Partner: +```bash +env RUST_LOG=info cargo run --release --bin pjc-client -- \ + --company localhost:10011 \ + --input etc/example/pjc_partner.csv \ + --stdout \ + --no-tls +``` + +## SUMID +This is an implmentation of 2-party version of Secure Universal ID protocol. This can work on multiple keys. In the current implementation, the merger party also assumes the role of one data party and the sharer party assumes the role of all the other data parties. The data parties are the `.csv` files show below +To run merger: ```bash -env RUST_LOG=info cargo run --bin cross-psi-client -- \ ---company localhost:10010 \ ---input etc/example/input_partner.csv \ ---output etc/example/output_partner.csv \ ---no-tls +env RUST_LOG=info cargo run --release --bin suid-create-server -- \ + --host 0.0.0.0:10010 \ + --input etc/example/suid/Example1/DataParty2_input.csv \ + --stdout \ + --no-tls ``` -### PS3I XOR +To run client: +```bash +env RUST_LOG=info cargo run --release --bin suid-create-client -- \ + --merger localhost:10010 \ + --input etc/example/suid/Example1/DataParty1_input.csv \ + --input etc/example/suid/Example1/DataParty3_input.csv \ + --stdout \ + --no-tls +``` + +The output will be ElGamal encrypted Universal IDs assigned to each entry in the `.csv` file. -This protocol does an inner join based on email addresses as keys and then generates XOR share of a feature associated with that email address. The shares are generated in the designated output files as 64 bit numbers +## Delegated Private Matching for Compute (DPMC) -To run Company +We extend the Multi-key Private-ID protocol to multiple partners. Please refer to our [paper](https://eprint.iacr.org/2023/012) for more details. +To run Company: ```bash -env RUST_LOG=info cargo run --bin cross-psi-xor-server -- \ ---host 0.0.0.0:10010 \ ---input etc/example/cross_psi_xor/input_company.csv \ ---output etc/example/cross_psi_xor/output_company \ ---no-tls +env RUST_LOG=info cargo run --release --bin dpmc-company-server -- \ + --host 0.0.0.0:10010 \ + --input etc/example/dpmc/Ex0_company.csv \ + --stdout \ + --output-shares-path etc/example/dpmc/output_company \ + --no-tls ``` -To run Partner +To run multiple partners (servers): +```bash +env RUST_LOG=info cargo run --release --bin dpmc-partner-server -- \ + --host 0.0.0.0:10020 \ + --company localhost:10010 \ + --input-keys etc/example/dpmc/Ex0_partner_1.csv \ + --input-features etc/example/dpmc/Ex0_partner_1_features.csv \ + --no-tls +``` ```bash -env RUST_LOG=info cargo run --bin cross-psi-xor-client -- \ ---company localhost:10010 \ ---input etc/example/cross_psi_xor/input_partner.csv \ ---output etc/example/cross_psi_xor/output_partner \ ---no-tls +env RUST_LOG=info cargo run --release --bin dpmc-partner-server -- \ + --host 0.0.0.0:10021 \ + --company localhost:10010 \ + --input-keys etc/example/dpmc/Ex0_partner_2.csv \ + --input-features etc/example/dpmc/Ex0_partner_2_features.csv \ + --no-tls ``` -The `--output` option provides prefix for the output files that contain the shares. In this case, Company generates two files; `output_company_company_feature.csv` and `output_company_partner_feature.csv`. They contain Company's share of company and parter features respectively. Similarly Partner generates two files; `output_partner_company_feature.csv` and `output_partner_partner_feature.csv`. They contain Partner's share of company and partner features respectively. -Thus `output_company_company_feature.csv` and `output_partner_company_feature.csv` are XOR shares of Company's features. Similarly `output_partner_company_feature.csv` and `output_partner_partner_feature.csv` are XOR shares of Partner's features. +Start helper (client): +```bash +env RUST_LOG=info cargo run --release --bin dpmc-helper -- \ + --company localhost:10010 \ + --partners localhost:10020,localhost:10021 \ + --stdout \ + --output-shares-path etc/example/dpmc/output_partner \ + --no-tls +``` -### Private Join and Compute -This is an implementation of Google's [Private Join and Compute](https://github.com/google/private-join-and-compute) protocol, that does a inner join based on email addresses and computes a sum of the corresponding feature for the Partner. +The above will generate one-to-one matches. To enable one-to-many matches (one +record from C will match to `M` P records), use the flag `--one-to-many M` in the +`dpmc-helper` binary, where `M` is the number of matches. + +For example, using the same scripts as above for company and partners, to run +`1-2` matching, start the helper as follows: ```bash -env RUST_LOG=info cargo run --bin pjc-client -- \ ---company localhost:10011 \ ---input etc/example/pjc_partner.csv \ ---stdout \ ---tls-dir etc/example/dummy_certs +env RUST_LOG=info cargo run --release --bin dpmc-helper -- \ + --company localhost:10010 \ + --partners localhost:10020,localhost:10021 \ + --one-to-many 2 \ + --stdout \ + --output-shares-path etc/example/dpmc/output_partner \ + --no-tls ``` +## Delegated Private Matching for Compute with Secure Shuffling (DSPMC) + +Start helper (server): ```bash -env RUST_LOG=info cargo run --bin pjc-server -- \ ---host 0.0.0.0:10011 \ ---input etc/example/pjc_company.csv \ ---stdout \ ---tls-dir etc/example/dummy_certs +env RUST_LOG=info cargo run --release --bin dspmc-helper-server -- \ + --host 0.0.0.0:10030 \ + --stdout \ + --output-shares-path etc/example/dspmc/output_helper \ + --no-tls ``` -### SUMID -This is an implmentation of 2-party version of Secure Universal ID protocol. This can work on multiple keys. In the current implementation, the merger party also assumes the role of one data party and the sharer party assumes the role of all the other data parties. The data parties are the `.csv` files show below -To run merger +Start company (server): +```bash +env RUST_LOG=info cargo run --release --bin dspmc-company-server -- \ + --host 0.0.0.0:10010 \ + --helper localhost:10030 \ + --input etc/example/dspmc/Ex0_company.csv \ + --stdout \ + --output-shares-path etc/example/dspmc/output_company \ + --no-tls +``` + +Start multiple partners (servers): +```bash +env RUST_LOG=info cargo run --release --bin dspmc-partner-server -- \ + --host 0.0.0.0:10020 \ + --company localhost:10010 \ + --input-keys etc/example/dspmc/Ex0_partner_1.csv \ + --input-features etc/example/dspmc/Ex0_partner_1_features.csv \ + --no-tls +``` + +```bash +env RUST_LOG=info cargo run --release --bin dspmc-partner-server -- \ + --host 0.0.0.0:10021 \ + --company localhost:10010 \ + --input-keys etc/example/dspmc/Ex0_partner_2.csv \ + --input-features etc/example/dspmc/Ex0_partner_2_features.csv \ + --no-tls +``` + +Start Shuffler (client): +```bash +env RUST_LOG=info cargo run --release --bin dspmc-shuffler -- \ + --company localhost:10010 \ + --helper localhost:10030 \ + --partners localhost:10020,localhost:10021 \ + --stdout \ + --no-tls +``` + +### Note: Running over the network +To run over the network instead of localhost prepend the IP address with `http://` or `https://`. For example: + +To run Company (in IP `1.23.34.45`): ```bash -env RUST_LOG=info cargo run --bin suid-create-server -- \ - --host 0.0.0.0:10010 \ - --input etc/example/suid/Example1/DataParty2_input.csv \ - --stdout \ - --tls-dir etc/example/dummy_certs +env RUST_LOG=info cargo run --release --bin dpmc-company-server -- \ + --host 0.0.0.0:10010 \ + --input etc/example/dpmc/Ex0_company.csv \ + --stdout \ + --output-shares-path etc/example/dpmc/output_company \ + --no-tls ``` -To run merger +To run multiple partners (servers) (in IPs `76.65.54.43` and `76.65.54.44`): ```bash -env RUST_LOG=info cargo run --bin suid-create-client -- \ - --merger localhost:10010 \ - --input etc/example/suid/Example1/DataParty1_input.csv \ - --input etc/example/suid/Example1/DataParty3_input.csv \ - --stdout \ - --tls-dir etc/example/dummy_certs +env RUST_LOG=info cargo run --release --bin dpmc-partner-server -- \ + --host 0.0.0.0:10020 \ + --company http://1.23.34.45:10010 \ + --input-keys etc/example/dpmc/Ex0_partner_1.csv \ + --input-features etc/example/dpmc/Ex0_partner_1_features.csv \ + --no-tls ``` -The output will be ElGamal encrypted Universal IDs assigned to each entry in the `.csv` file +```bash +env RUST_LOG=info cargo run --release --bin dpmc-partner-server -- \ + --host 0.0.0.0:10021 \ + --company http://1.23.34.45:10010 \ + --input-keys etc/example/dpmc/Ex0_partner_2.csv \ + --input-features etc/example/dpmc/Ex0_partner_2_features.csv \ + --no-tls +``` + +Start helper (client): +```bash +env RUST_LOG=info cargo run --release --bin dpmc-helper -- \ + --company http://1.23.34.45:10010 \ + --partners http://76.65.54.43:10020,http://76.65.54.44:10021 \ + --stdout \ + --output-shares-path etc/example/dpmc/output_partner \ + --no-tls +``` + +# Citing Private-ID + +To cite Private-ID in academic papers, please use the following BibTeX entries. + +## Delegated Private-ID +``` +@Article{PoPETS:MMTSBC23, + author = "Dimitris Mouris and + Daniel Masny and + Ni Trieu and + Shubho Sengupta and + Prasad Buddhavarapu and + Benjamin M Case", + title = "{Delegated Private Matching for Compute}", + volume = 2024, + month = Jul, + year = 2024, + journal = "{Proceedings on Privacy Enhancing Technologies}", + number = 2, + pages = "1--24", +} +``` + +## Multi-Key Private-ID +``` +@Misc{EPRINT:BCGKMSTX21, + author = "Prasad Buddhavarapu and + Benjamin M Case and + Logan Gore and + Andrew Knox and + Payman Mohassel and + Shubho Sengupta and + Erik Taubeneck and + Min Xue", + title = "Multi-key Private Matching for Compute", + year = 2021, + howpublished = "Cryptology ePrint Archive, Report 2021/770", + note = "\url{https://eprint.iacr.org/2021/770}", +} +``` + +## Private-ID +``` +@Misc{EPRINT:BKMSTV20, + author = "Prasad Buddhavarapu and + Andrew Knox and + Payman Mohassel and + Shubho Sengupta and + Erik Taubeneck and + Vlad Vlaskin", + title = "Private Matching for Compute", + year = 2020, + howpublished = "Cryptology ePrint Archive, Report 2020/599", + note = "\url{https://eprint.iacr.org/2020/599}", +} +``` ## License -Private-ID is Apache 2.0 licensed, as found in the [LICENSE](/LICENSE) file +Private-ID is Apache 2.0 licensed, as found in the [LICENSE](/LICENSE) file. ## Additional Resources on Private Computation at Meta +* [Delegated Multi-key Private Matching for Compute: Improving match rates and enabling adoption](https://research.facebook.com/blog/2023/1/delegated-multi-key-private-matching-for-compute-improving-match-rates-and-enabling-adoption/) +* [Private matching for compute](https://engineering.fb.com/2020/07/10/open-source/private-matching/) * [The Value of Secure Multi-Party Computation](https://privacytech.fb.com/multi-party-computation/) * [Building the Next Era of Personalized Experiences](https://www.facebook.com/business/news/building-the-next-era-of-personalized-experiences) * [Privacy-Enhancing Technologies and Building for the Future](https://www.facebook.com/business/news/building-for-the-future) diff --git a/common/Cargo.toml b/common/Cargo.toml index f0f782e..c5d9876 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -14,7 +14,7 @@ path = "datagen/datagen.rs" [dependencies] log = "0.4" env_logger = "0.7.1" -rayon = "1.3.0" +rayon = "1.8.0" clap = "2.33.0" csv = "1.1.1" rand = { version = "0.8", features = ["small_rng"] } @@ -23,10 +23,10 @@ hex = "0.3.0" serde = {version = "1.0.104", features = ["derive"] } num = "0.2.1" wasm-timer = "0.2.5" -aws-config = "0.54.1" -aws-credential-types = "0.54.1" -aws-sdk-s3 = "0.24.0" -aws-smithy-http = "0.54.0" +aws-config = "0.56.1" +aws-credential-types = "0.56.1" +aws-sdk-s3 = "0.34.0" +aws-smithy-http = "0.56.0" lazy_static = "1.4.0" regex = "1.5.4" tempfile = "3.2.0" diff --git a/common/datagen/datagen.rs b/common/datagen/datagen.rs index 5c53d8c..08ed254 100644 --- a/common/datagen/datagen.rs +++ b/common/datagen/datagen.rs @@ -19,7 +19,7 @@ pub mod gen { pub player_a: Vec, pub player_a_values: Option>, pub player_b: Vec, - pub player_b_values: Option>, + pub player_b_values: Option>, } pub fn random_data( @@ -40,11 +40,15 @@ pub mod gen { player_b.extend_from_slice(&intersection); player_b.shuffle(&mut rng); + let player_b_features = (0..(player_a_size + intersection_size)) + .map(|_| random_u8().to_string()) + .collect::>(); + Data { player_a, player_b, player_a_values: None, - player_b_values: None, + player_b_values: Some(player_b_features), } } @@ -76,6 +80,12 @@ pub mod gen { s } + fn random_u8() -> u8 { + let mut r = thread_rng(); + let s: u8 = r.gen(); + s + } + pub fn write_slice_to_file(source: &[String], cols: usize, path: &str) -> Result<(), String> { use indicatif::ProgressBar; @@ -132,6 +142,14 @@ fn main() { .takes_value(true) .default_value("0"), ) + .arg( + Arg::with_name("features") + .short("f") + .long("features") + .value_name("FEATURES") + .help("number of features") + .takes_value(false), + ) .get_matches(); let size = matches @@ -145,16 +163,23 @@ fn main() { .unwrap() .parse::() .expect("size param"); + + let gen_features = matches.is_present("features"); let dir = matches.value_of("dir").unwrap_or("./"); let fn_a = format!("{}/input_{}_size_{}_cols_{}.csv", dir, "a", size, cols); let fn_b = format!("{}/input_{}_size_{}_cols_{}.csv", dir, "b", size, cols); + let fn_b_features = format!( + "{}/input_{}_size_{}_cols_{}_features.csv", + dir, "b", size, cols + ); info!("Generating output of size {}", size); info!("Player a output: {}", fn_a); info!("Player b output: {}", fn_b); + info!("Player b features: {}", fn_b_features); - let intrsct = size / 2 as usize; + let intrsct = size / 2_usize; let size_player = size - intrsct; let data = gen::random_data(size_player, size_player, intrsct); info!("Data generation done, writing to files"); @@ -164,6 +189,11 @@ fn main() { gen::write_slice_to_file(&data.player_b, cols, &fn_b).unwrap(); info!("File {} finished", fn_b); + if gen_features { + gen::write_slice_to_file(&data.player_b_values.unwrap(), 0, &fn_b_features).unwrap(); + info!("File {} finished", fn_b_features); + } + info!("Bye!"); } diff --git a/common/src/files.rs b/common/src/files.rs index 1227d8c..8709fae 100644 --- a/common/src/files.rs +++ b/common/src/files.rs @@ -88,6 +88,32 @@ where .collect::>>() } +/// Reads CSV file into vector of rows, +/// where each row is represented as a vector of u64 +/// All zero length fields are removed +pub fn read_csv_as_u64(filename: T) -> Vec> +where + T: AsRef, +{ + let mut reader = csv::ReaderBuilder::new() + .delimiter(b',') + .flexible(false) + .has_headers(false) + .from_path(filename) + .expect("Failure reading CSV file"); + + let it = reader.records(); + it.map(|x| { + x.unwrap() + .iter() + .map(|z| { + u64::from_str(z.trim()).unwrap_or_else(|_| panic!("Cannot format {} as u64", z)) + }) + .collect::>() + }) + .collect::>>() +} + /// Reads CSV file into vector of rows, /// where each row is a first as key, and then as interger-like values pub fn read_csv_as_keyed_nums(filename: T, has_headers: bool) -> Vec> diff --git a/crypto/Cargo.toml b/crypto/Cargo.toml index 93f35f6..a7d5d58 100644 --- a/crypto/Cargo.toml +++ b/crypto/Cargo.toml @@ -19,7 +19,7 @@ rand = "0.8" rand_core = "0.5.1" curve25519-dalek = "3.2" Cupcake = { git = "https://github.com/facebookresearch/Cupcake"} -rayon = "1.3.0" +rayon = "1.8.0" serde = {version = "1.0.104", features = ["derive"] } bincode = "1.2.1" num-bigint = { version = "0.4", features = ["rand"] } diff --git a/crypto/src/prelude.rs b/crypto/src/prelude.rs index a46c831..2d017a6 100644 --- a/crypto/src/prelude.rs +++ b/crypto/src/prelude.rs @@ -8,6 +8,7 @@ pub use curve25519_dalek::ristretto::RistrettoPoint; pub use curve25519_dalek::scalar; pub use curve25519_dalek::scalar::Scalar; pub use curve25519_dalek::traits::Identity; +pub use curve25519_dalek::traits::IsIdentity; pub use crate::spoint::ByteBuffer; diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..14ff892 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,231 @@ +version: '3.0' + +services: + +# Datagen + + datagen: + container_name: 'datagen' + profiles: ['private-id', 'dpmc', 'dspmc'] + build: + context: . + entrypoint: + - '/opt/private-id/bin/datagen' + command: '--size ${ENV_VARIABLE_FOR_SIZE:-10} --cols 1 --features -d /etc/example/' + volumes: + - './common/datagen:/etc/example/' + +# Private-ID + + private-id-server: + container_name: 'private-id-server' + profiles: ['private-id'] + depends_on: + datagen: + condition: service_completed_successfully + build: + context: . + entrypoint: '/opt/private-id/bin/private-id-server' + command: >- + --host 0.0.0.0:10009 + --input /etc/example/private-id/company.csv + --stdout + --no-tls + environment: + - 'RUST_LOG=info' + volumes: + - './common/datagen/input_a_size_${ENV_VARIABLE_FOR_SIZE:-10}_cols_1.csv:/etc/example/private-id/company.csv' + + private-id-client: + container_name: 'private-id-client' + profiles: ['private-id'] + depends_on: + datagen: + condition: service_completed_successfully + private-id-server: + condition: service_started + build: + context: . + entrypoint: '/opt/private-id/bin/private-id-client' + command: >- + --company company-host:10009 + --input /etc/example/private-id/partner.csv + --stdout + --no-tls + environment: + - 'RUST_LOG=info' + links: + - 'private-id-server:company-host' + volumes: + - './common/datagen/input_b_size_${ENV_VARIABLE_FOR_SIZE:-10}_cols_1.csv:/etc/example/private-id/partner.csv' + +# DPMC + + dpmc-company-server: + container_name: 'dpmc-company-server' + profiles: ['dpmc'] + depends_on: + datagen: + condition: service_completed_successfully + build: + context: . + entrypoint: '/opt/private-id/bin/dpmc-company-server' + command: >- + --host 0.0.0.0:10010 + --input /etc/example/dpmc/company.csv + --stdout + --output-shares-path /etc/example/dpmc/output_company + --no-tls + environment: + - 'RUST_LOG=info' + volumes: + - './common/datagen/input_a_size_${ENV_VARIABLE_FOR_SIZE:-10}_cols_1.csv:/etc/example/dpmc/company.csv' + + dpmc-partner-server: + container_name: 'dpmc-partner-server' + profiles: ['dpmc'] + depends_on: + datagen: + condition: service_completed_successfully + dpmc-company-server: + condition: service_started + build: + context: . + entrypoint: '/opt/private-id/bin/dpmc-partner-server' + command: >- + --host 0.0.0.0:10020 + --company company-host:10010 + --input-keys /etc/example/dpmc/partner_1.csv + --input-features /etc/example/dpmc/partner_1_features.csv + --no-tls + environment: + - 'RUST_LOG=info' + links: + - 'dpmc-company-server:company-host' + volumes: + - './common/datagen/input_b_size_${ENV_VARIABLE_FOR_SIZE:-10}_cols_1.csv:/etc/example/dpmc/partner_1.csv' + - './common/datagen/input_b_size_${ENV_VARIABLE_FOR_SIZE:-10}_cols_1_features.csv:/etc/example/dpmc/partner_1_features.csv' + + dpmc-helper: + container_name: 'dpmc-helper' + profiles: ['dpmc'] + depends_on: + datagen: + condition: service_completed_successfully + dpmc-company-server: + condition: service_started + dpmc-partner-server: + condition: service_started + build: + context: . + entrypoint: '/opt/private-id/bin/dpmc-helper' + command: >- + --company company-host:10010 + --partners partner-host:10020 + --stdout --output-shares-path /etc/example/dpmc/output_partner + --no-tls + environment: + - 'RUST_LOG=info' + links: + - 'dpmc-company-server:company-host' + - 'dpmc-partner-server:partner-host' + volumes: + - './etc/example/dpmc/:/etc/example/dpmc/' + +# DsPMC + + dspmc-helper-server: + container_name: 'dspmc-helper-server' + profiles: ['dspmc'] + depends_on: + datagen: + condition: service_completed_successfully + build: + context: . + entrypoint: '/opt/private-id/bin/dspmc-helper-server' + command: >- + --host 0.0.0.0:10030 + --stdout + --output-shares-path /etc/example/dspmc/output_helper + --no-tls + environment: + - 'RUST_LOG=info' + volumes: + - './etc/example/dspmc/:/etc/example/dspmc/' + + dspmc-company-server: + container_name: 'dspmc-company-server' + profiles: ['dspmc'] + depends_on: + datagen: + condition: service_completed_successfully + dspmc-helper-server: + condition: service_started + build: + context: . + entrypoint: '/opt/private-id/bin/dspmc-company-server' + command: >- + --host 0.0.0.0:10010 + --helper helper-host:10030 + --input /etc/example/dspmc/company.csv + --stdout + --output-shares-path /etc/example/dspmc/output_company --no-tls + environment: + - 'RUST_LOG=info' + links: + - 'dspmc-helper-server:helper-host' + volumes: + - './common/datagen/input_a_size_${ENV_VARIABLE_FOR_SIZE:-10}_cols_1.csv:/etc/example/dspmc/company.csv' + + dspmc-partner-server: + container_name: 'dspmc-partner-server' + profiles: ['dspmc'] + depends_on: + datagen: + condition: service_completed_successfully + dspmc-company-server: + condition: service_started + build: + context: . + entrypoint: '/opt/private-id/bin/dspmc-partner-server' + command: >- + --host 0.0.0.0:10020 + --company company-host:10010 + --input-keys /etc/example/dspmc/partner_1.csv + --input-features /etc/example/dspmc/partner_1_features.csv + --no-tls + environment: + - 'RUST_LOG=info' + links: + - 'dspmc-company-server:company-host' + volumes: + - './common/datagen/input_b_size_${ENV_VARIABLE_FOR_SIZE:-10}_cols_1.csv:/etc/example/dspmc/partner_1.csv' + - './common/datagen/input_b_size_${ENV_VARIABLE_FOR_SIZE:-10}_cols_1_features.csv:/etc/example/dspmc/partner_1_features.csv' + + dspmc-shuffler: + container_name: 'dspmc-shuffler' + profiles: ['dspmc'] + depends_on: + datagen: + condition: service_completed_successfully + dspmc-company-server: + condition: service_started + dspmc-helper-server: + condition: service_started + dspmc-partner-server: + condition: service_started + build: + context: . + entrypoint: '/opt/private-id/bin/dspmc-shuffler' + command: >- + --company company-host:10010 + --helper helper-host:10030 + --partners partner-host:10020 + --stdout + --no-tls + environment: + - 'RUST_LOG=info' + links: + - 'dspmc-helper-server:helper-host' + - 'dspmc-company-server:company-host' + - 'dspmc-partner-server:partner-host' diff --git a/etc/example/dpmc/Ex0_company.csv b/etc/example/dpmc/Ex0_company.csv new file mode 100644 index 0000000..4fd374e --- /dev/null +++ b/etc/example/dpmc/Ex0_company.csv @@ -0,0 +1,4 @@ +email1 +email2 +email3 +email4 diff --git a/etc/example/dpmc/Ex0_partner_1.csv b/etc/example/dpmc/Ex0_partner_1.csv new file mode 100644 index 0000000..ce2fbfd --- /dev/null +++ b/etc/example/dpmc/Ex0_partner_1.csv @@ -0,0 +1,2 @@ +email1 +email7 diff --git a/etc/example/dpmc/Ex0_partner_1_features.csv b/etc/example/dpmc/Ex0_partner_1_features.csv new file mode 100644 index 0000000..af1e29e --- /dev/null +++ b/etc/example/dpmc/Ex0_partner_1_features.csv @@ -0,0 +1,2 @@ +10, 0 +50, 50 diff --git a/etc/example/dpmc/Ex0_partner_2.csv b/etc/example/dpmc/Ex0_partner_2.csv new file mode 100644 index 0000000..9c5a308 --- /dev/null +++ b/etc/example/dpmc/Ex0_partner_2.csv @@ -0,0 +1,2 @@ +email1 +email4 diff --git a/etc/example/dpmc/Ex0_partner_2_features.csv b/etc/example/dpmc/Ex0_partner_2_features.csv new file mode 100644 index 0000000..a32a3b5 --- /dev/null +++ b/etc/example/dpmc/Ex0_partner_2_features.csv @@ -0,0 +1,2 @@ +20, 21 +30, 31 diff --git a/etc/example/dpmc/Ex1_company.csv b/etc/example/dpmc/Ex1_company.csv new file mode 100644 index 0000000..d4cd8d9 --- /dev/null +++ b/etc/example/dpmc/Ex1_company.csv @@ -0,0 +1,3 @@ +email1,phone1 +phone2, +email3, diff --git a/etc/example/dpmc/Ex1_partner.csv b/etc/example/dpmc/Ex1_partner.csv new file mode 100644 index 0000000..aca4e62 --- /dev/null +++ b/etc/example/dpmc/Ex1_partner.csv @@ -0,0 +1,3 @@ +email1,phone2 +phone1, +email3, diff --git a/etc/example/dpmc/Ex1_partner_features.csv b/etc/example/dpmc/Ex1_partner_features.csv new file mode 100644 index 0000000..e26a103 --- /dev/null +++ b/etc/example/dpmc/Ex1_partner_features.csv @@ -0,0 +1,3 @@ +0, 4, 3 +1, 2, 3 +3, 2, 0 diff --git a/etc/example/dpmc/Ex2_company.csv b/etc/example/dpmc/Ex2_company.csv new file mode 100644 index 0000000..d4cd8d9 --- /dev/null +++ b/etc/example/dpmc/Ex2_company.csv @@ -0,0 +1,3 @@ +email1,phone1 +phone2, +email3, diff --git a/etc/example/dpmc/Ex2_partner.csv b/etc/example/dpmc/Ex2_partner.csv new file mode 100644 index 0000000..8753ea0 --- /dev/null +++ b/etc/example/dpmc/Ex2_partner.csv @@ -0,0 +1,3 @@ +phone1,phone2 +email1, +email3, diff --git a/etc/example/dpmc/Ex2_partner_features.csv b/etc/example/dpmc/Ex2_partner_features.csv new file mode 100644 index 0000000..7633f14 --- /dev/null +++ b/etc/example/dpmc/Ex2_partner_features.csv @@ -0,0 +1,3 @@ +2, 0 +10, 9 +5, 10 diff --git a/etc/example/dpmc/Ex3_company.csv b/etc/example/dpmc/Ex3_company.csv new file mode 100644 index 0000000..3c0ea06 --- /dev/null +++ b/etc/example/dpmc/Ex3_company.csv @@ -0,0 +1,3 @@ +email1 +phone2 +email3 diff --git a/etc/example/dpmc/Ex3_partner.csv b/etc/example/dpmc/Ex3_partner.csv new file mode 100644 index 0000000..cdc0d04 --- /dev/null +++ b/etc/example/dpmc/Ex3_partner.csv @@ -0,0 +1,3 @@ +phone2 +phone1 +email3 diff --git a/etc/example/dpmc/Ex3_partner_features.csv b/etc/example/dpmc/Ex3_partner_features.csv new file mode 100644 index 0000000..3fe01bc --- /dev/null +++ b/etc/example/dpmc/Ex3_partner_features.csv @@ -0,0 +1,3 @@ +8, 5 +5, 3 +7, 1 diff --git a/etc/example/dpmc/Ex4_company.csv b/etc/example/dpmc/Ex4_company.csv new file mode 100644 index 0000000..01bf2dd --- /dev/null +++ b/etc/example/dpmc/Ex4_company.csv @@ -0,0 +1,3 @@ +email1,phone1,zip1,fnln1 +email2,gahsyas,adhsauishjd,xybsgh +email3,phone3,zip3,fnln3 diff --git a/etc/example/dpmc/Ex4_partner.csv b/etc/example/dpmc/Ex4_partner.csv new file mode 100644 index 0000000..1628815 --- /dev/null +++ b/etc/example/dpmc/Ex4_partner.csv @@ -0,0 +1,3 @@ +fnln1,fnln3 +phone1, +zip3, diff --git a/etc/example/dpmc/Ex4_partner_features.csv b/etc/example/dpmc/Ex4_partner_features.csv new file mode 100644 index 0000000..f26e125 --- /dev/null +++ b/etc/example/dpmc/Ex4_partner_features.csv @@ -0,0 +1,3 @@ +5 +0 +1 diff --git a/etc/example/dpmc/Ex5_company.csv b/etc/example/dpmc/Ex5_company.csv new file mode 100644 index 0000000..01bf2dd --- /dev/null +++ b/etc/example/dpmc/Ex5_company.csv @@ -0,0 +1,3 @@ +email1,phone1,zip1,fnln1 +email2,gahsyas,adhsauishjd,xybsgh +email3,phone3,zip3,fnln3 diff --git a/etc/example/dpmc/Ex5_partner.csv b/etc/example/dpmc/Ex5_partner.csv new file mode 100644 index 0000000..ba834d6 --- /dev/null +++ b/etc/example/dpmc/Ex5_partner.csv @@ -0,0 +1,3 @@ +email1 +email2 +email3 diff --git a/etc/example/dpmc/Ex5_partner_features.csv b/etc/example/dpmc/Ex5_partner_features.csv new file mode 100644 index 0000000..4ee9268 --- /dev/null +++ b/etc/example/dpmc/Ex5_partner_features.csv @@ -0,0 +1,3 @@ +2, 3, 0, 1 +0, 4, 3, 5 +10, 21, 5, 9 diff --git a/etc/example/dpmc/Ex6_company.csv b/etc/example/dpmc/Ex6_company.csv new file mode 100644 index 0000000..3fcf012 --- /dev/null +++ b/etc/example/dpmc/Ex6_company.csv @@ -0,0 +1,3 @@ +email1,phone1, +email2,phone2,zip2 +email3,phone3,zip3 diff --git a/etc/example/dpmc/Ex6_partner.csv b/etc/example/dpmc/Ex6_partner.csv new file mode 100644 index 0000000..e87690b --- /dev/null +++ b/etc/example/dpmc/Ex6_partner.csv @@ -0,0 +1,3 @@ +email1,fnln1 +zip2, +phone3, diff --git a/etc/example/dpmc/Ex6_partner_features.csv b/etc/example/dpmc/Ex6_partner_features.csv new file mode 100644 index 0000000..15e01ff --- /dev/null +++ b/etc/example/dpmc/Ex6_partner_features.csv @@ -0,0 +1,3 @@ +0 +21 +9 diff --git a/etc/example/dspmc/Ex0_company.csv b/etc/example/dspmc/Ex0_company.csv new file mode 100644 index 0000000..4fd374e --- /dev/null +++ b/etc/example/dspmc/Ex0_company.csv @@ -0,0 +1,4 @@ +email1 +email2 +email3 +email4 diff --git a/etc/example/dspmc/Ex0_partner_1.csv b/etc/example/dspmc/Ex0_partner_1.csv new file mode 100644 index 0000000..6177a12 --- /dev/null +++ b/etc/example/dspmc/Ex0_partner_1.csv @@ -0,0 +1 @@ +email3 diff --git a/etc/example/dspmc/Ex0_partner_1_features.csv b/etc/example/dspmc/Ex0_partner_1_features.csv new file mode 100644 index 0000000..9e1fd09 --- /dev/null +++ b/etc/example/dspmc/Ex0_partner_1_features.csv @@ -0,0 +1 @@ +10, 11, 12 diff --git a/etc/example/dspmc/Ex0_partner_2.csv b/etc/example/dspmc/Ex0_partner_2.csv new file mode 100644 index 0000000..94ef675 --- /dev/null +++ b/etc/example/dspmc/Ex0_partner_2.csv @@ -0,0 +1,2 @@ +email1 +email2 diff --git a/etc/example/dspmc/Ex0_partner_2_features.csv b/etc/example/dspmc/Ex0_partner_2_features.csv new file mode 100644 index 0000000..fa021b7 --- /dev/null +++ b/etc/example/dspmc/Ex0_partner_2_features.csv @@ -0,0 +1,2 @@ +100, 101, 102 +300, 301, 302 diff --git a/etc/example/dspmc/Ex1_company.csv b/etc/example/dspmc/Ex1_company.csv new file mode 100644 index 0000000..d4cd8d9 --- /dev/null +++ b/etc/example/dspmc/Ex1_company.csv @@ -0,0 +1,3 @@ +email1,phone1 +phone2, +email3, diff --git a/etc/example/dspmc/Ex1_partner.csv b/etc/example/dspmc/Ex1_partner.csv new file mode 100644 index 0000000..aca4e62 --- /dev/null +++ b/etc/example/dspmc/Ex1_partner.csv @@ -0,0 +1,3 @@ +email1,phone2 +phone1, +email3, diff --git a/etc/example/dspmc/Ex1_partner_features.csv b/etc/example/dspmc/Ex1_partner_features.csv new file mode 100644 index 0000000..e26a103 --- /dev/null +++ b/etc/example/dspmc/Ex1_partner_features.csv @@ -0,0 +1,3 @@ +0, 4, 3 +1, 2, 3 +3, 2, 0 diff --git a/etc/example/dspmc/Ex2_company.csv b/etc/example/dspmc/Ex2_company.csv new file mode 100644 index 0000000..d4cd8d9 --- /dev/null +++ b/etc/example/dspmc/Ex2_company.csv @@ -0,0 +1,3 @@ +email1,phone1 +phone2, +email3, diff --git a/etc/example/dspmc/Ex2_partner.csv b/etc/example/dspmc/Ex2_partner.csv new file mode 100644 index 0000000..8753ea0 --- /dev/null +++ b/etc/example/dspmc/Ex2_partner.csv @@ -0,0 +1,3 @@ +phone1,phone2 +email1, +email3, diff --git a/etc/example/dspmc/Ex2_partner_features.csv b/etc/example/dspmc/Ex2_partner_features.csv new file mode 100644 index 0000000..7633f14 --- /dev/null +++ b/etc/example/dspmc/Ex2_partner_features.csv @@ -0,0 +1,3 @@ +2, 0 +10, 9 +5, 10 diff --git a/etc/example/dspmc/Ex3_company.csv b/etc/example/dspmc/Ex3_company.csv new file mode 100644 index 0000000..3c0ea06 --- /dev/null +++ b/etc/example/dspmc/Ex3_company.csv @@ -0,0 +1,3 @@ +email1 +phone2 +email3 diff --git a/etc/example/dspmc/Ex3_partner.csv b/etc/example/dspmc/Ex3_partner.csv new file mode 100644 index 0000000..cdc0d04 --- /dev/null +++ b/etc/example/dspmc/Ex3_partner.csv @@ -0,0 +1,3 @@ +phone2 +phone1 +email3 diff --git a/etc/example/dspmc/Ex3_partner_features.csv b/etc/example/dspmc/Ex3_partner_features.csv new file mode 100644 index 0000000..3fe01bc --- /dev/null +++ b/etc/example/dspmc/Ex3_partner_features.csv @@ -0,0 +1,3 @@ +8, 5 +5, 3 +7, 1 diff --git a/etc/example/dspmc/Ex4_company.csv b/etc/example/dspmc/Ex4_company.csv new file mode 100644 index 0000000..01bf2dd --- /dev/null +++ b/etc/example/dspmc/Ex4_company.csv @@ -0,0 +1,3 @@ +email1,phone1,zip1,fnln1 +email2,gahsyas,adhsauishjd,xybsgh +email3,phone3,zip3,fnln3 diff --git a/etc/example/dspmc/Ex4_partner.csv b/etc/example/dspmc/Ex4_partner.csv new file mode 100644 index 0000000..1628815 --- /dev/null +++ b/etc/example/dspmc/Ex4_partner.csv @@ -0,0 +1,3 @@ +fnln1,fnln3 +phone1, +zip3, diff --git a/etc/example/dspmc/Ex4_partner_features.csv b/etc/example/dspmc/Ex4_partner_features.csv new file mode 100644 index 0000000..f26e125 --- /dev/null +++ b/etc/example/dspmc/Ex4_partner_features.csv @@ -0,0 +1,3 @@ +5 +0 +1 diff --git a/etc/example/dspmc/Ex5_company.csv b/etc/example/dspmc/Ex5_company.csv new file mode 100644 index 0000000..01bf2dd --- /dev/null +++ b/etc/example/dspmc/Ex5_company.csv @@ -0,0 +1,3 @@ +email1,phone1,zip1,fnln1 +email2,gahsyas,adhsauishjd,xybsgh +email3,phone3,zip3,fnln3 diff --git a/etc/example/dspmc/Ex5_partner.csv b/etc/example/dspmc/Ex5_partner.csv new file mode 100644 index 0000000..ba834d6 --- /dev/null +++ b/etc/example/dspmc/Ex5_partner.csv @@ -0,0 +1,3 @@ +email1 +email2 +email3 diff --git a/etc/example/dspmc/Ex5_partner_features.csv b/etc/example/dspmc/Ex5_partner_features.csv new file mode 100644 index 0000000..4ee9268 --- /dev/null +++ b/etc/example/dspmc/Ex5_partner_features.csv @@ -0,0 +1,3 @@ +2, 3, 0, 1 +0, 4, 3, 5 +10, 21, 5, 9 diff --git a/etc/example/dspmc/Ex6_company.csv b/etc/example/dspmc/Ex6_company.csv new file mode 100644 index 0000000..3fcf012 --- /dev/null +++ b/etc/example/dspmc/Ex6_company.csv @@ -0,0 +1,3 @@ +email1,phone1, +email2,phone2,zip2 +email3,phone3,zip3 diff --git a/etc/example/dspmc/Ex6_partner.csv b/etc/example/dspmc/Ex6_partner.csv new file mode 100644 index 0000000..e87690b --- /dev/null +++ b/etc/example/dspmc/Ex6_partner.csv @@ -0,0 +1,3 @@ +email1,fnln1 +zip2, +phone3, diff --git a/etc/example/dspmc/Ex6_partner_features.csv b/etc/example/dspmc/Ex6_partner_features.csv new file mode 100644 index 0000000..15e01ff --- /dev/null +++ b/etc/example/dspmc/Ex6_partner_features.csv @@ -0,0 +1,3 @@ +0 +21 +9 diff --git a/protocol-rpc/Cargo.toml b/protocol-rpc/Cargo.toml index 489623e..248d8b3 100644 --- a/protocol-rpc/Cargo.toml +++ b/protocol-rpc/Cargo.toml @@ -57,6 +57,34 @@ path = "src/rpc/suid-create/server.rs" name = "suid-create-client" path = "src/rpc/suid-create/client.rs" +[[bin]] +name = "dpmc-company-server" +path = "src/rpc/dpmc/company-server.rs" + +[[bin]] +name = "dpmc-partner-server" +path = "src/rpc/dpmc/partner-server.rs" + +[[bin]] +name = "dpmc-helper" +path = "src/rpc/dpmc/client.rs" + +[[bin]] +name = "dspmc-company-server" +path = "src/rpc/dspmc/company-server.rs" + +[[bin]] +name = "dspmc-helper-server" +path = "src/rpc/dspmc/helper-server.rs" + +[[bin]] +name = "dspmc-partner-server" +path = "src/rpc/dspmc/partner-server.rs" + +[[bin]] +name = "dspmc-shuffler" +path = "src/rpc/dspmc/client.rs" + [lib] name = "rpc" path = "src/lib.rs" @@ -80,12 +108,12 @@ futures = { version = "0.3", features = ["thread-pool", "alloc"]} http = "0.2" url = "2.1.0" async-stream = "0.2" -rayon = "1.3.0" +rayon = "1.8.0" bytes = "0.4" clap = "2.33.4" csv = "1.1.1" indicatif = "0.13.0" -ctrlc = "3.1.3" +ctrlc = "3.2.3" retry = "0.5.1" bincode = "1.2.1" itertools = "0.9.0" diff --git a/protocol-rpc/build.rs b/protocol-rpc/build.rs index e938490..1c39699 100644 --- a/protocol-rpc/build.rs +++ b/protocol-rpc/build.rs @@ -11,6 +11,11 @@ fn main() -> Result<(), Box> { "crosspsixor.proto", "pjc.proto", "suidcreate.proto", + "dpmccompany.proto", + "dpmcpartner.proto", + "dspmccompany.proto", + "dspmchelper.proto", + "dspmcpartner.proto", ]; let out_env = if cfg!(fbcode_build) { "OUT" } else { "OUT_DIR" }; let out_dir = std::env::var_os(out_env).unwrap_or_else(|| panic!("env `{out_env}` is not set")); diff --git a/protocol-rpc/proto/dpmccompany.proto b/protocol-rpc/proto/dpmccompany.proto new file mode 100644 index 0000000..6b8a413 --- /dev/null +++ b/protocol-rpc/proto/dpmccompany.proto @@ -0,0 +1,54 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; +package dpmccompany; + +import "common.proto"; + +message Init {} +message InitAck {} +message UCompanyAck {} +message ECompanyAck {} +message UPartnerAck {} +message VPartnerAck {} +message SPartnerAck {} +message SPrimePartnerAck {} +message PartnerPublicKeyAck {} +message CalculateIdMapAck {} +message CalculateFeaturesXorSharesAck {} +message Commitment {} +message CommitmentAck {} + +message ServiceResponse { + oneof Ack { + Init init = 1; + InitAck init_ack = 2; + UCompanyAck u_company_ack = 3; + ECompanyAck e_company_ack = 5; + UPartnerAck u_partner_ack = 6; + VPartnerAck v_partner_ack = 7; + SPartnerAck s_partner_ack = 8; + SPrimePartnerAck s_prime_partner_ack = 9; + PartnerPublicKeyAck partner_public_key_ack = 10; + CalculateIdMapAck calculate_id_map_ack = 12; + CalculateFeaturesXorSharesAck calculate_features_xor_shares_ack = 13; + Commitment commitment = 14; + } +} + +service DpmcCompany { + rpc Initialize(Init) returns (ServiceResponse) {} + rpc RecvUCompany(ServiceResponse) returns (stream common.Payload) {} + + rpc SendUPartner(stream common.Payload) returns (ServiceResponse) {} + + rpc CalculateIdMap(Commitment) returns (CommitmentAck) {} + rpc CalculateFeaturesXorShares(stream common.Payload) + returns (ServiceResponse) {} + + rpc RecvCompanyPublicKey(ServiceResponse) returns (stream common.Payload) {} + rpc RecvVPartner(ServiceResponse) returns (stream common.Payload) {} + + rpc Reveal(Commitment) returns (CommitmentAck) {} +} diff --git a/protocol-rpc/proto/dpmcpartner.proto b/protocol-rpc/proto/dpmcpartner.proto new file mode 100644 index 0000000..159e551 --- /dev/null +++ b/protocol-rpc/proto/dpmcpartner.proto @@ -0,0 +1,43 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; +package dpmcpartner; + +import "common.proto"; + +message Init {} +message InitAck {} +message SendData {} +message SendDataAck {} +message CompanyPublicKeyAck {} +message IdMapIndicesAck {} +message Commitment {} +message CommitmentAck {} +message HelperPublicKeyAck {} + +message ServiceResponse { + oneof Ack { + Init init = 1; + InitAck init_ack = 2; + CompanyPublicKeyAck company_public_key_ack = 3; + IdMapIndicesAck id_map_indices_ack = 4; + Commitment commitment = 5; + SendData send_data = 6; + SendDataAck send_data_ack = 7; + HelperPublicKeyAck helper_public_key_ack = 8; + } +} + +service DpmcPartner { + rpc Initialize(Init) returns (ServiceResponse) {} + rpc SendDataToCompany(SendData) returns (ServiceResponse) {} + + rpc RecvPartnerPublicKey(ServiceResponse) returns (stream common.Payload) {} + + rpc SendCompanyPublicKey(stream common.Payload) returns (ServiceResponse) {} + + rpc SendHelperPublicKey(stream common.Payload) returns (ServiceResponse) {} + + rpc StopService(Commitment) returns (CommitmentAck) {} +} diff --git a/protocol-rpc/proto/dspmccompany.proto b/protocol-rpc/proto/dspmccompany.proto new file mode 100644 index 0000000..e92d7a1 --- /dev/null +++ b/protocol-rpc/proto/dspmccompany.proto @@ -0,0 +1,68 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; +package dspmccompany; + +import "common.proto"; + +message Init {} +message InitAck {} +message SendData {} +message SendDataAck {} +message RecvShares {} +message RecvSharesAck {} +message UCompanyAck {} +message ECompanyAck {} +message UPartnerAck {} +message VPartnerAck {} +message SPartnerAck {} +message SPrimePartnerAck {} +message PartnerPublicKeyAck {} +message HelperPublicKeyAck {} +message CalculateIdMapAck {} +message CalculateFeaturesXorSharesAck {} +message Commitment {} +message CommitmentAck {} + +message ServiceResponse { + oneof Ack { + Init init = 1; + InitAck init_ack = 2; + UCompanyAck u_company_ack = 3; + ECompanyAck e_company_ack = 5; + UPartnerAck u_partner_ack = 6; + VPartnerAck v_partner_ack = 7; + SPartnerAck s_partner_ack = 8; + SPrimePartnerAck s_prime_partner_ack = 9; + PartnerPublicKeyAck partner_public_key_ack = 10; + CalculateIdMapAck calculate_id_map_ack = 12; + CalculateFeaturesXorSharesAck calculate_features_xor_shares_ack = 13; + Commitment commitment = 14; + SendData send_data = 15; + SendDataAck send_data_ack = 16; + RecvShares recv_shares = 17; + RecvSharesAck recv_shares_ack = 18; + HelperPublicKeyAck helper_public_key_ack = 19; + } +} + +service DspmcCompany { + rpc Initialize(Init) returns (ServiceResponse) {} + rpc SendCt3PCdVCdToHelper(SendData) returns (ServiceResponse) {} + rpc SendU1ToHelper(SendData) returns (ServiceResponse) {} + rpc SendEncryptedKeysToHelper(SendData) returns (ServiceResponse) {} + + rpc SendHelperPublicKey(stream common.Payload) returns (ServiceResponse) {} + rpc SendPScVScCt1ct2dprime(stream common.Payload) returns (ServiceResponse) {} + rpc SendUPartner(stream common.Payload) returns (ServiceResponse) {} + + rpc RecvCompanyPublicKey(ServiceResponse) returns (stream common.Payload) {} + rpc RecvSharesFromHelper(RecvShares) returns (ServiceResponse) {} + rpc RecvPCsVCs(ServiceResponse) returns (stream common.Payload) {} + rpc RecvUCompany(ServiceResponse) returns (stream common.Payload) {} + + rpc CalculateIdMap(Commitment) returns (CommitmentAck) {} + + // rpc Reveal(Commitment) returns (CommitmentAck) {} +} diff --git a/protocol-rpc/proto/dspmchelper.proto b/protocol-rpc/proto/dspmchelper.proto new file mode 100644 index 0000000..ab7194f --- /dev/null +++ b/protocol-rpc/proto/dspmchelper.proto @@ -0,0 +1,61 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; +package dspmchelper; + +import "common.proto"; + +message Init {} +message InitAck {} +message SendData {} +message SendDataAck {} +message UHelperAck {} +message EHelperAck {} +message UPartnerAck {} +message VPartnerAck {} +message SPartnerAck {} +message SPrimePartnerAck {} +message CompanyPublicKeyAck {} +message PartnerPublicKeyAck {} +message CalculateIdMapAck {} +message CalculateFeaturesXorSharesAck {} +message Commitment {} +message CommitmentAck {} + +message ServiceResponse { + oneof Ack { + Init init = 1; + InitAck init_ack = 2; + UHelperAck u_helper_ack = 3; + EHelperAck e_helper_ack = 5; + UPartnerAck u_partner_ack = 6; + VPartnerAck v_partner_ack = 7; + SPartnerAck s_partner_ack = 8; + SPrimePartnerAck s_prime_partner_ack = 9; + PartnerPublicKeyAck partner_public_key_ack = 10; + CalculateIdMapAck calculate_id_map_ack = 12; + CalculateFeaturesXorSharesAck calculate_features_xor_shares_ack = 13; + Commitment commitment = 14; + CompanyPublicKeyAck company_public_key_ack = 15; + SendData send_data = 16; + SendDataAck send_data_ack = 17; + } +} + +service DspmcHelper { + rpc SendCompanyPublicKey(stream common.Payload) returns (ServiceResponse) {} + rpc SendEncryptedVprime(stream common.Payload) returns (ServiceResponse) {} + rpc SendEncryptedKeys(stream common.Payload) returns (ServiceResponse) {} + rpc SendCt3PCdVCd(stream common.Payload) returns (ServiceResponse) {} + rpc SendU1(stream common.Payload) returns (ServiceResponse) {} + rpc SendPSdVSd(stream common.Payload) returns (ServiceResponse) {} + + rpc RecvHelperPublicKey(ServiceResponse) returns (stream common.Payload) {} + rpc RecvXorShares(ServiceResponse) returns (stream common.Payload) {} + rpc RecvU2(ServiceResponse) returns (stream common.Payload) {} + + rpc CalculateIdMap(Commitment) returns (CommitmentAck) {} + rpc Reveal(Commitment) returns (CommitmentAck) {} + rpc StopService(Commitment) returns (CommitmentAck) {} +} diff --git a/protocol-rpc/proto/dspmcpartner.proto b/protocol-rpc/proto/dspmcpartner.proto new file mode 100644 index 0000000..6a23a4d --- /dev/null +++ b/protocol-rpc/proto/dspmcpartner.proto @@ -0,0 +1,42 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; +package dspmcpartner; + +import "common.proto"; + +message Init {} +message InitAck {} +message SendData {} +message SendDataAck {} +message CompanyPublicKeyAck {} +message IdMapIndicesAck {} +message Commitment {} +message CommitmentAck {} +message ShufflerPublicKeyAck {} +message HelperPublicKeyAck {} + +message ServiceResponse { + oneof Ack { + Init init = 1; + InitAck init_ack = 2; + CompanyPublicKeyAck company_public_key_ack = 3; + IdMapIndicesAck id_map_indices_ack = 4; + Commitment commitment = 5; + SendData send_data = 6; + SendDataAck send_data_ack = 7; + ShufflerPublicKeyAck shuffler_public_key_ack = 8; + HelperPublicKeyAck helper_public_key_ack = 9; + } +} + +service DspmcPartner { + rpc Initialize(Init) returns (ServiceResponse) {} + rpc SendDataToCompany(SendData) returns (ServiceResponse) {} + + rpc SendCompanyPublicKey(stream common.Payload) returns (ServiceResponse) {} + rpc SendHelperPublicKey(stream common.Payload) returns (ServiceResponse) {} + + rpc StopService(Commitment) returns (CommitmentAck) {} +} diff --git a/protocol-rpc/src/connect/create_client.rs b/protocol-rpc/src/connect/create_client.rs index 6d2a23b..92b9887 100644 --- a/protocol-rpc/src/connect/create_client.rs +++ b/protocol-rpc/src/connect/create_client.rs @@ -16,6 +16,11 @@ use tonic::transport::Endpoint; use crate::connect::tls; use crate::proto::gen_crosspsi::cross_psi_client::CrossPsiClient; use crate::proto::gen_crosspsi_xor::cross_psi_xor_client::CrossPsiXorClient; +use crate::proto::gen_dpmc_company::dpmc_company_client::DpmcCompanyClient; +use crate::proto::gen_dpmc_partner::dpmc_partner_client::DpmcPartnerClient; +use crate::proto::gen_dspmc_company::dspmc_company_client::DspmcCompanyClient; +use crate::proto::gen_dspmc_helper::dspmc_helper_client::DspmcHelperClient; +use crate::proto::gen_dspmc_partner::dspmc_partner_client::DspmcPartnerClient; use crate::proto::gen_pjc::pjc_client::PjcClient; use crate::proto::gen_private_id::private_id_client::PrivateIdClient; use crate::proto::gen_private_id_multi_key::private_id_multi_key_client::PrivateIdMultiKeyClient; @@ -108,11 +113,11 @@ pub fn create_client( }; let has_tls = maybe_tls.is_some(); let running = Arc::new(AtomicBool::new(true)); - let r = running.clone(); - ctrlc::set_handler(move || { - r.store(false, Ordering::SeqCst); - }) - .expect("Error setting Ctrl-C handler"); + let _r = running.clone(); + // ctrlc::set_handler(move || { + // r.store(false, Ordering::SeqCst); + // }) + // .expect("Error setting Ctrl-C handler"); let mut retry_count: u32 = 0; @@ -140,6 +145,11 @@ pub fn create_client( "cross-psi-xor" => RpcClient::CrossPsiXor(CrossPsiXorClient::new(conn)), "pjc" => RpcClient::Pjc(PjcClient::new(conn)), "suid-create" => RpcClient::SuidCreate(SuidCreateClient::new(conn)), + "dpmc-company" => RpcClient::DpmcCompany(DpmcCompanyClient::new(conn)), + "dpmc-partner" => RpcClient::DpmcPartner(DpmcPartnerClient::new(conn)), + "dspmc-company" => RpcClient::DspmcCompany(DspmcCompanyClient::new(conn)), + "dspmc-helper" => RpcClient::DspmcHelper(DspmcHelperClient::new(conn)), + "dspmc-partner" => RpcClient::DspmcPartner(DspmcPartnerClient::new(conn)), _ => panic!("wrong client"), }) } else { @@ -160,6 +170,21 @@ pub fn create_client( "suid-create" => Ok(RpcClient::SuidCreate( SuidCreateClient::connect(__uri).await.unwrap(), )), + "dpmc-company" => Ok(RpcClient::DpmcCompany( + DpmcCompanyClient::connect(__uri).await.unwrap(), + )), + "dpmc-partner" => Ok(RpcClient::DpmcPartner( + DpmcPartnerClient::connect(__uri).await.unwrap(), + )), + "dspmc-company" => Ok(RpcClient::DspmcCompany( + DspmcCompanyClient::connect(__uri).await.unwrap(), + )), + "dspmc-helper" => Ok(RpcClient::DspmcHelper( + DspmcHelperClient::connect(__uri).await.unwrap(), + )), + "dspmc-partner" => Ok(RpcClient::DspmcPartner( + DspmcPartnerClient::connect(__uri).await.unwrap(), + )), _ => panic!("wrong client"), } } diff --git a/protocol-rpc/src/proto/mod.rs b/protocol-rpc/src/proto/mod.rs index 522390d..46d7256 100644 --- a/protocol-rpc/src/proto/mod.rs +++ b/protocol-rpc/src/proto/mod.rs @@ -29,10 +29,35 @@ pub mod gen_suid_create { tonic::include_proto!("suidcreate"); } +pub mod gen_dpmc_company { + tonic::include_proto!("dpmccompany"); +} + +pub mod gen_dpmc_partner { + tonic::include_proto!("dpmcpartner"); +} + +pub mod gen_dspmc_company { + tonic::include_proto!("dspmccompany"); +} + +pub mod gen_dspmc_helper { + tonic::include_proto!("dspmchelper"); +} + +pub mod gen_dspmc_partner { + tonic::include_proto!("dspmcpartner"); +} + pub mod streaming; use gen_crosspsi::cross_psi_client::CrossPsiClient; use gen_crosspsi_xor::cross_psi_xor_client::CrossPsiXorClient; +use gen_dpmc_company::dpmc_company_client::DpmcCompanyClient; +use gen_dpmc_partner::dpmc_partner_client::DpmcPartnerClient; +use gen_dspmc_company::dspmc_company_client::DspmcCompanyClient; +use gen_dspmc_helper::dspmc_helper_client::DspmcHelperClient; +use gen_dspmc_partner::dspmc_partner_client::DspmcPartnerClient; use gen_pjc::pjc_client::PjcClient; use gen_private_id::private_id_client::PrivateIdClient; use gen_private_id_multi_key::private_id_multi_key_client::PrivateIdMultiKeyClient; @@ -45,6 +70,11 @@ pub enum RpcClient { CrossPsiXor(CrossPsiXorClient), Pjc(PjcClient), SuidCreate(SuidCreateClient), + DpmcCompany(DpmcCompanyClient), + DpmcPartner(DpmcPartnerClient), + DspmcCompany(DspmcCompanyClient), + DspmcHelper(DspmcHelperClient), + DspmcPartner(DspmcPartnerClient), } use crypto::prelude::*; diff --git a/protocol-rpc/src/rpc/dpmc/client.rs b/protocol-rpc/src/rpc/dpmc/client.rs new file mode 100644 index 0000000..cdb4931 --- /dev/null +++ b/protocol-rpc/src/rpc/dpmc/client.rs @@ -0,0 +1,449 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +use std::convert::TryInto; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; +use common::timer; +use crypto::prelude::TPayload; +use log::error; +use log::info; +use protocol::dpmc::helper::HelperDpmc; +use protocol::dpmc::traits::*; +use rpc::connect::create_client::create_client; +use rpc::proto::gen_dpmc_company::service_response::Ack as CompanyAck; +use rpc::proto::gen_dpmc_company::Init as CompanyInit; +use rpc::proto::gen_dpmc_company::ServiceResponse as CompanyServiceResponse; +use rpc::proto::gen_dpmc_partner::service_response::Ack as PartnerAck; +use rpc::proto::gen_dpmc_partner::Init as PartnerInit; +use rpc::proto::gen_dpmc_partner::SendData as PartnerSendData; +use rpc::proto::RpcClient; +use tonic::Request; + +mod rpc_client_company; +mod rpc_client_partner; + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + + // todo: move matches outside, or move to build.rs + let matches = App::new("Delegated Private Id MultiKey Helper") + .version("0.1") + .about("Delegated Private Id Multi Key Protocol") + .args(&[ + Arg::with_name("company") + .long("company") + .short("c") + .takes_value(true) + .required(true) + .help("Company host path to connect to, ex: 0.0.0.0:10009"), + Arg::with_name("partners") + .long("partners") + .short("p") + .takes_value(true) + .required(true) + .help("Partner host path to connect to, ex: 0.0.0.0:10010"), + Arg::with_name("output") + .long("output") + .short("o") + .takes_value(true) + .help("Path to output file for spine, output format: private-id, option(key)"), + Arg::with_name("stdout") + .long("stdout") + .short("u") + .takes_value(false) + .help("Prints the output to stdout rather than file"), + Arg::with_name("output-shares-path") + .long("output-shares-path") + .takes_value(true) + .help( + "path to write shares of features.\n + Feature will be written as {path}_partner_features.csv", + ), + Arg::with_name("one-to-many") + .long("one-to-many") + .takes_value(true) + .required(false) + .help( + "By default, DPMC generates one-to-one matches. Use this\n + flag to generate one(C)-to-many(P) matches.", + ), + Arg::with_name("no-tls") + .long("no-tls") + .takes_value(false) + .help("Turns tls off"), + Arg::with_name("tls-dir") + .long("tls-dir") + .takes_value(true) + .help( + "Path to directory with files with key, cert and ca.pem file\n + client: client.key, client.pem, ca.pem \n + server: server.key, server.pem, ca.pem \n", + ), + Arg::with_name("tls-key") + .long("tls-key") + .takes_value(true) + .requires("tls-cert") + .requires("tls-ca") + .help("Path to tls key (non-encrypted)"), + Arg::with_name("tls-cert") + .long("tls-cert") + .takes_value(true) + .requires("tls-key") + .requires("tls-ca") + .help( + "Path to tls certificate (pem format), SINGLE cert, \ + NO CHAINING, required by client as well", + ), + Arg::with_name("tls-ca") + .long("tls-ca") + .takes_value(true) + .requires("tls-key") + .requires("tls-cert") + .help("Path to root CA certificate issued cert and keys"), + Arg::with_name("tls-domain") + .long("tls-domain") + .takes_value(true) + .help("Override TLS domain for SSL cert (if host is IP)"), + ]) + .groups(&[ + ArgGroup::with_name("tls") + .args(&["no-tls", "tls-dir", "tls-key"]) + .required(true), + ArgGroup::with_name("out") + .args(&["output", "stdout"]) + .required(true), + ]) + .get_matches(); + + let global_timer = timer::Timer::new_silent("global"); + + let no_tls = matches.is_present("no-tls"); + let host_pre = matches.value_of("company"); + let tls_dir = matches.value_of("tls-dir"); + let tls_key = matches.value_of("tls-key"); + let tls_cert = matches.value_of("tls-cert"); + let tls_ca = matches.value_of("tls-ca"); + let tls_domain = matches.value_of("tls-domain"); + let one_to_many = { + match matches.value_of("one-to-many") { + Some(many) => many.parse::().unwrap(), + _ => 1, + } + }; + + let mut company_client_context = { + match create_client( + no_tls, + host_pre, + tls_dir, + tls_key, + tls_cert, + tls_ca, + tls_domain, + "dpmc-company".to_string(), + ) { + RpcClient::DpmcCompany(x) => x, + _ => panic!("wrong client"), + } + }; + + let mut partner_client_context = vec![]; + let partner_host_pre = matches.value_of("partners").unwrap().split(","); + for host_pre_i in partner_host_pre { + let partner_client_context_i = { + match create_client( + no_tls, + Some(host_pre_i), + tls_dir, + tls_key, + tls_cert, + tls_ca, + tls_domain, + "dpmc-partner".to_string(), + ) { + RpcClient::DpmcPartner(x) => x, + _ => panic!("wrong client"), + } + }; + partner_client_context.push(partner_client_context_i); + } + + let output_keys_path = matches.value_of("output"); + let output_shares_path = matches.value_of("output-shares-path"); + + // 1. Create helper protocol instance + let helper_protocol = HelperDpmc::new(); + + // 2. Initialize company - this loads company's data + let company_init_ack = match company_client_context + .initialize(Request::new(CompanyInit {})) + .await? + .into_inner() + .ack + .unwrap() + { + CompanyAck::InitAck(x) => x, + _ => panic!("wrong ack"), + }; + + // 3. Initialize partners - this loads partner's data + let mut partner_init_acks = vec![]; + for i in 0..partner_client_context.len() { + let partner_init_ack = match partner_client_context[i] + .initialize(Request::new(PartnerInit {})) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::InitAck(x) => x, + _ => panic!("wrong ack"), + }; + partner_init_acks.push(partner_init_ack); + } + + // 4. Get public key from company and send it to partners + // Send helper's public key to partners + { + let helper_public_key = helper_protocol.get_helper_public_key().unwrap(); + + let mut company_public_key = TPayload::new(); + let _ = rpc_client_company::recv( + CompanyServiceResponse { + ack: Some(CompanyAck::InitAck(company_init_ack.clone())), + }, + "company_public_key".to_string(), + &mut company_public_key, + &mut company_client_context, + ) + .await?; + + helper_protocol.set_company_public_key(company_public_key.clone())?; + + for i in 0..partner_client_context.len() { + // Send company public key + let _ = match rpc_client_partner::send( + company_public_key.clone(), + "company_public_key".to_string(), + &mut partner_client_context[i], + ) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::CompanyPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; + + // Send helper public key + let _ = match rpc_client_partner::send( + helper_public_key.clone(), + "helper_public_key".to_string(), + &mut partner_client_context[i], + ) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::HelperPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; + } + } + + // 5. Get company's data from company + // h_company_beta = H(C)^beta + // beta = company.private_key + { + let mut h_company_beta = TPayload::new(); + let _ = rpc_client_company::recv( + CompanyServiceResponse { + ack: Some(CompanyAck::InitAck(company_init_ack.clone())), + }, + "u_company".to_string(), + &mut h_company_beta, + &mut company_client_context, + ) + .await?; + + let offset_len = u64::from_le_bytes( + h_company_beta + .pop() + .unwrap() + .buffer + .as_slice() + .try_into() + .unwrap(), + ) as usize; + // flattened len + let data_len = u64::from_le_bytes( + h_company_beta + .pop() + .unwrap() + .buffer + .as_slice() + .try_into() + .unwrap(), + ) as usize; + + let offset = h_company_beta + .drain(data_len..) + .map(|b| u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize) + .collect::>(); + h_company_beta.shrink_to_fit(); + + assert_eq!(offset_len, offset.len()); + + // set H(C)^beta + helper_protocol.set_encrypted_company(h_company_beta, offset)?; + } + + // 6. Send requests to partners to send their data and shares to company + let mut partner_sent_data_acks = vec![]; + for i in 0..partner_client_context.len() { + let partner_sent_data_ack = match partner_client_context[i] + .send_data_to_company(Request::new(PartnerSendData {})) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::SendDataAck(x) => x, + _ => panic!("wrong ack"), + }; + partner_sent_data_acks.push(partner_sent_data_ack); + } + + // 7. Stop Partner service + for i in 0..partner_client_context.len() { + rpc_client_partner::stop_service(&mut partner_client_context[i]).await?; + } + + // 8. Receive partner's data from company, deserialize, and remove + // private exponent alpha. Also decrypt the XOR shares. + // input: h_partner_alpha_beta = H(P)^alpha^beta + // output: h_partner_beta = H(P)^beta + for _ in 0..partner_client_context.len() { + let mut h_partner_alpha_beta = TPayload::new(); + + let _ = rpc_client_company::recv( + CompanyServiceResponse { + ack: Some(CompanyAck::InitAck(company_init_ack.clone())), + }, + "v_partner".to_string(), + &mut h_partner_alpha_beta, + &mut company_client_context, + ) + .await?; + + let xor_shares_len = u64::from_le_bytes( + h_partner_alpha_beta + .pop() + .unwrap() + .buffer + .as_slice() + .try_into() + .unwrap(), + ) as usize; + + let xor_shares = h_partner_alpha_beta + .drain((h_partner_alpha_beta.len() - xor_shares_len)..) + .collect::>(); + + // Last element is the p_scalar_times_g + let p_scalar_times_g = h_partner_alpha_beta.pop().unwrap(); + + // Last element is the encrypted_alpha_t + let enc_alpha_t = h_partner_alpha_beta.pop().unwrap(); + + // deserialize ragged array + let num_partner_keys = u64::from_le_bytes( + h_partner_alpha_beta + .pop() + .unwrap() + .buffer + .as_slice() + .try_into() + .unwrap(), + ) as usize; + // flattened len + let data_len = u64::from_le_bytes( + h_partner_alpha_beta + .pop() + .unwrap() + .buffer + .as_slice() + .try_into() + .unwrap(), + ) as usize; + + let offset = h_partner_alpha_beta + .drain(data_len..) + .map(|b| u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize) + .collect::>(); + h_partner_alpha_beta.shrink_to_fit(); + + assert_eq!(num_partner_keys, offset.len()); + + // Perform 1/alpha, where alpha = partner.alpha. + // Then decrypt XOR secret shares and compute features and mask. + helper_protocol.remove_partner_scalar_from_p_and_set_shares( + h_partner_alpha_beta, + offset, + enc_alpha_t.buffer, + vec![p_scalar_times_g], + xor_shares, + )?; + } + + // 9. Calculate set diffs + for i in 0..partner_client_context.len() { + // Compute multi-key matches + helper_protocol.calculate_set_diff(i)?; + } + + // 10. Create helper's ID spine + // Compute ID map for LJ + helper_protocol.calculate_id_map(one_to_many); + + // 11. Create company's ID spine + rpc_client_company::calculate_id_map(&mut company_client_context).await?; + + // 12. Get XOR share of value from partner. Depends on Id-map + let v_d_prime = helper_protocol.calculate_features_xor_shares()?; + + // 13. Set XOR share of features for company + let _ = + rpc_client_company::calculate_features_xor_shares(v_d_prime, &mut company_client_context) + .await? + .into_inner() + .ack + .unwrap(); + + // 14. Print Company's ID spine and save partners shares + rpc_client_company::reveal(&mut company_client_context).await?; + + // 15. Print Helper's ID spine (same as Partners without the keys) + match output_keys_path { + Some(p) => helper_protocol.save_id_map(&String::from(p)).unwrap(), + None => helper_protocol.print_id_map(), + }; + + // 16. Print Helper's feature shares + match output_shares_path { + Some(p) => helper_protocol + .save_features_shares(&String::from(p)) + .unwrap(), + None => error!("Output features path not set. Can't output shares"), + }; + + global_timer.qps("total time", partner_client_context.len()); + info!("Bye!"); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dpmc/company-server.rs b/protocol-rpc/src/rpc/dpmc/company-server.rs new file mode 100644 index 0000000..ef457f6 --- /dev/null +++ b/protocol-rpc/src/rpc/dpmc/company-server.rs @@ -0,0 +1,177 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +#[macro_use] +extern crate log; +extern crate clap; +extern crate ctrlc; +extern crate tonic; + +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; +use log::info; + +mod rpc_server_company; +use rpc::connect::create_server::create_server; +use rpc::proto::gen_dpmc_company::dpmc_company_server; + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + + let matches = App::new("Delegated Private Id MultiKey Company") + .version("0.1") + .about("Private Id MultiKey Protocol") + .args(&[ + Arg::with_name("host") + .long("host") + .takes_value(true) + .required(true) + .help("Host path to connect to, ex: 0.0.0.0:10009"), + Arg::with_name("input") + .long("input") + .short("i") + .default_value("input.csv") + .help("Path to input file with keys"), + Arg::with_name("input-with-headers") + .long("input-with-headers") + .takes_value(false) + .help("Indicates if the input CSV contains headers"), + Arg::with_name("output") + .long("output") + .short("o") + .takes_value(true) + .help("Path to output file for keys only"), + Arg::with_name("stdout") + .long("stdout") + .short("u") + .takes_value(false) + .help("Prints the keys to stdout rather than file"), + Arg::with_name("output-shares-path") + .long("output-shares-path") + .takes_value(true) + .required(true) + .help( + "path to write shares of features.\n + Feature will be written as {path}_partner_features.csv", + ), + Arg::with_name("no-tls") + .long("no-tls") + .takes_value(false) + .help("Turns tls off"), + Arg::with_name("tls-dir") + .long("tls-dir") + .takes_value(true) + .help( + "Path to directory with files with key, cert and ca.pem file\n + client: client.key, client.pem, ca.pem \n + server: server.key, server.pem, ca.pem \n + ", + ), + Arg::with_name("tls-key") + .long("tls-key") + .takes_value(true) + .requires("tls-cert") + .requires("tls-ca") + .help("Path to tls key (non-encrypted)"), + Arg::with_name("tls-cert") + .long("tls-cert") + .takes_value(true) + .requires("tls-key") + .requires("tls-ca") + .help( + "Path to tls certificate (pem format), SINGLE cert, \ + NO CHAINING, required by client as well", + ), + Arg::with_name("tls-ca") + .long("tls-ca") + .takes_value(true) + .requires("tls-key") + .requires("tls-cert") + .help("Path to root CA certificate issued cert and keys"), + ]) + .groups(&[ + ArgGroup::with_name("tls") + .args(&["no-tls", "tls-dir", "tls-key"]) + .required(true), + ArgGroup::with_name("out") + .args(&["output", "stdout"]) + .required(true), + ]) + .get_matches(); + + let input_path = matches.value_of("input").unwrap_or("input.csv"); + let input_with_headers = matches.is_present("input-with-headers"); + let output_keys_path = matches.value_of("output"); + let output_shares_path = matches.value_of("output-shares-path"); + + let no_tls = matches.is_present("no-tls"); + let host = matches.value_of("host"); + let tls_dir = matches.value_of("tls-dir"); + let tls_key = matches.value_of("tls-key"); + let tls_cert = matches.value_of("tls-cert"); + let tls_ca = matches.value_of("tls-ca"); + + let (mut server, tx, rx) = create_server(no_tls, tls_dir, tls_key, tls_cert, tls_ca); + + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + info!("Input path: {}", input_path); + + if output_keys_path.is_some() { + info!("Output keys path: {}", output_keys_path.unwrap()); + } else { + info!("Output view to stdout (first 10 keys)"); + } + + if output_shares_path.is_some() { + info!("Output shares path: {}", output_shares_path.unwrap()); + } else { + error!("Output shares path not provided"); + } + + let service = rpc_server_company::DpmcCompanyService::new( + input_path, + output_keys_path, + output_shares_path, + input_with_headers, + ); + + let ks = service.killswitch.clone(); + let recv_thread = thread::spawn(move || { + let sleep_dur = time::Duration::from_millis(1000); + while !(ks.load(Ordering::Relaxed)) && running.load(Ordering::Relaxed) { + thread::sleep(sleep_dur); + } + + info!("Shutting down server ..."); + tx.send(()).unwrap(); + }); + + info!("Server starting at {}", host.unwrap()); + + let addr = host.unwrap().parse()?; + + server + .add_service(dpmc_company_server::DpmcCompanyServer::new(service)) + .serve_with_shutdown(addr, async { + rx.await.ok(); + }) + .await?; + + recv_thread.join().unwrap(); + info!("Bye!"); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dpmc/partner-server.rs b/protocol-rpc/src/rpc/dpmc/partner-server.rs new file mode 100644 index 0000000..d434057 --- /dev/null +++ b/protocol-rpc/src/rpc/dpmc/partner-server.rs @@ -0,0 +1,178 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +#[macro_use] +extern crate log; +extern crate clap; +extern crate ctrlc; +extern crate tonic; + +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; +use log::info; + +mod rpc_client_company; +mod rpc_server_partner; +use rpc::connect::create_client::create_client; +use rpc::connect::create_server::create_server; +use rpc::proto::gen_dpmc_partner::dpmc_partner_server; +use rpc::proto::RpcClient; + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + + let matches = App::new("Delegated Private Id MultiKey Partner") + .version("0.1") + .about("Delegated Private Id MultiKey Protocol") + .args(&[ + Arg::with_name("host") + .long("host") + .takes_value(true) + .required(true) + .help("Host path to connect to, ex: 0.0.0.0:10009"), + Arg::with_name("company") + .long("company") + .short("c") + .takes_value(true) + .required(true) + .help("Company host path to connect to, ex: 0.0.0.0:10009"), + Arg::with_name("input-keys") + .long("input-keys") + .default_value("input_keys.csv") + .help("Path to input file with keys"), + Arg::with_name("input-features") + .long("input-features") + .default_value("input_features.csv") + .help("Path to input file with keys"), + Arg::with_name("input-with-headers") + .long("input-with-headers") + .takes_value(false) + .help("Indicates if the input CSV contains headers"), + Arg::with_name("no-tls") + .long("no-tls") + .takes_value(false) + .help("Turns tls off"), + Arg::with_name("tls-dir") + .long("tls-dir") + .takes_value(true) + .help( + "Path to directory with files with key, cert and ca.pem file\n + client: client.key, client.pem, ca.pem \n + server: server.key, server.pem, ca.pem \n", + ), + Arg::with_name("tls-key") + .long("tls-key") + .takes_value(true) + .requires("tls-cert") + .requires("tls-ca") + .help("Path to tls key (non-encrypted)"), + Arg::with_name("tls-cert") + .long("tls-cert") + .takes_value(true) + .requires("tls-key") + .requires("tls-ca") + .help( + "Path to tls certificate (pem format), SINGLE cert, \ + NO CHAINING, required by client as well", + ), + Arg::with_name("tls-ca") + .long("tls-ca") + .takes_value(true) + .requires("tls-key") + .requires("tls-cert") + .help("Path to root CA certificate issued cert and keys"), + Arg::with_name("tls-domain") + .long("tls-domain") + .takes_value(true) + .help("Override TLS domain for SSL cert (if host is IP)"), + ]) + .groups(&[ArgGroup::with_name("tls") + .args(&["no-tls", "tls-dir", "tls-key"]) + .required(true)]) + .get_matches(); + + let input_keys_path = matches.value_of("input-keys").unwrap_or("input_keys.csv"); + let input_features_path = matches + .value_of("input-features") + .unwrap_or("input_features.csv"); + let input_with_headers = matches.is_present("input-with-headers"); + + let no_tls = matches.is_present("no-tls"); + let host = matches.value_of("host"); + let company_host = matches.value_of("company"); + let tls_dir = matches.value_of("tls-dir"); + let tls_key = matches.value_of("tls-key"); + let tls_cert = matches.value_of("tls-cert"); + let tls_ca = matches.value_of("tls-ca"); + let tls_domain = matches.value_of("tls-domain"); + + let company_client_context = { + match create_client( + no_tls, + company_host, + tls_dir, + tls_key, + tls_cert, + tls_ca, + tls_domain, + "dpmc-company".to_string(), + ) { + RpcClient::DpmcCompany(x) => x, + _ => panic!("wrong client"), + } + }; + + let (mut server, tx, rx) = create_server(no_tls, tls_dir, tls_key, tls_cert, tls_ca); + + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + info!("Input path for keys: {}", input_keys_path); + info!("Input path for features: {}", input_features_path); + + let service = rpc_server_partner::DpmcPartnerService::new( + input_keys_path, + input_features_path, + input_with_headers, + company_client_context, + ); + + let ks = service.killswitch.clone(); + let recv_thread = thread::spawn(move || { + let sleep_dur = time::Duration::from_millis(1000); + while !(ks.load(Ordering::Relaxed)) && running.load(Ordering::Relaxed) { + thread::sleep(sleep_dur); + } + + info!("Shutting down server ..."); + tx.send(()).unwrap(); + }); + + info!("Server starting at {}", host.unwrap()); + + let addr = host.unwrap().parse()?; + + server + .add_service(dpmc_partner_server::DpmcPartnerServer::new(service)) + .serve_with_shutdown(addr, async { + rx.await.ok(); + }) + .await?; + + recv_thread.join().unwrap(); + + info!("Bye!"); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dpmc/rpc_client_company.rs b/protocol-rpc/src/rpc/dpmc/rpc_client_company.rs new file mode 100644 index 0000000..d26ac22 --- /dev/null +++ b/protocol-rpc/src/rpc/dpmc/rpc_client_company.rs @@ -0,0 +1,61 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate common; +extern crate crypto; +extern crate protocol; + +use common::timer; +use crypto::prelude::TPayload; +use rpc::proto::gen_dpmc_company::dpmc_company_client::DpmcCompanyClient; +use rpc::proto::gen_dpmc_company::Commitment; +use rpc::proto::gen_dpmc_company::ServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; + +pub async fn recv( + response: ServiceResponse, + name: String, + data: &mut TPayload, + rpc: &mut DpmcCompanyClient, +) -> Result<(), Status> { + let t = timer::Builder::new().label(name.as_str()).build(); + + let request = Request::new(response); + let mut strm = match name.as_str() { + "company_public_key" => rpc.recv_company_public_key(request).await?.into_inner(), + "u_company" => rpc.recv_u_company(request).await?.into_inner(), + "v_partner" => rpc.recv_v_partner(request).await?.into_inner(), + _ => panic!("wrong data type"), + }; + + let res = read_from_stream(&mut strm).await?; + t.qps(format!("received {}", name.as_str()).as_str(), res.len()); + data.clear(); + data.extend(res); + Ok(()) +} + +pub async fn calculate_features_xor_shares( + data: TPayload, + rpc: &mut DpmcCompanyClient, +) -> Result, Status> { + rpc.calculate_features_xor_shares(send_data(data)).await +} + +pub async fn calculate_id_map(rpc: &mut DpmcCompanyClient) -> Result<(), Status> { + let _r = rpc + .calculate_id_map(Request::new(Commitment {})) + .await? + .into_inner(); + Ok(()) +} + +pub async fn reveal(rpc: &mut DpmcCompanyClient) -> Result<(), Status> { + let _r = rpc.reveal(Request::new(Commitment {})).await?.into_inner(); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dpmc/rpc_client_partner.rs b/protocol-rpc/src/rpc/dpmc/rpc_client_partner.rs new file mode 100644 index 0000000..10755c8 --- /dev/null +++ b/protocol-rpc/src/rpc/dpmc/rpc_client_partner.rs @@ -0,0 +1,36 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate common; +extern crate crypto; +extern crate protocol; + +use crypto::prelude::TPayload; +use rpc::proto::gen_dpmc_partner::dpmc_partner_client::DpmcPartnerClient; +use rpc::proto::gen_dpmc_partner::Commitment; +use rpc::proto::gen_dpmc_partner::ServiceResponse; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; + +pub async fn send( + data: TPayload, + name: String, + rpc: &mut DpmcPartnerClient, +) -> Result, Status> { + match name.as_str() { + "company_public_key" => rpc.send_company_public_key(send_data(data)).await, + "helper_public_key" => rpc.send_helper_public_key(send_data(data)).await, + _ => panic!("wrong data type"), + } +} + +pub async fn stop_service(rpc: &mut DpmcPartnerClient) -> Result<(), Status> { + let _r = rpc + .stop_service(Request::new(Commitment {})) + .await? + .into_inner(); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dpmc/rpc_server_company.rs b/protocol-rpc/src/rpc/dpmc/rpc_server_company.rs new file mode 100644 index 0000000..502ad68 --- /dev/null +++ b/protocol-rpc/src/rpc/dpmc/rpc_server_company.rs @@ -0,0 +1,246 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +use std::borrow::BorrowMut; +use std::convert::TryInto; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use common::timer; +use protocol::dpmc::company::CompanyDpmc; +use protocol::dpmc::traits::CompanyDpmcProtocol; +use protocol::shared::TFeatures; +use rpc::proto::common::Payload; +use rpc::proto::gen_dpmc_company::dpmc_company_server::DpmcCompany; +use rpc::proto::gen_dpmc_company::service_response::*; +use rpc::proto::gen_dpmc_company::CalculateFeaturesXorSharesAck; +use rpc::proto::gen_dpmc_company::Commitment; +use rpc::proto::gen_dpmc_company::CommitmentAck; +use rpc::proto::gen_dpmc_company::Init; +use rpc::proto::gen_dpmc_company::InitAck; +use rpc::proto::gen_dpmc_company::ServiceResponse; +use rpc::proto::gen_dpmc_company::UPartnerAck; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::write_to_stream; +use rpc::proto::streaming::TPayloadStream; +use tonic::Code; +use tonic::Request; +use tonic::Response; +use tonic::Status; +use tonic::Streaming; + +pub struct DpmcCompanyService { + protocol: CompanyDpmc, + input_path: String, + output_keys_path: Option, + output_shares_path: Option, + input_with_headers: bool, + pub killswitch: Arc, +} + +impl DpmcCompanyService { + pub fn new( + input_path: &str, + output_keys_path: Option<&str>, + output_shares_path: Option<&str>, + input_with_headers: bool, + ) -> DpmcCompanyService { + DpmcCompanyService { + protocol: CompanyDpmc::new(), + input_path: String::from(input_path), + output_keys_path: output_keys_path.map(String::from), + output_shares_path: output_shares_path.map(String::from), + input_with_headers, + killswitch: Arc::new(AtomicBool::new(false)), + } + } +} + +#[tonic::async_trait] +impl DpmcCompany for DpmcCompanyService { + type RecvUCompanyStream = TPayloadStream; + type RecvVPartnerStream = TPayloadStream; + type RecvCompanyPublicKeyStream = TPayloadStream; + + async fn initialize(&self, _: Request) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("init") + .build(); + self.protocol + .load_data(&self.input_path, self.input_with_headers); + Ok(Response::new(ServiceResponse { + ack: Some(Ack::InitAck(InitAck {})), + })) + } + + async fn calculate_id_map( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("calculate_id_map") + .build(); + self.protocol + .write_company_to_id_map() + .map(|_| Response::new(CommitmentAck {})) + .map_err(|_| Status::new(Code::Aborted, "cannot init the protocol for partner")) + } + + async fn calculate_features_xor_shares( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("calculate_features_xor_shares") + .build(); + let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; + + let num_features = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let num_rows = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let mask = data + .drain(num_features * num_rows..) + .map(|x| x) + .collect::>(); + let mut t = TFeatures::new(); + + for i in (0..num_features).rev() { + let x = data + .drain(i * num_rows..) + .map(|x| u64::from_le_bytes(x.buffer.as_slice().try_into().unwrap())) + .collect::>(); + t.push(x); + } + + self.protocol + .calculate_features_xor_shares(t, mask) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::CalculateFeaturesXorSharesAck( + CalculateFeaturesXorSharesAck {}, + )), + }) + }) + .map_err(|_| Status::internal("error calculating XOR shares")) + } + + async fn recv_company_public_key( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("recv_company_public_key") + .build(); + self.protocol + .get_company_public_key() + .map(write_to_stream) + .map_err(|_| Status::new(Code::Aborted, "cannot send company_public_key")) + } + + async fn recv_u_company( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("recv_u_company") + .build(); + self.protocol + .get_permuted_keys() + .map(write_to_stream) + .map_err(|_| Status::new(Code::Aborted, "cannot send u_company")) + } + + async fn recv_v_partner( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("recv_v_partner") + .build(); + self.protocol + .serialize_encrypted_keys_and_features() + .map(write_to_stream) + .map_err(|_| Status::new(Code::Aborted, "cannot init the protocol for partner")) + } + + async fn send_u_partner( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_u_partner") + .build(); + let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; + + let xor_shares_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + + let xor_shares = data + .drain((data.len() - xor_shares_len)..) + .collect::>(); + + let p_scalar_g = data.pop().unwrap(); + + let enc_alpha_t = data.pop().unwrap(); + + let offset_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let data_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + + let offset = data + .drain(data_len..) + .map(|b| u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize) + .collect::>(); + data.shrink_to_fit(); + + assert_eq!(offset_len, offset.len()); + + self.protocol + .set_encrypted_partner_keys_and_shares( + data, + offset, + enc_alpha_t.buffer, + p_scalar_g.buffer, + xor_shares, + ) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::UPartnerAck(UPartnerAck {})), + }) + }) + .map_err(|_| Status::internal("error loading")) + } + + async fn reveal(&self, _: Request) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("reveal") + .build(); + match &self.output_keys_path { + Some(p) => self.protocol.save_id_map(p).unwrap(), + None => self.protocol.print_id_map(), + } + + let resp = self + .protocol + .save_features_shares(&self.output_shares_path.clone().unwrap()) + .map(|_| Response::new(CommitmentAck {})) + .map_err(|_| Status::internal("error saving feature shares")); + { + debug!("Setting up flag for graceful down"); + self.killswitch.store(true, Ordering::SeqCst); + } + + resp + } +} diff --git a/protocol-rpc/src/rpc/dpmc/rpc_server_partner.rs b/protocol-rpc/src/rpc/dpmc/rpc_server_partner.rs new file mode 100644 index 0000000..7546cdd --- /dev/null +++ b/protocol-rpc/src/rpc/dpmc/rpc_server_partner.rs @@ -0,0 +1,178 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use common::timer; +use crypto::spoint::ByteBuffer; +use protocol::dpmc::partner::PartnerDpmc; +use protocol::dpmc::traits::PartnerDpmcProtocol; +use rpc::proto::common::Payload; +use rpc::proto::gen_dpmc_company::dpmc_company_client::DpmcCompanyClient; +use rpc::proto::gen_dpmc_partner::dpmc_partner_server::DpmcPartner; +use rpc::proto::gen_dpmc_partner::service_response::*; +use rpc::proto::gen_dpmc_partner::Commitment; +use rpc::proto::gen_dpmc_partner::CommitmentAck; +use rpc::proto::gen_dpmc_partner::CompanyPublicKeyAck; +use rpc::proto::gen_dpmc_partner::HelperPublicKeyAck; +use rpc::proto::gen_dpmc_partner::Init; +use rpc::proto::gen_dpmc_partner::InitAck; +use rpc::proto::gen_dpmc_partner::SendData; +use rpc::proto::gen_dpmc_partner::SendDataAck; +use rpc::proto::gen_dpmc_partner::ServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use rpc::proto::streaming::write_to_stream; +use rpc::proto::streaming::TPayloadStream; +use tonic::transport::Channel; +use tonic::Code; +use tonic::Request; +use tonic::Response; +use tonic::Status; +use tonic::Streaming; + +pub struct DpmcPartnerService { + protocol: PartnerDpmc, + input_keys_path: String, + input_features_path: String, + input_with_headers: bool, + company_client_context: DpmcCompanyClient, + pub killswitch: Arc, +} + +impl DpmcPartnerService { + pub fn new( + input_keys_path: &str, + input_features_path: &str, + input_with_headers: bool, + company_client_context: DpmcCompanyClient, + ) -> DpmcPartnerService { + DpmcPartnerService { + protocol: PartnerDpmc::new(), + input_keys_path: String::from(input_keys_path), + input_features_path: String::from(input_features_path), + input_with_headers, + company_client_context, + killswitch: Arc::new(AtomicBool::new(false)), + } + } +} + +#[tonic::async_trait] +impl DpmcPartner for DpmcPartnerService { + type RecvPartnerPublicKeyStream = TPayloadStream; + + async fn initialize(&self, _: Request) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("init") + .build(); + self.protocol.load_data( + &self.input_keys_path, + &self.input_features_path, + self.input_with_headers, + ); + Ok(Response::new(ServiceResponse { + ack: Some(Ack::InitAck(InitAck {})), + })) + } + + async fn send_data_to_company( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("init") + .build(); + + // Send partner data to company. - Partner acts as a client to company. + let mut h_partner_alpha = self.protocol.get_encrypted_keys().unwrap(); + let mut company_client_contxt = self.company_client_context.clone(); + + let xor_shares = self.protocol.get_features_xor_shares().unwrap(); + let xor_shares_len = xor_shares.len(); + h_partner_alpha.extend(xor_shares); + h_partner_alpha.push(ByteBuffer { + buffer: (xor_shares_len as u64).to_le_bytes().to_vec(), + }); + + _ = company_client_contxt + .send_u_partner(send_data(h_partner_alpha)) + .await; + + Ok(Response::new(ServiceResponse { + ack: Some(Ack::SendDataAck(SendDataAck {})), + })) + } + + async fn recv_partner_public_key( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("recv_partner_public_key") + .build(); + self.protocol + .get_partner_public_key() + .map(write_to_stream) + .map_err(|_| Status::new(Code::Aborted, "cannot send partner_public_key")) + } + + async fn send_company_public_key( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_company_public_key") + .build(); + let mut strm = request.into_inner(); + self.protocol + .set_company_public_key(read_from_stream(&mut strm).await?) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::CompanyPublicKeyAck(CompanyPublicKeyAck {})), + }) + }) + .map_err(|_| Status::internal("error writing")) + } + + async fn send_helper_public_key( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_helper_public_key") + .build(); + let mut strm = request.into_inner(); + self.protocol + .set_helper_public_key(read_from_stream(&mut strm).await?) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::HelperPublicKeyAck(HelperPublicKeyAck {})), + }) + }) + .map_err(|_| Status::internal("error writing")) + } + + async fn stop_service( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("stop") + .build(); + { + debug!("Setting up flag for graceful down"); + self.killswitch.store(true, Ordering::SeqCst); + } + + Ok(Response::new(CommitmentAck {})) + } +} diff --git a/protocol-rpc/src/rpc/dspmc/client.rs b/protocol-rpc/src/rpc/dspmc/client.rs new file mode 100644 index 0000000..d7ae9e4 --- /dev/null +++ b/protocol-rpc/src/rpc/dspmc/client.rs @@ -0,0 +1,519 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +use std::convert::TryInto; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; +use common::timer; +use crypto::prelude::TPayload; +use log::info; +use protocol::dspmc::shuffler::ShufflerDspmc; +use protocol::dspmc::traits::*; +use protocol::shared::TFeatures; +use rpc::connect::create_client::create_client; +use rpc::proto::gen_dspmc_company::service_response::Ack as CompanyAck; +use rpc::proto::gen_dspmc_company::Init as CompanyInit; +use rpc::proto::gen_dspmc_company::RecvShares as CompanyRecvShares; +use rpc::proto::gen_dspmc_company::SendData as CompanySendData; +use rpc::proto::gen_dspmc_company::ServiceResponse as CompanyServiceResponse; +use rpc::proto::gen_dspmc_helper::service_response::Ack as HelperAck; +use rpc::proto::gen_dspmc_helper::SendDataAck; +use rpc::proto::gen_dspmc_helper::ServiceResponse as HelperServiceResponse; +use rpc::proto::gen_dspmc_partner::service_response::Ack as PartnerAck; +use rpc::proto::gen_dspmc_partner::Init as PartnerInit; +use rpc::proto::gen_dspmc_partner::SendData as PartnerSendData; +use rpc::proto::RpcClient; +use tonic::Request; + +mod rpc_client_company; +mod rpc_client_helper; +mod rpc_client_partner; + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + + // todo: move matches outside, or move to build.rs + let matches = App::new("Delegated Private Id MultiKey Shuffler") + .version("0.1") + .about("Delegated Private Id Multi Key Protocol") + .args(&[ + Arg::with_name("company") + .long("company") + .short("c") + .takes_value(true) + .required(true) + .help("Company host path to connect to, ex: 0.0.0.0:10009"), + Arg::with_name("helper") + .long("helper") + .short("helper") + .takes_value(true) + .required(true) + .help("Helper host path to connect to, ex: 0.0.0.0:10011"), + Arg::with_name("partners") + .long("partners") + .short("p") + .takes_value(true) + .required(true) + .help("Partner host path to connect to, ex: 0.0.0.0:10010"), + Arg::with_name("stdout") + .long("stdout") + .short("u") + .takes_value(false) + .help("Prints the output to stdout rather than file"), + Arg::with_name("no-tls") + .long("no-tls") + .takes_value(false) + .help("Turns tls off"), + Arg::with_name("tls-dir") + .long("tls-dir") + .takes_value(true) + .help( + "Path to directory with files with key, cert and ca.pem file\n + client: client.key, client.pem, ca.pem \n + server: server.key, server.pem, ca.pem \n", + ), + Arg::with_name("tls-key") + .long("tls-key") + .takes_value(true) + .requires("tls-cert") + .requires("tls-ca") + .help("Path to tls key (non-encrypted)"), + Arg::with_name("tls-cert") + .long("tls-cert") + .takes_value(true) + .requires("tls-key") + .requires("tls-ca") + .help( + "Path to tls certificate (pem format), SINGLE cert, \ + NO CHAINING, required by client as well", + ), + Arg::with_name("tls-ca") + .long("tls-ca") + .takes_value(true) + .requires("tls-key") + .requires("tls-cert") + .help("Path to root CA certificate issued cert and keys"), + Arg::with_name("tls-domain") + .long("tls-domain") + .takes_value(true) + .help("Override TLS domain for SSL cert (if host is IP)"), + ]) + .groups(&[ + ArgGroup::with_name("tls") + .args(&["no-tls", "tls-dir", "tls-key"]) + .required(true), + ArgGroup::with_name("out").args(&["stdout"]).required(true), + ]) + .get_matches(); + + let global_timer = timer::Timer::new_silent("global"); + + let no_tls = matches.is_present("no-tls"); + let company_host = matches.value_of("company"); + let helper_host = matches.value_of("helper"); + let tls_dir = matches.value_of("tls-dir"); + let tls_key = matches.value_of("tls-key"); + let tls_cert = matches.value_of("tls-cert"); + let tls_ca = matches.value_of("tls-ca"); + let tls_domain = matches.value_of("tls-domain"); + + let mut company_client_context = { + match create_client( + no_tls, + company_host, + tls_dir, + tls_key, + tls_cert, + tls_ca, + tls_domain, + "dspmc-company".to_string(), + ) { + RpcClient::DspmcCompany(x) => x, + _ => panic!("wrong client"), + } + }; + + let mut helper_client_context = { + match create_client( + no_tls, + helper_host, + tls_dir, + tls_key, + tls_cert, + tls_ca, + tls_domain, + "dspmc-helper".to_string(), + ) { + RpcClient::DspmcHelper(x) => x, + _ => panic!("wrong client"), + } + }; + + let mut partner_client_context = vec![]; + let partner_host_pre = matches.value_of("partners").unwrap().split(","); + for host_pre_i in partner_host_pre { + let partner_client_context_i = { + match create_client( + no_tls, + Some(host_pre_i), + tls_dir, + tls_key, + tls_cert, + tls_ca, + tls_domain, + "dspmc-partner".to_string(), + ) { + RpcClient::DspmcPartner(x) => x, + _ => panic!("wrong client"), + } + }; + partner_client_context.push(partner_client_context_i); + } + + // 1. Create shuffler protocol instance + let shuffler_protocol = ShufflerDspmc::new(); + + // 2. Initialize company - this loads company's data + let company_init_ack = match company_client_context + .initialize(Request::new(CompanyInit {})) + .await? + .into_inner() + .ack + .unwrap() + { + CompanyAck::InitAck(x) => x, + _ => panic!("wrong ack"), + }; + + // 3. Initialize partners - this loads partner's data + let mut partner_init_acks = vec![]; + for i in 0..partner_client_context.len() { + let partner_init_ack = match partner_client_context[i] + .initialize(Request::new(PartnerInit {})) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::InitAck(x) => x, + _ => panic!("wrong ack"), + }; + partner_init_acks.push(partner_init_ack); + } + + // 4. Get public key from company and send it to partners and to helper + // Send helper's public key to partners + { + let mut company_public_key = TPayload::new(); + let _ = rpc_client_company::recv( + CompanyServiceResponse { + ack: Some(CompanyAck::InitAck(company_init_ack.clone())), + }, + "company_public_key".to_string(), + &mut company_public_key, + &mut company_client_context, + ) + .await?; + shuffler_protocol.set_company_public_key(company_public_key.clone())?; + + let helper_public_key_ack = match rpc_client_helper::send( + company_public_key.clone(), + "company_public_key".to_string(), + &mut helper_client_context, + ) + .await? + .into_inner() + .ack + .unwrap() + { + HelperAck::CompanyPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; + + let mut helper_public_key = TPayload::new(); + let _ = rpc_client_helper::recv( + HelperServiceResponse { + ack: Some(HelperAck::CompanyPublicKeyAck( + helper_public_key_ack.clone(), + )), + }, + "helper_public_key".to_string(), + &mut helper_public_key, + &mut helper_client_context, + ) + .await?; + shuffler_protocol.set_helper_public_key(helper_public_key.clone())?; + + // Send helper public key to Company + let _ = match rpc_client_company::send( + helper_public_key.clone(), + "helper_public_key".to_string(), + &mut company_client_context, + ) + .await? + .into_inner() + .ack + .unwrap() + { + CompanyAck::HelperPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; + + for i in 0..partner_client_context.len() { + // Send company public key to partners + let _ = match rpc_client_partner::send( + company_public_key.clone(), + "company_public_key".to_string(), + &mut partner_client_context[i], + ) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::CompanyPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; + + // Send helper public key to partners + let _ = match rpc_client_partner::send( + helper_public_key.clone(), + "helper_public_key".to_string(), + &mut partner_client_context[i], + ) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::HelperPublicKeyAck(x) => x, + _ => panic!("wrong ack"), + }; + } + } + + // 5. Send requests to partners to send their data and shares to company + let mut partner_sent_data_acks = vec![]; + for i in 0..partner_client_context.len() { + let partner_sent_data_ack = match partner_client_context[i] + .send_data_to_company(Request::new(PartnerSendData {})) + .await? + .into_inner() + .ack + .unwrap() + { + PartnerAck::SendDataAck(x) => x, + _ => panic!("wrong ack"), + }; + partner_sent_data_acks.push(partner_sent_data_ack); + } + + // 6. Stop Partner service + for i in 0..partner_client_context.len() { + rpc_client_partner::stop_service(&mut partner_client_context[i]).await?; + } + + // Secure shuffling starts here + + // 7. Send request to company to send ct3 from all partners to Helper along + // with p_cd and v_cd + let company_sent_ct3_v3_p3_ack = match company_client_context + .send_ct3_p_cd_v_cd_to_helper(Request::new(CompanySendData {})) + .await? + .into_inner() + .ack + .unwrap() + { + CompanyAck::SendDataAck(x) => x, + _ => panic!("wrong ack"), + }; + + // 8. Receive p_cs and v_cs from company + let mut v4_p4 = TPayload::new(); + let _ = rpc_client_company::recv( + CompanyServiceResponse { + ack: Some(CompanyAck::SendDataAck(company_sent_ct3_v3_p3_ack.clone())), + }, + "p_cs_v_cs".to_string(), + &mut v4_p4, + &mut company_client_context, + ) + .await?; + + let offset_len = + u64::from_le_bytes(v4_p4.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + // flattened len + let data_len = + u64::from_le_bytes(v4_p4.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let num_keys = offset_len - 1; + let offset = v4_p4 + .drain((num_keys * 2 + data_len * 2)..) + .map(|b| u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize) + .collect::>(); + assert_eq!(offset_len, offset.len()); + + let ct2_prime_flat = v4_p4.drain((v4_p4.len() - data_len)..).collect::>(); + let ct1_prime_flat = v4_p4.drain((v4_p4.len() - data_len)..).collect::>(); + + let v_cs_bytes = v4_p4.drain((v4_p4.len() - num_keys)..).collect::>(); + v4_p4.shrink_to_fit(); + + shuffler_protocol.set_p_cs_v_cs(v_cs_bytes, v4_p4)?; + + // 9. Receive u_2 = p_cd(v'') xor v_cd from helper + let mut data = TPayload::new(); + let _ = rpc_client_helper::recv( + HelperServiceResponse { + ack: Some(HelperAck::SendDataAck(SendDataAck {})), + }, + "u2".to_string(), + &mut data, + &mut helper_client_context, + ) + .await?; + let num_features = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let num_rows = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + data.shrink_to_fit(); + + let mut u2 = TFeatures::new(); + for i in (0..num_features).rev() { + let x = data + .drain(i * num_rows..) + .map(|x| u64::from_le_bytes(x.buffer.as_slice().try_into().unwrap())) + .collect::>(); + u2.push(x); + } + + // 10. Generete shuffler permutations + // Generate p_sc, v_sc and p_sd, v_sd + let (p_sc_v_sc, p_sd_v_sd) = shuffler_protocol.gen_permutations().unwrap(); + + // 11. Compute x_2 = p_cs(u2) xor v_cs + // Compute v_2' = p_sd(p_sc(x_2) xor v_sd) xor v_sd + // Return rerandomized ct1' and ct2' as ct1'' and ct2'' + let ct1_ct2_dprime = shuffler_protocol + .compute_v2prime_ct1ct2(u2, ct1_prime_flat, ct2_prime_flat, offset) + .unwrap(); + + // v_sc, p_sc, ct1_dprime_flat, ct2_dprime_flat, ct_offset + let mut p_sc_v_sc_ct1_ct2_dprime = p_sc_v_sc; + p_sc_v_sc_ct1_ct2_dprime.extend(ct1_ct2_dprime); + + // 12. Send v_sc, p_sc, ct1'', ct2'' to C + let _company_p_sc_v_sc_ack = match rpc_client_company::send( + p_sc_v_sc_ct1_ct2_dprime, + "p_sc_v_sc_ct1_ct2_dprime".to_string(), + &mut company_client_context, + ) + .await? + .into_inner() + .ack + .unwrap() + { + CompanyAck::UPartnerAck(x) => x, + _ => panic!("wrong ack"), + }; + + // 13. Send p_sd, v_sd to helper (D) + let _ = match rpc_client_helper::send( + p_sd_v_sd, + "p_sd_v_sd".to_string(), + &mut helper_client_context, + ) + .await? + .into_inner() + .ack + .unwrap() + { + HelperAck::UPartnerAck(x) => x, + _ => panic!("wrong ack"), + }; + + // 14. Send request to company to send u1 to Helper + // u1 = p_sc( p_cs( p_cd(v_1) xor v_cd) xor v_cs) xor v_sc + let _company_sent_u1_ack = match company_client_context + .send_u1_to_helper(Request::new(CompanySendData {})) + .await? + .into_inner() + .ack + .unwrap() + { + CompanyAck::SendDataAck(x) => x, + _ => panic!("wrong ack"), + }; + + // Secure shuffling ends here + + // Blind v' with hashed Elgamal. + // Send blinded v' and h = g^z. + let blinded_vprime = shuffler_protocol.get_blinded_vprime().unwrap(); + + // 15. Send blinded v' and g^z to helper (D) + let _helper_vprime_ack = match rpc_client_helper::send( + blinded_vprime, + "encrypted_vprime".to_string(), + &mut helper_client_context, + ) + .await? + .into_inner() + .ack + .unwrap() + { + HelperAck::UPartnerAck(x) => x, + _ => panic!("wrong ack"), + }; + + // 16. Send request to company to send ct1, ct2', and X to Helper + // ct2' = ct2^c + // X = H(C)^c + let _company_keys_ack = match company_client_context + .send_encrypted_keys_to_helper(Request::new(CompanySendData {})) + .await? + .into_inner() + .ack + .unwrap() + { + CompanyAck::SendDataAck(x) => x, + _ => panic!("wrong ack"), + }; + + // Identity Match Stage is done. + + // 17. Create company's ID spine + rpc_client_company::calculate_id_map(&mut company_client_context).await?; + + // 18. Signal the helper to run the rest of the protocol + // 1. Compute multi-key matches -- calculate_set_diff + // 2. Compute ID map for LJ -- calculate_id_map + rpc_client_helper::calculate_id_map(&mut helper_client_context).await?; + + // 19. Send request to company to receive shares from helper + // calculate_features_xor_shares + // Set XOR share of features for company + // Print Company's ID spine and save partners shares + let _company_sent_data_ack = match company_client_context + .recv_shares_from_helper(Request::new(CompanyRecvShares {})) + .await? + .into_inner() + .ack + .unwrap() + { + CompanyAck::RecvSharesAck(x) => x, + _ => panic!("wrong ack"), + }; + + // 20. Print Helper's ID spine and save partners shares + rpc_client_helper::reveal(&mut helper_client_context).await?; + + // Stop Helper service + rpc_client_helper::stop_service(&mut helper_client_context).await?; + + global_timer.qps("total time", partner_client_context.len()); + info!("Bye!"); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dspmc/company-server.rs b/protocol-rpc/src/rpc/dspmc/company-server.rs new file mode 100644 index 0000000..a0c8696 --- /dev/null +++ b/protocol-rpc/src/rpc/dspmc/company-server.rs @@ -0,0 +1,205 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +#[macro_use] +extern crate log; +extern crate clap; +extern crate ctrlc; +extern crate tonic; + +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; +use log::info; + +mod rpc_client_helper; +mod rpc_server_company; + +use rpc::connect::create_client::create_client; +use rpc::connect::create_server::create_server; +use rpc::proto::gen_dspmc_company::dspmc_company_server; +use rpc::proto::RpcClient; + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + + let matches = App::new("Delegated Private Id MultiKey Company") + .version("0.1") + .about("Private Id MultiKey Protocol") + .args(&[ + Arg::with_name("host") + .long("host") + .takes_value(true) + .required(true) + .help("Host path to connect to, ex: 0.0.0.0:10009"), + Arg::with_name("helper") + .long("helper") + .short("helper") + .takes_value(true) + .required(true) + .help("helper host path to connect to, ex: 0.0.0.0:10011"), + Arg::with_name("input") + .long("input") + .short("i") + .default_value("input.csv") + .help("Path to input file with keys"), + Arg::with_name("input-with-headers") + .long("input-with-headers") + .takes_value(false) + .help("Indicates if the input CSV contains headers"), + Arg::with_name("output") + .long("output") + .short("o") + .takes_value(true) + .help("Path to output file for keys only"), + Arg::with_name("stdout") + .long("stdout") + .short("u") + .takes_value(false) + .help("Prints the keys to stdout rather than file"), + Arg::with_name("output-shares-path") + .long("output-shares-path") + .takes_value(true) + .required(true) + .help( + "path to write shares of features.\n + Feature will be written as {path}_partner_features.csv", + ), + Arg::with_name("no-tls") + .long("no-tls") + .takes_value(false) + .help("Turns tls off"), + Arg::with_name("tls-dir") + .long("tls-dir") + .takes_value(true) + .help( + "Path to directory with files with key, cert and ca.pem file\n + client: client.key, client.pem, ca.pem \n + server: server.key, server.pem, ca.pem \n", + ), + Arg::with_name("tls-key") + .long("tls-key") + .takes_value(true) + .requires("tls-cert") + .requires("tls-ca") + .help("Path to tls key (non-encrypted)"), + Arg::with_name("tls-cert") + .long("tls-cert") + .takes_value(true) + .requires("tls-key") + .requires("tls-ca") + .help( + "Path to tls certificate (pem format), SINGLE cert, \ + NO CHAINING, required by client as well", + ), + Arg::with_name("tls-ca") + .long("tls-ca") + .takes_value(true) + .requires("tls-key") + .requires("tls-cert") + .help("Path to root CA certificate issued cert and keys"), + ]) + .groups(&[ + ArgGroup::with_name("tls") + .args(&["no-tls", "tls-dir", "tls-key"]) + .required(true), + ArgGroup::with_name("out") + .args(&["output", "stdout"]) + .required(true), + ]) + .get_matches(); + + let input_path = matches.value_of("input").unwrap_or("input.csv"); + let input_with_headers = matches.is_present("input-with-headers"); + let output_keys_path = matches.value_of("output"); + let output_shares_path = matches.value_of("output-shares-path"); + + let no_tls = matches.is_present("no-tls"); + let host = matches.value_of("host"); + let helper_host = matches.value_of("helper"); + let tls_dir = matches.value_of("tls-dir"); + let tls_key = matches.value_of("tls-key"); + let tls_cert = matches.value_of("tls-cert"); + let tls_ca = matches.value_of("tls-ca"); + let tls_domain = matches.value_of("tls-domain"); + + let helper_client_context = { + match create_client( + no_tls, + helper_host, + tls_dir, + tls_key, + tls_cert, + tls_ca, + tls_domain, + "dspmc-helper".to_string(), + ) { + RpcClient::DspmcHelper(x) => x, + _ => panic!("wrong client"), + } + }; + + let (mut server, tx, rx) = create_server(no_tls, tls_dir, tls_key, tls_cert, tls_ca); + + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + info!("Input path: {}", input_path); + + if output_keys_path.is_some() { + info!("Output keys path: {}", output_keys_path.unwrap()); + } else { + info!("Output view to stdout (first 10 keys)"); + } + + if output_shares_path.is_some() { + info!("Output shares path: {}", output_shares_path.unwrap()); + } else { + error!("Output shares path not provided"); + } + + let service = rpc_server_company::DspmcCompanyService::new( + input_path, + output_keys_path, + output_shares_path, + input_with_headers, + helper_client_context, + ); + + let ks = service.killswitch.clone(); + let recv_thread = thread::spawn(move || { + let sleep_dur = time::Duration::from_millis(1000); + while !(ks.load(Ordering::Relaxed)) && running.load(Ordering::Relaxed) { + thread::sleep(sleep_dur); + } + + info!("Shutting down server ..."); + tx.send(()).unwrap(); + }); + + info!("Server starting at {}", host.unwrap()); + + let addr = host.unwrap().parse()?; + + server + .add_service(dspmc_company_server::DspmcCompanyServer::new(service)) + .serve_with_shutdown(addr, async { + rx.await.ok(); + }) + .await?; + + recv_thread.join().unwrap(); + info!("Bye!"); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dspmc/helper-server.rs b/protocol-rpc/src/rpc/dspmc/helper-server.rs new file mode 100644 index 0000000..013ccc6 --- /dev/null +++ b/protocol-rpc/src/rpc/dspmc/helper-server.rs @@ -0,0 +1,158 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +#[macro_use] +extern crate log; +extern crate clap; +extern crate ctrlc; +extern crate tonic; + +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; +use log::info; + +mod rpc_server_helper; +use rpc::connect::create_server::create_server; +use rpc::proto::gen_dspmc_helper::dspmc_helper_server; + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + + let matches = App::new("Delegated Private Id MultiKey Helper") + .version("0.1") + .about("Private Id MultiKey Protocol") + .args(&[ + Arg::with_name("host") + .long("host") + .takes_value(true) + .required(true) + .help("Host path to connect to, ex: 0.0.0.0:10009"), + Arg::with_name("output") + .long("output") + .short("o") + .takes_value(true) + .help("Path to output file for keys only"), + Arg::with_name("stdout") + .long("stdout") + .short("u") + .takes_value(false) + .help("Prints the keys to stdout rather than file"), + Arg::with_name("output-shares-path") + .long("output-shares-path") + .takes_value(true) + .required(true) + .help( + "path to write shares of features.\n + Feature will be written as {path}_partner_features.csv", + ), + Arg::with_name("no-tls") + .long("no-tls") + .takes_value(false) + .help("Turns tls off"), + Arg::with_name("tls-dir") + .long("tls-dir") + .takes_value(true) + .help( + "Path to directory with files with key, cert and ca.pem file\n + client: client.key, client.pem, ca.pem \n + server: server.key, server.pem, ca.pem \n", + ), + Arg::with_name("tls-key") + .long("tls-key") + .takes_value(true) + .requires("tls-cert") + .requires("tls-ca") + .help("Path to tls key (non-encrypted)"), + Arg::with_name("tls-cert") + .long("tls-cert") + .takes_value(true) + .requires("tls-key") + .requires("tls-ca") + .help( + "Path to tls certificate (pem format), SINGLE cert, \ + NO CHAINING, required by client as well", + ), + Arg::with_name("tls-ca") + .long("tls-ca") + .takes_value(true) + .requires("tls-key") + .requires("tls-cert") + .help("Path to root CA certificate issued cert and keys"), + ]) + .groups(&[ + ArgGroup::with_name("tls") + .args(&["no-tls", "tls-dir", "tls-key"]) + .required(true), + ArgGroup::with_name("out") + .args(&["output", "stdout"]) + .required(true), + ]) + .get_matches(); + + let output_keys_path = matches.value_of("output"); + let output_shares_path = matches.value_of("output-shares-path"); + + let no_tls = matches.is_present("no-tls"); + let host = matches.value_of("host"); + let tls_dir = matches.value_of("tls-dir"); + let tls_key = matches.value_of("tls-key"); + let tls_cert = matches.value_of("tls-cert"); + let tls_ca = matches.value_of("tls-ca"); + + let (mut server, tx, rx) = create_server(no_tls, tls_dir, tls_key, tls_cert, tls_ca); + + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + if output_keys_path.is_some() { + info!("Output keys path: {}", output_keys_path.unwrap()); + } else { + info!("Output view to stdout (first 10 keys)"); + } + + if output_shares_path.is_some() { + info!("Output shares path: {}", output_shares_path.unwrap()); + } else { + error!("Output shares path not provided"); + } + + let service = rpc_server_helper::DspmcHelperService::new(output_keys_path, output_shares_path); + + let ks = service.killswitch.clone(); + let recv_thread = thread::spawn(move || { + let sleep_dur = time::Duration::from_millis(1000); + while !(ks.load(Ordering::Relaxed)) && running.load(Ordering::Relaxed) { + thread::sleep(sleep_dur); + } + + info!("Shutting down server ..."); + tx.send(()).unwrap(); + }); + + info!("Server starting at {}", host.unwrap()); + + let addr = host.unwrap().parse()?; + + server + .add_service(dspmc_helper_server::DspmcHelperServer::new(service)) + .serve_with_shutdown(addr, async { + rx.await.ok(); + }) + .await?; + + recv_thread.join().unwrap(); + info!("Bye!"); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dspmc/partner-server.rs b/protocol-rpc/src/rpc/dspmc/partner-server.rs new file mode 100644 index 0000000..8af486f --- /dev/null +++ b/protocol-rpc/src/rpc/dspmc/partner-server.rs @@ -0,0 +1,179 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +#[macro_use] +extern crate log; +extern crate clap; +extern crate ctrlc; +extern crate tonic; + +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::thread; +use std::time; + +use clap::App; +use clap::Arg; +use clap::ArgGroup; +use log::info; + +mod rpc_client_company; +mod rpc_server_partner; + +use rpc::connect::create_client::create_client; +use rpc::connect::create_server::create_server; +use rpc::proto::gen_dspmc_partner::dspmc_partner_server; +use rpc::proto::RpcClient; + +#[tokio::main] +async fn main() -> Result<(), Box> { + env_logger::init(); + + let matches = App::new("Delegated Private Id MultiKey Partner") + .version("0.1") + .about("Delegated Private Id MultiKey Protocol") + .args(&[ + Arg::with_name("host") + .long("host") + .takes_value(true) + .required(true) + .help("Host path to connect to, ex: 0.0.0.0:10009"), + Arg::with_name("company") + .long("company") + .short("c") + .takes_value(true) + .required(true) + .help("Company host path to connect to, ex: 0.0.0.0:10009"), + Arg::with_name("input-keys") + .long("input-keys") + .default_value("input_keys.csv") + .help("Path to input file with keys"), + Arg::with_name("input-features") + .long("input-features") + .default_value("input_features.csv") + .help("Path to input file with keys"), + Arg::with_name("input-with-headers") + .long("input-with-headers") + .takes_value(false) + .help("Indicates if the input CSV contains headers"), + Arg::with_name("no-tls") + .long("no-tls") + .takes_value(false) + .help("Turns tls off"), + Arg::with_name("tls-dir") + .long("tls-dir") + .takes_value(true) + .help( + "Path to directory with files with key, cert and ca.pem file\n + client: client.key, client.pem, ca.pem \n + server: server.key, server.pem, ca.pem \n", + ), + Arg::with_name("tls-key") + .long("tls-key") + .takes_value(true) + .requires("tls-cert") + .requires("tls-ca") + .help("Path to tls key (non-encrypted)"), + Arg::with_name("tls-cert") + .long("tls-cert") + .takes_value(true) + .requires("tls-key") + .requires("tls-ca") + .help( + "Path to tls certificate (pem format), SINGLE cert, \ + NO CHAINING, required by client as well", + ), + Arg::with_name("tls-ca") + .long("tls-ca") + .takes_value(true) + .requires("tls-key") + .requires("tls-cert") + .help("Path to root CA certificate issued cert and keys"), + Arg::with_name("tls-domain") + .long("tls-domain") + .takes_value(true) + .help("Override TLS domain for SSL cert (if host is IP)"), + ]) + .groups(&[ArgGroup::with_name("tls") + .args(&["no-tls", "tls-dir", "tls-key"]) + .required(true)]) + .get_matches(); + + let input_keys_path = matches.value_of("input-keys").unwrap_or("input_keys.csv"); + let input_features_path = matches + .value_of("input-features") + .unwrap_or("input_features.csv"); + let input_with_headers = matches.is_present("input-with-headers"); + + let no_tls = matches.is_present("no-tls"); + let host = matches.value_of("host"); + let company_host = matches.value_of("company"); + let tls_dir = matches.value_of("tls-dir"); + let tls_key = matches.value_of("tls-key"); + let tls_cert = matches.value_of("tls-cert"); + let tls_ca = matches.value_of("tls-ca"); + let tls_domain = matches.value_of("tls-domain"); + + let company_client_context = { + match create_client( + no_tls, + company_host, + tls_dir, + tls_key, + tls_cert, + tls_ca, + tls_domain, + "dspmc-company".to_string(), + ) { + RpcClient::DspmcCompany(x) => x, + _ => panic!("wrong client"), + } + }; + + let (mut server, tx, rx) = create_server(no_tls, tls_dir, tls_key, tls_cert, tls_ca); + + let running = Arc::new(AtomicBool::new(true)); + let r = running.clone(); + ctrlc::set_handler(move || { + r.store(false, Ordering::SeqCst); + }) + .expect("Error setting Ctrl-C handler"); + + info!("Input path for keys: {}", input_keys_path); + info!("Input path for features: {}", input_features_path); + + let service = rpc_server_partner::DspmcPartnerService::new( + input_keys_path, + input_features_path, + input_with_headers, + company_client_context, + ); + + let ks = service.killswitch.clone(); + let recv_thread = thread::spawn(move || { + let sleep_dur = time::Duration::from_millis(1000); + while !(ks.load(Ordering::Relaxed)) && running.load(Ordering::Relaxed) { + thread::sleep(sleep_dur); + } + + info!("Shutting down server ..."); + tx.send(()).unwrap(); + }); + + info!("Server starting at {}", host.unwrap()); + + let addr = host.unwrap().parse()?; + + server + .add_service(dspmc_partner_server::DspmcPartnerServer::new(service)) + .serve_with_shutdown(addr, async { + rx.await.ok(); + }) + .await?; + + recv_thread.join().unwrap(); + + info!("Bye!"); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dspmc/rpc_client_company.rs b/protocol-rpc/src/rpc/dspmc/rpc_client_company.rs new file mode 100644 index 0000000..7b08d17 --- /dev/null +++ b/protocol-rpc/src/rpc/dspmc/rpc_client_company.rs @@ -0,0 +1,61 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate common; +extern crate crypto; +extern crate protocol; + +use common::timer; +use crypto::prelude::TPayload; +use rpc::proto::gen_dspmc_company::dspmc_company_client::DspmcCompanyClient; +use rpc::proto::gen_dspmc_company::Commitment; +use rpc::proto::gen_dspmc_company::ServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; + +pub async fn send( + data: TPayload, + name: String, + rpc: &mut DspmcCompanyClient, +) -> Result, Status> { + match name.as_str() { + "helper_public_key" => rpc.send_helper_public_key(send_data(data)).await, + "p_sc_v_sc_ct1_ct2_dprime" => rpc.send_p_sc_v_sc_ct1ct2dprime(send_data(data)).await, + _ => panic!("wrong data type"), + } +} + +pub async fn recv( + response: ServiceResponse, + name: String, + data: &mut TPayload, + rpc: &mut DspmcCompanyClient, +) -> Result<(), Status> { + let t = timer::Builder::new().label(name.as_str()).build(); + + let request = Request::new(response); + let mut strm = match name.as_str() { + "company_public_key" => rpc.recv_company_public_key(request).await?.into_inner(), + "p_cs_v_cs" => rpc.recv_p_cs_v_cs(request).await?.into_inner(), + "u_company" => rpc.recv_u_company(request).await?.into_inner(), + _ => panic!("wrong data type"), + }; + + let res = read_from_stream(&mut strm).await?; + t.qps(format!("received {}", name.as_str()).as_str(), res.len()); + data.clear(); + data.extend(res); + Ok(()) +} + +pub async fn calculate_id_map(rpc: &mut DspmcCompanyClient) -> Result<(), Status> { + let _r = rpc + .calculate_id_map(Request::new(Commitment {})) + .await? + .into_inner(); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dspmc/rpc_client_helper.rs b/protocol-rpc/src/rpc/dspmc/rpc_client_helper.rs new file mode 100644 index 0000000..84a3c98 --- /dev/null +++ b/protocol-rpc/src/rpc/dspmc/rpc_client_helper.rs @@ -0,0 +1,74 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate common; +extern crate crypto; +extern crate protocol; + +use common::timer; +use crypto::prelude::TPayload; +use rpc::proto::gen_dspmc_helper::dspmc_helper_client::DspmcHelperClient; +use rpc::proto::gen_dspmc_helper::Commitment; +use rpc::proto::gen_dspmc_helper::ServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; + +pub async fn send( + data: TPayload, + name: String, + rpc: &mut DspmcHelperClient, +) -> Result, Status> { + match name.as_str() { + "company_public_key" => rpc.send_company_public_key(send_data(data)).await, + "p_sd_v_sd" => rpc.send_p_sd_v_sd(send_data(data)).await, + "encrypted_vprime" => rpc.send_encrypted_vprime(send_data(data)).await, + _ => panic!("wrong data type"), + } +} + +pub async fn recv( + response: ServiceResponse, + name: String, + data: &mut TPayload, + rpc: &mut DspmcHelperClient, +) -> Result<(), Status> { + let t = timer::Builder::new().label(name.as_str()).build(); + + let request = Request::new(response); + let mut strm = match name.as_str() { + "helper_public_key" => rpc.recv_helper_public_key(request).await?.into_inner(), + "u2" => rpc.recv_u2(request).await?.into_inner(), + _ => panic!("wrong data type"), + }; + + let res = read_from_stream(&mut strm).await?; + t.qps(format!("received {}", name.as_str()).as_str(), res.len()); + data.clear(); + data.extend(res); + Ok(()) +} + +pub async fn calculate_id_map(rpc: &mut DspmcHelperClient) -> Result<(), Status> { + let _r = rpc + .calculate_id_map(Request::new(Commitment {})) + .await? + .into_inner(); + Ok(()) +} + +pub async fn reveal(rpc: &mut DspmcHelperClient) -> Result<(), Status> { + let _r = rpc.reveal(Request::new(Commitment {})).await?.into_inner(); + Ok(()) +} + +pub async fn stop_service(rpc: &mut DspmcHelperClient) -> Result<(), Status> { + let _r = rpc + .stop_service(Request::new(Commitment {})) + .await? + .into_inner(); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dspmc/rpc_client_partner.rs b/protocol-rpc/src/rpc/dspmc/rpc_client_partner.rs new file mode 100644 index 0000000..aabe369 --- /dev/null +++ b/protocol-rpc/src/rpc/dspmc/rpc_client_partner.rs @@ -0,0 +1,36 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate common; +extern crate crypto; +extern crate protocol; + +use crypto::prelude::TPayload; +use rpc::proto::gen_dspmc_partner::dspmc_partner_client::DspmcPartnerClient; +use rpc::proto::gen_dspmc_partner::Commitment; +use rpc::proto::gen_dspmc_partner::ServiceResponse; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; + +pub async fn send( + data: TPayload, + name: String, + rpc: &mut DspmcPartnerClient, +) -> Result, Status> { + match name.as_str() { + "company_public_key" => rpc.send_company_public_key(send_data(data)).await, + "helper_public_key" => rpc.send_helper_public_key(send_data(data)).await, + _ => panic!("wrong data type"), + } +} + +pub async fn stop_service(rpc: &mut DspmcPartnerClient) -> Result<(), Status> { + let _r = rpc + .stop_service(Request::new(Commitment {})) + .await? + .into_inner(); + Ok(()) +} diff --git a/protocol-rpc/src/rpc/dspmc/rpc_server_company.rs b/protocol-rpc/src/rpc/dspmc/rpc_server_company.rs new file mode 100644 index 0000000..80df38d --- /dev/null +++ b/protocol-rpc/src/rpc/dspmc/rpc_server_company.rs @@ -0,0 +1,406 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate common; +extern crate crypto; +extern crate futures; +extern crate protocol; +extern crate tokio; +extern crate tonic; + +use std::borrow::BorrowMut; +use std::convert::TryInto; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use common::timer; +use protocol::dspmc::company::CompanyDspmc; +use protocol::dspmc::traits::CompanyDspmcProtocol; +use protocol::shared::TFeatures; +use rpc::proto::common::Payload; +use rpc::proto::gen_dspmc_company::dspmc_company_server::DspmcCompany; +use rpc::proto::gen_dspmc_company::service_response::*; +use rpc::proto::gen_dspmc_company::Commitment; +use rpc::proto::gen_dspmc_company::CommitmentAck; +use rpc::proto::gen_dspmc_company::HelperPublicKeyAck; +use rpc::proto::gen_dspmc_company::Init; +use rpc::proto::gen_dspmc_company::InitAck; +use rpc::proto::gen_dspmc_company::RecvShares; +use rpc::proto::gen_dspmc_company::RecvSharesAck; +use rpc::proto::gen_dspmc_company::SendData; +use rpc::proto::gen_dspmc_company::SendDataAck; +use rpc::proto::gen_dspmc_company::ServiceResponse; +use rpc::proto::gen_dspmc_company::UPartnerAck; +use rpc::proto::gen_dspmc_helper::dspmc_helper_client::DspmcHelperClient; +use rpc::proto::gen_dspmc_helper::ServiceResponse as HelperServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use rpc::proto::streaming::write_to_stream; +use rpc::proto::streaming::TPayloadStream; +use tonic::transport::Channel; +use tonic::Code; +use tonic::Request; +use tonic::Response; +use tonic::Status; +use tonic::Streaming; + +pub struct DspmcCompanyService { + protocol: CompanyDspmc, + input_path: String, + output_keys_path: Option, + output_shares_path: Option, + input_with_headers: bool, + helper_client_context: DspmcHelperClient, + pub killswitch: Arc, +} + +impl DspmcCompanyService { + pub fn new( + input_path: &str, + output_keys_path: Option<&str>, + output_shares_path: Option<&str>, + input_with_headers: bool, + helper_client_context: DspmcHelperClient, + ) -> DspmcCompanyService { + DspmcCompanyService { + protocol: CompanyDspmc::new(), + input_path: String::from(input_path), + output_keys_path: output_keys_path.map(String::from), + output_shares_path: output_shares_path.map(String::from), + input_with_headers, + helper_client_context, + killswitch: Arc::new(AtomicBool::new(false)), + } + } +} + +#[tonic::async_trait] +impl DspmcCompany for DspmcCompanyService { + type RecvUCompanyStream = TPayloadStream; + type RecvPCsVCsStream = TPayloadStream; + type RecvCompanyPublicKeyStream = TPayloadStream; + + async fn initialize(&self, _: Request) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("init") + .build(); + self.protocol + .load_data(&self.input_path, self.input_with_headers); + Ok(Response::new(ServiceResponse { + ack: Some(Ack::InitAck(InitAck {})), + })) + } + + async fn send_ct3_p_cd_v_cd_to_helper( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_ct3_p_cd_v_cd_to_helper") + .build(); + + self.protocol.gen_permutations(); + + // Send ct3 from all partners to helper. - company acts as a client to helper. + let partners_ct3 = self.protocol.get_all_ct3_p_cd_v_cd().unwrap(); + let mut helper_client_contxt = self.helper_client_context.clone(); + _ = helper_client_contxt + .send_ct3_p_cd_v_cd(send_data(partners_ct3)) + .await; + + Ok(Response::new(ServiceResponse { + ack: Some(Ack::SendDataAck(SendDataAck {})), + })) + } + + async fn send_u1_to_helper( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_u1_to_helper") + .build(); + + // Send u1 to helper. - company acts as a client to helper. + let u1 = self.protocol.get_u1().unwrap(); + let mut helper_client_contxt = self.helper_client_context.clone(); + _ = helper_client_contxt.send_u1(send_data(u1)).await; + + Ok(Response::new(ServiceResponse { + ack: Some(Ack::SendDataAck(SendDataAck {})), + })) + } + + async fn send_encrypted_keys_to_helper( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_encrypted_keys_to_helper") + .build(); + + // Send ct1, ct2', and X to helper. - company acts as a client to helper. + // H(C)^c + let mut enc_keys = self.protocol.get_company_keys().unwrap(); + let ct1_ct2 = self.protocol.get_ct1_ct2().unwrap(); + enc_keys.extend(ct1_ct2); + + let mut helper_client_contxt = self.helper_client_context.clone(); + // X, offset, metadata, ct1, ct2, offset, metadata + _ = helper_client_contxt + .send_encrypted_keys(send_data(enc_keys)) + .await; + + Ok(Response::new(ServiceResponse { + ack: Some(Ack::SendDataAck(SendDataAck {})), + })) + } + + async fn recv_shares_from_helper( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("recv_shares_from_helper") + .build(); + + let mut helper_client_contxt = self.helper_client_context.clone(); + let request = Request::new(HelperServiceResponse { + ack: Some( + rpc::proto::gen_dspmc_helper::service_response::Ack::UPartnerAck( + rpc::proto::gen_dspmc_helper::UPartnerAck {}, + ), + ), + }); + let mut strm = helper_client_contxt + .recv_xor_shares(request) + .await? + .into_inner(); + let mut data = read_from_stream(&mut strm).await?; + + let num_features = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let num_rows = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let g_zi = data + .drain(num_features * num_rows..) + .map(|x| x) + .collect::>(); + + let mut features = TFeatures::new(); + for i in (0..num_features).rev() { + let x = data + .drain(i * num_rows..) + .map(|x| u64::from_le_bytes(x.buffer.as_slice().try_into().unwrap())) + .collect::>(); + features.push(x); + } + + _ = self.protocol.calculate_features_xor_shares(features, g_zi); + + // Print Company's ID spine and save partners shares + match &self.output_keys_path { + Some(p) => self.protocol.save_id_map(p).unwrap(), + None => self.protocol.print_id_map(), + } + + let resp = self + .protocol + .save_features_shares(&self.output_shares_path.clone().unwrap()) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::RecvSharesAck(RecvSharesAck {})), + }) + }) + .map_err(|_| Status::internal("error saving feature shares")); + { + debug!("Setting up flag for graceful down"); + self.killswitch.store(true, Ordering::SeqCst); + } + + resp + } + + async fn calculate_id_map( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("calculate_id_map") + .build(); + self.protocol + .write_company_to_id_map() + .map(|_| Response::new(CommitmentAck {})) + .map_err(|_| Status::new(Code::Aborted, "cannot init the protocol for partner")) + } + + async fn recv_company_public_key( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("recv_company_public_key") + .build(); + self.protocol + .get_company_public_key() + .map(write_to_stream) + .map_err(|_| Status::new(Code::Aborted, "cannot send company_public_key")) + } + + async fn recv_p_cs_v_cs( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("recv_p_cs_v_cs") + .build(); + self.protocol + .get_p_cs_v_cs() + .map(write_to_stream) + .map_err(|_| Status::new(Code::Aborted, "cannot send p_cs_v_cs")) + } + + async fn recv_u_company( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("recv_u_company") + .build(); + self.protocol + .get_company_keys() + .map(write_to_stream) + .map_err(|_| Status::new(Code::Aborted, "cannot send u_company")) + } + + async fn send_p_sc_v_sc_ct1ct2dprime( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_p_sc_v_sc_ct1ct2dprime") + .build(); + let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; + + let offset_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + // flattened len + let data_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let num_keys = offset_len - 1; + + let offset = data + .drain((num_keys * 2 + data_len * 2)..) + .map(|b| u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize) + .collect::>(); + assert_eq!(offset_len, offset.len()); + + let ct2_dprime_flat = data.drain((data.len() - data_len)..).collect::>(); + let ct1_dprime_flat = data.drain((data.len() - data_len)..).collect::>(); + + let v_sc_bytes = data.drain((data.len() - num_keys)..).collect::>(); + data.shrink_to_fit(); // p_sc + + self.protocol + .set_p_sc_v_sc_ct1ct2dprime(v_sc_bytes, data, ct1_dprime_flat, ct2_dprime_flat, offset) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::UPartnerAck(UPartnerAck {})), + }) + }) + .map_err(|_| Status::internal("error loading")) + } + + async fn send_u_partner( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_u_partner") + .build(); + let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; + + let ct3 = data.pop().unwrap(); + + let num_features = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let num_rows = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + + let mut v_prime = data + .drain((data.len() - (num_features * num_rows))..) + .collect::>(); + + let mut xor_features = TFeatures::new(); + for i in (0..num_features).rev() { + let x = v_prime + .drain(i * num_rows..) + .map(|x| u64::from_le_bytes(x.buffer.as_slice().try_into().unwrap())) + .collect::>(); + xor_features.push(x); + } + + let offset_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let data_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + + let offset = data + .drain((data_len * 2)..) + .map(|b| u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize) + .collect::>(); + + // ct2 = pkd^r * H(P) + // ct1 = pkc^r + data.shrink_to_fit(); + let (ct2, ct1) = data.split_at(data_len); + + assert_eq!(offset_len, offset.len()); + + self.protocol + .set_encrypted_partner_keys_and_shares( + ct1.to_vec(), + ct2.to_vec(), + offset, + ct3.buffer, + xor_features, + ) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::UPartnerAck(UPartnerAck {})), + }) + }) + .map_err(|_| Status::internal("error loading")) + } + + async fn send_helper_public_key( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_helper_public_key") + .build(); + let mut strm = request.into_inner(); + self.protocol + .set_helper_public_key(read_from_stream(&mut strm).await?) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::HelperPublicKeyAck(HelperPublicKeyAck {})), + }) + }) + .map_err(|_| Status::internal("error writing")) + } +} diff --git a/protocol-rpc/src/rpc/dspmc/rpc_server_helper.rs b/protocol-rpc/src/rpc/dspmc/rpc_server_helper.rs new file mode 100644 index 0000000..81dfc5d --- /dev/null +++ b/protocol-rpc/src/rpc/dspmc/rpc_server_helper.rs @@ -0,0 +1,360 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate common; +extern crate crypto; +extern crate futures; +extern crate protocol; +extern crate tokio; +extern crate tonic; + +use std::borrow::BorrowMut; +use std::convert::TryInto; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use common::timer; +use protocol::dspmc::helper::HelperDspmc; +use protocol::dspmc::traits::HelperDspmcProtocol; +use protocol::shared::TFeatures; +use rpc::proto::common::Payload; +use rpc::proto::gen_dspmc_helper::dspmc_helper_server::DspmcHelper; +use rpc::proto::gen_dspmc_helper::service_response::*; +use rpc::proto::gen_dspmc_helper::Commitment; +use rpc::proto::gen_dspmc_helper::CommitmentAck; +use rpc::proto::gen_dspmc_helper::CompanyPublicKeyAck; +use rpc::proto::gen_dspmc_helper::EHelperAck; +use rpc::proto::gen_dspmc_helper::ServiceResponse; +use rpc::proto::gen_dspmc_helper::UPartnerAck; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::write_to_stream; +use rpc::proto::streaming::TPayloadStream; +use tonic::Code; +use tonic::Request; +use tonic::Response; +use tonic::Status; +use tonic::Streaming; + +pub struct DspmcHelperService { + protocol: HelperDspmc, + output_keys_path: Option, + output_shares_path: Option, + pub killswitch: Arc, +} + +impl DspmcHelperService { + pub fn new( + output_keys_path: Option<&str>, + output_shares_path: Option<&str>, + ) -> DspmcHelperService { + DspmcHelperService { + protocol: HelperDspmc::new(), + output_keys_path: output_keys_path.map(String::from), + output_shares_path: output_shares_path.map(String::from), + killswitch: Arc::new(AtomicBool::new(false)), + } + } +} + +#[tonic::async_trait] +impl DspmcHelper for DspmcHelperService { + type RecvHelperPublicKeyStream = TPayloadStream; + type RecvXorSharesStream = TPayloadStream; + type RecvU2Stream = TPayloadStream; + + async fn send_company_public_key( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_company_public_key") + .build(); + let mut strm = request.into_inner(); + self.protocol + .set_company_public_key(read_from_stream(&mut strm).await?) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::CompanyPublicKeyAck(CompanyPublicKeyAck {})), + }) + }) + .map_err(|_| Status::internal("error writing")) + } + + async fn calculate_id_map( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("calculate_id_map") + .build(); + _ = self.protocol.calculate_set_diff(); + self.protocol.calculate_id_map(); + Ok(Response::new(CommitmentAck {})) + } + + async fn recv_helper_public_key( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("recv_helper_public_key") + .build(); + self.protocol + .get_helper_public_key() + .map(write_to_stream) + .map_err(|_| Status::new(Code::Aborted, "cannot send helper_public_key")) + } + + async fn recv_xor_shares( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("calculate_features_xor_shares") + .build(); + self.protocol + .calculate_features_xor_shares() // returns v_d_prime + .map(write_to_stream) + .map_err(|_| Status::new(Code::Aborted, "cannot send xor shares")) + } + + async fn recv_u2( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("recv_u2") + .build(); + self.protocol + .get_u2() + .map(write_to_stream) + .map_err(|_| Status::new(Code::Aborted, "cannot send u2")) + } + + // Gets ct3s from all partners as well as permutation p_cd and blinding v_cd. + async fn send_ct3_p_cd_v_cd( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_ct3_p_cd_v_cd") + .build(); + let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; + + let data_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + + let v_cd_bytes = data.drain((data.len() - data_len)..).collect::>(); + let p_cd_bytes = data.drain((data.len() - data_len)..).collect::>(); + + let num_partners = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + + self.protocol + .set_ct3p_cd_v_cd(data, num_partners, v_cd_bytes, p_cd_bytes) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::EHelperAck(EHelperAck {})), + }) + }) + .map_err(|_| Status::internal("error loading")) + } + + async fn send_p_sd_v_sd( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_p_sd_v_sd") + .build(); + let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; + + let data_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + + let v_sd_bytes = data.drain((data.len() - data_len)..).collect::>(); + data.shrink_to_fit(); + + self.protocol + .set_p_sd_v_sd(v_sd_bytes, data) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::UPartnerAck(UPartnerAck {})), + }) + }) + .map_err(|_| Status::internal("error loading")) + } + + async fn send_encrypted_vprime( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_encrypted_vprime") + .build(); + let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; + + let num_features = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let num_rows = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let g_zi = data + .drain(num_features * num_rows..) + .map(|x| x) + .collect::>(); + + let mut blinded_features = TFeatures::new(); + for i in (0..num_features).rev() { + let x = data + .drain(i * num_rows..) + .map(|x| u64::from_le_bytes(x.buffer.as_slice().try_into().unwrap())) + .collect::>(); + blinded_features.push(x); + } + + self.protocol + .set_encrypted_vprime(blinded_features, g_zi) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::UPartnerAck(UPartnerAck {})), + }) + }) + .map_err(|_| Status::internal("error loading")) + } + + async fn send_u1( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_u1") + .build(); + let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; + + let num_features = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let num_rows = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + data.shrink_to_fit(); + + let mut u1 = TFeatures::new(); + for i in (0..num_features).rev() { + let x = data + .drain(i * num_rows..) + .map(|x| u64::from_le_bytes(x.buffer.as_slice().try_into().unwrap())) + .collect::>(); + u1.push(x); + } + + self.protocol + .set_u1(u1) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::UPartnerAck(UPartnerAck {})), + }) + }) + .map_err(|_| Status::internal("error loading")) + } + + async fn send_encrypted_keys( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_encrypted_keys") + .build(); + + // X, offset, metadata, ct1, ct2, offset, metadata + let mut data = read_from_stream(request.into_inner().borrow_mut()).await?; + + let ct_offset_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + // flattened len + let ct_data_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + // let num_keys = ct_offset_len - 1; + + let ct_offset = data + .drain((data.len() - ct_offset_len)..) + .map(|b| u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize) + .collect::>(); + assert_eq!(ct_offset_len, ct_offset.len()); + + let ct2_flat = data.drain((data.len() - ct_data_len)..).collect::>(); + let ct1_flat = data.drain((data.len() - ct_data_len)..).collect::>(); + + // H(C)*c + let offset_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + // flattened len + let data_len = + u64::from_le_bytes(data.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + // let num_keys = offset_len - 1; + + let offset = data + .drain(data_len..) + .map(|b| u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize) + .collect::>(); + assert_eq!(offset_len, offset.len()); + data.shrink_to_fit(); + + self.protocol + .set_encrypted_keys(data, offset, ct1_flat, ct2_flat, ct_offset) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::UPartnerAck(UPartnerAck {})), + }) + }) + .map_err(|_| Status::internal("error loading")) + } + + async fn reveal(&self, _: Request) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("reveal") + .build(); + match &self.output_keys_path { + Some(p) => self.protocol.save_id_map(&String::from(p)).unwrap(), + None => self.protocol.print_id_map(), + } + + let resp = self + .protocol + .save_features_shares(&self.output_shares_path.clone().unwrap()) + .map(|_| Response::new(CommitmentAck {})) + .map_err(|_| Status::internal("error saving feature shares")); + { + debug!("Setting up flag for graceful down"); + self.killswitch.store(true, Ordering::SeqCst); + } + + resp + } + + async fn stop_service( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("stop") + .build(); + { + debug!("Setting up flag for graceful down"); + self.killswitch.store(true, Ordering::SeqCst); + } + + Ok(Response::new(CommitmentAck {})) + } +} diff --git a/protocol-rpc/src/rpc/dspmc/rpc_server_partner.rs b/protocol-rpc/src/rpc/dspmc/rpc_server_partner.rs new file mode 100644 index 0000000..aa6c1f6 --- /dev/null +++ b/protocol-rpc/src/rpc/dspmc/rpc_server_partner.rs @@ -0,0 +1,164 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate common; +extern crate crypto; +extern crate futures; +extern crate protocol; +extern crate tokio; +extern crate tonic; + +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::sync::Arc; + +use common::timer; +use protocol::dspmc::partner::PartnerDspmc; +use protocol::dspmc::traits::PartnerDspmcProtocol; +use rpc::proto::common::Payload; +use rpc::proto::gen_dspmc_company::dspmc_company_client::DspmcCompanyClient; +use rpc::proto::gen_dspmc_partner::dspmc_partner_server::DspmcPartner; +use rpc::proto::gen_dspmc_partner::service_response::*; +use rpc::proto::gen_dspmc_partner::Commitment; +use rpc::proto::gen_dspmc_partner::CommitmentAck; +use rpc::proto::gen_dspmc_partner::CompanyPublicKeyAck; +use rpc::proto::gen_dspmc_partner::HelperPublicKeyAck; +use rpc::proto::gen_dspmc_partner::Init; +use rpc::proto::gen_dspmc_partner::InitAck; +use rpc::proto::gen_dspmc_partner::SendData; +use rpc::proto::gen_dspmc_partner::SendDataAck; +use rpc::proto::gen_dspmc_partner::ServiceResponse; +use rpc::proto::streaming::read_from_stream; +use rpc::proto::streaming::send_data; +use tonic::transport::Channel; +use tonic::Request; +use tonic::Response; +use tonic::Status; +use tonic::Streaming; + +pub struct DspmcPartnerService { + protocol: PartnerDspmc, + input_keys_path: String, + input_features_path: String, + input_with_headers: bool, + company_client_context: DspmcCompanyClient, + pub killswitch: Arc, +} + +impl DspmcPartnerService { + pub fn new( + input_keys_path: &str, + input_features_path: &str, + input_with_headers: bool, + company_client_context: DspmcCompanyClient, + ) -> DspmcPartnerService { + DspmcPartnerService { + protocol: PartnerDspmc::new(), + input_keys_path: String::from(input_keys_path), + input_features_path: String::from(input_features_path), + input_with_headers, + company_client_context, + killswitch: Arc::new(AtomicBool::new(false)), + } + } +} + +#[tonic::async_trait] +impl DspmcPartner for DspmcPartnerService { + async fn initialize(&self, _: Request) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("init") + .build(); + self.protocol.load_data( + &self.input_keys_path, + &self.input_features_path, + self.input_with_headers, + ); + Ok(Response::new(ServiceResponse { + ack: Some(Ack::InitAck(InitAck {})), + })) + } + + async fn send_data_to_company( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("init") + .build(); + + // Send partner data to company. - Partner acts as a client to company. + let mut ct1_ct2 = self.protocol.get_encrypted_keys().unwrap(); + let mut company_client_contxt = self.company_client_context.clone(); + + // XOR shares + metadata + ct3 + let xor_shares = self.protocol.get_features_xor_shares().unwrap(); + + ct1_ct2.extend(xor_shares); + + // ct2 + ct1 + offset + XOR shares + metadata + ct3 + _ = company_client_contxt + .send_u_partner(send_data(ct1_ct2)) + .await; + + Ok(Response::new(ServiceResponse { + ack: Some(Ack::SendDataAck(SendDataAck {})), + })) + } + + async fn send_company_public_key( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_company_public_key") + .build(); + let mut strm = request.into_inner(); + self.protocol + .set_company_public_key(read_from_stream(&mut strm).await?) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::CompanyPublicKeyAck(CompanyPublicKeyAck {})), + }) + }) + .map_err(|_| Status::internal("error writing")) + } + + async fn send_helper_public_key( + &self, + request: Request>, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("send_helper_public_key") + .build(); + let mut strm = request.into_inner(); + self.protocol + .set_helper_public_key(read_from_stream(&mut strm).await?) + .map(|_| { + Response::new(ServiceResponse { + ack: Some(Ack::HelperPublicKeyAck(HelperPublicKeyAck {})), + }) + }) + .map_err(|_| Status::internal("error writing")) + } + + async fn stop_service( + &self, + _: Request, + ) -> Result, Status> { + let _ = timer::Builder::new() + .label("server") + .extra_label("stop") + .build(); + { + debug!("Setting up flag for graceful down"); + self.killswitch.store(true, Ordering::SeqCst); + } + + Ok(Response::new(CommitmentAck {})) + } +} diff --git a/protocol/Cargo.toml b/protocol/Cargo.toml index 52a6feb..547435a 100644 --- a/protocol/Cargo.toml +++ b/protocol/Cargo.toml @@ -21,9 +21,11 @@ itertools = "0.9.0" rand = "0.8" rand_core = "0.5.1" hex = "0.4.2" -rayon = "1.3.0" +rayon = "1.8.0" num-bigint = { version = "0.4", features = ["rand"] } num-traits = "0.2" zeroize = "1.5.5" tempfile = "3.2.0" mockall = "0.10.2" +fernet = "0.2.1" +base64 = "0.13" diff --git a/protocol/src/dpmc/company.rs b/protocol/src/dpmc/company.rs new file mode 100644 index 0000000..3a73f96 --- /dev/null +++ b/protocol/src/dpmc/company.rs @@ -0,0 +1,373 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate csv; + +use std::collections::HashMap; +use std::collections::VecDeque; +use std::convert::TryInto; +use std::path::Path; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::permutations::undo_permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; +use itertools::Itertools; + +use super::load_data_keys; +use super::serialize_helper; +use super::writer_helper; +use super::ProtocolError; +use crate::dpmc::traits::CompanyDpmcProtocol; +use crate::shared::TFeatures; + +#[derive(Debug)] +struct PartnerData { + enc_alpha_t: Vec, + scalar_g: Vec, + partner_enc_shares: Vec, + e_partner: Vec>, +} + +#[derive(Debug)] +pub struct CompanyDpmc { + keypair_sk: Scalar, + keypair_pk: TPoint, + private_beta: Scalar, + ec_cipher: ECRistrettoParallel, + // TODO: consider using dyn pid::crypto::ECCipher trait? + plaintext: Arc>>>, + permutation: Arc>>, + h_k_beta_company: Arc>>>, + partners_queue: Arc>>, + id_map: Arc>>, + partner_shares: Arc>>>, +} + +impl CompanyDpmc { + pub fn new() -> CompanyDpmc { + let x = gen_scalar(); + CompanyDpmc { + keypair_sk: x, + keypair_pk: &x * &RISTRETTO_BASEPOINT_TABLE, + private_beta: gen_scalar(), + ec_cipher: ECRistrettoParallel::default(), + plaintext: Arc::new(RwLock::default()), + permutation: Arc::new(RwLock::default()), + h_k_beta_company: Arc::new(RwLock::default()), + partners_queue: Arc::new(RwLock::default()), + id_map: Arc::new(RwLock::default()), + partner_shares: Arc::new(RwLock::default()), + } + } + + pub fn get_company_public_key(&self) -> Result { + Ok(self.ec_cipher.to_bytes(&[self.keypair_pk])) + } + + pub fn load_data(&self, path: &str, input_with_headers: bool) { + load_data_keys(self.plaintext.clone(), path, input_with_headers); + } +} + +impl Default for CompanyDpmc { + fn default() -> Self { + Self::new() + } +} + +impl CompanyDpmcProtocol for CompanyDpmc { + fn set_encrypted_partner_keys_and_shares( + &self, + data: TPayload, + psum: Vec, + enc_alpha_t: Vec, + scalar_g: Vec, + xor_shares: TPayload, + ) -> Result<(), ProtocolError> { + match (self.partners_queue.clone().write(),) { + (Ok(mut partners_queue),) => { + let t = timer::Timer::new_silent("load_e_partner"); + // This is an array of exclusive-inclusive prefix sum - hence + // number of keys is one less than length + let num_keys = psum.len() - 1; + + // Unflatten + let pdata = { + let t = self.ec_cipher.to_points_encrypt(&data, &self.private_beta); + + psum.get(0..num_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| t.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + + t.qps("deserialize_exp", pdata.len()); + + partners_queue.push_back(PartnerData { + enc_alpha_t, + scalar_g, + partner_enc_shares: xor_shares, + e_partner: pdata, + }); + + Ok(()) + } + _ => { + error!("Cannot load e_partner"); + Err(ProtocolError::ErrorDeserialization( + "cannot load e_partner".to_string(), + )) + } + } + } + + fn get_permuted_keys(&self) -> Result { + match ( + self.plaintext.clone().read(), + self.h_k_beta_company.clone().write(), + self.permutation.clone().write(), + ) { + (Ok(pdata), Ok(mut edata), Ok(mut permutation)) => { + let t = timer::Timer::new_silent("u_company"); + + permutation.clear(); + permutation.extend(gen_permute_pattern(pdata.len())); + + // Permute, flatten, encrypt + let (mut d_flat, mut offset, metadata) = { + let mut d = pdata.clone(); + permute(permutation.as_slice(), &mut d); + + let (d_flat, offset, metadata) = serialize_helper(d); + + // Encrypt + let x = self + .ec_cipher + .hash_encrypt(d_flat.as_slice(), &self.private_beta); + + (x, offset, metadata) + }; + + // Unflatten and set encrypted keys + { + let psum = offset + .iter() + .map(|b| { + u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize + }) + .collect::>(); + + let num_keys = psum.len() - 1; + let mut x = psum + .get(0..num_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| d_flat.get(x1..x2).unwrap().to_vec()) + .collect::>>(); + + edata.clear(); + edata.extend(x.drain(..)); + } + + t.qps("encryption", d_flat.len()); + + // Serialize + let buf = { + let mut x = self.ec_cipher.to_bytes(d_flat.as_slice()); + + d_flat.clear(); + d_flat.shrink_to_fit(); + + offset.extend(metadata); + x.extend(offset); + x + }; + + Ok(buf) + } + _ => { + error!("Unable to encrypt UCompany:"); + Err(ProtocolError::ErrorEncryption( + "cannot encrypt UCompany".to_string(), + )) + } + } + } + + fn serialize_encrypted_keys_and_features(&self) -> Result { + match (self.partners_queue.clone().write(),) { + (Ok(mut partners_data_q),) => { + let t = timer::Timer::new_silent("e_partner"); + + let partner_data: PartnerData = partners_data_q.pop_front().unwrap(); + let pdata = partner_data.e_partner; + let enc_a_t = partner_data.enc_alpha_t; + let scalar_g = partner_data.scalar_g; + let enc_shares = partner_data.partner_enc_shares; + + let (mut d_flat, offset) = { + let (d_flat, mut offset, metadata) = serialize_helper(pdata.to_vec()); + offset.extend(metadata); + + // Serialize + (self.ec_cipher.to_bytes(&d_flat), offset) + }; + + t.qps("encryption", d_flat.len()); + + // Append offsets array + d_flat.extend(offset); + + // Append encrypted key alpha + d_flat.push(ByteBuffer { + buffer: enc_a_t.to_vec(), + }); + + d_flat.push(ByteBuffer { + buffer: scalar_g.to_vec(), + }); + + // Append offsets array + d_flat.extend(enc_shares.clone()); + d_flat.push(ByteBuffer { + buffer: (enc_shares.len() as u64).to_le_bytes().to_vec(), + }); + + Ok(d_flat) + } + _ => { + error!("Unable to flatten e_partner:"); + Err(ProtocolError::ErrorEncryption( + "cannot flatten e_partner".to_string(), + )) + } + } + } + + fn calculate_features_xor_shares( + &self, + partner_features: TFeatures, + p_mask_d: TPayload, + ) -> Result<(), ProtocolError> { + match self.partner_shares.clone().write() { + Ok(mut shares) => { + let n_features = partner_features.len(); + let p_mask = self.ec_cipher.to_points(&p_mask_d); + + let mask = p_mask + .iter() + .map(|x| { + let t = self.ec_cipher.to_bytes(&[x * self.keypair_sk]); + u64::from_le_bytes((t[0].buffer[0..8]).try_into().unwrap()) + }) + .collect::>(); + + for f_idx in 0..n_features { + let s = partner_features[f_idx] + .iter() + .zip_eq(mask.iter()) + .map(|(x1, x2)| *x1 ^ *x2) + .collect::>(); + shares.insert(f_idx, s); + } + + Ok(()) + } + _ => { + error!("Unable to calculate XOR shares"); + Err(ProtocolError::ErrorEncryption( + "unable to calculate XOR shares".to_string(), + )) + } + } + } + + fn write_company_to_id_map(&self) -> Result<(), ProtocolError> { + match ( + self.h_k_beta_company.clone().read(), + self.permutation.clone().read(), + self.id_map.clone().write(), + ) { + (Ok(pdata), Ok(permutation), Ok(mut id_map)) => { + let mut company_ragged = pdata.clone(); + undo_permute(permutation.as_slice(), &mut company_ragged); + + // Get the first column. + let company_keys = { + let tmp = company_ragged.iter().map(|s| s[0]).collect::>(); + self.ec_cipher.to_bytes(tmp.as_slice()) + }; + + id_map.clear(); + for (idx, k) in company_keys.iter().enumerate() { + id_map.push((k.to_string(), idx, true)); + } + + // Sort the id_map by the spine + id_map.sort_by(|(a, _, _), (b, _, _)| a.cmp(b)); + + Ok(()) + } + _ => { + error!("Cannot create id_map"); + Err(ProtocolError::ErrorDeserialization( + "cannot create id_map".to_string(), + )) + } + } + } + + fn print_id_map(&self) { + match (self.plaintext.clone().read(), self.id_map.clone().read()) { + (Ok(data), Ok(id_map)) => { + writer_helper(&data, &id_map, None); + } + _ => panic!("Cannot print id_map"), + } + } + + fn save_id_map(&self, path: &str) -> Result<(), ProtocolError> { + match (self.plaintext.clone().read(), self.id_map.clone().read()) { + (Ok(data), Ok(id_map)) => { + writer_helper(&data, &id_map, Some(path.to_string())); + Ok(()) + } + _ => Err(ProtocolError::ErrorIO( + "Unable to write company view to file".to_string(), + )), + } + } + + fn save_features_shares(&self, path_prefix: &str) -> Result<(), ProtocolError> { + match self.partner_shares.clone().read() { + Ok(shares) => { + assert!(shares.len() > 0); + + let mut out: Vec> = Vec::new(); + + for key in shares.keys().sorted() { + out.push(shares.get(key).unwrap().clone()); + } + + let p_filename = format!("{}{}", path_prefix, "_partner_features.csv"); + info!("revealing partner features to output file"); + common::files::write_u64cols_to_file(&mut out, Path::new(&p_filename)).unwrap(); + + Ok(()) + } + _ => Err(ProtocolError::ErrorIO( + "Unable to write company shares of partner features to file".to_string(), + )), + } + } +} diff --git a/protocol/src/dpmc/helper.rs b/protocol/src/dpmc/helper.rs new file mode 100644 index 0000000..c8fae65 --- /dev/null +++ b/protocol/src/dpmc/helper.rs @@ -0,0 +1,674 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate csv; + +use std::collections::HashMap; +use std::collections::HashSet; +use std::convert::TryInto; +use std::path::Path; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::permutations::undo_permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; +use fernet::Fernet; +use itertools::Itertools; +use rand::distributions::Uniform; +use rand::Rng; +use rayon::iter::ParallelDrainRange; +use rayon::iter::ParallelIterator; + +use super::writer_helper_dpmc; +use super::ProtocolError; +use crate::dpmc::traits::HelperDpmcProtocol; +use crate::shared::TFeatures; + +#[derive(Debug)] +struct PartnerData { + h_b_partner: Vec>, + features: TFeatures, + g_zi: Vec, +} + +struct SetDiff { + s_company: HashSet, + s_partner: HashSet, +} + +pub struct HelperDpmc { + keypair_sk: Scalar, + keypair_pk: TPoint, + company_public_key: Arc>, + ec_cipher: ECRistrettoParallel, + self_permutation: Arc>>, + partners_data: Arc>>, + set_diffs: Arc>>, + h_company_beta: Arc>>>, + partner_shares: Arc>>>, + id_map: Arc>>, +} + +impl HelperDpmc { + pub fn new() -> HelperDpmc { + let x = gen_scalar(); + HelperDpmc { + keypair_sk: x, + keypair_pk: &x * &RISTRETTO_BASEPOINT_TABLE, + company_public_key: Arc::new(RwLock::default()), + ec_cipher: ECRistrettoParallel::default(), + self_permutation: Arc::new(RwLock::default()), + partners_data: Arc::new(RwLock::default()), + set_diffs: Arc::new(RwLock::default()), + h_company_beta: Arc::new(RwLock::default()), + partner_shares: Arc::new(RwLock::default()), + id_map: Arc::new(RwLock::default()), + } + } + + pub fn set_company_public_key( + &self, + company_public_key: TPayload, + ) -> Result<(), ProtocolError> { + let pk = self.ec_cipher.to_points(&company_public_key); + // Check that one key is sent + assert_eq!(pk.len(), 1); + + match self.company_public_key.clone().write() { + Ok(mut company_pk) => { + *company_pk = pk[0]; + assert!(!(*company_pk).is_identity()); + Ok(()) + } + _ => { + error!("Unable to set company public key"); + Err(ProtocolError::ErrorEncryption( + "unable to set company public key".to_string(), + )) + } + } + } + + pub fn get_helper_public_key(&self) -> Result { + Ok(self.ec_cipher.to_bytes(&[self.keypair_pk])) + } +} + +impl Default for HelperDpmc { + fn default() -> Self { + Self::new() + } +} + +fn decrypt_shares(mut enc_t: TPayload, aes_key: String) -> (TFeatures, TPayload) { + let mut t = { + let fernet = Fernet::new(&aes_key).unwrap(); + enc_t + .par_drain(..) + .map(|x| { + let ctxt_str = String::from_utf8(x.buffer).unwrap(); + ByteBuffer { + buffer: fernet.decrypt(&ctxt_str).unwrap().to_vec(), + } + }) + .collect::>() + }; + + let num_features = + u64::from_le_bytes(t.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let num_rows = + u64::from_le_bytes(t.pop().unwrap().buffer.as_slice().try_into().unwrap()) as usize; + let g_zi = t.drain(num_features * num_rows..).collect::>(); + + let mut features = TFeatures::new(); + for i in (0..num_features).rev() { + let x = t + .drain(i * num_rows..) + .map(|x| u64::from_le_bytes(x.buffer.as_slice().try_into().unwrap())) + .collect::>(); + features.push(x); + } + + (features, g_zi) +} + +impl HelperDpmcProtocol for HelperDpmc { + fn remove_partner_scalar_from_p_and_set_shares( + &self, + data: TPayload, + psum: Vec, + enc_alpha_t: Vec, + p_scalar_g: TPayload, + xor_shares: TPayload, + ) -> Result<(), ProtocolError> { + match ( + self.partners_data.clone().write(), + self.set_diffs.clone().write(), + ) { + (Ok(mut partners_data), Ok(mut set_diffs)) => { + let t = timer::Timer::new_silent("load_h_b_partner"); + + let aes_key = { + let aes_key_bytes = { + let x = self + .ec_cipher + .to_points_encrypt(&p_scalar_g, &self.keypair_sk); + let y = self.ec_cipher.to_bytes(&x); + y[0].buffer.clone() + }; + base64::encode_config(aes_key_bytes, base64::URL_SAFE) + }; + + let alpha_t = { + let ctxt_str: String = String::from_utf8(enc_alpha_t.clone()).unwrap(); + + Scalar::from_bits( + Fernet::new(&aes_key) + .unwrap() + .decrypt(&ctxt_str) + .unwrap() + .to_vec()[0..32] + .try_into() + .unwrap(), + ) + }; + + // This is an array of exclusive-inclusive prefix sum - hence + // number of keys is one less than length + let num_keys = psum.len() - 1; + + // Unflatten + let pdata = { + let t = self.ec_cipher.to_points_encrypt(&data, &alpha_t.invert()); + + psum.get(0..num_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| t.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + + t.qps("deserialize_exp", pdata.len()); + + let (features, g_zi) = decrypt_shares(xor_shares, aes_key); + + partners_data.push(PartnerData { + h_b_partner: pdata, + features, + g_zi, + }); + + set_diffs.push(SetDiff { + s_company: HashSet::::new(), + s_partner: HashSet::::new(), + }); + + Ok(()) + } + _ => { + error!("Cannot load e_company"); + Err(ProtocolError::ErrorDeserialization( + "cannot load h_b_partner".to_string(), + )) + } + } + } + + fn set_encrypted_company( + &self, + company: TPayload, + company_psum: Vec, + ) -> Result<(), ProtocolError> { + match (self.h_company_beta.clone().write(),) { + (Ok(mut h_company_beta),) => { + // To ragged array + let num_keys = company_psum.len() - 1; + h_company_beta.clear(); + let e_company = { + let t = self.ec_cipher.to_points(&company); + company_psum + .get(0..num_keys) + .unwrap() + .iter() + .zip_eq(company_psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| t.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + h_company_beta.extend(e_company); + + Ok(()) + } + _ => { + error!("Unable to obtain locks to buffers for set_encrypted_company"); + Err(ProtocolError::ErrorEncryption( + "unable to encrypt data".to_string(), + )) + } + } + } + + fn calculate_set_diff(&self, partner_idx: usize) -> Result<(), ProtocolError> { + match ( + self.h_company_beta.clone().read(), + self.partners_data.clone().read(), + self.set_diffs.clone().write(), + ) { + (Ok(e_company), Ok(partners_data), Ok(mut set_diffs)) => { + let e_partner = &partners_data[partner_idx].h_b_partner; + + let set_diff = &mut set_diffs[partner_idx]; + let s_company = &mut set_diff.s_company; + let s_partner = &mut set_diff.s_partner; + + let s_c = e_company.iter().map(|e| e[0]).collect::>(); + let s_p = e_partner.iter().map(|e| e[0]).collect::>(); + + let max_len = e_company.iter().map(|e| e.len()).max().unwrap(); + + // Start with both vectors as all valid + let mut e_c_valid = vec![true; e_company.len()]; + let mut e_p_valid = vec![true; e_partner.len()]; + + for idx in 0..max_len { + // TODO: This should be a ByteBuffer instead of a vec + let mut e_c_map = HashMap::, usize>::new(); + + // Strip the idx-th key (viewed as a column) + for (e, i) in e_company + .iter() + .enumerate() + .filter(|(_, e)| e.len() > idx) + .map(|(i, e)| (e[idx], i)) + { + // Ristretto points are not hashable by themselves + e_c_map.insert(e.compress().to_bytes().to_vec(), i); + } + + // Vector of indices of e_p that match. These will be set to false + let mut e_p_match_idx = Vec::::new(); + for ((i, e), _) in e_partner + .iter() + .enumerate() + .zip_eq(e_p_valid.iter()) + .filter(|((_, _), &f)| f) + { + // Find the minimum index where match happens + let match_idx = e + .iter() + .map(|key| + // TODO: Replace with match + if e_c_map.contains_key(&key.compress().to_bytes().to_vec()) { + let &m_idx = e_c_map.get(&key.compress().to_bytes().to_vec()).unwrap(); + (m_idx, e_c_valid[m_idx]) + } else { + // Using length of vector as a sentinel value. Will get + // filtered out because of false + (e_c_valid.len(), false) + }) + .filter(|(_, f)| *f) + .map(|(e, _)| e) + .min(); + + // For those indices that have matched - set them to false + // Also assign the correct keys + if let Some(m_idx) = match_idx { + e_c_valid[m_idx] = false; + e_p_match_idx.push(i); + } + } + + // Set all e_p that matched to false - so they aren't matched in the next + // iteration + e_p_match_idx.iter().for_each(|&idx| e_p_valid[idx] = false); + } + + // Create S_p by filtering out values that matched + s_partner.clear(); + { + // Only keep s_p that have not been matched + let mut inp = s_p + .iter() + .zip_eq(e_p_valid.iter()) + .filter(|(_, &f)| f) + .map(|(&e, _)| e) + .collect::>(); + + if !inp.is_empty() { + // Permute s_p + permute(gen_permute_pattern(inp.len()).as_slice(), &mut inp); + + // save output + let x = self.ec_cipher.to_bytes(inp.as_slice()); + let y = x.iter().map(|t| t.to_string()).collect::>(); + s_partner.extend(y); + } + } + + // Create S_c by filtering out values that matched + let t = s_c + .iter() + .zip_eq(e_c_valid.iter()) + .filter(|(_, &f)| f) + .map(|(&e, _)| e) + .collect::>(); + s_company.clear(); + + if !t.is_empty() { + let x = self.ec_cipher.to_bytes(t.as_slice()); + let y = x.iter().map(|t| t.to_string()).collect::>(); + s_company.extend(y); + } + + Ok(()) + } + _ => { + error!("Unable to obtain locks to buffers for set diff operation"); + Err(ProtocolError::ErrorEncryption( + "unable to encrypt data".to_string(), + )) + } + } + } + + /// s_partner_i has all the data that are in Partner_i but did not get matched. + /// s_company_i has all the data that are in Company but not in partner i. + /// + /// 1. Add all elements from P_i that are not in SP_i in the id_map. + /// Essentially, these are the items on the intersection. + /// 2. Add intersection of all SC_i elements. These are the items that company + /// has, but none of the partners had. + /// + /// Example: + /// C: a b c d e f + /// P1: a b c g f + /// P2: a b d h + /// + /// SC1: d e + /// SP1: g + /// SC2: c e f + /// SP2: h + /// + /// 1) P1 \ SP1: a b c f + /// LJ result: a b c f + /// + /// 2) P2 \ SP2: a b d + /// LJ result: a b c d f + /// + /// 3) SC1 intersection SC2: e + /// LJ result: a b c d e f + fn calculate_id_map(&self, num_of_matches: usize) { + match ( + self.partners_data.clone().read(), + self.set_diffs.clone().read(), + self.self_permutation.clone().read(), + self.id_map.clone().write(), + ) { + (Ok(partners_data), Ok(set_diffs), Ok(permutation), Ok(mut id_map)) => { + assert_eq!(partners_data.len(), set_diffs.len()); + + // Compute the intersection of all SC_i + let sc_intersection = { + let mut x: HashSet<_> = set_diffs[0].s_company.clone(); + + for p in 1..set_diffs.len() { + x.retain(|e| set_diffs[p].s_company.contains(e)); + } + x + }; + + // Create a hashmap for all unique partner keys that are not in S_Partner + let mut unique_partner_ids: HashMap> = HashMap::new(); + for p in 0..set_diffs.len() { + // Get the first column. + let partner_keys = { + let tmp = { + let mut h_b_partner = partners_data[p].h_b_partner.clone(); + undo_permute(permutation.as_slice(), &mut h_b_partner); + + h_b_partner.iter().map(|s| s[0]).collect::>() + }; + self.ec_cipher.to_bytes(tmp.as_slice()) + }; + + for (idx, key) in partner_keys.iter().enumerate() { + // if not in S_Partner + if !set_diffs[p].s_partner.contains(&key.to_string()) { + // if not already in the id map + if let std::collections::hash_map::Entry::Vacant(e) = + unique_partner_ids.entry(key.to_string()) + { + e.insert(vec![(idx, p)]); + } else { + let v = unique_partner_ids.get_mut(&key.to_string()).unwrap(); + if v.len() < num_of_matches { + v.push((idx, p)); + } + } + } + } + } + // Add each item of unique_partner_ids into id_map. + id_map.clear(); + id_map.extend({ + let x = unique_partner_ids + .iter_mut() + .map(|(key, v)| { + v.resize(num_of_matches, (0, 0)); + v.iter() + .map(|(idx, from_p)| (key.to_string(), *idx, true, *from_p)) + .collect::>() + }) + .collect::>(); + x.into_iter().flatten().collect::>() + }); + // Add all the remaining keys that company has but the partners don't. + id_map.extend({ + let x = sc_intersection + .clone() + .iter() + .map(|key| { + (0..num_of_matches) + .map(|_| (key.to_string(), 0, false, 0)) + .collect::>() + }) + .collect::>(); + x.into_iter().flatten().collect::>() + }); + + // Sort the id_map by the spine + id_map.sort_by(|(a, _, _, _), (b, _, _, _)| a.cmp(b)); + } + _ => panic!("Cannot make v"), + } + } + + fn calculate_features_xor_shares(&self) -> Result { + match ( + self.partners_data.clone().read(), + self.id_map.clone().read(), + self.company_public_key.clone().read(), + self.partner_shares.clone().write(), + ) { + (Ok(partners_data), Ok(id_map), Ok(company_public_key), Ok(mut shares)) => { + let mut rng = rand::thread_rng(); + let range = Uniform::new(0_u64, u64::MAX); + let t = timer::Timer::new_silent("calculate_features_xor_shares"); + + // Find maximum number of features across all partners. + let n_features = partners_data + .iter() + .map(|p| p.features.len()) + .max() + .unwrap(); + + let (t_i, mut g_zi) = { + let z_i = (0..id_map.len()).map(|_| gen_scalar()).collect::>(); + let x = z_i + .iter() + .map(|a| { + let x = self.ec_cipher.to_bytes(&[a * *company_public_key]); + x[0].clone() + }) + .collect::>(); + let y = z_i + .iter() + .map(|a| a * &RISTRETTO_BASEPOINT_TABLE) + .collect::>(); + (x, y) + }; + + let mut d_flat = { + let p_mask_v = partners_data + .iter() + .map(|p_data| self.ec_cipher.to_points(&p_data.g_zi)) + .collect::>(); + + let mut v_p = Vec::::new(); + for f_idx in (0..n_features).rev() { + let mask = (0..id_map.len()) + .map(|_| rng.sample(range)) + .collect::>(); + let t = id_map + .iter() + .enumerate() + .map(|(i, (_, idx, exists, from_partner))| { + let y = if *exists { + if f_idx == 0 { + g_zi[i] = p_mask_v[*from_partner][*idx]; + } + let partner_features = &partners_data[*from_partner].features; + if f_idx < partner_features.len() { + partner_features[f_idx][*idx] ^ mask[i] + } else { + // In case the data are not padded correctly, + // return secret shares of the first feature. + partner_features[0][*idx] ^ mask[i] + } + } else { + let y = u64::from_le_bytes( + (t_i[i].buffer[0..8]).try_into().unwrap(), + ); + y ^ mask[i] + }; + ByteBuffer { + buffer: y.to_le_bytes().to_vec(), + } + }) + .collect::>(); + + v_p.push(t); + shares.insert(f_idx, mask); + } + + v_p.into_iter().flatten().collect::>() + }; + + d_flat.extend(self.ec_cipher.to_bytes(&g_zi)); + + let metadata = vec![ + ByteBuffer { + buffer: (id_map.len() as u64).to_le_bytes().to_vec(), + }, + ByteBuffer { + buffer: (n_features as u64).to_le_bytes().to_vec(), + }, + ]; + + d_flat.extend(metadata); + t.qps("d_flat", d_flat.len()); + Ok(d_flat) + } + _ => { + error!("Cannot read id_map"); + Err(ProtocolError::ErrorEncryption( + "unable to read id_map".to_string(), + )) + } + } + } + + fn print_id_map(&self) { + match self.id_map.clone().read() { + Ok(id_map) => { + // Create fake data since we only have encrypted partner keys + let m_idx = id_map + .iter() + .filter(|(_, _, flag, _)| *flag) + .map(|(_, idx, _, _)| idx) + .max() + .unwrap(); + + let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx + 1]; + + for i in 0..id_map.len() { + let (_, idx, flag, _) = id_map[i]; + if flag { + data[idx] = vec![format!(" Partner enc key at pos {}", idx)]; + } + } + + writer_helper_dpmc(&data, &id_map, None); + } + _ => panic!("Cannot print id_map"), + } + } + + fn save_id_map(&self, path: &str) -> Result<(), ProtocolError> { + match self.id_map.clone().read() { + Ok(id_map) => { + // Create fake data since we only have encrypted partner keys + let m_idx = id_map + .iter() + .filter(|(_, _, flag, _)| *flag) + .map(|(_, idx, _, _)| idx) + .max() + .unwrap(); + + let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx + 1]; + + for i in 0..id_map.len() { + let (_, idx, flag, _) = id_map[i]; + if flag { + data[idx] = vec![format!(" Partner enc key at pos {}", idx)]; + } + } + + writer_helper_dpmc(&data, &id_map, Some(path.to_string())); + Ok(()) + } + _ => Err(ProtocolError::ErrorIO( + "Unable to write company view to file".to_string(), + )), + } + } + + fn save_features_shares(&self, path_prefix: &str) -> Result<(), ProtocolError> { + match self.partner_shares.clone().read() { + Ok(shares) => { + assert!(shares.len() > 0); + + let mut out: Vec> = Vec::new(); + + for key in shares.keys().sorted() { + out.push(shares.get(key).unwrap().clone()); + } + + let p_filename = format!("{}{}", path_prefix, "_partner_features.csv"); + info!("revealing partner features to output file"); + common::files::write_u64cols_to_file(&mut out, Path::new(&p_filename)).unwrap(); + + Ok(()) + } + _ => Err(ProtocolError::ErrorIO( + "Unable to write company shares of partner features to file".to_string(), + )), + } + } +} diff --git a/protocol/src/dpmc/mod.rs b/protocol/src/dpmc/mod.rs new file mode 100644 index 0000000..128831e --- /dev/null +++ b/protocol/src/dpmc/mod.rs @@ -0,0 +1,216 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate csv; + +use std::collections::HashSet; +use std::error::Error; +use std::fmt; +use std::sync::Arc; +use std::sync::RwLock; + +use common::files; +use common::timer; +use crypto::prelude::*; + +#[derive(Debug)] +pub enum ProtocolError { + ErrorDeserialization(String), + ErrorSerialization(String), + ErrorEncryption(String), + ErrorCalcSetDiff(String), + ErrorReencryption(String), + ErrorIO(String), +} + +impl fmt::Display for ProtocolError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "protocol error {}", self) + } +} + +impl Error for ProtocolError {} + +fn load_data_keys(plaintext: Arc>>>, path: &str, input_with_headers: bool) { + let t = timer::Timer::new_silent("load data"); + + let mut lines = files::read_csv_as_strings(path, true); + let text_len = lines.len(); + + if let Ok(mut data) = plaintext.write() { + data.clear(); + let mut line_it = lines.drain(..); + // Strip the header + if input_with_headers { + line_it.next(); + } + + let mut t = HashSet::>::new(); + // Filter out zero length strings - these will come from ragged + // arrays since they are padded out to the longest array + // Also deduplicate all input + for line in line_it { + let v = line + .iter() + .map(String::from) + .filter(|s| !s.is_empty()) + .collect::>(); + if !t.contains(&v) { + data.push(v.clone()); + t.insert(v); + } + } + info!("Read {} lines from {}", text_len, path,); + } + + t.qps("text read", text_len); +} + +fn load_data_features(plaintext: Arc>>>, path: &str) { + let t = timer::Timer::new_silent("load features"); + + let lines = files::read_csv_as_u64(path); + let n_rows = lines.len(); + let n_cols = lines[0].len(); + assert!(n_rows > 0); + assert!(n_cols > 0); + + if let Ok(mut data) = plaintext.write() { + data.clear(); + + // Make sure its not a ragged array + for i in lines.iter() { + assert_eq!(n_cols, i.len()); + } + + let mut features: Vec> = vec![vec![0; n_rows]; n_cols]; + + for (i, v) in lines.iter().enumerate() { + for (j, z) in v.iter().enumerate() { + features[j][i] = *z; + } + } + + data.extend(features.drain(..)); + + info!("Read {} lines from {}", n_rows, path,); + } + + t.qps("text read", n_rows); +} + +fn writer_helper_dpmc( + data: &[Vec], + id_map: &[(String, usize, bool, usize)], + path: Option, +) { + let mut device = match path { + Some(path) => { + let wr = csv::WriterBuilder::new() + .flexible(true) + .buffer_capacity(1024) + .from_path(path) + .unwrap(); + Some(wr) + } + None => None, + }; + + for (key, idx, flag, _) in id_map.iter() { + let mut v = vec![(*key).clone()]; + + match flag { + true => v.extend(data[*idx].clone()), + false => v.push("NA".to_string()), + } + + match device { + Some(ref mut wr) => { + wr.write_record(v.as_slice()).unwrap(); + } + None => { + println!("{}", v.join(",")); + } + } + } +} + +fn writer_helper(data: &[Vec], id_map: &[(String, usize, bool)], path: Option) { + let mut device = match path { + Some(path) => { + let wr = csv::WriterBuilder::new() + .flexible(true) + .buffer_capacity(1024) + .from_path(path) + .unwrap(); + Some(wr) + } + None => None, + }; + + for (key, idx, flag) in id_map.iter() { + let mut v = vec![(*key).clone()]; + + match flag { + true => v.extend(data[*idx].clone()), + false => v.push("NA".to_string()), + } + + match device { + Some(ref mut wr) => { + wr.write_record(v.as_slice()).unwrap(); + } + None => { + println!("{}", v.join(",")); + } + } + } +} + +fn compute_prefix_sum(input: &[usize]) -> Vec { + let prefix_sum = input + .iter() + .scan(0, |sum, i| { + *sum += i; + Some(*sum) + }) + .collect::>(); + + // offset is now a combined exclusive and inclusive prefix sum + // that will help us convert to a flattened vector and back to a + // vector of vectors + let mut output = Vec::::with_capacity(prefix_sum.len() + 1); + output.push(0); + output.extend(prefix_sum); + output +} + +fn serialize_helper(data: Vec>) -> (Vec, TPayload, TPayload) { + let offset = { + let lengths = data.iter().map(|v| v.len()).collect::>(); + compute_prefix_sum(&lengths) + .iter() + .map(|&o| ByteBuffer { + buffer: (o as u64).to_le_bytes().to_vec(), + }) + .collect::>() + }; + + let d_flat = data.into_iter().flatten().collect::>(); + + let metadata = vec![ + ByteBuffer { + buffer: (d_flat.len() as u64).to_le_bytes().to_vec(), + }, + ByteBuffer { + buffer: (offset.len() as u64).to_le_bytes().to_vec(), + }, + ]; + + (d_flat, offset, metadata) +} + +pub mod company; +pub mod helper; +pub mod partner; +pub mod traits; diff --git a/protocol/src/dpmc/partner.rs b/protocol/src/dpmc/partner.rs new file mode 100644 index 0000000..2a8146c --- /dev/null +++ b/protocol/src/dpmc/partner.rs @@ -0,0 +1,305 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +use std::convert::TryInto; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; +use fernet::Fernet; +use rayon::iter::ParallelDrainRange; +use rayon::iter::ParallelIterator; + +use super::load_data_features; +use super::load_data_keys; +use super::serialize_helper; +use super::ProtocolError; +use crate::dpmc::traits::PartnerDpmcProtocol; +use crate::shared::TFeatures; + +pub struct PartnerDpmc { + keypair_sk: Scalar, + keypair_pk: TPoint, + partner_scalar: Scalar, + company_public_key: Arc>, + helper_public_key: Arc>, + ec_cipher: ECRistrettoParallel, + permutation: Arc>>, + plaintext_keys: Arc>>>, + plaintext_features: Arc>, + aes_key: Arc>, +} + +impl PartnerDpmc { + pub fn new() -> PartnerDpmc { + let x = gen_scalar(); + PartnerDpmc { + keypair_sk: x, + keypair_pk: &x * &RISTRETTO_BASEPOINT_TABLE, + partner_scalar: gen_scalar(), + company_public_key: Arc::new(RwLock::default()), + helper_public_key: Arc::new(RwLock::default()), + ec_cipher: ECRistrettoParallel::default(), + permutation: Arc::new(RwLock::default()), + plaintext_keys: Arc::new(RwLock::default()), + plaintext_features: Arc::new(RwLock::default()), + aes_key: Arc::new(RwLock::default()), + } + } + + // TODO: Fix header processing + pub fn load_data(&self, path_keys: &str, path_features: &str, input_with_headers: bool) { + load_data_keys(self.plaintext_keys.clone(), path_keys, input_with_headers); + load_data_features(self.plaintext_features.clone(), path_features); + + match ( + self.plaintext_keys.clone().read(), + self.plaintext_features.clone().read(), + ) { + (Ok(keys), Ok(features)) => { + assert!(features.len() > 0); + assert_eq!(keys.len(), features[0].len()); + } + _ => { + error!("Unable to read keys and features"); + } + } + } + + pub fn get_size(&self) -> usize { + self.plaintext_keys.clone().read().unwrap().len() + } + + pub fn get_partner_public_key(&self) -> Result { + Ok(self.ec_cipher.to_bytes(&[self.keypair_pk])) + } + + pub fn set_company_public_key( + &self, + company_public_key: TPayload, + ) -> Result<(), ProtocolError> { + let pk = self.ec_cipher.to_points(&company_public_key); + // Check that one key is sent + assert_eq!(pk.len(), 1); + + match self.company_public_key.clone().write() { + Ok(mut company_pk) => { + *company_pk = pk[0]; + assert!(!(*company_pk).is_identity()); + Ok(()) + } + _ => { + error!("Unable to set company public key"); + Err(ProtocolError::ErrorEncryption( + "unable to set company public key".to_string(), + )) + } + } + } + + pub fn set_helper_public_key(&self, helper_public_key: TPayload) -> Result<(), ProtocolError> { + let pk = self.ec_cipher.to_points(&helper_public_key); + // Check that one key is sent + assert_eq!(pk.len(), 1); + match ( + self.helper_public_key.clone().write(), + self.aes_key.clone().write(), + ) { + (Ok(mut helper_pk), Ok(mut aes_key)) => { + *helper_pk = pk[0]; + assert!(!(*helper_pk).is_identity()); + + *aes_key = { + let x = self + .ec_cipher + .to_bytes(&[self.partner_scalar * (*helper_pk)]); + let aes_key_bytes = x[0].buffer.clone(); + base64::encode_config(aes_key_bytes, base64::URL_SAFE) + }; + Ok(()) + } + _ => { + error!("Unable to set helper public key"); + Err(ProtocolError::ErrorEncryption( + "unable to set helper public key".to_string(), + )) + } + } + } +} + +impl Default for PartnerDpmc { + fn default() -> Self { + Self::new() + } +} + +impl PartnerDpmcProtocol for PartnerDpmc { + fn get_encrypted_keys(&self) -> Result { + match ( + self.plaintext_keys.clone().read(), + self.aes_key.clone().read(), + self.permutation.clone().write(), + ) { + (Ok(pdata), Ok(aes_key), Ok(mut permutation)) => { + let t = timer::Timer::new_silent("partner data"); + + // Generate random permutation. + permutation.clear(); + permutation.extend(gen_permute_pattern(pdata.len())); + + // Permute, flatten, encrypt + let (mut d_flat, offset) = { + let mut d = pdata.clone(); + permute(permutation.as_slice(), &mut d); + + let (d_flat, mut offset, metadata) = serialize_helper(d); + offset.extend(metadata); + + // Encrypt + ( + // Blind the keys by encrypting + self.ec_cipher + .hash_encrypt_to_bytes(d_flat.as_slice(), &self.keypair_sk), + offset, + ) + }; + + t.qps("encryption", d_flat.len()); + + // Append offsets array + d_flat.extend(offset); + + let fernet = Fernet::new(&aes_key).unwrap(); + let ctxt = fernet.encrypt(self.keypair_sk.to_bytes().clone().as_slice()); + // Append encrypted key alpha + d_flat.push(ByteBuffer { + buffer: ctxt.as_bytes().to_vec(), + }); + + let p_scalar_times_g = self + .ec_cipher + .to_bytes(&[&self.partner_scalar * &RISTRETTO_BASEPOINT_TABLE]); + d_flat.extend(p_scalar_times_g); + + Ok(d_flat) + } + _ => { + error!("Unable to encrypt data"); + Err(ProtocolError::ErrorEncryption( + "unable to encrypt data".to_string(), + )) + } + } + } + + fn get_features_xor_shares(&self) -> Result { + match ( + self.plaintext_features.clone().read(), + self.company_public_key.clone().read(), + self.aes_key.clone().read(), + self.permutation.clone().read(), + ) { + (Ok(pdata), Ok(company_public_key), Ok(aes_key), Ok(permutation)) => { + let t = timer::Timer::new_silent("get_features_xor_shares"); + let n_rows = pdata[0].len(); + let n_features = pdata.len(); + + // Apply the same permutation as in keys. + let mut permuted_pdata = pdata.clone(); + for i in 0..n_features { + permute(permutation.as_slice(), &mut permuted_pdata[i]); + } + + let z_i = (0..n_rows) + .collect::>() + .iter() + .map(|_| gen_scalar()) + .collect::>(); + + let mut d_flat = { + let r_i = { + let y_zi = { + let t = z_i + .iter() + .map(|x| *x * *company_public_key) + .collect::>(); + self.ec_cipher.to_bytes(&t) + }; + y_zi.iter() + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) + .collect::>() + }; + + let mut v_p = Vec::::new(); + + for f_idx in (0..n_features).rev() { + let t = (0..n_rows) + .map(|x| x) + .collect::>() + .iter() + .map(|i| { + let z: u64 = permuted_pdata[f_idx][*i] ^ r_i[*i]; + ByteBuffer { + buffer: z.to_le_bytes().to_vec(), + } + }) + .collect::>(); + v_p.push(t); + } + + v_p.into_iter().flatten().collect::>() + }; + + { + let g_zi = { + let t = z_i + .iter() + .map(|x| x * &RISTRETTO_BASEPOINT_TABLE) + .collect::>(); + self.ec_cipher.to_bytes(&t) + }; + d_flat.extend(g_zi); + } + + let metadata = vec![ + ByteBuffer { + buffer: (n_rows as u64).to_le_bytes().to_vec(), + }, + ByteBuffer { + buffer: (n_features as u64).to_le_bytes().to_vec(), + }, + ]; + d_flat.extend(metadata); + + let e_d_flat = { + let fernet = Fernet::new(&aes_key.clone()).unwrap(); + d_flat + .par_drain(..) + .map(|x| { + let ctxt = fernet.encrypt(x.buffer.as_slice()); + ByteBuffer { + buffer: ctxt.as_bytes().to_vec(), + } + }) + .collect::>() + }; + t.qps("e_d_flat", e_d_flat.len()); + + Ok(e_d_flat) + } + _ => { + error!("Unable to encrypt data"); + Err(ProtocolError::ErrorEncryption( + "unable to encrypt data".to_string(), + )) + } + } + } +} diff --git a/protocol/src/dpmc/traits.rs b/protocol/src/dpmc/traits.rs new file mode 100644 index 0000000..5ce690e --- /dev/null +++ b/protocol/src/dpmc/traits.rs @@ -0,0 +1,58 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate crypto; + +use crypto::prelude::TPayload; + +use crate::dpmc::ProtocolError; +use crate::shared::TFeatures; + +pub trait PartnerDpmcProtocol { + fn get_encrypted_keys(&self) -> Result; + fn get_features_xor_shares(&self) -> Result; +} + +pub trait HelperDpmcProtocol { + fn remove_partner_scalar_from_p_and_set_shares( + &self, + data: TPayload, + psum: Vec, + enc_alpha_t: Vec, + p_scalar_g: TPayload, + xor_shares: TPayload, + ) -> Result<(), ProtocolError>; + fn calculate_set_diff(&self, partner_num: usize) -> Result<(), ProtocolError>; + fn calculate_id_map(&self, calculate_id_map: usize); + fn set_encrypted_company( + &self, + company: TPayload, + company_psum: Vec, + ) -> Result<(), ProtocolError>; + fn calculate_features_xor_shares(&self) -> Result; + fn print_id_map(&self); + fn save_id_map(&self, path: &str) -> Result<(), ProtocolError>; + fn save_features_shares(&self, path_prefix: &str) -> Result<(), ProtocolError>; +} + +pub trait CompanyDpmcProtocol { + fn set_encrypted_partner_keys_and_shares( + &self, + keys_data: TPayload, + keys_psum: Vec, + enc_alpha_t: Vec, + p_scalar_g: Vec, + shares_data: TPayload, + ) -> Result<(), ProtocolError>; + fn get_permuted_keys(&self) -> Result; + fn serialize_encrypted_keys_and_features(&self) -> Result; + fn calculate_features_xor_shares( + &self, + features: TFeatures, + data: TPayload, + ) -> Result<(), ProtocolError>; + fn write_company_to_id_map(&self) -> Result<(), ProtocolError>; + fn print_id_map(&self); + fn save_id_map(&self, path: &str) -> Result<(), ProtocolError>; + fn save_features_shares(&self, path_prefix: &str) -> Result<(), ProtocolError>; +} diff --git a/protocol/src/dspmc/company.rs b/protocol/src/dspmc/company.rs new file mode 100644 index 0000000..4ca1b70 --- /dev/null +++ b/protocol/src/dspmc/company.rs @@ -0,0 +1,754 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate csv; + +use std::collections::HashMap; +use std::collections::VecDeque; +use std::convert::TryInto; +use std::path::Path; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::permutations::undo_permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; +use itertools::Itertools; +use rand::distributions::Uniform; +use rand::Rng; + +use super::load_data_keys; +use super::serialize_helper; +use super::writer_helper; +use super::ProtocolError; +use crate::dspmc::traits::CompanyDspmcProtocol; +use crate::shared::TFeatures; + +#[derive(Debug)] +struct PartnerData { + scalar_g: Vec, + n_rows: usize, + n_features: usize, +} + +#[derive(Debug)] +pub struct CompanyDspmc { + keypair_sk: (Scalar, Scalar), + keypair_pk: (TPoint, TPoint), + helper_public_key: Arc>, + ec_cipher: ECRistrettoParallel, + ct1: Arc>>>, // initially holds ct1 and later ct1'' + ct2: Arc>>>, // initially holds ct2 and later ct2'' + v1: Arc>, + u1: Arc>>, + plaintext: Arc>>>, + permutation: Arc>>, + perms: Arc, Vec)>>, // (p_3, p_4) + blinds: Arc, Vec)>>, // (v_cd, v_cs) + enc_company: Arc>>>, + partners_queue: Arc>>, + id_map: Arc>>, + partner_shares: Arc>>>, +} + +impl CompanyDspmc { + pub fn new() -> CompanyDspmc { + let x1 = gen_scalar(); + let x2 = gen_scalar(); + CompanyDspmc { + keypair_sk: (x1, x2), + keypair_pk: ( + &x1 * &RISTRETTO_BASEPOINT_TABLE, + &x2 * &RISTRETTO_BASEPOINT_TABLE, + ), + helper_public_key: Arc::new(RwLock::default()), + ec_cipher: ECRistrettoParallel::default(), + ct1: Arc::new(RwLock::default()), + ct2: Arc::new(RwLock::default()), + v1: Arc::new(RwLock::default()), + u1: Arc::new(RwLock::default()), + plaintext: Arc::new(RwLock::default()), + permutation: Arc::new(RwLock::default()), + perms: Arc::new(RwLock::default()), + blinds: Arc::new(RwLock::default()), + enc_company: Arc::new(RwLock::default()), + partners_queue: Arc::new(RwLock::default()), + id_map: Arc::new(RwLock::default()), + partner_shares: Arc::new(RwLock::default()), + } + } + + pub fn get_company_public_key(&self) -> Result { + Ok(self + .ec_cipher + .to_bytes(&vec![self.keypair_pk.0, self.keypair_pk.1])) + } + + pub fn load_data(&self, path: &str, input_with_headers: bool) { + load_data_keys(self.plaintext.clone(), path, input_with_headers); + } + + pub fn gen_permutations(&self) { + match ( + self.perms.clone().write(), + self.blinds.clone().write(), + self.ct1.clone().read(), + ) { + (Ok(mut perms), Ok(mut blinds), Ok(ct1_data)) => { + let mut rng = rand::thread_rng(); + let range = Uniform::new(0_u64, u64::MAX); + + let data_len = ct1_data.len(); + assert!(data_len > 0); + perms.0.clear(); + perms.1.clear(); + perms.0.extend(gen_permute_pattern(data_len)); + perms.1.extend(gen_permute_pattern(data_len)); + + blinds.0 = (0..data_len) + .map(|_| rng.sample(range)) + .collect::>(); + blinds.1 = (0..data_len) + .map(|_| rng.sample(range)) + .collect::>(); + } + _ => {} + } + } + + pub fn set_helper_public_key(&self, helper_public_key: TPayload) -> Result<(), ProtocolError> { + let pk = self.ec_cipher.to_points(&helper_public_key); + // Check that one key is sent + assert_eq!(pk.len(), 1); + match self.helper_public_key.clone().write() { + Ok(mut helper_pk) => { + *helper_pk = pk[0]; + Ok(()) + } + _ => { + error!("Unable to set helper public key"); + Err(ProtocolError::ErrorEncryption( + "unable to set helper public key".to_string(), + )) + } + } + } +} + +impl Default for CompanyDspmc { + fn default() -> Self { + Self::new() + } +} + +impl CompanyDspmcProtocol for CompanyDspmc { + fn set_encrypted_partner_keys_and_shares( + &self, + ct1: TPayload, + ct2: TPayload, + psum: Vec, + ct3: Vec, + xor_features: TFeatures, + ) -> Result<(), ProtocolError> { + match ( + self.partners_queue.clone().write(), + self.ct1.clone().write(), + self.ct2.clone().write(), + self.v1.clone().write(), + ) { + (Ok(mut partners_queue), Ok(mut all_ct1), Ok(mut all_ct2), Ok(mut all_v1)) => { + let t = timer::Timer::new_silent("load_ct2"); + // This is an array of exclusive-inclusive prefix sum - hence + // number of keys is one less than length + let num_keys = psum.len() - 1; + + // Unflatten + let ct1_points = { + let t = self.ec_cipher.to_points(&ct1); + + psum.get(0..num_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| t.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + let ct2_points = { + let t = self.ec_cipher.to_points(&ct2); + + psum.get(0..num_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| t.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + assert_eq!(ct1_points.len(), ct2_points.len()); + let data_size = ct1_points.len(); + t.qps("deserialize_exp", data_size); + + all_ct1.extend(ct1_points); + all_ct2.extend(ct2_points); + + let n_rows = xor_features[0].len(); + let n_features = xor_features.len(); + assert_eq!(n_rows, data_size); + + for f_idx in 0..n_features { + if all_v1.len() != n_features { + all_v1.push(xor_features[f_idx].clone()); + } else { + all_v1[f_idx].extend(xor_features[f_idx].clone()); + } + } + + partners_queue.push_back(PartnerData { + scalar_g: ct3, + n_rows, + n_features, + }); + Ok(()) + } + _ => { + error!("Cannot load ct2"); + Err(ProtocolError::ErrorDeserialization( + "cannot load ct2".to_string(), + )) + } + } + } + + // Get dataset C with company keys and encrypt them to H(C)^c + // With Elliptic curves: H(C)*c + fn get_company_keys(&self) -> Result { + match ( + self.plaintext.clone().read(), + self.enc_company.clone().write(), + ) { + (Ok(pdata), Ok(mut enc_company)) => { + let t = timer::Timer::new_silent("x_company"); + + // Flatten + let (mut d_flat, mut offset, metadata) = { + let (d_flat, offset, metadata) = serialize_helper(pdata.to_vec()); + + // Hash Encrypt - H(C)^c + let enc = self + .ec_cipher + .hash_encrypt(d_flat.as_slice(), &self.keypair_sk.0); + + (enc, offset, metadata) + }; + + // Unflatten and set encrypted keys + { + let psum = offset + .iter() + .map(|b| { + u64::from_le_bytes(b.buffer.as_slice().try_into().unwrap()) as usize + }) + .collect::>(); + + let num_keys = psum.len() - 1; + let mut x = psum + .get(0..num_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| d_flat.get(x1..x2).unwrap().to_vec()) + .collect::>>(); + + enc_company.clear(); + enc_company.extend(x.drain(..)); + } + + t.qps("encryption H(C)^c", d_flat.len()); + + // Serialize + let buf = { + let mut x = self.ec_cipher.to_bytes(d_flat.as_slice()); + + d_flat.clear(); + d_flat.shrink_to_fit(); + + offset.extend(metadata); + // Append offsets array + x.extend(offset); + x + }; + + Ok(buf) + } + _ => { + error!("Unable to encrypt H(C)^c:"); + Err(ProtocolError::ErrorEncryption( + "cannot encrypt H(C)^c".to_string(), + )) + } + } + } + + // Get dataset ct1 and ct2' + fn get_ct1_ct2(&self) -> Result { + match (self.ct1.clone().read(), self.ct2.clone().read()) { + (Ok(ct1), Ok(ct2)) => { + let t = timer::Timer::new_silent("x_company"); + + // Re-randomize ct1'' and ct2'' and flatten + let (mut ct1_dprime_flat, ct2_dprime_flat, ct_offset) = { + // flatten + let (ct1_dprime_flat, _ct1_offset) = { + let (d_flat, mut offset, metadata) = serialize_helper(ct1.clone()); + offset.extend(metadata); + + (self.ec_cipher.to_bytes(d_flat.as_slice()), offset) + }; + + let (ctd2_dprime_flat, ct2_offset) = { + let (d_flat, mut offset, metadata) = serialize_helper(ct2.clone()); + offset.extend(metadata); + + let d_flat_c = d_flat + .iter() + .map(|x| *x * (self.keypair_sk.0)) + .collect::>(); + + (self.ec_cipher.to_bytes(d_flat_c.as_slice()), offset) + }; + (ct1_dprime_flat, ctd2_dprime_flat, ct2_offset) + }; + + ct1_dprime_flat.extend(ct2_dprime_flat); + ct1_dprime_flat.extend(ct_offset); + + t.qps("encryption H(C)^c", ct1_dprime_flat.len()); + + // ct1_dprime_flat, ct2_dprime_flat, ct_offset + Ok(ct1_dprime_flat) + } + _ => { + error!("Unable to encrypt H(C)^c:"); + Err(ProtocolError::ErrorEncryption( + "cannot encrypt H(C)^c".to_string(), + )) + } + } + } + + // Send , v_cd, p_cd to D + fn get_all_ct3_p_cd_v_cd(&self) -> Result { + match ( + self.partners_queue.clone().write(), + self.perms.clone().read(), + self.blinds.clone().read(), + ) { + (Ok(mut partners_data_q), Ok(perms), Ok(blinds)) => { + let mut res = vec![]; + let num_partners = partners_data_q.len(); + for _ in 0..num_partners { + let partner_data: PartnerData = partners_data_q.pop_back().unwrap(); + let ct3 = partner_data.scalar_g; + let n_rows = partner_data.n_rows; + let n_features = partner_data.n_features; + + res.push(ByteBuffer { buffer: ct3 }); + let metadata = vec![ + ByteBuffer { + buffer: (n_rows as u64).to_le_bytes().to_vec(), + }, + ByteBuffer { + buffer: (n_features as u64).to_le_bytes().to_vec(), + }, + ]; + res.extend(metadata); + } + res.push(ByteBuffer { + buffer: (num_partners as u64).to_le_bytes().to_vec(), + }); + + let p_cd_bytes = perms + .0 + .iter() + .map(|e| ByteBuffer { + buffer: (*e).to_le_bytes().to_vec(), + }) + .collect::>(); + let v_cd_bytes = blinds + .0 + .iter() + .map(|e| ByteBuffer { + buffer: (*e).to_le_bytes().to_vec(), + }) + .collect::>(); + let data_len = p_cd_bytes.len(); + res.extend(p_cd_bytes); + res.extend(v_cd_bytes); + res.push(ByteBuffer { + buffer: (data_len as u64).to_le_bytes().to_vec(), + }); + + Ok(res) + } + _ => { + error!("Unable to flatten ct3:"); + Err(ProtocolError::ErrorEncryption( + "cannot flatten ct3".to_string(), + )) + } + } + } + + fn get_p_cs_v_cs(&self) -> Result { + match ( + self.perms.clone().read(), + self.blinds.clone().read(), + self.ct1.clone().write(), + self.ct2.clone().write(), + self.helper_public_key.clone().read(), + ) { + (Ok(perms), Ok(blinds), Ok(mut ct1), Ok(mut ct2), Ok(helper_pk)) => { + let mut res = vec![]; + + let p_cs_bytes = perms + .1 + .iter() + .map(|e| ByteBuffer { + buffer: (*e as u64).to_le_bytes().to_vec(), + }) + .collect::>(); + let v_cs_bytes = blinds + .1 + .iter() + .map(|e| ByteBuffer { + buffer: (*e as u64).to_le_bytes().to_vec(), + }) + .collect::>(); + let data_len = ct1.len(); + res.extend(p_cs_bytes); + res.extend(v_cs_bytes); + + // Re-randomize ct1 and ct2 to ct1' and ct2' + let (ct1_prime_flat, ct2_prime_flat, ct_offset) = { + let r_i = (0..data_len) + .collect::>() + .iter() + .map(|_| gen_scalar()) + .collect::>(); + // company_pk^r + // with EC: company_pk * r + let pkc_r = r_i + .iter() + .map(|x| *x * (self.keypair_pk.0)) + .collect::>(); + // helper_pk^r + // with EC: helper_pk * r + let pkd_r = r_i.iter().map(|x| *x * (*helper_pk)).collect::>(); + + permute(perms.0.as_slice(), &mut ct1); // p_cd + permute(perms.0.as_slice(), &mut ct2); // p_cd + permute(perms.1.as_slice(), &mut ct1); // p_cs + permute(perms.1.as_slice(), &mut ct2); // p_cs + + // ct1' = p_4(p_3(ct1)) * company_pk^r + // with EC: ct1' = p_4(p_3(ct1)) + company_pk*r + let ct1_prime = ct1 + .iter() + .zip_eq(pkc_r.iter()) + .map(|(s, t)| (*s).iter().map(|si| *si + *t).collect::>()) + .collect::>(); + // ct2' = p_4(p_3(ct2)) * helper_pk^r + // with EC: ct2' = p_4(p_3(ct2)) + helper_pk*r + let ct2_prime = ct2 + .iter() + .zip_eq(pkd_r.iter()) + .map(|(s, t)| (*s).iter().map(|si| *si + *t).collect::>()) + .collect::>(); + + let (ct1_prime_flat, _ct1_offset) = { + let (d_flat, mut offset, metadata) = serialize_helper(ct1_prime.clone()); + offset.extend(metadata); + + (self.ec_cipher.to_bytes(d_flat.as_slice()), offset) + }; + let (ct2_prime_flat, ct2_offset) = { + let (d_flat, mut offset, metadata) = serialize_helper(ct2_prime.clone()); + offset.extend(metadata); + + (self.ec_cipher.to_bytes(d_flat.as_slice()), offset) + }; + (ct1_prime_flat, ct2_prime_flat, ct2_offset) + }; + + assert_eq!(ct1_prime_flat.len(), ct2_prime_flat.len()); + + res.extend(ct1_prime_flat); + res.extend(ct2_prime_flat); + res.extend(ct_offset); + + // p_cs, v_cs, ct1_prime_flat, ct2_prime_flat, ct_offset + Ok(res) + } + _ => { + error!("Unable to flatten ct3:"); + Err(ProtocolError::ErrorEncryption( + "cannot flatten ct3".to_string(), + )) + } + } + } + + fn set_p_sc_v_sc_ct1ct2dprime( + &self, + v_sc_bytes: TPayload, + p_sc_bytes: TPayload, + ct1_dprime_flat: TPayload, + ct2_dprime_flat: TPayload, + psum: Vec, + ) -> Result<(), ProtocolError> { + match ( + self.ct1.clone().write(), + self.ct2.clone().write(), + self.perms.clone().read(), + self.blinds.clone().read(), + self.v1.clone().write(), + self.u1.clone().write(), + ) { + (Ok(mut ct1), Ok(mut ct2), Ok(perms), Ok(blinds), Ok(mut v1), Ok(mut u1)) => { + let t = timer::Timer::new_silent("set set_p_sc_v_sc_ct1ct2dprime"); + let num_keys = v_sc_bytes.len(); + // Remove the previous data and replace them with the (doubly) re-randomized + ct1.clear(); + ct2.clear(); + // Unflatten and convert to points + *ct1 = { + // ct1'' (doubly re-randomized ct1) + let t = self.ec_cipher.to_points(&ct1_dprime_flat); + + psum.get(0..num_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| t.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + // Unflatten and convert to points + *ct2 = { + // ct2'' (doubly re-randomized ct2) + let t = self.ec_cipher.to_points(&ct2_dprime_flat); + + psum.get(0..num_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| t.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + + let p_sc = p_sc_bytes + .iter() + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize) + .collect::>(); + let v_sc = v_sc_bytes + .iter() + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) + .collect::>(); + + let n_features = v1.len(); + // Compute u1 = p_sc( p_cs( p_cd(v_1) xor v_cd) xor v_cs) xor v_sc + (*u1).clear(); + for f_idx in (0..n_features).rev() { + permute(perms.0.as_slice(), &mut v1[f_idx]); // p_cd + let mut u2 = v1[f_idx] + .iter() + .zip_eq(blinds.0.iter()) // v_cd + .map(|(s, v_cd)| *s ^ *v_cd) + .collect::>(); + + permute(perms.1.as_slice(), &mut u2); // p_cs + let mut t1 = u2 + .iter() + .zip_eq(blinds.1.iter()) // v_cs + .map(|(s, v_cs)| *s ^ *v_cs) + .collect::>(); + + permute(p_sc.as_slice(), &mut t1); // p_sc + (*u1).push( + t1.iter() + .zip_eq(v_sc.iter()) + .map(|(s, v_sc)| { + // v_sc + let y = *s ^ *v_sc; + ByteBuffer { + buffer: y.to_le_bytes().to_vec(), + } + }) + .collect::>(), + ); + } + + t.qps("ct1'' and ct2''", ct1.len()); + Ok(()) + } + _ => { + error!("Cannot flatten ct1'' and ct2''"); + Err(ProtocolError::ErrorDeserialization( + "cannot flatten ct1'' and ct2''".to_string(), + )) + } + } + } + + // Send u1 to D + fn get_u1(&self) -> Result { + match self.u1.clone().read() { + Ok(u1) => { + let t = timer::Timer::new_silent("company get_u1"); + t.qps("u1", u1.len()); + + let mut d_flat = (*u1).clone().into_iter().flatten().collect::>(); + let n_rows = u1[0].len(); + let n_features = u1.len(); + let metadata = vec![ + ByteBuffer { + buffer: (n_rows as u64).to_le_bytes().to_vec(), + }, + ByteBuffer { + buffer: (n_features as u64).to_le_bytes().to_vec(), + }, + ]; + d_flat.extend(metadata); + + t.qps("d_flat", d_flat.len()); + Ok(d_flat) + } + _ => { + error!("Unable to flatten u1:"); + Err(ProtocolError::ErrorEncryption( + "cannot flatten u1".to_string(), + )) + } + } + } + + fn calculate_features_xor_shares( + &self, + partner_features: TFeatures, + g_zi: TPayload, + ) -> Result<(), ProtocolError> { + match self.partner_shares.clone().write() { + Ok(mut shares) => { + let n_features = partner_features.len(); + + let g_zi_pt = self.ec_cipher.to_points(&g_zi); + let r = g_zi_pt + .iter() + .map(|x| { + let t = self.ec_cipher.to_bytes(&[x * self.keypair_sk.0]); + u64::from_le_bytes((t[0].buffer[0..8]).try_into().unwrap()) + }) + .collect::>(); + + for f_idx in 0..n_features { + let s = partner_features[f_idx] + .iter() + .zip_eq(r.iter()) + .map(|(x1, x2)| *x1 ^ *x2) + .collect::>(); + shares.insert(f_idx, s); + } + + Ok(()) + } + _ => { + error!("Unable to calculate XOR shares"); + Err(ProtocolError::ErrorEncryption( + "unable to calculate XOR shares".to_string(), + )) + } + } + } + + fn write_company_to_id_map(&self) -> Result<(), ProtocolError> { + match ( + self.enc_company.clone().read(), + self.permutation.clone().read(), + self.id_map.clone().write(), + ) { + (Ok(pdata), Ok(permutation), Ok(mut id_map)) => { + let mut company_ragged = pdata.clone(); + undo_permute(permutation.as_slice(), &mut company_ragged); + + // Get the first column. + let company_keys = { + let tmp = company_ragged.iter().map(|s| s[0]).collect::>(); + self.ec_cipher.to_bytes(tmp.as_slice()) + }; + + id_map.clear(); + for (idx, k) in company_keys.iter().enumerate() { + id_map.push((k.to_string(), idx, true)); + } + + // Sort the id_map by the spine + id_map.sort_by(|(a, _, _), (b, _, _)| a.cmp(b)); + + Ok(()) + } + _ => { + error!("Cannot create id_map"); + Err(ProtocolError::ErrorDeserialization( + "cannot create id_map".to_string(), + )) + } + } + } + + fn print_id_map(&self) { + match (self.plaintext.clone().read(), self.id_map.clone().read()) { + (Ok(data), Ok(id_map)) => { + writer_helper(&data, &id_map, None); + } + _ => panic!("Cannot print id_map"), + } + } + + fn save_id_map(&self, path: &str) -> Result<(), ProtocolError> { + match (self.plaintext.clone().read(), self.id_map.clone().read()) { + (Ok(data), Ok(id_map)) => { + writer_helper(&data, &id_map, Some(path.to_string())); + Ok(()) + } + _ => Err(ProtocolError::ErrorIO( + "Unable to write company view to file".to_string(), + )), + } + } + + fn save_features_shares(&self, path_prefix: &str) -> Result<(), ProtocolError> { + match self.partner_shares.clone().read() { + Ok(shares) => { + assert!(shares.len() > 0); + + let mut out: Vec> = Vec::new(); + + for key in shares.keys().sorted() { + out.push(shares.get(key).unwrap().clone()); + } + + let p_filename = format!("{}{}", path_prefix, "_partner_features.csv"); + info!("revealing partner features to output file"); + common::files::write_u64cols_to_file(&mut out, Path::new(&p_filename)).unwrap(); + + Ok(()) + } + _ => Err(ProtocolError::ErrorIO( + "Unable to write company shares of partner features to file".to_string(), + )), + } + } +} diff --git a/protocol/src/dspmc/helper.rs b/protocol/src/dspmc/helper.rs new file mode 100644 index 0000000..f58d9fa --- /dev/null +++ b/protocol/src/dspmc/helper.rs @@ -0,0 +1,793 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate csv; + +use std::collections::HashMap; +use std::convert::TryInto; +use std::path::Path; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; +use itertools::Itertools; +use rand::distributions::Uniform; +use rand::prelude::*; +use rand::Rng; + +use super::writer_helper; +use super::ProtocolError; +use crate::dspmc::traits::HelperDspmcProtocol; +use crate::shared::TFeatures; + +#[derive(Debug)] +pub struct HelperDspmc { + keypair_sk: Scalar, + keypair_pk: TPoint, + ec_cipher: ECRistrettoParallel, + company_public_key: Arc>, + xor_shares_v2: Arc>, // v2 = v xor v1 -- The shuffler has v1 + enc_company: Arc>>>, // H(C)^c + enc_partners: Arc>>>, // H(P)^c + features: Arc>, // v''' from shuffler + p_cd: Arc>>, + v_cd: Arc>>, + p_sd: Arc>>, + v_sd: Arc>>, + shuffler_gz: Arc>>, // h = g^z from shuffler + s_company: Arc>>, + s_partner: Arc>>, + id_map: Arc>>, + helper_shares: Arc>>>, +} + +impl HelperDspmc { + pub fn new() -> HelperDspmc { + let x = gen_scalar(); + HelperDspmc { + keypair_sk: x, + keypair_pk: &x * &RISTRETTO_BASEPOINT_TABLE, + ec_cipher: ECRistrettoParallel::default(), + company_public_key: Arc::new(RwLock::default()), + xor_shares_v2: Arc::new(RwLock::default()), + enc_company: Arc::new(RwLock::default()), + enc_partners: Arc::new(RwLock::default()), + features: Arc::new(RwLock::default()), + p_cd: Arc::new(RwLock::default()), + v_cd: Arc::new(RwLock::default()), + p_sd: Arc::new(RwLock::default()), + v_sd: Arc::new(RwLock::default()), + shuffler_gz: Arc::new(RwLock::default()), + s_company: Arc::new(RwLock::default()), + s_partner: Arc::new(RwLock::default()), + id_map: Arc::new(RwLock::default()), + helper_shares: Arc::new(RwLock::default()), + } + } + + pub fn get_helper_public_key(&self) -> Result { + Ok(self.ec_cipher.to_bytes(&[self.keypair_pk])) + } + + pub fn set_company_public_key( + &self, + company_public_key: TPayload, + ) -> Result<(), ProtocolError> { + let pk = self.ec_cipher.to_points(&company_public_key); + // Check that two keys are sent + assert_eq!(pk.len(), 2); + + match self.company_public_key.clone().write() { + Ok(mut company_pk) => { + company_pk.0 = pk[0]; + company_pk.1 = pk[1]; + assert!(!(company_pk.0).is_identity()); + assert!(!(company_pk.1).is_identity()); + Ok(()) + } + _ => { + error!("Unable to set company public key"); + Err(ProtocolError::ErrorEncryption( + "unable to set company public key".to_string(), + )) + } + } + } +} + +impl Default for HelperDspmc { + fn default() -> Self { + Self::new() + } +} + +impl HelperDspmcProtocol for HelperDspmc { + fn set_ct3p_cd_v_cd( + &self, + mut data: TPayload, + num_partners: usize, + v_cd_bytes: TPayload, + p_cd_bytes: TPayload, + ) -> Result<(), ProtocolError> { + match ( + self.xor_shares_v2.clone().write(), + self.p_cd.clone().write(), + self.v_cd.clone().write(), + ) { + (Ok(mut xor_shares_v2), Ok(mut p_cd), Ok(mut v_cd)) => { + let t = timer::Timer::new_silent("set v''"); + for _ in 0..num_partners { + // Data in form [(ct3, metadata), (ct3, metadata), ... ] + let n_features = u64::from_le_bytes( + data.pop().unwrap().buffer.as_slice().try_into().unwrap(), + ) as usize; + let n_rows = u64::from_le_bytes( + data.pop().unwrap().buffer.as_slice().try_into().unwrap(), + ) as usize; + + let ct3 = data.drain((data.len() - 1)..).collect::>(); + + // PRG seed = scalar * PK_helper + let seed = { + let x = self.ec_cipher.to_points_encrypt(&ct3, &self.keypair_sk); + &self.ec_cipher.to_bytes(&x)[0].buffer + }; + let seed_array: [u8; 32] = + seed.as_slice().try_into().expect("incorrect length"); + let mut rng = StdRng::from_seed(seed_array); + + // Merge features from all partners together. Example: + // features from P1: + // 10, 11, 12 + // 20, 21, 22 + // --> [[10, 20], [11, 21], [12, 22]] + // + // features from P2: + // 30, 31, 32 + // 40, 41, 42 + // --> [[30, 40], [31, 41], [32, 42]] + // + // Merged: [[10, 20, 30, 40], [11, 21, 31, 41], [12, 22, 32, 42]] + for f_idx in 0..n_features { + let t = (0..n_rows) + .collect::>() + .iter() + .map(|_| rng.gen::()) + .collect::>(); + if xor_shares_v2.len() != n_features { + xor_shares_v2.push(t); + } else { + xor_shares_v2[f_idx].extend(t); + } + } + } + + *v_cd = v_cd_bytes + .iter() + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) + .collect::>(); + + *p_cd = p_cd_bytes + .iter() + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize) + .collect::>(); + + t.qps("deserialize_exp", xor_shares_v2.len()); + + Ok(()) + } + _ => { + error!("Cannot load xor_shares_v2"); + Err(ProtocolError::ErrorDeserialization( + "cannot load xor_shares_v2".to_string(), + )) + } + } + } + + fn set_encrypted_vprime( + &self, + blinded_features: TFeatures, + g_zi: TPayload, + ) -> Result<(), ProtocolError> { + match ( + self.features.clone().write(), + self.shuffler_gz.clone().write(), + ) { + (Ok(mut features), Ok(mut shuffler_gz)) => { + let t = timer::Timer::new_silent("set_encrypted_vprime"); + + features.clear(); + features.extend(blinded_features); + + shuffler_gz.clear(); + shuffler_gz.extend(g_zi); + + t.qps("deserialize_exp", shuffler_gz.len()); + + Ok(()) + } + _ => { + error!("Cannot load encrypted_vprime"); + Err(ProtocolError::ErrorDeserialization( + "cannot load encrypted_vprime".to_string(), + )) + } + } + } + + fn set_p_sd_v_sd( + &self, + v_sd_bytes: TPayload, + p_sd_bytes: TPayload, + ) -> Result<(), ProtocolError> { + match (self.p_sd.clone().write(), self.v_sd.clone().write()) { + (Ok(mut p_sd), Ok(mut v_sd)) => { + let t = timer::Timer::new_silent("set set_p_sd_v_sd"); + + *v_sd = v_sd_bytes + .iter() + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) + .collect::>(); + + *p_sd = p_sd_bytes + .iter() + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize) + .collect::>(); + + t.qps("deserialize_exp", (*p_sd).len()); + + Ok(()) + } + _ => { + error!("Cannot load set_p_sd_v_sd"); + Err(ProtocolError::ErrorDeserialization( + "cannot load set_p_sd_v_sd".to_string(), + )) + } + } + } + + fn set_u1(&self, mut u1: TFeatures) -> Result<(), ProtocolError> { + match ( + self.p_sd.clone().read(), + self.v_sd.clone().read(), + self.xor_shares_v2.clone().write(), + ) { + (Ok(p_sd), Ok(v_sd), Ok(mut xor_shares_v2)) => { + let t = timer::Timer::new_silent("set set_u1"); + + let n_features = u1.len(); + + xor_shares_v2.clear(); + for f_idx in 0..n_features { + permute(p_sd.as_slice(), &mut u1[f_idx]); // p_sc + let t = u1[f_idx] + .iter() + .zip_eq(v_sd.iter()) + .map(|(s, t)| *s ^ *t) + .collect::>(); + xor_shares_v2.push(t); + } + + t.qps("deserialize_exp", xor_shares_v2.len()); + Ok(()) + } + _ => { + error!("Cannot load set_u1"); + Err(ProtocolError::ErrorDeserialization( + "cannot load set_u1".to_string(), + )) + } + } + } + + // Gets H(C)^c and ct1, ct2. + // Stores H(C)^c as enc_company + // Computes H(P)^c = ct2 / ct1^d + // Stores H(P)^c as enc_partners + fn set_encrypted_keys( + &self, + enc_keys: TPayload, + psum: Vec, + ct1_flat: TPayload, + ct2_flat: TPayload, + ct_psum: Vec, + ) -> Result<(), ProtocolError> { + match ( + self.enc_company.clone().write(), + self.enc_partners.clone().write(), + ) { + (Ok(mut enc_company), Ok(mut enc_partners)) => { + let t = timer::Timer::new_silent("set set_encrypted_keys"); + + // Unflatten and convert to points + let num_ct_keys = ct_psum.len() - 1; + *enc_partners = { + let t1 = self.ec_cipher.to_points(&ct1_flat); + let t2 = self.ec_cipher.to_points(&ct2_flat); + + let ct1_d = t1 + .iter() + .map(|x| *x * (&self.keypair_sk)) + .collect::>(); + + let y = t2 + .iter() + .zip_eq(ct1_d.iter()) + .map(|(s2, s1)| *s2 - *s1) + .collect::>(); + + ct_psum + .get(0..num_ct_keys) + .unwrap() + .iter() + .zip_eq(ct_psum.get(1..num_ct_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| y.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + + // Unflatten and convert to points + let num_company_keys = psum.len() - 1; + *enc_company = { + let t = self.ec_cipher.to_points(&enc_keys); + + psum.get(0..num_company_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_company_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| t.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + + t.qps("deserialize_exp", (*enc_company).len()); + Ok(()) + } + _ => { + error!("Cannot load set_encrypted_keys"); + Err(ProtocolError::ErrorDeserialization( + "cannot load set_encrypted_keys".to_string(), + )) + } + } + } + + // s_partner has all the data that are in Partner but did not get matched + // s_company has all the data that are in Company but did not get matched + // + // 1. For each item of the partners ID-MAP: + // * If it is already in the ID-MAP continue. + // * If it is a company key OR if it is not in s_partner (which means + // that this item is in the intersection), save it to ID-MAP and set + // "partner index" to the correct index and "found" to true. + // 2. For all the keys from s_company add them to ID-MAP and se "found" to + // false. + fn calculate_id_map(&self) { + match ( + self.enc_partners.clone().read(), + self.enc_company.clone().read(), + self.s_partner.clone().read(), + self.s_company.clone().read(), + self.id_map.clone().write(), + ) { + (Ok(enc_partners), Ok(enc_company), Ok(s_partner), Ok(s_company), Ok(mut id_map)) => { + // Get the first column. + let partner_keys = { + let tmp = enc_partners.iter().map(|s| s[0]).collect::>(); + self.ec_cipher.to_bytes(tmp.as_slice()) + }; + + // Get the first column. + let company_keys = { + let tmp = enc_company.iter().map(|s| s[0]).collect::>(); + self.ec_cipher.to_bytes(tmp.as_slice()) + }; + + // Put all the company keys into a map to access them quickly. + let mut company_keys_map = HashMap::new(); + for key in company_keys.iter() { + company_keys_map.insert(key.to_string(), false); + } + + // Put all the items of s_partner into a map to access them quickly. + let mut s_partner_map = HashMap::new(); + for key in s_partner.iter() { + s_partner_map.insert(key.to_string(), true); + } + + // Add the index of each item of partner_keys into company_keys_map + // if it's already there (i.e., if it's in the intersection). + let mut id_hashmap = HashMap::new(); + for (idx, key) in partner_keys.iter().enumerate() { + if id_hashmap.contains_key(&key.to_string()) { + continue; + } + if company_keys_map.contains_key(&key.to_string()) + || !s_partner_map.contains_key(&key.to_string()) + { + id_hashmap.insert(key.to_string(), (idx, true)); + } + } + + // Add all the remaining keys that company has but the partner doesn't. + for (idx, key) in s_company.iter().enumerate() { + id_hashmap.insert(key.to_string(), (idx, false)); + } + + id_map.clear(); + *id_map = id_hashmap + .into_iter() + .map(|(key, (idx, score))| (key, idx, score)) + .collect::>(); + + // Sort the id_map by the spine + id_map.sort_by(|(a, _, _), (b, _, _)| a.cmp(b)); + } + _ => panic!("Cannot create id-map"), + } + } + + fn calculate_set_diff(&self) -> Result<(), ProtocolError> { + match ( + self.enc_company.clone().read(), + self.enc_partners.clone().write(), + self.s_company.clone().write(), + self.s_partner.clone().write(), + ) { + (Ok(e_company), Ok(mut e_partner), Ok(mut s_company), Ok(mut s_partner)) => { + // let t = timer::Timer::new_silent("helper calculate_set_diff"); + + let s_c = e_company.iter().map(|e| e[0]).collect::>(); + let s_p = e_partner.iter().map(|e| e[0]).collect::>(); + + let max_len = e_company.iter().map(|e| e.len()).max().unwrap(); + + // Start with both vectors as all valid + let mut e_c_valid = vec![true; e_company.len()]; + let mut e_p_valid = vec![true; e_partner.len()]; + + for idx in 0..max_len { + // TODO: This should be a ByteBuffer instead of a vec + let mut e_c_map = HashMap::, usize>::new(); + + // Strip the idx-th key (viewed as a column) + for (e, i) in e_company + .iter() + .enumerate() + .filter(|(_, e)| e.len() > idx) + .map(|(i, e)| (e[idx], i)) + { + // Ristretto points are not hashable by themselves + e_c_map.insert(e.compress().to_bytes().to_vec(), i); + } + + // Vector of indices of e_p that match. These will be set to false + let mut e_p_match_idx = Vec::::new(); + for ((i, e), _) in e_partner + .iter_mut() + .enumerate() + .zip_eq(e_p_valid.iter()) + .filter(|((_, _), &f)| f) + { + // Find the minimum index where match happens + let match_idx = e + .iter() + .map(|key| + // TODO: Replace with match + if e_c_map.contains_key(&key.compress().to_bytes().to_vec()) { + let &m_idx = e_c_map.get(&key.compress().to_bytes().to_vec()).unwrap(); + (m_idx, e_c_valid[m_idx]) + } else { + // Using length of vector as a sentinel value. Will get + // filtered out because of false + (e_c_valid.len(), false) + }) + .filter(|(_, f)| *f) + .map(|(e, _)| e) + .min(); + + // For those indices that have matched - set them to false + // Also assign the correct keys + if let Some(m_idx) = match_idx { + // if the match occurred not in the first column, + // make sure the spine keys will be the same. + if idx > 0 { + e[0] = e_company[m_idx][0]; + } + e_c_valid[m_idx] = false; + e_p_match_idx.push(i); + } + } + + // Set all e_p that matched to false - so they aren't matched in the next + // iteration + e_p_match_idx.iter().for_each(|&idx| e_p_valid[idx] = false); + } + + // Create S_p by filtering out values that matched + s_partner.clear(); + { + // Only keep s_p that have not been matched + let mut inp = s_p + .iter() + .zip_eq(e_p_valid.iter()) + .filter(|(_, &f)| f) + .map(|(&e, _)| e) + .collect::>(); + + if !inp.is_empty() { + // Permute s_p + permute(gen_permute_pattern(inp.len()).as_slice(), &mut inp); + + // save output + s_partner.extend(self.ec_cipher.to_bytes(inp.as_slice())); + } + } + + // Create S_c by filtering out values that matched + let t = s_c + .iter() + .zip_eq(e_c_valid.iter()) + .filter(|(_, &f)| f) + .map(|(&e, _)| e) + .collect::>(); + s_company.clear(); + + if !t.is_empty() { + s_company.extend(self.ec_cipher.to_bytes(t.as_slice())); + } + // t.qps("s_company", s_company.len()); + + Ok(()) + } + _ => { + error!("Unable to obtain locks to buffers for set diff operation"); + Err(ProtocolError::ErrorEncryption( + "unable to encrypt data".to_string(), + )) + } + } + } + + // Compute u2 = p_cd(v_2) xor v_cd + fn get_u2(&self) -> Result { + match ( + self.xor_shares_v2.clone().write(), + self.p_cd.clone().read(), + self.v_cd.clone().read(), + ) { + (Ok(mut xor_shares_v2), Ok(p_cd), Ok(v_cd)) => { + let t = timer::Timer::new_silent("helper get_u2"); + + let n_rows = xor_shares_v2[0].len(); + let n_features = xor_shares_v2.len(); + let mut u2 = Vec::::new(); + // for f_idx in (0..n_features).rev() { + for f_idx in 0..n_features { + permute(p_cd.as_slice(), &mut xor_shares_v2[f_idx]); + let t = xor_shares_v2[f_idx] + .iter() + .zip_eq(v_cd.iter()) + .map(|(s, t)| { + let y = *s ^ *t; + ByteBuffer { + buffer: y.to_le_bytes().to_vec(), + } + }) + .collect::>(); + u2.push(t); + } + + let mut d_flat = u2.into_iter().flatten().collect::>(); + let metadata = vec![ + ByteBuffer { + buffer: (n_rows as u64).to_le_bytes().to_vec(), + }, + ByteBuffer { + buffer: (n_features as u64).to_le_bytes().to_vec(), + }, + ]; + d_flat.extend(metadata); + + t.qps("d_flat", d_flat.len()); + Ok(d_flat) + } + _ => { + error!("Cannot read get_u2"); + Err(ProtocolError::ErrorEncryption( + "unable to read get_u2".to_string(), + )) + } + } + } + + fn calculate_features_xor_shares(&self) -> Result { + match ( + self.features.clone().read(), + self.xor_shares_v2.clone().read(), + self.shuffler_gz.clone().read(), + self.id_map.clone().read(), + self.company_public_key.clone().read(), + self.helper_shares.clone().write(), + ) { + ( + Ok(partner_features), + Ok(xor_shares_v2), + Ok(shuffler_gz), + Ok(id_map), + Ok(company_pk), + Ok(mut shares), + ) => { + let t = timer::Timer::new_silent("helper calculate_features_xor_shares"); + let mut rng = rand::thread_rng(); + let range = Uniform::new(0_u64, u64::MAX); + + let n_features = partner_features.len(); + + let (t_i, mut g_zi) = { + let z_i = (0..id_map.len()).map(|_| gen_scalar()).collect::>(); + let x = z_i + .iter() + .map(|a| { + let x = self.ec_cipher.to_bytes(&[a * company_pk.0]); + x[0].clone() + }) + .collect::>(); + let y = z_i + .iter() + .map(|a| a * &RISTRETTO_BASEPOINT_TABLE) + .collect::>(); + (x, y) + }; + + let mut d_flat = { + let mut v_p = Vec::::new(); + + let shuffler_gz_points = self.ec_cipher.to_points(&shuffler_gz); + + for f_idx in (0..n_features).rev() { + let mask = (0..id_map.len()) + .map(|_| rng.sample(range)) + .collect::>(); + let t = id_map + .iter() + .enumerate() + .map(|(i, (_, idx, exists))| { + let y = if *exists { + if f_idx == 0 { + // If exists, overwrite g_z' with g_z from shuffler. + g_zi[i] = shuffler_gz_points[*idx]; + } + // v'' xor v''' xor mask = v'' xor v' xor r xor mask = + // v xor r xor mask + xor_shares_v2[f_idx][*idx] + ^ partner_features[f_idx][*idx] + ^ mask[i] + } else { + // If it doesn't exist, r xor mask + let y = u64::from_le_bytes( + (t_i[i].buffer[0..8]).try_into().unwrap(), + ); + y ^ mask[i] + }; + ByteBuffer { + buffer: y.to_le_bytes().to_vec(), + } + }) + .collect::>(); + + v_p.push(t); + shares.insert(f_idx, mask); + } + + v_p.into_iter().flatten().collect::>() + }; + + d_flat.extend(self.ec_cipher.to_bytes(&g_zi)); + + let metadata = vec![ + ByteBuffer { + buffer: (id_map.len() as u64).to_le_bytes().to_vec(), + }, + ByteBuffer { + buffer: (n_features as u64).to_le_bytes().to_vec(), + }, + ]; + + d_flat.extend(metadata); + t.qps("d_flat", d_flat.len()); + Ok(d_flat) + } + _ => { + error!("Cannot read id_map"); + Err(ProtocolError::ErrorEncryption( + "unable to read id_map".to_string(), + )) + } + } + } + + fn print_id_map(&self) { + match self.id_map.clone().read() { + Ok(id_map) => { + // Create fake data since we only have encrypted partner keys + let m_idx = id_map + .iter() + .filter(|(_, _, flag)| *flag) + .map(|(_, idx, _)| idx) + .max() + .unwrap(); + + let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx + 1]; + + for i in 0..id_map.len() { + let (_, idx, flag) = id_map[i]; + if flag { + data[idx] = vec![format!(" Partner enc key at pos {}", idx)]; + } + } + + writer_helper(&data, &id_map, None); + } + _ => panic!("Cannot print id_map"), + } + } + + fn save_id_map(&self, path: &str) -> Result<(), ProtocolError> { + match self.id_map.clone().read() { + Ok(id_map) => { + // Create fake data since we only have encrypted partner keys + let m_idx = id_map + .iter() + .filter(|(_, _, flag)| *flag) + .map(|(_, idx, _)| idx) + .max() + .unwrap(); + + let mut data: Vec> = vec![vec!["NA".to_string()]; m_idx + 1]; + + for i in 0..id_map.len() { + let (_, idx, flag) = id_map[i]; + if flag { + data[idx] = vec![format!(" Partner enc key at pos {}", idx)]; + } + } + + writer_helper(&data, &id_map, Some(path.to_string())); + Ok(()) + } + _ => Err(ProtocolError::ErrorIO( + "Unable to write company view to file".to_string(), + )), + } + } + + fn save_features_shares(&self, path_prefix: &str) -> Result<(), ProtocolError> { + match self.helper_shares.clone().read() { + Ok(shares) => { + assert!(shares.len() > 0); + + let mut out: Vec> = Vec::new(); + + for key in shares.keys().sorted() { + out.push(shares.get(key).unwrap().clone()); + } + + let p_filename = format!("{}{}", path_prefix, "_partner_features.csv"); + info!("revealing partner features to output file"); + common::files::write_u64cols_to_file(&mut out, Path::new(&p_filename)).unwrap(); + + Ok(()) + } + _ => Err(ProtocolError::ErrorIO( + "Unable to write company shares of partner features to file".to_string(), + )), + } + } +} diff --git a/protocol/src/dspmc/mod.rs b/protocol/src/dspmc/mod.rs new file mode 100644 index 0000000..0780b37 --- /dev/null +++ b/protocol/src/dspmc/mod.rs @@ -0,0 +1,181 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate csv; + +use std::collections::HashSet; +use std::error::Error; +use std::fmt; +use std::sync::Arc; +use std::sync::RwLock; + +use common::files; +use common::timer; +use crypto::prelude::*; + +#[derive(Debug)] +pub enum ProtocolError { + ErrorDeserialization(String), + ErrorSerialization(String), + ErrorEncryption(String), + ErrorCalcSetDiff(String), + ErrorReencryption(String), + ErrorIO(String), +} + +impl fmt::Display for ProtocolError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "protocol error {}", self) + } +} + +impl Error for ProtocolError {} + +fn load_data_keys(plaintext: Arc>>>, path: &str, input_with_headers: bool) { + let t = timer::Timer::new_silent("load data"); + + let mut lines = files::read_csv_as_strings(path, true); + let text_len = lines.len(); + + if let Ok(mut data) = plaintext.write() { + data.clear(); + let mut line_it = lines.drain(..); + // Strip the header + if input_with_headers { + line_it.next(); + } + + let mut t = HashSet::>::new(); + // Filter out zero length strings - these will come from ragged + // arrays since they are padded out to the longest array + // Also deduplicate all input + for line in line_it { + let v = line + .iter() + .map(String::from) + .filter(|s| !s.is_empty()) + .collect::>(); + if !t.contains(&v) { + data.push(v.clone()); + t.insert(v); + } + } + info!("Read {} lines from {}", text_len, path,); + } + + t.qps("text read", text_len); +} + +fn load_data_features(plaintext: Arc>>>, path: &str) { + let t = timer::Timer::new_silent("load features"); + + let lines = files::read_csv_as_u64(path); + let n_rows = lines.len(); + let n_cols = lines[0].len(); + assert!(n_rows > 0); + assert!(n_cols > 0); + + if let Ok(mut data) = plaintext.write() { + data.clear(); + + // Make sure its not a ragged array + for i in lines.iter() { + assert_eq!(n_cols, i.len()); + } + + let mut features: Vec> = vec![vec![0; n_rows]; n_cols]; + + for (i, v) in lines.iter().enumerate() { + for (j, z) in v.iter().enumerate() { + features[j][i] = *z; + } + } + + data.extend(features.drain(..)); + + info!("Read {} lines from {}", n_rows, path,); + } + + t.qps("text read", n_rows); +} + +fn writer_helper(data: &[Vec], id_map: &[(String, usize, bool)], path: Option) { + let mut device = match path { + Some(path) => { + let wr = csv::WriterBuilder::new() + .flexible(true) + .buffer_capacity(1024) + .from_path(path) + .unwrap(); + Some(wr) + } + None => None, + }; + + for (key, idx, flag) in id_map.iter() { + let mut v = vec![(*key).clone()]; + + match flag { + true => v.extend(data[*idx].clone()), + false => v.push("NA".to_string()), + } + + match device { + Some(ref mut wr) => { + wr.write_record(v.as_slice()).unwrap(); + } + None => { + println!("{}", v.join(",")); + } + } + } +} + +fn compute_prefix_sum(input: &[usize]) -> Vec { + let prefix_sum = input + .iter() + .scan(0, |sum, i| { + *sum += i; + Some(*sum) + }) + .collect::>(); + + // offset is now a combined exclusive and inclusive prefix sum + // that will help us convert to a flattened vector and back to a + // vector of vectors + let mut output = Vec::::with_capacity(prefix_sum.len() + 1); + output.push(0); + output.extend(prefix_sum); + output +} + +fn serialize_helper(data: Vec>) -> (Vec, TPayload, TPayload) { + let offset = { + let lengths = data.iter().map(|v| v.len()).collect::>(); + compute_prefix_sum(&lengths) + .iter() + .map(|&o| ByteBuffer { + buffer: (o as u64).to_le_bytes().to_vec(), + }) + .collect::>() + }; + + let d_flat = data.into_iter().flatten().collect::>(); + + let metadata = vec![ + ByteBuffer { + buffer: (d_flat.len() as u64).to_le_bytes().to_vec(), + }, + ByteBuffer { + buffer: (offset.len() as u64).to_le_bytes().to_vec(), + }, + ]; + + (d_flat, offset, metadata) +} + +pub mod company; +pub mod helper; +pub mod partner; +pub mod shuffler; +pub mod traits; diff --git a/protocol/src/dspmc/partner.rs b/protocol/src/dspmc/partner.rs new file mode 100644 index 0000000..58c5392 --- /dev/null +++ b/protocol/src/dspmc/partner.rs @@ -0,0 +1,258 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate csv; + +use std::convert::TryInto; +use std::sync::Arc; +use std::sync::RwLock; + +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; +use itertools::Itertools; +use rand::prelude::*; + +use super::load_data_features; +use super::load_data_keys; +use super::serialize_helper; +use super::ProtocolError; +use crate::dspmc::traits::PartnerDspmcProtocol; +use crate::shared::TFeatures; + +pub struct PartnerDspmc { + company_public_key: Arc>, + helper_public_key: Arc>, + ec_cipher: ECRistrettoParallel, + plaintext_keys: Arc>>>, + plaintext_features: Arc>, +} + +impl PartnerDspmc { + pub fn new() -> PartnerDspmc { + PartnerDspmc { + company_public_key: Arc::new(RwLock::default()), + helper_public_key: Arc::new(RwLock::default()), + ec_cipher: ECRistrettoParallel::default(), + plaintext_keys: Arc::new(RwLock::default()), + plaintext_features: Arc::new(RwLock::default()), + } + } + + // TODO: Fix header processing + pub fn load_data(&self, path_keys: &str, path_features: &str, input_with_headers: bool) { + load_data_keys(self.plaintext_keys.clone(), path_keys, input_with_headers); + load_data_features(self.plaintext_features.clone(), path_features); + + match ( + self.plaintext_keys.clone().read(), + self.plaintext_features.clone().read(), + ) { + (Ok(keys), Ok(features)) => { + assert!(features.len() > 0); + assert_eq!(keys.len(), features[0].len()); + } + _ => { + error!("Unable to read keys and features"); + } + } + } + + pub fn get_size(&self) -> usize { + self.plaintext_keys.clone().read().unwrap().len() + } + + pub fn set_company_public_key( + &self, + company_public_key: TPayload, + ) -> Result<(), ProtocolError> { + let pk = self.ec_cipher.to_points(&company_public_key); + // Check that two keys are sent + assert_eq!(pk.len(), 2); + + match self.company_public_key.clone().write() { + Ok(mut company_pk) => { + company_pk.0 = pk[0]; + company_pk.1 = pk[1]; + assert!(!(company_pk.0).is_identity()); + assert!(!(company_pk.1).is_identity()); + Ok(()) + } + _ => { + error!("Unable to set company public key"); + Err(ProtocolError::ErrorEncryption( + "unable to set company public key".to_string(), + )) + } + } + } + + pub fn set_helper_public_key(&self, helper_public_key: TPayload) -> Result<(), ProtocolError> { + let pk = self.ec_cipher.to_points(&helper_public_key); + // Check that one key is sent + assert_eq!(pk.len(), 1); + match self.helper_public_key.clone().write() { + Ok(mut helper_pk) => { + *helper_pk = pk[0]; + assert_eq!((*helper_pk).is_identity(), false); + Ok(()) + } + _ => { + error!("Unable to set helper public key"); + Err(ProtocolError::ErrorEncryption( + "unable to set helper public key".to_string(), + )) + } + } + } +} + +impl Default for PartnerDspmc { + fn default() -> Self { + Self::new() + } +} + +impl PartnerDspmcProtocol for PartnerDspmc { + fn get_encrypted_keys(&self) -> Result { + match ( + self.plaintext_keys.clone().read(), + self.company_public_key.clone().read(), + self.helper_public_key.clone().read(), + ) { + (Ok(pdata), Ok(company_pk), Ok(helper_pk)) => { + let t = timer::Timer::new_silent("partner data"); + + // let n_rows = pdata.len(); + + // ct2 = helper_pk^r * H(P_i) + // with EC: helper_pk*r + H(P_i) + let (mut d_flat, ct1_flat, offset) = { + let (d_flat, mut offset, metadata) = serialize_helper(pdata.clone()); + offset.extend(metadata); + let hash_p = self.ec_cipher.hash(d_flat.as_slice()); + + // ct1 = company_pk^r + // with EC: company_pk * r + let (ct1, pkd_r) = { + let r_i = (0..d_flat.len()) + .collect::>() + .iter() + .map(|_| gen_scalar()) + .collect::>(); + let ct1_bytes = { + let t1 = r_i.iter().map(|x| *x * company_pk.0).collect::>(); + self.ec_cipher.to_bytes(&t1) + }; + let pkd_r = r_i.iter().map(|x| *x * (*helper_pk)).collect::>(); + (ct1_bytes, pkd_r) + }; + + let ct2 = pkd_r + .iter() + .zip_eq(hash_p.iter()) + .map(|(s, t)| *s + *t) + .collect::>(); + + (self.ec_cipher.to_bytes(ct2.as_slice()), ct1, offset) + }; + + // Append ct1 + d_flat.extend(ct1_flat); + // Append offsets array + d_flat.extend(offset); + + t.qps("encryption", d_flat.len()); + + // d_flat = ct2, ct1, offset + Ok(d_flat) + } + _ => { + error!("Unable to encrypt data"); + Err(ProtocolError::ErrorEncryption( + "unable to encrypt data".to_string(), + )) + } + } + } + + fn get_features_xor_shares(&self) -> Result { + match ( + self.plaintext_features.clone().read(), + self.helper_public_key.clone().read(), + ) { + (Ok(pdata), Ok(helper_pk)) => { + let t = timer::Timer::new_silent("get_features_xor_shares"); + let n_rows = pdata[0].len(); + let n_features = pdata.len(); + + // ct3 = scalar * g + // PRG seed = scalar * PK_helper + let (seed, ct3) = { + let x = gen_scalar(); + let ct3 = self.ec_cipher.to_bytes(&[&x * &RISTRETTO_BASEPOINT_TABLE]); + let seed: [u8; 32] = { + let t = self.ec_cipher.to_bytes(&[x * (*helper_pk)]); + t[0].buffer.as_slice().try_into().expect("incorrect length") + }; + (seed, ct3) + }; + + let mut rng = StdRng::from_seed(seed); + let mut v2 = TFeatures::new(); + for _ in 0..n_features { + let t = (0..n_rows) + .collect::>() + .iter() + .map(|_| rng.gen::()) + .collect::>(); + v2.push(t); + } + + let mut d_flat = { + let mut v_p = Vec::::new(); + + for f_idx in (0..n_features).rev() { + let t = (0..n_rows) + .map(|x| x) + .collect::>() + .iter() + .map(|i| { + let z: u64 = pdata[f_idx][*i] ^ v2[f_idx][*i]; + ByteBuffer { + buffer: z.to_le_bytes().to_vec(), + } + }) + .collect::>(); + v_p.push(t); + } + + v_p.into_iter().flatten().collect::>() + }; + + let metadata = vec![ + ByteBuffer { + buffer: (n_rows as u64).to_le_bytes().to_vec(), + }, + ByteBuffer { + buffer: (n_features as u64).to_le_bytes().to_vec(), + }, + ]; + d_flat.extend(metadata); + d_flat.extend(ct3); + + t.qps("d_flat", d_flat.len()); + + Ok(d_flat) + } + _ => { + error!("Unable to encrypt data"); + Err(ProtocolError::ErrorEncryption( + "unable to encrypt data".to_string(), + )) + } + } + } +} diff --git a/protocol/src/dspmc/shuffler.rs b/protocol/src/dspmc/shuffler.rs new file mode 100644 index 0000000..a8cc0a1 --- /dev/null +++ b/protocol/src/dspmc/shuffler.rs @@ -0,0 +1,429 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate csv; + +use std::convert::TryInto; +use std::sync::Arc; +use std::sync::RwLock; + +use common::permutations::gen_permute_pattern; +use common::permutations::permute; +use common::timer; +use crypto::eccipher::gen_scalar; +use crypto::eccipher::ECCipher; +use crypto::eccipher::ECRistrettoParallel; +use crypto::prelude::*; +use itertools::Itertools; +use rand::distributions::Uniform; +use rand::Rng; + +use super::serialize_helper; +use super::ProtocolError; +use crate::dspmc::traits::ShufflerDspmcProtocol; +use crate::shared::TFeatures; + +pub struct ShufflerDspmc { + company_public_key: Arc>, + helper_public_key: Arc>, + ec_cipher: ECRistrettoParallel, + p_cs: Arc>>, + v_cs: Arc>>, + perms: Arc, Vec)>>, // (p_sc, p_sd) + blinds: Arc, Vec)>>, // (v_sc, v_sd) + xor_shares_v1: Arc>, // v' + ct1_dprime: Arc>>>, + ct2_dprime: Arc>>>, +} + +impl ShufflerDspmc { + pub fn new() -> ShufflerDspmc { + ShufflerDspmc { + company_public_key: Arc::new(RwLock::default()), + helper_public_key: Arc::new(RwLock::default()), + ec_cipher: ECRistrettoParallel::default(), + p_cs: Arc::new(RwLock::default()), + v_cs: Arc::new(RwLock::default()), + perms: Arc::new(RwLock::default()), + blinds: Arc::new(RwLock::default()), + xor_shares_v1: Arc::new(RwLock::default()), + ct1_dprime: Arc::new(RwLock::default()), + ct2_dprime: Arc::new(RwLock::default()), + } + } + + pub fn set_company_public_key( + &self, + company_public_key: TPayload, + ) -> Result<(), ProtocolError> { + let pk = self.ec_cipher.to_points(&company_public_key); + // Check that two keys are sent + assert_eq!(pk.len(), 2); + + match self.company_public_key.clone().write() { + Ok(mut company_pk) => { + company_pk.0 = pk[0]; + company_pk.1 = pk[1]; + assert!(!(company_pk.0).is_identity()); + assert!(!(company_pk.1).is_identity()); + Ok(()) + } + _ => { + error!("Unable to set company public key"); + Err(ProtocolError::ErrorEncryption( + "unable to set company public key".to_string(), + )) + } + } + } + + pub fn set_helper_public_key(&self, helper_public_key: TPayload) -> Result<(), ProtocolError> { + let pk = self.ec_cipher.to_points(&helper_public_key); + // Check that one key is sent + assert_eq!(pk.len(), 1); + + match self.helper_public_key.clone().write() { + Ok(mut helper_pk) => { + *helper_pk = pk[0]; + assert_eq!((*helper_pk).is_identity(), false); + Ok(()) + } + _ => { + error!("Unable to set helper public key"); + Err(ProtocolError::ErrorEncryption( + "unable to set helper public key".to_string(), + )) + } + } + } +} + +impl Default for ShufflerDspmc { + fn default() -> Self { + Self::new() + } +} + +impl ShufflerDspmcProtocol for ShufflerDspmc { + fn set_p_cs_v_cs( + &self, + v_cs_bytes: TPayload, + p_cs_bytes: TPayload, + ) -> Result<(), ProtocolError> { + match (self.p_cs.clone().write(), self.v_cs.clone().write()) { + (Ok(mut p_cs), Ok(mut v_cs)) => { + let t = timer::Timer::new_silent("set p_cs, v_cs"); + *v_cs = v_cs_bytes + .iter() + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) + .collect::>(); + + *p_cs = p_cs_bytes + .iter() + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap()) as usize) + .collect::>(); + + t.qps("deserialize_exp", p_cs_bytes.len()); + Ok(()) + } + _ => { + error!("Cannot load p_cs, v_cs"); + Err(ProtocolError::ErrorDeserialization( + "cannot load p_cs, v_cs".to_string(), + )) + } + } + } + + fn gen_permutations(&self) -> Result<(TPayload, TPayload), ProtocolError> { + match ( + self.p_cs.clone().read(), + self.perms.clone().write(), + self.blinds.clone().write(), + ) { + (Ok(p_cs), Ok(mut perms), Ok(mut blinds)) => { + let mut rng = rand::thread_rng(); + let range = Uniform::new(0_u64, u64::MAX); + + let data_len = p_cs.len(); + assert!(data_len > 0); + perms.0.clear(); + perms.1.clear(); + perms.0.extend(gen_permute_pattern(data_len)); + perms.1.extend(gen_permute_pattern(data_len)); + + blinds.0 = (0..data_len) + .map(|_| rng.sample(range)) + .collect::>(); + blinds.1 = (0..data_len) + .map(|_| rng.sample(range)) + .collect::>(); + + let mut p_sc_v_sc = perms + .0 + .iter() + .map(|e| ByteBuffer { + buffer: (*e).to_le_bytes().to_vec(), + }) + .collect::>(); + let v_sc_bytes = blinds + .0 + .iter() + .map(|e| ByteBuffer { + buffer: (*e).to_le_bytes().to_vec(), + }) + .collect::>(); + p_sc_v_sc.extend(v_sc_bytes); + + let mut p_sd_v_sd = perms + .1 + .iter() + .map(|e| ByteBuffer { + buffer: (*e as u64).to_le_bytes().to_vec(), + }) + .collect::>(); + let v_sd_bytes = blinds + .1 + .iter() + .map(|e| ByteBuffer { + buffer: (*e as u64).to_le_bytes().to_vec(), + }) + .collect::>(); + p_sd_v_sd.extend(v_sd_bytes); + p_sd_v_sd.push(ByteBuffer { + buffer: (data_len as u64).to_le_bytes().to_vec(), + }); + + Ok((p_sc_v_sc, p_sd_v_sd)) + } + _ => { + error!("Unable to generate p_sc, v_sc, p_sd, v_sd:"); + Err(ProtocolError::ErrorEncryption( + "cannot generate p_sc, v_sc, p_sd, v_sd".to_string(), + )) + } + } + } + + fn compute_v2prime_ct1ct2( + &self, + mut u2: TFeatures, + ct1_prime_bytes: TPayload, + ct2_prime_bytes: TPayload, + psum: Vec, + ) -> Result { + match ( + self.p_cs.clone().read(), + self.v_cs.clone().read(), + self.perms.clone().read(), + self.blinds.clone().read(), + self.company_public_key.clone().read(), + self.helper_public_key.clone().read(), + self.xor_shares_v1.clone().write(), + self.ct1_dprime.clone().write(), + self.ct2_dprime.clone().write(), + ) { + ( + Ok(p_cs), + Ok(v_cs), + Ok(perms), + Ok(blinds), + Ok(company_pk), + Ok(helper_pk), + Ok(mut v_p), + Ok(mut ct1_dprime), + Ok(mut ct2_dprime), + ) => { + // This is an array of exclusive-inclusive prefix sum - hence + // number of keys is one less than length + let num_keys = psum.len() - 1; + + // Unflatten and convert to points + let mut ct1_prime = { + let t = self.ec_cipher.to_points(&ct1_prime_bytes); + + psum.get(0..num_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| t.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + let mut ct2_prime = { + let t = self.ec_cipher.to_points(&ct2_prime_bytes); + + psum.get(0..num_keys) + .unwrap() + .iter() + .zip_eq(psum.get(1..num_keys + 1).unwrap().iter()) + .map(|(&x1, &x2)| t.get(x1..x2).unwrap().to_vec()) + .collect::>>() + }; + + let n_rows = u2[0].len(); + let n_features = u2.len(); + + // Compute p_sd(p_sc(p_cs( u2 ) xor v_cs) xor v_sc) xor v_sd + (*v_p).clear(); + for f_idx in (0..n_features).rev() { + permute(p_cs.as_slice(), &mut u2[f_idx]); // p_cs + let mut x_2 = u2[f_idx] + .iter() + .zip_eq(v_cs.iter()) + .map(|(s, v_cs)| *s ^ *v_cs) + .collect::>(); + + permute(perms.0.as_slice(), &mut x_2); // p_sc + let mut t_1 = x_2 + .iter() + .zip_eq(blinds.0.iter()) + .map(|(s, v_sc)| *s ^ *v_sc) + .collect::>(); + + permute(perms.1.as_slice(), &mut t_1); // p_sd + (*v_p).push( + t_1.iter() + .zip_eq(blinds.1.iter()) + .map(|(s, v_sd)| *s ^ *v_sd) + .collect::>(), + ); + } + + // Re-randomize ct1'' and ct2'' and flatten + let (mut ct1_dprime_flat, ct2_dprime_flat, ct_offset) = { + let r_i = (0..n_rows) + .collect::>() + .iter() + .map(|_| gen_scalar()) + .collect::>(); + // company_pk^r + // with EC: company_pk * r + let pkc_r = r_i.iter().map(|x| *x * (company_pk.0)).collect::>(); + // helper_pk^r + // with EC: helper_pk * r + let pkd_r = r_i.iter().map(|x| *x * (*helper_pk)).collect::>(); + + permute(perms.0.as_slice(), &mut ct1_prime); // p_sc + permute(perms.0.as_slice(), &mut ct2_prime); // p_sc + permute(perms.1.as_slice(), &mut ct1_prime); // p_sd + permute(perms.1.as_slice(), &mut ct2_prime); // p_sd + + // ct1' = ct1'' * company_pk^r + // with EC: ct1' = ct1'' + company_pk*r + *ct1_dprime = ct1_prime + .iter() + .zip_eq(pkc_r.iter()) + .map(|(s, t)| (*s).iter().map(|si| *si + *t).collect::>()) + .collect::>(); + // ct2' = ct2'' * helper_pk^r + // with EC: ct2' = ct2'' + helper_pk*r + *ct2_dprime = ct2_prime + .iter() + .zip_eq(pkd_r.iter()) + .map(|(s, t)| (*s).iter().map(|si| *si + *t).collect::>()) + .collect::>(); + let (ct1_dprime_flat, _ct1_offset) = { + let (d_flat, mut offset, metadata) = serialize_helper(ct1_dprime.clone()); + offset.extend(metadata); + + (self.ec_cipher.to_bytes(d_flat.as_slice()), offset) + }; + let (ctd2_dprime_flat, ct2_offset) = { + let (d_flat, mut offset, metadata) = serialize_helper(ct2_dprime.clone()); + offset.extend(metadata); + + (self.ec_cipher.to_bytes(d_flat.as_slice()), offset) + }; + (ct1_dprime_flat, ctd2_dprime_flat, ct2_offset) + }; + ct1_dprime_flat.extend(ct2_dprime_flat); + ct1_dprime_flat.extend(ct_offset); + + // ct1_dprime_flat, ct2_dprime_flat, ct_offset + Ok(ct1_dprime_flat) + } + _ => { + error!("Unable to flatten ct1'' and ct2'':"); + Err(ProtocolError::ErrorEncryption( + "cannot flatten ct1'' and ct2''".to_string(), + )) + } + } + } + + fn get_blinded_vprime(&self) -> Result { + match ( + self.company_public_key.clone().read(), + self.xor_shares_v1.clone().read(), + ) { + (Ok(company_pk), Ok(v_p)) => { + let t = timer::Timer::new_silent("get_blinded_vprime"); + + let n_rows = v_p[0].len(); + let n_features = v_p.len(); + + let z_i = (0..n_rows).map(|_| gen_scalar()).collect::>(); + + let mut d_flat = { + let r_i = { + let y_zi = { + let t = z_i.iter().map(|x| *x * company_pk.0).collect::>(); + self.ec_cipher.to_bytes(&t) + }; + y_zi.iter() + .map(|x| u64::from_le_bytes((x.buffer[0..8]).try_into().unwrap())) + .collect::>() + }; + + let mut blinded_vprime = Vec::::new(); + + for f_idx in (0..n_features).rev() { + let t = (0..n_rows) + .collect::>() + .iter() + .map(|i| { + let z: u64 = v_p[f_idx][*i] ^ r_i[*i]; + ByteBuffer { + buffer: z.to_le_bytes().to_vec(), + } + }) + .collect::>(); + blinded_vprime.push(t); + } + + blinded_vprime.into_iter().flatten().collect::>() + }; + + { + let g_zi = { + let t = z_i + .iter() + .map(|x| x * &RISTRETTO_BASEPOINT_TABLE) + .collect::>(); + self.ec_cipher.to_bytes(&t) + }; + d_flat.extend(g_zi); + } + + let metadata = vec![ + ByteBuffer { + buffer: (n_rows as u64).to_le_bytes().to_vec(), + }, + ByteBuffer { + buffer: (n_features as u64).to_le_bytes().to_vec(), + }, + ]; + d_flat.extend(metadata); + + t.qps("get_blinded_vprime", d_flat.len()); + Ok(d_flat) + } + _ => { + error!("Unable to serialize v1':"); + Err(ProtocolError::ErrorEncryption( + "cannot serialize v1'".to_string(), + )) + } + } + } +} diff --git a/protocol/src/dspmc/traits.rs b/protocol/src/dspmc/traits.rs new file mode 100644 index 0000000..2b12ef3 --- /dev/null +++ b/protocol/src/dspmc/traits.rs @@ -0,0 +1,100 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// SPDX-License-Identifier: Apache-2.0 + +extern crate crypto; + +use crypto::prelude::TPayload; + +use crate::dspmc::ProtocolError; +use crate::shared::TFeatures; + +pub trait PartnerDspmcProtocol { + fn get_encrypted_keys(&self) -> Result; + fn get_features_xor_shares(&self) -> Result; +} + +pub trait ShufflerDspmcProtocol { + fn set_p_cs_v_cs( + &self, + v_cs_bytes: TPayload, + p_cs_bytes: TPayload, + ) -> Result<(), ProtocolError>; + fn gen_permutations(&self) -> Result<(TPayload, TPayload), ProtocolError>; + fn get_blinded_vprime(&self) -> Result; + fn compute_v2prime_ct1ct2( + &self, + u2_bytes: TFeatures, + ct1_prime_bytes: TPayload, + ct2_prime_bytes: TPayload, + psum: Vec, + ) -> Result; +} + +pub trait CompanyDspmcProtocol { + fn set_encrypted_partner_keys_and_shares( + &self, + ct1: TPayload, + ct2: TPayload, + keys_psum: Vec, + ct3: Vec, + xor_features: TFeatures, + ) -> Result<(), ProtocolError>; + fn get_all_ct3_p_cd_v_cd(&self) -> Result; + fn get_company_keys(&self) -> Result; + fn get_ct1_ct2(&self) -> Result; + fn get_p_cs_v_cs(&self) -> Result; + fn get_u1(&self) -> Result; + fn calculate_features_xor_shares( + &self, + features: TFeatures, + g_zi: TPayload, + ) -> Result<(), ProtocolError>; + fn write_company_to_id_map(&self) -> Result<(), ProtocolError>; + fn print_id_map(&self); + fn save_id_map(&self, path: &str) -> Result<(), ProtocolError>; + fn save_features_shares(&self, path_prefix: &str) -> Result<(), ProtocolError>; + fn set_p_sc_v_sc_ct1ct2dprime( + &self, + v_sc_bytes: TPayload, + p_sc_bytes: TPayload, + ct1_dprime_flat: TPayload, + ct2_dprime_flat: TPayload, + psum: Vec, + ) -> Result<(), ProtocolError>; +} + +pub trait HelperDspmcProtocol { + fn set_ct3p_cd_v_cd( + &self, + ct3: TPayload, + num_partners: usize, + v_cd_bytes: TPayload, + p_cd_bytes: TPayload, + ) -> Result<(), ProtocolError>; + fn set_encrypted_vprime( + &self, + blinded_features: TFeatures, + g_zi: TPayload, + ) -> Result<(), ProtocolError>; + fn set_encrypted_keys( + &self, + enc_keys: TPayload, + psum: Vec, + ct1: TPayload, + ct2: TPayload, + ct_psum: Vec, + ) -> Result<(), ProtocolError>; + fn set_p_sd_v_sd( + &self, + v_sd_bytes: TPayload, + p_sd_bytes: TPayload, + ) -> Result<(), ProtocolError>; + fn set_u1(&self, u1_bytes: TFeatures) -> Result<(), ProtocolError>; + fn calculate_set_diff(&self) -> Result<(), ProtocolError>; + fn calculate_features_xor_shares(&self) -> Result; + fn get_u2(&self) -> Result; + fn calculate_id_map(&self); + fn print_id_map(&self); + fn save_id_map(&self, path: &str) -> Result<(), ProtocolError>; + fn save_features_shares(&self, path_prefix: &str) -> Result<(), ProtocolError>; +} diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 061ccc7..dae101d 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -8,6 +8,8 @@ extern crate log; pub mod cross_psi; pub mod cross_psi_xor; +pub mod dpmc; +pub mod dspmc; pub mod fileio; pub mod pjc; pub mod private_id;