Skip to content

Commit

Permalink
[Bind2] Add support for compile guard around bind (#1315)
Browse files Browse the repository at this point in the history
* [Bind2] Separate return statements

* [Bind2] Remove compile_guard argument

* [CompileGuard] Store compile guard info directly

* [CompileGuard] Expose active compile guards

* [Bind2] Stash active compile guards

Attaches active compile guard info's to bound instance metadata.

* [Bind2] Add tests for compile guarded bind2()

* [MLIR] Print attr_dict for sv.IfDefOp

* [MLIR] Add MlirOp.parent_op() function

* [MLIR] Add support for bind2 wrapped in compile guard

* Emit sv.IfDefOp around sv.BindOp
* Attach hw.OutputFileAttr for entire sub-tree of op's surrounding
  sv.BindOp (all enclosing sv.IfDefOp, and including sv.BindOp).
* Add generic MlirOpPass infrastructure

* [CompileGuard] Fix test

* [MLIR] Update golds

* [MLIR] Use := in MlirOp.parent_op()

* [MLIR] Make MlirOpPass callable interface

* [CompileGuard] Use Stack iterable interface

* [Bind2] Use getattr to get bound inst info
  • Loading branch information
rsetaluri authored Sep 22, 2023
1 parent 5598f23 commit 07f4cb8
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 36 deletions.
18 changes: 13 additions & 5 deletions magma/backend/mlir/hardware_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from magma.backend.mlir.sv import sv
from magma.backend.mlir.when_utils import WhenCompiler
from magma.backend.mlir.xmr_utils import get_xmr_paths
from magma.bind2 import is_bound_instance
from magma.bind2 import maybe_get_bound_instance_info, is_bound_instance
from magma.bit import Bit
from magma.bits import Bits, BitsMeta
from magma.bitutils import clog2, clog2safe
Expand Down Expand Up @@ -1135,11 +1135,19 @@ def postprocess(self):
for sym in self._syms:
instance = hw.InnerRefAttr(defn_sym, sym)
sv.BindOp(instance=instance)
bound_instances = list(filter(is_bound_instance, self._defn.instances))
for bound_instance in bound_instances:
inst_sym = self._ctx.parent.get_mapped_symbol(bound_instance)
for instance in self._defn.instances:
bound_instance_info = maybe_get_bound_instance_info(instance)
if bound_instance_info is None:
continue
inst_sym = self._ctx.parent.get_mapped_symbol(instance)
ref = hw.InnerRefAttr(defn_sym, inst_sym)
sv.BindOp(instance=ref)
with contextlib.ExitStack() as stack:
for compile_guard_info in bound_instance_info.compile_guards:
block = _make_compile_guard_block(
dataclasses.asdict(compile_guard_info)
)
stack.enter_context(push_block(block))
sv.BindOp(instance=ref)


class CoreIRBindProcessor(BindProcessorInterface):
Expand Down
9 changes: 8 additions & 1 deletion magma/backend/mlir/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from magma.backend.mlir.common import WithId, default_field, constant
from magma.backend.mlir.print_opts import PrintOpts
from magma.backend.mlir.printer_base import PrinterBase
from magma.common import Stack, epilogue
from magma.common import Stack, epilogue, maybe_dereference


OptionalWeakRef = Optional[weakref.ReferenceType]
Expand Down Expand Up @@ -197,6 +197,13 @@ def new_region(self) -> MlirRegion:
def set_parent(self, parent: MlirBlock):
self.parent = weakref.ref(parent)

def parent_op(self) -> Optional['MlirOp']:
if (block := maybe_dereference(self.parent)) is None:
return None
if (region := maybe_dereference(block.parent)) is None:
return None
return maybe_dereference(region.parent)

@print_location
def print(self, printer: PrinterBase, opts: PrintOpts):
self.print_op(printer, opts)
Expand Down
29 changes: 29 additions & 0 deletions magma/backend/mlir/mlir_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import abc
from typing import Any

from magma.backend.mlir.mlir import MlirOp


class MlirOpPass(abc.ABC):
def __init__(self, root: MlirOp):
self._root = root
self._callable = callable(self)

def _run_on_op(self, op: MlirOp):
for region in op.regions:
for block in region.blocks:
for op in block.operations:
yield from self._run_on_op(op)
yield self(op)

@abc.abstractmethod
def __call__(self, op: MlirOp) -> Any:
raise NotImplementedError()

def run(self):
yield from self._run_on_op(self._root)


class CollectMlirOpsPass(MlirOpPass):
def __call__(self, op: MlirOp):
return op
11 changes: 9 additions & 2 deletions magma/backend/mlir/sv.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,21 @@ def print(self, printer: PrinterBase, opts: PrintOpts):
printer.pop()
printer.print("}")
if self._else_block is None:
printer.flush()
self._print_close(printer)
return
printer.print(" else {")
printer.flush()
printer.push()
self._else_block.print(printer, opts)
printer.pop()
printer.print_line("}")
printer.print("}")
self._print_close(printer)

def _print_close(self, printer: PrinterBase):
if self.attr_dict:
printer.print(" ")
print_attr_dict(self.attr_dict, printer)
printer.flush()

def print_op(self, printer: PrinterBase, print_opts: PrintOpts):
raise NotImplementedError()
Expand Down
18 changes: 12 additions & 6 deletions magma/backend/mlir/translation_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from magma.backend.mlir.errors import MlirCompilerInternalError
from magma.backend.mlir.hardware_module import HardwareModule
from magma.backend.mlir.hw import hw
from magma.backend.mlir.mlir import MlirSymbol, push_block
from magma.backend.mlir.mlir import MlirOp, MlirSymbol, push_block
from magma.backend.mlir.mlir_passes import CollectMlirOpsPass
from magma.backend.mlir.scoped_name_generator import ScopedNameGenerator
from magma.backend.mlir.sv import sv
from magma.bind2 import is_bound_module
Expand Down Expand Up @@ -60,10 +61,15 @@ def _prepare_for_split_verilog(
)
for bind_op in bind_ops:
magma_inst = only(find_by_value(symbol_map, bind_op.instance.name))
output_filename = bound_module_to_filename[type(magma_inst)]
bind_op.attr_dict["output_file"] = (
hw.OutputFileAttr(str(output_filename))
)
output_filename = str(bound_module_to_filename[type(magma_inst)])
curr_op = bind_op
while True:
curr_op.attr_dict["output_file"] = (
hw.OutputFileAttr(output_filename)
)
curr_op = curr_op.parent_op()
if curr_op is None or not isinstance(curr_op, sv.IfDefOp):
break


class TranslationUnit:
Expand Down Expand Up @@ -156,7 +162,7 @@ def compile(self):
self._hardware_modules.values(),
filter(
lambda op: isinstance(op, sv.BindOp),
self._mlir_module.block.operations
CollectMlirOpsPass(self._mlir_module).run(),
),
self._symbol_map,
self._opts.basename,
Expand Down
33 changes: 25 additions & 8 deletions magma/bind2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import dataclasses
from typing import Dict, List, Optional, Union

from magma.circuit import DefineCircuitKind, CircuitKind, CircuitType
from magma.compile_guard import get_active_compile_guard_info, CompileGuardInfo
from magma.generator import Generator2Kind, Generator2
from magma.passes.passes import DefinitionPass, pass_lambda
from magma.view import PortView
Expand All @@ -17,27 +19,40 @@
ArgumentType = Union[Type, PortView]


def set_bound_instance_info(inst: CircuitType, info: Dict):
@dataclasses.dataclass(frozen=True)
class BoundInstanceInfo:
args: List[ArgumentType]
compile_guards: List[CompileGuardInfo]


def set_bound_instance_info(inst: CircuitType, info: BoundInstanceInfo):
global _BOUND_INSTANCE_INFO_KEY
setattr(inst, _BOUND_INSTANCE_INFO_KEY, info)


def get_bound_instance_info(inst: CircuitType) -> Dict:
def maybe_get_bound_instance_info(
inst: CircuitType
) -> Optional[BoundInstanceInfo]:
global _BOUND_INSTANCE_INFO_KEY
return getattr(inst, _BOUND_INSTANCE_INFO_KEY, None)


def is_bound_instance(inst: CircuitType) -> bool:
return get_bound_instance_info(inst) is not None
return maybe_get_bound_instance_info(inst) is not None


def set_is_bound_module(defn: CircuitKind, value: bool = True):
global _IS_BOUND_MODULE_KEY
setattr(defn, _IS_BOUND_MODULE_KEY, value)


def is_bound_module(defn: CircuitKind) -> bool:
global _IS_BOUND_MODULE_KEY
return getattr(defn, _IS_BOUND_MODULE_KEY, False)


def get_bound_generator_info(inst: CircuitType) -> List[Generator2Kind]:
global _BOUND_GENERATOR_INFO_KEY
try:
info = getattr(inst, _BOUND_GENERATOR_INFO_KEY)
except AttributeError:
Expand Down Expand Up @@ -76,14 +91,14 @@ def _bind_impl(
dut: DefineCircuitKind,
bind_module: CircuitKind,
args: List[ArgumentType],
compile_guard: str,
compile_guards: List[CompileGuardInfo],
):
arguments = list(dut.interface.ports.values()) + args
with dut.open():
inst = bind_module()
for param, arg in zip(inst.interface.ports.values(), arguments):
wire_value_or_port_view(param, arg)
info = {"args": args, "compile_guard": compile_guard}
info = BoundInstanceInfo(args, compile_guards)
set_bound_instance_info(inst, info)
set_is_bound_module(bind_module, True)

Expand All @@ -92,9 +107,9 @@ def bind2(
dut: DutType,
bind_module: BindModuleType,
*args,
compile_guard: Optional[str] = None,
):
args = list(args)
compile_guards = list(get_active_compile_guard_info())
are_generators = (
isinstance(dut, Generator2Kind) and
isinstance(bind_module, Generator2Kind)
Expand All @@ -105,11 +120,13 @@ def bind2(
"Expected no arguments for binding generators. "
"Implement bind_arguments() instead."
)
return _bind_generator_impl(dut, bind_module)
_bind_generator_impl(dut, bind_module)
return
are_modules = (
isinstance(dut, DefineCircuitKind) and
isinstance(dut, CircuitKind)
)
if are_modules:
return _bind_impl(dut, bind_module, args, compile_guard)
_bind_impl(dut, bind_module, args, compile_guards)
return
raise TypeError(dut, bind_module)
12 changes: 11 additions & 1 deletion magma/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from functools import wraps, partial, reduce
import hashlib
import operator
from typing import Any, Callable, Dict, Iterable
from typing import Any, Callable, Dict, Iterable, Optional
import warnings
import weakref


class Stack:
Expand All @@ -29,6 +30,9 @@ def peek_default(self, default: Any) -> Any:
except IndexError:
return default

def __iter__(self) -> Iterable[Any]:
return iter(reversed(self._stack))

def __bool__(self) -> bool:
return bool(self._stack)

Expand Down Expand Up @@ -370,3 +374,9 @@ def hash_expr(expr: str) -> str:
hasher = hashlib.shake_128()
hasher.update(expr.encode())
return hasher.hexdigest(8)


def maybe_dereference(ref: Optional[weakref.ReferenceType]) -> Optional:
if ref is None:
return None
return ref()
45 changes: 39 additions & 6 deletions magma/compile_guard.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib
import dataclasses
import itertools
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Iterable, Optional, Tuple, Union

from magma.bits import BitsMeta
from magma.clock import ClockTypes
from magma.circuit import CircuitKind, AnonymousCircuitType, CircuitBuilder
from magma.common import Stack
from magma.digital import DigitalMeta
from magma.generator import Generator2
from magma.interface import IO
Expand All @@ -17,6 +19,31 @@

_logger = root_logger().getChild("compile_guard")

_compile_guard_builder_stack = Stack()


def _get_compile_guard_builder_stack() -> Stack:
global _compile_guard_builder_stack
return _compile_guard_builder_stack


@contextlib.contextmanager
def _push_compile_guard_builder_stack(builder: '_CompileGuardBuilder'):
_get_compile_guard_builder_stack().push(builder)
yield
_get_compile_guard_builder_stack().pop()


def get_active_compile_guard_info() -> Iterable['CompileGuardInfo']:
for builder in _get_compile_guard_builder_stack():
yield builder.info


@dataclasses.dataclass(frozen=True)
class CompileGuardInfo:
condition_str: str
type: str


class _Grouper(GrouperBase):
def __init__(self, instances: InstanceCollection, builder: CircuitBuilder):
Expand Down Expand Up @@ -73,11 +100,16 @@ def __init__(self, name: Optional[str], cond: str, type: str):
self._system_types_added = set()
if type not in {"defined", "undefined"}:
raise ValueError(f"Unexpected compile guard type: {type}")
metadata = {"condition_str": cond, "type": type}
self._set_inst_attr("coreir_metadata", {"compile_guard": metadata})
self._set_definition_attr("_compile_guard_", metadata)
self._info = CompileGuardInfo(cond, type)
info_as_dict = dataclasses.asdict(self._info)
self._set_inst_attr("coreir_metadata", {"compile_guard": info_as_dict})
self._set_definition_attr("_compile_guard_", info_as_dict)
self._num_ports = itertools.count()

@property
def info(self) -> CompileGuardInfo:
return self._info

def add_port(self, T: Kind, name: Optional[str] = None) -> Type:
if name is None:
name = self._new_port_name()
Expand Down Expand Up @@ -128,8 +160,9 @@ def compile_guard(
type: str = "defined",
):
builder = _make_builder(cond, defn_name, inst_name, type)
with builder.open() as f:
yield f
with _push_compile_guard_builder_stack(builder):
with builder.open() as f:
yield f
grouper = _Grouper(builder.instances(), builder)
grouper.run()

Expand Down
13 changes: 13 additions & 0 deletions tests/gold/test_bind2_compile_guard.mlir.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module attributes {circt.loweringOptions = "locationInfoStyle=none,omitVersionComment"} {
hw.module @TopCompileGuardAsserts_mlir(%I: i1, %O: i1, %other: i1) -> () attributes {output_filelist = #hw.output_filelist<"$cwd/build/test_bind2_compile_guard_bind_files.list">} {
}
hw.module @Top(%I: i1) -> (O: i1) {
%1 = hw.constant -1 : i1
%0 = comb.xor %1, %I : i1
hw.instance "TopCompileGuardAsserts_mlir_inst0" sym @Top.TopCompileGuardAsserts_mlir_inst0 @TopCompileGuardAsserts_mlir(I: %I: i1, O: %0: i1, other: %I: i1) -> () {doNotPrint = true}
hw.output %0 : i1
}
sv.ifdef "ASSERT_ON" {
sv.bind #hw.innerNameRef<@Top::@Top.TopCompileGuardAsserts_mlir_inst0>
}
}
13 changes: 13 additions & 0 deletions tests/gold/test_bind2_compile_guard_split_verilog.mlir.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module attributes {circt.loweringOptions = "locationInfoStyle=none,omitVersionComment"} {
hw.module @TopCompileGuardAsserts_mlir(%I: i1, %O: i1, %other: i1) -> () attributes {output_file = #hw.output_file<"$cwd/build/TopCompileGuardAsserts_mlir.v">, output_filelist = #hw.output_filelist<"$cwd/build/test_bind2_compile_guard_split_verilog_bind_files.list">} {
}
hw.module @Top(%I: i1) -> (O: i1) attributes {output_file = #hw.output_file<"$cwd/build/test_bind2_compile_guard_split_verilog.v">} {
%1 = hw.constant -1 : i1
%0 = comb.xor %1, %I : i1
hw.instance "TopCompileGuardAsserts_mlir_inst0" sym @Top.TopCompileGuardAsserts_mlir_inst0 @TopCompileGuardAsserts_mlir(I: %I: i1, O: %0: i1, other: %I: i1) -> () {doNotPrint = true}
hw.output %0 : i1
}
sv.ifdef "ASSERT_ON" {
sv.bind #hw.innerNameRef<@Top::@Top.TopCompileGuardAsserts_mlir_inst0> {output_file = #hw.output_file<"$cwd/build/TopCompileGuardAsserts_mlir.v">}
} {output_file = #hw.output_file<"$cwd/build/TopCompileGuardAsserts_mlir.v">}
}
Loading

0 comments on commit 07f4cb8

Please sign in to comment.