diff --git a/CMakeLists.txt b/CMakeLists.txt index 4740f2312394..0c562fbe31c0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,13 +54,6 @@ cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLI option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF) -# TODO(#3299): migrate to from member x.cast() to mlir::cast(x). -if(MSVC) - add_compile_options(/wd4996) -else() - add_compile_options(-Wno-deprecated-declarations) -endif() - macro(torch_mlir_enable_werror) if(TORCH_MLIR_ENABLE_WERROR_FLAG) if(NOT MSVC) diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td index 8e7be05e198c..3dce86149fa8 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td @@ -125,7 +125,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::copy_if(getInputOperands(), std::back_inserter(result), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); return result; }] @@ -144,7 +144,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::copy_if(getInputOperands(), std::back_inserter(result), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); return result; }] @@ -200,7 +200,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::copy_if(getOutputOperands(), std::back_inserter(result), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); return result; }] @@ -219,7 +219,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::copy_if(getOutputOperands(), std::back_inserter(result), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); return result; }] @@ -238,7 +238,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::transform(getOutputBufferOperands(), std::back_inserter(result), [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); + return cast(opOperands->get().getType()); }); return result; }] @@ -257,7 +257,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::transform(getOutputTensorOperands(), std::back_inserter(result), [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); + return cast(opOperands->get().getType()); }); return result; }] @@ -318,7 +318,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) + if (!isa(opOperand->get().getType())) return false; if (opOperand->getOperandNumber() < $_op.getNumInputs()) return true; @@ -334,7 +334,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) + if (!isa(opOperand->get().getType())) return false; if (opOperand->getOperandNumber() >= $_op.getNumInputs()) return true; @@ -367,7 +367,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); if (auto shapedType = - opOperand->get().getType().template dyn_cast()) + dyn_cast(opOperand->get().getType())) return shapedType.getRank(); return 0; }] @@ -383,7 +383,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); if (auto shapedType = - opOperand->get().getType().template dyn_cast()) + dyn_cast(opOperand->get().getType())) return shapedType.getShape(); return {}; }] @@ -398,7 +398,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); - return !opOperand->get().getType().template isa(); + return !isa(opOperand->get().getType()); }] >, //===------------------------------------------------------------------===// @@ -416,10 +416,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { return this->getOperation()->getNumResults() == 0 && llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) { return isScalar(opOperand) || - opOperand->get().getType().template isa(); + isa(opOperand->get().getType()); }) && llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); }] >, @@ -435,10 +435,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { return llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) { return isScalar(opOperand) || - opOperand->get().getType().template isa(); + isa(opOperand->get().getType()); }) && llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); }] >, @@ -478,8 +478,8 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { private: void setOperandSegmentAt(unsigned idx, unsigned val) { - auto attr = (*this)->getAttr("operand_segment_sizes") - .cast(); + auto attr = cast((*this)->getAttr("operand_segment_sizes") + ); unsigned i = 0; auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index 12a74faa44d3..dc745097c5fb 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -88,7 +88,7 @@ def TMTensor_ScanOp : TMTensor_Op<"scan", return getOutputOperand(0)->get(); } ShapedType getOperandType() { - return input().getType().cast(); + return cast(input().getType()); } int64_t getOperandRank() { return getOperandType().getRank(); @@ -151,10 +151,10 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{ int64_t getIndexDepth() { - return getInputOperand(1) + return cast(getInputOperand(1) ->get() .getType() - .cast() + ) .getShape() .back(); } @@ -164,7 +164,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", } ShapedType getUpdateType() { - return updates().getType().cast(); + return cast(updates().getType()); } Value indices() { @@ -172,7 +172,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", } ShapedType getIndicesType() { - return indices().getType().cast(); + return cast(indices().getType()); } Value original() { @@ -180,11 +180,11 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", } ShapedType getOriginalType() { - return original().getType().cast(); + return cast(original().getType()); } int64_t getUpdateSliceRank() { - return updates().getType().cast().getRank() - 1; + return cast(updates().getType()).getRank() - 1; } bool isScalarUpdate() { @@ -224,7 +224,7 @@ def TMTensor_SortOp : TMTensor_Op<"sort", return getOutputs()[index]; } ShapedType getOperandType(int index) { - return operand(index).getType().cast(); + return cast(operand(index).getType()); } int64_t getOperandRank() { return getOperandType(0).getRank(); @@ -291,16 +291,16 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", return getOutputOperand(0)->get(); } ShapedType getQueryType() { - return getQuery().getType().cast(); + return cast(getQuery().getType()); } ShapedType getKeyType() { - return getKey().getType().cast(); + return cast(getKey().getType()); } ShapedType getValueType() { - return getValue().getType().cast(); + return cast(getValue().getType()); } ShapedType getOutputType() { - return getOutput().getType().cast(); + return cast(getOutput().getType()); } int64_t getQueryRank() { return getQueryType().getRank(); diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 919146c6a1c7..97d004e367ba 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -61,12 +61,12 @@ struct onnx_list_of_constant_ints_op_binder { bool match(Operation *op) { auto constOp = dyn_cast(op); - if (!constOp || !constOp.getName().equals("onnx.Constant")) + if (!constOp || !(constOp.getName() == "onnx.Constant")) return false; if (DenseResourceElementsAttr attr = - constOp->getAttr("torch.onnx.value") - .dyn_cast_or_null()) { + dyn_cast_or_null( + constOp->getAttr("torch.onnx.value"))) { // Bytes are stored in little endian order. Big endian support will // require swizzling. if (!Endian::little) { diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index f49fef0721c2..d5db519bef17 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -190,7 +190,7 @@ struct torch_list_of_optional_constant_ints_op_binder { int64_t num; if (matchPattern(value, m_TorchConstantInt(&num))) bind_values.push_back(num); - else if (value.getType().isa()) + else if (isa(value.getType())) bind_values.push_back(std::nullopt); else return false; diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index f578cefe0297..65f514c2ede9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -442,8 +442,8 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [ }]; let extraClassDeclaration = [{ - Type getKeyType() { return getType().cast().getKeyType(); } - Type getValueType() { return getType().cast().getValueType(); } + Type getKeyType() { return cast(getType()).getKeyType(); } + Type getValueType() { return cast(getType()).getValueType(); } }]; } @@ -1003,7 +1003,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [ DeclareOpInterfaceMethods, TypesMatchWith<"operand is corresponding !torch.vtensor", "result", "operand", - "$_self.cast().getWithValueSemantics()">, + "cast($_self).getWithValueSemantics()">, ]> { let summary = "Create a !torch.tensor with the same contents as the operand"; let description = [{ @@ -1036,7 +1036,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [ DeclareOpInterfaceMethods, TypesMatchWith<"operand is corresponding !torch.tensor", "result", "operand", - "$_self.cast().getWithoutValueSemantics()">, + "cast($_self).getWithoutValueSemantics()">, ]> { let summary = "Create a !torch.vtensor with the same contents as the operand"; let description = [{ @@ -1064,7 +1064,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [ def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [ TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type", "value", "overwritten", - "$_self.cast().getWithoutValueSemantics()"> + "cast($_self).getWithoutValueSemantics()"> ]> { let summary = "Ovewrite the contents of tensor with values from another."; let description = [{ diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index e7fc4bc976bb..279e694540f9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -199,7 +199,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> { } def AnyTorchTensorType : Type< - CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">, + CPred<"isa<::mlir::torch::Torch::BaseTensorType>($_self)">, "Any Torch tensor type" >; @@ -410,11 +410,11 @@ def AnyTorchOptionalDeviceType: def AnyTorchOptionalGeneratorType: OptionalOf; -def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">; +def IsListTypePred : CPred<"isa<::mlir::torch::Torch::ListType>($_self)">; class ListOf allowedTypes, string descr> : ContainerType, IsListTypePred, - "$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()", + "cast<::mlir::torch::Torch::ListType>($_self).getContainedType()", descr, "::mlir::torch::Torch::ListType">; def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">; diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index edc85c7e7d63..399915459e40 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -26,7 +26,7 @@ bool torchMlirTypeIsValidSubtype(MlirType subtype, MlirType type) { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchNnModule(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchNnModuleTypeGet(MlirContext context, @@ -43,7 +43,7 @@ MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchOptional(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) { @@ -64,7 +64,7 @@ MlirTypeID torchMlirTorchOptionalTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchTuple(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchTupleTypeGet(MlirContext context, @@ -95,7 +95,7 @@ MlirTypeID torchMlirTorchTupleTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchUnion(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchUnionTypeGet(MlirContext context, @@ -126,7 +126,7 @@ MlirTypeID torchMlirTorchUnionTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchList(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchListTypeGet(MlirType containedType) { @@ -146,7 +146,7 @@ MlirTypeID torchMlirTorchListTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchDevice(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchDeviceTypeGet(MlirContext context) { @@ -162,7 +162,7 @@ MlirTypeID torchMlirTorchDeviceTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchGenerator(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) { @@ -178,7 +178,7 @@ MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchBool(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchBoolTypeGet(MlirContext context) { @@ -194,7 +194,7 @@ MlirTypeID torchMlirTorchBoolTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchInt(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchIntTypeGet(MlirContext context) { @@ -210,7 +210,7 @@ MlirTypeID torchMlirTorchIntTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchFloat(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchFloatTypeGet(MlirContext context) { @@ -226,7 +226,7 @@ MlirTypeID torchMlirTorchFloatTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchLinearParams(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) { @@ -242,7 +242,7 @@ MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchQInt8(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchQInt8TypeGet(MlirContext context) { @@ -258,7 +258,7 @@ MlirTypeID torchMlirTorchQInt8TypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchQUInt8(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) { @@ -274,7 +274,7 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchNonValueTensor(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context, @@ -341,7 +341,7 @@ MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchValueTensor(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchValueTensorTypeGet(MlirContext context, @@ -408,7 +408,7 @@ MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchNone(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchNoneTypeGet(MlirContext context) { @@ -424,7 +424,7 @@ MlirTypeID torchMlirTorchNoneTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchString(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchStringTypeGet(MlirContext context) { @@ -440,7 +440,7 @@ MlirTypeID torchMlirTorchStringTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchAny(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchAnyTypeGet(MlirContext context) { @@ -456,7 +456,7 @@ MlirTypeID torchMlirTorchAnyTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchNumber(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchNumberTypeGet(MlirContext context) { @@ -472,7 +472,7 @@ MlirTypeID torchMlirTorchNumberTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchDict(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 7fc752a14680..e55756eb4305 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -352,7 +352,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rightDimsPrimList); return success(); }); - patterns.onOp("MatMul", 13, + patterns.onOp("MatMul", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; @@ -546,12 +546,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value shuffledPaddingList = createConstantIntList(binder, rewriter, padding); Value zero; - if (resultTypeOut.getDtype().isa()) { + if (isa(resultTypeOut.getDtype())) { zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr( std::numeric_limits::lowest())); - } else if (resultTypeOut.getDtype().isa()) { + } else if (isa(resultTypeOut.getDtype())) { zero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr( std::numeric_limits::lowest())); @@ -1296,7 +1296,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) return failure(); - auto inputTensorType = operand.getType().cast(); + auto inputTensorType = cast(operand.getType()); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure( binder.op, "Expected input type having sizes"); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index eaddb2dd4cda..c9f58cb04f97 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -469,6 +469,41 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "Scatter", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + int64_t axis; + if (binder.s64IntegerAttr(axis, "axis", {})) + return rewriter.notifyMatchFailure(binder.op, "axis bind failure"); + + Torch::ValueTensorType resultTy; + Value data, indices, updates; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorOperandAtIndex(updates, 2) || + binder.tensorResultType(resultTy)) + return failure(); + + auto dataTy = data.getType().cast(), + indicesTy = indices.getType().cast(), + updatesTy = updates.getType().cast(); + + int64_t dataRank = dataTy.getSizes().size(), + indicesRank = indicesTy.getSizes().size(), + updatesRank = updatesTy.getSizes().size(); + + if ((dataRank < 1) || (indicesRank < 1) || (updatesRank < 1) || + (axis < -dataRank) || (axis >= dataRank)) + return failure(); + + Value axisValue = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + + rewriter.replaceOpWithNewOp( + binder.op, resultTy, data, axisValue, indices, updates); + + return success(); + }); patterns.onOp( "ScatterElements", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1023,9 +1058,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value noneVal = rewriter.create(binder.getLoc()); Value constFalse = rewriter.create(binder.getLoc(), false); - auto size = data.getType() - .dyn_cast() - .getOptionalSizes(); + auto size = + dyn_cast(data.getType()).getOptionalSizes(); auto f64ResultType = rewriter.getType( size, rewriter.getF64Type()); Value dataCast = rewriter.create( @@ -3088,8 +3122,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( scalesValueList = noneVal; sizesValueList = getValueList(sizeOperand); } - if (scalesValueList.getType().isa() && - sizesValueList.getType().isa()) { + if (isa(scalesValueList.getType()) && + isa(sizesValueList.getType())) { return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); } rewriter diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 5476462f3c91..b9b0fb0ae5d7 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1828,9 +1828,8 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { const TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); SmallVector resultShape; SmallVector offsets; @@ -2067,9 +2066,8 @@ class ConvertAtenSliceScatterOp auto input = adaptor.getSelf(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); SmallVector resultShape; SmallVector offsets; @@ -2303,9 +2301,8 @@ class ConvertAtenDiagonalOp : public OpConversionPattern { op, "diagonal dimensions cannot be identical"); Type elementType = inputType.getElementType(); - RankedTensorType outputType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType outputType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Location loc = op.getLoc(); Value dim1Size, dim2Size; @@ -2541,9 +2538,8 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { }) .getResult(0); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, resultTensor); return success(); @@ -2568,9 +2564,8 @@ class ConvertSparseOperatorOp : public OpConversionPattern { return failure(); // Conversion is completed specified by information in the sparse tensor // type. Thus, we can rewrite all legalizedNames to the same construct. - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp( op, resultType, adaptor.getOperands()[0]); return success(); diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index ef44cad8d804..fbc5004c94e2 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -845,7 +845,7 @@ class ConvertAtenUpsampleNearest2dOp outputSizeIntValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), outputSizeTorchInt); - if (!op.getScalesH().getType().isa()) { + if (!isa(op.getScalesH().getType())) { // Convert float values to int values. // int_value = (int64_t)ceil(float_value) Value ceilVal = rewriter.create(loc, adaptor.getScalesH()); @@ -858,7 +858,7 @@ class ConvertAtenUpsampleNearest2dOp scaleFactorsInt.push_back(scaleFactorVal); } - if (!op.getScalesW().getType().isa()) { + if (!isa(op.getScalesW().getType())) { // Convert float values to int values. // int_value = (int64_t)ceil(float_value) Value ceilVal = rewriter.create(loc, adaptor.getScalesW()); @@ -1006,7 +1006,7 @@ class ConvertAtenUpsampleNearest2dBackwardOp unsigned hDimOffset = 2; SmallVector scaleFactorsFloatValues; - if (!op.getScalesH().getType().isa()) { + if (!isa(op.getScalesH().getType())) { scaleFactorsFloatValues.push_back(adaptor.getScalesH()); } else { auto scaleFactorVal = rewriter.create( @@ -1019,7 +1019,7 @@ class ConvertAtenUpsampleNearest2dBackwardOp scaleFactorsFloatValues.push_back(scaleFactorVal); } - if (!op.getScalesW().getType().isa()) { + if (!isa(op.getScalesW().getType())) { scaleFactorsFloatValues.push_back(adaptor.getScalesW()); } else { auto scaleFactorVal = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index f184f77d87f8..373ed076551b 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -41,7 +41,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, return; int64_t minSI = -(1 << (numBits - 1)); Value minSIValue = rewriter.create( - loc, minSI, zp.getType().cast().getWidth()); + loc, minSI, cast(zp.getType()).getWidth()); zp = rewriter.create(loc, zp, minSIValue); minSIValue = rewriter.create(loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( @@ -1057,10 +1057,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { loc, getAsOpFoldResult(outDims), accumulatorDType); Value outputTensor; - if (accumulatorDType != resultDTy && !bias.getType().isa()) + if (accumulatorDType != resultDTy && !isa(bias.getType())) bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias, accumulatorDType); - if (bias.getType().isa()) { + if (isa(bias.getType())) { Value c0; if (isa(accumulatorDType)) { c0 = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 454c9d408ea4..36fa9dc56f82 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -440,10 +440,8 @@ class ConvertAtenMaxPool2dWithIndicesOp Value self = adaptor.getSelf(); RankedTensorType selfType = cast(self.getType()); Type elementType = selfType.getElementType(); - RankedTensorType indicesRankedTensorType = - getTypeConverter() - ->convertType(op->getResult(1).getType()) - .cast(); + RankedTensorType indicesRankedTensorType = cast( + getTypeConverter()->convertType(op->getResult(1).getType())); // TODO: Add support for 3D inputs. if (selfType.getRank() == 3) @@ -750,10 +748,10 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper { Location loc = op->getLoc(); const TypeConverter *typeConverter = opConversionPattern.getTypeConverter(); - outputType = typeConverter->convertType(op.getResult0().getType()) - .template cast(); - auxTensorType = typeConverter->convertType(op.getResult1().getType()) - .template cast(); + outputType = cast( + typeConverter->convertType(op.getResult0().getType())); + auxTensorType = cast( + typeConverter->convertType(op.getResult1().getType())); Type auxTensorElementType = auxTensorType.getElementType(); auto smallestFPValueAttr = rewriter.getFloatAttr( elementType, @@ -832,8 +830,8 @@ class AdaptiveAvgPoolingHelper : public AdaptivePoolingHelper { Location loc = op->getLoc(); const TypeConverter *typeConverter = opConversionPattern.getTypeConverter(); - outputType = typeConverter->convertType(op.getResult().getType()) - .template cast(); + outputType = cast( + typeConverter->convertType(op.getResult().getType())); buffVal = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 0)); auxTensor = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 1d7bfbaacb19..40ab475ca2dd 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -42,9 +42,8 @@ class ConvertAtenDropoutOp : public OpConversionPattern { if (train) return failure(); - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getInput()); return success(); @@ -60,8 +59,8 @@ static Value toLinearIndex(OpBuilder &b, Location loc, Value result = b.create(loc, b.getZeroAttr(b.getI64Type())); for (auto [index, stride] : llvm::zip(indicesIntValues, shapeIntValues)) { - assert(index.getType().isa() && - stride.getType().isa() && + assert(isa(index.getType()) && + isa(stride.getType()) && "Input arrays to `toLinearIndex` must only contain values of type " "`mlir::IntegerType`"); Value mul = b.create(loc, result, stride); @@ -129,7 +128,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { if (!isa(elemTy)) return rewriter.notifyMatchFailure(op, "This op only support float type"); - if (!generator.getType().isa()) + if (!isa(generator.getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); @@ -180,7 +179,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { b.create(loc, updateFloat, scale); Value res = b.create(loc, updateScaled, min); Value truncRes = res; - if (elemTy.isa()) + if (isa(elemTy)) truncRes = b.create(loc, elemTy, res); b.create(loc, truncRes); }) diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index cc86f0eeda60..0e1f6426f958 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -86,11 +86,8 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { bool isUnsigned = false; if (!isa(inElementType)) { if (isa(inElementType)) { - auto integerTy = op.getSelf() - .getType() - .template cast() - .getDtype() - .template dyn_cast(); + auto integerTy = dyn_cast( + cast(op.getSelf().getType()).getDtype()); isUnsigned = integerTy.isUnsigned(); } else { return rewriter.notifyMatchFailure( @@ -280,7 +277,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem, Type resultElementType) { - if (elem.getType().isa()) { + if (isa(elem.getType())) { return b.create(loc, elem); } @@ -376,11 +373,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, if (isa(resultElementType)) return b.create(loc, self, result); else if (isa(resultElementType)) { - IntegerType intType = max.getSelf() - .getType() - .cast() - .getDtype() - .dyn_cast(); + IntegerType intType = dyn_cast( + cast(max.getSelf().getType()).getDtype()); if (intType.isUnsigned()) return b.create(loc, self, result); if (intType.isSigned()) @@ -393,11 +387,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, if (isa(resultElementType)) return b.create(loc, self, result); else if (isa(resultElementType)) { - IntegerType intType = min.getSelf() - .getType() - .cast() - .getDtype() - .dyn_cast(); + IntegerType intType = dyn_cast( + cast(min.getSelf().getType()).getDtype()); if (intType.isUnsigned()) return b.create(loc, self, result); if (intType.isSigned()) @@ -657,9 +648,8 @@ class ConvertReductionOp : public ConversionPattern { return opInfo; Location loc = op->getLoc(); - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type elemType = resultType.getElementType(); LogicalResult elemTypeCheck = validateReductionElementType(op, elemType, rewriter); diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index b467d8c6f7b9..06da3e0018e7 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -179,15 +179,13 @@ class ConvertAtenReplicationPad2dOp for (auto i : {TOP, VCENTER, BOTTOM}) { for (auto j : {LEFT, HCENTER, RIGHT}) { - auto constVtile{ + auto constVtile{dyn_cast_or_null( mlir::dyn_cast(vTile[i].getDefiningOp()) - .getValue() - .dyn_cast_or_null()}; + .getValue())}; - auto constHtile{ + auto constHtile{dyn_cast_or_null( mlir::dyn_cast(hTile[j].getDefiningOp()) - .getValue() - .dyn_cast_or_null()}; + .getValue())}; auto vSize = constVtile.getInt(); auto hSize = constHtile.getInt(); @@ -369,8 +367,8 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern { for (auto size : resultSize) resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); - auto resultType = typeConverter->convertType(op.getType()) - .template cast(); + auto resultType = + cast(typeConverter->convertType(op.getType())); Type resultElementType; if (isa(op.getDtype().getType())) { resultElementType = resultType.getElementType(); @@ -426,7 +424,7 @@ class ConvertAtenEmptyMemoryFormatOp op, "unimplemented: pin_memory must be either None or false"); // Only `none`, `contiguous` and `preserve` memory_format is supported. - if (!op.getMemoryFormat().getType().isa()) { + if (!isa(op.getMemoryFormat().getType())) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) @@ -441,7 +439,7 @@ class ConvertAtenEmptyMemoryFormatOp } // TODO: Add support for device arg other than cpu. - if (!op.getDevice().getType().isa()) { + if (!isa(op.getDevice().getType())) { std::string device; if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) return rewriter.notifyMatchFailure( @@ -453,7 +451,7 @@ class ConvertAtenEmptyMemoryFormatOp // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. - if (!op.getLayout().getType().isa()) { + if (!isa(op.getLayout().getType())) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( @@ -478,7 +476,7 @@ class ConvertAtenEmptyMemoryFormatOp auto resultType = cast(typeConverter->convertType(op.getType())); Type resultElementType; - if (op.getDtype().getType().isa()) { + if (isa(op.getDtype().getType())) { resultElementType = getDefaultDtypeForTorchScalar( Torch::FloatType::get(op->getContext())); } else { @@ -527,7 +525,7 @@ class ConvertAtenArangeStartStepOp // The pin_memory should be either `False` or `none`. bool pinMemory; - if (!op.getPinMemory().getType().isa() && + if (!isa(op.getPinMemory().getType()) && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return rewriter.notifyMatchFailure( @@ -536,9 +534,8 @@ class ConvertAtenArangeStartStepOp Location loc = op.getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); Type dtype = resultType.getElementType(); Value start = convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype); diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index 7585e07b9825..ab5fec18f9b2 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -138,17 +138,16 @@ class ConvertAtenScalarToTensorLike : public ConversionPattern { requires_grad = tensorFloatOp.getRequiresGrad(); } // TODO: Dtype conversion. - if (!dtype.getType().isa()) + if (!isa(dtype.getType())) return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype"); // TODO: Device information. - if (!device.getType().isa()) + if (!isa(device.getType())) return rewriter.notifyMatchFailure( op, "Unimplemented non-None device information"); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type outElementType = resultType.getElementType(); Value elemValProm = convertScalarToDtype(rewriter, loc, elemVal, outElementType); @@ -171,9 +170,8 @@ class ConvertPrimNumToTensorScalarOp if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type outElementType = resultType.getElementType(); Value elemVal = adaptor.getA(); Value elemValProm = diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 2fcfbc539042..f7c40c147262 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -422,7 +422,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; - if (!clone.getMemoryFormat().getType().isa() && + if (!isa(clone.getMemoryFormat().getType()) && (!matchPattern(clone.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) || (memoryFormat != torch_upstream::MemoryFormat::Contiguous && @@ -434,24 +434,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return payloadArgs[0]; } if (auto bitwiseAndTensor = dyn_cast(op)) { - if (bitwiseAndTensor.getType() - .cast() - .getDtype() - .isa()) { + if (isa( + cast(bitwiseAndTensor.getType()).getDtype())) { bitwiseAndTensor.emitError( "Bitwise_And does not support floating point dtype"); return nullptr; } - Type dtype = converter->convertType(bitwiseAndTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseAndTensor.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto bitwiseAndScalar = dyn_cast(op)) { - Type dtype = converter->convertType(bitwiseAndScalar.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseAndScalar.getType())) .getElementType(); if (!isa(dtype)) { bitwiseAndScalar.emitError( @@ -469,32 +467,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, self, other); } if (auto bitwiseOrTensor = dyn_cast(op)) { - if (bitwiseOrTensor.getType() - .cast() - .getDtype() - .isa()) { + if (isa( + cast(bitwiseOrTensor.getType()).getDtype())) { bitwiseOrTensor.emitError( "Bitwise_Or does not support floating point dtype"); return nullptr; } - Type dtype = converter->convertType(bitwiseOrTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseOrTensor.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto bitwiseXorTensor = dyn_cast(op)) { - if (bitwiseXorTensor.getType() - .cast() - .getDtype() - .isa()) { + if (isa( + cast(bitwiseXorTensor.getType()).getDtype())) { bitwiseXorTensor.emitError( "Bitwise_Xor does not support floating point dtype"); return nullptr; } - Type dtype = converter->convertType(bitwiseXorTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseXorTensor.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); @@ -502,8 +496,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto bitwiseRightShiftTensor = dyn_cast(op)) { - Type dtype = converter->convertType(bitwiseRightShiftTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseRightShiftTensor.getType())) .getElementType(); if (!isa(dtype)) { bitwiseRightShiftTensor.emitError( @@ -516,8 +510,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto bitwiseLeftShiftTensor = dyn_cast(op)) { - Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseLeftShiftTensor.getType())) .getElementType(); if (!isa(dtype)) { bitwiseLeftShiftTensor.emitError( @@ -557,7 +551,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createEqual(b, loc, floatDtype, self, zero); } if (isa(op)) { - if (payloadArgs[0].getType().isa()) + if (isa(payloadArgs[0].getType())) return b.create(loc, payloadArgs[0]); return b.create(loc, payloadArgs[0]); } @@ -653,20 +647,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, cmp, arg, zeroPoint); } if (auto round = dyn_cast(op)) { - if (!round.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(round.getType()).getDtype())) { round.emitError("unimplemented: non-floating point dtype"); return nullptr; } return b.create(loc, payloadArgs[0]); } if (auto prelu = dyn_cast(op)) { - if (!prelu.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(prelu.getType()).getDtype())) { prelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -685,10 +675,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, positivePart, scaledNegativePart); } if (auto gelu = dyn_cast(op)) { - if (!gelu.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(gelu.getType()).getDtype())) { gelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -732,10 +720,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } if (auto geluBackward = dyn_cast(op)) { - if (!geluBackward.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(geluBackward.getType()).getDtype())) { geluBackward.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -770,10 +756,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto hardtanhBackward = dyn_cast(op)) { AtenHardtanhBackwardOp::Adaptor adaptor(operands); - if (!hardtanhBackward.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(hardtanhBackward.getType()).getDtype())) { hardtanhBackward.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -967,10 +951,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto pow = dyn_cast(op)) { - if (!pow.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(pow.getType()).getDtype())) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -1047,10 +1029,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto lerp = dyn_cast(op)) { - if (!lerp.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(lerp.getType()).getDtype())) { lerp.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -1064,9 +1044,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto minimum = dyn_cast(op)) { Type dtype = cast(minimum.getType()).getDtype(); - Type elemTy = converter->convertType(minimum.getType()) - .cast() - .getElementType(); + Type elemTy = + cast(converter->convertType(minimum.getType())) + .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createLessThan(b, loc, dtype, lhs, rhs); @@ -1074,9 +1054,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto maximum = dyn_cast(op)) { Type dtype = cast(maximum.getType()).getDtype(); - Type elemTy = converter->convertType(maximum.getType()) - .cast() - .getElementType(); + Type elemTy = + cast(converter->convertType(maximum.getType())) + .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createGreaterThan(b, loc, dtype, lhs, rhs); @@ -1086,8 +1066,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); - if (min.getType().isa() || - max.getType().isa()) { + if (isa(min.getType()) || + isa(max.getType())) { clamp.emitError("unimplemented: runtime optional type"); return nullptr; } @@ -1125,9 +1105,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( }; auto result = payloadArgs[0]; - if (!min.getType().isa()) + if (!isa(min.getType())) result = cmpSelect(result, min, /*getMax=*/false); - if (!max.getType().isa()) + if (!isa(max.getType())) result = cmpSelect(result, max, /*getMax=*/true); return result; } @@ -1135,8 +1115,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( AtenClampTensorOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); - if (min.getType().isa() || - max.getType().isa()) { + if (isa(min.getType()) || + isa(max.getType())) { clampTensor.emitError("unimplemented: runtime optional type"); return nullptr; } @@ -1145,7 +1125,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); bool isMinNone = true; auto result = payloadArgs[0]; - if (!min.getType().isa()) { + if (!isa(min.getType())) { isMinNone = false; auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value pred; @@ -1163,7 +1143,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } result = b.create(loc, pred, minPromoted, result); } - if (!max.getType().isa()) { + if (!isa(max.getType())) { max = isMinNone ? payloadArgs[1] : payloadArgs[2]; auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); Value pred; @@ -1252,9 +1232,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, self, other); } if (auto remScalar = dyn_cast(op)) { - Type newResultType = converter->convertType(remScalar.getType()) - .cast() - .getElementType(); + Type newResultType = + cast(converter->convertType(remScalar.getType())) + .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value other = convertScalarToDtype(b, loc, operands[1], newResultType); @@ -1272,9 +1252,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return result; } if (auto remTensor = dyn_cast(op)) { - Type newResultType = converter->convertType(remTensor.getType()) - .cast() - .getElementType(); + Type newResultType = + cast(converter->convertType(remTensor.getType())) + .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); @@ -1292,9 +1272,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return result; } if (auto fmod = dyn_cast(op)) { - Type newResultType = converter->convertType(fmod.getType()) - .cast() - .getElementType(); + Type newResultType = + cast(converter->convertType(fmod.getType())) + .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); @@ -1420,9 +1400,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto bitwiseNot = dyn_cast(op)) { - Type elementType = converter->convertType(bitwiseNot.getType()) - .cast() - .getElementType(); + Type elementType = + cast(converter->convertType(bitwiseNot.getType())) + .getElementType(); if (isa(elementType)) { bitwiseNot.emitError("Bitwise_Not does not support floating point dtype"); return nullptr; @@ -1607,10 +1587,9 @@ class ConvertElementwiseOp : public ConversionPattern { Location loc = op->getLoc(); auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range( - operands, [](Value v) { return v.getType().isa(); })); - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + operands, [](Value v) { return isa(v.getType()); })); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); bool hadErrorCreatingPayload = false; Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), @@ -1657,7 +1636,7 @@ class ConvertAtenNllLossForwardOp return rewriter.notifyMatchFailure(op, "dim must be constant"); // TODO: Incorporate the weight argument. - if (!weight.getType().isa()) + if (!isa(weight.getType())) return rewriter.notifyMatchFailure( op, "Unimplemented, the weight operand is not incorporated."); @@ -1672,9 +1651,8 @@ class ConvertAtenNllLossForwardOp return rewriter.notifyMatchFailure( op, "expected input and target to be rank <= 2"); } - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type elementType = resultType.getElementType(); Value zeroVal = rewriter.create( @@ -1948,7 +1926,7 @@ class ConvertAtenNllLossBackwardOp Value input = adaptor.getSelf(); Value target = adaptor.getTarget(); Value weight = adaptor.getWeight(); - bool weightIsNone = op.getWeight().getType().isa(); + bool weightIsNone = isa(op.getWeight().getType()); Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex()); Value totalWeight = adaptor.getTotalWeight(); @@ -2069,9 +2047,8 @@ class ConvertAtenNllLossBackwardOp }) ->getResult(0); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, gradInput); return success(); } @@ -2214,9 +2191,8 @@ class ConvertTensorStaticInfoCastOp LogicalResult matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getOperand()); return success(); @@ -2243,7 +2219,7 @@ class ConvertLogitOp : public OpConversionPattern { if (succeeded(checkNotNone(rewriter, op, eps))) handleEps = true; - if (handleEps && !eps.getType().isa()) { + if (handleEps && !isa(eps.getType())) { op.emitError("Logit does not support non-floating point type"); return failure(); } @@ -2317,9 +2293,8 @@ class ConvertAtenIntReprOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); return success(); @@ -2362,8 +2337,8 @@ class ConvertDequantizePerChannel zeropoint = converter->materializeTargetConversion( rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint); - auto resultType = converter->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + converter->convertType(op->getResult(0).getType())); llvm::SmallVector dynSizes; for (auto [index, dim] : llvm::enumerate(resultType.getShape())) { @@ -2553,9 +2528,8 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { return res; }; - auto resultType = getTypeConverter() - ->convertType(op.getResult().getType()) - .cast(); + auto resultType = cast( + getTypeConverter()->convertType(op.getResult().getType())); SmallVector resultSize{}; if (resultType.isDynamicDim(0)) resultSize.push_back(rewriter.create(loc, input, 0)); @@ -2675,7 +2649,7 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, SmallVector scaleValues, std::string coordStr, std::string nearestMode) { - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); SmallVector indices; @@ -2764,7 +2738,7 @@ static Value BilinearInterpolate(OpBuilder &b, SmallVector scaleValues, std::string coordStr) { unsigned dimOffset = 2; - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); Value cstOneEps = @@ -2929,7 +2903,7 @@ class ConvertInterpolateOp Location loc = op->getLoc(); Value input = adaptor.getInput(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); if (mode.substr(0, 8) == "bilinear" && inputRank != 4) return rewriter.notifyMatchFailure( @@ -2945,7 +2919,7 @@ class ConvertInterpolateOp loc, rewriter.getIntegerType(64), inputSize)); } - if (!op.getScaleFactor().getType().isa()) { + if (!isa(op.getScaleFactor().getType())) { bool recompScale; if (!matchPattern(op.getRecomputeScaleFactor(), m_TorchConstantBool(&recompScale))) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 9f63f58861cf..0aa919fe04a6 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -52,7 +52,7 @@ Value torch_to_linalg::getPaddedTensor( Value torch_to_linalg::getZeroPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &paddingInts) { - assert(input.getType().isa() && + assert(isa(input.getType()) && "input must be RankedTensorType"); Location loc = op->getLoc(); Value c0 = b.create( @@ -67,7 +67,7 @@ Value torch_to_linalg::getZeroPaddedTensor( Value torch_to_linalg::getDynamicZeroPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &padding, int unpaddedDims, Value pad) { - assert(input.getType().isa() && + assert(isa(input.getType()) && "input must be RankedTensorType"); unsigned int inRank = cast(input.getType()).getRank(); Location loc = op->getLoc(); diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index e3418e38ea1f..27e0a61f4b31 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -252,7 +252,7 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { // "block" arguments for (const auto &barg : enumerate(op.getRegion().front().getArguments())) { Value to = block->getArgument(barg.index()); - if (to.getType().isa()) + if (isa(to.getType())) to = rewriter.create(loc, rewriter.getI64Type(), to); Type targetType = to.getType(); diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 10a8647b4b58..715f89ff9063 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -146,9 +146,9 @@ class ConvertAtenUnaryOp : public OpConversionPattern { if (!selfType) { return op.emitError("only Tensor types supported in StableHLO"); } - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); self = hlo::promoteType(rewriter, op.getLoc(), self, outType); rewriter.replaceOpWithNewOp(op, outType, self); return success(); @@ -203,9 +203,9 @@ class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { auto selfTy = cast(self.getType()); if (!selfTy) return op.emitError("only Tensor types supported in StableHLO"); - auto resultTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto resultTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (isa(resultTy.getElementType())) { Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy); @@ -231,9 +231,9 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!outType) return op.emitError("only Tensor types supported in StableHLO"); @@ -321,9 +321,9 @@ class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { if (!lhsTy || !rhsTy) return op.emitError("only Tensor types supported"); - auto outTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); @@ -354,9 +354,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern { if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); - TensorType outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + TensorType outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { @@ -607,9 +607,9 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern { if (!lhsTy) return op.emitError("lhs must be a ranked tensor type"); - TensorType outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + TensorType outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); if (!rhsTy) { @@ -917,9 +917,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outType = cast( + OpConversionPattern::getTypeConverter() + ->convertType(op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { @@ -1421,9 +1421,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp. SmallVector zeroConstVec( - numFeatureDimSize, APFloat::getZero(inputTy.getElementType() - .cast() - .getFloatSemantics())); + numFeatureDimSize, + APFloat::getZero( + cast(inputTy.getElementType()).getFloatSemantics())); SmallVector oneConstVec( numFeatureDimSize, APFloat( @@ -1633,9 +1633,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Location loc = op->getLoc(); // Get element type of resultType as dtype - auto outType = this->getTypeConverter() - ->convertType(op.getType()) - .cast(); + auto outType = cast( + this->getTypeConverter()->convertType(op.getType())); auto dtype = outType.getElementType(); if (!isa(dtype) && !isa(dtype)) { return rewriter.notifyMatchFailure( @@ -1678,7 +1677,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenConstantPadNdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); auto selfElemTy = selfTy.getElementType(); int64_t rank = selfTy.getRank(); @@ -2029,7 +2028,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!selfTy.hasStaticShape()) { return op->emitError("dynamic shaped input is not supported"); } @@ -2062,7 +2061,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( cmpTypeAttr); auto resTy = - getTypeConverter()->convertType(op.getType()).cast(); + cast(getTypeConverter()->convertType(op.getType())); auto bcastTy = resTy.clone(rewriter.getI1Type()); auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1}); @@ -2071,15 +2070,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto resElemTy = resTy.getElementType(); Value zeroTensor; - if (resElemTy.isa()) { + if (isa(resElemTy)) { auto constAttr = SplatElementsAttr::get( resTy, llvm::APFloat::getZero( - resElemTy.cast().getFloatSemantics(), false)); + cast(resElemTy).getFloatSemantics(), false)); zeroTensor = rewriter.create(loc, resTy, constAttr); - } else if (resElemTy.isa()) { + } else if (isa(resElemTy)) { auto constAttr = SplatElementsAttr::get( resTy, - llvm::APInt::getZero(resElemTy.cast().getWidth())); + llvm::APInt::getZero(cast(resElemTy).getWidth())); zeroTensor = rewriter.create(loc, resTy, constAttr); } else { return op.emitError("element type is not float or integer"); diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 0f16662756a9..a551e0521852 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -157,8 +157,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value builtinTypeStart = adaptor.getStart(); Value builtinTypeEnd = adaptor.getEnd(); - if (torchTypeStart.getType().isa() || - torchTypeEnd.getType().isa()) + if (isa(torchTypeStart.getType()) || + isa(torchTypeEnd.getType())) return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); int64_t step; @@ -349,11 +349,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "offsets must be a vector with static shape equal to 1"); - if (!op.getPaddingIdx().getType().isa()) + if (!isa(op.getPaddingIdx().getType())) return rewriter.notifyMatchFailure( op, "Unimplemented: padding_idx should be none"); - if (!op.getPerSampleWeights().getType().isa()) + if (!isa(op.getPerSampleWeights().getType())) return rewriter.notifyMatchFailure( op, "Unimplemented: per_sample_weights should be none"); @@ -453,25 +453,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( loc, getTypeConverter()->convertType(op.getType(0)), stablehloReduceOp.getResult(0), outShapeTensor); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(1).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(1).getType())); Value resultB = createInitialValueForGatherScatterOp(op, resultType, rewriter); if (!resultB) return failure(); - resultType = getTypeConverter() - ->convertType(op->getResult(2).getType()) - .cast(); + resultType = cast( + getTypeConverter()->convertType(op->getResult(2).getType())); Value resultC = createInitialValueForGatherScatterOp(op, resultType, rewriter); if (!resultC) return failure(); - resultType = getTypeConverter() - ->convertType(op->getResult(3).getType()) - .cast(); + resultType = cast( + getTypeConverter()->convertType(op->getResult(3).getType())); Value resultD = createInitialValueForGatherScatterOp(op, resultType, rewriter); if (!resultD) @@ -612,9 +609,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto input = adaptor.getSelf(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); SmallVector resultShape; SmallVector offsets; diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index 93c6d2eac8f9..b6e9d9ba90a8 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -350,9 +350,9 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { rewriter.replaceOpWithNewOp( op, - ConvertAtenOp::getTypeConverter() - ->convertType(op.getType()) - .template cast(), + cast( + ConvertAtenOp::getTypeConverter()->convertType( + op.getType())), output); return success(); @@ -730,9 +730,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { // If transposed is set to true, // the weight shape changes to [IC, (OC//G), KH, KW] auto weightTy = cast(weight.getType()); - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outTy = + cast(getTypeConverter()->convertType(op.getType())); if (!inputTy || !weightTy || !outTy) { return op.emitError("input, weight and output must be ranked tensors"); } diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index eb32cd3ac9d7..a52d4e7194e2 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -216,10 +216,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto *secondIdxArg = std::next(secondValArg); stablehlo::ComparisonTypeAttr compareTypeAttr; - if (inputTy.getElementType().isa()) { + if (isa(inputTy.getElementType())) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::FLOAT); - } else if (inputTy.getElementType().isa()) { + } else if (isa(inputTy.getElementType())) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::SIGNED); } @@ -395,9 +395,8 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { RankedTensorType inputTy = cast(input.getType()); Type inputElemTy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); - RankedTensorType outTy = ConvertAtenOp::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + RankedTensorType outTy = cast( + ConvertAtenOp::getTypeConverter()->convertType(op.getType())); auto outShape = outTy.getShape(); if (inputRank <= Dim) { diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index d31a46035e05..d8d7d43c4d24 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -242,10 +242,10 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, auto *secondIdxArg = std::next(secondValArg); stablehlo::ComparisonTypeAttr compareTypeAttr; - if (inputTy.getElementType().isa()) { + if (isa(inputTy.getElementType())) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::FLOAT); - } else if (inputTy.getElementType().isa()) { + } else if (isa(inputTy.getElementType())) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::SIGNED); } @@ -535,12 +535,10 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "AtenMaxDimOp to StableHLO"); } - RankedTensorType valResultType = getTypeConverter() - ->convertType(op.getResult(0).getType()) - .template cast(); - RankedTensorType idxResultType = getTypeConverter() - ->convertType(op.getResult(1).getType()) - .template cast(); + RankedTensorType valResultType = cast( + getTypeConverter()->convertType(op.getResult(0).getType())); + RankedTensorType idxResultType = cast( + getTypeConverter()->convertType(op.getResult(1).getType())); Type idxElementType = idxResultType.getElementType(); if (!isa(idxElementType)) { return op.emitError("Aten.max.dim needs integer-like result"); @@ -636,9 +634,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); if (!inputTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index 4ced38656fce..46d58b8b5f8f 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -271,7 +271,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "dim is statically invalid"); auto getOptionalVal = [&](Value val) -> std::optional { - if (val.getType().isa()) { + if (isa(val.getType())) { return std::nullopt; } else { return val; @@ -451,7 +451,7 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( PrimsSplitDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto selfType = adaptor.getA().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getA().getType()); if (!selfType) { return op.emitError("only tensor types are currently supported"); } diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index c9235af9a14f..4beee4c5e82f 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -292,7 +292,7 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc, arith::CmpIPredicate predicate = isDescending ? ge : le; compareOp = rewriter.create( loc, predicate, block->getArgument(0), block->getArgument(1)); - } else if (elementTypes[0].isa()) { + } else if (isa(elementTypes[0])) { // Case for using arith::CmpFOp. arith::CmpFPredicate predicate = isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE; @@ -349,8 +349,8 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern { b.create(loc, updatesElement); }); - auto resultType = typeConverter->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, scatterOp); return success(); } @@ -381,7 +381,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { // Check whether the input is a 1-d tensor of integer type or not. RankedTensorType inputType = cast(input.getType()); if (inputType.getRank() != 1 || - !inputType.getElementType().isa()) + !isa(inputType.getElementType())) return rewriter.notifyMatchFailure( op, "Input tensor has to be a one-dimensional tensor of integer type."); @@ -395,7 +395,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { "Unimplemented: Integer width not equal to 64 are not supported."); // TODO: Incorporate the weight argument. - if (!weights.getType().isa()) + if (!isa(weights.getType())) return rewriter.notifyMatchFailure( op, "Unimplemented: the weights operand is not incorporated."); @@ -439,8 +439,8 @@ class ConvertAtenBincountOp : public OpConversionPattern { indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); - auto resultType = typeConverter->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); Type resultElemType = resultType.getElementType(); SmallVector inputSizeDynamic = @@ -686,8 +686,8 @@ class ConvertAtenIndexPutHackedTwinOp auto valuesType = cast(values.getType()); int64_t inputRank = inputType.getSizes().size(); auto valuesTensorType = cast(op.getValues().getType()); - auto resultType = typeConverter->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); if (!valuesTensorType.hasSizes()) return rewriter.notifyMatchFailure( @@ -826,10 +826,10 @@ class ConvertAtenIndexPutHackedTwinOp Value inputElement) { Value yieldValue = valuesElement; if (accumulate) { - if (inputElement.getType().isa()) { + if (isa(inputElement.getType())) { yieldValue = b.create(loc, inputElement, valuesElement); - } else if (inputElement.getType().isa()) { + } else if (isa(inputElement.getType())) { yieldValue = b.create(loc, inputElement, valuesElement); } else { @@ -1045,10 +1045,10 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { Value yieldValue = valuesElement; - if (inputElement.getType().isa()) { + if (isa(inputElement.getType())) { yieldValue = b.create(loc, inputElement, valuesElement); - } else if (inputElement.getType().isa()) { + } else if (isa(inputElement.getType())) { yieldValue = b.create(loc, inputElement, valuesElement); } else { @@ -1207,33 +1207,33 @@ class ConvertAtenScatterReduceTwoOp Value result; if (reduceEnum == torch_upstream::ReductionType::SUM || reduceEnum == torch_upstream::ReductionType::MEAN) { - if (update.getType().isa()) { + if (isa(update.getType())) { result = b.create(loc, update, current); - } else if (update.getType().isa()) { + } else if (isa(update.getType())) { result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } } else if (reduceEnum == torch_upstream::ReductionType::PROD) { - if (update.getType().isa()) { + if (isa(update.getType())) { result = b.create(loc, update, current); - } else if (update.getType().isa()) { + } else if (isa(update.getType())) { result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } } else if (reduceEnum == torch_upstream::ReductionType::MAX) { - if (update.getType().isa()) { + if (isa(update.getType())) { result = b.create(loc, update, current); - } else if (update.getType().isa()) { + } else if (isa(update.getType())) { result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } } else if (reduceEnum == torch_upstream::ReductionType::MIN) { - if (update.getType().isa()) { + if (isa(update.getType())) { result = b.create(loc, update, current); - } else if (update.getType().isa()) { + } else if (isa(update.getType())) { result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); @@ -1288,9 +1288,8 @@ class ConvertAtenScatterReduceTwoOp }) .getResult()[0]; } - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, scatterOp); return success(); @@ -1395,9 +1394,8 @@ class ConvertAtenCumsumOp : public OpConversionPattern { Location loc = op.getLoc(); Value input = adaptor.getSelf(); - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type elementType = resultType.getElementType(); Type inputElementType = cast(input.getType()).getElementType(); @@ -1417,7 +1415,7 @@ class ConvertAtenCumsumOp : public OpConversionPattern { int64_t inputRank = resultType.getRank(); Value dtype = op.getDtype(); - if (!dtype.getType().isa()) + if (!isa(dtype.getType())) return rewriter.notifyMatchFailure( op, "unsupported: dtype argument not supported"); @@ -1447,7 +1445,7 @@ class ConvertAtenCumsumOp : public OpConversionPattern { rewriter, loc, input, output, acc, dim, /*inclusive=*/true, [](OpBuilder &b, Location loc, Value input, Value acc) { Value sum = - (input.getType().isa() + (isa(input.getType()) ? b.create(loc, input, acc)->getResult(0) : b.create(loc, input, acc)->getResult(0)); b.create(loc, sum); @@ -1475,7 +1473,7 @@ class ConvertAtenScaledDotProductAttentionOp cast(adaptor.getQuery().getType()).getElementType(); // Verify inputs (only support defaults) - if (!mask.getType().isa()) + if (!isa(mask.getType())) return rewriter.notifyMatchFailure(op.getLoc(), "attention masking not supported"); double dropout; @@ -1486,7 +1484,7 @@ class ConvertAtenScaledDotProductAttentionOp if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) return rewriter.notifyMatchFailure( op.getLoc(), "causal attention masking not supported"); - if (!scale.getType().isa()) { + if (!isa(scale.getType())) { double scaleFloat; if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) || scaleFloat != 1.0) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 582197567a1d..5b1027626c2d 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1,5 +1,5 @@ //===----------------------------------------------------------------------===// -// +//// // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -101,9 +101,9 @@ class ConvertAtenBinaryOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - auto outTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); auto binaryOp = tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); @@ -250,9 +250,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern { } // Get output type: tensor - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { @@ -379,9 +379,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { std::is_same()); // Promote lhs and rhs dtypes for bitwise operators. - TensorType resultTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + TensorType resultTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (isBitwiseOp) { lhs = tosa::promoteType(rewriter, lhs, resultTy); rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy); @@ -422,9 +422,9 @@ class ConvertAtenMulOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) @@ -450,9 +450,9 @@ class ConvertAtenMulOp : public OpConversionPattern { } if (isa(outElemTy) || isa(outElemTy)) { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsTensor, /*shift=*/0); @@ -498,9 +498,9 @@ class ConvertAtenDivOp : public OpConversionPattern { "conversion in TOSA operation"); } auto rhsTensor = rhsTy ? rhs : rhsAsTensor; - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); // auto result; Value result; @@ -545,7 +545,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenTanhOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (selfTy && isa(selfTy.getElementType())) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); @@ -562,7 +562,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenSigmoidOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (selfTy && isa(selfTy.getElementType())) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); @@ -609,7 +609,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); if (!isa(selfTy.getElementType())) { return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization currently supported"); @@ -673,9 +673,9 @@ class ConvertAtenReductionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - auto outputTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outputTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!outputTy) return rewriter.notifyMatchFailure( op, "Only ranked tensor type outputs permitted for reduce_mean"); @@ -834,9 +834,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "non-const keepdim parameter unsupported"); - auto resultTy = getTypeConverter() - ->convertType(op.getResult().getType()) - .cast(); + auto resultTy = cast( + getTypeConverter()->convertType(op.getResult().getType())); auto outputETy = resultTy.getElementType(); // Create a single instance of tosa.argmax. @@ -933,9 +932,9 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Squeeze could not compute new shape"); - auto resultTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getResult().getType()) - .template cast(); + auto resultTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getResult().getType())); auto resultElemTy = resultTy.getElementType(); auto newOutputTy = RankedTensorType::get( @@ -2070,9 +2069,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto inputTy = cast(input.getType()); auto weightTy = cast(weight.getType()); - auto outputTy = getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outputTy = + cast(getTypeConverter()->convertType(op.getType())); if (!inputTy || !weightTy || !outputTy) return rewriter.notifyMatchFailure( @@ -2655,9 +2653,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto outputTy = getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outputTy = + cast(getTypeConverter()->convertType(op.getType())); // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse @@ -3325,7 +3322,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType outType = cast(typeConverter->convertType(op.getType())); - auto indicesType = indices.getType().dyn_cast(); + auto indicesType = dyn_cast(indices.getType()); if (!indicesType || !isa(indicesType.getElementType())) return rewriter.notifyMatchFailure( op, "Indices must be of integer tensor type"); @@ -4616,9 +4613,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { const TypeConverter *typeConverter = this->getTypeConverter(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); // At this point all tensors should have value semantics, and hence the // `layout` check can be ignored. @@ -4720,10 +4716,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( }; const auto isIntType = - resultType.getElementType().dyn_cast_or_null(); + dyn_cast_or_null(resultType.getElementType()); const auto isDoubleType = - resultType.getElementType().dyn_cast_or_null(); + dyn_cast_or_null(resultType.getElementType()); auto maybeResult = [&]() -> std::optional { // Integer output type, and start / end / range are all integers. @@ -4776,9 +4772,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { const TypeConverter *typeConverter = this->getTypeConverter(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); // Only supports integer operand type, because for the floating point operand // type result tensor has to be of type `f64` which is not supported in the @@ -4894,9 +4889,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "memory_format is supported"); } - auto resultTy = getTypeConverter() - ->convertType(op.getResult().getType()) - .cast(); + auto resultTy = cast( + getTypeConverter()->convertType(op.getResult().getType())); Value result; if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(), @@ -5337,9 +5331,9 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!outType) return rewriter.notifyMatchFailure(op, @@ -5399,9 +5393,9 @@ class ConvertAtenFillScalarOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!outType || !outType.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -5435,9 +5429,9 @@ class ConvertAtenMaskedFillOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!outType || !outType.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -5507,9 +5501,9 @@ class ConvertAtenCloneOp : public OpConversionPattern { "unimplemented: only contiguous and channels last memory " "format is supported"); } - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf()); return success(); @@ -5640,8 +5634,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - auto resultType = typeConverter->convertType(op.getType()) - .template cast(); + auto resultType = + cast(typeConverter->convertType(op.getType())); auto elementType = resultType.getElementType(); if (isa(selfTy.getElementType())) { diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 450689cea45c..5ea4e4bc47dc 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -820,9 +820,9 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); + isa(input_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + isa(output_type.getElementType()); if (input_is_qtype || output_is_qtype) { op->emitOpError("ConvertReduceProdOp: input/output tensor should " @@ -846,9 +846,9 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); + isa(input_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + isa(output_type.getElementType()); if (input_is_qtype != output_is_qtype) { op->emitOpError("ConvertReduceSumOp: input/output tensor should " @@ -901,9 +901,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); + isa(input_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + isa(output_type.getElementType()); if (input_is_qtype != output_is_qtype) { op->emitOpError("ConvertReduceSumOp: input/output tensor should " @@ -912,7 +912,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, } // Only supports float type mean() if it's non-quantized - if (!input_is_qtype && !output_type.getElementType().isa()) { + if (!input_is_qtype && !isa(output_type.getElementType())) { op->emitWarning( "Failed convertReduceMean: input unquantized type but output element " "not FloatType!"); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index c6adfa73fc63..4af9709fdfd7 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -31,7 +31,7 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op, return false; auto tensor = dyn_cast(type); return !tensor || - tensor.toBuiltinTensor().dyn_cast_or_null(); + dyn_cast_or_null(tensor.toBuiltinTensor()); }; bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) && @@ -66,7 +66,7 @@ Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, // Generate IR: assert(dim >= 0 && dim < inputRank) void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) { - assert(dim.getType().isa() && + assert(isa(dim.getType()) && "dim arg of assertIsValidDim must be integer type"); Value cst0 = b.create(loc, b.getZeroAttr(inputRank.getType())); @@ -375,7 +375,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, Value defaultValue, Value dimSize) { - if (torchOptionalInt.getType().isa()) + if (isa(torchOptionalInt.getType())) return defaultValue; auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize); Value positiveDim = diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 314191a2c428..1669cb43fbc0 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -149,14 +149,12 @@ static Value getScalarIntValue(Value input, Location loc, if (auto valueTensorLiteralOp = input.getDefiningOp()) { if (inputDtype.isInteger(64)) { - auto val = valueTensorLiteralOp.getValue() - .cast() + auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue(); return rewriter.create( loc, rewriter.getI64IntegerAttr(val)); } else { - auto val = valueTensorLiteralOp.getValue() - .cast() + auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue(); return rewriter.create( loc, rewriter.getI64IntegerAttr(val)); @@ -191,8 +189,7 @@ static Value getScalarFloatValue(Value input, Location loc, return nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { - auto val = valueTensorLiteralOp.getValue() - .cast() + auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue() .getValueAsDouble(); return rewriter.create( @@ -1946,7 +1943,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) { auto resultType = dyn_cast(getType()); if (resultType && resultType.hasDtype() && - resultType.getDtype().isa()) { + isa(resultType.getDtype())) { return getSelf(); } return {}; @@ -2136,7 +2133,7 @@ traceKnownSizeTensorType(Value value, std::optional dim) { // Limit the loop count to 6 to avoid indefinite compilation times from // unbounded IR traversals. for (auto idx = 0; idx < 6; ++idx) { - if (!value || !value.getType().isa()) + if (!value || !isa(value.getType())) return failure(); auto tensorType = cast(value.getType()); @@ -2533,7 +2530,7 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { // Constant fold int -> float conversion. - if (auto integerAttr = adaptor.getA().dyn_cast_or_null()) { + if (auto integerAttr = dyn_cast_or_null(adaptor.getA())) { return FloatAttr::get( mlir::Float64Type::get(getContext()), static_cast(integerAttr.getValue().getSExtValue())); @@ -2550,7 +2547,7 @@ OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) { // Constant fold float -> int conversion. - if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { + if (auto floatAttr = dyn_cast_or_null(adaptor.getA())) { return IntegerAttr::get( mlir::IntegerType::get(getContext(), 64), static_cast(floatAttr.getValue().convertToDouble())); @@ -2564,7 +2561,7 @@ OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) { // Constant fold float -> int conversion. - if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { + if (auto floatAttr = dyn_cast_or_null(adaptor.getA())) { return IntegerAttr::get( mlir::IntegerType::get(getContext(), 64), static_cast(floatAttr.getValue().convertToDouble())); @@ -2710,9 +2707,8 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto attr = properties.as() - ->getValue() - .dyn_cast_or_null(); + auto attr = + dyn_cast_or_null(properties.as()->getValue()); if (!attr) return failure(); RankedTensorType tensorType = cast(attr.getType()); @@ -2738,10 +2734,10 @@ static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) { bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred, TypeRange actual) { - if (!actual[0].isa()) + if (!isa(actual[0])) return false; - return areSizesAndDtypesCompatible(inferred[0].cast(), - actual[0].cast()); + return areSizesAndDtypesCompatible(cast(inferred[0]), + cast(actual[0])); } //===----------------------------------------------------------------------===// @@ -2752,9 +2748,8 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto attr = properties.as() - ->getValue() - .dyn_cast_or_null(); + auto attr = + dyn_cast_or_null(properties.as()->getValue()); if (!attr) return failure(); RankedTensorType tensorType = cast(attr.getType()); @@ -2775,8 +2770,8 @@ OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) { bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs, mlir::TypeRange outputs) { - return areSizesAndDtypesCompatible(inputs[0].cast(), - outputs[0].cast()); + return areSizesAndDtypesCompatible(cast(inputs[0]), + cast(outputs[0])); } void TensorStaticInfoCastOp::getCanonicalizationPatterns( @@ -3087,7 +3082,7 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) { if (!operandType) return nullptr; if (operandType.hasDtype()) { - bool isFloatType = operandType.getDtype().isa(); + bool isFloatType = isa(operandType.getDtype()); return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType); } // doesn't has dtype @@ -3145,12 +3140,12 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, int64_t start; int64_t end; int64_t step; - if (op.getStart().getType().isa()) { + if (isa(op.getStart().getType())) { start = 0; } else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) { return failure(); } - if (op.getEnd().getType().isa()) { + if (isa(op.getEnd().getType())) { end = listElements.size(); } else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { return failure(); @@ -3243,7 +3238,7 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // things. Value replacement = tupleConstruct.getElements()[i]; if (replacement.getType() != op.getType()) { - if (op.getType().isa()) { + if (isa(op.getType())) { replacement = rewriter.create( op.getLoc(), op.getType(), replacement); } else { @@ -3399,8 +3394,8 @@ using BinaryIntOperatorFn = std::function; static OpFoldResult atenBinaryIntOperatorFoldHelper(ArrayRef operands, BinaryIntOperatorFn f) { - auto intLhs = operands[0].dyn_cast_or_null(); - auto intRhs = operands[1].dyn_cast_or_null(); + auto intLhs = dyn_cast_or_null(operands[0]); + auto intRhs = dyn_cast_or_null(operands[1]); if (!intLhs || !intRhs) { return nullptr; } @@ -3726,7 +3721,7 @@ OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) { return nullptr; } - if (adaptor.getA().isa() && adaptor.getB().isa()) { + if (isa(adaptor.getA()) && isa(adaptor.getB())) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a + b; }); @@ -3745,7 +3740,7 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) { return nullptr; } - if (adaptor.getA().isa() && adaptor.getB().isa()) { + if (isa(adaptor.getA()) && isa(adaptor.getB())) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a * b; }); @@ -3764,7 +3759,7 @@ OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) { return nullptr; } - if (adaptor.getA().isa() && adaptor.getB().isa()) { + if (isa(adaptor.getA()) && isa(adaptor.getB())) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a - b; }); @@ -3821,7 +3816,7 @@ OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA()) { return nullptr; } - auto floatValue = adaptor.getA().dyn_cast_or_null(); + auto floatValue = dyn_cast_or_null(adaptor.getA()); if (!floatValue) { return nullptr; } @@ -3849,7 +3844,7 @@ OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA()) { return nullptr; } - auto value = adaptor.getA().dyn_cast_or_null(); + auto value = dyn_cast_or_null(adaptor.getA()); if (!value) { return nullptr; } @@ -4571,8 +4566,8 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { if (getA() == getB()) return getA(); - auto lhs = adaptor.getA().dyn_cast_or_null(); - auto rhs = adaptor.getB().dyn_cast_or_null(); + auto lhs = dyn_cast_or_null(adaptor.getA()); + auto rhs = dyn_cast_or_null(adaptor.getB()); if (!lhs || !rhs) return nullptr; // Torch semantics are that !torch.int is 64-bit signed. @@ -4640,8 +4635,8 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) { if (getA() == getB()) return getA(); - auto lhs = adaptor.getA().dyn_cast_or_null(); - auto rhs = adaptor.getB().dyn_cast_or_null(); + auto lhs = dyn_cast_or_null(adaptor.getA()); + auto rhs = dyn_cast_or_null(adaptor.getB()); if (!lhs || !rhs) return nullptr; // Torch semantics are that !torch.int is 64-bit signed. @@ -4728,8 +4723,8 @@ LogicalResult AtenNormScalarOp::verify() { // Check if dtype is one of those supported by norm operation. // ComplexType will match any torch complex types, but each float must be // checked individually. - if (!inTensorDtype.isa()) { + if (!isa(inTensorDtype)) { return emitOpError( "expected a float or complex type for input tensor, but got ") << inTensorDtype; diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index d1906d6989af..6735bb37e48b 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -190,8 +190,8 @@ static bool isValidTorchDtype(Type dtype) { // Builtin floating point types. if (isa(dtype)) return true; - if (dtype.isa()) + if (isa(dtype)) return true; if (isa(dtype)) @@ -228,9 +228,9 @@ Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const { Type BaseTensorType::getWithSizesAndDtype( std::optional> optionalSizes, Type optionalDtype) const { - if (isa()) + if (mlir::isa(*this)) return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype); - if (isa()) + if (mlir::isa(*this)) return ValueTensorType::get(getContext(), optionalSizes, optionalDtype); llvm_unreachable("not a BaseTensorType!"); } @@ -248,9 +248,9 @@ Type BaseTensorType::getWithSizesAndDtypeAndSparsity( } ValueTensorType BaseTensorType::getWithValueSemantics() const { - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getWithValueSemantics(); - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor; llvm_unreachable("not a BaseTensorType!"); } diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 750ccc355e34..2cbfe2642045 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -110,7 +110,7 @@ class AdjustCallingConventionForCall continue; auto it = typeBoundMap.find({call.getCallee(), operand.index()}); if (it != typeBoundMap.end()) { - if (auto valueTensorType = it->second.dyn_cast()) { + if (auto valueTensorType = dyn_cast(it->second)) { newOperands.push_back(copyTensorToType( rewriter, call->getLoc(), valueTensorType, operand.value())); continue; @@ -215,11 +215,11 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, for (int i = 0, e = func.getNumArguments(); i != e; i++) { if (func.getArgAttr(i, "torch.type_bound")) return false; - if (func.getArgumentTypes()[i].isa()) + if (isa(func.getArgumentTypes()[i])) return false; } for (int i = 0, e = func.getNumResults(); i != e; i++) { - if (func.getFunctionType().getResults()[i].isa()) + if (isa(func.getFunctionType().getResults()[i])) return false; } return true; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0edb9bc51f2d..983d04cfdb5e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -38,7 +38,7 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); if (failed(resDtype)) return false; - return resDtype->isa(); + return isa(*resDtype); } // Helper function to compute the return type of the reduction function. @@ -99,19 +99,15 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { Value keepDimCst = rewriter.create(loc, keepDim); - BaseTensorType valueType = - computeReductionType(rewriter, op, cast(input.getType()), - dim, keepDim) - .cast(); + BaseTensorType valueType = cast(computeReductionType( + rewriter, op, cast(input.getType()), dim, keepDim)); if (!valueType) return nullptr; BaseTensorType indexType = - valueType - .getWithSizesAndDtype( - !valueType.hasSizes() ? std::optional>() - : llvm::ArrayRef(valueType.getSizes()), - IntegerType::get(op->getContext(), 64, IntegerType::Signed)) - .cast(); + cast(valueType.getWithSizesAndDtype( + !valueType.hasSizes() ? std::optional>() + : llvm::ArrayRef(valueType.getSizes()), + IntegerType::get(op->getContext(), 64, IntegerType::Signed))); return rewriter .create(loc, valueType, indexType, input, dim, keepDimCst) .getValues(); @@ -1059,7 +1055,7 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenEyeMOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto outType = op.getType().dyn_cast(); + auto outType = dyn_cast(op.getType()); if (!outType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); @@ -1659,11 +1655,9 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { unsigned inputRank = *maybeInputRank; if (!indicesTensorType.hasSizes()) return failure(); - BaseTensorType valueTensorType = - inputType - .getWithSizesAndDtype(indicesTensorType.getOptionalSizes(), - inputType.getOptionalDtype()) - .cast(); + BaseTensorType valueTensorType = cast( + inputType.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(), + inputType.getOptionalDtype())); // If the dim type is `NoneType` i.e. reduce along all the dimensions. // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so @@ -1671,10 +1665,8 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { // happens on the 0th dimension. if (isa(dim.getType())) { BaseTensorType flattenType = - inputType - .getWithSizesAndDtype({kUnknownSize}, - inputType.getOptionalDtype()) - .cast(); + cast(inputType.getWithSizesAndDtype( + {kUnknownSize}, inputType.getOptionalDtype())); dim = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value end = rewriter.create( loc, rewriter.getI64IntegerAttr(inputRank - 1)); @@ -3003,7 +2995,7 @@ class DecomposeAtenRepeatInterleaveSelfIntOp bool dimIsNone = false; int64_t dim; Value dimValue = op.getDim(); - if (dimValue.getType().isa()) { + if (isa(dimValue.getType())) { dimIsNone = true; dim = inputRank - 1; } else { @@ -3887,10 +3879,9 @@ class DecomposeAtenConvolutionBackwardOp gradOutputViewSizesInt[0] = kUnknownSize; gradOutputViewSizesInt[1] = 1; BaseTensorType gradOutputTypeForView = - gradOutputTy - .getWithSizesAndDtype(llvm::ArrayRef(gradOutputViewSizesInt), - gradOutputTy.getOptionalDtype()) - .cast(); + cast(gradOutputTy.getWithSizesAndDtype( + llvm::ArrayRef(gradOutputViewSizesInt), + gradOutputTy.getOptionalDtype())); Value gradOutputView = rewriter.create( loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList); @@ -3918,10 +3909,9 @@ class DecomposeAtenConvolutionBackwardOp } BaseTensorType gradWeightTy = - inputTransposedTy - .getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), - inputTransposedTy.getOptionalDtype()) - .cast(); + cast(inputTransposedTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightSizesInt), + inputTransposedTy.getOptionalDtype())); Value numGroup = rewriter.create(loc, input, cstZero); gradWeight = rewriter.create( @@ -3937,10 +3927,9 @@ class DecomposeAtenConvolutionBackwardOp for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) { gradWeightSizesInt[i + 2] = weightSizes[i + 2]; BaseTensorType gradWeightNarrowTy = - gradWeightTy - .getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), - gradWeightTy.getOptionalDtype()) - .cast(); + cast(gradWeightTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightSizesInt), + gradWeightTy.getOptionalDtype())); Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(i + 2)); @@ -3970,10 +3959,9 @@ class DecomposeAtenConvolutionBackwardOp gradWeightViewShapeValue); BaseTensorType gradWeightTypeForView = - gradWeightTy - .getWithSizesAndDtype(llvm::ArrayRef(gradWeightViewShapeInt), - gradWeightTy.getOptionalDtype()) - .cast(); + cast(gradWeightTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightViewShapeInt), + gradWeightTy.getOptionalDtype())); gradWeight = rewriter.create( loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList); @@ -3986,10 +3974,9 @@ class DecomposeAtenConvolutionBackwardOp gradWeightViewShapeInt[gradWeightDimsOrder[i]]); } BaseTensorType gradWeightTypeForMoveDim = - gradWeightTy - .getWithSizesAndDtype(llvm::ArrayRef(gradWeightMoveDimShape), - gradWeightTy.getOptionalDtype()) - .cast(); + cast(gradWeightTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightMoveDimShape), + gradWeightTy.getOptionalDtype())); gradWeight = rewriter.create( loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero, @@ -4009,9 +3996,8 @@ class DecomposeAtenConvolutionBackwardOp Value gradOutputTransposed = rewriter.create( loc, transposedType, gradOutput, cstZero, cstOne); // Convolve input with grad_output. - if (failed( - getTransposedType(op.getResultTypes()[1].cast(), - 0, 1, transposedType))) + if (failed(getTransposedType(cast(op.getResultTypes()[1]), + 0, 1, transposedType))) return failure(); gradWeight = rewriter.create( loc, transposedType, inputTransposed, gradOutputTransposed, cstNone, @@ -4063,7 +4049,7 @@ class DecomposeAtenAddmmOp : public OpRewritePattern { // TODO: Handle integer type operands. auto inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.getDtype().isa()) { + if (!inputType.hasDtype() || !isa(inputType.getDtype())) { return rewriter.notifyMatchFailure( op, "unimplemented: non-floating point dtype"); } @@ -4125,7 +4111,7 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { MLIRContext *context = op.getContext(); BaseTensorType inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.getDtype().isa() || + if (!inputType.hasDtype() || !isa(inputType.getDtype()) || !isNoneOrFloatDtype(context, dtype)) { return rewriter.notifyMatchFailure( op, "only floating-point type is supported"); @@ -4133,7 +4119,7 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { SmallVector dimListElements; if (!getListConstructElements(dimList, dimListElements) && - !dimList.getType().isa()) { + !isa(dimList.getType())) { return rewriter.notifyMatchFailure( op, "expected `dim` to be `None` or constructed from list construct"); } @@ -4215,7 +4201,7 @@ class DecomposeAtenDropoutOp : public OpRewritePattern { return success(); } BaseTensorType inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.getDtype().isa()) + if (!inputType.hasDtype() || !isa(inputType.getDtype())) return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); Value noneVal = rewriter.create(loc); @@ -4243,7 +4229,7 @@ class DeomposeAtenNativeDropoutOp Value input = op.getInput(); Value prob = op.getP(); bool train = false; - if (!op.getTrain().getType().isa()) { + if (!isa(op.getTrain().getType())) { if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) { return rewriter.notifyMatchFailure( op, "train must be a boolean constant or none"); @@ -4263,7 +4249,7 @@ class DeomposeAtenNativeDropoutOp return success(); } BaseTensorType inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.getDtype().isa()) { + if (!inputType.hasDtype() || !isa(inputType.getDtype())) { return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); } @@ -4332,7 +4318,7 @@ class DecomposeAtenStdOp : public OpRewritePattern { Value self = op.getSelf(); BaseTensorType inputTensorTy = cast(self.getType()); if (!inputTensorTy.hasDtype() || - !inputTensorTy.getDtype().isa()) { + !isa(inputTensorTy.getDtype())) { return rewriter.notifyMatchFailure(op, "Only aten.std support floating type"); } @@ -4388,7 +4374,7 @@ class DecomposeAtenStdDimOp : public OpRewritePattern { Value self = op.getSelf(); BaseTensorType inputTensorType = cast(self.getType()); if (!inputTensorType.hasDtype() || - !inputTensorType.getDtype().isa()) { + !isa(inputTensorType.getDtype())) { return rewriter.notifyMatchFailure( op, "aten.std.dim expects input tensor of floating-point type"); } @@ -4413,7 +4399,7 @@ class DecomposeAtenStdCorrectionOp Value self = op.getSelf(); BaseTensorType inputTensorType = cast(self.getType()); if (!inputTensorType.hasDtype() || - !inputTensorType.getDtype().isa()) { + !isa(inputTensorType.getDtype())) { return rewriter.notifyMatchFailure( op, "aten.std.correction expects input tensor of floating-point type"); @@ -4506,7 +4492,7 @@ class DecomposeAtenRandLikeOp : public OpRewritePattern { Value input = op.getSelf(); Type resultType = op.getType(); auto inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.getDtype().isa()) { + if (!inputType.hasDtype() || !isa(inputType.getDtype())) { return rewriter.notifyMatchFailure(op, "only support floating-point type"); } @@ -4547,7 +4533,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, op, "can't decompose bernoulli like ops without sizes or dtype"); } // The `prob` is expected to be a float type tensor. - if (!probType.getDtype().isa()) { + if (!isa(probType.getDtype())) { return rewriter.notifyMatchFailure( op, "probabilities must be a float type tensor"); } @@ -4582,7 +4568,7 @@ class DecomposeAtenBernoulliOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); - if (!op.getGenerator().getType().isa()) + if (!isa(op.getGenerator().getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); @@ -4640,7 +4626,7 @@ class DecomposeAtenBernoulliTensorOp Location loc = op.getLoc(); Value input = op.getSelf(); Value prob = op.getP(); - if (!op.getGenerator().getType().isa()) + if (!isa(op.getGenerator().getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); @@ -4665,7 +4651,7 @@ class DecomposeAtenExponentialOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenExponentialOp op, PatternRewriter &rewriter) const override { - if (!op.getGenerator().getType().isa()) + if (!isa(op.getGenerator().getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); @@ -4706,7 +4692,7 @@ class DecomposeAtenNormalFunctionalOp using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNormalFunctionalOp op, PatternRewriter &rewriter) const override { - if (!op.getGenerator().getType().isa()) + if (!isa(op.getGenerator().getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); @@ -4984,10 +4970,10 @@ class DecomposeAtenNativeLayerNormOp Value weight = op.getWeight(); Value bias = op.getBias(); - if (!weight.getType().isa()) { + if (!isa(weight.getType())) { out = rewriter.create(loc, out.getType(), out, weight); } - if (!bias.getType().isa()) { + if (!isa(bias.getType())) { out = rewriter.create(loc, out.getType(), out, bias, one); } @@ -5238,13 +5224,13 @@ class DecomposeAtenNativeGroupNormOp loc, ListType::get(IntType::get(context)), viewShape); Value groupNormOutput = reshapedOutput; - if (!weight.getType().isa()) { + if (!isa(weight.getType())) { auto weightReshaped = rewriter.create( loc, baseType, weight, /*shape=*/viewShapeSizeList); groupNormOutput = rewriter.create( loc, inputType, groupNormOutput, weightReshaped); } - if (!bias.getType().isa()) { + if (!isa(bias.getType())) { auto biasReshaped = rewriter.create( loc, baseType, bias, /*shape=*/viewShapeSizeList); groupNormOutput = rewriter.create( @@ -5297,8 +5283,8 @@ class DecomposeAtenNativeBatchNormOp // In the inference mode, the `runningMean` and `runningVar` must not be // None. - if (runningMean.getType().isa() || - runningVar.getType().isa()) + if (isa(runningMean.getType()) || + isa(runningVar.getType())) return rewriter.notifyMatchFailure( op, "running stats must not be None in inference mode"); @@ -5354,7 +5340,7 @@ class DecomposeAtenNativeBatchNormOp // 2. bias = bias.view(1, C, 1?, 1?, 1?) // 3. output = normalizedInput * weight + bias Value batchNormOutput = normalizedInput; - if (!weight.getType().isa()) { + if (!isa(weight.getType())) { // Rank of `weight` must be exactly 1. std::optional weightRank = getTensorRank(weight); if (!weightRank || *weightRank != 1) @@ -5364,7 +5350,7 @@ class DecomposeAtenNativeBatchNormOp batchNormOutput = rewriter.create( loc, batchNormOutput.getType(), batchNormOutput, weight); } - if (!bias.getType().isa()) { + if (!isa(bias.getType())) { // Rank of `bias` must be exactly 1. std::optional biasRank = getTensorRank(bias); if (!biasRank || *biasRank != 1) @@ -5444,7 +5430,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Value dtype = op.getDtype(); - if (dtype.getType().isa()) { + if (isa(dtype.getType())) { BaseTensorType tensorType = cast(op.getSelf().getType()); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( @@ -5518,7 +5504,7 @@ class DecomposeAtenLinearOp : public OpRewritePattern { return transposeWeight; }; - if (bias.getType().isa()) { + if (isa(bias.getType())) { auto weightRank = weightType.getSizes().size(); if (weightRank > 2 || weightRank <= 0) return rewriter.notifyMatchFailure( @@ -5622,7 +5608,7 @@ class DecomposeAtenNewFullOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenNewFullOp op, PatternRewriter &rewriter) const override { Value dtype = op.getDtype(); - if (dtype.getType().isa()) { + if (isa(dtype.getType())) { BaseTensorType tensorType = cast(op.getSelf().getType()); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( @@ -5718,7 +5704,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Value noneVal = rewriter.create(op.getLoc()); Value dtype = op.getDtype(); - if (dtype.getType().isa()) { + if (isa(dtype.getType())) { BaseTensorType tensorType = cast(op.getSelf().getType()); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( @@ -5751,9 +5737,9 @@ class DecomposeAtenPadOp : public OpRewritePattern { } Value value = op.getValue(); - if (value.getType().isa()) + if (isa(value.getType())) return rewriter.notifyMatchFailure(op, "optional type not supported"); - if (value.getType().isa()) + if (isa(value.getType())) value = rewriter.create( op.getLoc(), rewriter.getF64FloatAttr(0)); @@ -5773,7 +5759,7 @@ class DecomposeAtenToDtypeLayoutOp LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op, PatternRewriter &rewriter) const override { // TODO: Add support for pinMemory arg equal to `True`. - if (!op.getPinMemory().getType().isa()) { + if (!isa(op.getPinMemory().getType())) { bool pinMemory; if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) return rewriter.notifyMatchFailure( @@ -5784,7 +5770,7 @@ class DecomposeAtenToDtypeLayoutOp } // TODO: Add support for device arg other than cpu. - if (!op.getDevice().getType().isa()) { + if (!isa(op.getDevice().getType())) { std::string device; if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) return rewriter.notifyMatchFailure( @@ -5796,7 +5782,7 @@ class DecomposeAtenToDtypeLayoutOp // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. - if (!op.getLayout().getType().isa()) { + if (!isa(op.getLayout().getType())) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( @@ -6262,7 +6248,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, Type newOutputType = outputTensorType.getWithSizesAndDtype( outputTensorType.getSizes(), rewriter.getF64Type()); if (!inputTensorTy.hasDtype() || - !inputTensorTy.getDtype().isa()) { + !isa(inputTensorTy.getDtype())) { return rewriter.notifyMatchFailure( op, "support floating-point type input only"); } @@ -6399,14 +6385,14 @@ class DecomposeAtenVarCorrectionOp PatternRewriter &rewriter) const override { int64_t correctionValInt; double correctionValFloat = 1.0; - if (!op.getCorrection().getType().isa()) { - if (op.getCorrection().getType().isa()) { + if (!isa(op.getCorrection().getType())) { + if (isa(op.getCorrection().getType())) { if (!matchPattern(op.getCorrection(), m_TorchConstantFloat(&correctionValFloat))) return rewriter.notifyMatchFailure( op, "Only support constant int or float correction value for " "aten.var"); - } else if (op.getCorrection().getType().isa()) { + } else if (isa(op.getCorrection().getType())) { if (!matchPattern(op.getCorrection(), m_TorchConstantInt(&correctionValInt))) return rewriter.notifyMatchFailure( @@ -6533,11 +6519,9 @@ class DecomposeAtenMseLossOp : public OpRewritePattern { if (!inputType.hasSizes()) return rewriter.notifyMatchFailure( op, "Expected the input tensor to have sizes"); - BaseTensorType subType = - inputType - .getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()), - resultType.getOptionalDtype()) - .cast(); + BaseTensorType subType = cast( + inputType.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()), + resultType.getOptionalDtype())); Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); @@ -6574,7 +6558,7 @@ class DecomposeAtenNormScalarOptDimOp Location loc = op->getLoc(); Value none = rewriter.create(loc); Value ord = op.getP(); - if (ord.getType().isa()) { + if (isa(ord.getType())) { ord = rewriter.create( loc, rewriter.getF64FloatAttr(2.0)); } @@ -6617,10 +6601,8 @@ class DecomposeAtenRandintLowOp : public OpRewritePattern { loc, rewriter.getF64FloatAttr((double)cstHigh)); BaseTensorType floatResultType = - resultTensorType - .getWithSizesAndDtype(resultTensorType.getSizes(), - rewriter.getF32Type()) - .cast(); + cast(resultTensorType.getWithSizesAndDtype( + resultTensorType.getSizes(), rewriter.getF32Type())); Value emptyTensor = rewriter.create( loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), @@ -6712,7 +6694,7 @@ class DecomposePrimsVarOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsVarOp op, PatternRewriter &rewriter) const override { - if (!op.getOutputDtype().getType().isa()) + if (!isa(op.getOutputDtype().getType())) return rewriter.notifyMatchFailure( op, "Unimplemented non-None dtype for prims::var op"); Value cstFalse = rewriter.create(op.getLoc(), false); @@ -6824,7 +6806,7 @@ class DecomposeAtenRandnLikeOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenRandnLikeOp op, PatternRewriter &rewriter) const override { // Only `none`, `contiguous` and `preserve` memory_format is supported. - if (!op.getMemoryFormat().getType().isa()) { + if (!isa(op.getMemoryFormat().getType())) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) @@ -6921,8 +6903,8 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern { op.getDevice(), op.getPinMemory()); // calculate (end - start) / (steps - 1) Value sub; - if (op.getEnd().getType().isa() || - op.getStart().getType().isa()) { + if (isa(op.getEnd().getType()) || + isa(op.getStart().getType())) { sub = rewriter.create(loc, Torch::FloatType::get(context), op.getEnd(), op.getStart()); } else { @@ -6938,7 +6920,7 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern { } // to dtype Value result; - if (!op.getDtype().getType().isa()) { + if (!isa(op.getDtype().getType())) { result = rewriter.create( loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal, /*copy=*/falseVal, /*memory_format=*/none); @@ -7352,11 +7334,8 @@ class DecomposeAtenScatterValueOp auto selfType = cast(self.getType()); auto indexType = cast(index.getType()); - BaseTensorType srcType = - selfType - .getWithSizesAndDtype(indexType.getOptionalSizes(), - selfType.getOptionalDtype()) - .cast(); + BaseTensorType srcType = cast(selfType.getWithSizesAndDtype( + indexType.getOptionalSizes(), selfType.getOptionalDtype())); Value src = createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList); rewriter.replaceOpWithNewOp(op, op.getType(), self, @@ -7444,7 +7423,7 @@ class DecomposeAtenSgnOp : public OpRewritePattern { "expected result type to have dtype"); } // TODO: support complex type in future. - if (outType.getDtype().isa()) { + if (isa(outType.getDtype())) { return rewriter.notifyMatchFailure(op, "doesn't support complex type now"); } @@ -7610,7 +7589,7 @@ static FailureOr createNewIndices(Operation *op, Location loc = op->getLoc(); MLIRContext *context = op->getContext(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (!inputType.hasSizes()) { return failure(); } @@ -7619,7 +7598,7 @@ static FailureOr createNewIndices(Operation *op, int64_t maxIndexRank = 0; for (auto index : oldIndices) { - auto indexType = index.getType().dyn_cast(); + auto indexType = dyn_cast(index.getType()); if (!indexType) // None index continue; if (!indexType.hasSizes()) @@ -7708,15 +7687,13 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { int64_t inputRank = inputSizes.size(); auto isTensor = [](Value v) { - return v.getType().isa(); + return isa(v.getType()); }; // directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin if (llvm::all_of(indices, isTensor)) { // By default, we regard the first index type as the list element type. - auto indexElemType = indices[0] - .getType() - .template cast() + auto indexElemType = cast(indices[0].getType()) .getWithSizesAndDtype(std::nullopt, nullptr); auto newIndices = rewriter.create( loc, Torch::ListType::get(indexElemType), indices); @@ -7806,7 +7783,7 @@ class DecomposeAtenIndexPutLikeOp "failed to get elements of `indices`"); auto input = op.getSelf(); - auto inputType = input.getType().template cast(); + auto inputType = cast(input.getType()); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only input with shape information is supported"); @@ -7815,15 +7792,13 @@ class DecomposeAtenIndexPutLikeOp int64_t inputRank = inputSizes.size(); auto isTensor = [](Value v) { - return v.getType().isa(); + return isa(v.getType()); }; // directly replace current op with aten.index_put.hacked_twin if (llvm::all_of(indices, isTensor)) { // By default, we regard the first index type as the list element type. - auto indexElemType = indices[0] - .getType() - .template cast() + auto indexElemType = cast(indices[0].getType()) .getWithSizesAndDtype(std::nullopt, nullptr); auto newIndex = rewriter.create( loc, Torch::ListType::get(indexElemType), indices); @@ -7953,7 +7928,7 @@ class DecomposeAtenLinalgNormOp : public OpRewritePattern { // default ord value is 2 for vector_norm auto ord = op.getOrd(); - if (ord.getType().isa()) { + if (isa(ord.getType())) { ord = rewriter.create(loc, rewriter.getI64IntegerAttr(2)); } rewriter.replaceOpWithNewOp( diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 887766c590fa..ec80d21ef20b 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -63,8 +63,8 @@ class FlatSymbolRefProgramPoint }; static bool isTypeTriviallySafe(Type type) { - return type.isa(); + return isa(type); } static bool isUseTreatedWithValueSemantics(OpOperand &use) { diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 630c0ef94106..b73044c9bd40 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -36,8 +36,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, static LogicalResult checkType(Operation *op, Type type, bool actuallyEmitDiagnostics) { // Allow various scalar types that backends are expected to be able to handle. - if (type.isa()) + if (isa( + type)) return success(); // Backends are not expected to support dynamic computations on these types, diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 095400d2b869..92e538772d85 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -187,7 +187,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock auto it = originalReturnTypes.find(i); if (it == originalReturnTypes.end()) continue; - auto originalType = it->second.cast(); + auto originalType = cast(it->second); rewriter.setInsertionPoint(returnOp); Value newReturnValue = copyTensorToType(rewriter, returnOp->getLoc(), originalType, operand.get()); @@ -350,7 +350,7 @@ class RewriteViewLikeSubgraph auto it = originalTypes.find(operand.get()); if (it == originalTypes.end()) continue; - auto originalType = it->second.cast(); + auto originalType = cast(it->second); rewriter.setInsertionPoint(op); Value newReturnValue = copyTensorToType(rewriter, op->getLoc(), originalType, operand.get()); diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 8b758a135751..84780e0426ae 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -118,7 +118,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { if (auto optionalType = dyn_cast(listType.getContainedType())) { if (!llvm::all_of(listConstruct.getElements(), [](Value val) { - return val.getType().isa(); + return isa(val.getType()); })) { rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 3b25e12c3a8e..cd6126aa4da5 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -81,7 +81,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( if (name.starts_with("valsem.")) name = name.drop_front(strlen("valsem.")); if (isa(op)) - name = cast(op)->getAttr("name").cast().getValue(); + name = cast(cast(op)->getAttr("name")).getValue(); std::string libFuncName = (getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str(); auto libFunc = library.lookupSymbol(libFuncName); @@ -191,8 +191,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // to match the library function signature. if (auto unionType = dyn_cast(desiredType)) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { - return containedType - .isa(); + return isa( + containedType); })) return b.create(loc, desiredType, operand).getResult(); } diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 6b18af04dca6..cf4e444d37a1 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -179,11 +179,10 @@ class RefineNumToTensorScalarOpType "should have concrete Scalar Type."); } Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType()); - auto impliedTypeFromInputType = + auto impliedTypeFromInputType = cast( cast(originalResultType) .getWithSizesAndDtype(originalResultType.getOptionalSizes(), - inputType) - .cast(); + inputType)); op.getResult().setType(impliedTypeFromInputType); return success(); diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 37ce829cb731..6d2008a28407 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -97,11 +97,10 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, } auto originalResultType = cast(result.getType()); - auto impliedTypesFromShape = + auto impliedTypesFromShape = cast( cast(originalResultType) .getWithSizesAndDtype(ArrayRef(sizes), - originalResultType.getOptionalDtype()) - .cast(); + originalResultType.getOptionalDtype())); return updateCalculateOpResultTypes(op, resultNum, impliedTypesFromShape, rewriter); diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp index 06f3fb8500bb..bd66bbe55330 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -74,7 +74,7 @@ LogicalResult FromBuiltinTensorOp::verify() { //===----------------------------------------------------------------------===// OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { @@ -87,7 +87,7 @@ OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { @@ -100,7 +100,7 @@ OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { @@ -113,7 +113,7 @@ OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { @@ -126,7 +126,7 @@ OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { @@ -139,7 +139,7 @@ OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index b58e070065fc..3c0ad51fb520 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -91,7 +91,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, return std::nullopt; // Other input type to be converted to i64 are handled by other // materializers. - if (!inputs[0].getType().isa()) + if (!isa(inputs[0].getType())) return std::nullopt; assert(inputs.size() == 1); return builder.createOrFold(loc, inputs[0]); @@ -145,7 +145,7 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, return std::nullopt; // Other input type to be converted to i64 are handled by other // materializers. - if (!inputs[0].getType().isa()) + if (!isa(inputs[0].getType())) return std::nullopt; assert(inputs.size() == 1); return builder.create(loc, inputs[0]).getResult(); diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 3bd16ed38940..880d6ace9cd6 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -56,7 +56,7 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); } static bool isArgMemRefTypeValid(Type type) { if (auto memRefType = dyn_cast(type)) { Type elemTy = memRefType.getElementType(); - if (elemTy.isa()) { + if (isa(elemTy)) { return true; } else if (auto integerTy = dyn_cast(elemTy)) { if (integerTy.isSignlessInteger(64)) @@ -70,7 +70,7 @@ static bool isArgMemRefTypeValid(Type type) { if (integerTy.isSignlessInteger(1)) return true; } else if (auto complexTy = dyn_cast(elemTy)) { - return complexTy.getElementType().isa(); + return isa(complexTy.getElementType()); } } return false; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index a0ecafb85ad5..b7fa9eb82f20 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -215,6 +215,17 @@ func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f3 // ----- +// CHECK-LABEL: func.func @test_scatter +func.func @test_scatter(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[2,3],si64>, %arg2: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[RESULT:.*]] = torch.aten.scatter.src %arg0, %[[INT0]], %arg1, %arg2 : !torch.vtensor<[3,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[3,3],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[3,3],f32> + %0 = torch.operator "onnx.Scatter"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,3],f32>, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> + return %0 : !torch.vtensor<[3,3],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_scatter_elements_with_axis func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1