From 71f1e205a3df61e236cf10054798c53fcfe460c9 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Tue, 6 Aug 2024 13:29:21 -0700 Subject: [PATCH] Revert "Optimize `fp8` `linalg_ext.attention` by rework Q@K scaling" (#18112) Reverts iree-org/iree#18031 Mathematically correct, however since scaling (which is an elementwise mul linalg.generic) happens now before reduction-addf, it unexpectedly becomes a vector.contract after GenericVectorization pass. ``` %41 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind} %39, %cst_0, %34 : vector<32x128xf32>, vector<32x128xf32> into vector<32xf32> ``` We can re-land once VectorDistribute uses vector_ext.multi_mma, and we can turn off generateContract in genericVectorization. --- .../Transforms/AggregatedOpInterfaceImpl.cpp | 74 +++++++++++++------ .../test/decompose_online_attention.mlir | 20 ++--- 2 files changed, 62 insertions(+), 32 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp index b4e3d2e027ed..23ba79c2f665 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/AggregatedOpInterfaceImpl.cpp @@ -90,19 +90,28 @@ static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap, auto srcTy = cast(args[0].getType()); auto dstTy = cast(args[1].getType()); + // We clamp to the min / max of the floating point representation + double mnDbl = + APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/true) + .convertToDouble(); double mxDbl = APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false) .convertToDouble(); // Truncate to the `fp8` range so avoid nan values. + Value mn = builder.create( + loc, builder.getFloatAttr(srcTy, mnDbl)); Value mx = builder.create( loc, builder.getFloatAttr(srcTy, mxDbl)); Value gt = b.create(loc, arith::CmpFPredicate::OGT, args[0], mx); + Value lt = b.create(loc, arith::CmpFPredicate::OLT, + args[0], mn); Value sel0 = b.create(loc, gt, mx, args[0]); + Value sel1 = b.create(loc, lt, mn, sel0); // Convert scale to the same datatype as input. - Value trunc = convertScalarToDtype(b, loc, sel0, dstTy, + Value trunc = convertScalarToDtype(b, loc, sel1, dstTy, /*isUnsignedCast=*/false); b.create(loc, trunc); }); @@ -293,7 +302,6 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // SMap = QMap @ KMap Value emptyS = b.create(loc, sSizes, elementType); Value sZero = b.create(loc, b.getZeroAttr(elementType)); - Value s = b.create(loc, sZero, emptyS).getResult(0); s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); @@ -315,6 +323,11 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap maxMap = getMaxMap(); Value newMax = reduce(b, loc, sMap, maxMap, s, oldMax); + // P = exp2(S - newMax) + // PMap = SMap + AffineMap pMap = sMap; + Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); + // norm = exp2(oldMax - newMax) // normMap = maxMap AffineMap normMap = getMaxMap(); @@ -324,27 +337,6 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap sumMap = getSumMap(); Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm); - // P = exp2(S - newMax) - // PMap = SMap - AffineMap pMap = sMap; - Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s); - - // If we need to truncate to fp8 post softmax we apply a scaling to use the - // full fp8 range. We can do this with a offset as post `exp2` this equates - // to multiplying by a static value. We are able to do this as `max` and `sum` - // are scaled by the same value so the end result is the same. - if (isa(qETy) && qETy.getIntOrFloatBitWidth() == 8) { - auto fpTy = cast(qETy); - double mx = - APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false) - .convertToDouble(); - Value offset = - b.create(loc, b.getFloatAttr(elementType, mx)); - AffineMap scaleMap = AffineMap::get(/*dimCount=*/pMap.getNumInputs(), - /*symbolCount=*/0, getContext()); - p = scaleValueInPlace(b, loc, pMap, scaleMap, p, offset); - } - // newSum = normSum + rowSum(P) Value newSum = reduce(b, loc, pMap, sumMap, p, normSum); @@ -352,8 +344,36 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { AffineMap accMap = getOutputMap(); // ---- Scale and truncate LHS to match RHS ---- + Value pScale; auto pETy = getElementTypeOrSelf(p.getType()); if (pETy != vETy && isa(vETy)) { + if (vETy.getIntOrFloatBitWidth() <= 8) { + SmallVector mSizes( + llvm::map_range(maxMap.getResults(), [&](AffineExpr dimExpr) { + return sizes[cast(dimExpr).getPosition()]; + })); + + auto fpTy = cast(vETy); + double largestDbl = + APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false) + .convertToDouble(); + + // We normalize p from [0, max] to [0, fp8.max] to guarantee we + // use the full `fp8` range, then renormlize post Softmax@V matmul + // to correct. + pScale = b.create( + loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / largestDbl)); + + // Compute the pre matmul scale to handle fp8 quantization: + Value pScaleInv = b.create( + loc, b.getFloatAttr(elementType, largestDbl / clAttentionSoftmaxMax)); + + AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(), + /*symbolCount=*/0, getContext()); + p = scaleValueInPlace(b, loc, pMap, scaleMap, p, pScaleInv); + norm = scaleValueInPlace(b, loc, normMap, scaleMap, norm, pScaleInv); + } + Value convertP = b.create(loc, sSizes, vETy); p = truncateFloat(b, loc, pMap, pMap, p, convertP); } @@ -364,6 +384,14 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { // newAcc = P @ V + newAcc newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc); + + // Update for for the FP8 dynamic scale: + if (pScale) { + AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(), + /*symbolCount=*/0, getContext()); + newAcc = scaleValueInPlace(b, loc, accMap, scaleMap, newAcc, pScale); + } + return SmallVector{newAcc, newMax, newSum}; } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir index 51fcc6dc2608..3e323eda0c10 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/decompose_online_attention.mlir @@ -43,15 +43,15 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>, // CHECK: arith.subf // CHECK: math.exp2 // CHECK: linalg.yield -// normSum = norm * oldSum -// CHECK: linalg.generic -// CHECK: arith.mulf -// CHECK: linalg.yield // norm = exp2(oldMax - newMax) // CHECK: linalg.generic // CHECK: arith.subf // CHECK: math.exp2 // CHECK: linalg.yield +// normSum = norm * oldSum +// CHECK: linalg.generic +// CHECK: arith.mulf +// CHECK: linalg.yield // newSum = normSum + rowMax(P) // CHECK: linalg.generic // CHECK: arith.addf @@ -107,6 +107,11 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // CHECK: linalg.generic // CHECK: arith.maximumf // CHECK: linalg.yield +// P = exp2(S - newMax) +// CHECK: linalg.generic +// CHECK: arith.subf +// CHECK: math.exp2 +// CHECK: linalg.yield // norm = exp2(oldMax - newMax) // CHECK: linalg.generic // CHECK: arith.subf @@ -116,11 +121,6 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // CHECK: linalg.generic // CHECK: arith.mulf // CHECK: linalg.yield -// P = exp2(S - newMax) -// CHECK: linalg.generic -// CHECK: arith.subf -// CHECK: math.exp2 -// CHECK: linalg.yield // newSum = normSum + rowMax(P) // CHECK: linalg.generic // CHECK: arith.addf @@ -128,6 +128,8 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>, // clamp = clamp(norm) // CHECK: linalg.generic // CHECK: arith.cmpf ogt +// CHECK: arith.cmpf olt +// CHECK: arith.select // CHECK: arith.select // CHECK: arith.truncf // newAcc = norm * oldAcc