Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchToLinalg] Add aten.fft_rfft and lowering #3857

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

giacs-epic
Copy link
Contributor

@giacs-epic giacs-epic commented Nov 7, 2024

  • Add AtenFftRfftOp to Torch dialect.
  • Add conversion of AtenFftRfftOp to Linalg, using a linalg.matmul per output component (real and imaginary). Computing the DFT is O(n^2).
  • Add decomposition of AtenFftRfftOp into Torch-level ops (same paradigm as above).
  • Add unit and end-to-end tests.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before reviewing in detail, let me see if I understand correctly what your cross-repository goal is.

  1. You include this torch-to-linalg conversion as a baseline conversion, but in IREE, you intend to have a more performant lowering to linalg_ext? I suppose this pass exists before torch-to-linalg in the torch-to-iree pipeline?
  2. Would it make more sense to add this as a decomposition of the op at the torch-dialect level? That way, other backends like Tosa and StableHLO could benefit, and we can turn off the op via backend-legal-ops option in torch-decompose-complex-ops pass if we want to go a different route in IREE. I have plans to modify the decompose complex ops pass to be more specific in the torch-to-iree pipeline this week, so we can specify a backend-legal-ops set there.

@giacs-epic giacs-epic changed the title [TorchToLinalg] Add aten.ffr_rfft and lowering [TorchToLinalg] Add aten.fft_rfft and lowering Nov 11, 2024
@giacs-epic
Copy link
Contributor Author

giacs-epic commented Nov 11, 2024

@zjgarvey

  1. Yes, that's the goal so far. Indeed I would place a pass to lower rfft to linalg_ext before torch-to-linalg in iree.
  2. That makes a lot of sense. Would it be compatible with an eventual decomposition of aten.stft? I.e. would having both decompositions yield to the following behavior: aten.stft gets decomposed into aten.fft_rffts, which in turn get decomposed into matmuls?

@zjgarvey
Copy link
Collaborator

2. That makes a lot of sense. Would it be compatible with an eventual decomposition of `aten.stft`? I.e. would having both decompositions yield to the following behavior: `aten.stft` gets decomposed into `aten.fft_rfft`s, which in turn get decomposed into matmuls?

Yeah, precisely. There are some limitations, however. Does the higher performance path for fft_rfft to linalg_ext apply to the same cases as this conversion? If not, we will definitely need to keep this as a torch-to-linalg conversion to catch any patterns that failed to match the conversion to linalg_ext. This is because we won't be able to go back to decompose-complex-ops after trying to convert to linalg_ext.

@zjgarvey
Copy link
Collaborator

Ah, I see you already converted this to a decomposition. Perhaps we should just do both? StableHlo and Tosa would benefit from the decomposition, which we can turn off once you add the torch-to-linalg-ext path to IREE, and then the torch-to-linalg conversion would be a final fallback if the linalg_ext path doesn't apply.

@giacs-epic
Copy link
Contributor Author

giacs-epic commented Nov 13, 2024

@zjgarvey The higher-performance path would apply when the input signal length is a power of 2, all other cases would need to be translated to this "naive" algorithm. Do you think it's possible to branch compilation based on the input dimension size?
Otherwise I'm open to just keeping both decomposition and lowering to linalg.

@zjgarvey
Copy link
Collaborator

It might be possible to mark the op as conditionally illegal for decompose-complex-ops, but I don't think we want to go that route. Let's add both the torch-to-linalg conversion and the decomposition for now.

@giacs-epic
Copy link
Contributor Author

@zjgarvey Added conversion back.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few comments after looking closer at the code.

Comment on lines 1394 to 1398
if (isRealPart) {
v = cos(v);
} else {
v = -sin(v);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : I'd prefer a ternary expression

Suggested change
if (isRealPart) {
v = cos(v);
} else {
v = -sin(v);
}
v = isRealPart ? cos(v) : -sin(v);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed according to your suggestion.

Comment on lines 9055 to 9071
BaseTensorType lhsType = cast<BaseTensorType>(lhs.getType());
assert(lhsType && lhsType.hasSizes());
const ArrayRef<int64_t> lhsShape = lhsType.getSizes();
assert(lhsShape.size() >= 2);
BaseTensorType rhsType = cast<BaseTensorType>(rhs.getType());
assert(rhsType && rhsType.hasSizes());
const ArrayRef<int64_t> rhsShape = rhsType.getSizes();
assert(rhsShape.size() >= 2);
assert(rhsShape[rhsShape.size() - 2] == lhsShape[lhsShape.size() - 1]);

SmallVector<int64_t> resShape(lhsShape);
resShape[resShape.size() - 1] = rhsShape[rhsShape.size() - 1];

Type dtype = lhsType.getOptionalDtype();

ValueTensorType resType =
ValueTensorType::get(rewriter.getContext(), resShape, dtype);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can avoid this helper function entirely. The asserts should never fail anyway, since you are generating the DFT coefficient matrix from the input and are already reporting match failures for unsupported cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Removing the function and simplifying.

Comment on lines 9144 to 9171
Value unsqueezeDim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-2));
auto unsqueezed = unsqueezeTensor(rewriter, op, self, unsqueezeDim);
if (failed(unsqueezed))
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueezed tensor");
Value lhs = *unsqueezed;
Type dtype = inputType.getOptionalDtype();

Value real, complex;

for (const bool isRealPart : {true, false}) {

// coeff : (fftLength x outputFftDim)
ValueTensorType matrixType = ValueTensorType::get(
op.getContext(), SmallVector<int64_t>{fftLength, outputFftDim},
dtype);
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType,
/*isRealPart=*/isRealPart);

// X = matmul(lhs, coeff) : (D x 1 x outputFftDim)
Value matmulRes = createBatchMatmul(rewriter, loc, lhs, coeffMatrix);

// Y = squeeze(X, -2) : (D x outputFftDim)
auto squeezed = squeezeTensor(rewriter, op, loc, -2, matmulRes);
if (failed(squeezed))
return rewriter.notifyMatchFailure(op,
"cannot generate squeezed tensor");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we need to conjugate the torch.aten.matmul with a squeeze. Pytorch's matmul should do what we want regardless of the size of D: (D x fftLength) * (fftLength x outputFftDim) -> (D x outputFftDim).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, we don't. Changing.


Value real, complex;

for (const bool isRealPart : {true, false}) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the looping variable a const bool?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is removed in the refactoring.

op.getContext(), SmallVector<int64_t>{fftLength, outputFftDim},
dtype);
Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType,
/*isRealPart=*/isRealPart);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove the arg hint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed in the refactoring.

Value lhs = *unsqueezed;
Type dtype = inputType.getOptionalDtype();

Value real, complex;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd rename complex to imaginary, since the latter tensor represents the imaginary part of the end result, which is complex-valued.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I misused the word complex. imaginary is the correct one. Changing.

// CHECK: %[[INTM2:.*]] = torch.constant.int -2
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[VAR2:.*]] = torch.aten.transpose.int %[[ARG0:.*]], %[[INT0:.*]], %[[INT1:.*]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For lit tests, you should only do [[NAME:.*]] once. Every subsequent use of a variable should be [[NAME]], otherwise the variable NAME gets overridden, even if it didn't match the original use in the first place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Changing.

Comment on lines 9190 to 9191
Value stack =
rewriter.create<AtenStackOp>(loc, stackType, sequence, cstMinusOne);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how much we need to beef up the decomposition here, but it would probably be most efficient to construct the real and imaginary parts of the coeff matrix in one literal tensor of shape [fftLength, (outputFftDim*2)] in such a way that unflattening to [fftLength, outputFftDim, 2] gives the real and imaginary split in the last dim. Then the matmul can be performed in one torch.aten.matmul, the result can then be unflattened before getting converted to a complex tensor. Concatenations and matmuls are expensive, so reducing those would be ideal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I erroneously assumed that through optimization passes it would have been transformed to the optimal computation that you described, but indeed it doesn't. Also I'm not sure how an optimizing transformation for this case should be expressed. For simplicity I'll change the decomposition to the form that you suggest, although, by doing so, the decomposition becomes slightly less readable.

Comment on lines 1502 to 1510
Value realMatrix =
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true);
Value real = createLinalgMatmulOnTensors(rewriter, loc, componentsType,
self, realMatrix);

Value imagMatrix =
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false);
Value imag = createLinalgMatmulOnTensors(rewriter, loc, componentsType,
self, imagMatrix);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my comment about multiple matmuls in the decomposition below applies here as well. Let me know what you think about making one DFTMatmulCoeff with both real and imaginary parts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although, in this case, the linalg generic would need to be constructed more carefully, since you wouldn't have two tensors to iterate over (it wouldn't be elementwise anymore, and you would need to fiddle with the indexing maps). Feel free to keep this conversion as-is if it seems like too much work to make the change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the same reason as above I think this should also be done. I will add this in the next commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored the conversion in the last commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants