Skip to content

Commit

Permalink
Revert "Optimize fp8 linalg_ext.attention by rework Q@K scaling" (i…
Browse files Browse the repository at this point in the history
…ree-org#18112)

Reverts iree-org#18031

Mathematically correct, however since scaling (which is an elementwise
mul linalg.generic) happens now before reduction-addf, it unexpectedly
becomes a vector.contract after GenericVectorization pass.

```
    %41 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %39, %cst_0, %34 : vector<32x128xf32>, vector<32x128xf32> into vector<32xf32>
```

We can re-land once VectorDistribute uses vector_ext.multi_mma, and we can turn off generateContract in genericVectorization.
  • Loading branch information
raikonenfnu authored Aug 6, 2024
1 parent 5d8362c commit 71f1e20
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,28 @@ static Value truncateFloat(OpBuilder &builder, Location loc, AffineMap inputMap,
auto srcTy = cast<FloatType>(args[0].getType());
auto dstTy = cast<FloatType>(args[1].getType());

// We clamp to the min / max of the floating point representation
double mnDbl =
APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/true)
.convertToDouble();
double mxDbl =
APFloat::getLargest(dstTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();

// Truncate to the `fp8` range so avoid nan values.
Value mn = builder.create<arith::ConstantOp>(
loc, builder.getFloatAttr(srcTy, mnDbl));
Value mx = builder.create<arith::ConstantOp>(
loc, builder.getFloatAttr(srcTy, mxDbl));
Value gt = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
args[0], mx);
Value lt = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
args[0], mn);
Value sel0 = b.create<arith::SelectOp>(loc, gt, mx, args[0]);
Value sel1 = b.create<arith::SelectOp>(loc, lt, mn, sel0);

// Convert scale to the same datatype as input.
Value trunc = convertScalarToDtype(b, loc, sel0, dstTy,
Value trunc = convertScalarToDtype(b, loc, sel1, dstTy,
/*isUnsignedCast=*/false);
b.create<linalg::YieldOp>(loc, trunc);
});
Expand Down Expand Up @@ -293,7 +302,6 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
// SMap = QMap @ KMap
Value emptyS = b.create<tensor::EmptyOp>(loc, sSizes, elementType);
Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));

Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0);
s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s);

Expand All @@ -315,6 +323,11 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
AffineMap maxMap = getMaxMap();
Value newMax = reduce<arith::MaximumFOp>(b, loc, sMap, maxMap, s, oldMax);

// P = exp2(S - newMax)
// PMap = SMap
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);

// norm = exp2(oldMax - newMax)
// normMap = maxMap
AffineMap normMap = getMaxMap();
Expand All @@ -324,36 +337,43 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {
AffineMap sumMap = getSumMap();
Value normSum = scaleValueInPlace(b, loc, sumMap, normMap, oldSum, norm);

// P = exp2(S - newMax)
// PMap = SMap
AffineMap pMap = sMap;
Value p = computeSubAndExp2(b, loc, maxMap, sMap, newMax, s);

// If we need to truncate to fp8 post softmax we apply a scaling to use the
// full fp8 range. We can do this with a offset as post `exp2` this equates
// to multiplying by a static value. We are able to do this as `max` and `sum`
// are scaled by the same value so the end result is the same.
if (isa<FloatType>(qETy) && qETy.getIntOrFloatBitWidth() == 8) {
auto fpTy = cast<FloatType>(qETy);
double mx =
APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();
Value offset =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, mx));
AffineMap scaleMap = AffineMap::get(/*dimCount=*/pMap.getNumInputs(),
/*symbolCount=*/0, getContext());
p = scaleValueInPlace(b, loc, pMap, scaleMap, p, offset);
}

// newSum = normSum + rowSum(P)
Value newSum = reduce<arith::AddFOp>(b, loc, pMap, sumMap, p, normSum);

// newAcc = norm * oldAcc
AffineMap accMap = getOutputMap();

// ---- Scale and truncate LHS to match RHS ----
Value pScale;
auto pETy = getElementTypeOrSelf(p.getType());
if (pETy != vETy && isa<FloatType>(vETy)) {
if (vETy.getIntOrFloatBitWidth() <= 8) {
SmallVector<OpFoldResult> mSizes(
llvm::map_range(maxMap.getResults(), [&](AffineExpr dimExpr) {
return sizes[cast<AffineDimExpr>(dimExpr).getPosition()];
}));

auto fpTy = cast<FloatType>(vETy);
double largestDbl =
APFloat::getLargest(fpTy.getFloatSemantics(), /*Negative=*/false)
.convertToDouble();

// We normalize p from [0, max] to [0, fp8.max] to guarantee we
// use the full `fp8` range, then renormlize post Softmax@V matmul
// to correct.
pScale = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(elementType, clAttentionSoftmaxMax / largestDbl));

// Compute the pre matmul scale to handle fp8 quantization:
Value pScaleInv = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(elementType, largestDbl / clAttentionSoftmaxMax));

AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
/*symbolCount=*/0, getContext());
p = scaleValueInPlace(b, loc, pMap, scaleMap, p, pScaleInv);
norm = scaleValueInPlace(b, loc, normMap, scaleMap, norm, pScaleInv);
}

Value convertP = b.create<tensor::EmptyOp>(loc, sSizes, vETy);
p = truncateFloat(b, loc, pMap, pMap, p, convertP);
}
Expand All @@ -364,6 +384,14 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) {

// newAcc = P @ V + newAcc
newAcc = computeMatmul(b, loc, pMap, getValueMap(), accMap, p, value, newAcc);

// Update for for the FP8 dynamic scale:
if (pScale) {
AffineMap scaleMap = AffineMap::get(/*dimCount=*/maxMap.getNumInputs(),
/*symbolCount=*/0, getContext());
newAcc = scaleValueInPlace(b, loc, accMap, scaleMap, newAcc, pScale);
}

return SmallVector<Value>{newAcc, newMax, newSum};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ func.func @attention_f16(%query: tensor<192x1024x64xf16>,
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// normSum = norm * oldSum
// CHECK: linalg.generic
// CHECK: arith.mulf
// CHECK: linalg.yield
// norm = exp2(oldMax - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// normSum = norm * oldSum
// CHECK: linalg.generic
// CHECK: arith.mulf
// CHECK: linalg.yield
// newSum = normSum + rowMax(P)
// CHECK: linalg.generic
// CHECK: arith.addf
Expand Down Expand Up @@ -107,6 +107,11 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
// CHECK: linalg.generic
// CHECK: arith.maximumf
// CHECK: linalg.yield
// P = exp2(S - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// norm = exp2(oldMax - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
Expand All @@ -116,18 +121,15 @@ func.func @attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
// CHECK: linalg.generic
// CHECK: arith.mulf
// CHECK: linalg.yield
// P = exp2(S - newMax)
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: math.exp2
// CHECK: linalg.yield
// newSum = normSum + rowMax(P)
// CHECK: linalg.generic
// CHECK: arith.addf
// CHECK: linalg.yield
// clamp = clamp(norm)
// CHECK: linalg.generic
// CHECK: arith.cmpf ogt
// CHECK: arith.cmpf olt
// CHECK: arith.select
// CHECK: arith.select
// CHECK: arith.truncf
// newAcc = norm * oldAcc
Expand Down

0 comments on commit 71f1e20

Please sign in to comment.