Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ot #133

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 59 additions & 27 deletions src/encryption/asymmetric/rsa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,56 +17,88 @@ const fn mod_inverse(e: u64, totient: u64) -> u64 {
}
d
}
/// RSAKey struct
pub struct RSA {
/// pub key (e,n)
pub private_key: PrivateKey,
/// priv key (d,n)
pub public_key: PublicKey,
}

/// private key
pub struct PrivateKey {
/// gcd(e, totient) = 1
e: usize,
/// d x e mod totient = 1
d: usize,
/// modulus
n: usize,
pub n: usize,
}

/// public key
pub struct PublicKey {
/// d x e mod totient = 1
d: usize,
/// gcd(e, totient) = 1
e: usize,
/// modulus
n: usize,
pub n: usize,
}

impl RSA {
/// Encrypts a message using the RSA algorithm
#[allow(dead_code)]
/// RSA encryption
pub struct RSAEncryption {
/// RSA public key
pub public_key: PublicKey,
}

/// RSA decryption
pub struct RSADecryption {
/// RSA private key
pub private_key: PrivateKey,
}

impl RSAEncryption {
/// Encrypts a message using the RSA algorithm
/// C = P^e mod n
const fn encrypt(&self, message: u32) -> u32 {
message.pow(self.private_key.e as u32) % self.private_key.n as u32
pub const fn encrypt(&self, plaintext: u32) -> u32 {
let mut plaintext = plaintext;
let mut res = 1;
let mut exp = self.public_key.e as u32;

while exp > 0 {
if exp % 2 == 1 {
res = ((res as u64 * plaintext as u64) % self.public_key.n as u64) as u32;
}
plaintext = ((plaintext as u64).pow(2) % self.public_key.n as u64) as u32;
exp >>= 1;
}

res
}
}

#[allow(dead_code)]
impl RSADecryption {
/// Decrypts a cipher using the RSA algorithm
/// P = C^d mod n
const fn decrypt(&self, cipher: u32) -> u32 {
cipher.pow(self.public_key.d as u32) % self.public_key.n as u32
pub const fn decrypt(&self, ciphertext: u32) -> u32 {
let mut res = 1;
let mut ciphertext = ciphertext;
let mut exp = self.private_key.d as u32;

while exp > 0 {
if exp % 2 == 1 {
res = ((res as u64 * ciphertext as u64) % self.private_key.n as u64) as u32;
}
ciphertext = ((ciphertext as u64).pow(2) % self.private_key.n as u64) as u32;
exp >>= 1;
}

res
// ((ciphertext as u64).pow(self.private_key.d as u32) % self.private_key.n as u64) as u32
}
}

/// Key generation for the RSA algorithm
/// TODO: Implement a secure key generation algorithm using miller rabin primality test
pub fn rsa_key_gen(p: usize, q: usize) -> RSA {
pub fn rsa_key_gen(p: usize, q: usize) -> (RSAEncryption, RSADecryption) {
assert!(is_prime(p));
assert!(is_prime(q));
let n = p * q;
let e = generate_e(p, q);
let totient = euler_totient(p as u64, q as u64);
let d = mod_inverse(e, totient);
RSA { private_key: PrivateKey { e: e as usize, n }, public_key: PublicKey { d: d as usize, n } }
(RSAEncryption { public_key: PublicKey { e: e as usize, n } }, RSADecryption {
private_key: PrivateKey { d: d as usize, n },
})
}

/// Generates e value for the RSA algorithm
Expand All @@ -86,10 +118,10 @@ const fn generate_e(p: usize, q: usize) -> u64 {
panic!("Failed to find coprime e; totient should be greater than 1")
}

/// Generates a random prime number bigger than 1_000_000
pub fn random_prime(first_prime: usize) -> usize {
let mut n = 1_000_000;
while !is_prime(n) && n != first_prime {
/// Generates a random prime number bigger than `begin`
pub fn random_prime(begin: usize) -> usize {
let mut n = begin;
while !is_prime(n) {
n += 1;
}
n
Expand Down
64 changes: 36 additions & 28 deletions src/encryption/asymmetric/rsa/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,26 @@ fn test_euler_totient() {

#[test]
fn key_gen() {
let key = rsa_key_gen(PRIME_1, PRIME_2);
assert_eq!(key.public_key.n, PRIME_1 * PRIME_2);
assert_eq!(gcd(key.private_key.e as u64, euler_totient(PRIME_1 as u64, PRIME_2 as u64)), 1);

let key = rsa_key_gen(PRIME_2, PRIME_3);
assert_eq!(key.public_key.n, PRIME_2 * PRIME_3);
assert_eq!(gcd(key.private_key.e as u64, euler_totient(PRIME_2 as u64, PRIME_3 as u64)), 1);

let key = rsa_key_gen(PRIME_3, PRIME_1);
assert_eq!(key.public_key.n, PRIME_3 * PRIME_1);
assert_eq!(gcd(key.private_key.e as u64, euler_totient(PRIME_3 as u64, PRIME_1 as u64)), 1);
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_1, PRIME_2);
assert_eq!(rsa_encrypt.public_key.n, PRIME_1 * PRIME_2);
assert_eq!(
gcd(rsa_decrypt.private_key.d as u64, euler_totient(PRIME_1 as u64, PRIME_2 as u64)),
1
);

let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_2, PRIME_3);
assert_eq!(rsa_encrypt.public_key.n, PRIME_2 * PRIME_3);
assert_eq!(
gcd(rsa_decrypt.private_key.d as u64, euler_totient(PRIME_2 as u64, PRIME_3 as u64)),
1
);

let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_3, PRIME_1);
assert_eq!(rsa_encrypt.public_key.n, PRIME_3 * PRIME_1);
assert_eq!(
gcd(rsa_decrypt.private_key.d as u64, euler_totient(PRIME_3 as u64, PRIME_1 as u64)),
1
);
}

#[test]
Expand Down Expand Up @@ -58,33 +67,32 @@ fn test_mod_inverse() {
#[test]
fn test_encrypt_decrypt() {
let message = 10;
let key = rsa_key_gen(PRIME_1, PRIME_2);
let cipher = key.encrypt(message);
let decrypted = key.decrypt(cipher);
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_1, PRIME_2);
let cipher = rsa_encrypt.encrypt(message);
let decrypted = rsa_decrypt.decrypt(cipher);
assert_eq!(decrypted, message);

let key = rsa_key_gen(PRIME_2, PRIME_3);
let cipher = key.encrypt(message);
let decrypted = key.decrypt(cipher);
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_2, PRIME_3);
let cipher = rsa_encrypt.encrypt(message);
let decrypted = rsa_decrypt.decrypt(cipher);
assert_eq!(decrypted, message);

let message = 10;
let key = rsa_key_gen(PRIME_3, PRIME_1);
let cipher = key.encrypt(message);
let decrypted = key.decrypt(cipher);
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(PRIME_3, PRIME_1);
let cipher = rsa_encrypt.encrypt(message);
let decrypted = rsa_decrypt.decrypt(cipher);
assert_eq!(decrypted, message);
}

#[test]
fn test_random_prime() {
let prime = random_prime(2);
assert!(is_prime(prime));
assert!(prime >= 1_000_000);
let message = u16::MAX as u32;
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(10007, 49999);
let cipher = rsa_encrypt.encrypt(message);
let decrypted = rsa_decrypt.decrypt(cipher);
assert_eq!(decrypted, message);
}

#[test]
fn test_random_prime_generation() {
let prime = random_prime(2);
fn test_random_prime() {
let prime = random_prime(1_000_000);
assert!(is_prime(prime));
assert!(prime >= 1_000_000);
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod encryption;
pub mod field;
pub mod hashes;
pub mod kzg;
pub mod ot;
pub mod polynomial;
pub mod tree;

Expand Down
Empty file added src/ot/README.md
Empty file.
2 changes: 2 additions & 0 deletions src/ot/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
//! Contains implementation of oblivious transfer and various extensions.
pub mod ot_rsa;
120 changes: 120 additions & 0 deletions src/ot/ot_rsa.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
//! Contains implementation of 1-out-of-2 OT using RSA encryption.

use rand::{thread_rng, Rng};

use crate::encryption::asymmetric::rsa::{rsa_key_gen, RSADecryption, RSAEncryption};

/// Sender that has two messages and wants to send one of it to [`OTReceiver`] without knowledge of
/// which one.
pub struct OTSender {
messages: [usize; 2],
random_messages: [usize; 2],
rsa_decrypt: RSADecryption,
}

/// Receiver wants to get access to one of the message that [`OTSender`] has without knowledge of
/// the other.
pub struct OTReceiver {
choice: bool,
key: usize,
}

impl OTSender {
/// create a new [`OTSender`] object.
/// ## Arguments
/// - `messages`: message that sender has access to
/// - `primes`: [`RSAEncryption`] primes
pub fn new(messages: [usize; 2], primes: [usize; 2]) -> (Self, RSAEncryption, [usize; 2]) {
let (rsa_encrypt, rsa_decrypt) = rsa_key_gen(primes[0], primes[1]);

let random_messages: [usize; 2] = rand::random();
(OTSender { messages, rsa_decrypt, random_messages }, rsa_encrypt, random_messages)
}

/// Encrypt messages with receiver's choice
pub fn encrypt(&self, v: usize) -> [usize; 2] {
let k0 = if v < self.random_messages[0] {
v + self.rsa_decrypt.private_key.n
- (self.random_messages[0] % self.rsa_decrypt.private_key.n)
} else {
v - self.random_messages[0]
};
let k1 = if v < self.random_messages[1] {
v + self.rsa_decrypt.private_key.n
- (self.random_messages[1] % self.rsa_decrypt.private_key.n)
} else {
v - self.random_messages[1]
};

let k0 = self.rsa_decrypt.decrypt((k0) as u32);
let k1 = self.rsa_decrypt.decrypt((k1) as u32);

println!("k0: {}, k1: {}", k0, k1);

let m0 = (self.messages[0] + k0 as usize) % self.rsa_decrypt.private_key.n;
let m1 = (self.messages[1] + k1 as usize) % self.rsa_decrypt.private_key.n;

[m0, m1]
}
}

impl OTReceiver {
/// create new [`OTReceiver`] object
/// ## Arguments
/// - `choice`: receiver message choice
pub fn new(choice: bool) -> Self {
let mut rng = thread_rng();
Self { choice, key: rng.gen::<u32>() as usize }
}

/// Encrypts receiver's choice out of sender's messages.
///
/// v = (x_b + k^e) mod N
pub fn encrypt(&self, rsa_encrypt: RSAEncryption, sender_messages: [usize; 2]) -> usize {
println!("key: {}", self.key % rsa_encrypt.public_key.n);
(rsa_encrypt.encrypt(self.key as u32) as usize + sender_messages[self.choice as usize])
% rsa_encrypt.public_key.n
}

/// Decrypts sender's encrypted message
///
/// m_b = (m'_b - k) mod N
/// ## Arguments:
/// - `messages`: sender's encrypted messages: m'_0, m'_1
/// - `modulus`: RSA modulus
pub fn decrypt(&self, messages: [usize; 2], modulus: usize) -> usize {
if messages[self.choice as usize] < self.key {
(messages[self.choice as usize] + modulus - (self.key % modulus)) % modulus
} else {
(messages[self.choice as usize] - self.key) % modulus
}
}
}

#[cfg(test)]
mod tests {

use super::*;

#[test]
fn ot_rsa() {
let mut rng = thread_rng();
let messages = [10, 2];
let random_primes = [19, 13];

let (ot_sender, rsa_encrypt, random_messages) = OTSender::new(messages, random_primes);

let modulus = rsa_encrypt.public_key.n;

let bit = rng.gen::<bool>();
let ot_receiver = OTReceiver::new(bit);

let v = ot_receiver.encrypt(rsa_encrypt, random_messages);

let encrypted_messages = ot_sender.encrypt(v);

let message = ot_receiver.decrypt(encrypted_messages, modulus);

assert_eq!(message, messages[bit as usize]);
}
}