diff --git a/Cargo.toml b/Cargo.toml index 61b8de3..a2445f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,8 @@ indicatif = { version = "0.17.8", optional = true } ndarray-npy ="0.8.1" itertools = "0.12.1" thiserror = { version = "1.0.58", optional = true } +dtw = { git = "https://github.com/Ledger-Donjon/dtw.git", rev = "0f8d7ec3bbdf2ca4ec8ea35feddb8d1db73e7d54" } +num-traits = "0.2.19" [dev-dependencies] criterion = "0.5.1" diff --git a/src/preprocessors.rs b/src/preprocessors.rs index e6fd7cc..44895eb 100644 --- a/src/preprocessors.rs +++ b/src/preprocessors.rs @@ -1,6 +1,10 @@ use itertools::Itertools; use ndarray::{s, Array1, ArrayView1}; -use std::ops::Range; +use num_traits::{AsPrimitive, One, Zero}; +use std::{ + cmp::Ordering, + ops::{Div, Range}, +}; use crate::processors::MeanVar; @@ -148,12 +152,103 @@ impl StandardScaler { } } +pub use dtw::dist; + +/// Align traces using elastic alignment [1]. Elastic alignment is a dynamic alignment algorithm +/// based on FastDTW. +/// +/// # Examples +/// ```rust +/// use muscat::preprocessors::{ElasticAlignment, dist::euclidean_distance}; +/// use ndarray::array; +/// +/// let reference_trace = array![77, 117, 5, 51, 91, -12, -33]; +/// let trace_to_align = array![77, 117, 13, 15, 5, 51, 91]; +/// let elastic_alignment = ElasticAlignment::new(reference_trace, 1, euclidean_distance); +/// +/// let aligned_trace = elastic_alignment.align(trace_to_align.view()); +/// ``` +/// +/// # References +/// [1] van Woudenberg, J.G.J., Witteman, M.F., Bakker, B. (2011). Improving Differential Power +/// Analysis by Elastic Alignment. In: Kiayias, A. (eds) Topics in Cryptology – CT-RSA 2011. CT-RSA +/// 2011. Lecture Notes in Computer Science, vol 6558. Springer, Berlin, Heidelberg. +/// https://doi.org/10.1007/978-3-642-19074-2_8 +pub struct ElasticAlignment +where + D: Fn(T, T) -> T, +{ + reference_trace: Array1, + /// FastDTW radius + radius: usize, + /// Distance function + dist: D, +} + +impl ElasticAlignment +where + T: dtw::Average + dtw::SumContainer + Copy + Zero + 'static, + T::Container: Zero + One + Div + AsPrimitive, + D: Fn(T, T) -> T, +{ + /// Creates a new [`ElasticAlignment`] with the given reference trace and distance function. + pub fn new(reference_trace: Array1, radius: usize, dist: D) -> Self { + Self { + reference_trace, + radius, + dist, + } + } + + /// Align given trace using elastic alignment (see [`ElasticAlignment`]). + pub fn align_with_cmp(&self, trace: ArrayView1, cmp: &C) -> Array1 + where + C: Fn(&T::Container, &T::Container) -> Ordering, + { + let warp_path = dtw::fast_dtw_with_cmp( + self.reference_trace.as_slice().unwrap(), + trace.as_slice().unwrap(), + self.radius, + &self.dist, + cmp, + ); + + let mut aligned_trace = Array1::zeros([trace.len()]); + let mut k = 0; + for j in 0..trace.len() { + let mut count = T::Container::zero(); + let mut sum = T::Container::zero(); + + while k < warp_path.len() && warp_path[k].0 == j { + count = count + T::Container::one(); + sum = sum + T::Container::from(trace[warp_path[k].1]); + k += 1; + } + + aligned_trace[j] = (sum / count).as_(); + } + + aligned_trace + } +} + +impl ElasticAlignment +where + T: dtw::Average + dtw::SumContainer + Copy + Zero + 'static, + T::Container: Zero + One + Div + AsPrimitive + Ord, + D: Fn(T, T) -> T, +{ + /// Align given trace using elastic alignment (see [`ElasticAlignment`]). + /// + /// NOTE: See [`ElasticAlignment::align_with_cmp`] for type than do not implement [`Ord`]. + pub fn align(&self, trace: ArrayView1) -> Array1 { + self.align_with_cmp(trace, &T::Container::cmp) + } +} + #[cfg(test)] mod tests { - use crate::preprocessors::StandardScaler; - - use super::CenteredProduct; - use super::Power; + use crate::preprocessors::{dist, CenteredProduct, ElasticAlignment, Power, StandardScaler}; use ndarray::array; fn round_to_2_digits(x: f64) -> f64 { @@ -285,4 +380,29 @@ mod tests { ); } } + + #[test] + fn test_elastic_align() { + let reference_trace = array![77, 117, 5, 51, 91, -12, -33]; + let trace = array![77, 117, 13, 15, 5, 51, 91]; + + let elastic_alignment = + ElasticAlignment::new(reference_trace.clone(), 1, dist::euclidean_distance); + assert_eq!( + elastic_alignment.align(trace.view()), + array![77, 117, 11, 51, 51, 51, 91] + ); + } + + #[test] + fn test_elastic_align_same_trace() { + let reference_trace = array![77, 117, 5, 51, 91, -12, -33]; + + let elastic_alignment = + ElasticAlignment::new(reference_trace.clone(), 1, dist::euclidean_distance); + assert_eq!( + elastic_alignment.align(reference_trace.view()), + reference_trace + ); + } }