Skip to content

Commit

Permalink
TorchToTosa: Correctly lower pow with broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Dec 11, 2024
1 parent d42a648 commit 8a224ba
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
5 changes: 3 additions & 2 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1098,10 +1098,11 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
cast<TensorType>(getTypeConverter()->convertType(op.getType()));

Value expTensor = adaptor.getExponent();
if (expTensor.getType() != selfTy) {
auto expTensorTy = cast<RankedTensorType>(expTensor.getType());
if (expTensorTy.getElementType() != selfTy.getElementType()) {
expTensor = rewriter.createOrFold<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(outType.getShape(), selfTy.getElementType()),
RankedTensorType::get(expTensorTy.getShape(), selfTy.getElementType()),
expTensor);
}

Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2161,6 +2161,7 @@
"PermuteNegativeIndexModule_basic",
"PowFloatFloatModule_basic",
"PowFloatIntModule_basic",
"PowBroadcastModule_basic",
"PrimListUnpackNumMismatchModule_basic",
"PrimsIotaModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
Expand Down
24 changes: 24 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4551,6 +4551,30 @@ def PowFloatFloatModule_basic(module, tu: TestUtils):
# ==============================================================================


class PowBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([], torch.float32, True),
]
)
def forward(self, x, y):
return torch.ops.aten.pow(x, y)


@register_test_case(module_factory=lambda: PowBroadcastModule())
def PowBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), torch.ones([]))


# ==============================================================================


class PowIntFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 8a224ba

Please sign in to comment.