diff --git a/magma/backend/mlir/compile_to_mlir_opts.py b/magma/backend/mlir/compile_to_mlir_opts.py index f13e36020..6f04ad952 100644 --- a/magma/backend/mlir/compile_to_mlir_opts.py +++ b/magma/backend/mlir/compile_to_mlir_opts.py @@ -16,3 +16,4 @@ class CompileToMlirOpts: location_info_style: str = "none" explicit_bitcast: bool = False disallow_expression_inlining_in_ports: bool = False + emit_muxes_as_if_then_else: bool = False diff --git a/magma/backend/mlir/hardware_module.py b/magma/backend/mlir/hardware_module.py index 3638b8c79..b353a6a27 100644 --- a/magma/backend/mlir/hardware_module.py +++ b/magma/backend/mlir/hardware_module.py @@ -160,16 +160,49 @@ def get_module_interface( return operands, results +def make_if_then_else_mux( + ctx: 'HardwareModule', + data: List[MlirValue], + select: MlirValue, + result: MlirValue, +): + wire = ctx.new_value(hw.InOutType(result.type)) + sv.RegOp(results=[wire]) + sv.ReadInOutOp(operands=[wire], results=[result]) + counter = itertools.count() + + def _process(): + index = next(counter) + if index >= len(data): + return + cond = ctx.new_value(builtin.IntegerType(1)) + const = ctx.new_value(select.type) + hw.ConstantOp(value=index, results=[const]) + comb.ICmpOp(predicate="eq", operands=[select, const], results=[cond]) + curr = sv.IfOp(operands=[cond]) + with push_block(curr.then_block): + sv.BPAssignOp(operands=[wire, data[index]]) + with push_block(curr.else_block): + _process() + + with push_block(sv.AlwaysCombOp().body_block): + _process() + + def make_mux( ctx: 'HardwareModule', data: List[MlirValue], select: MlirValue, - result: MlirValue): + result: MlirValue, +): if ctx.opts.extend_non_power_of_two_muxes: closest_power_of_two_size = 2 ** clog2(len(data)) if closest_power_of_two_size != len(data): extension = [data[0]] * (closest_power_of_two_size - len(data)) data = extension + data + if ctx.opts.emit_muxes_as_if_then_else: + make_if_then_else_mux(ctx, data, select, result) + return mlir_type = hw.ArrayType((len(data),), data[0].type) array = ctx.new_value(mlir_type) hw.ArrayCreateOp( diff --git a/tests/test_backend/test_mlir/golds/aggregate_mux_wrapper_emit_muxes_as_if_then_else.mlir b/tests/test_backend/test_mlir/golds/aggregate_mux_wrapper_emit_muxes_as_if_then_else.mlir new file mode 100644 index 000000000..2987169d1 --- /dev/null +++ b/tests/test_backend/test_mlir/golds/aggregate_mux_wrapper_emit_muxes_as_if_then_else.mlir @@ -0,0 +1,28 @@ +module attributes {circt.loweringOptions = "locationInfoStyle=none"} { + hw.module @aggregate_mux_wrapper(%a: !hw.struct, %s: i1) -> (y: !hw.struct) { + %0 = hw.struct_extract %a["x"] : !hw.struct + %2 = hw.constant -1 : i8 + %1 = comb.xor %2, %0 : i8 + %3 = hw.struct_extract %a["y"] : !hw.struct + %5 = hw.constant -1 : i1 + %4 = comb.xor %5, %3 : i1 + %6 = hw.struct_create (%1, %4) : !hw.struct + %8 = sv.reg : !hw.inout> + %7 = sv.read_inout %8 : !hw.inout> + sv.alwayscomb { + %10 = hw.constant 0 : i1 + %9 = comb.icmp eq %s, %10 : i1 + sv.if %9 { + sv.bpassign %8, %6 : !hw.struct + } else { + %12 = hw.constant 1 : i1 + %11 = comb.icmp eq %s, %12 : i1 + sv.if %11 { + sv.bpassign %8, %a : !hw.struct + } else { + } + } + } + hw.output %7 : !hw.struct + } +} diff --git a/tests/test_backend/test_mlir/golds/aggregate_mux_wrapper_emit_muxes_as_if_then_else.v b/tests/test_backend/test_mlir/golds/aggregate_mux_wrapper_emit_muxes_as_if_then_else.v new file mode 100644 index 000000000..a2aa44e5c --- /dev/null +++ b/tests/test_backend/test_mlir/golds/aggregate_mux_wrapper_emit_muxes_as_if_then_else.v @@ -0,0 +1,16 @@ +// Generated by CIRCT circtorg-0.0.0-1018-g3a39b339f +module aggregate_mux_wrapper( + input struct packed {logic [7:0] x; logic y; } a, + input s, + output struct packed {logic [7:0] x; logic y; } y); + + struct packed {logic [7:0] x; logic y; } _GEN; + always_comb begin + if (~s) + _GEN = '{x: (~a.x), y: (~a.y)}; + else if (s) + _GEN = a; + end // always_comb + assign y = _GEN; +endmodule + diff --git a/tests/test_backend/test_mlir/test_compile_to_mlir_local_examples.py b/tests/test_backend/test_mlir/test_compile_to_mlir_local_examples.py index ef4359d1d..37279a445 100644 --- a/tests/test_backend/test_mlir/test_compile_to_mlir_local_examples.py +++ b/tests/test_backend/test_mlir/test_compile_to_mlir_local_examples.py @@ -215,3 +215,14 @@ def test_compile_to_mlir_disallow_expression_inlining_in_ports( } kwargs.update({"check_verilog": False}) run_test_compile_to_mlir(ckt, **kwargs) + + +@pytest.mark.parametrize("ckt", (aggregate_mux_wrapper,)) +def test_compile_to_mlir_emit_muxes_as_if_then_else(ckt): + gold_name = f"{ckt.name}_emit_muxes_as_if_then_else" + kwargs = { + "extend_non_power_of_two_muxes": False, + "emit_muxes_as_if_then_else": True, + "gold_name": gold_name, + } + run_test_compile_to_mlir(ckt, **kwargs)