Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for Roberta unstick->reshape->transpose->reshape->stick #3056

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions docs/Dialects/zhigh.md
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,6 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
Expand All @@ -814,13 +807,6 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
Expand Down Expand Up @@ -857,6 +843,40 @@ Effects: `MemoryEffects::Effect{}`
| :----: | ----------- |
| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH

### `zhigh.Reshape` (::onnx_mlir::zhigh::ZHighReshapeOp)

_ZHigh Reshape operation for Z Tensors_

ZHigh operation to perform a converts a Z Tensor from one type to an equivalent type
with a provided shape. The data is never copied or modified. When no layout is specified,
the output preserve the same layout as the source input.

Traits: `AlwaysSpeculatableImplTrait`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `source` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
| `shape` | tensor of 64-bit signless integer values

#### Results:

| Result | Description |
| :----: | ----------- |
| `result` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH

### `zhigh.Sigmoid` (::onnx_mlir::zhigh::ZHighSigmoidOp)

_ZHigh Sigmoid operation_
Expand Down
25 changes: 23 additions & 2 deletions docs/Dialects/zlow.md
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,6 @@ Traits: `MemRefsNormalizable`
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:
Expand All @@ -795,7 +794,6 @@ Traits: `MemRefsNormalizable`
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -832,6 +830,29 @@ Interfaces: `MemoryEffectOpInterface`
| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values

### `zlow.reshape` (::onnx_mlir::zlow::ZLowReshapeOp)

_ZLow Reshape operation_

ZLow operation to perform a reshape (no data movement).

Traits: `MemRefsNormalizable`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `X` | memref of dlfloat16 type values
| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values

### `zlow.sigmoid` (::onnx_mlir::zlow::ZLowSigmoidOp)

_ZLow sigmoid operation_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,6 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern<
//
//===----------------------------------------------------------------------===//


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

migrated the code elsewhere so that it can be reused, as it was needed to support the reshape op.

// Create an ONNX Shape Op with type
def CreateShapeOp: NativeCodeCall<
"$_builder.create<mlir::ONNXShapeOp>($_loc, $0, $1, IntegerAttr(), 0)"
>;

// Get a type for a tensor that stores the shape of another tensor.
def GetShapeTypeOf: NativeCodeCall<
"RankedTensorType::get({mlir::cast<ShapedType>($0.getType()).getRank()}, $_builder.getIntegerType(64))"
>;

// Check unidirectional broadcasting from the first to second tensor.
def IsUniBroadcastingFromFirstToSecond: Constraint<
CPred<"isUniBroadcatableFirstToSecond($0, $1)">,
Expand Down
45 changes: 43 additions & 2 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,47 @@ struct ZHighToZLowUnaryOpLowering : public ConversionPattern {
}
};

// Reshape operation. Code similar to unary lowering, except that we use the
// operation's specialized shape here.
struct ZHighToZLowReshapeOpLowering : public ConversionPattern {
ZHighToZLowReshapeOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(ZHighReshapeOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = operands[0];

// Helper builders.
MultiDialectBuilder<IndexExprBuilderForKrnl> create(rewriter, loc);

// Convert ZTensor type to MemRefType.
ZMemRefType zMemRefType =
convertZTensorToMemRefType(*op->result_type_begin());

// Shape helper.
ZHighReshapeOpShapeHelper shapeHelper(op, operands, &create.krnlIE);
shapeHelper.computeShapeAndAssertOnFailure();
SmallVector<IndexExpr, 4> &dims = shapeHelper.getOutputDims();

// Allocate a buffer for the result MemRef. Follow this pattern to be
// similar to all the other zlow patterns. Will remove the alloc when
// lowering zlow.reshape to memref.reinterpret_cast once memrefs are
// normalized. See code in ReshapeToReinterpretCastPattern.
Value alloc = insertAllocForZMemRef(zMemRefType, dims, op, rewriter);

// Note, we do not need to save the shape of the original operation, as this
// reshape is "no-op" that logically reorganize the shape of the operation
// into 2 equivalent shapes under their given layout.

// Emit a ZLow operation.
rewriter.create<ZLowReshapeOp>(
loc, input, /* shape,*/ alloc, zMemRefType.layout);
rewriter.replaceOp(op, alloc);
return success();
}
};

//===----------------------------------------------------------------------===//
// Lower ZHigh ReduceMax/ReduceMin to ZLow ReduceMax/ReduceMin
//===----------------------------------------------------------------------===//
Expand All @@ -1117,8 +1158,6 @@ struct ZHighToZLowReduceOpLowering : public ConversionPattern {
: ConversionPattern(OP_TYPE::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
MLIRContext *context = rewriter.getContext();
OP_TYPE reduceOp = mlir::cast<OP_TYPE>(op);
Location loc = op->getLoc();
Value data = operands[0];

Expand Down Expand Up @@ -2285,6 +2324,8 @@ void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns,
patterns.insert<ZHighToZLowUnaryOpLowering<ZHighTanhOp>>(typeConverter, ctx);
patterns.insert<ZHighToZLowUnaryOpLowering<ZHighSigmoidOp>>(
typeConverter, ctx);
// Reshape operations.
patterns.insert<ZHighToZLowReshapeOpLowering>(typeConverter, ctx);
// Neural network operations.
patterns.insert<ZHighToZLowReduceOpLowering<ZHighReduceMaxOp>>(
typeConverter, ctx);
Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_onnx_mlir_library(OMZHighOps
ZHighOps/QuantizedMatMul/QuantizedMatMul.cpp
ZHighOps/QuantizedStick/QuantizedStick.cpp
ZHighOps/Reduction/Reduction.cpp
ZHighOps/Reshape/Reshape.cpp
ZHighOps/Softmax/Softmax.cpp
ZHighOps/Stick/Stick.cpp
ZHighOps/StickForGRU/StickForGRU.cpp
Expand Down
24 changes: 24 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -1195,4 +1195,28 @@ def ZHighFixGRUYhOp:ZHigh_Op<"FixGRUYh", [Pure,
}];
}

def ZHighReshapeOp:ZHigh_Op<"Reshape", [Pure,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
let summary = "ZHigh Reshape operation for Z Tensors";
let description = [{
ZHigh operation to perform a converts a Z Tensor from one type to an equivalent type
with a provided shape. The data is never copied or modified. When no layout is specified,
the output preserve the same layout as the source input.
}];
let arguments = (ins AnyTypeOf<[AnyZTensor]>:$source, // Input Z Tensor to be reshaped.
TensorOf<[I64]>:$shape, // Shape of output Z Tensor.
OptionalAttr<StrAttr>:$layout); // Layout of output Z Tensor, default same as input.
let results = (outs AnyTypeOf<[AnyZTensor]>:$result);

let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * ZHighReshapeOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef<mlir::Value> oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
onnx_mlir::ONNXOpShapeHelper *sh = new ZHighReshapeOpShapeHelper(op, oper, ieb, scope);
assert(sh && "failed to allocate shape helper");
return sh;
}
}];
}

#endif // ZHIGH_OPS
Loading