Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add CIRCTSRAMInterface #4494

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions src/main/scala/chisel3/util/experimental/CIRCTSRAMInterface.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// SPDX-License-Identifier: Apache-2.0

package chisel3.util.experimental
unlsycn marked this conversation as resolved.
Show resolved Hide resolved

import scala.collection.immutable.SeqMap

import chisel3._
import chisel3.experimental.BaseModule
import chisel3.experimental.hierarchy.Instance

import chisel3.util.log2Ceil

object CIRCTSRAMParameter {
implicit val rw: upickle.default.ReadWriter[CIRCTSRAMParameter] = upickle.default.macroRW
}

case class CIRCTSRAMParameter(
moduleName: String,
read: Int,
write: Int,
readwrite: Int,
depth: Int,
width: Int,
maskGranularity: Int) {
def masked: Boolean = maskGranularity != 0
}

class CIRCTSRAMReadPort(memoryParameter: CIRCTSRAMParameter) extends Record {
val clock = Input(Clock())
val address = Input(UInt(log2Ceil(memoryParameter.depth).W))
val data = Output(UInt(memoryParameter.width.W))
val enable = Input(Bool())

// Records store elements in reverse order
val elements: SeqMap[String, Data] = SeqMap(
"addr" -> address,
"en" -> enable,
"clk" -> clock,
"data" -> data
).toSeq.reverse.to(SeqMap)
}

class CIRCTSRAMReadWritePort(memoryParameter: CIRCTSRAMParameter) extends Record {
val clock = Input(Clock())
val address = Input(UInt(log2Ceil(memoryParameter.depth).W))
val writeData = Input(UInt(memoryParameter.width.W))
val writeMask = Option.when(memoryParameter.masked)(Input(UInt(memoryParameter.width.W)))
val writeEnable = Input(Bool())
val readData = Output(UInt(memoryParameter.width.W))
val enable = Input(Bool())

// Records store elements in reverse order
val elements: SeqMap[String, Data] = (SeqMap(
"addr" -> address,
"en" -> enable,
"clk" -> clock,
"wmode" -> writeEnable,
"wdata" -> writeData,
"rdata" -> readData
) ++ Option.when(memoryParameter.masked)("wmask" -> writeMask.get)).toSeq.reverse.to(SeqMap)
}

class CIRCTSRAMWritePort(memoryParameter: CIRCTSRAMParameter) extends Record {
val clock = Input(Clock())
val address = Input(UInt(log2Ceil(memoryParameter.depth).W))
val data = Input(UInt(memoryParameter.width.W))
val mask = Option.when(memoryParameter.masked)(Input(UInt(memoryParameter.width.W)))
val enable = Input(Bool())

// Records store elements in reverse order
val elements: SeqMap[String, Data] = (SeqMap(
"addr" -> address,
"en" -> enable,
"clk" -> clock,
"data" -> data
) ++ Option.when(memoryParameter.masked)("mask" -> mask.get)).toSeq.reverse.to(SeqMap)
}

class CIRCTSRAMInterface(memoryParameter: CIRCTSRAMParameter) extends Record {
def R(idx: Int) =
elements.getOrElse(s"R$idx", throw new Exception(s"Cannot get port R$idx")).asInstanceOf[CIRCTSRAMReadPort]
def RW(idx: Int) =
elements.getOrElse(s"RW$idx", throw new Exception(s"Cannot get port RW$idx")).asInstanceOf[CIRCTSRAMReadWritePort]
def W(idx: Int) =
elements.getOrElse(s"W$idx", throw new Exception(s"Cannot get port W$idx")).asInstanceOf[CIRCTSRAMWritePort]

// Records store elements in reverse order
val elements: SeqMap[String, Data] =
(Seq.tabulate(memoryParameter.read)(i => s"R$i" -> new CIRCTSRAMReadPort(memoryParameter)) ++
Seq.tabulate(memoryParameter.readwrite)(i => s"RW$i" -> new CIRCTSRAMReadWritePort(memoryParameter)) ++
Seq.tabulate(memoryParameter.write)(i => s"W$i" -> new CIRCTSRAMWritePort(memoryParameter))).reverse
.to(SeqMap)
}

abstract class CIRCTSRAM[T <: RawModule](memoryParameter: CIRCTSRAMParameter)
extends FixedIORawModule[CIRCTSRAMInterface](new CIRCTSRAMInterface(memoryParameter)) {
override def desiredName: String = memoryParameter.moduleName
val memoryInstance: Instance[_ <: BaseModule]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package chiselTests.experimental

import scala.util.chaining.scalaUtilChainingOps

import chisel3._
import chisel3.experimental.hierarchy.Instantiate

import chisel3.util.SRAM
import chisel3.util.experimental.{CIRCTSRAM, CIRCTSRAMInterface, CIRCTSRAMParameter, SlangUtils}

import chiselTests.ChiselFlatSpec
import circt.stage.ChiselStage

class CIRCTSRAMInterfaceSpec extends ChiselFlatSpec {
"CIRCTSRAMInterface" should "match the Verilog ports generated by CIRCT" in {
def matchPorts(rd: Int, wr: Int, rw: Int, depth: Int, width: Int) = {
class GenerateSRAMModule extends Module {
val sram = SRAM(depth, UInt(width.W), rd, wr, rw)

val ioR = IO(chiselTypeOf(sram.readPorts)).tap(_.zip(sram.readPorts).foreach {
case (io, mem) => io <> mem
})
val ioRW = IO(chiselTypeOf(sram.readwritePorts)).tap(_.zip(sram.readwritePorts).foreach {
case (io, mem) => io <> mem
})
val ioW = IO(chiselTypeOf(sram.writePorts)).tap(_.zip(sram.writePorts).foreach {
case (io, mem) => io <> mem
})
}

class CIRCTSRAMTestModule extends CIRCTSRAM(CIRCTSRAMParameter("sram_interface", rd, wr, rw, depth, width, 0)) {
class EmptyModule extends RawModule {}
val memoryInstance = Instantiate(new EmptyModule)

for (i <- 0 until rd) {
io.R(i).data := DontCare
}
for (i <- 0 until rw) {
io.RW(i).readData := DontCare
}
}

val targetDir = "CIRCTSRAMInterfaceSpec"
val firrtlOpts = Array("--split-verilog", s"-td=${targetDir}")
ChiselStage.emitSystemVerilogFile(new GenerateSRAMModule, firrtlOpts)
ChiselStage.emitSystemVerilogFile(new CIRCTSRAMTestModule, firrtlOpts)

val sramPorts =
SlangUtils.verilogModuleIO(
SlangUtils.getVerilogAst(os.read(os.pwd / targetDir / s"sram_sram_${depth}x${width}.sv"))
)
val interfacePorts =
SlangUtils.verilogModuleIO(SlangUtils.getVerilogAst(os.read(os.pwd / targetDir / "sram_interface.sv")))

assert(sramPorts.toString == interfacePorts.toString)
}

Seq.tabulate(2, 2, 2) { case (rd, wr, rw) => if (rd + rw != 0 && wr + rw != 0) matchPorts(rd, wr, rw, 32, 8) }
}
}