Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for Roberta unstick->reshape->transpose->reshape->stick #3056

Merged
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ee16dee
merge from remote branch
AlexandreEichenberger Dec 19, 2024
5b6b918
added files
AlexandreEichenberger Dec 19, 2024
5e7e21f
fix tests
AlexandreEichenberger Dec 19, 2024
97d871a
update
AlexandreEichenberger Dec 20, 2024
903fcb4
update
AlexandreEichenberger Jan 9, 2025
0cb084d
update
AlexandreEichenberger Jan 9, 2025
5355c04
update
AlexandreEichenberger Jan 14, 2025
a33fa39
add rudimentary roberta pattern
AlexandreEichenberger Jan 15, 2025
44c8f79
update
AlexandreEichenberger Jan 15, 2025
a9e8713
update
AlexandreEichenberger Jan 16, 2025
d311ae5
added test for size mod 32 and 64
AlexandreEichenberger Jan 16, 2025
88e7337
added lit tests
AlexandreEichenberger Jan 17, 2025
b20b23c
get second pattern
AlexandreEichenberger Jan 23, 2025
88227da
added ZHighReshapeOp
AlexandreEichenberger Jan 28, 2025
f3ac015
added all the parts
AlexandreEichenberger Jan 29, 2025
a9292ba
update
AlexandreEichenberger Jan 29, 2025
9e44254
added testing the option under test-compiler-option
AlexandreEichenberger Jan 29, 2025
e33c2d3
update
AlexandreEichenberger Jan 29, 2025
ffd8b01
testing
AlexandreEichenberger Jan 29, 2025
adf3efd
remove shape from zlow.reshape
AlexandreEichenberger Jan 29, 2025
aed20d3
added lit tests for reshape in zhigh and zlow
AlexandreEichenberger Jan 29, 2025
a27daf0
update
AlexandreEichenberger Jan 29, 2025
6c190e7
update
AlexandreEichenberger Jan 30, 2025
94f31f7
remove test
AlexandreEichenberger Jan 30, 2025
4cbead0
Merge branch 'main' into transpose-nnpa-v1
AlexandreEichenberger Jan 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions docs/Dialects/zhigh.md
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,6 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
Expand All @@ -814,13 +807,6 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
Expand Down Expand Up @@ -857,6 +843,40 @@ Effects: `MemoryEffects::Effect{}`
| :----: | ----------- |
| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH

### `zhigh.Reshape` (::onnx_mlir::zhigh::ZHighReshapeOp)

_ZHigh Reshape operation for Z Tensors_

ZHigh operation to perform a converts a Z Tensor from one type to an equivalent type
with a provided shape. The data is never copied or modified. When no layout is specified,
the output preserve the same layout as the source input.

Traits: `AlwaysSpeculatableImplTrait`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `source` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
| `shape` | tensor of 64-bit signless integer values

#### Results:

| Result | Description |
| :----: | ----------- |
| `result` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH

### `zhigh.Sigmoid` (::onnx_mlir::zhigh::ZHighSigmoidOp)

_ZHigh Sigmoid operation_
Expand Down
25 changes: 23 additions & 2 deletions docs/Dialects/zlow.md
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,6 @@ Traits: `MemRefsNormalizable`
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:
Expand All @@ -795,7 +794,6 @@ Traits: `MemRefsNormalizable`
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -832,6 +830,29 @@ Interfaces: `MemoryEffectOpInterface`
| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values

### `zlow.reshape` (::onnx_mlir::zlow::ZLowReshapeOp)

_ZLow Reshape operation_

ZLow operation to perform a reshape (no data movement).

Traits: `MemRefsNormalizable`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `X` | memref of dlfloat16 type values
| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values

### `zlow.sigmoid` (::onnx_mlir::zlow::ZLowSigmoidOp)

_ZLow sigmoid operation_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,6 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern<
//
//===----------------------------------------------------------------------===//


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

migrated the code elsewhere so that it can be reused, as it was needed to support the reshape op.

// Create an ONNX Shape Op with type
def CreateShapeOp: NativeCodeCall<
"$_builder.create<mlir::ONNXShapeOp>($_loc, $0, $1, IntegerAttr(), 0)"
>;

// Get a type for a tensor that stores the shape of another tensor.
def GetShapeTypeOf: NativeCodeCall<
"RankedTensorType::get({mlir::cast<ShapedType>($0.getType()).getRank()}, $_builder.getIntegerType(64))"
>;

// Check unidirectional broadcasting from the first to second tensor.
def IsUniBroadcastingFromFirstToSecond: Constraint<
CPred<"isUniBroadcatableFirstToSecond($0, $1)">,
Expand Down
45 changes: 43 additions & 2 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,47 @@ struct ZHighToZLowUnaryOpLowering : public ConversionPattern {
}
};

// Reshape operation. Code similar to unary lowering, except that we use the
// operation's specialized shape here.
struct ZHighToZLowReshapeOpLowering : public ConversionPattern {
ZHighToZLowReshapeOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(ZHighReshapeOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = operands[0];

// Helper builders.
MultiDialectBuilder<IndexExprBuilderForKrnl> create(rewriter, loc);

// Convert ZTensor type to MemRefType.
ZMemRefType zMemRefType =
convertZTensorToMemRefType(*op->result_type_begin());

// Shape helper.
ZHighReshapeOpShapeHelper shapeHelper(op, operands, &create.krnlIE);
shapeHelper.computeShapeAndAssertOnFailure();
SmallVector<IndexExpr, 4> &dims = shapeHelper.getOutputDims();

// Allocate a buffer for the result MemRef. Follow this pattern to be
// similar to all the other zlow patterns. Will remove the alloc when
// lowering zlow.reshape to memref.reinterpret_cast once memrefs are
// normalized. See code in ReshapeToReinterpretCastPattern.
Value alloc = insertAllocForZMemRef(zMemRefType, dims, op, rewriter);

// Note, we do not need to save the shape of the original operation, as this
// reshape is "no-op" that logically reorganize the shape of the operation
// into 2 equivalent shapes under their given layout.

// Emit a ZLow operation.
rewriter.create<ZLowReshapeOp>(
loc, input, /* shape,*/ alloc, zMemRefType.layout);
rewriter.replaceOp(op, alloc);
return success();
}
};

//===----------------------------------------------------------------------===//
// Lower ZHigh ReduceMax/ReduceMin to ZLow ReduceMax/ReduceMin
//===----------------------------------------------------------------------===//
Expand All @@ -1117,8 +1158,6 @@ struct ZHighToZLowReduceOpLowering : public ConversionPattern {
: ConversionPattern(OP_TYPE::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
MLIRContext *context = rewriter.getContext();
OP_TYPE reduceOp = mlir::cast<OP_TYPE>(op);
Location loc = op->getLoc();
Value data = operands[0];

Expand Down Expand Up @@ -2285,6 +2324,8 @@ void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns,
patterns.insert<ZHighToZLowUnaryOpLowering<ZHighTanhOp>>(typeConverter, ctx);
patterns.insert<ZHighToZLowUnaryOpLowering<ZHighSigmoidOp>>(
typeConverter, ctx);
// Reshape operations.
patterns.insert<ZHighToZLowReshapeOpLowering>(typeConverter, ctx);
// Neural network operations.
patterns.insert<ZHighToZLowReduceOpLowering<ZHighReduceMaxOp>>(
typeConverter, ctx);
Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_onnx_mlir_library(OMZHighOps
ZHighOps/QuantizedMatMul/QuantizedMatMul.cpp
ZHighOps/QuantizedStick/QuantizedStick.cpp
ZHighOps/Reduction/Reduction.cpp
ZHighOps/Reshape/Reshape.cpp
ZHighOps/Softmax/Softmax.cpp
ZHighOps/Stick/Stick.cpp
ZHighOps/StickForGRU/StickForGRU.cpp
Expand Down
24 changes: 24 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -1195,4 +1195,28 @@ def ZHighFixGRUYhOp:ZHigh_Op<"FixGRUYh", [Pure,
}];
}

def ZHighReshapeOp:ZHigh_Op<"Reshape", [Pure,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
let summary = "ZHigh Reshape operation for Z Tensors";
let description = [{
ZHigh operation to perform a converts a Z Tensor from one type to an equivalent type
with a provided shape. The data is never copied or modified. When no layout is specified,
the output preserve the same layout as the source input.
}];
let arguments = (ins AnyTypeOf<[AnyZTensor]>:$source, // Input Z Tensor to be reshaped.
TensorOf<[I64]>:$shape, // Shape of output Z Tensor.
OptionalAttr<StrAttr>:$layout); // Layout of output Z Tensor, default same as input.
let results = (outs AnyTypeOf<[AnyZTensor]>:$result);

let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * ZHighReshapeOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef<mlir::Value> oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
onnx_mlir::ONNXOpShapeHelper *sh = new ZHighReshapeOpShapeHelper(op, oper, ieb, scope);
assert(sh && "failed to allocate shape helper");
return sh;
}
}];
}

#endif // ZHIGH_OPS
Loading