diff --git a/compiler/ir/tsl/__init__.py b/compiler/ir/tsl/__init__.py deleted file mode 100644 index 4c50e7ce..00000000 --- a/compiler/ir/tsl/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from compiler.ir.tsl.stride import Stride -from compiler.ir.tsl.tiled_stride import TiledStride -from compiler.ir.tsl.tiled_strided_layout import TiledStridedLayout - -__all__ = ["Stride", "TiledStride", "TiledStridedLayout"] diff --git a/compiler/__init__.py b/snaxc/compiler/__init__.py similarity index 100% rename from compiler/__init__.py rename to snaxc/compiler/__init__.py diff --git a/compiler/accelerators/__init__.py b/snaxc/compiler/accelerators/__init__.py similarity index 100% rename from compiler/accelerators/__init__.py rename to snaxc/compiler/accelerators/__init__.py diff --git a/compiler/accelerators/accelerator.py b/snaxc/compiler/accelerators/accelerator.py similarity index 98% rename from compiler/accelerators/accelerator.py rename to snaxc/compiler/accelerators/accelerator.py index 8976df0a..558e4dfb 100644 --- a/compiler/accelerators/accelerator.py +++ b/snaxc/compiler/accelerators/accelerator.py @@ -3,7 +3,7 @@ from xdsl.ir import Operation -from compiler.dialects import accfg +from snaxc.dialects import accfg class Accelerator(ABC): diff --git a/compiler/accelerators/dispatching.py b/snaxc/compiler/accelerators/dispatching.py similarity index 80% rename from compiler/accelerators/dispatching.py rename to snaxc/compiler/accelerators/dispatching.py index 137a3ccf..36ffe3ef 100644 --- a/compiler/accelerators/dispatching.py +++ b/snaxc/compiler/accelerators/dispatching.py @@ -4,8 +4,8 @@ from xdsl.ir import Attribute -from compiler.accelerators.accelerator import Accelerator -from compiler.dialects.kernel import KernelOp +from snaxc.accelerators.accelerator import Accelerator +from snaxc.dialects.kernel import KernelOp @dataclass diff --git a/compiler/accelerators/gemmini.py b/snaxc/compiler/accelerators/gemmini.py similarity index 97% rename from compiler/accelerators/gemmini.py rename to snaxc/compiler/accelerators/gemmini.py index 2b7a8e80..56540581 100644 --- a/compiler/accelerators/gemmini.py +++ b/snaxc/compiler/accelerators/gemmini.py @@ -4,9 +4,9 @@ from xdsl.dialects.builtin import IndexType, i64 from xdsl.ir import Operation -from compiler.accelerators.rocc import RoCCAccelerator -from compiler.dialects import accfg -from compiler.util.pack_bitlist import pack_bitlist +from snaxc.accelerators.rocc import RoCCAccelerator +from snaxc.dialects import accfg +from snaxc.util.pack_bitlist import pack_bitlist class GemminiAccelerator(RoCCAccelerator): diff --git a/compiler/accelerators/matmul_unit.py b/snaxc/compiler/accelerators/matmul_unit.py similarity index 100% rename from compiler/accelerators/matmul_unit.py rename to snaxc/compiler/accelerators/matmul_unit.py diff --git a/compiler/accelerators/registry.py b/snaxc/compiler/accelerators/registry.py similarity index 82% rename from compiler/accelerators/registry.py rename to snaxc/compiler/accelerators/registry.py index 438cf72e..a5ecc285 100644 --- a/compiler/accelerators/registry.py +++ b/snaxc/compiler/accelerators/registry.py @@ -2,14 +2,14 @@ from xdsl.dialects.builtin import ModuleOp, StringAttr -from compiler.accelerators import find_accelerator_op -from compiler.accelerators.accelerator import Accelerator -from compiler.accelerators.gemmini import GemminiAccelerator -from compiler.accelerators.snax_alu import SNAXAluAccelerator -from compiler.accelerators.snax_gemm import SNAXGEMMAccelerator -from compiler.accelerators.snax_gemmx import SNAXGEMMXAccelerator -from compiler.accelerators.snax_hwpe_mult import SNAXHWPEMultAccelerator -from compiler.dialects.accfg import AcceleratorOp +from snaxc.accelerators import find_accelerator_op +from snaxc.accelerators.accelerator import Accelerator +from snaxc.accelerators.gemmini import GemminiAccelerator +from snaxc.accelerators.snax_alu import SNAXAluAccelerator +from snaxc.accelerators.snax_gemm import SNAXGEMMAccelerator +from snaxc.accelerators.snax_gemmx import SNAXGEMMXAccelerator +from snaxc.accelerators.snax_hwpe_mult import SNAXHWPEMultAccelerator +from snaxc.dialects.accfg import AcceleratorOp class AcceleratorRegistry: diff --git a/compiler/accelerators/rocc.py b/snaxc/compiler/accelerators/rocc.py similarity index 97% rename from compiler/accelerators/rocc.py rename to snaxc/compiler/accelerators/rocc.py index 0c6ab0f6..37a5c5c7 100644 --- a/compiler/accelerators/rocc.py +++ b/snaxc/compiler/accelerators/rocc.py @@ -5,9 +5,9 @@ from xdsl.dialects.builtin import IntegerAttr, i64 from xdsl.ir import Operation, SSAValue -from compiler.accelerators.accelerator import Accelerator -from compiler.dialects import accfg -from compiler.inference.trace_acc_state import infer_state_of +from snaxc.accelerators.accelerator import Accelerator +from snaxc.dialects import accfg +from snaxc.inference.trace_acc_state import infer_state_of class RoCCAccelerator(Accelerator, ABC): diff --git a/compiler/accelerators/snax.py b/snaxc/compiler/accelerators/snax.py similarity index 97% rename from compiler/accelerators/snax.py rename to snaxc/compiler/accelerators/snax.py index 8e609c52..c5f50e38 100644 --- a/compiler/accelerators/snax.py +++ b/snaxc/compiler/accelerators/snax.py @@ -7,13 +7,13 @@ from xdsl.dialects.scf import ConditionOp, WhileOp, YieldOp from xdsl.ir import Operation, OpResult, SSAValue -from compiler.accelerators.accelerator import Accelerator -from compiler.accelerators.streamers import StreamerConfiguration -from compiler.accelerators.streamers.streamers import StreamerFlag, StreamerOpts -from compiler.dialects import accfg -from compiler.dialects.dart import StreamingRegionOpBase -from compiler.dialects.snax_stream import StreamerConfigurationAttr, StreamingRegionOp -from compiler.ir.dart.access_pattern import Template +from snaxc.accelerators.accelerator import Accelerator +from snaxc.accelerators.streamers import StreamerConfiguration +from snaxc.accelerators.streamers.streamers import StreamerFlag, StreamerOpts +from snaxc.dialects import accfg +from snaxc.dialects.dart import StreamingRegionOpBase +from snaxc.dialects.snax_stream import StreamerConfigurationAttr, StreamingRegionOp +from snaxc.ir.dart.access_pattern import Template c0_attr = builtin.IntegerAttr(0, builtin.IndexType()) diff --git a/compiler/accelerators/snax_alu.py b/snaxc/compiler/accelerators/snax_alu.py similarity index 95% rename from compiler/accelerators/snax_alu.py rename to snaxc/compiler/accelerators/snax_alu.py index ba706379..fe3f7c2a 100644 --- a/compiler/accelerators/snax_alu.py +++ b/snaxc/compiler/accelerators/snax_alu.py @@ -5,20 +5,20 @@ from xdsl.ir import Operation, SSAValue from xdsl.ir.affine import AffineMap -import compiler.dialects.kernel as kernel -from compiler.accelerators.dispatching import DispatchTemplate, SupportedKernel -from compiler.accelerators.snax import ( +import snaxc.dialects.kernel as kernel +from snaxc.accelerators.dispatching import DispatchTemplate, SupportedKernel +from snaxc.accelerators.snax import ( SNAXAccelerator, SNAXPollingBarrier3, SNAXStreamer, ) -from compiler.accelerators.streamers import ( +from snaxc.accelerators.streamers import ( Streamer, StreamerConfiguration, StreamerType, ) -from compiler.dialects import accfg, dart, snax_stream -from compiler.ir.dart.access_pattern import Template, TemplatePattern +from snaxc.dialects import accfg, dart, snax_stream +from snaxc.ir.dart.access_pattern import Template, TemplatePattern default_streamer = StreamerConfiguration( [ diff --git a/compiler/accelerators/snax_gemm.py b/snaxc/compiler/accelerators/snax_gemm.py similarity index 93% rename from compiler/accelerators/snax_gemm.py rename to snaxc/compiler/accelerators/snax_gemm.py index 91df943d..8058e94c 100644 --- a/compiler/accelerators/snax_gemm.py +++ b/snaxc/compiler/accelerators/snax_gemm.py @@ -5,16 +5,16 @@ from xdsl.ir import Operation, SSAValue from xdsl.ir.affine import AffineDimExpr, AffineMap -import compiler.dialects.kernel as kernel -from compiler.accelerators.dispatching import DispatchTemplate, SupportedKernel -from compiler.accelerators.snax import SNAXAccelerator, SNAXStreamer -from compiler.accelerators.streamers import ( +import snaxc.dialects.kernel as kernel +from snaxc.accelerators.dispatching import DispatchTemplate, SupportedKernel +from snaxc.accelerators.snax import SNAXAccelerator, SNAXStreamer +from snaxc.accelerators.streamers import ( Streamer, StreamerConfiguration, StreamerType, ) -from compiler.dialects import accfg, dart, snax_stream -from compiler.ir.dart.access_pattern import Template, TemplatePattern +from snaxc.dialects import accfg, dart, snax_stream +from snaxc.ir.dart.access_pattern import Template, TemplatePattern default_streamer = StreamerConfiguration( [ diff --git a/compiler/accelerators/snax_gemmx.py b/snaxc/compiler/accelerators/snax_gemmx.py similarity index 96% rename from compiler/accelerators/snax_gemmx.py rename to snaxc/compiler/accelerators/snax_gemmx.py index 4b64133e..7a1073d6 100644 --- a/compiler/accelerators/snax_gemmx.py +++ b/snaxc/compiler/accelerators/snax_gemmx.py @@ -6,21 +6,21 @@ from xdsl.ir import BlockArgument, Operation, SSAValue from xdsl.ir.affine import AffineDimExpr, AffineMap -from compiler.accelerators.dispatching import DispatchTemplate, SupportedKernel -from compiler.accelerators.snax import ( +from snaxc.accelerators.dispatching import DispatchTemplate, SupportedKernel +from snaxc.accelerators.snax import ( SNAXAccelerator, SNAXPollingBarrier4, SNAXStreamer, ) -from compiler.accelerators.streamers import ( +from snaxc.accelerators.streamers import ( Streamer, StreamerConfiguration, StreamerType, ) -from compiler.accelerators.streamers.streamers import StreamerOpts -from compiler.dialects import accfg, dart, kernel, snax_stream -from compiler.ir.dart.access_pattern import Template, TemplatePattern -from compiler.util.pack_bitlist import pack_bitlist +from snaxc.accelerators.streamers.streamers import StreamerOpts +from snaxc.dialects import accfg, dart, kernel, snax_stream +from snaxc.ir.dart.access_pattern import Template, TemplatePattern +from snaxc.util.pack_bitlist import pack_bitlist default_streamer = StreamerConfiguration( [ diff --git a/compiler/accelerators/snax_hwpe_mult.py b/snaxc/compiler/accelerators/snax_hwpe_mult.py similarity index 97% rename from compiler/accelerators/snax_hwpe_mult.py rename to snaxc/compiler/accelerators/snax_hwpe_mult.py index f2855852..ed353a71 100644 --- a/compiler/accelerators/snax_hwpe_mult.py +++ b/snaxc/compiler/accelerators/snax_hwpe_mult.py @@ -4,8 +4,8 @@ from xdsl.ir import Attribute, Operation, SSAValue from xdsl.utils.hints import isa -from compiler.accelerators.snax import SNAXAccelerator, SNAXPollingBarrier -from compiler.dialects import accfg +from snaxc.accelerators.snax import SNAXAccelerator, SNAXPollingBarrier +from snaxc.dialects import accfg class SNAXHWPEMultAccelerator(SNAXAccelerator, SNAXPollingBarrier): diff --git a/compiler/accelerators/streamers/__init__.py b/snaxc/compiler/accelerators/streamers/__init__.py similarity index 100% rename from compiler/accelerators/streamers/__init__.py rename to snaxc/compiler/accelerators/streamers/__init__.py diff --git a/compiler/accelerators/streamers/streamers.py b/snaxc/compiler/accelerators/streamers/streamers.py similarity index 100% rename from compiler/accelerators/streamers/streamers.py rename to snaxc/compiler/accelerators/streamers/streamers.py diff --git a/compiler/accelerators/util.py b/snaxc/compiler/accelerators/util.py similarity index 94% rename from compiler/accelerators/util.py rename to snaxc/compiler/accelerators/util.py index f025066c..0a0d7f98 100644 --- a/compiler/accelerators/util.py +++ b/snaxc/compiler/accelerators/util.py @@ -2,7 +2,7 @@ from xdsl.ir import Operation from xdsl.traits import SymbolTable -from compiler.dialects.accfg import AcceleratorOp +from snaxc.dialects.accfg import AcceleratorOp def find_accelerator_op( diff --git a/compiler/dialects/__init__.py b/snaxc/compiler/dialects/__init__.py similarity index 66% rename from compiler/dialects/__init__.py rename to snaxc/compiler/dialects/__init__.py index fd97524c..0114f547 100644 --- a/compiler/dialects/__init__.py +++ b/snaxc/compiler/dialects/__init__.py @@ -7,37 +7,37 @@ def get_all_snax_dialects() -> dict[str, Callable[[], Dialect]]: """Returns all available snax dialects""" def get_accfg(): - from compiler.dialects.accfg import ACCFG + from snaxc.dialects.accfg import ACCFG return ACCFG def get_dart(): - from compiler.dialects.dart import Dart + from snaxc.dialects.dart import Dart return Dart def get_debug(): - from compiler.dialects.test.debug import Debug + from snaxc.dialects.test.debug import Debug return Debug def get_kernel(): - from compiler.dialects.kernel import Kernel + from snaxc.dialects.kernel import Kernel return Kernel def get_snax(): - from compiler.dialects.snax import Snax + from snaxc.dialects.snax import Snax return Snax def get_snax_stream(): - from compiler.dialects.snax_stream import SnaxStream + from snaxc.dialects.snax_stream import SnaxStream return SnaxStream def get_tsl(): - from compiler.dialects.tsl import TSL + from snaxc.dialects.tsl import TSL return TSL diff --git a/compiler/dialects/accfg.py b/snaxc/compiler/dialects/accfg.py similarity index 100% rename from compiler/dialects/accfg.py rename to snaxc/compiler/dialects/accfg.py diff --git a/compiler/dialects/dart.py b/snaxc/compiler/dialects/dart.py similarity index 100% rename from compiler/dialects/dart.py rename to snaxc/compiler/dialects/dart.py diff --git a/compiler/dialects/kernel.py b/snaxc/compiler/dialects/kernel.py similarity index 100% rename from compiler/dialects/kernel.py rename to snaxc/compiler/dialects/kernel.py diff --git a/compiler/dialects/snax.py b/snaxc/compiler/dialects/snax.py similarity index 97% rename from compiler/dialects/snax.py rename to snaxc/compiler/dialects/snax.py index 7b8f2f65..46e7b231 100644 --- a/compiler/dialects/snax.py +++ b/snaxc/compiler/dialects/snax.py @@ -31,14 +31,14 @@ from xdsl.printer import Printer from xdsl.utils.exceptions import VerifyException -from compiler.accelerators.streamers import ( +from snaxc.accelerators.streamers import ( Streamer, StreamerConfiguration, StreamerFlag, StreamerType, ) -from compiler.accelerators.streamers.streamers import StreamerOpts -from compiler.util.memref_descriptor import LLVMMemrefDescriptor +from snaxc.accelerators.streamers.streamers import StreamerOpts +from snaxc.util.memref_descriptor import LLVMMemrefDescriptor @irdl_op_definition diff --git a/compiler/dialects/snax_stream.py b/snaxc/compiler/dialects/snax_stream.py similarity index 97% rename from compiler/dialects/snax_stream.py rename to snaxc/compiler/dialects/snax_stream.py index 3afc138c..5b253772 100644 --- a/compiler/dialects/snax_stream.py +++ b/snaxc/compiler/dialects/snax_stream.py @@ -25,9 +25,9 @@ from xdsl.parser import AttrParser from xdsl.printer import Printer -from compiler.accelerators import find_accelerator_op -from compiler.accelerators.streamers import StreamerConfiguration -from compiler.dialects.snax import StreamerConfigurationAttr +from snaxc.accelerators import find_accelerator_op +from snaxc.accelerators.streamers import StreamerConfiguration +from snaxc.dialects.snax import StreamerConfigurationAttr @irdl_attr_definition diff --git a/compiler/dialects/test/__init__.py b/snaxc/compiler/dialects/test/__init__.py similarity index 100% rename from compiler/dialects/test/__init__.py rename to snaxc/compiler/dialects/test/__init__.py diff --git a/compiler/dialects/test/debug.py b/snaxc/compiler/dialects/test/debug.py similarity index 100% rename from compiler/dialects/test/debug.py rename to snaxc/compiler/dialects/test/debug.py diff --git a/compiler/dialects/tsl.py b/snaxc/compiler/dialects/tsl.py similarity index 99% rename from compiler/dialects/tsl.py rename to snaxc/compiler/dialects/tsl.py index 91b7d477..b3e47085 100644 --- a/compiler/dialects/tsl.py +++ b/snaxc/compiler/dialects/tsl.py @@ -20,8 +20,8 @@ from xdsl.parser import AttrParser from xdsl.printer import Printer -from compiler.ir.tsl import TiledStridedLayout -from compiler.parser.tsl_parser import TSLParser +from snaxc.ir.tsl import TiledStridedLayout +from snaxc.parser.tsl_parser import TSLParser @irdl_attr_definition diff --git a/compiler/inference/__init__.py b/snaxc/compiler/inference/__init__.py similarity index 100% rename from compiler/inference/__init__.py rename to snaxc/compiler/inference/__init__.py diff --git a/compiler/inference/dataflow.py b/snaxc/compiler/inference/dataflow.py similarity index 100% rename from compiler/inference/dataflow.py rename to snaxc/compiler/inference/dataflow.py diff --git a/compiler/inference/helpers.py b/snaxc/compiler/inference/helpers.py similarity index 99% rename from compiler/inference/helpers.py rename to snaxc/compiler/inference/helpers.py index fa28558b..102155c4 100644 --- a/compiler/inference/helpers.py +++ b/snaxc/compiler/inference/helpers.py @@ -3,7 +3,7 @@ from xdsl.dialects import func, llvm, scf from xdsl.ir import Block, BlockArgument, Operation, OpResult, Region, SSAValue -from compiler.dialects import accfg +from snaxc.dialects import accfg def has_accfg_effects(op: Operation) -> bool: diff --git a/compiler/inference/scoped_setups.py b/snaxc/compiler/inference/scoped_setups.py similarity index 98% rename from compiler/inference/scoped_setups.py rename to snaxc/compiler/inference/scoped_setups.py index e1c55fca..66cbb9ad 100644 --- a/compiler/inference/scoped_setups.py +++ b/snaxc/compiler/inference/scoped_setups.py @@ -56,8 +56,8 @@ from xdsl.rewriter import InsertPoint from xdsl.traits import is_side_effect_free -from compiler.dialects import accfg -from compiler.inference.helpers import val_is_defined_in_block +from snaxc.dialects import accfg +from snaxc.inference.helpers import val_is_defined_in_block def get_scoped_setup_inputs( diff --git a/compiler/inference/trace_acc_state.py b/snaxc/compiler/inference/trace_acc_state.py similarity index 98% rename from compiler/inference/trace_acc_state.py rename to snaxc/compiler/inference/trace_acc_state.py index dc2abc6c..2038dbbf 100644 --- a/compiler/inference/trace_acc_state.py +++ b/snaxc/compiler/inference/trace_acc_state.py @@ -18,7 +18,7 @@ from xdsl.dialects import scf from xdsl.ir import Block, BlockArgument, Region, SSAValue -from compiler.dialects import accfg +from snaxc.dialects import accfg State = dict[str, SSAValue] diff --git a/compiler/ir/__init__.py b/snaxc/compiler/ir/__init__.py similarity index 100% rename from compiler/ir/__init__.py rename to snaxc/compiler/ir/__init__.py diff --git a/compiler/ir/dart/__init__.py b/snaxc/compiler/ir/dart/__init__.py similarity index 100% rename from compiler/ir/dart/__init__.py rename to snaxc/compiler/ir/dart/__init__.py diff --git a/compiler/ir/dart/access_pattern.py b/snaxc/compiler/ir/dart/access_pattern.py similarity index 99% rename from compiler/ir/dart/access_pattern.py rename to snaxc/compiler/ir/dart/access_pattern.py index 96fe6198..0691ea13 100644 --- a/compiler/ir/dart/access_pattern.py +++ b/snaxc/compiler/ir/dart/access_pattern.py @@ -6,7 +6,7 @@ from typing_extensions import Self, TypeVar, overload from xdsl.ir.affine import AffineDimExpr, AffineMap -from compiler.ir.dart.affine_transform import AffineTransform +from snaxc.ir.dart.affine_transform import AffineTransform @dataclass(frozen=True) diff --git a/compiler/ir/dart/affine_transform.py b/snaxc/compiler/ir/dart/affine_transform.py similarity index 100% rename from compiler/ir/dart/affine_transform.py rename to snaxc/compiler/ir/dart/affine_transform.py diff --git a/compiler/ir/dart/scheduler.py b/snaxc/compiler/ir/dart/scheduler.py similarity index 98% rename from compiler/ir/dart/scheduler.py rename to snaxc/compiler/ir/dart/scheduler.py index abd3b9f2..b03f720e 100644 --- a/compiler/ir/dart/scheduler.py +++ b/snaxc/compiler/ir/dart/scheduler.py @@ -2,7 +2,7 @@ import numpy as np -from compiler.ir.dart.access_pattern import Schedule, Template +from snaxc.ir.dart.access_pattern import Schedule, Template def scheduler_backtrack( diff --git a/compiler/ir/tsl/README.md b/snaxc/compiler/ir/tsl/README.md similarity index 100% rename from compiler/ir/tsl/README.md rename to snaxc/compiler/ir/tsl/README.md diff --git a/snaxc/compiler/ir/tsl/__init__.py b/snaxc/compiler/ir/tsl/__init__.py new file mode 100644 index 00000000..007fce61 --- /dev/null +++ b/snaxc/compiler/ir/tsl/__init__.py @@ -0,0 +1,5 @@ +from snaxc.ir.tsl.stride import Stride +from snaxc.ir.tsl.tiled_stride import TiledStride +from snaxc.ir.tsl.tiled_strided_layout import TiledStridedLayout + +__all__ = ["Stride", "TiledStride", "TiledStridedLayout"] diff --git a/compiler/ir/tsl/stride.py b/snaxc/compiler/ir/tsl/stride.py similarity index 100% rename from compiler/ir/tsl/stride.py rename to snaxc/compiler/ir/tsl/stride.py diff --git a/compiler/ir/tsl/tiled_stride.py b/snaxc/compiler/ir/tsl/tiled_stride.py similarity index 99% rename from compiler/ir/tsl/tiled_stride.py rename to snaxc/compiler/ir/tsl/tiled_stride.py index 9e75c09e..0f75c323 100644 --- a/compiler/ir/tsl/tiled_stride.py +++ b/snaxc/compiler/ir/tsl/tiled_stride.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Self -from compiler.ir.tsl.stride import Stride +from snaxc.ir.tsl.stride import Stride @dataclass diff --git a/compiler/ir/tsl/tiled_strided_layout.py b/snaxc/compiler/ir/tsl/tiled_strided_layout.py similarity index 98% rename from compiler/ir/tsl/tiled_strided_layout.py rename to snaxc/compiler/ir/tsl/tiled_strided_layout.py index fc9f545c..3ffd42d3 100644 --- a/compiler/ir/tsl/tiled_strided_layout.py +++ b/snaxc/compiler/ir/tsl/tiled_strided_layout.py @@ -7,8 +7,8 @@ import numpy as np from numpy._typing import NDArray -from compiler.ir.tsl.stride import Stride -from compiler.ir.tsl.tiled_stride import TiledStride +from snaxc.ir.tsl.stride import Stride +from snaxc.ir.tsl.tiled_stride import TiledStride @dataclass(frozen=True) diff --git a/compiler/parser/__init__.py b/snaxc/compiler/parser/__init__.py similarity index 100% rename from compiler/parser/__init__.py rename to snaxc/compiler/parser/__init__.py diff --git a/compiler/parser/tsl_parser.py b/snaxc/compiler/parser/tsl_parser.py similarity index 97% rename from compiler/parser/tsl_parser.py rename to snaxc/compiler/parser/tsl_parser.py index be879ffc..42bec307 100644 --- a/compiler/parser/tsl_parser.py +++ b/snaxc/compiler/parser/tsl_parser.py @@ -5,7 +5,7 @@ from xdsl.utils.exceptions import ParseError from xdsl.utils.mlir_lexer import MLIRTokenKind -from compiler.ir.tsl import Stride, TiledStride, TiledStridedLayout +from snaxc.ir.tsl import Stride, TiledStride, TiledStridedLayout class TSLParser(BaseParser): diff --git a/compiler/snax-opt b/snaxc/compiler/snax-opt similarity index 100% rename from compiler/snax-opt rename to snaxc/compiler/snax-opt diff --git a/compiler/tools/__init__.py b/snaxc/compiler/tools/__init__.py similarity index 100% rename from compiler/tools/__init__.py rename to snaxc/compiler/tools/__init__.py diff --git a/compiler/tools/snax_opt_main.py b/snaxc/compiler/tools/snax_opt_main.py similarity index 94% rename from compiler/tools/snax_opt_main.py rename to snaxc/compiler/tools/snax_opt_main.py index 1124f175..97c5fb4d 100644 --- a/compiler/tools/snax_opt_main.py +++ b/snaxc/compiler/tools/snax_opt_main.py @@ -6,8 +6,8 @@ from xdsl.transforms import get_all_passes from xdsl.xdsl_opt_main import xDSLOptMain -from compiler.dialects import get_all_snax_dialects -from compiler.transforms import get_all_snax_passes +from snaxc.dialects import get_all_snax_dialects +from snaxc.transforms import get_all_snax_passes class SNAXOptMain(xDSLOptMain): diff --git a/compiler/transforms/__init__.py b/snaxc/compiler/transforms/__init__.py similarity index 64% rename from compiler/transforms/__init__.py rename to snaxc/compiler/transforms/__init__.py index 648484ce..c8f7d62f 100644 --- a/compiler/transforms/__init__.py +++ b/snaxc/compiler/transforms/__init__.py @@ -7,197 +7,197 @@ def get_all_snax_passes() -> dict[str, Callable[[], type[ModulePass]]]: """Return the list of all available passes.""" def get_accfg_config_overlap(): - from compiler.transforms.accfg_config_overlap import AccfgConfigOverlapPass + from snaxc.transforms.accfg_config_overlap import AccfgConfigOverlapPass return AccfgConfigOverlapPass def get_accfg_dedup(): - from compiler.transforms.accfg_dedup import AccfgDeduplicate + from snaxc.transforms.accfg_dedup import AccfgDeduplicate return AccfgDeduplicate def get_accfg_insert_resets(): - from compiler.transforms.accfg_insert_resets import InsertResetsPass + from snaxc.transforms.accfg_insert_resets import InsertResetsPass return InsertResetsPass def get_accfg_trace_states(): - from compiler.transforms.convert_linalg_to_accfg import TraceStatesPass + from snaxc.transforms.convert_linalg_to_accfg import TraceStatesPass return TraceStatesPass def get_alloc_to_global(): - from compiler.transforms.alloc_to_global import AllocToGlobalPass + from snaxc.transforms.alloc_to_global import AllocToGlobalPass return AllocToGlobalPass def get_clear_memory_space(): - from compiler.transforms.clear_memory_space import ClearMemorySpace + from snaxc.transforms.clear_memory_space import ClearMemorySpace return ClearMemorySpace def get_convert_accfg_to_csr(): - from compiler.transforms.convert_accfg_to_csr import ConvertAccfgToCsrPass + from snaxc.transforms.convert_accfg_to_csr import ConvertAccfgToCsrPass return ConvertAccfgToCsrPass def get_convert_dart_to_snax_stream(): - from compiler.transforms.convert_dart_to_snax_stream import ( + from snaxc.transforms.convert_dart_to_snax_stream import ( ConvertDartToSnaxStream, ) return ConvertDartToSnaxStream def get_convert_kernel_to_linalg(): - from compiler.transforms.convert_kernel_to_linalg import ConvertKernelToLinalg + from snaxc.transforms.convert_kernel_to_linalg import ConvertKernelToLinalg return ConvertKernelToLinalg def get_convert_linalg_to_accfg(): - from compiler.transforms.convert_linalg_to_accfg import ConvertLinalgToAccPass + from snaxc.transforms.convert_linalg_to_accfg import ConvertLinalgToAccPass return ConvertLinalgToAccPass def get_convert_linalg_to_dart(): - from compiler.transforms.dart.convert_linalg_to_dart import ConvertLinalgToDart + from snaxc.transforms.dart.convert_linalg_to_dart import ConvertLinalgToDart return ConvertLinalgToDart def get_convert_linalg_to_kernel(): - from compiler.transforms.convert_linalg_to_kernel import ConvertLinalgToKernel + from snaxc.transforms.convert_linalg_to_kernel import ConvertLinalgToKernel return ConvertLinalgToKernel def get_convert_tosa_to_kernel(): - from compiler.transforms.convert_tosa_to_kernel import ConvertTosaToKernelPass + from snaxc.transforms.convert_tosa_to_kernel import ConvertTosaToKernelPass return ConvertTosaToKernelPass def get_dart_fuse_operations(): - from compiler.transforms.dart.dart_fuse_operations import DartFuseOperationsPass + from snaxc.transforms.dart.dart_fuse_operations import DartFuseOperationsPass return DartFuseOperationsPass def get_dart_layout_resolution(): - from compiler.transforms.dart.dart_layout_resolution import ( + from snaxc.transforms.dart.dart_layout_resolution import ( DartLayoutResolutionPass, ) return DartLayoutResolutionPass def get_dart_scheduler(): - from compiler.transforms.dart.dart_scheduler import DartSchedulerPass + from snaxc.transforms.dart.dart_scheduler import DartSchedulerPass return DartSchedulerPass def get_dispatch_kernels(): - from compiler.transforms.dispatch_kernels import DispatchKernels + from snaxc.transforms.dispatch_kernels import DispatchKernels return DispatchKernels def get_dispatch_regions(): - from compiler.transforms.dispatch_regions import DispatchRegions + from snaxc.transforms.dispatch_regions import DispatchRegions return DispatchRegions def get_insert_accfg_op(): - from compiler.transforms.insert_accfg_op import InsertAccOp + from snaxc.transforms.insert_accfg_op import InsertAccOp return InsertAccOp def get_insert_sync_barrier(): - from compiler.transforms.insert_sync_barrier import InsertSyncBarrier + from snaxc.transforms.insert_sync_barrier import InsertSyncBarrier return InsertSyncBarrier def get_memref_to_snax(): - from compiler.transforms.memref_to_snax import MemrefToSNAX + from snaxc.transforms.memref_to_snax import MemrefToSNAX return MemrefToSNAX def get_postprocess_mlir(): - from compiler.transforms.backend.postprocess_mlir import PostprocessPass + from snaxc.transforms.backend.postprocess_mlir import PostprocessPass return PostprocessPass def get_preprocess_mlir(): - from compiler.transforms.frontend.preprocess_mlir import PreprocessPass + from snaxc.transforms.frontend.preprocess_mlir import PreprocessPass return PreprocessPass def get_preprocess_mlperf_tiny(): - from compiler.transforms.frontend.preprocess_mlperf_tiny import ( + from snaxc.transforms.frontend.preprocess_mlperf_tiny import ( PreprocessMLPerfTiny, ) return PreprocessMLPerfTiny def get_realize_memref_casts(): - from compiler.transforms.realize_memref_casts import RealizeMemrefCastsPass + from snaxc.transforms.realize_memref_casts import RealizeMemrefCastsPass return RealizeMemrefCastsPass def get_reuse_memref_allocs(): - from compiler.transforms.reuse_memref_allocs import ReuseMemrefAllocs + from snaxc.transforms.reuse_memref_allocs import ReuseMemrefAllocs return ReuseMemrefAllocs def get_set_memory_layout(): - from compiler.transforms.set_memory_layout import SetMemoryLayout + from snaxc.transforms.set_memory_layout import SetMemoryLayout return SetMemoryLayout def get_set_memory_space(): - from compiler.transforms.set_memory_space import SetMemorySpace + from snaxc.transforms.set_memory_space import SetMemorySpace return SetMemorySpace def get_snax_bufferize(): - from compiler.transforms.snax_bufferize import SnaxBufferize + from snaxc.transforms.snax_bufferize import SnaxBufferize return SnaxBufferize def get_snax_copy_to_dma(): - from compiler.transforms.snax_copy_to_dma import SNAXCopyToDMA + from snaxc.transforms.snax_copy_to_dma import SNAXCopyToDMA return SNAXCopyToDMA def get_snax_lower_mcycle(): - from compiler.transforms.snax_lower_mcycle import SNAXLowerMCycle + from snaxc.transforms.snax_lower_mcycle import SNAXLowerMCycle return SNAXLowerMCycle def get_snax_to_func(): - from compiler.transforms.snax_to_func import SNAXToFunc + from snaxc.transforms.snax_to_func import SNAXToFunc return SNAXToFunc def get_test_add_mcycle_around_loop(): - from compiler.transforms.test_add_mcycle_around_loop import ( + from snaxc.transforms.test_add_mcycle_around_loop import ( AddMcycleAroundLoopPass, ) return AddMcycleAroundLoopPass def get_test_add_mcycle_around_launch(): - from compiler.transforms.test.test_add_mcycle_around_launch import ( + from snaxc.transforms.test.test_add_mcycle_around_launch import ( AddMcycleAroundLaunch, ) return AddMcycleAroundLaunch def get_test_debug_to_func(): - from compiler.transforms.test.debug_to_func import DebugToFuncPass + from snaxc.transforms.test.debug_to_func import DebugToFuncPass return DebugToFuncPass def get_test_insert_debugs(): - from compiler.transforms.test.insert_debugs import InsertDebugPass + from snaxc.transforms.test.insert_debugs import InsertDebugPass return InsertDebugPass def get_test_remove_memref_copy(): - from compiler.transforms.test_remove_memref_copy import RemoveMemrefCopyPass + from snaxc.transforms.test_remove_memref_copy import RemoveMemrefCopyPass return RemoveMemrefCopyPass diff --git a/compiler/transforms/accfg_config_overlap.py b/snaxc/compiler/transforms/accfg_config_overlap.py similarity index 97% rename from compiler/transforms/accfg_config_overlap.py rename to snaxc/compiler/transforms/accfg_config_overlap.py index 6d1e07e7..19220771 100644 --- a/compiler/transforms/accfg_config_overlap.py +++ b/snaxc/compiler/transforms/accfg_config_overlap.py @@ -11,9 +11,9 @@ ) from xdsl.rewriter import InsertPoint -from compiler.dialects import accfg -from compiler.inference.helpers import iter_ops_range, previous_ops_of -from compiler.inference.scoped_setups import get_scoped_setup_inputs +from snaxc.dialects import accfg +from snaxc.inference.helpers import iter_ops_range, previous_ops_of +from snaxc.inference.scoped_setups import get_scoped_setup_inputs class BlockLevelSetupAwaitOverlapPattern(RewritePattern): diff --git a/compiler/transforms/accfg_dedup.py b/snaxc/compiler/transforms/accfg_dedup.py similarity index 98% rename from compiler/transforms/accfg_dedup.py rename to snaxc/compiler/transforms/accfg_dedup.py index 5aca2a76..215ef71c 100644 --- a/compiler/transforms/accfg_dedup.py +++ b/snaxc/compiler/transforms/accfg_dedup.py @@ -13,12 +13,12 @@ ) from xdsl.traits import is_side_effect_free -from compiler.dialects import accfg -from compiler.inference.helpers import ( +from snaxc.dialects import accfg +from snaxc.inference.helpers import ( get_initial_value_for_scf_for_lcv, val_is_defined_in_block, ) -from compiler.inference.trace_acc_state import all_setup_ops_in_region, infer_state_of +from snaxc.inference.trace_acc_state import all_setup_ops_in_region, infer_state_of class SimplifyRedundantSetupCalls(RewritePattern): diff --git a/compiler/transforms/accfg_insert_resets.py b/snaxc/compiler/transforms/accfg_insert_resets.py similarity index 97% rename from compiler/transforms/accfg_insert_resets.py rename to snaxc/compiler/transforms/accfg_insert_resets.py index 75ae8054..ca6b0d91 100644 --- a/compiler/transforms/accfg_insert_resets.py +++ b/snaxc/compiler/transforms/accfg_insert_resets.py @@ -9,8 +9,8 @@ from xdsl.passes import ModulePass from xdsl.pattern_rewriter import PatternRewriter, PatternRewriteWalker, RewritePattern -from compiler.dialects import accfg -from compiler.inference.dataflow import ( +from snaxc.dialects import accfg +from snaxc.inference.dataflow import ( get_insertion_points_where_val_dangles, uses_through_controlflow, ) diff --git a/compiler/transforms/alloc_to_global.py b/snaxc/compiler/transforms/alloc_to_global.py similarity index 100% rename from compiler/transforms/alloc_to_global.py rename to snaxc/compiler/transforms/alloc_to_global.py diff --git a/compiler/transforms/backend/__init__.py b/snaxc/compiler/transforms/backend/__init__.py similarity index 100% rename from compiler/transforms/backend/__init__.py rename to snaxc/compiler/transforms/backend/__init__.py diff --git a/compiler/transforms/backend/postprocess_mlir.py b/snaxc/compiler/transforms/backend/postprocess_mlir.py similarity index 100% rename from compiler/transforms/backend/postprocess_mlir.py rename to snaxc/compiler/transforms/backend/postprocess_mlir.py diff --git a/compiler/transforms/clear_memory_space.py b/snaxc/compiler/transforms/clear_memory_space.py similarity index 98% rename from compiler/transforms/clear_memory_space.py rename to snaxc/compiler/transforms/clear_memory_space.py index 677a2908..f48c6188 100644 --- a/compiler/transforms/clear_memory_space.py +++ b/snaxc/compiler/transforms/clear_memory_space.py @@ -4,7 +4,7 @@ from xdsl.passes import ModulePass from xdsl.utils.hints import isa -from compiler.dialects.tsl import TiledStridedLayoutAttr +from snaxc.dialects.tsl import TiledStridedLayoutAttr class ClearMemorySpace(ModulePass): diff --git a/compiler/transforms/convert_accfg_to_csr.py b/snaxc/compiler/transforms/convert_accfg_to_csr.py similarity index 98% rename from compiler/transforms/convert_accfg_to_csr.py rename to snaxc/compiler/transforms/convert_accfg_to_csr.py index 3c5f3f49..dfdc1bd2 100644 --- a/compiler/transforms/convert_accfg_to_csr.py +++ b/snaxc/compiler/transforms/convert_accfg_to_csr.py @@ -15,9 +15,9 @@ op_type_rewrite_pattern, ) -from compiler.accelerators.accelerator import Accelerator -from compiler.accelerators.registry import AcceleratorRegistry -from compiler.dialects import accfg +from snaxc.accelerators.accelerator import Accelerator +from snaxc.accelerators.registry import AcceleratorRegistry +from snaxc.dialects import accfg @dataclass diff --git a/compiler/transforms/convert_dart_to_snax_stream.py b/snaxc/compiler/transforms/convert_dart_to_snax_stream.py similarity index 97% rename from compiler/transforms/convert_dart_to_snax_stream.py rename to snaxc/compiler/transforms/convert_dart_to_snax_stream.py index cf5b7525..7bd9f154 100644 --- a/compiler/transforms/convert_dart_to_snax_stream.py +++ b/snaxc/compiler/transforms/convert_dart_to_snax_stream.py @@ -11,10 +11,10 @@ op_type_rewrite_pattern, ) -from compiler.accelerators.registry import AcceleratorRegistry -from compiler.accelerators.snax import SNAXStreamer -from compiler.dialects import dart, snax_stream -from compiler.ir.dart.affine_transform import AffineTransform +from snaxc.accelerators.registry import AcceleratorRegistry +from snaxc.accelerators.snax import SNAXStreamer +from snaxc.dialects import dart, snax_stream +from snaxc.ir.dart.affine_transform import AffineTransform @dataclass diff --git a/compiler/transforms/convert_kernel_to_linalg.py b/snaxc/compiler/transforms/convert_kernel_to_linalg.py similarity index 97% rename from compiler/transforms/convert_kernel_to_linalg.py rename to snaxc/compiler/transforms/convert_kernel_to_linalg.py index 8274b1a1..86d3abb6 100644 --- a/compiler/transforms/convert_kernel_to_linalg.py +++ b/snaxc/compiler/transforms/convert_kernel_to_linalg.py @@ -8,7 +8,7 @@ op_type_rewrite_pattern, ) -from compiler.dialects.kernel import Parsable +from snaxc.dialects.kernel import Parsable class LowerLinalgBody(RewritePattern): diff --git a/compiler/transforms/convert_linalg_to_accfg.py b/snaxc/compiler/transforms/convert_linalg_to_accfg.py similarity index 98% rename from compiler/transforms/convert_linalg_to_accfg.py rename to snaxc/compiler/transforms/convert_linalg_to_accfg.py index 6d8c0edb..e91293be 100644 --- a/compiler/transforms/convert_linalg_to_accfg.py +++ b/snaxc/compiler/transforms/convert_linalg_to_accfg.py @@ -12,10 +12,10 @@ op_type_rewrite_pattern, ) -from compiler.accelerators.registry import AcceleratorRegistry -from compiler.dialects import accfg -from compiler.dialects.snax_stream import StreamingRegionOp -from compiler.inference.helpers import ( +from snaxc.accelerators.registry import AcceleratorRegistry +from snaxc.dialects import accfg +from snaxc.dialects.snax_stream import StreamingRegionOp +from snaxc.inference.helpers import ( calc_if_state_delta, find_all_acc_names_in_region, find_existing_block_arg, diff --git a/compiler/transforms/convert_linalg_to_kernel.py b/snaxc/compiler/transforms/convert_linalg_to_kernel.py similarity index 98% rename from compiler/transforms/convert_linalg_to_kernel.py rename to snaxc/compiler/transforms/convert_linalg_to_kernel.py index 4631da94..03e30d9c 100644 --- a/compiler/transforms/convert_linalg_to_kernel.py +++ b/snaxc/compiler/transforms/convert_linalg_to_kernel.py @@ -11,7 +11,7 @@ ) from xdsl.rewriter import InsertPoint -from compiler.dialects.kernel import Kernel, Parsable +from snaxc.dialects.kernel import Kernel, Parsable def check_kernel_equivalence(block_a: Block, block_b: Block) -> bool: diff --git a/compiler/transforms/convert_tosa_to_kernel.py b/snaxc/compiler/transforms/convert_tosa_to_kernel.py similarity index 99% rename from compiler/transforms/convert_tosa_to_kernel.py rename to snaxc/compiler/transforms/convert_tosa_to_kernel.py index 3809d852..63cec662 100644 --- a/compiler/transforms/convert_tosa_to_kernel.py +++ b/snaxc/compiler/transforms/convert_tosa_to_kernel.py @@ -13,7 +13,7 @@ ) from xdsl.utils.hints import isa -from compiler.dialects import kernel +from snaxc.dialects import kernel def assert_int8(val: float | int) -> int: diff --git a/compiler/transforms/dart/__init__.py b/snaxc/compiler/transforms/dart/__init__.py similarity index 100% rename from compiler/transforms/dart/__init__.py rename to snaxc/compiler/transforms/dart/__init__.py diff --git a/compiler/transforms/dart/convert_linalg_to_dart.py b/snaxc/compiler/transforms/dart/convert_linalg_to_dart.py similarity index 99% rename from compiler/transforms/dart/convert_linalg_to_dart.py rename to snaxc/compiler/transforms/dart/convert_linalg_to_dart.py index 5b8aacb1..7077a344 100644 --- a/compiler/transforms/dart/convert_linalg_to_dart.py +++ b/snaxc/compiler/transforms/dart/convert_linalg_to_dart.py @@ -13,7 +13,7 @@ ) from xdsl.rewriter import InsertPoint -from compiler.dialects import dart +from snaxc.dialects import dart @dataclass diff --git a/compiler/transforms/dart/dart_fuse_operations.py b/snaxc/compiler/transforms/dart/dart_fuse_operations.py similarity index 99% rename from compiler/transforms/dart/dart_fuse_operations.py rename to snaxc/compiler/transforms/dart/dart_fuse_operations.py index 255d5d90..8f8cff66 100644 --- a/compiler/transforms/dart/dart_fuse_operations.py +++ b/snaxc/compiler/transforms/dart/dart_fuse_operations.py @@ -13,7 +13,7 @@ ) from xdsl.rewriter import InsertPoint -from compiler.dialects import dart +from snaxc.dialects import dart @dataclass diff --git a/compiler/transforms/dart/dart_layout_resolution.py b/snaxc/compiler/transforms/dart/dart_layout_resolution.py similarity index 95% rename from compiler/transforms/dart/dart_layout_resolution.py rename to snaxc/compiler/transforms/dart/dart_layout_resolution.py index 61f6067e..318f02f2 100644 --- a/compiler/transforms/dart/dart_layout_resolution.py +++ b/snaxc/compiler/transforms/dart/dart_layout_resolution.py @@ -14,9 +14,9 @@ op_type_rewrite_pattern, ) -from compiler.dialects import dart -from compiler.ir.dart.access_pattern import Schedule, SchedulePattern -from compiler.ir.dart.affine_transform import AffineTransform +from snaxc.dialects import dart +from snaxc.ir.dart.access_pattern import Schedule, SchedulePattern +from snaxc.ir.dart.affine_transform import AffineTransform @dataclass diff --git a/compiler/transforms/dart/dart_scheduler.py b/snaxc/compiler/transforms/dart/dart_scheduler.py similarity index 88% rename from compiler/transforms/dart/dart_scheduler.py rename to snaxc/compiler/transforms/dart/dart_scheduler.py index 85278fdc..1d633987 100644 --- a/compiler/transforms/dart/dart_scheduler.py +++ b/snaxc/compiler/transforms/dart/dart_scheduler.py @@ -11,11 +11,11 @@ op_type_rewrite_pattern, ) -from compiler.accelerators.registry import AcceleratorRegistry -from compiler.accelerators.snax import SNAXStreamer -from compiler.dialects import dart -from compiler.ir.dart.access_pattern import Schedule, SchedulePattern -from compiler.ir.dart.scheduler import scheduler +from snaxc.accelerators.registry import AcceleratorRegistry +from snaxc.accelerators.snax import SNAXStreamer +from snaxc.dialects import dart +from snaxc.ir.dart.access_pattern import Schedule, SchedulePattern +from snaxc.ir.dart.scheduler import scheduler @dataclass diff --git a/compiler/transforms/dispatch_kernels.py b/snaxc/compiler/transforms/dispatch_kernels.py similarity index 92% rename from compiler/transforms/dispatch_kernels.py rename to snaxc/compiler/transforms/dispatch_kernels.py index 1e65d40e..1bd3a196 100644 --- a/compiler/transforms/dispatch_kernels.py +++ b/snaxc/compiler/transforms/dispatch_kernels.py @@ -11,11 +11,11 @@ op_type_rewrite_pattern, ) -from compiler.accelerators.dispatching import DispatchTemplate -from compiler.accelerators.registry import AcceleratorRegistry -from compiler.accelerators.snax import SNAXStreamer -from compiler.dialects.accfg import AcceleratorOp -from compiler.dialects.kernel import KernelOp +from snaxc.accelerators.dispatching import DispatchTemplate +from snaxc.accelerators.registry import AcceleratorRegistry +from snaxc.accelerators.snax import SNAXStreamer +from snaxc.dialects.accfg import AcceleratorOp +from snaxc.dialects.kernel import KernelOp class DispatchTemplatePattern(RewritePattern): diff --git a/compiler/transforms/dispatch_regions.py b/snaxc/compiler/transforms/dispatch_regions.py similarity index 98% rename from compiler/transforms/dispatch_regions.py rename to snaxc/compiler/transforms/dispatch_regions.py index 7f8c0b1c..6cbf433e 100644 --- a/compiler/transforms/dispatch_regions.py +++ b/snaxc/compiler/transforms/dispatch_regions.py @@ -14,7 +14,7 @@ from xdsl.rewriter import InsertPoint from xdsl.traits import SymbolTable -from compiler.util.dispatching_rules import dispatch_to_compute, dispatch_to_dm +from snaxc.util.dispatching_rules import dispatch_to_compute, dispatch_to_dm @dataclass diff --git a/compiler/transforms/frontend/__init__.py b/snaxc/compiler/transforms/frontend/__init__.py similarity index 100% rename from compiler/transforms/frontend/__init__.py rename to snaxc/compiler/transforms/frontend/__init__.py diff --git a/compiler/transforms/frontend/preprocess_mlir.py b/snaxc/compiler/transforms/frontend/preprocess_mlir.py similarity index 100% rename from compiler/transforms/frontend/preprocess_mlir.py rename to snaxc/compiler/transforms/frontend/preprocess_mlir.py diff --git a/compiler/transforms/frontend/preprocess_mlperf_tiny.py b/snaxc/compiler/transforms/frontend/preprocess_mlperf_tiny.py similarity index 97% rename from compiler/transforms/frontend/preprocess_mlperf_tiny.py rename to snaxc/compiler/transforms/frontend/preprocess_mlperf_tiny.py index 0338c805..44724b9b 100644 --- a/compiler/transforms/frontend/preprocess_mlperf_tiny.py +++ b/snaxc/compiler/transforms/frontend/preprocess_mlperf_tiny.py @@ -15,10 +15,10 @@ from xdsl.rewriter import InsertPoint from xdsl.transforms.mlir_opt import MLIROptPass -from compiler.dialects import snax -from compiler.transforms.alloc_to_global import AllocToGlobal -from compiler.transforms.convert_tosa_to_kernel import RescaleClampPattern -from compiler.transforms.test.insert_debugs import InsertDebugStatements +from snaxc.dialects import snax +from snaxc.transforms.alloc_to_global import AllocToGlobal +from snaxc.transforms.convert_tosa_to_kernel import RescaleClampPattern +from snaxc.transforms.test.insert_debugs import InsertDebugStatements class InsertStaticFunctionCall(RewritePattern): diff --git a/compiler/transforms/insert_accfg_op.py b/snaxc/compiler/transforms/insert_accfg_op.py similarity index 95% rename from compiler/transforms/insert_accfg_op.py rename to snaxc/compiler/transforms/insert_accfg_op.py index 2747a8f1..0c542fc5 100644 --- a/compiler/transforms/insert_accfg_op.py +++ b/snaxc/compiler/transforms/insert_accfg_op.py @@ -5,7 +5,7 @@ from xdsl.passes import ModulePass from xdsl.traits import SymbolTable -from compiler.accelerators.registry import AcceleratorRegistry +from snaxc.accelerators.registry import AcceleratorRegistry @dataclass(frozen=True) diff --git a/compiler/transforms/insert_sync_barrier.py b/snaxc/compiler/transforms/insert_sync_barrier.py similarity index 94% rename from compiler/transforms/insert_sync_barrier.py rename to snaxc/compiler/transforms/insert_sync_barrier.py index 58cea27f..b25f5346 100644 --- a/compiler/transforms/insert_sync_barrier.py +++ b/snaxc/compiler/transforms/insert_sync_barrier.py @@ -4,8 +4,8 @@ from xdsl.passes import ModulePass from xdsl.rewriter import InsertPoint, Rewriter -from compiler.dialects import snax -from compiler.util.dispatching_rules import dispatch_to_compute, dispatch_to_dm +from snaxc.dialects import snax +from snaxc.util.dispatching_rules import dispatch_to_compute, dispatch_to_dm class InsertSyncBarrier(ModulePass): diff --git a/compiler/transforms/memref_to_snax.py b/snaxc/compiler/transforms/memref_to_snax.py similarity index 97% rename from compiler/transforms/memref_to_snax.py rename to snaxc/compiler/transforms/memref_to_snax.py index 0f78d56d..39c001e4 100644 --- a/compiler/transforms/memref_to_snax.py +++ b/snaxc/compiler/transforms/memref_to_snax.py @@ -19,9 +19,9 @@ op_type_rewrite_pattern, ) -from compiler.dialects import snax -from compiler.dialects.tsl import TiledStridedLayoutAttr -from compiler.util.snax_memory import L1 +from snaxc.dialects import snax +from snaxc.dialects.tsl import TiledStridedLayoutAttr +from snaxc.util.snax_memory import L1 class AllocOpRewrite(RewritePattern): diff --git a/compiler/transforms/realize_memref_casts.py b/snaxc/compiler/transforms/realize_memref_casts.py similarity index 98% rename from compiler/transforms/realize_memref_casts.py rename to snaxc/compiler/transforms/realize_memref_casts.py index 537f69e4..b1552ced 100644 --- a/compiler/transforms/realize_memref_casts.py +++ b/snaxc/compiler/transforms/realize_memref_casts.py @@ -12,8 +12,8 @@ from xdsl.rewriter import InsertPoint from xdsl.utils.hints import isa -from compiler.dialects import dart -from compiler.dialects.snax import LayoutCast +from snaxc.dialects import dart +from snaxc.dialects.snax import LayoutCast def is_cast_op(op: Operation) -> bool: diff --git a/compiler/transforms/reuse_memref_allocs.py b/snaxc/compiler/transforms/reuse_memref_allocs.py similarity index 100% rename from compiler/transforms/reuse_memref_allocs.py rename to snaxc/compiler/transforms/reuse_memref_allocs.py diff --git a/compiler/transforms/set_memory_layout.py b/snaxc/compiler/transforms/set_memory_layout.py similarity index 94% rename from compiler/transforms/set_memory_layout.py rename to snaxc/compiler/transforms/set_memory_layout.py index 714ef164..43bd2780 100644 --- a/compiler/transforms/set_memory_layout.py +++ b/snaxc/compiler/transforms/set_memory_layout.py @@ -16,11 +16,11 @@ from xdsl.rewriter import InsertPoint from xdsl.utils.hints import isa -from compiler.dialects import dart -from compiler.dialects.snax import LayoutCast -from compiler.dialects.tsl import TiledStridedLayoutAttr -from compiler.ir.dart.access_pattern import Schedule, SchedulePattern -from compiler.ir.tsl import Stride, TiledStride, TiledStridedLayout +from snaxc.dialects import dart +from snaxc.dialects.snax import LayoutCast +from snaxc.dialects.tsl import TiledStridedLayoutAttr +from snaxc.ir.dart.access_pattern import Schedule, SchedulePattern +from snaxc.ir.tsl import Stride, TiledStride, TiledStridedLayout @dataclass diff --git a/compiler/transforms/set_memory_space.py b/snaxc/compiler/transforms/set_memory_space.py similarity index 99% rename from compiler/transforms/set_memory_space.py rename to snaxc/compiler/transforms/set_memory_space.py index ff5477ac..40d1566f 100644 --- a/compiler/transforms/set_memory_space.py +++ b/snaxc/compiler/transforms/set_memory_space.py @@ -12,8 +12,8 @@ ) from xdsl.utils.hints import isa -from compiler.dialects import dart -from compiler.util.snax_memory import L1, L3 +from snaxc.dialects import dart +from snaxc.util.snax_memory import L1, L3 class InitFuncMemorySpace(RewritePattern): diff --git a/compiler/transforms/snax_bufferize.py b/snaxc/compiler/transforms/snax_bufferize.py similarity index 98% rename from compiler/transforms/snax_bufferize.py rename to snaxc/compiler/transforms/snax_bufferize.py index aa6eed69..b003d9cf 100644 --- a/compiler/transforms/snax_bufferize.py +++ b/snaxc/compiler/transforms/snax_bufferize.py @@ -12,7 +12,7 @@ ) from xdsl.transforms.mlir_opt import MLIROptPass -from compiler.dialects import dart +from snaxc.dialects import dart @dataclass diff --git a/compiler/transforms/snax_copy_to_dma.py b/snaxc/compiler/transforms/snax_copy_to_dma.py similarity index 99% rename from compiler/transforms/snax_copy_to_dma.py rename to snaxc/compiler/transforms/snax_copy_to_dma.py index 2a565677..d2a6ee53 100644 --- a/compiler/transforms/snax_copy_to_dma.py +++ b/snaxc/compiler/transforms/snax_copy_to_dma.py @@ -26,8 +26,8 @@ ) from xdsl.traits import SymbolTable -from compiler.dialects.tsl import TiledStridedLayoutAttr -from compiler.ir.tsl import TiledStridedLayout +from snaxc.dialects.tsl import TiledStridedLayoutAttr +from snaxc.ir.tsl import TiledStridedLayout def get_total_size_op(source: SSAValue): diff --git a/compiler/transforms/snax_lower_mcycle.py b/snaxc/compiler/transforms/snax_lower_mcycle.py similarity index 96% rename from compiler/transforms/snax_lower_mcycle.py rename to snaxc/compiler/transforms/snax_lower_mcycle.py index 9d2aa0aa..15783ff5 100644 --- a/compiler/transforms/snax_lower_mcycle.py +++ b/snaxc/compiler/transforms/snax_lower_mcycle.py @@ -8,7 +8,7 @@ op_type_rewrite_pattern, ) -from compiler.dialects import snax +from snaxc.dialects import snax class ConvertMCycleToLLVM(RewritePattern): diff --git a/compiler/transforms/snax_to_func.py b/snaxc/compiler/transforms/snax_to_func.py similarity index 98% rename from compiler/transforms/snax_to_func.py rename to snaxc/compiler/transforms/snax_to_func.py index 8122c4b1..9a3bcc56 100644 --- a/compiler/transforms/snax_to_func.py +++ b/snaxc/compiler/transforms/snax_to_func.py @@ -11,8 +11,8 @@ ) from xdsl.traits import SymbolTable -from compiler.dialects import snax -from compiler.util.snax_memory import L1 +from snaxc.dialects import snax +from snaxc.util.snax_memory import L1 class InsertFunctionCall(RewritePattern): diff --git a/compiler/transforms/test/__init__.py b/snaxc/compiler/transforms/test/__init__.py similarity index 100% rename from compiler/transforms/test/__init__.py rename to snaxc/compiler/transforms/test/__init__.py diff --git a/compiler/transforms/test/debug_to_func.py b/snaxc/compiler/transforms/test/debug_to_func.py similarity index 98% rename from compiler/transforms/test/debug_to_func.py rename to snaxc/compiler/transforms/test/debug_to_func.py index 2b3dd316..89a71287 100644 --- a/compiler/transforms/test/debug_to_func.py +++ b/snaxc/compiler/transforms/test/debug_to_func.py @@ -10,7 +10,7 @@ ) from xdsl.traits import SymbolTable -from compiler.dialects.test import debug +from snaxc.dialects.test import debug class DebugToFunc(RewritePattern): diff --git a/compiler/transforms/test/insert_debugs.py b/snaxc/compiler/transforms/test/insert_debugs.py similarity index 97% rename from compiler/transforms/test/insert_debugs.py rename to snaxc/compiler/transforms/test/insert_debugs.py index 34bc466d..cf9a545f 100644 --- a/compiler/transforms/test/insert_debugs.py +++ b/snaxc/compiler/transforms/test/insert_debugs.py @@ -11,7 +11,7 @@ ) from xdsl.rewriter import InsertPoint -from compiler.dialects.test import debug +from snaxc.dialects.test import debug @dataclass(frozen=True) diff --git a/compiler/transforms/test/test_add_mcycle_around_launch.py b/snaxc/compiler/transforms/test/test_add_mcycle_around_launch.py similarity index 96% rename from compiler/transforms/test/test_add_mcycle_around_launch.py rename to snaxc/compiler/transforms/test/test_add_mcycle_around_launch.py index edc0ecaa..6dd3e203 100644 --- a/compiler/transforms/test/test_add_mcycle_around_launch.py +++ b/snaxc/compiler/transforms/test/test_add_mcycle_around_launch.py @@ -9,7 +9,7 @@ ) from xdsl.rewriter import InsertPoint -from compiler.dialects import accfg, snax +from snaxc.dialects import accfg, snax class InsertBeforeLaunch(RewritePattern): diff --git a/compiler/transforms/test_add_mcycle_around_loop.py b/snaxc/compiler/transforms/test_add_mcycle_around_loop.py similarity index 97% rename from compiler/transforms/test_add_mcycle_around_loop.py rename to snaxc/compiler/transforms/test_add_mcycle_around_loop.py index 649fea7c..86bcbdaf 100644 --- a/compiler/transforms/test_add_mcycle_around_loop.py +++ b/snaxc/compiler/transforms/test_add_mcycle_around_loop.py @@ -9,7 +9,7 @@ ) from xdsl.rewriter import InsertPoint -from compiler.dialects import snax +from snaxc.dialects import snax class InsertMcycleForLoop(RewritePattern): diff --git a/compiler/transforms/test_remove_memref_copy.py b/snaxc/compiler/transforms/test_remove_memref_copy.py similarity index 100% rename from compiler/transforms/test_remove_memref_copy.py rename to snaxc/compiler/transforms/test_remove_memref_copy.py diff --git a/compiler/util/__init__.py b/snaxc/compiler/util/__init__.py similarity index 100% rename from compiler/util/__init__.py rename to snaxc/compiler/util/__init__.py diff --git a/compiler/util/canonicalize_affine.py b/snaxc/compiler/util/canonicalize_affine.py similarity index 100% rename from compiler/util/canonicalize_affine.py rename to snaxc/compiler/util/canonicalize_affine.py diff --git a/compiler/util/dispatching_rules.py b/snaxc/compiler/util/dispatching_rules.py similarity index 94% rename from compiler/util/dispatching_rules.py rename to snaxc/compiler/util/dispatching_rules.py index b1814012..a4367eaa 100644 --- a/compiler/util/dispatching_rules.py +++ b/snaxc/compiler/util/dispatching_rules.py @@ -1,7 +1,7 @@ from xdsl.dialects import linalg, memref from xdsl.ir import Operation -from compiler.dialects import dart +from snaxc.dialects import dart def dispatch_to_dm(op: Operation): diff --git a/compiler/util/memref_descriptor.py b/snaxc/compiler/util/memref_descriptor.py similarity index 100% rename from compiler/util/memref_descriptor.py rename to snaxc/compiler/util/memref_descriptor.py diff --git a/compiler/util/pack_bitlist.py b/snaxc/compiler/util/pack_bitlist.py similarity index 100% rename from compiler/util/pack_bitlist.py rename to snaxc/compiler/util/pack_bitlist.py diff --git a/compiler/util/snax_memory.py b/snaxc/compiler/util/snax_memory.py similarity index 100% rename from compiler/util/snax_memory.py rename to snaxc/compiler/util/snax_memory.py diff --git a/tests/filecheck/lit.cfg b/tests/filecheck/lit.cfg index 0408a756..66e784f3 100644 --- a/tests/filecheck/lit.cfg +++ b/tests/filecheck/lit.cfg @@ -8,8 +8,8 @@ config.name = "SNAX" config.test_format = lit.formats.ShTest(preamble_commands=[f"cd {snax_src}"]) config.suffixes = ['.test', '.mlir', '.py'] -config.substitutions.append(('XDSL_PARSING_DIAG', "./compiler/snax-opt %s --print-op-generic --parsing-diagnostics --split-input-file | filecheck %s")) -config.substitutions.append(('XDSL_VERIFY_DIAG', "./compiler/snax-opt %s --print-op-generic --verify-diagnostics --split-input-file | filecheck %s")) -config.substitutions.append(('XDSL_ROUNDTRIP', "./compiler/snax-opt %s --print-op-generic --split-input-file | ./compiler/snax-opt --split-input-file | filecheck %s")) -config.substitutions.append(('XDSL_SINGLETRIP', "./compiler/snax-opt %s --split-input-file | filecheck %s")) -config.substitutions.append(("XDSL_GENERIC_ROUNDTRIP", "./compiler/snax-opt %s --print-op-generic --split-input-file | filecheck %s --check-prefix=CHECK-GENERIC")) +config.substitutions.append(('XDSL_PARSING_DIAG', "snax-opt %s --print-op-generic --parsing-diagnostics --split-input-file | filecheck %s")) +config.substitutions.append(('XDSL_VERIFY_DIAG', "snax-opt %s --print-op-generic --verify-diagnostics --split-input-file | filecheck %s")) +config.substitutions.append(('XDSL_ROUNDTRIP', "snax-opt %s --print-op-generic --split-input-file | ./compiler/snax-opt --split-input-file | filecheck %s")) +config.substitutions.append(('XDSL_SINGLETRIP', "snax-opt %s --split-input-file | filecheck %s")) +config.substitutions.append(("XDSL_GENERIC_ROUNDTRIP", "snax-opt %s --print-op-generic --split-input-file | filecheck %s --check-prefix=CHECK-GENERIC")) diff --git a/tests/filecheck/transforms/acc-dedup.mlir b/tests/filecheck/transforms/acc-dedup.mlir index cf1ac68a..4ef7104e 100644 --- a/tests/filecheck/transforms/acc-dedup.mlir +++ b/tests/filecheck/transforms/acc-dedup.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p accfg-dedup | filecheck %s +// RUN: snax-opt %s -p accfg-dedup | filecheck %s func.func public @simple_mult(%arg0 : memref, %arg1 : memref, %arg2 : memref) { %0 = arith.constant 0 : index diff --git a/tests/filecheck/transforms/accfg-end-to-end.mlir b/tests/filecheck/transforms/accfg-end-to-end.mlir index 7f87375f..94c295fb 100644 --- a/tests/filecheck/transforms/accfg-end-to-end.mlir +++ b/tests/filecheck/transforms/accfg-end-to-end.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file -p 'convert-linalg-to-accfg,mlir-opt{executable=mlir-opt generic=true arguments=-cse,-canonicalize,-allow-unregistered-dialect,-mlir-print-op-generic,-split-input-file},accfg-dedup,convert-accfg-to-csr' %s | filecheck %s +// RUN: snax-opt --split-input-file -p 'convert-linalg-to-accfg,mlir-opt{executable=mlir-opt generic=true arguments=-cse,-canonicalize,-allow-unregistered-dialect,-mlir-print-op-generic,-split-input-file},accfg-dedup,convert-accfg-to-csr' %s | filecheck %s builtin.module { "accfg.accelerator"() <{ diff --git a/tests/filecheck/transforms/accfg-trace-states.mlir b/tests/filecheck/transforms/accfg-trace-states.mlir index 8acc167a..5269cd71 100644 --- a/tests/filecheck/transforms/accfg-trace-states.mlir +++ b/tests/filecheck/transforms/accfg-trace-states.mlir @@ -1,4 +1,4 @@ -//RUN: ./compiler/snax-opt -p accfg-trace-states %s | filecheck %s +//RUN: snax-opt -p accfg-trace-states %s | filecheck %s // CHECK-NEXT: builtin.module { diff --git a/tests/filecheck/transforms/alloc-to-global.mlir b/tests/filecheck/transforms/alloc-to-global.mlir index 90a5cfc7..efff4225 100644 --- a/tests/filecheck/transforms/alloc-to-global.mlir +++ b/tests/filecheck/transforms/alloc-to-global.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p alloc-to-global | filecheck %s +// RUN: snax-opt %s -p alloc-to-global | filecheck %s %0 = memref.alloc() {"alignment" = 64 : i64} : memref<16x16xi8> diff --git a/tests/filecheck/transforms/convert-acc-to-csr.mlir b/tests/filecheck/transforms/convert-acc-to-csr.mlir index b85ba49e..36782829 100644 --- a/tests/filecheck/transforms/convert-acc-to-csr.mlir +++ b/tests/filecheck/transforms/convert-acc-to-csr.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p convert-accfg-to-csr --split-input-file | filecheck %s +// RUN: snax-opt %s -p convert-accfg-to-csr --split-input-file | filecheck %s builtin.module { diff --git a/tests/filecheck/transforms/convert-dart-to-snax-stream.mlir b/tests/filecheck/transforms/convert-dart-to-snax-stream.mlir index c86e3ad9..bb6ee361 100644 --- a/tests/filecheck/transforms/convert-dart-to-snax-stream.mlir +++ b/tests/filecheck/transforms/convert-dart-to-snax-stream.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p insert-accfg-op{accelerator=snax_alu},insert-accfg-op{accelerator=snax_gemmx},dart-scheduler,dart-layout-resolution,convert-dart-to-snax-stream | filecheck %s +// RUN: snax-opt --split-input-file %s -p insert-accfg-op{accelerator=snax_alu},insert-accfg-op{accelerator=snax_gemmx},dart-scheduler,dart-layout-resolution,convert-dart-to-snax-stream | filecheck %s func.func public @streamer_add(%arg0 : memref<16xi64>, %arg1 : memref<16xi64>, %arg2 : memref<16xi64>) { "dart.operation"(%arg0, %arg1, %arg2) <{patterns = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], accelerator = "snax_alu", operandSegmentSizes = array}> ({ diff --git a/tests/filecheck/transforms/convert-kernel-to-linalg.mlir b/tests/filecheck/transforms/convert-kernel-to-linalg.mlir index a525c510..f9067c62 100644 --- a/tests/filecheck/transforms/convert-kernel-to-linalg.mlir +++ b/tests/filecheck/transforms/convert-kernel-to-linalg.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file -p convert-kernel-to-linalg %s | filecheck %s +// RUN: snax-opt --split-input-file -p convert-kernel-to-linalg %s | filecheck %s %0, %1, %2 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>) linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%0, %1 : memref<64xi32>, memref<64xi32>) outs(%2 : memref<64xi32>) { diff --git a/tests/filecheck/transforms/convert-linalg-to-acc.mlir b/tests/filecheck/transforms/convert-linalg-to-acc.mlir index ca119456..5ef0ca10 100644 --- a/tests/filecheck/transforms/convert-linalg-to-acc.mlir +++ b/tests/filecheck/transforms/convert-linalg-to-acc.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt -p 'convert-linalg-to-accfg,mlir-opt{executable=mlir-opt generic=true arguments=-cse,-canonicalize,-allow-unregistered-dialect,-mlir-print-op-generic,-split-input-file}' %s | filecheck %s +// RUN: snax-opt -p 'convert-linalg-to-accfg,mlir-opt{executable=mlir-opt generic=true arguments=-cse,-canonicalize,-allow-unregistered-dialect,-mlir-print-op-generic,-split-input-file}' %s | filecheck %s "builtin.module"() ({ "accfg.accelerator"() <{ diff --git a/tests/filecheck/transforms/convert-linalg-to-dart.mlir b/tests/filecheck/transforms/convert-linalg-to-dart.mlir index 822baf81..c480b2de 100644 --- a/tests/filecheck/transforms/convert-linalg-to-dart.mlir +++ b/tests/filecheck/transforms/convert-linalg-to-dart.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt -p convert-linalg-to-dart %s | filecheck %s +// RUN: snax-opt -p convert-linalg-to-dart %s | filecheck %s %arg0, %arg1, %arg2 = "test.op"() : () -> (tensor<16x16xi8>, tensor<16x16xi8>, tensor<16x16xi32>) %c0_i32 = arith.constant 0 : i32 diff --git a/tests/filecheck/transforms/convert-linalg-to-kernel.mlir b/tests/filecheck/transforms/convert-linalg-to-kernel.mlir index 04c8cead..2bf0e818 100644 --- a/tests/filecheck/transforms/convert-linalg-to-kernel.mlir +++ b/tests/filecheck/transforms/convert-linalg-to-kernel.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file -p convert-linalg-to-kernel %s | filecheck %s +// RUN: snax-opt --split-input-file -p convert-linalg-to-kernel %s | filecheck %s %0, %1, %2 = "test.op"() : () -> (memref<64xi32>, memref<64xi32>, memref<64xi32>) linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%0, %1 : memref<64xi32>, memref<64xi32>) outs(%2 : memref<64xi32>) { diff --git a/tests/filecheck/transforms/convert-tosa-to-kernel.mlir b/tests/filecheck/transforms/convert-tosa-to-kernel.mlir index a87c9ca2..9a4677d4 100644 --- a/tests/filecheck/transforms/convert-tosa-to-kernel.mlir +++ b/tests/filecheck/transforms/convert-tosa-to-kernel.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt -p convert-tosa-to-kernel %s | filecheck %s +// RUN: snax-opt -p convert-tosa-to-kernel %s | filecheck %s %0 = "test.op"() : () -> tensor %1 = tosa.rescale %0 {"double_round" = true, "input_zp" = 0 : i32, "multiplier" = array, "output_zp" = -128 : i32, "per_channel" = false, "scale32" = true, "shift" = array} : (tensor) -> tensor diff --git a/tests/filecheck/transforms/copy_to_dma.mlir b/tests/filecheck/transforms/copy_to_dma.mlir index 8afae9bc..61791b90 100644 --- a/tests/filecheck/transforms/copy_to_dma.mlir +++ b/tests/filecheck/transforms/copy_to_dma.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p snax-copy-to-dma | filecheck %s +// RUN: snax-opt --split-input-file %s -p snax-copy-to-dma | filecheck %s "builtin.module"() ({ "func.func"() <{"sym_name" = "simple_mult", "function_type" = (memref, memref) -> (), "sym_visibility" = "public"}> ({ diff --git a/tests/filecheck/transforms/dart/dart-scheduler.mlir b/tests/filecheck/transforms/dart/dart-scheduler.mlir index e2dbdca7..3f50ea66 100644 --- a/tests/filecheck/transforms/dart/dart-scheduler.mlir +++ b/tests/filecheck/transforms/dart/dart-scheduler.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p insert-accfg-op{accelerator=snax_gemmx},dart-scheduler | filecheck %s +// RUN: snax-opt --split-input-file %s -p insert-accfg-op{accelerator=snax_gemmx},dart-scheduler | filecheck %s func.func @streamer_matmul(%arg0 : memref<16x16xi8>, %arg1 : memref<16x16xi8, strided<[1, 16]>>, %arg2 : memref<16x16xi32>) { %0 = arith.constant 0 : i32 diff --git a/tests/filecheck/transforms/dispatch_kernels.mlir b/tests/filecheck/transforms/dispatch_kernels.mlir index d245f5f1..59505e9c 100644 --- a/tests/filecheck/transforms/dispatch_kernels.mlir +++ b/tests/filecheck/transforms/dispatch_kernels.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p insert-accfg-op{accelerator=snax_alu},insert-accfg-op{accelerator=snax_gemm},dispatch-kernels --allow-unregistered-dialect --print-op-generic | filecheck %s +// RUN: snax-opt --split-input-file %s -p insert-accfg-op{accelerator=snax_alu},insert-accfg-op{accelerator=snax_gemm},dispatch-kernels --allow-unregistered-dialect --print-op-generic | filecheck %s builtin.module { func.func @mnist(%arg0 : memref, %arg1 : memref<128x128xi8>, %arg2 : memref) -> memref { diff --git a/tests/filecheck/transforms/dispatch_regions.mlir b/tests/filecheck/transforms/dispatch_regions.mlir index bbc65293..c60ddea0 100644 --- a/tests/filecheck/transforms/dispatch_regions.mlir +++ b/tests/filecheck/transforms/dispatch_regions.mlir @@ -1,5 +1,5 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p dispatch-regions | filecheck %s --check-prefixes=CHECK,NB_TWO -// RUN: ./compiler/snax-opt --split-input-file %s -p dispatch-regions{nb_cores=3} | filecheck %s --check-prefixes=CHECK,NB_THREE +// RUN: snax-opt --split-input-file %s -p dispatch-regions | filecheck %s --check-prefixes=CHECK,NB_TWO +// RUN: snax-opt --split-input-file %s -p dispatch-regions{nb_cores=3} | filecheck %s --check-prefixes=CHECK,NB_THREE // test function without dispatchable ops "builtin.module"() ({ diff --git a/tests/filecheck/transforms/fuse-streaming-regions.mlir b/tests/filecheck/transforms/fuse-streaming-regions.mlir index 22cf4952..e162a941 100644 --- a/tests/filecheck/transforms/fuse-streaming-regions.mlir +++ b/tests/filecheck/transforms/fuse-streaming-regions.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt -p dart-fuse-operations %s | filecheck %s +// RUN: snax-opt -p dart-fuse-operations %s | filecheck %s func.func @streamer_matmul(%arg0 : tensor<16x16xi8>, %arg1 : tensor<16x16xi8>, %arg2 : tensor<16x16xi32>) -> tensor<16x16xi32> { %c0_i32 = arith.constant 0 : i32 diff --git a/tests/filecheck/transforms/insert-acc-op.mlir b/tests/filecheck/transforms/insert-acc-op.mlir index e6ae9ef1..ff07fc7a 100644 --- a/tests/filecheck/transforms/insert-acc-op.mlir +++ b/tests/filecheck/transforms/insert-acc-op.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p insert-accfg-op{accelerator=snax_hwpe_mult} | filecheck %s +// RUN: snax-opt %s -p insert-accfg-op{accelerator=snax_hwpe_mult} | filecheck %s builtin.module{} diff --git a/tests/filecheck/transforms/insert-sync-barrier.mlir b/tests/filecheck/transforms/insert-sync-barrier.mlir index cd617f91..c42e1a83 100644 --- a/tests/filecheck/transforms/insert-sync-barrier.mlir +++ b/tests/filecheck/transforms/insert-sync-barrier.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p insert-sync-barrier --print-op-generic | filecheck %s +// RUN: snax-opt --split-input-file %s -p insert-sync-barrier --print-op-generic | filecheck %s // two global ops: no synchronization barrier required "builtin.module"() ({ diff --git a/tests/filecheck/transforms/memref_to_snax.mlir b/tests/filecheck/transforms/memref_to_snax.mlir index 3a19f0de..e99644ad 100644 --- a/tests/filecheck/transforms/memref_to_snax.mlir +++ b/tests/filecheck/transforms/memref_to_snax.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p memref-to-snax | filecheck %s +// RUN: snax-opt --split-input-file %s -p memref-to-snax | filecheck %s "builtin.module"() ({ %0 = "memref.alloc"() <{"alignment" = 64 : i64, "operandSegmentSizes" = array}> : () -> memref<16x16xi32> diff --git a/tests/filecheck/transforms/realize-memref-casts.mlir b/tests/filecheck/transforms/realize-memref-casts.mlir index 3d3a5cf0..1870797c 100644 --- a/tests/filecheck/transforms/realize-memref-casts.mlir +++ b/tests/filecheck/transforms/realize-memref-casts.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p realize-memref-casts --print-op-generic | filecheck %s +// RUN: snax-opt --split-input-file %s -p realize-memref-casts --print-op-generic | filecheck %s "builtin.module"() ({ %0 = "test.op"() : () -> (memref<64xi32, "L3">) diff --git a/tests/filecheck/transforms/reuse-memref-allocs.mlir b/tests/filecheck/transforms/reuse-memref-allocs.mlir index 969866bf..0df81aec 100644 --- a/tests/filecheck/transforms/reuse-memref-allocs.mlir +++ b/tests/filecheck/transforms/reuse-memref-allocs.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p reuse-memref-allocs | filecheck %s +// RUN: snax-opt --split-input-file %s -p reuse-memref-allocs | filecheck %s builtin.module { func.func @streamer_matmul(%arg0 : memref, %arg1 : memref) { diff --git a/tests/filecheck/transforms/rocc-dedup.mlir b/tests/filecheck/transforms/rocc-dedup.mlir index 5b105512..ed5bb4b0 100644 --- a/tests/filecheck/transforms/rocc-dedup.mlir +++ b/tests/filecheck/transforms/rocc-dedup.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p convert-accfg-to-csr | filecheck %s +// RUN: snax-opt %s -p convert-accfg-to-csr | filecheck %s builtin.module { diff --git a/tests/filecheck/transforms/set-memory-layout.mlir b/tests/filecheck/transforms/set-memory-layout.mlir index 11667d50..7c816b28 100644 --- a/tests/filecheck/transforms/set-memory-layout.mlir +++ b/tests/filecheck/transforms/set-memory-layout.mlir @@ -1,5 +1,5 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p set-memory-layout{tiled=false} --print-op-generic | filecheck %s -// RUN: ./compiler/snax-opt --split-input-file %s -p set-memory-layout{tiled=true} --print-op-generic | filecheck %s --check-prefix=TILED +// RUN: snax-opt --split-input-file %s -p set-memory-layout{tiled=false} --print-op-generic | filecheck %s +// RUN: snax-opt --split-input-file %s -p set-memory-layout{tiled=true} --print-op-generic | filecheck %s --check-prefix=TILED func.func @gemm(%arg0 : memref<16x16xi8, "L1">, %arg1 : memref<16x16xi8, "L1">, %arg2 : memref<16x16xi32, "L1">) { %0 = arith.constant 0 : i32 diff --git a/tests/filecheck/transforms/set-memory-space.mlir b/tests/filecheck/transforms/set-memory-space.mlir index 300a0bf7..6a89068e 100644 --- a/tests/filecheck/transforms/set-memory-space.mlir +++ b/tests/filecheck/transforms/set-memory-space.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p set-memory-space | filecheck %s +// RUN: snax-opt --split-input-file %s -p set-memory-space | filecheck %s "builtin.module"() ({ %0 = "memref.get_global"() <{"name" = @constant}> : () -> memref<640xi32> diff --git a/tests/filecheck/transforms/snax-pin-core-end-to-end.mlir b/tests/filecheck/transforms/snax-pin-core-end-to-end.mlir index 6e27eb99..ef4a9a07 100644 --- a/tests/filecheck/transforms/snax-pin-core-end-to-end.mlir +++ b/tests/filecheck/transforms/snax-pin-core-end-to-end.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p dispatch-regions{nb_cores=3},function-constant-pinning,snax-to-func | mlir-opt --canonicalize --inline| filecheck %s +// RUN: snax-opt %s -p dispatch-regions{nb_cores=3},function-constant-pinning,snax-to-func | mlir-opt --canonicalize --inline| filecheck %s // test function with dispatchable ops to both cores "builtin.module"() ({ diff --git a/tests/filecheck/transforms/snax_lower_mcycle.mlir b/tests/filecheck/transforms/snax_lower_mcycle.mlir index 6b829e03..adab1d2b 100644 --- a/tests/filecheck/transforms/snax_lower_mcycle.mlir +++ b/tests/filecheck/transforms/snax_lower_mcycle.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p snax-lower-mcycle --print-op-generic | filecheck %s +// RUN: snax-opt %s -p snax-lower-mcycle --print-op-generic | filecheck %s "snax.mcycle"() : () -> () //func.func @mcycle () -> () { diff --git a/tests/filecheck/transforms/snax_pin_core.mlir b/tests/filecheck/transforms/snax_pin_core.mlir index 29492c3a..f91713c5 100644 --- a/tests/filecheck/transforms/snax_pin_core.mlir +++ b/tests/filecheck/transforms/snax_pin_core.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p function-constant-pinning | filecheck %s +// RUN: snax-opt %s -p function-constant-pinning | filecheck %s builtin.module { func.func public @simple_mult(%0 : memref<64xi32>, %1 : memref<64xi32>, %2 : memref<64xi32>) { diff --git a/tests/filecheck/transforms/snax_to_func.mlir b/tests/filecheck/transforms/snax_to_func.mlir index 2b9797aa..8a3044e7 100644 --- a/tests/filecheck/transforms/snax_to_func.mlir +++ b/tests/filecheck/transforms/snax_to_func.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt --split-input-file %s -p snax-to-func --print-op-generic | filecheck %s +// RUN: snax-opt --split-input-file %s -p snax-to-func --print-op-generic | filecheck %s "builtin.module"() ({ "snax.cluster_sync_op"() : () -> () diff --git a/tests/filecheck/transforms/test-remove-memref-copy.mlir b/tests/filecheck/transforms/test-remove-memref-copy.mlir index 1774deca..36b1c06b 100644 --- a/tests/filecheck/transforms/test-remove-memref-copy.mlir +++ b/tests/filecheck/transforms/test-remove-memref-copy.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p test-remove-memref-copy | filecheck %s +// RUN: snax-opt %s -p test-remove-memref-copy | filecheck %s builtin.module { %0 = "test.op"() : () -> memref<64xi32, "L3"> diff --git a/tests/filecheck/transforms/test/debug-to-func.mlir b/tests/filecheck/transforms/test/debug-to-func.mlir index 047346c0..c1164e5a 100644 --- a/tests/filecheck/transforms/test/debug-to-func.mlir +++ b/tests/filecheck/transforms/test/debug-to-func.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p test-debug-to-func | filecheck %s +// RUN: snax-opt %s -p test-debug-to-func | filecheck %s builtin.module { func.func public @streamer_add(%arg0 : memref, %arg1 : memref, %arg2 : memref) { diff --git a/tests/filecheck/transforms/test/insert-debugs.mlir b/tests/filecheck/transforms/test/insert-debugs.mlir index 4708e051..148e7e08 100644 --- a/tests/filecheck/transforms/test/insert-debugs.mlir +++ b/tests/filecheck/transforms/test/insert-debugs.mlir @@ -1,4 +1,4 @@ -// RUN: ./compiler/snax-opt %s -p test-insert-debugs | filecheck %s +// RUN: snax-opt %s -p test-insert-debugs | filecheck %s func.func public @streamer_add(%arg0 : memref, %arg1 : memref, %arg2 : memref) { linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) { diff --git a/util/tracing/snax_event_generator.py b/util/tracing/snax_event_generator.py index 8fb0c770..429c1eba 100644 --- a/util/tracing/snax_event_generator.py +++ b/util/tracing/snax_event_generator.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field -from compiler.dialects.accfg import AcceleratorOp +from snaxc.dialects.accfg import AcceleratorOp from util.tracing.annotation import EventGenerator from util.tracing.event import DurationEvent from util.tracing.state import CSRInstruction, TraceState diff --git a/util/tracing/trace_to_perfetto.py b/util/tracing/trace_to_perfetto.py index e823e3bd..7aa3906b 100644 --- a/util/tracing/trace_to_perfetto.py +++ b/util/tracing/trace_to_perfetto.py @@ -5,8 +5,8 @@ import typing from concurrent.futures import ProcessPoolExecutor -from compiler.accelerators.registry import AcceleratorRegistry -from compiler.accelerators.snax import SNAXAccelerator +from snaxc.accelerators.registry import AcceleratorRegistry +from snaxc.accelerators.snax import SNAXAccelerator from util.tracing.annotation import ( BarrierEventGenerator, DMAEventGenerator,