diff --git a/magma/__init__.py b/magma/__init__.py index 6d77e3a49..46ceebf4b 100644 --- a/magma/__init__.py +++ b/magma/__init__.py @@ -147,7 +147,7 @@ def set_mantle_target(t): from magma.debug import magma_helper_function import magma.mantle from magma.sum_type import Sum, match, case, Enum2 -from magma.syntax.fsm import fsm, FSMState, wait, wait_until +from magma.syntax.fsm import fsm, FSMState, wait, wait_until, loop ################################################################################ # BEGIN ALIASES diff --git a/magma/syntax/fsm.py b/magma/syntax/fsm.py index 8077d39ec..168d991ee 100644 --- a/magma/syntax/fsm.py +++ b/magma/syntax/fsm.py @@ -54,3 +54,37 @@ def wait_until(cond: m.Bit): def wait(): wait_until(1) + + +class LoopWrapper: + def __init__(self, value, when_ctx): + self._value = value + self._when_ctx = when_ctx + + def __enter__(self): + self._when_ctx.__enter__() + return self._value + + def __exit__(self, exc_type, exc_value, exc_tb): + self._when_ctx.__exit__(exc_type, exc_value, exc_tb) + + +def loop(start, end, step=1): + with m.no_when(): + end_reg = m.Register(type(end))() + end_reg.I @= end_reg.O + end_reg.I @= end + with m.no_when(): + i = m.Register(type(end))() + i.I @= i.O + step + i.I @= start + m.wait() + # TODO: We explicitly override after wait, otherwise defautl value will set + # start, can we avoid this? + i.I @= i.O + step + when = m.when(end_reg.O < i.O) + curr_fsm = _FSM_STACK.peek() + when.exit_stack.callback( + lambda: curr_fsm.cases[-1].exit_stack.enter_context(m.otherwise()) + ) + return LoopWrapper(i.O, when) diff --git a/magma/when.py b/magma/when.py index 55808dfe5..202425a8f 100644 --- a/magma/when.py +++ b/magma/when.py @@ -75,6 +75,7 @@ def __init__(self, parent: Optional['WhenBlock']): self._children = list() self._conditional_wires = list() self._default_drivers = dict() + self.exit_stack = contextlib.ExitStack() def spawn(self, info: 'WhenBlockInfo') -> 'WhenBlock': child = WhenBlock(self, info) @@ -220,6 +221,7 @@ def __exit__(self, exc_type, exc_value, traceback): self._add_reg_enables() _set_curr_block(self._get_exit_block()) _set_prev_block(self) + self.exit_stack.__exit__(exc_type, exc_value, traceback) BlockBase = _BlockBase diff --git a/tests/test_syntax/gold/test_fsm_wait_loop.mlir b/tests/test_syntax/gold/test_fsm_wait_loop.mlir new file mode 100644 index 000000000..ea03368f8 --- /dev/null +++ b/tests/test_syntax/gold/test_fsm_wait_loop.mlir @@ -0,0 +1,136 @@ +module attributes {circt.loweringOptions = "locationInfoStyle=none,omitVersionComment"} { + hw.module @Foo(in %n: i8, in %CLK: i1, out O: i16) { + %1 = hw.struct_create (%0) : !hw.struct + %3 = sv.reg name "Register_inst0" : !hw.inout> + sv.alwaysff(posedge %CLK) { + sv.passign %3, %1 : !hw.struct + } + %5 = hw.constant 0 : i2 + %4 = hw.struct_create (%5) : !hw.struct + sv.initial { + sv.bpassign %3, %4 : !hw.struct + } + %2 = sv.read_inout %3 : !hw.inout> + %6 = hw.struct_extract %2["tag"] : !hw.struct + %7 = hw.constant 0 : i2 + %8 = comb.icmp eq %6, %7 : i2 + %9 = hw.constant 65261 : i16 + %10 = hw.constant 2 : i2 + %11 = hw.constant 2 : i2 + %12 = comb.icmp eq %6, %11 : i2 + %13 = hw.constant 57005 : i16 + %16 = sv.reg name "Register_inst1" : !hw.inout + sv.alwaysff(posedge %CLK) { + sv.passign %16, %14 : i8 + } + %17 = hw.constant 0 : i8 + sv.initial { + sv.bpassign %16, %17 : i8 + } + %15 = sv.read_inout %16 : !hw.inout + %20 = sv.reg name "Register_inst2" : !hw.inout + sv.alwaysff(posedge %CLK) { + sv.passign %20, %18 : i8 + } + sv.initial { + sv.bpassign %20, %17 : i8 + } + %19 = sv.read_inout %20 : !hw.inout + %21 = hw.constant 1 : i8 + %22 = comb.add %19, %21 : i8 + %23 = hw.constant 0 : i8 + %24 = hw.constant 0 : i1 + %25 = hw.constant 1 : i1 + %28 = sv.reg name "Register_inst3" : !hw.inout + sv.alwaysff(posedge %CLK) { + sv.passign %28, %26 : i1 + } + %29 = hw.constant 0 : i1 + sv.initial { + sv.bpassign %28, %29 : i1 + } + %27 = sv.read_inout %28 : !hw.inout + %30 = hw.constant 1 : i8 + %31 = comb.add %19, %30 : i8 + %32 = comb.icmp ult %15, %19 : i8 + %33 = comb.extract %19 from 0 : (i8) -> i1 + %34 = comb.extract %19 from 1 : (i8) -> i1 + %35 = comb.extract %19 from 2 : (i8) -> i1 + %36 = comb.extract %19 from 3 : (i8) -> i1 + %37 = comb.extract %19 from 4 : (i8) -> i1 + %38 = comb.extract %19 from 5 : (i8) -> i1 + %39 = comb.extract %19 from 6 : (i8) -> i1 + %40 = comb.extract %19 from 7 : (i8) -> i1 + %41 = hw.constant 1 : i2 + %42 = hw.constant 1 : i2 + %43 = comb.icmp eq %6, %42 : i2 + %46 = sv.reg name "Register_inst4" : !hw.inout + sv.alwaysff(posedge %CLK) { + sv.passign %46, %44 : i1 + } + sv.initial { + sv.bpassign %46, %29 : i1 + } + %45 = sv.read_inout %46 : !hw.inout + %47 = hw.constant 3 : i2 + %49 = sv.reg : !hw.inout + %48 = sv.read_inout %49 : !hw.inout + %50 = sv.reg : !hw.inout + %0 = sv.read_inout %50 : !hw.inout + %51 = sv.reg : !hw.inout + %14 = sv.read_inout %51 : !hw.inout + %52 = sv.reg : !hw.inout + %18 = sv.read_inout %52 : !hw.inout + %53 = sv.reg : !hw.inout + %26 = sv.read_inout %53 : !hw.inout + %54 = sv.reg : !hw.inout + %44 = sv.read_inout %54 : !hw.inout + sv.alwayscomb { + sv.bpassign %51, %15 : i8 + sv.bpassign %52, %22 : i8 + sv.bpassign %53, %24 : i1 + sv.bpassign %54, %24 : i1 + sv.bpassign %50, %6 : i2 + sv.if %8 { + sv.bpassign %50, %10 : i2 + %55 = comb.concat %25, %25, %25, %25, %25, %25, %25, %24, %25, %25, %25, %24, %25, %25, %24, %25 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 + sv.bpassign %49, %55 : i16 + } else { + sv.if %12 { + sv.bpassign %51, %n : i8 + sv.bpassign %52, %23 : i8 + sv.bpassign %53, %25 : i1 + %56 = comb.concat %25, %25, %24, %25, %25, %25, %25, %24, %25, %24, %25, %24, %25, %25, %24, %25 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 + sv.bpassign %49, %56 : i16 + sv.if %27 { + sv.bpassign %52, %31 : i8 + sv.if %32 { + %57 = comb.concat %24, %24, %24, %24, %24, %24, %24, %24, %40, %39, %38, %37, %36, %35, %34, %33 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 + sv.bpassign %49, %57 : i16 + } else { + %58 = comb.concat %25, %24, %25, %25, %25, %25, %25, %24, %25, %25, %25, %24, %25, %25, %25, %25 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 + sv.bpassign %49, %58 : i16 + sv.bpassign %50, %41 : i2 + } + } + } else { + sv.if %43 { + %59 = comb.concat %25, %24, %25, %25, %25, %25, %25, %24, %25, %25, %25, %24, %25, %25, %25, %25 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 + sv.bpassign %49, %59 : i16 + sv.bpassign %54, %25 : i1 + sv.if %45 { + %60 = comb.concat %25, %25, %24, %25, %25, %25, %25, %24, %25, %24, %25, %24, %25, %25, %24, %25 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 + sv.bpassign %49, %60 : i16 + sv.bpassign %50, %47 : i2 + } + } else { + %61 = comb.concat %25, %25, %24, %25, %25, %25, %25, %24, %25, %25, %25, %24, %25, %25, %24, %25 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 + sv.bpassign %49, %61 : i16 + sv.bpassign %50, %47 : i2 + } + } + } + } + hw.output %48 : i16 + } +} diff --git a/tests/test_syntax/test_fsm.py b/tests/test_syntax/test_fsm.py index f26273dec..89b65bee7 100644 --- a/tests/test_syntax/test_fsm.py +++ b/tests/test_syntax/test_fsm.py @@ -82,5 +82,37 @@ class Foo(m.Circuit): io.O @= 0xDEED state.next @= State.DONE - m.compile("build/test_fsm_wait_until", Foo, output="mlir-verilog") + m.compile("build/test_fsm_wait_until", Foo, output="mlir") assert check_gold(__file__, "test_fsm_wait_until.mlir") + + +def test_fsm_wait_loop(): + class State(m.FSMState): + INIT = 0 + RUN = 1 + WAIT = 2 + DONE = 3 + + class Foo(m.Circuit): + io = m.IO(n=m.In(m.UInt[8]), O=m.Out(m.Bits[16])) + with m.fsm(State, init=State.INIT) as state: + with m.case(State.INIT): + io.O @= 0xFEED + state.next @= State.WAIT + with m.case(State.WAIT): + io.O @= 0xDEAD + with m.loop(0, io.n) as i: + io.O @= m.zext_to(i, 16) + io.O @= 0xBEEF + state.next @= State.RUN + with m.case(State.RUN): + io.O @= 0xBEEF + m.wait() + io.O @= 0xDEAD + state.next @= State.DONE + with m.case(State.DONE): + io.O @= 0xDEED + state.next @= State.DONE + + m.compile("build/test_fsm_wait_loop", Foo, output="mlir-verilog") + assert check_gold(__file__, "test_fsm_wait_loop.mlir")