Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bind2] Add support for compile guard around bind #1315

Merged
merged 15 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions magma/backend/mlir/hardware_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from magma.backend.mlir.sv import sv
from magma.backend.mlir.when_utils import WhenCompiler
from magma.backend.mlir.xmr_utils import get_xmr_paths
from magma.bind2 import is_bound_instance
from magma.bind2 import maybe_get_bound_instance_info, is_bound_instance
from magma.bit import Bit
from magma.bits import Bits, BitsMeta
from magma.bitutils import clog2, clog2safe
Expand Down Expand Up @@ -1135,11 +1135,19 @@ def postprocess(self):
for sym in self._syms:
instance = hw.InnerRefAttr(defn_sym, sym)
sv.BindOp(instance=instance)
bound_instances = list(filter(is_bound_instance, self._defn.instances))
for bound_instance in bound_instances:
inst_sym = self._ctx.parent.get_mapped_symbol(bound_instance)
for instance in self._defn.instances:
bound_instance_info = maybe_get_bound_instance_info(instance)
if bound_instance_info is None:
continue
inst_sym = self._ctx.parent.get_mapped_symbol(instance)
ref = hw.InnerRefAttr(defn_sym, inst_sym)
sv.BindOp(instance=ref)
with contextlib.ExitStack() as stack:
for compile_guard_info in bound_instance_info.compile_guards:
block = _make_compile_guard_block(
dataclasses.asdict(compile_guard_info)
)
stack.enter_context(push_block(block))
sv.BindOp(instance=ref)


class CoreIRBindProcessor(BindProcessorInterface):
Expand Down
9 changes: 8 additions & 1 deletion magma/backend/mlir/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from magma.backend.mlir.common import WithId, default_field, constant
from magma.backend.mlir.print_opts import PrintOpts
from magma.backend.mlir.printer_base import PrinterBase
from magma.common import Stack, epilogue
from magma.common import Stack, epilogue, maybe_dereference


OptionalWeakRef = Optional[weakref.ReferenceType]
Expand Down Expand Up @@ -197,6 +197,13 @@ def new_region(self) -> MlirRegion:
def set_parent(self, parent: MlirBlock):
self.parent = weakref.ref(parent)

def parent_op(self) -> Optional['MlirOp']:
if (block := maybe_dereference(self.parent)) is None:
return None
if (region := maybe_dereference(block.parent)) is None:
return None
return maybe_dereference(region.parent)

@print_location
def print(self, printer: PrinterBase, opts: PrintOpts):
self.print_op(printer, opts)
Expand Down
29 changes: 29 additions & 0 deletions magma/backend/mlir/mlir_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import abc
from typing import Any

from magma.backend.mlir.mlir import MlirOp


class MlirOpPass(abc.ABC):
def __init__(self, root: MlirOp):
self._root = root
self._callable = callable(self)

def _run_on_op(self, op: MlirOp):
for region in op.regions:
for block in region.blocks:
for op in block.operations:
yield from self._run_on_op(op)
yield self(op)

@abc.abstractmethod
def __call__(self, op: MlirOp) -> Any:
raise NotImplementedError()

def run(self):
yield from self._run_on_op(self._root)


class CollectMlirOpsPass(MlirOpPass):
def __call__(self, op: MlirOp):
return op
11 changes: 9 additions & 2 deletions magma/backend/mlir/sv.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,21 @@ def print(self, printer: PrinterBase, opts: PrintOpts):
printer.pop()
printer.print("}")
if self._else_block is None:
printer.flush()
self._print_close(printer)
return
printer.print(" else {")
printer.flush()
printer.push()
self._else_block.print(printer, opts)
printer.pop()
printer.print_line("}")
printer.print("}")
self._print_close(printer)

def _print_close(self, printer: PrinterBase):
if self.attr_dict:
printer.print(" ")
print_attr_dict(self.attr_dict, printer)
printer.flush()

def print_op(self, printer: PrinterBase, print_opts: PrintOpts):
raise NotImplementedError()
Expand Down
18 changes: 12 additions & 6 deletions magma/backend/mlir/translation_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from magma.backend.mlir.errors import MlirCompilerInternalError
from magma.backend.mlir.hardware_module import HardwareModule
from magma.backend.mlir.hw import hw
from magma.backend.mlir.mlir import MlirSymbol, push_block
from magma.backend.mlir.mlir import MlirOp, MlirSymbol, push_block
from magma.backend.mlir.mlir_passes import CollectMlirOpsPass
from magma.backend.mlir.scoped_name_generator import ScopedNameGenerator
from magma.backend.mlir.sv import sv
from magma.bind2 import is_bound_module
Expand Down Expand Up @@ -60,10 +61,15 @@ def _prepare_for_split_verilog(
)
for bind_op in bind_ops:
magma_inst = only(find_by_value(symbol_map, bind_op.instance.name))
output_filename = bound_module_to_filename[type(magma_inst)]
bind_op.attr_dict["output_file"] = (
hw.OutputFileAttr(str(output_filename))
)
output_filename = str(bound_module_to_filename[type(magma_inst)])
curr_op = bind_op
while True:
curr_op.attr_dict["output_file"] = (
hw.OutputFileAttr(output_filename)
)
curr_op = curr_op.parent_op()
if curr_op is None or not isinstance(curr_op, sv.IfDefOp):
break


class TranslationUnit:
Expand Down Expand Up @@ -156,7 +162,7 @@ def compile(self):
self._hardware_modules.values(),
filter(
lambda op: isinstance(op, sv.BindOp),
self._mlir_module.block.operations
CollectMlirOpsPass(self._mlir_module).run(),
),
self._symbol_map,
self._opts.basename,
Expand Down
33 changes: 25 additions & 8 deletions magma/bind2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import dataclasses
from typing import Dict, List, Optional, Union

from magma.circuit import DefineCircuitKind, CircuitKind, CircuitType
from magma.compile_guard import get_active_compile_guard_info, CompileGuardInfo
from magma.generator import Generator2Kind, Generator2
from magma.passes.passes import DefinitionPass, pass_lambda
from magma.view import PortView
Expand All @@ -17,27 +19,40 @@
ArgumentType = Union[Type, PortView]


def set_bound_instance_info(inst: CircuitType, info: Dict):
@dataclasses.dataclass(frozen=True)
class BoundInstanceInfo:
args: List[ArgumentType]
compile_guards: List[CompileGuardInfo]


def set_bound_instance_info(inst: CircuitType, info: BoundInstanceInfo):
global _BOUND_INSTANCE_INFO_KEY
setattr(inst, _BOUND_INSTANCE_INFO_KEY, info)


def get_bound_instance_info(inst: CircuitType) -> Dict:
def maybe_get_bound_instance_info(
inst: CircuitType
) -> Optional[BoundInstanceInfo]:
global _BOUND_INSTANCE_INFO_KEY
return getattr(inst, _BOUND_INSTANCE_INFO_KEY, None)


def is_bound_instance(inst: CircuitType) -> bool:
return get_bound_instance_info(inst) is not None
return maybe_get_bound_instance_info(inst) is not None


def set_is_bound_module(defn: CircuitKind, value: bool = True):
global _IS_BOUND_MODULE_KEY
setattr(defn, _IS_BOUND_MODULE_KEY, value)


def is_bound_module(defn: CircuitKind) -> bool:
global _IS_BOUND_MODULE_KEY
return getattr(defn, _IS_BOUND_MODULE_KEY, False)


def get_bound_generator_info(inst: CircuitType) -> List[Generator2Kind]:
global _BOUND_GENERATOR_INFO_KEY
try:
info = getattr(inst, _BOUND_GENERATOR_INFO_KEY)
except AttributeError:
Expand Down Expand Up @@ -76,14 +91,14 @@ def _bind_impl(
dut: DefineCircuitKind,
bind_module: CircuitKind,
args: List[ArgumentType],
compile_guard: str,
compile_guards: List[CompileGuardInfo],
):
arguments = list(dut.interface.ports.values()) + args
with dut.open():
inst = bind_module()
for param, arg in zip(inst.interface.ports.values(), arguments):
wire_value_or_port_view(param, arg)
info = {"args": args, "compile_guard": compile_guard}
info = BoundInstanceInfo(args, compile_guards)
set_bound_instance_info(inst, info)
set_is_bound_module(bind_module, True)

Expand All @@ -92,9 +107,9 @@ def bind2(
dut: DutType,
bind_module: BindModuleType,
*args,
compile_guard: Optional[str] = None,
):
args = list(args)
compile_guards = list(get_active_compile_guard_info())
are_generators = (
isinstance(dut, Generator2Kind) and
isinstance(bind_module, Generator2Kind)
Expand All @@ -105,11 +120,13 @@ def bind2(
"Expected no arguments for binding generators. "
"Implement bind_arguments() instead."
)
return _bind_generator_impl(dut, bind_module)
_bind_generator_impl(dut, bind_module)
return
are_modules = (
isinstance(dut, DefineCircuitKind) and
isinstance(dut, CircuitKind)
)
if are_modules:
return _bind_impl(dut, bind_module, args, compile_guard)
_bind_impl(dut, bind_module, args, compile_guards)
return
raise TypeError(dut, bind_module)
12 changes: 11 additions & 1 deletion magma/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from functools import wraps, partial, reduce
import hashlib
import operator
from typing import Any, Callable, Dict, Iterable
from typing import Any, Callable, Dict, Iterable, Optional
import warnings
import weakref


class Stack:
Expand All @@ -29,6 +30,9 @@ def peek_default(self, default: Any) -> Any:
except IndexError:
return default

def __iter__(self) -> Iterable[Any]:
return iter(reversed(self._stack))

def __bool__(self) -> bool:
return bool(self._stack)

Expand Down Expand Up @@ -370,3 +374,9 @@ def hash_expr(expr: str) -> str:
hasher = hashlib.shake_128()
hasher.update(expr.encode())
return hasher.hexdigest(8)


def maybe_dereference(ref: Optional[weakref.ReferenceType]) -> Optional:
if ref is None:
return None
return ref()
45 changes: 39 additions & 6 deletions magma/compile_guard.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib
import dataclasses
import itertools
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Iterable, Optional, Tuple, Union

from magma.bits import BitsMeta
from magma.clock import ClockTypes
from magma.circuit import CircuitKind, AnonymousCircuitType, CircuitBuilder
from magma.common import Stack
from magma.digital import DigitalMeta
from magma.generator import Generator2
from magma.interface import IO
Expand All @@ -17,6 +19,31 @@

_logger = root_logger().getChild("compile_guard")

_compile_guard_builder_stack = Stack()


def _get_compile_guard_builder_stack() -> Stack:
global _compile_guard_builder_stack
return _compile_guard_builder_stack


@contextlib.contextmanager
def _push_compile_guard_builder_stack(builder: '_CompileGuardBuilder'):
_get_compile_guard_builder_stack().push(builder)
yield
_get_compile_guard_builder_stack().pop()


def get_active_compile_guard_info() -> Iterable['CompileGuardInfo']:
for builder in _get_compile_guard_builder_stack():
yield builder.info


@dataclasses.dataclass(frozen=True)
class CompileGuardInfo:
condition_str: str
type: str


class _Grouper(GrouperBase):
def __init__(self, instances: InstanceCollection, builder: CircuitBuilder):
Expand Down Expand Up @@ -73,11 +100,16 @@ def __init__(self, name: Optional[str], cond: str, type: str):
self._system_types_added = set()
if type not in {"defined", "undefined"}:
raise ValueError(f"Unexpected compile guard type: {type}")
metadata = {"condition_str": cond, "type": type}
self._set_inst_attr("coreir_metadata", {"compile_guard": metadata})
self._set_definition_attr("_compile_guard_", metadata)
self._info = CompileGuardInfo(cond, type)
info_as_dict = dataclasses.asdict(self._info)
self._set_inst_attr("coreir_metadata", {"compile_guard": info_as_dict})
self._set_definition_attr("_compile_guard_", info_as_dict)
self._num_ports = itertools.count()

@property
def info(self) -> CompileGuardInfo:
return self._info

def add_port(self, T: Kind, name: Optional[str] = None) -> Type:
if name is None:
name = self._new_port_name()
Expand Down Expand Up @@ -128,8 +160,9 @@ def compile_guard(
type: str = "defined",
):
builder = _make_builder(cond, defn_name, inst_name, type)
with builder.open() as f:
yield f
with _push_compile_guard_builder_stack(builder):
with builder.open() as f:
yield f
grouper = _Grouper(builder.instances(), builder)
grouper.run()

Expand Down
13 changes: 13 additions & 0 deletions tests/gold/test_bind2_compile_guard.mlir.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module attributes {circt.loweringOptions = "locationInfoStyle=none,omitVersionComment"} {
hw.module @TopCompileGuardAsserts_mlir(%I: i1, %O: i1, %other: i1) -> () attributes {output_filelist = #hw.output_filelist<"$cwd/build/test_bind2_compile_guard_bind_files.list">} {
}
hw.module @Top(%I: i1) -> (O: i1) {
%1 = hw.constant -1 : i1
%0 = comb.xor %1, %I : i1
hw.instance "TopCompileGuardAsserts_mlir_inst0" sym @Top.TopCompileGuardAsserts_mlir_inst0 @TopCompileGuardAsserts_mlir(I: %I: i1, O: %0: i1, other: %I: i1) -> () {doNotPrint = true}
hw.output %0 : i1
}
sv.ifdef "ASSERT_ON" {
sv.bind #hw.innerNameRef<@Top::@Top.TopCompileGuardAsserts_mlir_inst0>
}
}
13 changes: 13 additions & 0 deletions tests/gold/test_bind2_compile_guard_split_verilog.mlir.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module attributes {circt.loweringOptions = "locationInfoStyle=none,omitVersionComment"} {
hw.module @TopCompileGuardAsserts_mlir(%I: i1, %O: i1, %other: i1) -> () attributes {output_file = #hw.output_file<"$cwd/build/TopCompileGuardAsserts_mlir.v">, output_filelist = #hw.output_filelist<"$cwd/build/test_bind2_compile_guard_split_verilog_bind_files.list">} {
}
hw.module @Top(%I: i1) -> (O: i1) attributes {output_file = #hw.output_file<"$cwd/build/test_bind2_compile_guard_split_verilog.v">} {
%1 = hw.constant -1 : i1
%0 = comb.xor %1, %I : i1
hw.instance "TopCompileGuardAsserts_mlir_inst0" sym @Top.TopCompileGuardAsserts_mlir_inst0 @TopCompileGuardAsserts_mlir(I: %I: i1, O: %0: i1, other: %I: i1) -> () {doNotPrint = true}
hw.output %0 : i1
}
sv.ifdef "ASSERT_ON" {
sv.bind #hw.innerNameRef<@Top::@Top.TopCompileGuardAsserts_mlir_inst0> {output_file = #hw.output_file<"$cwd/build/TopCompileGuardAsserts_mlir.v">}
} {output_file = #hw.output_file<"$cwd/build/TopCompileGuardAsserts_mlir.v">}
}
Loading
Loading