From 8cdb09e54324c7e3eaa23bca32e313bb2969aa0e Mon Sep 17 00:00:00 2001 From: Vishwanath Martur <64204611+vishwamartur@users.noreply.github.com> Date: Sun, 3 Nov 2024 21:41:10 +0530 Subject: [PATCH] Add rotation matrix to equalized layers Related to #1073 Add `RotationEqualizedLayer` class to `equalized_layer.py` to hold two learnable matrices for rotation. * Implement `forward` method in `RotationEqualizedLayer` to apply rotation and call the wrapped layer. * Add method to fuse rotation matrices into the wrapped layer. * Add unit tests for `RotationEqualizedLayer` class in `test_equalized_layer.py`. * Test forward pass with rotation matrices. * Test fusing rotation matrices into the wrapped layer. --- src/brevitas/nn/equalized_layer.py | 46 +++++++++++++++++++++++ tests/brevitas/nn/test_equalized_layer.py | 27 +++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 tests/brevitas/nn/test_equalized_layer.py diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 7093c8c17..a5bc708fa 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -41,3 +41,49 @@ def forward(self, *args, **kwargs): # We convert everything to args so that hooks can work correctly out = self.layer(*kwargs.values()) return out + + +class RotationEqualizedLayer(torch.nn.Module): + + def __init__(self, layer, rotation_matrix1, rotation_matrix2) -> None: + super().__init__() + self.layer = layer + self.rotation_matrix1 = torch.nn.Parameter(rotation_matrix1) + self.rotation_matrix2 = torch.nn.Parameter(rotation_matrix2) + + def forward(self, *args, **kwargs): + # Convert args + kwargs + defaults into kwargs + bound_arguments = signature(self.layer.forward).bind(*args, **kwargs) + bound_arguments.apply_defaults() + kwargs = bound_arguments.arguments + + possible_input_kwargs = INPUT_NAMES + input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0] + x = kwargs[input_kwarg] + out = x + + # Apply the first rotation + out = torch.matmul(out, self.rotation_matrix1) + + kwargs[input_kwarg] = out + # QuantMultiheadAttention is not a subclass of MultiheadAttention + # We need to preserve the correctness of the forward even after + # quantization has been applied + if isinstance(self.layer, (torch.nn.MultiheadAttention, QuantMultiheadAttention)): + kwargs['key'] = out + kwargs['value'] = out + # We convert everything to args so that hooks can work correctly + out = self.layer(*kwargs.values()) + + # Apply the second rotation + out = torch.matmul(out, self.rotation_matrix2) + + return out + + def fuse_rotation_matrices(self): + with torch.no_grad(): + self.layer.weight.data = torch.matmul(self.rotation_matrix1, self.layer.weight.data) + self.layer.weight.data = torch.matmul(self.layer.weight.data, self.rotation_matrix2) + if self.layer.bias is not None: + self.layer.bias.data = torch.matmul(self.rotation_matrix1, self.layer.bias.data) + self.layer.bias.data = torch.matmul(self.layer.bias.data, self.rotation_matrix2) diff --git a/tests/brevitas/nn/test_equalized_layer.py b/tests/brevitas/nn/test_equalized_layer.py new file mode 100644 index 000000000..9f9a07a40 --- /dev/null +++ b/tests/brevitas/nn/test_equalized_layer.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +import pytest + +from brevitas.nn.equalized_layer import RotationEqualizedLayer + +class TestRotationEqualizedLayer: + + @pytest.fixture + def setup(self): + layer = nn.Linear(10, 10) + rotation_matrix1 = torch.eye(10) + rotation_matrix2 = torch.eye(10) + return RotationEqualizedLayer(layer, rotation_matrix1, rotation_matrix2) + + def test_forward_pass(self, setup): + rotation_layer = setup + input_tensor = torch.randn(1, 10) + output_tensor = rotation_layer(input_tensor) + assert output_tensor.shape == (1, 10) + + def test_fuse_rotation_matrices(self, setup): + rotation_layer = setup + rotation_layer.fuse_rotation_matrices() + assert torch.allclose(rotation_layer.layer.weight, torch.eye(10)) + if rotation_layer.layer.bias is not None: + assert torch.allclose(rotation_layer.layer.bias, torch.zeros(10))