Skip to content

Commit

Permalink
Merge commit '2341f207' into matthias.update_torch_stable
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Dec 12, 2024
2 parents e98c52f + 2341f20 commit 35789a6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
20 changes: 11 additions & 9 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1084,24 +1084,26 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {

Value self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType());
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
auto outType =
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
Value expTensor = adaptor.getExponent();
auto expTensorTy = dyn_cast<RankedTensorType>(expTensor.getType());

if (!selfTy)
if (!selfTy || !outType || !expTensorTy) {
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");
}

if (!isa<mlir::FloatType>(selfTy.getElementType()))
if (!isa<mlir::FloatType>(selfTy.getElementType())) {
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
}

auto outType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));

Value expTensor = adaptor.getExponent();
if (expTensor.getType() != selfTy) {
if (expTensorTy.getElementType() != selfTy.getElementType()) {
expTensor = rewriter.createOrFold<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(outType.getShape(), selfTy.getElementType()),
RankedTensorType::get(expTensorTy.getShape(), selfTy.getElementType()),
expTensor);
}

Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,7 @@
"ElementwisePowTensorBroadcastModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwisePowTensorModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwisePreluModule_basic",
"ElementwisePreluStaticModule_basic",
"ElementwiseRad2DegModule_basic",
Expand Down

0 comments on commit 35789a6

Please sign in to comment.