From 8f90fbedc3277d13b94e2fd3d1f4c64b3355712c Mon Sep 17 00:00:00 2001 From: qinjun-li Date: Wed, 20 Nov 2024 11:56:46 +0800 Subject: [PATCH] [rtl] Handle chaining checks across register groups --- t1/src/Lane.scala | 7 +--- t1/src/vrf/ChainingCheck.scala | 21 ++++++++++-- t1/src/vrf/VRF.scala | 6 ++-- t1/src/vrf/WriteCheck.scala | 62 +++++++++++++++++++--------------- 4 files changed, 57 insertions(+), 39 deletions(-) diff --git a/t1/src/Lane.scala b/t1/src/Lane.scala index e15c74a03..0ef01f120 100644 --- a/t1/src/Lane.scala +++ b/t1/src/Lane.scala @@ -1155,13 +1155,8 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[ lastWriteOH ) ) - // 8 register - val paddingSize: Int = elementSizeForOneRegister * 8 - val shifterMask: UInt = (((selectMask ## Fill(paddingSize, true.B)) - << laneRequest.bits.vd(2, 0) ## 0.U(log2Ceil(elementSizeForOneRegister).W)) - >> paddingSize).asUInt - vrf.instructionWriteReport.bits.elementMask := shifterMask + vrf.instructionWriteReport.bits.elementMask := selectMask // clear record by instructionFinished vrf.instructionLastReport := instructionFinished diff --git a/t1/src/vrf/ChainingCheck.scala b/t1/src/vrf/ChainingCheck.scala index 45873e738..3630fb500 100644 --- a/t1/src/vrf/ChainingCheck.scala +++ b/t1/src/vrf/ChainingCheck.scala @@ -31,8 +31,25 @@ class ChainingCheck(val parameter: VRFParam) extends Module { // 3: 8 register val readOH: UInt = UIntToOH((read.vs ## read.offset)(parameter.vrfOffsetBits + 3 - 1, 0)) - val hitElement: Bool = (readOH & record.bits.elementMask) === 0.U - val raw: Bool = record.bits.vd.valid && (read.vs(4, 3) === record.bits.vd.bits(4, 3)) && hitElement + // todo: def + val elementSizeForOneRegister: Int = parameter.vLen / parameter.datapathWidth / parameter.laneNumber + val paddingSize: Int = elementSizeForOneRegister * 8 + + // elementMask records the relative position of the relative instruction. + // Let's calculate the absolute position. + val maskShifter: UInt = (((Fill(paddingSize, true.B) ## record.bits.elementMask ## Fill(paddingSize, true.B)) + << record.bits.vd.bits(2, 0) ## 0.U(log2Ceil(elementSizeForOneRegister).W)) + >> paddingSize).asUInt(2 * paddingSize - 1, 0) + // mask for vd's group + val maskForVD: UInt = cutUIntBySize(maskShifter, 2)(0) + // Due to the existence of segment load, writes may cross register groups + // So we need the mask of the previous set of registers + val maskForVD1: UInt = cutUIntBySize(maskShifter, 2)(1) + + val hitVd: Bool = (readOH & maskForVD) === 0.U && read.vs(4, 3) === record.bits.vd.bits(4, 3) + val hitVd1: Bool = (readOH & maskForVD1) === 0.U && read.vs(4, 3) === (record.bits.vd.bits(4, 3) + 1.U) + + val raw: Bool = record.bits.vd.valid && (hitVd || hitVd1) checkResult := !(!older && raw && !sameInst && recordValid) } diff --git a/t1/src/vrf/VRF.scala b/t1/src/vrf/VRF.scala index 7e59380eb..623330418 100644 --- a/t1/src/vrf/VRF.scala +++ b/t1/src/vrf/VRF.scala @@ -521,12 +521,12 @@ class VRF(val parameter: VRFParam) extends Module with SerializableModule[VRFPar vrfAllocateIssue := freeRecord.orR && olderCheck val writePort: Seq[ValidIO[VRFWriteRequest]] = Seq(writePipe) - val writeOH = writePort.map(p => UIntToOH((p.bits.vd ## p.bits.offset)(parameter.vrfOffsetBits + 3 - 1, 0))) val loadUnitReadPorts: Seq[DecoupledIO[VRFReadRequest]] = Seq(readRequests.last) - val loadReadOH: Seq[UInt] = - loadUnitReadPorts.map(p => UIntToOH((p.bits.vs ## p.bits.offset)(parameter.vrfOffsetBits + 3 - 1, 0))) Seq(chainingRecord, chainingRecordCopy).foreach { recordVec => recordVec.zipWithIndex.foreach { case (record, i) => + // read write one hot base on base address + val writeOH = writePort.map(p => UIntToOH((p.bits.vd - record.bits.vd.bits)(2, 0) ## p.bits.offset)) + val loadReadOH = loadUnitReadPorts.map(p => UIntToOH((p.bits.vs - record.bits.vs2)(2, 0) ## p.bits.offset)) val dataInLsuQueue = ohCheck(loadDataInLSUWriteQueue, record.bits.instIndex, parameter.chainingSize) // elementMask update by write val writeUpdateValidVec: Seq[Bool] = diff --git a/t1/src/vrf/WriteCheck.scala b/t1/src/vrf/WriteCheck.scala index b454782b7..60b4f3372 100644 --- a/t1/src/vrf/WriteCheck.scala +++ b/t1/src/vrf/WriteCheck.scala @@ -32,35 +32,41 @@ class WriteCheck(val parameter: VRFParam) extends Module { val sameInst: Bool = check.instructionIndex === record.bits.instIndex val checkOH: UInt = UIntToOH((check.vd ## check.offset)(parameter.vrfOffsetBits + 3 - 1, 0)) - // this element in record not execute - val notHitMaskVd: Bool = (checkOH & record.bits.elementMask) === 0.U - val waw: Bool = record.bits.vd.valid && check.vd(4, 3) === record.bits.vd.bits(4, 3) && notHitMaskVd - // inst eg: vadd v0, v1, v1 (lmul = 1) - // We only recorded vd-related masks. - // 0 base: 11111111111111xx eg vs = 0 off=2 - // As above, using vd as the perspective, - // we will access the lowest two elements of the register group where vd is located. - // But from the perspective of vs1: - // 1 base: 111111111111xx11 eg vs = 1 off=2 - // Apparently. Our mask has shifted - // 0 base => 1 base << (1 * off) - // we need vd%8 base => vs1%8 base => vd base mask << (vs1 - vd) * off - // => vd base mask >> 8 * off << (8 + vs1 - vd) * off - // => vd base mask << (8 + vs1 - vd) * off >> 8 * off - val vs1Mask: UInt = (((-1.S(parameter.elementSize.W)).asUInt ## record.bits.elementMask) << - ((8.U + record.bits.vs1.bits(2, 0) - record.bits.vd.bits(2, 0)) << parameter.vrfOffsetBits).asUInt).asUInt( - 2 * 8 * parameter.singleGroupSize - 1, - 8 * parameter.singleGroupSize - ) + val elementSizeForOneRegister: Int = parameter.vLen / parameter.datapathWidth / parameter.laneNumber + val paddingSize: Int = elementSizeForOneRegister * 8 + + // elementMask records the relative position of the relative instruction. + // Let's calculate the absolute position. + val maskShifter: UInt = (((Fill(paddingSize, true.B) ## record.bits.elementMask ## Fill(paddingSize, true.B)) + << record.bits.vd.bits(2, 0) ## 0.U(log2Ceil(elementSizeForOneRegister).W)) + >> paddingSize).asUInt(2 * paddingSize - 1, 0) + // mask for vd's group + val maskForVD: UInt = cutUIntBySize(maskShifter, 2)(0) + // Due to the existence of segment load, writes may cross register groups + // So we need the mask of the previous set of registers + val maskForVD1: UInt = cutUIntBySize(maskShifter, 2)(1) + + val hitVd: Bool = (checkOH & maskForVD) === 0.U && check.vd(4, 3) === record.bits.vd.bits(4, 3) + val hitVd1: Bool = (checkOH & maskForVD1) === 0.U && check.vd(4, 3) === (record.bits.vd.bits(4, 3) + 1.U) + val waw: Bool = record.bits.vd.valid && (hitVd || hitVd1) + + // calculate the absolute position for vs1 + val vs1Mask: UInt = (((record.bits.elementMask ## Fill(paddingSize, true.B)) + << record.bits.vs1.bits(2, 0) ## 0.U(log2Ceil(elementSizeForOneRegister).W)) + >> paddingSize).asUInt val notHitVs1: Bool = (checkOH & vs1Mask) === 0.U val war1: Bool = record.bits.vs1.valid && check.vd(4, 3) === record.bits.vs1.bits(4, 3) && notHitVs1 - val maskForVs2: UInt = record.bits.elementMask & Fill(parameter.elementSize, !record.bits.onlyRead) - val vs2Mask: UInt = (((-1.S(parameter.elementSize.W)).asUInt ## maskForVs2) << - ((8.U + record.bits.vs2(2, 0) - record.bits.vd.bits(2, 0)) << parameter.vrfOffsetBits).asUInt).asUInt( - 2 * 8 * parameter.singleGroupSize - 1, - 8 * parameter.singleGroupSize - ) - val notHitVs2: Bool = (checkOH & vs2Mask) === 0.U - val war2: Bool = check.vd(4, 3) === record.bits.vs2(4, 3) && notHitVs2 + + // calculate the absolute position for vs2 + val maskShifterForVs2: UInt = (((Fill(paddingSize, true.B) ## record.bits.elementMask ## Fill(paddingSize, true.B)) + << record.bits.vs2(2, 0) ## 0.U(log2Ceil(elementSizeForOneRegister).W)) + >> paddingSize).asUInt(2 * paddingSize - 1, 0) + + val maskForVs2: UInt = cutUIntBySize(maskShifterForVs2, 2)(0) & Fill(parameter.elementSize, !record.bits.onlyRead) + val maskForVs21: UInt = cutUIntBySize(maskShifterForVs2, 2)(1) + val hitVs2: Bool = (checkOH & maskForVs2) === 0.U && check.vd(4, 3) === record.bits.vs2(4, 3) + val hitVs21: Bool = (checkOH & maskForVs21) === 0.U && check.vd(4, 3) === (record.bits.vs2(4, 3) + 1.U) + val war2: Bool = hitVs2 || hitVs21 + checkResult := !((!older && (waw || war1 || war2)) && !sameInst && record.valid) }