Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with fixes of afca88a0 (May 31) (51) #284

Closed
wants to merge 8 commits into from
7 changes: 0 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,6 @@ cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLI

option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF)

# TODO(#3299): migrate to from member x.cast<T>() to mlir::cast<T>(x).
if(MSVC)
add_compile_options(/wd4996)
else()
add_compile_options(-Wno-deprecated-declarations)
endif()

macro(torch_mlir_enable_werror)
if(TORCH_MLIR_ENABLE_WERROR_FLAG)
if(NOT MSVC)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getInputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
return isa<MemRefType>(opOperand->get().getType());
});
return result;
}]
Expand All @@ -144,7 +144,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getInputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
return isa<RankedTensorType>(opOperand->get().getType());
});
return result;
}]
Expand Down Expand Up @@ -200,7 +200,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getOutputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
return isa<MemRefType>(opOperand->get().getType());
});
return result;
}]
Expand All @@ -219,7 +219,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::copy_if(getOutputOperands(),
std::back_inserter(result),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
return isa<RankedTensorType>(opOperand->get().getType());
});
return result;
}]
Expand All @@ -238,7 +238,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::transform(getOutputBufferOperands(),
std::back_inserter(result),
[](OpOperand *opOperands) {
return opOperands->get().getType().cast<MemRefType>();
return cast<MemRefType>(opOperands->get().getType());
});
return result;
}]
Expand All @@ -257,7 +257,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
llvm::transform(getOutputTensorOperands(),
std::back_inserter(result),
[](OpOperand *opOperands) {
return opOperands->get().getType().cast<RankedTensorType>();
return cast<RankedTensorType>(opOperands->get().getType());
});
return result;
}]
Expand Down Expand Up @@ -318,7 +318,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (!opOperand->get().getType().template isa<RankedTensorType>())
if (!isa<RankedTensorType>(opOperand->get().getType()))
return false;
if (opOperand->getOperandNumber() < $_op.getNumInputs())
return true;
Expand All @@ -334,7 +334,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
if (!opOperand->get().getType().template isa<RankedTensorType>())
if (!isa<RankedTensorType>(opOperand->get().getType()))
return false;
if (opOperand->getOperandNumber() >= $_op.getNumInputs())
return true;
Expand Down Expand Up @@ -367,7 +367,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
if (auto shapedType =
opOperand->get().getType().template dyn_cast<ShapedType>())
dyn_cast<ShapedType>(opOperand->get().getType()))
return shapedType.getRank();
return 0;
}]
Expand All @@ -383,7 +383,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
if (auto shapedType =
opOperand->get().getType().template dyn_cast<ShapedType>())
dyn_cast<ShapedType>(opOperand->get().getType()))
return shapedType.getShape();
return {};
}]
Expand All @@ -398,7 +398,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
return !opOperand->get().getType().template isa<ShapedType>();
return !isa<ShapedType>(opOperand->get().getType());
}]
>,
//===------------------------------------------------------------------===//
Expand All @@ -416,10 +416,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
return this->getOperation()->getNumResults() == 0 &&
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
return isScalar(opOperand) ||
opOperand->get().getType().template isa<MemRefType>();
isa<MemRefType>(opOperand->get().getType());
}) &&
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
return isa<MemRefType>(opOperand->get().getType());
});
}]
>,
Expand All @@ -435,10 +435,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {
return
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
return isScalar(opOperand) ||
opOperand->get().getType().template isa<RankedTensorType>();
isa<RankedTensorType>(opOperand->get().getType());
}) &&
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
return isa<RankedTensorType>(opOperand->get().getType());
});
}]
>,
Expand Down Expand Up @@ -478,8 +478,8 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> {

private:
void setOperandSegmentAt(unsigned idx, unsigned val) {
auto attr = (*this)->getAttr("operand_segment_sizes")
.cast<DenseIntElementsAttr>();
auto attr = cast<DenseIntElementsAttr>((*this)->getAttr("operand_segment_sizes")
);
unsigned i = 0;
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
Expand Down
24 changes: 12 additions & 12 deletions include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def TMTensor_ScanOp : TMTensor_Op<"scan",
return getOutputOperand(0)->get();
}
ShapedType getOperandType() {
return input().getType().cast<ShapedType>();
return cast<ShapedType>(input().getType());
}
int64_t getOperandRank() {
return getOperandType().getRank();
Expand Down Expand Up @@ -151,10 +151,10 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{

int64_t getIndexDepth() {
return getInputOperand(1)
return cast<ShapedType>(getInputOperand(1)
->get()
.getType()
.cast<ShapedType>()
)
.getShape()
.back();
}
Expand All @@ -164,27 +164,27 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter",
}

ShapedType getUpdateType() {
return updates().getType().cast<ShapedType>();
return cast<ShapedType>(updates().getType());
}

Value indices() {
return getInputOperand(1)->get();
}

ShapedType getIndicesType() {
return indices().getType().cast<ShapedType>();
return cast<ShapedType>(indices().getType());
}

Value original() {
return getOutputOperand(0)->get();
}

ShapedType getOriginalType() {
return original().getType().cast<ShapedType>();
return cast<ShapedType>(original().getType());
}

int64_t getUpdateSliceRank() {
return updates().getType().cast<ShapedType>().getRank() - 1;
return cast<ShapedType>(updates().getType()).getRank() - 1;
}

bool isScalarUpdate() {
Expand Down Expand Up @@ -224,7 +224,7 @@ def TMTensor_SortOp : TMTensor_Op<"sort",
return getOutputs()[index];
}
ShapedType getOperandType(int index) {
return operand(index).getType().cast<ShapedType>();
return cast<ShapedType>(operand(index).getType());
}
int64_t getOperandRank() {
return getOperandType(0).getRank();
Expand Down Expand Up @@ -291,16 +291,16 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
return getOutputOperand(0)->get();
}
ShapedType getQueryType() {
return getQuery().getType().cast<ShapedType>();
return cast<ShapedType>(getQuery().getType());
}
ShapedType getKeyType() {
return getKey().getType().cast<ShapedType>();
return cast<ShapedType>(getKey().getType());
}
ShapedType getValueType() {
return getValue().getType().cast<ShapedType>();
return cast<ShapedType>(getValue().getType());
}
ShapedType getOutputType() {
return getOutput().getType().cast<ShapedType>();
return cast<ShapedType>(getOutput().getType());
}
int64_t getQueryRank() {
return getQueryType().getRank();
Expand Down
6 changes: 3 additions & 3 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ struct onnx_list_of_constant_ints_op_binder {

bool match(Operation *op) {
auto constOp = dyn_cast<Torch::OperatorOp>(op);
if (!constOp || !constOp.getName().equals("onnx.Constant"))
if (!constOp || !(constOp.getName() == "onnx.Constant"))
return false;

if (DenseResourceElementsAttr attr =
constOp->getAttr("torch.onnx.value")
.dyn_cast_or_null<DenseResourceElementsAttr>()) {
dyn_cast_or_null<DenseResourceElementsAttr>(
constOp->getAttr("torch.onnx.value"))) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {
Expand Down
2 changes: 1 addition & 1 deletion include/torch-mlir/Dialect/Torch/IR/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ struct torch_list_of_optional_constant_ints_op_binder {
int64_t num;
if (matchPattern(value, m_TorchConstantInt(&num)))
bind_values.push_back(num);
else if (value.getType().isa<Torch::NoneType>())
else if (isa<Torch::NoneType>(value.getType()))
bind_values.push_back(std::nullopt);
else
return false;
Expand Down
10 changes: 5 additions & 5 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,8 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [
}];

let extraClassDeclaration = [{
Type getKeyType() { return getType().cast<DictType>().getKeyType(); }
Type getValueType() { return getType().cast<DictType>().getValueType(); }
Type getKeyType() { return cast<DictType>(getType()).getKeyType(); }
Type getValueType() { return cast<DictType>(getType()).getValueType(); }
}];
}

Expand Down Expand Up @@ -1003,7 +1003,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.vtensor",
"result", "operand",
"$_self.cast<NonValueTensorType>().getWithValueSemantics()">,
"cast<NonValueTensorType>($_self).getWithValueSemantics()">,
]> {
let summary = "Create a !torch.tensor with the same contents as the operand";
let description = [{
Expand Down Expand Up @@ -1036,7 +1036,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"operand is corresponding !torch.tensor",
"result", "operand",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">,
"cast<ValueTensorType>($_self).getWithoutValueSemantics()">,
]> {
let summary = "Create a !torch.vtensor with the same contents as the operand";
let description = [{
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [
def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [
TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type",
"value", "overwritten",
"$_self.cast<ValueTensorType>().getWithoutValueSemantics()">
"cast<ValueTensorType>($_self).getWithoutValueSemantics()">
]> {
let summary = "Ovewrite the contents of tensor with values from another.";
let description = [{
Expand Down
6 changes: 3 additions & 3 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
}

def AnyTorchTensorType : Type<
CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">,
CPred<"isa<::mlir::torch::Torch::BaseTensorType>($_self)">,
"Any Torch tensor type"
>;

Expand Down Expand Up @@ -410,11 +410,11 @@ def AnyTorchOptionalDeviceType:
def AnyTorchOptionalGeneratorType:
OptionalOf<Torch_GeneratorType, "Optional torch Generator type">;

def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">;
def IsListTypePred : CPred<"isa<::mlir::torch::Torch::ListType>($_self)">;
class ListOf<list<Type> allowedTypes, string descr> :
ContainerType<AnyTypeOf<allowedTypes>,
IsListTypePred,
"$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()",
"cast<::mlir::torch::Torch::ListType>($_self).getContainedType()",
descr, "::mlir::torch::Torch::ListType">;

def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
Expand Down
Loading
Loading