diff --git a/model_compression_toolkit/core/keras/default_framework_info.py b/model_compression_toolkit/core/keras/default_framework_info.py index df4014405..d26efed71 100644 --- a/model_compression_toolkit/core/keras/default_framework_info.py +++ b/model_compression_toolkit/core/keras/default_framework_info.py @@ -29,7 +29,7 @@ from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.constants import SOFTMAX_THRESHOLD from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \ - KERNEL, DEPTHWISE_KERNEL + KERNEL, DEPTHWISE_KERNEL, GELU from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization """ @@ -75,7 +75,8 @@ TANH: (-1, 1), SWISH: (-0.279, None), RELU: (0, None), - SELU: (None, None), + SELU: (-1.76, None), + GELU: (-0.17, None), } """ diff --git a/model_compression_toolkit/core/pytorch/default_framework_info.py b/model_compression_toolkit/core/pytorch/default_framework_info.py index 1c5961464..f3d965182 100644 --- a/model_compression_toolkit/core/pytorch/default_framework_info.py +++ b/model_compression_toolkit/core/pytorch/default_framework_info.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid -from torch.nn.functional import hardsigmoid, relu, relu6, softmax +from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU +from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu from torch.nn import Conv2d, ConvTranspose2d, Linear from torch import sigmoid @@ -74,7 +74,12 @@ ReLU: (0, None), relu: (0, None), ReLU6: (0, None), - relu6: (0, None)} + relu6: (0, None), + GELU: (-0.17, None), + gelu: (-0.17, None), + SELU: (-1.76, None), + selu: (-1.76, None), + } """ Mapping from a QuantizationMethod to an activation quantizer function. diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py index 5fbeb5b8d..6b7b6279c 100644 --- a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py @@ -17,9 +17,9 @@ import numpy as np import torch.nn.functional -from torch.nn import Conv2d, Linear, PReLU, ELU, Hardswish, Dropout, ZeroPad2d, SiLU +from torch.nn import Conv2d, Linear, PReLU, ELU, Hardswish, Dropout, ZeroPad2d, SiLU, GELU from torch import reshape -from torch.nn.functional import hardswish, silu, prelu, elu +from torch.nn.functional import hardswish, silu, prelu, elu, gelu from torch.nn.functional import avg_pool2d from model_compression_toolkit.core import CoreConfig, FrameworkInfo @@ -68,7 +68,9 @@ def shift_negative_activation_node_matchers(): NodeOperationMatcher(Hardswish) | \ NodeOperationMatcher(hardswish) | \ NodeOperationMatcher(SiLU) | \ - NodeOperationMatcher(silu) + NodeOperationMatcher(silu) | \ + NodeOperationMatcher(GELU) | \ + NodeOperationMatcher(gelu) # Match linear layers where we can add a correction. linear_node = NodeOperationMatcher(Conv2d) | \ diff --git a/tests/pytorch_tests/model_tests/feature_models/shift_negative_activation_test.py b/tests/pytorch_tests/model_tests/feature_models/shift_negative_activation_test.py index 1f89fe638..7145b248f 100644 --- a/tests/pytorch_tests/model_tests/feature_models/shift_negative_activation_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/shift_negative_activation_test.py @@ -27,11 +27,11 @@ class ShiftNegaviteActivationNet(torch.nn.Module): - def __init__(self): + def __init__(self, activation_layer): super(ShiftNegaviteActivationNet, self).__init__() self.conv1 = torch.nn.Conv2d(3, 4, kernel_size=(5,6), stride=2) self.conv2 = torch.nn.Conv2d(4, 5, kernel_size=(8,7), stride=2, bias=False) - self.activation = torch.nn.Hardswish() + self.activation = activation_layer() def forward(self, inp): x0 = self.conv1(inp) @@ -45,8 +45,9 @@ class ShiftNegaviteActivationNetTest(BasePytorchTest): """ This test checks the shift negative activation feature. """ - def __init__(self, unit_test, float_reconstruction_error=1e-6): + def __init__(self, unit_test, float_reconstruction_error=1e-6, activation_layer=torch.nn.Hardswish): super().__init__(unit_test, float_reconstruction_error) + self.activation_layer = activation_layer @staticmethod def generate_inputs(input_shapes): @@ -74,4 +75,9 @@ def create_inputs_shape(self): return [[self.val_batch_size, 3, 224, 224]] def create_feature_network(self, input_shape): - return ShiftNegaviteActivationNet() \ No newline at end of file + return ShiftNegaviteActivationNet(self.activation_layer) + + def compare(self, quantized_models, float_model, input_x=None, quantization_info=None): + super() + q_nodes = quantized_models['all_8bit'].node_sort + assert "activation_post_add" in [n.name for n in q_nodes], "Add operator haven't been added after activation operator" diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index ea55eb81b..93ca341b6 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -425,7 +425,8 @@ def test_shift_negative_activation_net(self): """ This test checks the shift negative activation feature. """ - ShiftNegaviteActivationNetTest(self).run_test(seed=3) + for activation_layer in [torch.nn.Hardswish, torch.nn.GELU]: + ShiftNegaviteActivationNetTest(self, activation_layer=activation_layer).run_test(seed=3) def test_split_concat_net(self): """