Skip to content

Commit

Permalink
Feature/gelu shift negative (#1163)
Browse files Browse the repository at this point in the history
* setup updates & activation net

* gelu & selu shift negative

* onnx version

* GELU test

---------

Co-authored-by: yardeny-sony <[email protected]>
  • Loading branch information
yarden-yagil-sony and yardeny-sony authored Aug 14, 2024
1 parent fc68360 commit b4c74c8
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.constants import SOFTMAX_THRESHOLD
from model_compression_toolkit.core.keras.constants import SOFTMAX, LINEAR, RELU, SWISH, SIGMOID, IDENTITY, TANH, SELU, \
KERNEL, DEPTHWISE_KERNEL
KERNEL, DEPTHWISE_KERNEL, GELU
from model_compression_toolkit.core.keras.quantizer.fake_quant_builder import power_of_two_quantization, symmetric_quantization, uniform_quantization

"""
Expand Down Expand Up @@ -75,7 +75,8 @@
TANH: (-1, 1),
SWISH: (-0.279, None),
RELU: (0, None),
SELU: (None, None),
SELU: (-1.76, None),
GELU: (-0.17, None),
}

"""
Expand Down
11 changes: 8 additions & 3 deletions model_compression_toolkit/core/pytorch/default_framework_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid
from torch.nn.functional import hardsigmoid, relu, relu6, softmax
from torch.nn import Hardsigmoid, ReLU, ReLU6, Softmax, Sigmoid, GELU, SELU
from torch.nn.functional import hardsigmoid, relu, relu6, softmax, gelu, selu
from torch.nn import Conv2d, ConvTranspose2d, Linear
from torch import sigmoid

Expand Down Expand Up @@ -74,7 +74,12 @@
ReLU: (0, None),
relu: (0, None),
ReLU6: (0, None),
relu6: (0, None)}
relu6: (0, None),
GELU: (-0.17, None),
gelu: (-0.17, None),
SELU: (-1.76, None),
selu: (-1.76, None),
}

"""
Mapping from a QuantizationMethod to an activation quantizer function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import numpy as np
import torch.nn.functional
from torch.nn import Conv2d, Linear, PReLU, ELU, Hardswish, Dropout, ZeroPad2d, SiLU
from torch.nn import Conv2d, Linear, PReLU, ELU, Hardswish, Dropout, ZeroPad2d, SiLU, GELU
from torch import reshape
from torch.nn.functional import hardswish, silu, prelu, elu
from torch.nn.functional import hardswish, silu, prelu, elu, gelu
from torch.nn.functional import avg_pool2d

from model_compression_toolkit.core import CoreConfig, FrameworkInfo
Expand Down Expand Up @@ -68,7 +68,9 @@ def shift_negative_activation_node_matchers():
NodeOperationMatcher(Hardswish) | \
NodeOperationMatcher(hardswish) | \
NodeOperationMatcher(SiLU) | \
NodeOperationMatcher(silu)
NodeOperationMatcher(silu) | \
NodeOperationMatcher(GELU) | \
NodeOperationMatcher(gelu)

# Match linear layers where we can add a correction.
linear_node = NodeOperationMatcher(Conv2d) | \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@


class ShiftNegaviteActivationNet(torch.nn.Module):
def __init__(self):
def __init__(self, activation_layer):
super(ShiftNegaviteActivationNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 4, kernel_size=(5,6), stride=2)
self.conv2 = torch.nn.Conv2d(4, 5, kernel_size=(8,7), stride=2, bias=False)
self.activation = torch.nn.Hardswish()
self.activation = activation_layer()

def forward(self, inp):
x0 = self.conv1(inp)
Expand All @@ -45,8 +45,9 @@ class ShiftNegaviteActivationNetTest(BasePytorchTest):
"""
This test checks the shift negative activation feature.
"""
def __init__(self, unit_test, float_reconstruction_error=1e-6):
def __init__(self, unit_test, float_reconstruction_error=1e-6, activation_layer=torch.nn.Hardswish):
super().__init__(unit_test, float_reconstruction_error)
self.activation_layer = activation_layer

@staticmethod
def generate_inputs(input_shapes):
Expand Down Expand Up @@ -74,4 +75,9 @@ def create_inputs_shape(self):
return [[self.val_batch_size, 3, 224, 224]]

def create_feature_network(self, input_shape):
return ShiftNegaviteActivationNet()
return ShiftNegaviteActivationNet(self.activation_layer)

def compare(self, quantized_models, float_model, input_x=None, quantization_info=None):
super()
q_nodes = quantized_models['all_8bit'].node_sort
assert "activation_post_add" in [n.name for n in q_nodes], "Add operator haven't been added after activation operator"
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,8 @@ def test_shift_negative_activation_net(self):
"""
This test checks the shift negative activation feature.
"""
ShiftNegaviteActivationNetTest(self).run_test(seed=3)
for activation_layer in [torch.nn.Hardswish, torch.nn.GELU]:
ShiftNegaviteActivationNetTest(self, activation_layer=activation_layer).run_test(seed=3)

def test_split_concat_net(self):
"""
Expand Down

0 comments on commit b4c74c8

Please sign in to comment.