Skip to content

Commit

Permalink
Populate shift in parser
Browse files Browse the repository at this point in the history
  • Loading branch information
jorickert committed Feb 22, 2025
1 parent 987183a commit 48190a3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
19 changes: 19 additions & 0 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,25 @@ mlir::ParseResult parseQuantizeDequantizeLikeOp(
result.addAttribute(zeroPointAttrName, zeroPointAttr);
}

// Try to populate shift form scale, but only if the zero point is zero
if (!result.attributes.getNamed(shiftAttrName)) {
const auto scaleAttr = result.attributes.getNamed(scaleAttrName);
if (scaleAttr) {
const auto calculatedShift =
getShiftValue(cast<mlir::FloatAttr>(scaleAttr->getValue())
.getValue()
.convertToFloat());
if (calculatedShift &&
cast<IntegerAttr>(
result.attributes.getNamed(zeroPointAttrName)->getValue())
.getValue()
.isZero()) {
result.addAttribute(shiftAttrName,
builder.getSI32IntegerAttr(*calculatedShift));
}
}
}

if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc,
result.operands))
return mlir::failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func.func @qdq_different_shift(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32

func.func @qdq_different_zero(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
%0 = "tosa.concat"(%arg0, %arg0) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
%1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {scale = 3.125000e-02 : f32, zero_point = 0 : i8} -> tensor<1x2x7x7xi8>
%1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8>
%2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {scale = 3.125000e-02 : f32, zero_point = 1 : i8} -> tensor<1x2x7x7xf32>
%3 = "tosa.concat"(%2, %2) {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
return %3 : tensor<1x4x7x7xf32>
Expand All @@ -226,7 +226,7 @@ func.func @qdq_different_zero(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
// CHECK-LABEL: func.func @qdq_different_zero(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
// CHECK: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {scale = 3.125000e-02 : f32, zero_point = 0 : i8} -> tensor<1x2x7x7xi8>
// CHECK: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8>
// CHECK: %[[VAL_3:.*]] = xten_nn.dequantize(%[[VAL_2]] : tensor<1x2x7x7xi8>) {scale = 3.125000e-02 : f32, zero_point = 1 : i8} -> tensor<1x2x7x7xf32>
// CHECK: %[[VAL_4:.*]] = tosa.concat %[[VAL_3]], %[[VAL_3]] {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
// CHECK: return %[[VAL_4]] : tensor<1x4x7x7xf32>
Expand All @@ -235,7 +235,7 @@ func.func @qdq_different_zero(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>
// SANE-LABEL: func.func @qdq_different_zero(
// SANE-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
// SANE: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
// SANE: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {scale = 3.125000e-02 : f32, zero_point = 0 : i8} -> tensor<1x2x7x7xi8>
// SANE: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8>
// SANE: %[[VAL_3:.*]] = xten_nn.dequantize(%[[VAL_2]] : tensor<1x2x7x7xi8>) {scale = 3.125000e-02 : f32, zero_point = 1 : i8} -> tensor<1x2x7x7xf32>
// SANE: %[[VAL_4:.*]] = tosa.concat %[[VAL_3]], %[[VAL_3]] {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
// SANE: return %[[VAL_4]] : tensor<1x4x7x7xf32>
Expand All @@ -246,26 +246,26 @@ func.func @qdq_different_zero(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32>

func.func @qdq_different_scale(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
%0 = "tosa.concat"(%arg0, %arg0) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
%1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {scale = 3.125000e-02 : f32, zero_point = 0 : i8} -> tensor<1x2x7x7xi8>
%2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {scale = 1.250000e-01 : f32, zero_point = 0 : i8} -> tensor<1x2x7x7xf32>
%1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8>
%2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32>
%3 = "tosa.concat"(%2, %2) {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
return %3 : tensor<1x4x7x7xf32>
}

// CHECK-LABEL: func.func @qdq_different_scale(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
// CHECK: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {scale = 3.125000e-02 : f32, zero_point = 0 : i8} -> tensor<1x2x7x7xi8>
// CHECK: %[[VAL_3:.*]] = xten_nn.dequantize(%[[VAL_2]] : tensor<1x2x7x7xi8>) {scale = 1.250000e-01 : f32, zero_point = 0 : i8} -> tensor<1x2x7x7xf32>
// CHECK: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8>
// CHECK: %[[VAL_3:.*]] = xten_nn.dequantize(%[[VAL_2]] : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32>
// CHECK: %[[VAL_4:.*]] = tosa.concat %[[VAL_3]], %[[VAL_3]] {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
// CHECK: return %[[VAL_4]] : tensor<1x4x7x7xf32>
// CHECK: }

// SANE-LABEL: func.func @qdq_different_scale(
// SANE-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> {
// SANE: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32>
// SANE: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {scale = 3.125000e-02 : f32, zero_point = 0 : i8} -> tensor<1x2x7x7xi8>
// SANE: %[[VAL_3:.*]] = xten_nn.dequantize(%[[VAL_2]] : tensor<1x2x7x7xi8>) {scale = 1.250000e-01 : f32, zero_point = 0 : i8} -> tensor<1x2x7x7xf32>
// SANE: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8>
// SANE: %[[VAL_3:.*]] = xten_nn.dequantize(%[[VAL_2]] : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32>
// SANE: %[[VAL_4:.*]] = tosa.concat %[[VAL_3]], %[[VAL_3]] {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32>
// SANE: return %[[VAL_4]] : tensor<1x4x7x7xf32>
// SANE: }
Expand Down

0 comments on commit 48190a3

Please sign in to comment.