diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 882c7d889d4d..cdac8508b468 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1084,24 +1084,26 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); + auto selfTy = dyn_cast(self.getType()); + auto outType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + Value expTensor = adaptor.getExponent(); + auto expTensorTy = dyn_cast(expTensor.getType()); - if (!selfTy) + if (!selfTy || !outType || !expTensorTy) { return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Pow"); + } - if (!isa(selfTy.getElementType())) + if (!isa(selfTy.getElementType())) { return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization supported"); + } - auto outType = - cast(getTypeConverter()->convertType(op.getType())); - - Value expTensor = adaptor.getExponent(); - if (expTensor.getType() != selfTy) { + if (expTensorTy.getElementType() != selfTy.getElementType()) { expTensor = rewriter.createOrFold( op->getLoc(), - RankedTensorType::get(outType.getShape(), selfTy.getElementType()), + RankedTensorType::get(expTensorTy.getShape(), selfTy.getElementType()), expTensor); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fc5102e63b19..3606e48d8996 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1993,6 +1993,7 @@ "ElementwisePowTensorBroadcastModule_basic", "ElementwisePowTensorBroadcastStaticModule_basic", "ElementwisePowTensorModule_basic", + "ElementwisePowTensorStaticModule_basic", "ElementwisePreluModule_basic", "ElementwisePreluStaticModule_basic", "ElementwiseRad2DegModule_basic",