-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOSA] Fold Up-Casts into MatMul when supported by tosa.matmul #198
[TOSA] Fold Up-Casts into MatMul when supported by tosa.matmul #198
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you noted on your PR message, I think we should guard against i16 -> i48. Otherwise lgtm
Nice idea to get rid of the casts this way! Implementationwise, I would have expected this to be a canonicalization on TOSA, i.e. you just lower naively from torch-mlir to tosa, and then canonicalize the tosa casts + matmul to a different matmul (then you could also easily cover the |
Kind of tricky because for some the accumulator types are not supported as input types so it would be invalid to have:
Could be done for the float types, but integers may be difficult. |
In what sense do you mean guarded? Like don't convert in this case:
Or when there are casting ops? Because we cannot have a legal |
Oh, I see, so you cannot actually lower the integer cases when the casts are not present, I didn't notice that this was disallowed. Makes sense to keep it here then, thanks for the explanation!
|
I was referring to avoid converting a case like you described.
Had a look into: https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L11547 I am no sure it is not allowed as you say |
That is true the definition allows for anything, but I would find it hard to find an integer value to describe the dtype if the importer does not have it defined: https://github.com/Xilinx/torch-mlir/blob/feature/backport_ea1_ops/python/torch_mlir/extras/fx_importer.py#L176-L193 |
Oh I see, it seems that the imported does not allow it. Could we just harden it to make it safe for the future? |
rather than the f32 accumulator
…converted to tosa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
In the newer version of Torch-MLIR using Torch 2.3 and Torch Dynamo + FxImporter we see that a GEMM with bias that operate on bf16 and f16 types decompose differently than the earlier versions using TorchScript.
This decomposition contains a
torch.aten.mm
that operates on an accumulator type where the inputs are casted from the smaller representation to the accumulator type. E.g.bf16 -> f32
. And since TOSA specifications give that thetosa.matmul
can support these configurations, we can simply fold the cast into the operation signature rather than keeping the casting operations.The resulting IR should look like this for the bf16 case.
One caveat is the
i16 -> i48
tosa.matmul case cannot be supported at the moment as there is no torch.int48 equivalent dtype to represent this special accumulator used in TOSA.