Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConstantTimeEq proc macro derive #302

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"capable/sys",
"capable/sys/types",
"capable/types",
"constant_time_derive",
varsha888 marked this conversation as resolved.
Show resolved Hide resolved
"core/build",
"core/sys/types",
"core/types",
Expand Down
14 changes: 14 additions & 0 deletions constant_time_derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "constant_time_derive"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"
varsha888 marked this conversation as resolved.
Show resolved Hide resolved

[lib]
proc-macro = true

[dependencies]
proc-macro2 = "1.0.8"
quote = "1.0"
subtle = { version = "2.4.0", default-features = false }
syn = "1.0"
115 changes: 115 additions & 0 deletions constant_time_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) 2023 The MobileCoin Foundation

use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Data, DataEnum, DeriveInput, Fields, GenericParam, Generics};

#[proc_macro_derive(ConstantTimeEq)]
pub fn constant_time_eq(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
derive_ct_eq(&input)
}

// TODO: Check or remove padding and align decorators on the struct
fn parse_fields(fields: &Fields) -> Result<proc_macro2::TokenStream, &'static str> {
match &fields {
Fields::Named(fields_named) => {
let mut token_stream = quote!();
let mut iter = fields_named.named.iter().peekable();

while let Some(field) = iter.next() {
let ident = &field.ident;
match iter.peek() {
None => token_stream.extend(quote! { {self.#ident}.ct_eq(&{other.#ident}) }),
Some(_) => {
token_stream.extend(quote! { {self.#ident}.ct_eq(&{other.#ident}) & })
}
}
}
Ok(token_stream)
varsha888 marked this conversation as resolved.
Show resolved Hide resolved
}
Fields::Unnamed(unnamed_fields) => {
let mut token_stream = quote!();
let mut iter = unnamed_fields.unnamed.iter().peekable();
let mut idx = 0;
while let Some(_) = iter.next() {
let i = syn::Index::from(idx);
match iter.peek() {
None => token_stream.extend(quote! { {self.#i}.ct_eq(&{other.#i}) }),
Some(_) => {
token_stream.extend(quote! { {self.#i}.ct_eq(&{other.#i}) & });
idx += 1;
}
}
}

Ok(token_stream)
}
Fields::Unit => Err("Constant time cannot be derived for unit fields"),
}
}

fn parse_enum(data_enum: &DataEnum) -> Result<proc_macro2::TokenStream, &'static str> {
for variant in data_enum.variants.iter() {
if let Fields::Unnamed(_) = variant.fields {
panic!("Cannot derive ct_eq for fields in enums")
}
}
let token_stream = quote! {
::subtle::Choice::from((self == other) as u8)
};

Ok(token_stream)
}

fn parse_data(data: &Data) -> Result<proc_macro2::TokenStream, &'static str> {
match data {
Data::Struct(variant_data) => parse_fields(&variant_data.fields),
Data::Enum(data_enum) => parse_enum(data_enum),
Data::Union(..) => Err("Constant time cannot be derived for a union"),
}
}

fn parse_lifetime(generics: &Generics) -> u32 {
let mut count = 0;
for i in generics.params.iter() {
if let GenericParam::Lifetime(_) = i {
count += 1;
}
}
count
}

fn derive_ct_eq(input: &DeriveInput) -> TokenStream {
let ident = &input.ident;
let data = &input.data;
let generics = &input.generics;
let is_lifetime = parse_lifetime(generics);
let ct_eq_stream: proc_macro2::TokenStream =
parse_data(data).expect("Failed to parse DeriveInput data");
let data_ident = if is_lifetime != 0 {
let mut s = format!("{}<'_", ident);

for _ in 1..is_lifetime {
s.push_str(", '_");
}
s.push('>');

s
} else {
ident.to_string()
};
let ident_stream: proc_macro2::TokenStream =
data_ident.parse().expect("Should be valid lifetime tokens");

let expanded: proc_macro2::TokenStream = quote! {
impl ::subtle::ConstantTimeEq for #ident_stream {
fn ct_eq(&self, other: &Self) -> ::subtle::Choice {
use ::subtle::ConstantTimeEq;
return #ct_eq_stream
}
}
};

expanded.into()
}
Loading