Skip to content

Commit

Permalink
Remove redundant code for gwvw > 1 route (#1573)
Browse files Browse the repository at this point in the history
  • Loading branch information
KKyang authored Jan 27, 2025
1 parent 258a216 commit 403cb39
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions tensilelite/Tensile/Components/GlobalWriteBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,22 +1086,23 @@ def _emitNonatomicAdd(self, module: Module):

activationCDataType = self.kernel["ProblemType"]["ActivationComputeDataType"]

if self.kernel["ProblemType"]["DestDataType"].isBFloat16() and self.kernel["ProblemType"]["HighPrecisionAccumulate"]:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBf16Mask), "0xffff0000", "mask for pack two bfloat16 element to 32bit" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp32Nan), "0x7fff0000", "fp32 Nan" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBf16Inc), "0x7fff", "rounding bias for bfloat16" ))
elif self.kernel["ProblemType"]["DestDataType"].isFloat8() and self.kernel["ProblemType"]["HighPrecisionAccumulate"]:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8NanInf), "0x207", "Nan and +/- inf" ))
if self.parentWriter.states.archCaps["HasFP8_OCP"]:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8Max), "0x43E00000", "Fp8 Max value 448 as float32" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8Min), "0xc3E00000", "Fp8 Min value -448 as float32" ))
else:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8Max), "0x43700000", "Fp8 Max value 240 as float32" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8Min), "0xc3700000", "Fp8 Min value -240 as float32" ))
elif self.kernel["ProblemType"]["DestDataType"].isBFloat8() and self.kernel["ProblemType"]["HighPrecisionAccumulate"]:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBF8NanInf), "0x207", "Nan and +/- inf" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBF8Max), "0x47600000", "BF8 Max value 57344 as float32" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBF8Min), "0xc7600000", "BF8 Min value -57344 as float32" ))
if self.kernel["_GlobalAccumulation"] != 'MultipleBuffer':
if self.kernel["ProblemType"]["DestDataType"].isBFloat16() and self.kernel["ProblemType"]["HighPrecisionAccumulate"]:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBf16Mask), "0xffff0000", "mask for pack two bfloat16 element to 32bit" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp32Nan), "0x7fff0000", "fp32 Nan" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBf16Inc), "0x7fff", "rounding bias for bfloat16" ))
elif self.kernel["ProblemType"]["DestDataType"].isFloat8() and self.kernel["ProblemType"]["HighPrecisionAccumulate"]:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8NanInf), "0x207", "Nan and +/- inf" ))
if self.parentWriter.states.archCaps["HasFP8_OCP"]:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8Max), "0x43E00000", "Fp8 Max value 448 as float32" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8Min), "0xc3E00000", "Fp8 Min value -448 as float32" ))
else:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8Max), "0x43700000", "Fp8 Max value 240 as float32" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprFp8Min), "0xc3700000", "Fp8 Min value -240 as float32" ))
elif self.kernel["ProblemType"]["DestDataType"].isBFloat8() and self.kernel["ProblemType"]["HighPrecisionAccumulate"]:
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBF8NanInf), "0x207", "Nan and +/- inf" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBF8Max), "0x47600000", "BF8 Max value 57344 as float32" ))
module.add(VMovB32(vgpr(self.cvtVgprStruct.vgprBF8Min), "0xc7600000", "BF8 Min value -57344 as float32" ))

storeCode = Module("GroupLoadStore")
vmcntTotalIssued = self.loadsBetaIssued + self.loadsEIssued
Expand Down

0 comments on commit 403cb39

Please sign in to comment.