diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp index 42307b744c0c..2f85e2375e64 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/AggregatedOpInterfaceImpl.cpp @@ -229,9 +229,12 @@ static Value applyMask(OpBuilder &builder, Location loc, AffineMap qkMap, Value maskVal = args[0]; // TODO: Replace bool mask condition once treated as i1 (instead of i8) - if (maskVal.getType().isInteger()) { - maskVal = - b.create(loc, builder.getI1Type(), maskVal); + auto maskValType = maskVal.getType(); + if (maskValType.isInteger()) { + if (maskValType.getIntOrFloatBitWidth() != 1) { + maskVal = + b.create(loc, builder.getI1Type(), maskVal); + } maskVal = b.create(loc, maskVal, zero, negInf); } else { maskVal = convertScalarToDtype(b, loc, maskVal, qkVal.getType(),