Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] Fold Up-Casts into MatMul when supported by tosa.matmul #198

Merged

Conversation

cmcgirr-amd
Copy link

In the newer version of Torch-MLIR using Torch 2.3 and Torch Dynamo + FxImporter we see that a GEMM with bias that operate on bf16 and f16 types decompose differently than the earlier versions using TorchScript.

  %0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[4,8],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],f32>
  %1 = torch.aten.to.dtype %arg1, %int6, %false, %false, %none : !torch.vtensor<[8,16],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],f32>
  %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[4,16],f32>

This decomposition contains a torch.aten.mm that operates on an accumulator type where the inputs are casted from the smaller representation to the accumulator type. E.g. bf16 -> f32. And since TOSA specifications give that the tosa.matmul can support these configurations, we can simply fold the cast into the operation signature rather than keeping the casting operations.

The resulting IR should look like this for the bf16 case.

%8 = tosa.matmul %6, %7 : (tensor<1x4x8xbf16>, tensor<1x8x16xbf16>) -> tensor<1x4x16xf32>

One caveat is the i16 -> i48 tosa.matmul case cannot be supported at the moment as there is no torch.int48 equivalent dtype to represent this special accumulator used in TOSA.

Copy link

@roberteg16 roberteg16 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you noted on your PR message, I think we should guard against i16 -> i48. Otherwise lgtm

lib/Conversion/TorchToTosa/TorchToTosa.cpp Outdated Show resolved Hide resolved
@TinaAMD
Copy link
Collaborator

TinaAMD commented Jul 15, 2024

Nice idea to get rid of the casts this way! Implementationwise, I would have expected this to be a canonicalization on TOSA, i.e. you just lower naively from torch-mlir to tosa, and then canonicalize the tosa casts + matmul to a different matmul (then you could also easily cover the i48 case). WDYT?

@cmcgirr-amd
Copy link
Author

Nice idea to get rid of the casts this way! Implementationwise, I would have expected this to be a canonicalization on TOSA, i.e. you just lower naively from torch-mlir to tosa, and then canonicalize the tosa casts + matmul to a different matmul (then you could also easily cover the i48 case). WDYT?

Kind of tricky because for some the accumulator types are not supported as input types so it would be invalid to have:

tosa.matmul %6, %7 : (tensor<1x4x8xi32>, tensor<1x8x16xi32>) -> tensor<1x4x16xi32>

Could be done for the float types, but integers may be difficult.

@cmcgirr-amd
Copy link
Author

cmcgirr-amd commented Jul 15, 2024

As you noted on your PR message, I think we should guard against i16 -> i48. Otherwise lgtm

In what sense do you mean guarded? Like don't convert in this case:

torch.aten.mm %0, %1 : !torch.vtensor<[4,8],si16>, !torch.vtensor<[8,16],si16> -> !torch.vtensor<[4,16],si48>

Or when there are casting ops? Because we cannot have a legal torch.aten.to_dtype op with si48 AFAIK

@cmcgirr-amd cmcgirr-amd requested a review from roberteg16 July 15, 2024 08:58
@TinaAMD
Copy link
Collaborator

TinaAMD commented Jul 15, 2024

Nice idea to get rid of the casts this way! Implementationwise, I would have expected this to be a canonicalization on TOSA, i.e. you just lower naively from torch-mlir to tosa, and then canonicalize the tosa casts + matmul to a different matmul (then you could also easily cover the i48 case). WDYT?

Kind of tricky because for some the accumulator types are not supported as input types so it would be invalid to have:

tosa.matmul %6, %7 : (tensor<1x4x8xi32>, tensor<1x8x16xi32>) -> tensor<1x4x16xi32>

Oh, I see, so you cannot actually lower the integer cases when the casts are not present, I didn't notice that this was disallowed. Makes sense to keep it here then, thanks for the explanation!

Could be done for the float types, but integers may be difficult.

lib/Conversion/TorchToTosa/TorchToTosa.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTosa/TorchToTosa.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTosa/TorchToTosa.cpp Show resolved Hide resolved
@roberteg16
Copy link

roberteg16 commented Jul 15, 2024

As you noted on your PR message, I think we should guard against i16 -> i48. Otherwise lgtm

In what sense do you mean guarded? Like don't convert in this case:

torch.aten.mm %0, %1 : !torch.vtensor<[4,8],si16>, !torch.vtensor<[8,16],si16> -> !torch.vtensor<[4,16],si48>

Or when there are casting ops? Because we cannot have a legal torch.aten.to_dtype op with si48 AFAIK

I was referring to avoid converting a case like you described.

Because we cannot have a legal torch.aten.to_dtype op with si48 AFAIK

Had a look into: https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L11547

I am no sure it is not allowed as you say

@cmcgirr-amd
Copy link
Author

As you noted on your PR message, I think we should guard against i16 -> i48. Otherwise lgtm

In what sense do you mean guarded? Like don't convert in this case:

torch.aten.mm %0, %1 : !torch.vtensor<[4,8],si16>, !torch.vtensor<[8,16],si16> -> !torch.vtensor<[4,16],si48>

Or when there are casting ops? Because we cannot have a legal torch.aten.to_dtype op with si48 AFAIK

I was referring to avoid converting a case like you described.

Because we cannot have a legal torch.aten.to_dtype op with si48 AFAIK

Had a look into: https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L11547

I am no sure it is not allowed as you say

That is true the definition allows for anything, but I would find it hard to find an integer value to describe the dtype if the importer does not have it defined: https://github.com/Xilinx/torch-mlir/blob/feature/backport_ea1_ops/python/torch_mlir/extras/fx_importer.py#L176-L193

@cmcgirr-amd cmcgirr-amd requested a review from TinaAMD July 15, 2024 14:54
@roberteg16
Copy link

As you noted on your PR message, I think we should guard against i16 -> i48. Otherwise lgtm

In what sense do you mean guarded? Like don't convert in this case:

torch.aten.mm %0, %1 : !torch.vtensor<[4,8],si16>, !torch.vtensor<[8,16],si16> -> !torch.vtensor<[4,16],si48>

Or when there are casting ops? Because we cannot have a legal torch.aten.to_dtype op with si48 AFAIK

I was referring to avoid converting a case like you described.

Because we cannot have a legal torch.aten.to_dtype op with si48 AFAIK

Had a look into: https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L11547
I am no sure it is not allowed as you say

That is true the definition allows for anything, but I would find it hard to find an integer value to describe the dtype if the importer does not have it defined: https://github.com/Xilinx/torch-mlir/blob/feature/backport_ea1_ops/python/torch_mlir/extras/fx_importer.py#L176-L193

Oh I see, it seems that the imported does not allow it. Could we just harden it to make it safe for the future?

@cmcgirr-amd cmcgirr-amd requested a review from TinaAMD July 16, 2024 09:28
Copy link
Collaborator

@TinaAMD TinaAMD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cmcgirr-amd cmcgirr-amd merged commit 611039c into feature/backport_ea1_ops Aug 13, 2024
3 checks passed
@cmcgirr-amd cmcgirr-amd deleted the christopher.matmul_casted_inputs branch August 13, 2024 09:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants