Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RMS Norm doesn't seem to be supported #355

Open
spacycoder opened this issue Sep 10, 2024 · 5 comments
Open

RMS Norm doesn't seem to be supported #355

spacycoder opened this issue Sep 10, 2024 · 5 comments
Labels
enhancement New feature or request

Comments

@spacycoder
Copy link

Hi, converting a model that uses nn.RMSNorm does not work:

class RMSNormModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.norm = nn.RMSNorm(3, 0.1)

    def forward(self, x):
        x = x.transpose(1, 3) # [N, H, W, C]
        x = self.norm(x)
        return x.transpose(1, 3) # [N, C, H, W]

def _main():
    dummy_input = torch.rand(1, 3, 224, 224)
    model = RMSNormModel()

    qat_config = { "backend": "qnnpack" }
    quantizer = PostQuantizer(
        model, (dummy_input), work_dir="rms_model", config=qat_config
    )

    ptq_coarse_matcher = quantizer.quantize()

error:

ERROR (tinynn.graph.tracer) Connection is lost when generating code for transpose_1_f of type torch.Tensor.transpose
Traceback (most recent call last):
  File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 3380, in trace
    new_graph.init()
  File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 2041, in init
    self.module(*actual_input)
  File ".../lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "test_rms_norm.py", line 15, in forward
    return x.transpose(1, 3) # [N, C, H, W]
           ^^^^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 1089, in new_func
    trace_func = TraceFunction(key, is_class).parse_args(*args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 646, in parse_args
    arg_str = _parse_args(args)
              ^^^^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 589, in _parse_args
    self.tensor_names.append(_tensor_name(a))
                             ^^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/tinynn/graph/tracer.py", line 549, in _tensor_name
    pre_node_name = current_graph().tensor_pre_node_dict[id(a)]
 KeyError: 130226902469760
 ERROR (tinynn.graph.tracer) inputs: ['input_0_f']
 ERROR (tinynn.graph.tracer) forwards: ['transpose_0_f']
 ERROR (tinynn.graph.tracer) outputs: []
 ERROR (tinynn.graph.tracer) constants: []
@peterjc123 peterjc123 added the enhancement New feature or request label Sep 10, 2024
@spacycoder
Copy link
Author

spacycoder commented Sep 10, 2024

Using this implementation of RMSNorm instead of the built in one also fails:

class RMSNorm(nn.Module):
    def __init__(self, normalized_shape: int, eps=1e-8 ):
        """
            Root Mean Square Layer Normalization
        :param normalized_shape: input size
        :param eps: epsilon value, default 1e-8
        """
        super().__init__()

        self.eps = eps
        self.normalized_shape = normalized_shape

        self.scale = nn.Parameter(torch.ones(normalized_shape))
        self.register_parameter("scale", self.scale)

    def forward(self, x: torch.Tensor):
        norm_x = x.norm(2, dim=-1, keepdim=True)
        d_x = self.normalized_shape

        rms_x = norm_x * d_x ** (-1.0 / 2)
        x_normed = x / (rms_x + self.eps)

        return self.scale * x_normed

error:

  File "rmsnormmodel_q.py", line 31, in forward
    norm_0_f = transpose_0_f.norm(2, dim=-1, keepdim=True)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/torch/_tensor.py", line 761, in norm
    return torch.norm(self, p, dim, keepdim, dtype=dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/torch/functional.py", line 1632, in norm
    return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: linalg.vector_norm: Expected a floating point or complex tensor as input. Got QUInt8

@peterjc123
Copy link
Collaborator

peterjc123 commented Sep 10, 2024

@spacycoder Yes, both quantization for either norm or RMSNorm are unsupported at the moment. I wonder if you could actually do that using TFLite. But anyway, we should safely skip those ops.

@peterjc123 peterjc123 added the bug Something isn't working label Sep 10, 2024
@spacycoder
Copy link
Author

FYI torch.rsqrt also fails with the same error "QUint8" error. Maybe thats possible to support? link:

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (Tensor): input tensor to normalize

        Returns:
            Tensor: The output tensor after applying RMSNorm.
        """
        # computation is in fp32
        x_fp32 = x.float()
        x_normed = (
            x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
        ).type_as(x)
        return x_normed * self.scale

A working implementation of RMSNorm can be made like this:

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_fp32 = x.float()
        var = x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps
        x_norm = x_fp32 * (1. / torch.sqrt(var))
        return self.scale * x_norm

@peterjc123
Copy link
Collaborator

peterjc123 commented Sep 10, 2024

@spacycoder OP-wise speaking, yes, we may go through MUL -> MEAN -> RSQRT -> MUL. But the quantization errors can't be ignored I guess, especially for pow and rsqrt. Also, eps should redesigned for quantization.

@peterjc123
Copy link
Collaborator

peterjc123 commented Sep 12, 2024

with #356, at least it won't throw an error for the models you provided. Quantization for those ops are still skipped.

@peterjc123 peterjc123 removed the bug Something isn't working label Sep 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants