Skip to content

Commit

Permalink
Implement DynamicQuantizeLinear operator
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Sep 6, 2024
1 parent 600f4ea commit 6628040
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 7 deletions.
1 change: 1 addition & 0 deletions rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class OperatorType(object):
If = 104
DequantizeLinear = 105
QuantizeLinear = 106
DynamicQuantizeLinear = 107


class RNNDirection(object):
Expand Down
2 changes: 2 additions & 0 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub enum OpType<'a> {
Cos,
DequantizeLinear(DequantizeLinear),
Div,
DynamicQuantizeLinear,
Einsum(Einsum),
Elu(Elu),
Equal,
Expand Down Expand Up @@ -517,6 +518,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
}
),
OpType::Div => op!(Div),
OpType::DynamicQuantizeLinear => op!(DynamicQuantizeLinear),
OpType::Einsum(args) => {
let equation = self.builder.create_string(&args.equation);
op_with_attrs!(
Expand Down
2 changes: 2 additions & 0 deletions src/op_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ impl OpRegistry {
register_op!(CumSum);
register_op!(DequantizeLinear);
register_op!(Div);
register_op!(DynamicQuantizeLinear);
register_op!(Einsum);
register_op!(Elu);
register_op!(Equal);
Expand Down Expand Up @@ -476,6 +477,7 @@ impl_read_op!(Cos);
impl_read_op!(CumSum);
impl_read_op!(DequantizeLinear, attrs_as_dequantize_linear_attrs, axis);
impl_read_op!(Div);
impl_read_op!(DynamicQuantizeLinear);
impl_read_op!(Einsum, attrs_as_einsum_attrs, |attrs: sg::EinsumAttrs| {
Ok(ops::Einsum {
equation: attrs.equation().unwrap_or("").to_string(),
Expand Down
5 changes: 4 additions & 1 deletion src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ pub use pad::{pad, Pad, PadMode};
pub use pooling::{
average_pool, global_average_pool, max_pool, AveragePool, GlobalAveragePool, MaxPool,
};
pub use quantize::{dequantize_linear, quantize_linear, DequantizeLinear, QuantizeLinear};
pub use quantize::{
dequantize_linear, dynamic_quantize_linear, quantize_linear, DequantizeLinear,
DynamicQuantizeLinear, QuantizeLinear,
};

#[cfg(feature = "random")]
pub use random::{RandomNormal, RandomNormalLike, RandomUniform, RandomUniformLike};
Expand Down
197 changes: 194 additions & 3 deletions src/ops/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, Scalar, Tensor, TensorView};

use crate::ops::{
resolve_axis, DataType, Input, InputList, IntoOpResult, OpError, Operator, OutputList,
reduce_max, reduce_min, resolve_axis, DataType, Input, InputList, IntoOpResult, OpError,
Operator, Output, OutputList,
};
use crate::tensor_pool::TensorPool;
use crate::tensor_pool::{AutoReturn, TensorPool};

/// Convert a quantized tensor element to a higher precision value.
pub trait Dequantize<To> {
Expand Down Expand Up @@ -247,14 +248,140 @@ impl Operator for QuantizeLinear {
}
}

pub trait SaturatingCast<To> {
fn saturating_cast(self) -> To;
}

impl SaturatingCast<u8> for f32 {
fn saturating_cast(self) -> u8 {
self.clamp(0., 255.) as u8
}
}

pub struct DynamicQuantizeOutput<T> {
pub quantized: Tensor<T>,
pub scale: Tensor<f32>,
pub zero_point: Tensor<T>,
}

pub fn dynamic_quantize_linear<T: Copy + Default + Scalar>(
pool: &TensorPool,
input: TensorView<f32>,
) -> Result<DynamicQuantizeOutput<T>, OpError>
where
f32: Quantize<T> + SaturatingCast<T>,
{
// From the ONNX spec, this operator is defined in terms of other ONNX
// operators as:
//
// ```
// DynamicQuantizeLinear (x) => (y, y_scale, y_zero_point)
// {
// Q_Min = Constant <value: tensor = float {0}> ()
// Q_Max = Constant <value: tensor = float {255}> ()
// X_Min = ReduceMin <keepdims: int = 0> (x)
// X_Min_Adjusted = Min (X_Min, Q_Min)
// X_Max = ReduceMax <keepdims: int = 0> (x)
// X_Max_Adjusted = Max (X_Max, Q_Min)
// X_Range = Sub (X_Max_Adjusted, X_Min_Adjusted)
// Scale = Div (X_Range, Q_Max)
// Min_Scaled = Div (X_Min_Adjusted, Scale)
// Initial_ZeroPoint_FP = Sub (Q_Min, Min_Scaled)
// Clipped_ZeroPoint_FP = Clip (Initial_ZeroPoint_FP, Q_Min, Q_Max)
// Rounded_ZeroPoint_FP = Round (Clipped_ZeroPoint_FP)
// Zeropoint = Cast <to: int = 2> (Rounded_ZeroPoint_FP)
// y_scale = Identity (Scale)
// y_zero_point = Identity (Zeropoint)
// y = QuantizeLinear (x, Scale, Zeropoint)
// }
// ```

if input.is_empty() {
// If the input is empty, the zero point and scale can be chosen
// arbitrarily. We pick zero/one as natural choices.
return Ok(DynamicQuantizeOutput {
quantized: Tensor::zeros(input.shape()),
zero_point: Tensor::from(T::default()),
scale: Tensor::from(1.),
});
}

let q_min = 0.;
let q_max = 255.;

// Get the range of the input. This implementation is simple but sub-optimal
// as it makes two passes over the same data to get the min/max.
let x_min = reduce_min(pool, input.view(), None, false /* keep_dims */)?
.auto_return(pool)
.item()
.copied()
.unwrap();
let x_min_adjusted = x_min.min(q_min);
let x_max = reduce_max(pool, input.view(), None, false /* keep_dims */)?
.auto_return(pool)
.item()
.copied()
.unwrap();
let x_max_adjusted = x_max.max(q_min);
let x_range = x_max_adjusted - x_min_adjusted;
let scale = x_range / q_max;
let min_scaled = x_min_adjusted / scale;
let initial_zero_point = q_min - min_scaled;
let clipped_zero_point = initial_zero_point.clamp(q_min, q_max);
let rounded_zero_point = clipped_zero_point.round_ties_even();
let zero_point: T = rounded_zero_point.saturating_cast();

let scale_tensor = Tensor::from(scale);
let zero_point_tensor = Tensor::from(zero_point);
let quantized = quantize_linear(
pool,
input,
scale_tensor.view(),
Some(zero_point_tensor.view()),
1, /* axis */
)?;

Ok(DynamicQuantizeOutput {
quantized,
scale: scale_tensor,
zero_point: zero_point_tensor,
})
}

#[derive(Debug)]
pub struct DynamicQuantizeLinear {}

impl Operator for DynamicQuantizeLinear {
fn name(&self) -> &str {
"DynamicQuantizeLinear"
}

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let input = inputs.require_as(0)?;

let DynamicQuantizeOutput {
quantized,
scale,
zero_point,
} = dynamic_quantize_linear::<u8>(pool, input)?;

let quantized: Output = quantized.into();
let scale: Output = scale.into();
let zero_point: Output = zero_point.into();

Ok([quantized, scale, zero_point].into_iter().collect())
}
}

#[cfg(test)]
mod tests {
use std::error::Error;

use rten_tensor::prelude::*;
use rten_tensor::test_util::expect_equal_with_tolerance;
use rten_tensor::Tensor;

use super::{dequantize_linear, quantize_linear};
use super::{dequantize_linear, dynamic_quantize_linear, quantize_linear};
use crate::ops::tests::new_pool;
use crate::ops::{OpError, Output};

Expand Down Expand Up @@ -394,4 +521,68 @@ mod tests {

Ok(())
}

#[test]
fn test_dynamic_quantize_linear() -> Result<(), Box<dyn Error>> {
struct Case {
input: Tensor<f32>,
max_error: f32,
}

let cases = [
// Inputs centered around zero. Zero point should be ~0.
Case {
input: [-2., -1., 0., 1., 2.].into(),
max_error: 0.01,
},
// Positive inputs.
Case {
input: [1., 2., 3., 4., 5.].into(),
max_error: 0.01,
},
// Negative inputs.
Case {
input: [-1., -2., -3., -4., -5.].into(),
max_error: 0.01,
},
// Small input values
Case {
input: Tensor::arange(-0.1, 0.1, Some(0.01)),
max_error: 0.001,
},
// All values equal (positive)
Case {
input: Tensor::from([234.56]),
max_error: 0.,
},
// All values equal (negative)
Case {
input: Tensor::from([-234.56]),
max_error: 0.,
},
// Empty tensor
Case {
input: Tensor::zeros(&[0]),
max_error: 0.,
},
];
let pool = new_pool();

for Case { input, max_error } in cases {
// Quantize input.
let output = dynamic_quantize_linear::<u8>(&pool, input.view())?;
assert_eq!(output.quantized.shape(), input.shape());
let zero_point = *output.zero_point.item().unwrap();
let scale = *output.scale.item().unwrap();

// Dequantize the results and check the value is close enough to
// the inputs.
let dequantized = output
.quantized
.map(|&q| (q as i32 - zero_point as i32) as f32 * scale);
expect_equal_with_tolerance(&dequantized, &input, max_error, max_error)?;
}

Ok(())
}
}
1 change: 1 addition & 0 deletions src/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ enum OperatorType: ubyte {
If,
DequantizeLinear,
QuantizeLinear,
DynamicQuantizeLinear,
}

enum RNNDirection: ubyte {
Expand Down
10 changes: 7 additions & 3 deletions src/schema_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ pub const ENUM_MIN_OPERATOR_TYPE: u8 = 0;
since = "2.0.0",
note = "Use associated constants instead. This will no longer be generated in 2021."
)]
pub const ENUM_MAX_OPERATOR_TYPE: u8 = 106;
pub const ENUM_MAX_OPERATOR_TYPE: u8 = 107;
#[deprecated(
since = "2.0.0",
note = "Use associated constants instead. This will no longer be generated in 2021."
)]
#[allow(non_camel_case_types)]
pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 107] = [
pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 108] = [
OperatorType::Add,
OperatorType::ArgMin,
OperatorType::ArgMax,
Expand Down Expand Up @@ -132,6 +132,7 @@ pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 107] = [
OperatorType::If,
OperatorType::DequantizeLinear,
OperatorType::QuantizeLinear,
OperatorType::DynamicQuantizeLinear,
];

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
Expand Down Expand Up @@ -246,9 +247,10 @@ impl OperatorType {
pub const If: Self = Self(104);
pub const DequantizeLinear: Self = Self(105);
pub const QuantizeLinear: Self = Self(106);
pub const DynamicQuantizeLinear: Self = Self(107);

pub const ENUM_MIN: u8 = 0;
pub const ENUM_MAX: u8 = 106;
pub const ENUM_MAX: u8 = 107;
pub const ENUM_VALUES: &'static [Self] = &[
Self::Add,
Self::ArgMin,
Expand Down Expand Up @@ -357,6 +359,7 @@ impl OperatorType {
Self::If,
Self::DequantizeLinear,
Self::QuantizeLinear,
Self::DynamicQuantizeLinear,
];
/// Returns the variant's name or "" if unknown.
pub fn variant_name(self) -> Option<&'static str> {
Expand Down Expand Up @@ -468,6 +471,7 @@ impl OperatorType {
Self::If => Some("If"),
Self::DequantizeLinear => Some("DequantizeLinear"),
Self::QuantizeLinear => Some("QuantizeLinear"),
Self::DynamicQuantizeLinear => Some("DynamicQuantizeLinear"),
_ => None,
}
}
Expand Down

0 comments on commit 6628040

Please sign in to comment.