From 6628040e4bd5ee417c24281b839a1ac3b80314c7 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sat, 31 Aug 2024 14:17:32 +0100 Subject: [PATCH] Implement DynamicQuantizeLinear operator --- rten-convert/rten_convert/schema_generated.py | 1 + src/model_builder.rs | 2 + src/op_registry.rs | 2 + src/ops/mod.rs | 5 +- src/ops/quantize.rs | 197 +++++++++++++++++- src/schema.fbs | 1 + src/schema_generated.rs | 10 +- 7 files changed, 211 insertions(+), 7 deletions(-) diff --git a/rten-convert/rten_convert/schema_generated.py b/rten-convert/rten_convert/schema_generated.py index 5e8d38bb..64674bc3 100644 --- a/rten-convert/rten_convert/schema_generated.py +++ b/rten-convert/rten_convert/schema_generated.py @@ -114,6 +114,7 @@ class OperatorType(object): If = 104 DequantizeLinear = 105 QuantizeLinear = 106 + DynamicQuantizeLinear = 107 class RNNDirection(object): diff --git a/src/model_builder.rs b/src/model_builder.rs index cbdada27..0953a9ce 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -48,6 +48,7 @@ pub enum OpType<'a> { Cos, DequantizeLinear(DequantizeLinear), Div, + DynamicQuantizeLinear, Einsum(Einsum), Elu(Elu), Equal, @@ -517,6 +518,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { } ), OpType::Div => op!(Div), + OpType::DynamicQuantizeLinear => op!(DynamicQuantizeLinear), OpType::Einsum(args) => { let equation = self.builder.create_string(&args.equation); op_with_attrs!( diff --git a/src/op_registry.rs b/src/op_registry.rs index 4469e7bb..e5682617 100644 --- a/src/op_registry.rs +++ b/src/op_registry.rs @@ -102,6 +102,7 @@ impl OpRegistry { register_op!(CumSum); register_op!(DequantizeLinear); register_op!(Div); + register_op!(DynamicQuantizeLinear); register_op!(Einsum); register_op!(Elu); register_op!(Equal); @@ -476,6 +477,7 @@ impl_read_op!(Cos); impl_read_op!(CumSum); impl_read_op!(DequantizeLinear, attrs_as_dequantize_linear_attrs, axis); impl_read_op!(Div); +impl_read_op!(DynamicQuantizeLinear); impl_read_op!(Einsum, attrs_as_einsum_attrs, |attrs: sg::EinsumAttrs| { Ok(ops::Einsum { equation: attrs.equation().unwrap_or("").to_string(), diff --git a/src/ops/mod.rs b/src/ops/mod.rs index 0096cbec..087e5c49 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -92,7 +92,10 @@ pub use pad::{pad, Pad, PadMode}; pub use pooling::{ average_pool, global_average_pool, max_pool, AveragePool, GlobalAveragePool, MaxPool, }; -pub use quantize::{dequantize_linear, quantize_linear, DequantizeLinear, QuantizeLinear}; +pub use quantize::{ + dequantize_linear, dynamic_quantize_linear, quantize_linear, DequantizeLinear, + DynamicQuantizeLinear, QuantizeLinear, +}; #[cfg(feature = "random")] pub use random::{RandomNormal, RandomNormalLike, RandomUniform, RandomUniformLike}; diff --git a/src/ops/quantize.rs b/src/ops/quantize.rs index 31e7a4fc..4453ef35 100644 --- a/src/ops/quantize.rs +++ b/src/ops/quantize.rs @@ -2,9 +2,10 @@ use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView, Scalar, Tensor, TensorView}; use crate::ops::{ - resolve_axis, DataType, Input, InputList, IntoOpResult, OpError, Operator, OutputList, + reduce_max, reduce_min, resolve_axis, DataType, Input, InputList, IntoOpResult, OpError, + Operator, Output, OutputList, }; -use crate::tensor_pool::TensorPool; +use crate::tensor_pool::{AutoReturn, TensorPool}; /// Convert a quantized tensor element to a higher precision value. pub trait Dequantize { @@ -247,14 +248,140 @@ impl Operator for QuantizeLinear { } } +pub trait SaturatingCast { + fn saturating_cast(self) -> To; +} + +impl SaturatingCast for f32 { + fn saturating_cast(self) -> u8 { + self.clamp(0., 255.) as u8 + } +} + +pub struct DynamicQuantizeOutput { + pub quantized: Tensor, + pub scale: Tensor, + pub zero_point: Tensor, +} + +pub fn dynamic_quantize_linear( + pool: &TensorPool, + input: TensorView, +) -> Result, OpError> +where + f32: Quantize + SaturatingCast, +{ + // From the ONNX spec, this operator is defined in terms of other ONNX + // operators as: + // + // ``` + // DynamicQuantizeLinear (x) => (y, y_scale, y_zero_point) + // { + // Q_Min = Constant () + // Q_Max = Constant () + // X_Min = ReduceMin (x) + // X_Min_Adjusted = Min (X_Min, Q_Min) + // X_Max = ReduceMax (x) + // X_Max_Adjusted = Max (X_Max, Q_Min) + // X_Range = Sub (X_Max_Adjusted, X_Min_Adjusted) + // Scale = Div (X_Range, Q_Max) + // Min_Scaled = Div (X_Min_Adjusted, Scale) + // Initial_ZeroPoint_FP = Sub (Q_Min, Min_Scaled) + // Clipped_ZeroPoint_FP = Clip (Initial_ZeroPoint_FP, Q_Min, Q_Max) + // Rounded_ZeroPoint_FP = Round (Clipped_ZeroPoint_FP) + // Zeropoint = Cast (Rounded_ZeroPoint_FP) + // y_scale = Identity (Scale) + // y_zero_point = Identity (Zeropoint) + // y = QuantizeLinear (x, Scale, Zeropoint) + // } + // ``` + + if input.is_empty() { + // If the input is empty, the zero point and scale can be chosen + // arbitrarily. We pick zero/one as natural choices. + return Ok(DynamicQuantizeOutput { + quantized: Tensor::zeros(input.shape()), + zero_point: Tensor::from(T::default()), + scale: Tensor::from(1.), + }); + } + + let q_min = 0.; + let q_max = 255.; + + // Get the range of the input. This implementation is simple but sub-optimal + // as it makes two passes over the same data to get the min/max. + let x_min = reduce_min(pool, input.view(), None, false /* keep_dims */)? + .auto_return(pool) + .item() + .copied() + .unwrap(); + let x_min_adjusted = x_min.min(q_min); + let x_max = reduce_max(pool, input.view(), None, false /* keep_dims */)? + .auto_return(pool) + .item() + .copied() + .unwrap(); + let x_max_adjusted = x_max.max(q_min); + let x_range = x_max_adjusted - x_min_adjusted; + let scale = x_range / q_max; + let min_scaled = x_min_adjusted / scale; + let initial_zero_point = q_min - min_scaled; + let clipped_zero_point = initial_zero_point.clamp(q_min, q_max); + let rounded_zero_point = clipped_zero_point.round_ties_even(); + let zero_point: T = rounded_zero_point.saturating_cast(); + + let scale_tensor = Tensor::from(scale); + let zero_point_tensor = Tensor::from(zero_point); + let quantized = quantize_linear( + pool, + input, + scale_tensor.view(), + Some(zero_point_tensor.view()), + 1, /* axis */ + )?; + + Ok(DynamicQuantizeOutput { + quantized, + scale: scale_tensor, + zero_point: zero_point_tensor, + }) +} + +#[derive(Debug)] +pub struct DynamicQuantizeLinear {} + +impl Operator for DynamicQuantizeLinear { + fn name(&self) -> &str { + "DynamicQuantizeLinear" + } + + fn run(&self, pool: &TensorPool, inputs: InputList) -> Result { + let input = inputs.require_as(0)?; + + let DynamicQuantizeOutput { + quantized, + scale, + zero_point, + } = dynamic_quantize_linear::(pool, input)?; + + let quantized: Output = quantized.into(); + let scale: Output = scale.into(); + let zero_point: Output = zero_point.into(); + + Ok([quantized, scale, zero_point].into_iter().collect()) + } +} + #[cfg(test)] mod tests { use std::error::Error; use rten_tensor::prelude::*; + use rten_tensor::test_util::expect_equal_with_tolerance; use rten_tensor::Tensor; - use super::{dequantize_linear, quantize_linear}; + use super::{dequantize_linear, dynamic_quantize_linear, quantize_linear}; use crate::ops::tests::new_pool; use crate::ops::{OpError, Output}; @@ -394,4 +521,68 @@ mod tests { Ok(()) } + + #[test] + fn test_dynamic_quantize_linear() -> Result<(), Box> { + struct Case { + input: Tensor, + max_error: f32, + } + + let cases = [ + // Inputs centered around zero. Zero point should be ~0. + Case { + input: [-2., -1., 0., 1., 2.].into(), + max_error: 0.01, + }, + // Positive inputs. + Case { + input: [1., 2., 3., 4., 5.].into(), + max_error: 0.01, + }, + // Negative inputs. + Case { + input: [-1., -2., -3., -4., -5.].into(), + max_error: 0.01, + }, + // Small input values + Case { + input: Tensor::arange(-0.1, 0.1, Some(0.01)), + max_error: 0.001, + }, + // All values equal (positive) + Case { + input: Tensor::from([234.56]), + max_error: 0., + }, + // All values equal (negative) + Case { + input: Tensor::from([-234.56]), + max_error: 0., + }, + // Empty tensor + Case { + input: Tensor::zeros(&[0]), + max_error: 0., + }, + ]; + let pool = new_pool(); + + for Case { input, max_error } in cases { + // Quantize input. + let output = dynamic_quantize_linear::(&pool, input.view())?; + assert_eq!(output.quantized.shape(), input.shape()); + let zero_point = *output.zero_point.item().unwrap(); + let scale = *output.scale.item().unwrap(); + + // Dequantize the results and check the value is close enough to + // the inputs. + let dequantized = output + .quantized + .map(|&q| (q as i32 - zero_point as i32) as f32 * scale); + expect_equal_with_tolerance(&dequantized, &input, max_error, max_error)?; + } + + Ok(()) + } } diff --git a/src/schema.fbs b/src/schema.fbs index cddfd6e1..c09b09aa 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -120,6 +120,7 @@ enum OperatorType: ubyte { If, DequantizeLinear, QuantizeLinear, + DynamicQuantizeLinear, } enum RNNDirection: ubyte { diff --git a/src/schema_generated.rs b/src/schema_generated.rs index a3d974b5..e7eb1667 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -18,13 +18,13 @@ pub const ENUM_MIN_OPERATOR_TYPE: u8 = 0; since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] -pub const ENUM_MAX_OPERATOR_TYPE: u8 = 106; +pub const ENUM_MAX_OPERATOR_TYPE: u8 = 107; #[deprecated( since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] #[allow(non_camel_case_types)] -pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 107] = [ +pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 108] = [ OperatorType::Add, OperatorType::ArgMin, OperatorType::ArgMax, @@ -132,6 +132,7 @@ pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 107] = [ OperatorType::If, OperatorType::DequantizeLinear, OperatorType::QuantizeLinear, + OperatorType::DynamicQuantizeLinear, ]; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] @@ -246,9 +247,10 @@ impl OperatorType { pub const If: Self = Self(104); pub const DequantizeLinear: Self = Self(105); pub const QuantizeLinear: Self = Self(106); + pub const DynamicQuantizeLinear: Self = Self(107); pub const ENUM_MIN: u8 = 0; - pub const ENUM_MAX: u8 = 106; + pub const ENUM_MAX: u8 = 107; pub const ENUM_VALUES: &'static [Self] = &[ Self::Add, Self::ArgMin, @@ -357,6 +359,7 @@ impl OperatorType { Self::If, Self::DequantizeLinear, Self::QuantizeLinear, + Self::DynamicQuantizeLinear, ]; /// Returns the variant's name or "" if unknown. pub fn variant_name(self) -> Option<&'static str> { @@ -468,6 +471,7 @@ impl OperatorType { Self::If => Some("If"), Self::DequantizeLinear => Some("DequantizeLinear"), Self::QuantizeLinear => Some("QuantizeLinear"), + Self::DynamicQuantizeLinear => Some("DynamicQuantizeLinear"), _ => None, } }