Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

poly1305 rust and natmod #10

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"curve25519",
"chacha20",
"poly1305",
"poly1305-rust",
"chacha20poly1305",
"gimli",
"sha256",
Expand Down
23 changes: 23 additions & 0 deletions poly1305-rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "poly1305"
version = "0.1.0"
authors = ["Franziskus Kiefer <[email protected]>"]
edition = "2021"
license = "MIT OR Apache-2.0"
description = "hacspec poly1305 message authentication code"
readme = "README.md"
repository = "https://github.com/hacspec/specs"

[lib]
path = "src/poly1305.rs"

[dependencies]
num-bigint = "0.4"
natmod = { path = "./natmod" }

[dev-dependencies]
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
rayon = "1.3.0"
criterion = "0.4"
rand = "0.8"
14 changes: 14 additions & 0 deletions poly1305-rust/natmod/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "natmod"
version = "0.1.0"
edition = "2021"
authors = ["Franziskus Kiefer <[email protected]>"]

[lib]
proc-macro = true

[dependencies]
hex = "0.4.3"
num-bigint = "0.4.3"
quote = "1.0.28"
syn = { version = "2.0.18", features = ["full"] }
134 changes: 134 additions & 0 deletions poly1305-rust/natmod/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//! // This trait lives in the library
//! pub trait NatModTrait<T> {
//! const MODULUS: T;
//! }
//!
//! #[nat_mod("123456", 10)]
//! struct MyNatMod {}

use hex::FromHex;
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse::Parse, parse_macro_input, DeriveInput, Ident, LitInt, LitStr, Result, Token};

#[derive(Clone, Debug)]
struct NatModAttr {
/// Modulus as hex string and bytes
mod_str: String,
mod_bytes: Vec<u8>,
/// Number of bytes to use for the integer
int_size: usize,
}

impl Parse for NatModAttr {
fn parse(input: syn::parse::ParseStream) -> Result<Self> {
let mod_str = input.parse::<LitStr>()?.value();
let mod_bytes = Vec::<u8>::from_hex(&mod_str).expect("Invalid hex String");
input.parse::<Token![,]>()?;
let int_size = input.parse::<LitInt>()?.base10_parse::<usize>()?;
assert!(input.is_empty(), "Left over tokens in attribute {input:?}");
Ok(NatModAttr {
mod_str,
mod_bytes,
int_size,
})
}
}

#[proc_macro_attribute]
pub fn nat_mod(attr: TokenStream, item: TokenStream) -> TokenStream {
let item_ast = parse_macro_input!(item as DeriveInput);
let ident = item_ast.ident.clone();
let args = parse_macro_input!(attr as NatModAttr);

let num_bytes = args.int_size;
let modulus = args.mod_bytes;
let modulus_string = args.mod_str;

let mut padded_modulus = vec![0u8; num_bytes - modulus.len()];
padded_modulus.append(&mut modulus.clone());
let mod_iter1 = padded_modulus.iter();
let mod_iter2 = padded_modulus.iter();
let const_name = Ident::new(
&format!("{}_MODULUS", ident.to_string().to_uppercase()),
ident.span(),
);
let static_name = Ident::new(
&format!("{}_MODULUS_STR", ident.to_string().to_uppercase()),
ident.span(),
);
let mod_name = Ident::new(
&format!("{}_mod", ident.to_string().to_uppercase()),
ident.span(),
);

let out_struct = quote! {
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct #ident {
value: [u8; #num_bytes],
}

//#[not_hax]
#[allow(non_snake_case)]
mod #mod_name {
use super::*;

const #const_name: [u8; #num_bytes] = [#(#mod_iter1),*];
static #static_name: &str = #modulus_string;

impl NatMod<#num_bytes> for #ident {
const MODULUS: [u8; #num_bytes] = [#(#mod_iter2),*];
const MODULUS_STR: &'static str = #modulus_string;
const ZERO: [u8; #num_bytes] = [0u8; #num_bytes];


fn new(value: [u8; #num_bytes]) -> Self {
Self {
value
}
}
fn value(&self) -> &[u8] {
&self.value
}
}

impl core::convert::AsRef<[u8]> for #ident {
fn as_ref(&self) -> &[u8] {
&self.value
}
}

impl core::fmt::Display for #ident {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.to_hex())
}
}


impl Into<[u8; #num_bytes]> for #ident {
fn into(self) -> [u8; #num_bytes] {
self.value
}
}

impl core::ops::Add for #ident {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
self.fadd(rhs)
}
}


impl core::ops::Mul for #ident {
type Output = Self;

fn mul(self, rhs: Self) -> Self::Output {
self.fmul(rhs)
}
}
}
};

out_struct.into()
}
172 changes: 172 additions & 0 deletions poly1305-rust/natmod/tests/poly1305.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use natmod::nat_mod;

/// This has to come from the lib.

pub trait NatMod<const LEN: usize> {
const MODULUS: [u8; LEN];
const MODULUS_STR: &'static str;
const ZERO: [u8; LEN];

fn new(value: [u8; LEN]) -> Self;
fn value(&self) -> &[u8];

/// Add self with `rhs` and return the result `self + rhs % MODULUS`.
fn fadd(self, rhs: Self) -> Self
where
Self: Sized,
{
let lhs = num_bigint::BigUint::from_bytes_be(self.value());
let rhs = num_bigint::BigUint::from_bytes_be(rhs.value());
let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
let res = (lhs + rhs) % modulus;
let res = res.to_bytes_be();
assert!(res.len() <= LEN);
let mut value = Self::ZERO;
let offset = LEN - res.len();
for i in 0..res.len() {
value[offset + i] = res[i];
}
Self::new(value)
}

/// Multiply self with `rhs` and return the result `self * rhs % MODULUS`.
fn fmul(self, rhs: Self) -> Self
where
Self: Sized,
{
let lhs = num_bigint::BigUint::from_bytes_be(self.value());
let rhs = num_bigint::BigUint::from_bytes_be(rhs.value());
let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
let res = (lhs * rhs) % modulus;
let res = res.to_bytes_be();
assert!(res.len() <= LEN);
let mut value = Self::ZERO;
let offset = LEN - res.len();
for i in 0..res.len() {
value[offset + i] = res[i];
}
Self::new(value)
}

/// Returns 2 to the power of the argument
fn pow2(x: usize) -> Self
where
Self: Sized,
{
let res = num_bigint::BigUint::from(1u32) << x;
Self::from_bigint(res)
}

/// Create a new [`#ident`] from a `u128` literal.
fn from_u128(literal: u128) -> Self
where
Self: Sized,
{
Self::from_bigint(num_bigint::BigUint::from(literal))
}

/// Create a new [`#ident`] from a little endian byte slice.
fn from_le_bytes(bytes: &[u8]) -> Self
where
Self: Sized,
{
Self::from_bigint(num_bigint::BigUint::from_bytes_le(bytes))
}

/// Create a new [`#ident`] from a little endian byte slice.
fn from_be_bytes(bytes: &[u8]) -> Self
where
Self: Sized,
{
Self::from_bigint(num_bigint::BigUint::from_bytes_be(bytes))
}

fn to_le_bytes(self) -> [u8; LEN]
where
Self: Sized,
{
Self::pad(&num_bigint::BigUint::from_bytes_be(self.value()).to_bytes_le())
}

/// Get hex string representation of this.
fn to_hex(&self) -> String {
let strs: Vec<String> = self.value().iter().map(|b| format!("{:02x}", b)).collect();
strs.join("")
}

/// New from hex string
fn from_hex(hex: &str) -> Self
where
Self: Sized,
{
assert!(hex.len() % 2 == 0);
let l = hex.len() / 2;
assert!(l <= LEN);
let mut value = [0u8; LEN];
let skip = LEN - l;
for i in 0..l {
value[skip + i] = u8::from_str_radix(&hex[2 * i..2 * i + 2], 16)
.expect("An unexpected error occurred.");
}
Self::new(value)
}

fn pad(bytes: &[u8]) -> [u8; LEN] {
let mut value = [0u8; LEN];
let upper = value.len();
let lower = upper - bytes.len();
value[lower..upper].copy_from_slice(&bytes);
value
}

fn from_bigint(x: num_bigint::BigUint) -> Self
where
Self: Sized,
{
let max_value = Self::MODULUS;
assert!(
x <= num_bigint::BigUint::from_bytes_be(&max_value),
"{} is too large for type {}!",
x,
stringify!($ident)
);
let repr = x.to_bytes_be();
if repr.len() > LEN {
panic!("{} is too large for this type", x)
}

Self::new(Self::pad(&repr))
}
}

#[nat_mod("03fffffffffffffffffffffffffffffffb", 17)]
struct FieldElement {}

#[test]
fn add() {
let x = FieldElement::from_hex("03fffffffffffffffffffffffffffffffa");
let y = FieldElement::from_hex("01");
let z = x + y;
assert_eq!(FieldElement::ZERO.as_ref(), z.as_ref());

let x = FieldElement::from_hex("03fffffffffffffffffffffffffffffffa");
let y = FieldElement::from_hex("02");
let z = x + y;
assert_eq!(FieldElement::from_hex("01").as_ref(), z.as_ref());
}

#[test]
fn mul() {
let x = FieldElement::from_hex("03fffffffffffffffffffffffffffffffa");
let y = FieldElement::from_hex("01");
let z = x * y;
assert_eq!(x.as_ref(), z.as_ref());

let x = FieldElement::from_hex("03fffffffffffffffffffffffffffffffa");
let y = FieldElement::from_hex("02");
let z = x * y;
assert_eq!(
FieldElement::from_hex("03fffffffffffffffffffffffffffffff9").as_ref(),
z.as_ref()
);
}
Loading