diff --git a/src/main/scala/chisel3/util/experimental/CIRCTSRAMInterface.scala b/src/main/scala/chisel3/util/experimental/CIRCTSRAMInterface.scala new file mode 100644 index 0000000000..f0e02a72fe --- /dev/null +++ b/src/main/scala/chisel3/util/experimental/CIRCTSRAMInterface.scala @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: Apache-2.0 + +package chisel3.util.experimental + +import scala.collection.immutable.SeqMap + +import chisel3._ +import chisel3.experimental.BaseModule +import chisel3.experimental.hierarchy.Instance + +import chisel3.util.log2Ceil + +object CIRCTSRAMParameter { + implicit val rw: upickle.default.ReadWriter[CIRCTSRAMParameter] = upickle.default.macroRW +} + +case class CIRCTSRAMParameter( + moduleName: String, + read: Int, + write: Int, + readwrite: Int, + depth: Int, + width: Int, + maskGranularity: Int) { + def masked: Boolean = maskGranularity != 0 +} + +class CIRCTSRAMReadPort(memoryParameter: CIRCTSRAMParameter) extends Record { + val clock = Input(Clock()) + val address = Input(UInt(log2Ceil(memoryParameter.depth).W)) + val data = Output(UInt(memoryParameter.width.W)) + val enable = Input(Bool()) + + // Records store elements in reverse order + val elements: SeqMap[String, Data] = SeqMap( + "addr" -> address, + "en" -> enable, + "clk" -> clock, + "data" -> data + ).toSeq.reverse.to(SeqMap) +} + +class CIRCTSRAMReadWritePort(memoryParameter: CIRCTSRAMParameter) extends Record { + val clock = Input(Clock()) + val address = Input(UInt(log2Ceil(memoryParameter.depth).W)) + val writeData = Input(UInt(memoryParameter.width.W)) + val writeMask = Option.when(memoryParameter.masked)(Input(UInt(memoryParameter.width.W))) + val writeEnable = Input(Bool()) + val readData = Output(UInt(memoryParameter.width.W)) + val enable = Input(Bool()) + + // Records store elements in reverse order + val elements: SeqMap[String, Data] = (SeqMap( + "addr" -> address, + "en" -> enable, + "clk" -> clock, + "wmode" -> writeEnable, + "wdata" -> writeData, + "rdata" -> readData + ) ++ Option.when(memoryParameter.masked)("wmask" -> writeMask.get)).toSeq.reverse.to(SeqMap) +} + +class CIRCTSRAMWritePort(memoryParameter: CIRCTSRAMParameter) extends Record { + val clock = Input(Clock()) + val address = Input(UInt(log2Ceil(memoryParameter.depth).W)) + val data = Input(UInt(memoryParameter.width.W)) + val mask = Option.when(memoryParameter.masked)(Input(UInt(memoryParameter.width.W))) + val enable = Input(Bool()) + + // Records store elements in reverse order + val elements: SeqMap[String, Data] = (SeqMap( + "addr" -> address, + "en" -> enable, + "clk" -> clock, + "data" -> data + ) ++ Option.when(memoryParameter.masked)("mask" -> mask.get)).toSeq.reverse.to(SeqMap) +} + +class CIRCTSRAMInterface(memoryParameter: CIRCTSRAMParameter) extends Record { + def R(idx: Int) = + elements.getOrElse(s"R$idx", throw new Exception(s"Cannot get port R$idx")).asInstanceOf[CIRCTSRAMReadPort] + def RW(idx: Int) = + elements.getOrElse(s"RW$idx", throw new Exception(s"Cannot get port RW$idx")).asInstanceOf[CIRCTSRAMReadWritePort] + def W(idx: Int) = + elements.getOrElse(s"W$idx", throw new Exception(s"Cannot get port W$idx")).asInstanceOf[CIRCTSRAMWritePort] + + // Records store elements in reverse order + val elements: SeqMap[String, Data] = + (Seq.tabulate(memoryParameter.read)(i => s"R$i" -> new CIRCTSRAMReadPort(memoryParameter)) ++ + Seq.tabulate(memoryParameter.readwrite)(i => s"RW$i" -> new CIRCTSRAMReadWritePort(memoryParameter)) ++ + Seq.tabulate(memoryParameter.write)(i => s"W$i" -> new CIRCTSRAMWritePort(memoryParameter))).reverse + .to(SeqMap) +} + +abstract class CIRCTSRAM[T <: RawModule](memoryParameter: CIRCTSRAMParameter) + extends FixedIORawModule[CIRCTSRAMInterface](new CIRCTSRAMInterface(memoryParameter)) { + override def desiredName: String = memoryParameter.moduleName + val memoryInstance: Instance[_ <: BaseModule] +} diff --git a/src/test/scala/chiselTests/experimental/CIRCTSRAMInterfaceSpec.scala b/src/test/scala/chiselTests/experimental/CIRCTSRAMInterfaceSpec.scala new file mode 100644 index 0000000000..3437dc1e6c --- /dev/null +++ b/src/test/scala/chiselTests/experimental/CIRCTSRAMInterfaceSpec.scala @@ -0,0 +1,60 @@ +package chiselTests.experimental + +import scala.util.chaining.scalaUtilChainingOps + +import chisel3._ +import chisel3.experimental.hierarchy.Instantiate + +import chisel3.util.SRAM +import chisel3.util.experimental.{CIRCTSRAM, CIRCTSRAMInterface, CIRCTSRAMParameter, SlangUtils} + +import chiselTests.ChiselFlatSpec +import circt.stage.ChiselStage + +class CIRCTSRAMInterfaceSpec extends ChiselFlatSpec { + "CIRCTSRAMInterface" should "match the Verilog ports generated by CIRCT" in { + def matchPorts(rd: Int, wr: Int, rw: Int, depth: Int, width: Int) = { + class GenerateSRAMModule extends Module { + val sram = SRAM(depth, UInt(width.W), rd, wr, rw) + + val ioR = IO(chiselTypeOf(sram.readPorts)).tap(_.zip(sram.readPorts).foreach { + case (io, mem) => io <> mem + }) + val ioRW = IO(chiselTypeOf(sram.readwritePorts)).tap(_.zip(sram.readwritePorts).foreach { + case (io, mem) => io <> mem + }) + val ioW = IO(chiselTypeOf(sram.writePorts)).tap(_.zip(sram.writePorts).foreach { + case (io, mem) => io <> mem + }) + } + + class CIRCTSRAMTestModule extends CIRCTSRAM(CIRCTSRAMParameter("sram_interface", rd, wr, rw, depth, width, 0)) { + class EmptyModule extends RawModule {} + val memoryInstance = Instantiate(new EmptyModule) + + for (i <- 0 until rd) { + io.R(i).data := DontCare + } + for (i <- 0 until rw) { + io.RW(i).readData := DontCare + } + } + + val targetDir = "CIRCTSRAMInterfaceSpec" + val firrtlOpts = Array("--split-verilog", s"-td=${targetDir}") + ChiselStage.emitSystemVerilogFile(new GenerateSRAMModule, firrtlOpts) + ChiselStage.emitSystemVerilogFile(new CIRCTSRAMTestModule, firrtlOpts) + + val sramPorts = + SlangUtils.verilogModuleIO( + SlangUtils.getVerilogAst(os.read(os.pwd / targetDir / s"sram_sram_${depth}x${width}.sv")) + ) + val interfacePorts = + SlangUtils.verilogModuleIO(SlangUtils.getVerilogAst(os.read(os.pwd / targetDir / "sram_interface.sv"))) + + assert(sramPorts.toString == interfacePorts.toString) + } + + Seq.tabulate(2, 2, 2) { case (rd, wr, rw) => if (rd + rw != 0 && wr + rw != 0) matchPorts(rd, wr, rw, 32, 8) } + } +}