Skip to content

Commit

Permalink
Adjust bound_shape_inferencer to take 4 inputs for FCs (pytorch#41934)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#41934

The model exported from online training workflow with int8 quantization contains FCs with 4 inputs. The extra input is the quant_param blob. This diff is to adjust the bound_shape_inferencer and int8 op schema to get shape info for the quant_param input.

Test Plan:
```
buck test caffe2/caffe2/opt:bound_shape_inference_test
```

Reviewed By: yinghai

Differential Revision: D22683554

fbshipit-source-id: 684d1433212a528120aba1c37d27e26b6a31b403
  • Loading branch information
csummersea authored and facebook-github-bot committed Aug 6, 2020
1 parent 9ea9d1b commit 509fb77
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 14 deletions.
2 changes: 1 addition & 1 deletion caffe2/operators/quantized/int8_fc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ will throw errors.
.Input(
3,
"Qparam",
"Optional Qparam blob that constans quant param computed on activation histogram data"
"Optional Qparam blob that contains quant param computed on activation histogram data"
"Will overwrite Y_scale and Y_zero_point argument if specified")
.Output(0, "Y", "2D output tensor");

Expand Down
13 changes: 10 additions & 3 deletions caffe2/operators/quantized/int8_quantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,23 @@ namespace caffe2 {
REGISTER_CPU_OPERATOR(Int8Quantize, int8::Int8QuantizeOp);

OPERATOR_SCHEMA(Int8Quantize)
.IdenticalTypeAndShape()
.Arg("Y_scale", "Output tensor quantization scale")
.Arg("Y_zero_point", "Output tensor quantization offset")
.NumInputs(1, 3)
.NumInputs(1, 2)
.NumOutputs(1)
.TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
vector<TensorShape> out;
TensorShape X = in[0];
out.emplace_back(std::move(X));
out[0].set_data_type(TensorProto_DataType_UINT8);
return out;
})
.Input(0, "X", "FP32 Tensor X.")
.Input(
1,
"Qparam",
"Optional Qparam blob that constans quant param computed on activation histogram data"
"Optional Qparam blob that contains quant param computed on activation histogram data"
"Will overwrite Y_scale and Y_zero_point argument if specified")
.Output(0, "Y", "Int8 Tensor qX representing X with linear quantization.");

Expand Down
28 changes: 28 additions & 0 deletions caffe2/opt/bound_shape_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ TEST(BoundShapeInference, DISABLED_ON_WINDOWS(FC)) {
CreateOperatorDef("FC", "", {"X0", "W0", "B0"}, {"Out0"}, {}));
net.add_op()->CopyFrom(
CreateOperatorDef("FCTransposed", "", {"X1", "W1", "B1"}, {"Out1"}, {}));
net.add_op()->CopyFrom(CreateOperatorDef(
"Int8FC", "", {"X2", "W2", "B2", "quant_param"}, {"Out2"}, {}));
ShapeInfoMap shape_map;
shape_map.emplace(
"W0",
Expand All @@ -651,6 +653,18 @@ TEST(BoundShapeInference, DISABLED_ON_WINDOWS(FC)) {
{16, 1024}));
shape_map.emplace(
"B1", makeTensorInfo({TensorBoundShape_DimType_CONSTANT}, {1024}));

shape_map.emplace(
"W2",
makeTensorInfo(
{TensorBoundShape_DimType_CONSTANT,
TensorBoundShape_DimType_CONSTANT},
{16, 1024}));
shape_map.emplace(
"B2", makeTensorInfo({TensorBoundShape_DimType_CONSTANT}, {16}));
shape_map.emplace(
"quant_param", makeTensorInfo({TensorBoundShape_DimType_CONSTANT}, {1}));

BoundShapeSpec spec(20, 1000);
BoundShapeInferencer eng(spec);
eng.InferBoundShapeAndType(net, shape_map, nullptr);
Expand All @@ -675,6 +689,20 @@ TEST(BoundShapeInference, DISABLED_ON_WINDOWS(FC)) {
"Out1",
{TensorBoundShape_DimType_BATCH, TensorBoundShape_DimType_CONSTANT},
{spec.max_batch_size, 1024});
verifyShapeInfo(
out_shape,
"X2",
{TensorBoundShape_DimType_BATCH, TensorBoundShape_DimType_CONSTANT},
{spec.max_batch_size, 1024},
TensorProto_DataType_UINT8,
true);
verifyShapeInfo(
out_shape,
"Out2",
{TensorBoundShape_DimType_BATCH, TensorBoundShape_DimType_CONSTANT},
{spec.max_batch_size, 16},
TensorProto_DataType_UINT8,
true);
}

TEST(BoundShapeInference, FC3D) {
Expand Down
36 changes: 31 additions & 5 deletions caffe2/opt/bound_shape_inferencer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,9 @@ void BoundShapeInferencer::InferConcat(const OperatorDef& op) {
}

void BoundShapeInferencer::InferFC(const OperatorDef& op) {
CAFFE_ENFORCE_EQ(op.input_size(), 3, "FC has to have 3 inputs");
CAFFE_ENFORCE(
op.input_size() == 3 || op.input_size() == 4,
"FC has to have 3 or 4 inputs");
const auto w_it = shape_info_.find(op.input(1));
CAFFE_ENFORCE(
w_it != shape_info_.end(),
Expand Down Expand Up @@ -670,6 +672,16 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) {
// Standard shape inference for outputs
std::vector<TensorShape> input_shapes{
shape_info_[op.input(0)].shape, w_shape_info.shape, b_shape_info.shape};
if (op.input_size() == 4) {
const auto quant_param_it = shape_info_.find(op.input(3));
CAFFE_ENFORCE(
quant_param_it != shape_info_.end(),
"Shape of quant_param input of FC ",
op.input(3),
" needs to be presented");
const ShapeInfo& quant_param_shape_info = quant_param_it->second;
input_shapes.emplace_back(quant_param_shape_info.shape);
}
std::vector<TensorShape> output_shapes = InferOutput(op, input_shapes);
CAFFE_ENFORCE_EQ(output_shapes.size(), 1);
TensorProto::DataType output_data_type;
Expand Down Expand Up @@ -795,29 +807,43 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
// First, we need to check that all the input shape/types are already
// presented
try {
const static std::unordered_set<std::string>
types_with_independent_output_shape = {"Int8GenQuantParams",
"Int8QuantSchemeBlobFill"};
std::vector<TensorShape> input_shapes;
for (const auto& input : op.input()) {
const auto it = shape_info_.find(input);
if (it == shape_info_.end()) {
if (it == shape_info_.end() &&
!types_with_independent_output_shape.count(op.type())) {
LOG(WARNING) << "Cannot find shape info for " << input << ". Skipping "
<< op.type();
return;
}
input_shapes.emplace_back(it->second.shape);
if (types_with_independent_output_shape.count(op.type())) {
TensorShape input_shape;
input_shapes.emplace_back(std::move(input_shape));

} else {
input_shapes.emplace_back(it->second.shape);
}
}

const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
CAFFE_ENFORCE(schema);
std::vector<TensorShape> output_shapes;
output_shapes = schema->InferTensor(op, input_shapes);
bool is_quantized =
!(op.type().compare(0, 4, "Int8")) && (op.type() != "Int8Dequantize");
bool is_quantized = !(op.type().compare(0, 4, "Int8")) &&
(op.type() != "Int8Dequantize") &&
(op.type() != "Int8QuantSchemeBlobFill") &&
(op.type() != "Int8GenQuantParams");
float scale = 1;
int offset = 0;
TensorProto::DataType infered_data_type = TensorProto::UNDEFINED;
if (is_quantized) {
const static std::map<std::string, int> type_info_from_input = {
{"Int8Quantize", -1}, // Force this op's output to be uint8
{"Int8FCPackWeight", 0},
{"Int8ConvPackWeight", 0},
{"Int8ConvRelu", 1},
{"Int8MaxPool", 0},
{"Int8AveragePool", 0},
Expand Down
7 changes: 2 additions & 5 deletions caffe2/quantization/server/int8_gen_quant_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@ OPERATOR_SCHEMA(Int8GenQuantParams)
.NumOutputs(1)
.TensorInferenceFunction([](const OperatorDef& /* def */,
const vector<TensorShape>& in) {
vector<TensorShape> out;
TensorShape X = in[0];
X.clear_dims();
X.add_dims(1);
out.emplace_back(std::move(X));
vector<TensorShape> out(1);
out[0].set_data_type(TensorProto_DataType_FLOAT);
out[0].add_dims(1);
return out;
})
.Input(
Expand Down
7 changes: 7 additions & 0 deletions caffe2/quantization/server/int8_quant_scheme_blob_fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ REGISTER_CPU_OPERATOR(
OPERATOR_SCHEMA(Int8QuantSchemeBlobFill)
.NumInputs(0)
.NumOutputs(1)
.TensorInferenceFunction([](const OperatorDef& /* def */,
const vector<TensorShape>& in) {
vector<TensorShape> out(1);
out[0].set_data_type(TensorProto_DataType_STRING);
out[0].add_dims(1);
return out;
})
.Arg(
"quantization_kind",
"The kind of quant scheme that would be used to generate quant param")
Expand Down

0 comments on commit 509fb77

Please sign in to comment.