Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate llvm 1_20_2025 #19740

Merged
merged 6 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder,
return val;
matchPattern(val, m_Constant(&attr));
} else {
attr = llvm::cast<IntegerAttr>(attrOrValue.get<Attribute>());
attr = llvm::cast<IntegerAttr>(cast<Attribute>(attrOrValue));
}
return builder.createOrFold<arith::ConstantIndexOp>(
loc, attr.getValue().getSExtValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,7 @@ static LogicalResult commonRunOnOperation(
auto packOp = cast<tensor::PackOp>(op);

// Do nothing if any of inner tile sizes is dynamic.
if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
return tile.is<Value>();
})) {
if (llvm::any_of(packOp.getMixedTiles(), llvm::IsaPred<Value>)) {
return {};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ static Value getAsIndexValue(OpFoldResult attrOrValue, OpBuilder &builder,
return val;
matchPattern(val, m_Constant(&attr));
} else {
attr = llvm::cast<IntegerAttr>(attrOrValue.get<Attribute>());
attr = cast<IntegerAttr>(cast<Attribute>(attrOrValue));
}
return builder.createOrFold<arith::ConstantIndexOp>(loc, attr.getInt());
}
Expand Down Expand Up @@ -101,12 +101,12 @@ struct VectorizePadWithConditions final

/// Return true if the given `attrOrValue` is a constant zero.
auto isConstantZero = [](OpFoldResult attrOrValue) {
if (attrOrValue.is<Attribute>()) {
auto attr = llvm::dyn_cast<IntegerAttr>(attrOrValue.get<Attribute>());
return attr && attr.getValue().getZExtValue() == 0;
if (auto attr = dyn_cast<Attribute>(attrOrValue)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().getZExtValue() == 0;
}
IntegerAttr attr;
return matchPattern(attrOrValue.get<Value>(), m_Constant(&attr)) &&
return matchPattern(cast<Value>(attrOrValue), m_Constant(&attr)) &&
attr.getValue().getZExtValue() == 0;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ void DistributionLayout::print(raw_ostream &os) const {
void DistributionLayout::onUpdate(DataFlowSolver *solver) const {
AnalysisState::onUpdate(solver);

Value value = anchor.get<Value>();
Value value = cast<Value>(anchor);

if (propagation) {
// Make propagation run again on all users of this value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
// CHECK: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
// CHECK: %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
// CHECK: %[[CST:.+]] = arith.constant 0xFFC00000 : f32
// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
Expand Down Expand Up @@ -49,7 +49,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK-NO-FUSE: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK-NO-FUSE: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
// CHECK-NO-FUSE: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
// CHECK-NO-FUSE: %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
// CHECK-NO-FUSE: %[[CST:.+]] = arith.constant 0xFFC00000 : f32
// CHECK-NO-FUSE: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
// CHECK-NO-FUSE: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel",
// CHECK-NO-FUSE-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/DispatchABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,12 +843,12 @@ MemRefDescriptor HALDispatchABI::loadBinding(Operation *forOp, int64_t ordinal,
currentStrideVal = builder.create<LLVM::ConstantOp>(
loc, llvmIndexType, currentStrideInt.value());
} else {
currentStrideVal = currentStride.get<Value>();
currentStrideVal = cast<Value>(currentStride);
}
currentStride =
builder.create<LLVM::MulOp>(loc, currentStrideVal, dim)
.getResult();
desc.setStride(builder, loc, i - 1, currentStride.get<Value>());
desc.setStride(builder, loc, i - 1, cast<Value>(currentStride));
} else {
currentStride = builder.getIndexAttr(strides[i - 1]);
desc.setConstantStride(builder, loc, i - 1, strides[i - 1]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1788,8 +1788,8 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
// backends prefer to not decompose the ops.
DictionaryAttr pipelineConfig;
auto target = IREE::HAL::ExecutableTargetAttr::lookup(entryPointFn);
bool hasDynamicInnerTile = llvm::any_of(
op.getMixedTiles(), [](OpFoldResult ofr) { return ofr.is<Value>(); });
bool hasDynamicInnerTile =
llvm::any_of(op.getMixedTiles(), llvm::IsaPred<Value>);
if (!hasDynamicInnerTile && !isX86(target) && !isRISCV(target)) {
pipelineConfig = getPipelineConfWithDecompositionAttr(op.getContext());
}
Expand Down Expand Up @@ -1828,8 +1828,8 @@ static LogicalResult setRootConfig(mlir::FunctionOpInterface entryPointFn,
// backends prefer to not decompose the ops.
DictionaryAttr pipelineConfig;
auto target = IREE::HAL::ExecutableTargetAttr::lookup(entryPointFn);
bool hasDynamicInnerTile = llvm::any_of(
op.getMixedTiles(), [](OpFoldResult ofr) { return ofr.is<Value>(); });
bool hasDynamicInnerTile =
llvm::any_of(op.getMixedTiles(), llvm::IsaPred<Value>);
if (!hasDynamicInnerTile && !isX86(target) && !isRISCV(target)) {
pipelineConfig = getPipelineConfWithDecompositionAttr(op.getContext());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ Type mlirType(MLIRContext *context, MMTKernel::ScalarType t) {
case MMTKernel::ScalarType::I32:
return IntegerType::get(context, 32, IntegerType::Signless);
case MMTKernel::ScalarType::F32:
return FloatType::getF32(context);
return Float32Type::get(context);
}
assert(false);
return Type();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,12 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
currentStrideVal = rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexType, currentStrideInt.value());
} else {
currentStrideVal = currentStride.get<Value>();
currentStrideVal = cast<Value>(currentStride);
}
currentStride =
rewriter.create<LLVM::MulOp>(loc, currentStrideVal, dim)
.getResult();
desc.setStride(rewriter, loc, i - 1, currentStride.get<Value>());
desc.setStride(rewriter, loc, i - 1, cast<Value>(currentStride));
} else {
currentStride = rewriter.getIndexAttr(strides[i - 1]);
desc.setConstantStride(rewriter, loc, i - 1, strides[i - 1]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ getPaddedShapeFromTensorLoad(IREE::Flow::DispatchTensorLoadOp tensorLoad,
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB,
{size.get<Value>(), /*dim=*/std::nullopt},
{cast<Value>(size), /*dim=*/std::nullopt},
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(upperBound))
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) {
// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
// CHECK: arith.maxnumf
// CHECK: arith.maxnumf
// CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32>
Comment on lines 197 to 198
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened here?

Copy link
Contributor Author

@nirvedhmeshram nirvedhmeshram Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, that maxnumf is not present in IR, overall a lot of reduction codegen seems changed. I wonder if its something with valuebound interface like this commit
llvm/llvm-project#122804
I will flag this on discord.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think whats happening here can be explained by this commit
llvm/llvm-project#118952
I believe that we were using -INF in softmax decompostion when we wanted NAN, with NAN, I think there is some simplification as you would simply pick the other argument hence needing one less maxnumf after folding..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets check what happens to correctness with this... Might need to revert it. I missed it, but I dont think using NAN make sense. Should be -INF.

Copy link
Contributor Author

@nirvedhmeshram nirvedhmeshram Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CI is not finding any errors, I believe there are softmax dispactches checked into the regression tests so I think this works. I think there is some discussion on this here
llvm/llvm-project#114595
NaN is safe with maxnumf but should we have used maximumf and kept using -INF would be something to decide. Also should maxnumf fold the same for -INF is also something to think over.

// CHECK: scf.for {{.*}} -> (vector<4xf32>) {
// CHECK: vector.transfer_read
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1000,11 +1000,11 @@ struct ReifyExtractOfCreateMask final
for (auto [idx, size] :
llvm::zip_equal(extractOp.getMixedPosition(), maskOp.getOperands())) {
Value idxVal;
if (idx.is<Attribute>()) {
if (auto attr = dyn_cast<Attribute>(idx)) {
idxVal = rewriter.create<arith::ConstantIndexOp>(
loc, cast<IntegerAttr>(idx.get<Attribute>()).getInt());
loc, dyn_cast<IntegerAttr>(attr).getInt());
} else {
idxVal = idx.get<Value>();
idxVal = dyn_cast<Value>(idx);
}
Value cmpIdx = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, idxVal, size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ func.func @softmax() attributes {hal.executable.target = #executable_target_vulk
// CHECK: arith.maxnumf
// CHECK: gpu.shuffle xor
// CHECK: arith.maxnumf
// CHECK: arith.maxnumf
// CHECK: vector.splat %{{.*}} : vector<4xf32>
// CHECK: scf.for {{.*}} -> (vector<4xf32>) {
// CHECK: vector.transfer_read
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ hal.executable @i4_dequant_unit_matmul_f16 {

// CHECK-DAG: %[[CSTVEC4XI32_255:.+]] = spirv.Constant dense<255> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC4XI32_0:.+]] = spirv.Constant dense<0> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC4XI32_0_4:.+]] = spirv.Constant dense<[0, 4, 0, 4]> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC4XI32_15__16:.+]] = spirv.Constant dense<[15, -16, 15, -16]> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC2XI32_4:.+]] = spirv.Constant dense<4> : vector<2xi32>
// CHECK-DAG: %[[CSTVEC2XI32_15:.+]] = spirv.Constant dense<15> : vector<2xi32>

// CHECK: spirv.mlir.loop

// Load the quantized weight and get 8xi4 out of it.
// CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %{{.+}} : vector<4xi32>
// CHECK: %[[SHUF01:.+]] = spirv.VectorShuffle [0 : i32, 1 : i32] %[[LOAD]], %[[LOAD]] : vector<4xi32>, vector<4xi32> -> vector<2xi32>
// CHECK: %[[SHUF0011:.+]] = spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32] %[[SHUF01]], %[[SHUF01]] : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[SHUF0011]], %[[CSTVEC4XI32_15__16]] : vector<4xi32>
// CHECK: %[[SHIFTED:.+]] = spirv.ShiftRightLogical %[[MASKED]], %[[CSTVEC4XI32_0_4]] : vector<4xi32>, vector<4xi32>
// CHECK: %[[LOW4HIGH4_ZEROUPPER:.+]] = spirv.BitwiseAnd %[[SHIFTED]], %[[CSTVEC4XI32_255]] : vector<4xi32>
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[SHUF01]], %[[CSTVEC2XI32_15]] : vector<2xi32>
// CHECK: %[[SHIFTED:.+]] = spirv.ShiftRightLogical %[[SHUF01]], %[[CSTVEC2XI32_4]] : vector<2xi32>, vector<2xi32>
// CHECK: %[[SHUF0011:.+]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[MASKED]], %[[SHIFTED]] : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: %[[LOW4HIGH4_ZEROUPPER:.+]] = spirv.BitwiseAnd %[[SHUF0011]], %[[CSTVEC4XI32_255]] : vector<4xi32>

// CHECK: %[[SHUF23:.+]] = spirv.VectorShuffle [2 : i32, 3 : i32] %[[LOAD:.+]], %[[LOAD:.+]] : vector<4xi32>, vector<4xi32> -> vector<2xi32>

Expand Down Expand Up @@ -186,8 +186,6 @@ hal.executable @i4_dequant_matvec_f16_subgroup_64 {
// CHECK-DAG: %[[C0:.+]] = spirv.Constant 0 : i32
// CHECK-DAG: %[[CSTVEC4XF16_1:.+]] = spirv.Constant dense<1.000000e+00> : vector<4xf16>
// CHECK-DAG: %[[CSTVEC4XI32_255:.+]] = spirv.Constant dense<255> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC2XI32_1:.+]] = spirv.Constant dense<[0, 4, 0, 4]> : vector<4xi32>
// CHECK-DAG: %[[CSTVEC2XI32_2:.+]] = spirv.Constant dense<[15, -16, 15, -16]> : vector<4xi32>

// CHECK: %[[WIDX:.+]] = spirv.CompositeExtract %{{.*}}[0 : i32] : vector<3xi32>
// CHECK: %[[PCPTR:.+]] = spirv.AccessChain %{{.*}}[{{.*}}, %[[C0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<5 x i32, stride=4> [0])>, PushConstant>, i32, i32
Expand All @@ -209,9 +207,6 @@ hal.executable @i4_dequant_matvec_f16_subgroup_64 {
// CHECK: %[[ACCESS:.+]] = spirv.AccessChain %[[RADDR]][{{.*}}, %[[OFFSET]]] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: spirv.Load "StorageBuffer" %[[ACCESS]] : i32

// CHECK: spirv.ShiftRightLogical %{{.*}}, %[[CSTVEC2XI32_1]] : vector<4xi32>, vector<4xi32>
// CHECK: spirv.BitwiseAnd %{{.*}}, %[[CSTVEC4XI32_255]] : vector<4xi32>

// CHECK: spirv.ConvertUToF %{{.+}} : vector<4xi32> to vector<4xf16>
// CHECK: spirv.FSub %{{.+}}, %{{.+}} : vector<4xf16>
// CHECK-COUNT-2: spirv.FMul %{{.+}}, %{{.+}} : vector<4xf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ hal.executable @i4_dequant {
// CHECK-LABEL: spirv.func @i4_dequant()

// CHECK: %[[BYTE1:.+]] = spirv.VectorShuffle [0 : i32, 1 : i32] {{.*}} : vector<4xi32>, vector<4xi32> -> vector<2xi32>
// CHECK: %[[COPIED:.+]] = spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32] %[[BYTE1]], %[[BYTE1]] : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[COPIED]]
// CHECK: %[[SHIFTED:.+]] = spirv.ShiftRightLogical %[[MASKED]]
// CHECK: %[[ZEROUPPER:.+]] = spirv.BitwiseAnd %[[SHIFTED]]
// CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %[[BYTE1]]
// CHECK: %[[SHIFTED:.+]] = spirv.ShiftRightLogical %[[BYTE1]]
// CHECK: %[[COPIED:.+]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[MASKED]], %[[SHIFTED]] : vector<2xi32>, vector<2xi32> -> vector<4xi32>
// CHECK: %[[MASKED2:.+]] = spirv.BitwiseAnd %[[COPIED]]
// CHECK: spirv.VectorShuffle [2 : i32, 3 : i32] {{.*}} : vector<4xi32>, vector<4xi32> -> vector<2xi32>
// CHECK-COUNT-3: spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32]
// CHECK: spirv.VectorShuffle [0 : i32, 1 : i32]
// CHECK: spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32]
// CHECK: spirv.VectorShuffle [2 : i32, 3 : i32]
// CHECK: spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32]
// CHECK-NOT: spirv.VectorShuffle

// CHECK-COUNT-4: spirv.ConvertUToF {{.+}} : vector<4xi32> to vector<4xf32>
// CHECK-COUNT-4: spirv.FSub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ static SliceAndDynamicDims cloneOffsetsSizesAndStridesImpl(
SmallVector<OpFoldResult> clonedOfrs;
clonedOfrs.reserve(ofrs.size());
for (auto ofr : ofrs) {
if (ofr.is<Attribute>()) {
if (isa<Attribute>(ofr)) {
clonedOfrs.push_back(ofr);
} else {
clonedOfrs.push_back(bvm.lookupOrDefault(ofr.get<Value>()));
clonedOfrs.push_back(bvm.lookupOrDefault(cast<Value>(ofr)));
}
}
return clonedOfrs;
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/Flow/IR/FlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ def FLOW_CallOp : FLOW_Op<"call", [

/// Set the callee for this operation.
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
(*this)->setAttr("callee", cast<SymbolRefAttr>(callee));
}

ValueRange getOperandDynamicDims(unsigned idx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
return failure();
for (int64_t i = 0; i < shapedType.getRank(); ++i)
if (shapedType.isDynamicDim(i))
dynamicDims.push_back(dims[opResult.getResultNumber()][i].get<Value>());
dynamicDims.push_back(cast<Value>(dims[opResult.getResultNumber()][i]));
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2578,13 +2578,13 @@ module attributes { transform.with_named_sequence } {
transform.yield
}
}
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
// CHECK: func @custom_op_index_handling(%[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xindex>,
// CHECK: scf.forall (%[[IV:[a-zA-Z0-9]+]],
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: iree_linalg_ext.custom_op
// CHECK-SAME: ins(%[[SLICE]]
// CHECK: %[[NEW_INDEX:.+]] = iree_linalg_ext.index 0 : index
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]](%[[NEW_INDEX]], %[[IV]])
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]](%[[IV]])[%[[NEW_INDEX]]]
// CHECK: linalg.generic
// CHECK-SAME: ins(%{{.+}}, %[[INDEX]] :
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2642,7 +2642,7 @@ def Stream_AsyncCallOp : Stream_Op<"async.call", [

/// Set the callee for this operation.
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
(*this)->setAttr("callee", cast<SymbolRefAttr>(callee));
}

Value getOperandSize(unsigned idx) {
Expand Down Expand Up @@ -3322,7 +3322,7 @@ def Stream_CmdCallOp : Stream_Op<"cmd.call", [

/// Set the callee for this operation.
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
(*this)->setAttr("callee", cast<SymbolRefAttr>(callee));
}

Value getOperandSize(unsigned idx) {
Expand Down
10 changes: 5 additions & 5 deletions compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ void Explorer::initializeGlobalInfos() {
void Explorer::initializeInverseCallGraph() {
forEachFunctionLikeOp([&](FunctionOpInterface parentOp) {
parentOp->walk([&](CallOpInterface callOp) {
if (callOp.getCallableForCallee().is<Value>()) {
if (isa<Value>(callOp.getCallableForCallee())) {
// Indirect calls can't be tracked in the call graph, so ensure we mark
// the incomplete flag so that any call graph queries return
// TraversalResult::INCOMPLETE.
Expand Down Expand Up @@ -777,7 +777,7 @@ TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn,
// Indirect calls would require us to perform an analysis to first see if we
// can make them direct or annotate the call sites with the possible
// targets.
if (callOp.getCallableForCallee().is<Value>()) {
if (isa<Value>(callOp.getCallableForCallee())) {
LLVM_DEBUG({
llvm::dbgs()
<< " !! traversal incomplete due to unanalyzable indirect call: ";
Expand All @@ -786,7 +786,7 @@ TraversalResult Explorer::walkDefiningOps(Value value, ResultWalkFn fn,
});
return TraversalResult::INCOMPLETE;
}
auto targetSymbol = callOp.getCallableForCallee().get<SymbolRefAttr>();
auto targetSymbol = cast<SymbolRefAttr>(callOp.getCallableForCallee());
auto targetOp = symbolTables.lookupNearestSymbolFrom<CallableOpInterface>(
callOp, targetSymbol);
assert(targetOp && "call target not found");
Expand Down Expand Up @@ -1031,7 +1031,7 @@ TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn,
// Move across a call to the callee entry block.
auto traverseCallOp = [&](CallOpInterface callOp, unsigned operandIdx) {
auto callable = callOp.getCallableForCallee();
if (callable.is<Value>()) {
if (isa<Value>(callable)) {
LLVM_DEBUG({
llvm::dbgs()
<< " !! traversal incomplete due to unanalyzable indirect call: ";
Expand All @@ -1040,7 +1040,7 @@ TraversalResult Explorer::walkTransitiveUses(Value value, UseWalkFn fn,
});
return TraversalResult::INCOMPLETE;
}
auto targetSymbol = callable.get<SymbolRefAttr>();
auto targetSymbol = cast<SymbolRefAttr>(callable);
auto targetOp = symbolTables.lookupNearestSymbolFrom<CallableOpInterface>(
callOp, targetSymbol);
assert(targetOp && "call target not found");
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def Util_CallOp : Util_Op<"call", [
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
(*this)->setAttr("callee", cast<SymbolRefAttr>(callee));
}

// Clones the call and potentially expands each operand and result.
Expand Down
Loading
Loading