Skip to content

Commit

Permalink
Add rotation matrix to equalized layers
Browse files Browse the repository at this point in the history
Related to Xilinx#1073

Add `RotationEqualizedLayer` class to `equalized_layer.py` to hold two learnable matrices for rotation.

* Implement `forward` method in `RotationEqualizedLayer` to apply rotation and call the wrapped layer.
* Add method to fuse rotation matrices into the wrapped layer.
* Add unit tests for `RotationEqualizedLayer` class in `test_equalized_layer.py`.
* Test forward pass with rotation matrices.
* Test fusing rotation matrices into the wrapped layer.
  • Loading branch information
vishwamartur committed Nov 3, 2024
1 parent 4617f7b commit 8cdb09e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,49 @@ def forward(self, *args, **kwargs):
# We convert everything to args so that hooks can work correctly
out = self.layer(*kwargs.values())
return out


class RotationEqualizedLayer(torch.nn.Module):

def __init__(self, layer, rotation_matrix1, rotation_matrix2) -> None:
super().__init__()
self.layer = layer
self.rotation_matrix1 = torch.nn.Parameter(rotation_matrix1)
self.rotation_matrix2 = torch.nn.Parameter(rotation_matrix2)

def forward(self, *args, **kwargs):
# Convert args + kwargs + defaults into kwargs
bound_arguments = signature(self.layer.forward).bind(*args, **kwargs)
bound_arguments.apply_defaults()
kwargs = bound_arguments.arguments

possible_input_kwargs = INPUT_NAMES
input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0]
x = kwargs[input_kwarg]
out = x

# Apply the first rotation
out = torch.matmul(out, self.rotation_matrix1)

kwargs[input_kwarg] = out
# QuantMultiheadAttention is not a subclass of MultiheadAttention
# We need to preserve the correctness of the forward even after
# quantization has been applied
if isinstance(self.layer, (torch.nn.MultiheadAttention, QuantMultiheadAttention)):
kwargs['key'] = out
kwargs['value'] = out
# We convert everything to args so that hooks can work correctly
out = self.layer(*kwargs.values())

# Apply the second rotation
out = torch.matmul(out, self.rotation_matrix2)

return out

def fuse_rotation_matrices(self):
with torch.no_grad():
self.layer.weight.data = torch.matmul(self.rotation_matrix1, self.layer.weight.data)
self.layer.weight.data = torch.matmul(self.layer.weight.data, self.rotation_matrix2)
if self.layer.bias is not None:
self.layer.bias.data = torch.matmul(self.rotation_matrix1, self.layer.bias.data)
self.layer.bias.data = torch.matmul(self.layer.bias.data, self.rotation_matrix2)
27 changes: 27 additions & 0 deletions tests/brevitas/nn/test_equalized_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
import torch.nn as nn
import pytest

from brevitas.nn.equalized_layer import RotationEqualizedLayer

class TestRotationEqualizedLayer:

@pytest.fixture
def setup(self):
layer = nn.Linear(10, 10)
rotation_matrix1 = torch.eye(10)
rotation_matrix2 = torch.eye(10)
return RotationEqualizedLayer(layer, rotation_matrix1, rotation_matrix2)

def test_forward_pass(self, setup):
rotation_layer = setup
input_tensor = torch.randn(1, 10)
output_tensor = rotation_layer(input_tensor)
assert output_tensor.shape == (1, 10)

def test_fuse_rotation_matrices(self, setup):
rotation_layer = setup
rotation_layer.fuse_rotation_matrices()
assert torch.allclose(rotation_layer.layer.weight, torch.eye(10))
if rotation_layer.layer.bias is not None:
assert torch.allclose(rotation_layer.layer.bias, torch.zeros(10))

0 comments on commit 8cdb09e

Please sign in to comment.