-
Notifications
You must be signed in to change notification settings - Fork 507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TorchToLinalg] Add aten.fft_rfft
and lowering
#3857
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before reviewing in detail, let me see if I understand correctly what your cross-repository goal is.
- You include this torch-to-linalg conversion as a baseline conversion, but in
IREE
, you intend to have a more performant lowering tolinalg_ext
? I suppose this pass exists beforetorch-to-linalg
in thetorch-to-iree
pipeline? - Would it make more sense to add this as a decomposition of the op at the torch-dialect level? That way, other backends like
Tosa
andStableHLO
could benefit, and we can turn off the op viabackend-legal-ops
option intorch-decompose-complex-ops
pass if we want to go a different route inIREE
. I have plans to modify the decompose complex ops pass to be more specific in thetorch-to-iree
pipeline this week, so we can specify abackend-legal-ops
set there.
aten.ffr_rfft
and loweringaten.fft_rfft
and lowering
|
Yeah, precisely. There are some limitations, however. Does the higher performance path for |
Ah, I see you already converted this to a decomposition. Perhaps we should just do both? StableHlo and Tosa would benefit from the decomposition, which we can turn off once you add the |
@zjgarvey The higher-performance path would apply when the input signal length is a power of 2, all other cases would need to be translated to this "naive" algorithm. Do you think it's possible to branch compilation based on the input dimension size? |
It might be possible to mark the op as conditionally illegal for |
@zjgarvey Added conversion back. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few comments after looking closer at the code.
if (isRealPart) { | ||
v = cos(v); | ||
} else { | ||
v = -sin(v); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit : I'd prefer a ternary expression
if (isRealPart) { | |
v = cos(v); | |
} else { | |
v = -sin(v); | |
} | |
v = isRealPart ? cos(v) : -sin(v); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed according to your suggestion.
BaseTensorType lhsType = cast<BaseTensorType>(lhs.getType()); | ||
assert(lhsType && lhsType.hasSizes()); | ||
const ArrayRef<int64_t> lhsShape = lhsType.getSizes(); | ||
assert(lhsShape.size() >= 2); | ||
BaseTensorType rhsType = cast<BaseTensorType>(rhs.getType()); | ||
assert(rhsType && rhsType.hasSizes()); | ||
const ArrayRef<int64_t> rhsShape = rhsType.getSizes(); | ||
assert(rhsShape.size() >= 2); | ||
assert(rhsShape[rhsShape.size() - 2] == lhsShape[lhsShape.size() - 1]); | ||
|
||
SmallVector<int64_t> resShape(lhsShape); | ||
resShape[resShape.size() - 1] = rhsShape[rhsShape.size() - 1]; | ||
|
||
Type dtype = lhsType.getOptionalDtype(); | ||
|
||
ValueTensorType resType = | ||
ValueTensorType::get(rewriter.getContext(), resShape, dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can avoid this helper function entirely. The asserts should never fail anyway, since you are generating the DFT coefficient matrix from the input and are already reporting match failures for unsupported cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Removing the function and simplifying.
Value unsqueezeDim = | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-2)); | ||
auto unsqueezed = unsqueezeTensor(rewriter, op, self, unsqueezeDim); | ||
if (failed(unsqueezed)) | ||
return rewriter.notifyMatchFailure(op, | ||
"cannot generate unsqueezed tensor"); | ||
Value lhs = *unsqueezed; | ||
Type dtype = inputType.getOptionalDtype(); | ||
|
||
Value real, complex; | ||
|
||
for (const bool isRealPart : {true, false}) { | ||
|
||
// coeff : (fftLength x outputFftDim) | ||
ValueTensorType matrixType = ValueTensorType::get( | ||
op.getContext(), SmallVector<int64_t>{fftLength, outputFftDim}, | ||
dtype); | ||
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType, | ||
/*isRealPart=*/isRealPart); | ||
|
||
// X = matmul(lhs, coeff) : (D x 1 x outputFftDim) | ||
Value matmulRes = createBatchMatmul(rewriter, loc, lhs, coeffMatrix); | ||
|
||
// Y = squeeze(X, -2) : (D x outputFftDim) | ||
auto squeezed = squeezeTensor(rewriter, op, loc, -2, matmulRes); | ||
if (failed(squeezed)) | ||
return rewriter.notifyMatchFailure(op, | ||
"cannot generate squeezed tensor"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why we need to conjugate the torch.aten.matmul
with a squeeze. Pytorch's matmul should do what we want regardless of the size of D
: (D x fftLength) * (fftLength x outputFftDim) -> (D x outputFftDim)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, we don't. Changing.
|
||
Value real, complex; | ||
|
||
for (const bool isRealPart : {true, false}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the looping variable a const bool
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is removed in the refactoring.
op.getContext(), SmallVector<int64_t>{fftLength, outputFftDim}, | ||
dtype); | ||
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType, | ||
/*isRealPart=*/isRealPart); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove the arg hint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed in the refactoring.
Value lhs = *unsqueezed; | ||
Type dtype = inputType.getOptionalDtype(); | ||
|
||
Value real, complex; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'd rename complex
to imaginary
, since the latter tensor represents the imaginary part of the end result, which is complex-valued.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I misused the word complex
. imaginary
is the correct one. Changing.
// CHECK: %[[INTM2:.*]] = torch.constant.int -2 | ||
// CHECK: %[[INT0:.*]] = torch.constant.int 0 | ||
// CHECK: %[[INT1:.*]] = torch.constant.int 1 | ||
// CHECK: %[[VAR2:.*]] = torch.aten.transpose.int %[[ARG0:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For lit tests, you should only do [[NAME:.*]]
once. Every subsequent use of a variable should be [[NAME]]
, otherwise the variable NAME
gets overridden, even if it didn't match the original use in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Changing.
Value stack = | ||
rewriter.create<AtenStackOp>(loc, stackType, sequence, cstMinusOne); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how much we need to beef up the decomposition here, but it would probably be most efficient to construct the real and imaginary parts of the coeff matrix in one literal tensor of shape [fftLength, (outputFftDim*2)]
in such a way that unflattening to [fftLength, outputFftDim, 2]
gives the real and imaginary split in the last dim. Then the matmul can be performed in one torch.aten.matmul
, the result can then be unflattened before getting converted to a complex tensor. Concatenations and matmuls are expensive, so reducing those would be ideal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I erroneously assumed that through optimization passes it would have been transformed to the optimal computation that you described, but indeed it doesn't. Also I'm not sure how an optimizing transformation for this case should be expressed. For simplicity I'll change the decomposition to the form that you suggest, although, by doing so, the decomposition becomes slightly less readable.
Value realMatrix = | ||
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true); | ||
Value real = createLinalgMatmulOnTensors(rewriter, loc, componentsType, | ||
self, realMatrix); | ||
|
||
Value imagMatrix = | ||
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false); | ||
Value imag = createLinalgMatmulOnTensors(rewriter, loc, componentsType, | ||
self, imagMatrix); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think my comment about multiple matmuls in the decomposition below applies here as well. Let me know what you think about making one DFTMatmulCoeff
with both real and imaginary parts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although, in this case, the linalg generic would need to be constructed more carefully, since you wouldn't have two tensors to iterate over (it wouldn't be elementwise anymore, and you would need to fiddle with the indexing maps). Feel free to keep this conversion as-is if it seems like too much work to make the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the same reason as above I think this should also be done. I will add this in the next commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored the conversion in the last commit.
AtenFftRfftOp
to Torch dialect.AtenFftRfftOp
to Linalg, using alinalg.matmul
per output component (real and imaginary). Computing the DFT is O(n^2).AtenFftRfftOp
into Torch-level ops (same paradigm as above).