-
Notifications
You must be signed in to change notification settings - Fork 509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TorchToTosa] Refactoring to separate construction of legal/illegal ops and conversion patterns. #3759
base: main
Are you sure you want to change the base?
Conversation
Hi @sjarus, tagging you as you recently reviewed a PR for changes in the |
a76c138
to
49199da
Compare
I'll take a look within a day. |
target.addLegalOp<PrimTupleConstructOp>(); | ||
} | ||
|
||
void torch::populateTorchToTosaConversionIllegalOps(ConversionTarget &target) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During rebasing main
branch I realized that it's easy to miss updating the list but this list is probably not required since target.addIllegalDialect<Torch::TorchDialect>()
is also present in this pass. There's also a check during VerifyBackend*
that ensures no torch
ops are present after a full pipeline runs. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the Torch dialect would probably get in the way of partial legalization if that's the goal.
The handling Isn't there a way to make the aten op illegal conditionally - a successful rewrite ought to make it illegal but not otherwise ? I recall there was infrastructure proposed to do this, but this can be tricky when some instances of the pattern replacement succeed and others do not.
The alternative is to have the macro append to a list rather than make the op illegal and parameterize the pass behavior to either attempt a full conversion with a pass/fail or a partial conversion. You'd then either apply target.IllegalOp<> depending on whether the intended behavior is to have the pass fail on conversion or not. Does that work ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review, @sjarus.
Yes it's possible to parameterize the pass to bypass the illegal ops check but I'd like to understand your concern better as I am still learning MLIR. For the existing pass, since there is already target.addIllegalDialect<Torch::TorchDialect>()
are the individual iilegalOps<>
adding any benefit or are they redundant?
For the tosa+linalg
pass that I'm prototyping the end-goal is to fail if any torch
op is present at the end of the full pipeline, so adding target.addIllegalDialect<Torch::TorchDialect>()
there is working as well (I've used the populateTorchToTosaConversionPatterns
in that new pipeline to perform partial conversion of the torch->tosa
ops in the pattern list).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So one immediate concern would be that the manual addition to the Aten ops list here would not be a tenable solution. It used to be abstracted behind the macro and worked cleanly, but within the constraints applied originally. The new approach should not add a construct that you've correctly recognized as easily breakable.
Secondly I'm trying to understand the design goal here. Let me try to describe my own understanding of the goal: the current pass makes all legalized Torch ops invalid so that a conversion failure manifests itself as a pass failure. This is ok if the goal is to layer two passes - TorchToTosa handling subset A of Torch ops, and subsequently TorchToLinalg for all remaining ones of interest (let's call that B).
However this breaks if there exists variants of ops in A that such that A is a combination of ops supported by TorchToTosa (A') and variants of those same ops that happen to be unsupported (A") . For example let's trivially presume conv2d with unit strides are supported but non unit strides are not.
The goal is to layer an implementation of A" within TorchToLinalg and implement a pipeline that has TorchToTosa for A' followed by TorchToLinalg for A" and B. This won't currently work because A" would be handled in TorchToTosa and emit a pass failure. The goal here is to mark A" valid in TorchToTosa so they can be consumed by TorchToLinalg later within this intended pipeline. Is that correct ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the elaborate response. I agree with your understanding but here are some clarifications:
So one immediate concern would be that the manual addition to the Aten ops list here would not be a tenable solution.
Yes, I agree. My thought is that since this pass pipeline already specifies torch
as an illegal dialect, we don't need to maintain this individual list of illegal ops. If any op from A'' is present for the current TorchToTosa
pipeline while the op itself will not be marked as illegal the pass will fail because torch
dialect is illegal after the pass completion.
The goal is to layer an implementation of A" within TorchToLinalg and implement a pipeline that has TorchToTosa for A' followed by TorchToLinalg for A" and B.
Yes, the end goal is let the TorchToTosa
pass handle as many op it can handle and let TorchToLinalg
handle the rest. So ops in A' will be handled by the conversions in TorchToTosa
and ops in A'' and B will be handled by conversions in TorchToLinalg
as you mentioned. Here is how I've setup the new pipeline https://github.com/sahas3/torch-mlir/blob/b0468a9ec367da0fb2c2e813f74437b9fa9ff7e8/lib/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.cpp#L63. Instead of relying on the existing TorchToTosa
and TorchToLinalg
passes as individual passes I'm calling the rewrite patterns back to back in the same pass. In this case, my understanding is that populateTorchToTosaConversionPatterns
will handle A' but leave any ops in A'' as it is. Since we have not marked any ops in A to be illegal, there won't be any pass failure even if ops in A'' is present after running the TorchToTosa
pass patterns. Assuming A'' is supported by TorchToLinalg pass, it will then be processed correctly and lowered to linalg
+other dialects correctly as part of populateTorchToLinalgOnTensorsPatternsAndLegality
. The whole pipeline will fail if A'' is not handled by populateTorchToLinalgOnTensorsPatternsAndLegality
since we have torch
as illegal dialect in the new pass as well. For sanity, I verified that this new pass does handle the AvgPool2D
op with count_pad
set to true
that TorchToTosa
cannot support correctly (#3822) -- I see same IR generated for TorchToLinalg
and the new pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thought is that since this pass pipeline already specifies torch as an illegal dialect, we don't need to maintain this individual list of illegal ops.
That would not quite be the same behavior. Leveraging the originally defined classes i.e.
A' : the op variants supported by TorchToTosa
A" + B: variants of A not supported by TorchToTosa plus additional ops B not supported in any manner (e.g. aten.sort)
Right now, the pass will return TOSA+Torch when presented with a model containing A' (which would convert to TOSA) + B (left alone) . It will not fail but will simply partially convert. If Torch dialect is made illegal, B would be illegal and would fail. That's materially different behavior.
So you'd want to retain the explicit list A . The main problem is that we currently do not disambiguate A' from A" when doing addIllegalOp() . That affects your ability to put things into one pipeline. If you add controllability of addIllegalOp such that it's only done if the conversion succeeded, then you can craft your pipeline synthetically by just sequencing TorchToTosa before TorchToLinalg since it'll just work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had an offline discussion with @sjarus about the intended behavior of the TorchToTosa
pass. Originally it was designed to partially convert the ops that are added in the conversion pattern list and leave any other op unchanged. This behavior was changed when addIlegalDialect<torch>
was also added to the pass as now the pass fails for any ops that it cannot convert to Tosa. Based on our offline discussion, we decided to
- remove the
addIllegalDialect<torch>
call - to keep the individual
addIllegalOps
calls
I've only done 2 here by returning a list of illegal ops from the populateTorchToTosaConversionPatternsAndIllegalOps
that the pass author can then mark as illegal. I think this allows partial conversion of A'' ops by marking them as dynamicallyIllegal
based on the configurations of such ops that is supported by the conversion pass -- this is to be explored more in a future PR.
I haven't done 1 as looking back at the change that introduced the addIllegalDialect<torch>
call the motivation was to point to the actual torch
ops that cannot be converted in the full torch-backend-to-tosa-backend-pipeline
.
Consider the IR:
func.func @torch.prim.TupleConstruct() {
%int128 = torch.constant.int 128
%0 = torch.prim.TupleConstruct %int128 : !torch.int -> !torch.tuple<int>
// expected-error @below {{failed to legalize operation 'torch.prim.Print' that was explicitly marked illegal}}
torch.prim.Print(%0) : !torch.tuple<int>
return
}
Without addIllegalDialect<torch>
the failure points to torch.constant.int
but with the failure points to torch.prim.Print
which provides more context as to the op that isn't supported by the pipeline. This insight seems valuable to me and making the change will also cause a backward incompatibility if other devs are already familiar with this behavior. Maybe this option can be parameterized in the pass instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick comment on the latest PR: would addDynamicallyLegalOp be more suitable in this use case ? Ops that fail to legalize in TorchToTosa would be dynamically legal downstream for your TorchToLinalg pass to handle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so too. As mentioned in my last comment, I want to explore marking A'' ops as dynamicallyLegal in the new pass as a separate PR.
Hi @sjarus , a gentle reminder to review this PR when you get a chance. Thanks! |
e62285c
to
a197fdf
Compare
a197fdf
to
14e3cea
Compare
Hi @sjarus any thoughts on this PR? |
14e3cea
to
6760b46
Compare
Terribly sorry - I completely missed this PR! Please ping me on discord if you don't get a timely response from me. Reviewing this now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.
0f92856
to
7e120d4
Compare
…and conversion patterns.
7e120d4
to
ffa472f
Compare
This PR refactors TorchToTosa to separate the construction of legal/illegal ops and conversion patterns in their own functions:
Currently the (il)legality of the ops that are (il)legal after the conversion pass runs is embedded within the conversion pattern. Our end goal is to write a new pass pipeline that converts
torch
ops to a mix oftosa
,linalg
,tensor
, etc dialect ops. The reason we want to also emittosa
ops (instead of using the existingTorchToLinalg
to emitlinalg
+tensor
+...) is because some operations likeconv2d
encodes the padding behavior in the op intosa
unlike thelinalg
version -- this helps in lowering thetosa.conv2d
to a custom implementation that does padding on the fly.To implement this new pipeline we need to be able to separate out the illegal
tosa
ops from the conversion pattern itself. Otherwise we will hit an issue for ops likeAtenMaxDimOp
which can be lowered to bothtosa
andlinalg + others
dialects. Not allAtenMaxDimOp
can be lowered successfully totosa
as the implementation usestosa.reshape
which cannot handle multiple dynamic dimensions but theTorchToLinalg
lowering can handle it. In the current behavior the pipeline will stop as soon as the existingTorchToTosa
conversion runs asAtenMaxDimOp
will be marked as an illegal op.Essentially we want to be able to control what the legality of the ops should be independent of the conversion pattern. This is also inline with the conversion patterns in the llvm-mlir repo such as https://github.com/llvm/llvm-project/blob/000e790be35b77a01872851646d54432a203542c/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp#L718
"THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY."