Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lazy dialect loading #357

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions compiler/dialects/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from collections.abc import Callable

from xdsl.ir import Dialect


def get_all_snax_dialects() -> dict[str, Callable[[], Dialect]]:
"""Returns all available snax dialects"""

def get_accfg():
from compiler.dialects.accfg import ACCFG

return ACCFG

def get_dart():
from compiler.dialects.dart import Dart

return Dart

def get_debug():
from compiler.dialects.test.debug import Debug

return Debug

def get_kernel():
from compiler.dialects.kernel import Kernel

return Kernel

def get_snax():
from compiler.dialects.snax import Snax

return Snax

def get_snax_stream():
from compiler.dialects.snax_stream import SnaxStream

return SnaxStream

def get_tsl():
from compiler.dialects.tsl import TSL

return TSL

return {
"accfg": get_accfg,
"dart": get_dart,
"debug": get_debug,
"kernel": get_kernel,
"snax": get_snax,
"snax_stream": get_snax_stream,
"tsl": get_tsl,
}
35 changes: 12 additions & 23 deletions compiler/tools/snax_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
from collections.abc import Sequence

from xdsl.context import MLContext
from xdsl.dialects import get_all_dialects
from xdsl.xdsl_opt_main import xDSLOptMain

from compiler.dialects.accfg import ACCFG
from compiler.dialects.dart import Dart
from compiler.dialects.kernel import Kernel
from compiler.dialects.snax import Snax
from compiler.dialects.snax_stream import SnaxStream
from compiler.dialects.test.debug import Debug
from compiler.dialects.tsl import TSL
from compiler.dialects import get_all_snax_dialects
from compiler.transforms.accfg_config_overlap import AccfgConfigOverlapPass
from compiler.transforms.accfg_dedup import AccfgDeduplicate
from compiler.transforms.accfg_insert_resets import InsertResetsPass
Expand Down Expand Up @@ -53,6 +48,15 @@


class SNAXOptMain(xDSLOptMain):
def register_all_dialects(self):
all_dialects = get_all_dialects()
# FIXME: override upstream accfg and stream dialect.
all_dialects.pop("accfg", None)
all_dialects.pop("stream", None)
all_dialects.update(get_all_snax_dialects())
for dialect_name, dialect_factory in all_dialects.items():
self.ctx.register_dialect(dialect_name, dialect_factory)

def __init__(
self,
description: str = "SNAX modular optimizer driver",
Expand All @@ -63,24 +67,11 @@ def __init__(
self.available_targets = {}

self.ctx = MLContext()
super().register_all_dialects()
self.register_all_dialects()
super().register_all_frontends()
super().register_all_passes()
super().register_all_targets()

## Add custom dialects & passes
# FIXME: override upstream accfg dialect. Remove this after upstreaming full downstream accfg dialect.
self.ctx._registered_dialects.pop("accfg", None) # pyright: ignore
# Warning: overrides upstream stream dialect.
self.ctx._registered_dialects.pop("stream", None) # pyright: ignore

self.ctx.load_dialect(Snax)
self.ctx.load_dialect(TSL)
self.ctx.load_dialect(Kernel)
self.ctx.load_dialect(ACCFG)
self.ctx.load_dialect(SnaxStream)
self.ctx.load_dialect(Debug)
self.ctx.load_dialect(Dart)
super().register_pass(DispatchKernels.name, lambda: DispatchKernels)
super().register_pass(SetMemorySpace.name, lambda: SetMemorySpace)
super().register_pass(SetMemoryLayout.name, lambda: SetMemoryLayout)
Expand Down Expand Up @@ -144,8 +135,6 @@ def __init__(

super().setup_pipeline()

pass


def main():
SNAXOptMain().run()
Expand Down