Skip to content

Commit

Permalink
Add lazy pass loading (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
JosseVanDelm authored Jan 31, 2025
1 parent bdd3e87 commit a435ab7
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 94 deletions.
106 changes: 12 additions & 94 deletions compiler/tools/snax_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,11 @@

from xdsl.context import MLContext
from xdsl.dialects import get_all_dialects
from xdsl.transforms import get_all_passes
from xdsl.xdsl_opt_main import xDSLOptMain

from compiler.dialects import get_all_snax_dialects
from compiler.transforms.accfg_config_overlap import AccfgConfigOverlapPass
from compiler.transforms.accfg_dedup import AccfgDeduplicate
from compiler.transforms.accfg_insert_resets import InsertResetsPass
from compiler.transforms.alloc_to_global import AllocToGlobalPass
from compiler.transforms.backend.postprocess_mlir import PostprocessPass
from compiler.transforms.clear_memory_space import ClearMemorySpace
from compiler.transforms.convert_accfg_to_csr import ConvertAccfgToCsrPass
from compiler.transforms.convert_dart_to_snax_stream import ConvertDartToSnaxStream
from compiler.transforms.convert_kernel_to_linalg import ConvertKernelToLinalg
from compiler.transforms.convert_linalg_to_accfg import (
ConvertLinalgToAccPass,
TraceStatesPass,
)
from compiler.transforms.convert_linalg_to_kernel import ConvertLinalgToKernel
from compiler.transforms.convert_tosa_to_kernel import ConvertTosaToKernelPass
from compiler.transforms.dart.convert_linalg_to_dart import ConvertLinalgToDart
from compiler.transforms.dart.dart_fuse_operations import DartFuseOperationsPass
from compiler.transforms.dart.dart_layout_resolution import DartLayoutResolutionPass
from compiler.transforms.dart.dart_scheduler import DartSchedulerPass
from compiler.transforms.dispatch_kernels import DispatchKernels
from compiler.transforms.dispatch_regions import DispatchRegions
from compiler.transforms.frontend.preprocess_mlir import PreprocessPass
from compiler.transforms.frontend.preprocess_mlperf_tiny import PreprocessMLPerfTiny
from compiler.transforms.insert_accfg_op import InsertAccOp
from compiler.transforms.insert_sync_barrier import InsertSyncBarrier
from compiler.transforms.memref_to_snax import MemrefToSNAX
from compiler.transforms.realize_memref_casts import RealizeMemrefCastsPass
from compiler.transforms.reuse_memref_allocs import ReuseMemrefAllocs
from compiler.transforms.set_memory_layout import SetMemoryLayout
from compiler.transforms.set_memory_space import SetMemorySpace
from compiler.transforms.snax_bufferize import SnaxBufferize
from compiler.transforms.snax_copy_to_dma import SNAXCopyToDMA
from compiler.transforms.snax_lower_mcycle import SNAXLowerMCycle
from compiler.transforms.snax_to_func import SNAXToFunc
from compiler.transforms.test.debug_to_func import DebugToFuncPass
from compiler.transforms.test.insert_debugs import InsertDebugPass
from compiler.transforms.test.test_add_mcycle_around_launch import AddMcycleAroundLaunch
from compiler.transforms.test_add_mcycle_around_loop import AddMcycleAroundLoopPass
from compiler.transforms.test_remove_memref_copy import RemoveMemrefCopyPass
from compiler.transforms import get_all_snax_passes


class SNAXOptMain(xDSLOptMain):
Expand All @@ -57,6 +20,15 @@ def register_all_dialects(self):
for dialect_name, dialect_factory in all_dialects.items():
self.ctx.register_dialect(dialect_name, dialect_factory)

def register_all_passes(self):
"""
Register all SNAX and xDSL passes
"""
all_passes = get_all_passes()
all_passes.update(get_all_snax_passes())
for pass_name, pass_factory in all_passes.items():
self.register_pass(pass_name, pass_factory)

def __init__(
self,
description: str = "SNAX modular optimizer driver",
Expand All @@ -69,63 +41,9 @@ def __init__(
self.ctx = MLContext()
self.register_all_dialects()
super().register_all_frontends()
super().register_all_passes()
self.register_all_passes()
super().register_all_targets()

super().register_pass(DispatchKernels.name, lambda: DispatchKernels)
super().register_pass(SetMemorySpace.name, lambda: SetMemorySpace)
super().register_pass(SetMemoryLayout.name, lambda: SetMemoryLayout)
super().register_pass(InsertAccOp.name, lambda: InsertAccOp)
super().register_pass(InsertSyncBarrier.name, lambda: InsertSyncBarrier)
super().register_pass(DispatchRegions.name, lambda: DispatchRegions)
super().register_pass(SNAXCopyToDMA.name, lambda: SNAXCopyToDMA)
super().register_pass(SNAXToFunc.name, lambda: SNAXToFunc)
super().register_pass(SNAXLowerMCycle.name, lambda: SNAXLowerMCycle)
super().register_pass(ClearMemorySpace.name, lambda: ClearMemorySpace)
super().register_pass(
RealizeMemrefCastsPass.name, lambda: RealizeMemrefCastsPass
)
super().register_pass(InsertResetsPass.name, lambda: InsertResetsPass)
super().register_pass(MemrefToSNAX.name, lambda: MemrefToSNAX)
super().register_pass(AccfgDeduplicate.name, lambda: AccfgDeduplicate)
super().register_pass(
ConvertLinalgToAccPass.name, lambda: ConvertLinalgToAccPass
)
super().register_pass(TraceStatesPass.name, lambda: TraceStatesPass)
super().register_pass(ConvertAccfgToCsrPass.name, lambda: ConvertAccfgToCsrPass)
super().register_pass(
AccfgConfigOverlapPass.name, lambda: AccfgConfigOverlapPass
)
super().register_pass(
ConvertDartToSnaxStream.name, lambda: ConvertDartToSnaxStream
)
super().register_pass(ReuseMemrefAllocs.name, lambda: ReuseMemrefAllocs)
super().register_pass(RemoveMemrefCopyPass.name, lambda: RemoveMemrefCopyPass)
super().register_pass(
AddMcycleAroundLoopPass.name, lambda: AddMcycleAroundLoopPass
)
super().register_pass(ConvertLinalgToKernel.name, lambda: ConvertLinalgToKernel)
super().register_pass(ConvertKernelToLinalg.name, lambda: ConvertKernelToLinalg)
super().register_pass(
ConvertTosaToKernelPass.name, lambda: ConvertTosaToKernelPass
)
super().register_pass(InsertDebugPass.name, lambda: InsertDebugPass)
super().register_pass(DebugToFuncPass.name, lambda: DebugToFuncPass)
super().register_pass(PreprocessMLPerfTiny.name, lambda: PreprocessMLPerfTiny)
super().register_pass(AddMcycleAroundLaunch.name, lambda: AddMcycleAroundLaunch)
super().register_pass(ConvertLinalgToDart.name, lambda: ConvertLinalgToDart)
super().register_pass(SnaxBufferize.name, lambda: SnaxBufferize)
super().register_pass(
DartFuseOperationsPass.name, lambda: DartFuseOperationsPass
)
super().register_pass(AllocToGlobalPass.name, lambda: AllocToGlobalPass)
super().register_pass(PreprocessPass.name, lambda: PreprocessPass)
super().register_pass(PostprocessPass.name, lambda: PostprocessPass)
super().register_pass(DartSchedulerPass.name, lambda: DartSchedulerPass)
super().register_pass(
DartLayoutResolutionPass.name, lambda: DartLayoutResolutionPass
)

# arg handling
arg_parser = argparse.ArgumentParser(description=description)
super().register_all_arguments(arg_parser)
Expand Down
242 changes: 242 additions & 0 deletions compiler/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
from collections.abc import Callable

from xdsl.passes import ModulePass


def get_all_snax_passes() -> dict[str, Callable[[], type[ModulePass]]]:
"""Return the list of all available passes."""

def get_accfg_config_overlap():
from compiler.transforms.accfg_config_overlap import AccfgConfigOverlapPass

return AccfgConfigOverlapPass

def get_accfg_dedup():
from compiler.transforms.accfg_dedup import AccfgDeduplicate

return AccfgDeduplicate

def get_accfg_insert_resets():
from compiler.transforms.accfg_insert_resets import InsertResetsPass

return InsertResetsPass

def get_accfg_trace_states():
from compiler.transforms.convert_linalg_to_accfg import TraceStatesPass

return TraceStatesPass

def get_alloc_to_global():
from compiler.transforms.alloc_to_global import AllocToGlobalPass

return AllocToGlobalPass

def get_clear_memory_space():
from compiler.transforms.clear_memory_space import ClearMemorySpace

return ClearMemorySpace

def get_convert_accfg_to_csr():
from compiler.transforms.convert_accfg_to_csr import ConvertAccfgToCsrPass

return ConvertAccfgToCsrPass

def get_convert_dart_to_snax_stream():
from compiler.transforms.convert_dart_to_snax_stream import (
ConvertDartToSnaxStream,
)

return ConvertDartToSnaxStream

def get_convert_kernel_to_linalg():
from compiler.transforms.convert_kernel_to_linalg import ConvertKernelToLinalg

return ConvertKernelToLinalg

def get_convert_linalg_to_accfg():
from compiler.transforms.convert_linalg_to_accfg import ConvertLinalgToAccPass

return ConvertLinalgToAccPass

def get_convert_linalg_to_dart():
from compiler.transforms.dart.convert_linalg_to_dart import ConvertLinalgToDart

return ConvertLinalgToDart

def get_convert_linalg_to_kernel():
from compiler.transforms.convert_linalg_to_kernel import ConvertLinalgToKernel

return ConvertLinalgToKernel

def get_convert_tosa_to_kernel():
from compiler.transforms.convert_tosa_to_kernel import ConvertTosaToKernelPass

return ConvertTosaToKernelPass

def get_dart_fuse_operations():
from compiler.transforms.dart.dart_fuse_operations import DartFuseOperationsPass

return DartFuseOperationsPass

def get_dart_layout_resolution():
from compiler.transforms.dart.dart_layout_resolution import (
DartLayoutResolutionPass,
)

return DartLayoutResolutionPass

def get_dart_scheduler():
from compiler.transforms.dart.dart_scheduler import DartSchedulerPass

return DartSchedulerPass

def get_dispatch_kernels():
from compiler.transforms.dispatch_kernels import DispatchKernels

return DispatchKernels

def get_dispatch_regions():
from compiler.transforms.dispatch_regions import DispatchRegions

return DispatchRegions

def get_insert_accfg_op():
from compiler.transforms.insert_accfg_op import InsertAccOp

return InsertAccOp

def get_insert_sync_barrier():
from compiler.transforms.insert_sync_barrier import InsertSyncBarrier

return InsertSyncBarrier

def get_memref_to_snax():
from compiler.transforms.memref_to_snax import MemrefToSNAX

return MemrefToSNAX

def get_postprocess_mlir():
from compiler.transforms.backend.postprocess_mlir import PostprocessPass

return PostprocessPass

def get_preprocess_mlir():
from compiler.transforms.frontend.preprocess_mlir import PreprocessPass

return PreprocessPass

def get_preprocess_mlperf_tiny():
from compiler.transforms.frontend.preprocess_mlperf_tiny import (
PreprocessMLPerfTiny,
)

return PreprocessMLPerfTiny

def get_realize_memref_casts():
from compiler.transforms.realize_memref_casts import RealizeMemrefCastsPass

return RealizeMemrefCastsPass

def get_reuse_memref_allocs():
from compiler.transforms.reuse_memref_allocs import ReuseMemrefAllocs

return ReuseMemrefAllocs

def get_set_memory_layout():
from compiler.transforms.set_memory_layout import SetMemoryLayout

return SetMemoryLayout

def get_set_memory_space():
from compiler.transforms.set_memory_space import SetMemorySpace

return SetMemorySpace

def get_snax_bufferize():
from compiler.transforms.snax_bufferize import SnaxBufferize

return SnaxBufferize

def get_snax_copy_to_dma():
from compiler.transforms.snax_copy_to_dma import SNAXCopyToDMA

return SNAXCopyToDMA

def get_snax_lower_mcycle():
from compiler.transforms.snax_lower_mcycle import SNAXLowerMCycle

return SNAXLowerMCycle

def get_snax_to_func():
from compiler.transforms.snax_to_func import SNAXToFunc

return SNAXToFunc

def get_test_add_mcycle_around_loop():
from compiler.transforms.test_add_mcycle_around_loop import (
AddMcycleAroundLoopPass,
)

return AddMcycleAroundLoopPass

def get_test_add_mcycle_around_launch():
from compiler.transforms.test.test_add_mcycle_around_launch import (
AddMcycleAroundLaunch,
)

return AddMcycleAroundLaunch

def get_test_debug_to_func():
from compiler.transforms.test.debug_to_func import DebugToFuncPass

return DebugToFuncPass

def get_test_insert_debugs():
from compiler.transforms.test.insert_debugs import InsertDebugPass

return InsertDebugPass

def get_test_remove_memref_copy():
from compiler.transforms.test_remove_memref_copy import RemoveMemrefCopyPass

return RemoveMemrefCopyPass

return {
"accfg-config-overlap": get_accfg_config_overlap,
"accfg-dedup": get_accfg_dedup,
"accfg-insert-resets": get_accfg_insert_resets,
"accfg-trace-states": get_accfg_trace_states,
"alloc-to-global": get_alloc_to_global,
"clear-memory-space": get_clear_memory_space,
"convert-accfg-to-csr": get_convert_accfg_to_csr,
"convert-dart-to-snax-stream": get_convert_dart_to_snax_stream,
"convert-kernel-to-linalg": get_convert_kernel_to_linalg,
"convert-linalg-to-accfg": get_convert_linalg_to_accfg,
"convert-linalg-to-dart": get_convert_linalg_to_dart,
"convert-linalg-to-kernel": get_convert_linalg_to_kernel,
"convert-tosa-to-kernel": get_convert_tosa_to_kernel,
"dart-fuse-operations": get_dart_fuse_operations,
"dart-layout-resolution": get_dart_layout_resolution,
"dart-scheduler": get_dart_scheduler,
"dispatch-kernels": get_dispatch_kernels,
"dispatch-regions": get_dispatch_regions,
"insert-accfg-op": get_insert_accfg_op,
"insert-sync-barrier": get_insert_sync_barrier,
"memref-to-snax": get_memref_to_snax,
"postprocess": get_postprocess_mlir,
"preprocess": get_preprocess_mlir,
"preprocess-mlperftiny": get_preprocess_mlperf_tiny,
"realize-memref-casts": get_realize_memref_casts,
"reuse-memref-allocs": get_reuse_memref_allocs,
"set-memory-layout": get_set_memory_layout,
"set-memory-space": get_set_memory_space,
"snax-bufferize": get_snax_bufferize,
"snax-copy-to-dma": get_snax_copy_to_dma,
"snax-lower-mcycle": get_snax_lower_mcycle,
"snax-to-func": get_snax_to_func,
"test-add-mcycle-around-launch": get_test_add_mcycle_around_launch,
"test-add-mcycle-around-loop": get_test_add_mcycle_around_loop,
"test-debug-to-func": get_test_debug_to_func,
"test-insert-debugs": get_test_insert_debugs,
"test-remove-memref-copy": get_test_remove_memref_copy,
}

0 comments on commit a435ab7

Please sign in to comment.