Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for Roberta unstick->reshape->transpose->reshape->stick #3056

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

AlexandreEichenberger
Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger commented Jan 29, 2025

In some situations, a sequence of transformations are "no-ops" under a given ztensor representation.

The pattern that is exploited here is the 3DS for (A, B, C*D) <-> (A*C, B, D) which are equivalent when B%32=0 and D%64=0,

The pattern detected and transformed to a high/zlow reshape are the following

image

and

image

A high level proof is here below. For a detail proof, one has to follow every steps of the transformations above and show equality of memory accesses, namely that when accessing (e3, e2, e1), we get the same memory location in the original 3DS tensor as well as the final 3DS tensor in the above examples.

image

In practice, this PR adds 2 rules to catch the above 2 patterns, replace them with a zhigh.Reshape which is similar to the memref.reshape in that it performs no "data layout transformation", just provide mapping between two equivalent shapes. The ZHigh version performs such equivalency for ZTensor formats such as 3D.

THe ZHigh reshape operation is lowered to ZLow equivalent reshape operation, which is then transformed to a memref.reinterpret_cast operations after all members are normalized.

PR adds littlest to catch the patterns listed above, and one for ZHigh to ZLow conversion, and one for ZLow to memref.

I checked that the values generated by Roberta with/without this PR were the same. Performance measurements show that in Roberta, the number of transpose were reduced from 48 to 12 (with a reductions in stick/unstick also by 36 operations). Speedup for the time spent in the transpose/stick/unstick were reduced by 9%, 33%, and 37%. Overall (with one NNPA and once CPU), the time was reduced by 4%.

At this time, this PR is restricted to static shapes.

Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
Signed-off-by: Alexandre Eichenberger <[email protected]>
@@ -95,17 +95,6 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern<
//
//===----------------------------------------------------------------------===//


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

migrated the code elsewhere so that it can be reused, as it was needed to support the reshape op.

@@ -402,8 +473,8 @@ AffineMapAttr getTiling2DTo4DMap(OpBuilder &b, Value val) {
return AffineMapAttr::get(map);
}

AffineMapAttr getTiling3DTo4DMap(OpBuilder &b, Value val) {
assert(isTiling3DTo4D(val) &&
AffineMapAttr getLeftmostTiling3DTo4DMap(OpBuilder &b, Value val) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of the prior operations where not specific if they applied to the right most or leftmost position, as only one was needed. As I added more, I made all names more explicit,

IndexExprScope currScope(&rewriter, loc);
// Here, cannot use the shape found in the reshape op, as it is the original
// shape before memref normalization.
Value input = reshapeOp.getX();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check if X is normalized or not before processing further, by using something like this:

    // Input must have no affine layout. In other words, it has been normalized.
    if (hasNonIdentityLayout(input.getType()))
      return failure();

Without this check, I see that, in your lit test zlow-rewrite.mlir, zlow.reshape with affine_maps is still lowered to memref.reinterpret_cast

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I did not realize that zlow-rewrite ran twice. Its now fixed.

// CHECK-LABEL: func.func @handle_zlow_reshape
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<8x384x768xf16, #map>, [[PARAM_1_:%.+]]: memref<96x64x384xf16, #map>) -> memref<96x384x384xf16, #map> {

// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [96, 384, 64], strides: [24576, 64, 1] : memref<8x384x768xf16, #map> to memref<96x384x64xf16>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not look like what we are expecting since the input memref is not normalized.
To check this case, you can

  • add a new check for this case where we call --normalize-memrefs, by adding this line to top of this file:
// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --normalize-memrefs --zlow-rewrite --canonicalize %s -split-input-file | FileCheck %s --check-prefix=RESHAPE
  • then, replace CHECK by RESHAPE in CHECK-DAG, CHECK-LABEL, ... since we use the prefix RESHAPE for this check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply wrote 2 versions of that test, one without and one with memref normalized, and checked that the pattern only applies with memref normalized.

Signed-off-by: Alexandre Eichenberger <[email protected]>
@AlexandreEichenberger
Copy link
Collaborator Author

Thanks @tungld for the feedback, implemented both suggestions.

Copy link
Collaborator

@tungld tungld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Glad to see the performance improvement!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants