Skip to content

Commit

Permalink
[quant][graphmode][refactor] Move the whitelists to a centeralized pl…
Browse files Browse the repository at this point in the history
…ace (pytorch#35721)

Summary: Pull Request resolved: pytorch#35721

Test Plan:
.

Imported from OSS

Differential Revision: D20771829

fbshipit-source-id: f6ec3afe2d8034acbdbd81e5a6fbd4a2a76aa7ac
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Apr 2, 2020
1 parent e372f42 commit b3c0939
Showing 1 changed file with 47 additions and 41 deletions.
88 changes: 47 additions & 41 deletions torch/csrc/jit/passes/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@ struct PatternsAndModules {
Module packed_params_module;
};

std::vector<std::string> _quantizable_call_funcs = {
"conv2d",
"linear",
};

std::vector<std::string> _quantizable_aten_funcs = {
"conv2d",
"conv3d",
"linear",
"addmm",
"matmul",
"add_",
"add",
"cat",
};

// These are the prim::CallFunctions that doesn't require observation and
// have a single input Tensor
// example: `prim::CallFunction(%dropout, %input_tensor, ...)
Expand All @@ -81,9 +97,35 @@ std::vector<std::string> _single_input_general_call_funcs = {
"relu",
};

std::vector<std::string> _quantizable_call_funcs = {
"conv2d",
"linear",
// Similar to prim::CallFunctions, there are aten ops that doesn't
// require observation and have a single input Tensor
// e.g. `aten::max_pool2d(%input_tensor, ...)`
std::vector<std::string> _single_input_general_aten_funcs = {
"max_pool2d",
"avg_pool2d",
"flatten",
"max",
"min",
"mean",
"upsample_nearest1d",
"upsample_nearest2d",
"upsample_nearest3d",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"upsample_linear1d",
"upsample_bilinear2d",
"upsample_trilinear3d",
"upsample_bicubic2d",
"dropout",
"reshape",
"chunk",
"view",
"transpose",
"contiguous",
"permute",
"repeat_interleave",
"relu",
};

void fillQConfigMap(
Expand Down Expand Up @@ -167,33 +209,6 @@ bool isAddScalar(Node* n) {
// the quantization parameters for `v` given the list of values
std::vector<Value*> getPassThroughInputs(Value* v) {
Node* n = v->node();
std::vector<std::string> single_input_aten_funcs = {
"max_pool2d",
"avg_pool2d",
"flatten",
"max",
"min",
"mean",
"upsample_nearest1d",
"upsample_nearest2d",
"upsample_nearest3d",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"upsample_linear1d",
"upsample_bilinear2d",
"upsample_trilinear3d",
"upsample_bicubic2d",
"dropout",
"reshape",
"chunk",
"view",
"transpose",
"contiguous",
"permute",
"repeat_interleave",
"relu",
};
if (isFunctionNode(
n,
// We don't have call functions
Expand All @@ -207,7 +222,7 @@ std::vector<Value*> getPassThroughInputs(Value* v) {
// We don't have call functions
// after inline
/* call_funcs = */ {},
/* aten_funcs = */ single_input_aten_funcs) ||
/* aten_funcs = */ _single_input_general_aten_funcs) ||
(n->kind() == Symbol::aten("sort") && v->offset() == 0)) {
return {n->input(0)};
} else if (n->kind() == prim::If && n->outputs().size() == 1) {
Expand Down Expand Up @@ -242,16 +257,7 @@ bool nodeQuantizable(Node* n) {
/* call_funcs = */
_quantizable_call_funcs,
/* aten_funcs = */
{
"conv2d",
"conv3d",
"linear",
"addmm",
"matmul",
"add_",
"add",
"cat",
});
_quantizable_aten_funcs);
}

// We don't want to analyze the graph for some `builtin` CallFunctions
Expand Down

0 comments on commit b3c0939

Please sign in to comment.