Skip to content

Commit

Permalink
[TOSA] Add promote type to unary ops and aten.cat lowering (llvm#3860)
Browse files Browse the repository at this point in the history
Change-Id: I2699bf9007723fe629edb1c524c10ef8142e0234

Signed-off-by: Justin Ngo <[email protected]>
  • Loading branch information
justin-ngo-arm authored Nov 8, 2024
1 parent b6f04fa commit 8eb34da
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
16 changes: 12 additions & 4 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,16 @@ class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
auto self = adaptor.getSelf();

auto outType = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.getSelf());
op.getType()));

self = tosa::promoteType(rewriter, self, outType);

rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, self);

return success();
}
};
Expand Down Expand Up @@ -6091,6 +6096,9 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
auto builtinTensors =
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);

for (auto &tensor : builtinTensors)
tensor = tosa::promoteType(rewriter, tensor, outType);

auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim));
rewriter.replaceOp(op, result.getResult());
Expand Down
27 changes: 16 additions & 11 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,12 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
"ElementwiseRsqrtIntModule_basic",
"ElementwiseSinIntModule_basic",
"FloatPowerTensorTensorStaticModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"CollapseAllDimensionsModule_basic",
"CollapseRank1DynamicModule_basic",
Expand Down Expand Up @@ -1786,7 +1792,6 @@
"SliceCopy_Module_basic",
"Threshold1dIntModule_basic",
"Threshold2dIntModule_basic",
"Threshold3dIntModule_basic",
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
Expand Down Expand Up @@ -2435,6 +2440,7 @@
TOSA_PASS_SET
| {
### Tests additionally passing in make_fx_tosa
"IsInfiniteModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"ResNet18StaticModule_basic",
Expand Down Expand Up @@ -2510,6 +2516,8 @@
}
) - {
### Test failing in make_fx_tosa but not in tosa
"AdaptiveMaxPool1dDimOneStatic_basic",
"FloatPowerTensorTensorStaticModule_basic",
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
Expand Down Expand Up @@ -3390,6 +3398,11 @@
}

FX_IMPORTER_TOSA_XFAIL_SET = {
"IsInfiniteModule_basic",
"LayerNormFwAndBwModule_basic",
"LayerNormManualFwAndBwModule_basic",
"SelfAttentionFwAndBwModule_basic",
"Threshold3dIntModule_basic",
"ElementwiseCopysignModule_basic",
"ElementwiseSignbitModule_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
Expand Down Expand Up @@ -3417,9 +3430,6 @@
"AtenPolarDoubleModule_basic",
"AtenPolarFloatModule_basic",
"HstackBasicComplexModule_basic",
"HstackBasicFloatModule_basic",
"HstackBasicIntFloatModule_basic",
"HstackBasicIntModule_basic",
"AtenIntMM_basic",
"AtenKthvalueDynamicDimsModule_basic",
"AtenKthvalueFloat64DynamicDimsModule_basic",
Expand Down Expand Up @@ -3597,8 +3607,6 @@
"ElementwiseAtanTensorIntModule_basic",
"ElementwiseAtanhIntModule_basic",
"ElementwiseAtanhModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseCosIntModule_basic",
"ElementwiseCoshIntModule_basic",
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
Expand All @@ -3620,10 +3628,7 @@
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseReciprocalIntModule_basic",
"ElementwiseRsqrtIntModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSinIntModule_basic",
"ElementwiseSinhIntModule_basic",
"ElementwiseSinhModule_basic",
"ElementwiseTanIntModule_basic",
Expand Down Expand Up @@ -3850,8 +3855,6 @@
"TensorToFloat_basic",
"TensorToIntZeroRank_basic",
"TensorToInt_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"ThresholdBackward2dMixedModule_basic",
"ToCopyWithDTypeFalsePinMemoryModule_basic",
Expand Down Expand Up @@ -3931,6 +3934,8 @@
}

ONNX_TOSA_XFAIL_SET = {
"FloatPowerTensorTensorStaticModule_basic",
"IsInfiniteModule_basic",
"ElementwiseCopysignModule_basic",
"ElementwiseFracModule_basic",
"ElementwiseLdexpModule_basic",
Expand Down

0 comments on commit 8eb34da

Please sign in to comment.