From e44cd87a11d30ecc0571f7a8b9461219b489f21a Mon Sep 17 00:00:00 2001
From: emmau678
Date: Wed, 5 Feb 2025 13:29:52 +0000
Subject: [PATCH 01/23] fix: change forward block refs to forward ssa refs
---
tests/filecheck/parser-printer/graph_region.mlir | 10 ++++++++++
xdsl/parser/core.py | 2 +-
2 files changed, 11 insertions(+), 1 deletion(-)
diff --git a/tests/filecheck/parser-printer/graph_region.mlir b/tests/filecheck/parser-printer/graph_region.mlir
index b6ef50a77c..59f102121b 100644
--- a/tests/filecheck/parser-printer/graph_region.mlir
+++ b/tests/filecheck/parser-printer/graph_region.mlir
@@ -53,6 +53,16 @@ builtin.module {
// -----
+// A graph region that refers to values that are not defined in the module.
+
+builtin.module {
+ %0 = "test.termop"(%1, %2) : (i32, i32) -> i32
+}
+
+// CHECK: values %1, %2 were used but not defined
+
+// -----
+
// A forward value used with a wrong index
builtin.module {
diff --git a/xdsl/parser/core.py b/xdsl/parser/core.py
index 7d8745bbe0..d49e855e57 100644
--- a/xdsl/parser/core.py
+++ b/xdsl/parser/core.py
@@ -145,7 +145,7 @@ def parse_module(self, allow_implicit_module: bool = True) -> ModuleOp:
value_names = ", ".join(
"%" + name for name in self.forward_ssa_references.keys()
)
- if len(self.forward_block_references.keys()) > 1:
+ if len(self.forward_ssa_references.keys()) > 1:
self.raise_error(f"values {value_names} were used but not defined")
else:
self.raise_error(f"value {value_names} was used but not defined")
From b86c47eda4e8e6d5a08cc2a1282aee1338db5954 Mon Sep 17 00:00:00 2001
From: Alex Rice
Date: Tue, 4 Feb 2025 16:56:36 +0000
Subject: [PATCH 02/23] dialects: (builtin) remove AnyMemRefTypeConstr (#3832)
Can be replaced with `MemRefType.constr()`
---
tests/dialects/test_bufferization.py | 5 ++---
tests/test_traits.py | 5 ++---
xdsl/dialects/bufferization.py | 9 ++++-----
xdsl/dialects/builtin.py | 3 ---
xdsl/dialects/csl/csl_stencil.py | 19 +++++++++----------
xdsl/dialects/memref_stream.py | 4 ++--
xdsl/dialects/stencil.py | 4 ++--
.../convert_stencil_to_csl_stencil.py | 4 ++--
xdsl/transforms/csl_stencil_to_csl_wrapper.py | 5 ++---
9 files changed, 25 insertions(+), 33 deletions(-)
diff --git a/tests/dialects/test_bufferization.py b/tests/dialects/test_bufferization.py
index ab27a9dff0..1dd809c0d5 100644
--- a/tests/dialects/test_bufferization.py
+++ b/tests/dialects/test_bufferization.py
@@ -9,7 +9,6 @@
ToTensorOp,
)
from xdsl.dialects.builtin import (
- AnyMemRefTypeConstr,
AnyUnrankedMemrefTypeConstr,
IndexType,
IntegerType,
@@ -34,7 +33,7 @@
def test_tensor_from_memref_inference():
- constr = TensorFromMemrefConstraint(AnyMemRefTypeConstr)
+ constr = TensorFromMemrefConstraint(MemRefType.constr())
assert not constr.can_infer(set())
constr2 = TensorFromMemrefConstraint(
@@ -53,7 +52,7 @@ def test_tensor_from_memref_inference():
@irdl_op_definition
class TensorFromMemrefOp(IRDLOperation):
name = "test.tensor_from_memref"
- T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
+ T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
in_tensor = operand_def(
TensorFromMemrefConstraint(
diff --git a/tests/test_traits.py b/tests/test_traits.py
index f06fbcb98e..f8d06f55c8 100644
--- a/tests/test_traits.py
+++ b/tests/test_traits.py
@@ -15,7 +15,6 @@
from xdsl.dialects.builtin import (
DYNAMIC_INDEX,
AnyIntegerAttr,
- AnyMemRefTypeConstr,
AnyTensorTypeConstr,
AnyUnrankedMemrefTypeConstr,
AnyUnrankedTensorTypeConstr,
@@ -596,14 +595,14 @@ class SameOperandsAndResultTypeOp(IRDLOperation):
name = "test.same_operand_and_result_type"
ops = var_operand_def(
- AnyMemRefTypeConstr
+ MemRefType.constr()
| AnyUnrankedMemrefTypeConstr
| AnyUnrankedTensorTypeConstr
| AnyTensorTypeConstr
)
res = var_result_def(
- AnyMemRefTypeConstr
+ MemRefType.constr()
| AnyUnrankedMemrefTypeConstr
| AnyUnrankedTensorTypeConstr
| AnyTensorTypeConstr
diff --git a/xdsl/dialects/bufferization.py b/xdsl/dialects/bufferization.py
index 4efb28f460..5ff58c522a 100644
--- a/xdsl/dialects/bufferization.py
+++ b/xdsl/dialects/bufferization.py
@@ -3,7 +3,6 @@
from typing import Any, ClassVar
from xdsl.dialects.builtin import (
- AnyMemRefTypeConstr,
AnyTensorTypeConstr,
AnyUnrankedMemrefTypeConstr,
AnyUnrankedTensorTypeConstr,
@@ -140,7 +139,7 @@ def __init__(
class CloneOp(IRDLOperation):
name = "bufferization.clone"
- T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
+ T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
input = operand_def(T)
output = result_def(T)
@@ -156,7 +155,7 @@ def __init__(self, input: SSAValue | Operation):
class ToTensorOp(IRDLOperation):
name = "bufferization.to_tensor"
- T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
+ T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
memref = operand_def(T)
tensor = result_def(TensorFromMemrefConstraint(T))
@@ -196,7 +195,7 @@ def __init__(
class ToMemrefOp(IRDLOperation):
name = "bufferization.to_memref"
- T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
+ T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
tensor = operand_def(TensorFromMemrefConstraint(T))
memref = result_def(T)
@@ -209,7 +208,7 @@ class ToMemrefOp(IRDLOperation):
class MaterializeInDestinationOp(IRDLOperation):
name = "bufferization.materialize_in_destination"
- T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
+ T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
source = operand_def(TensorFromMemrefConstraint(T))
dest = operand_def(T | TensorFromMemrefConstraint(T))
result = opt_result_def(TensorFromMemrefConstraint(T))
diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py
index 49ad1bec6f..efefb47015 100644
--- a/xdsl/dialects/builtin.py
+++ b/xdsl/dialects/builtin.py
@@ -1975,9 +1975,6 @@ def constr(
)
-AnyMemRefTypeConstr = BaseAttr[MemRefType[Attribute]](MemRefType)
-
-
@dataclass(frozen=True, init=False)
class TensorOrMemrefOf(
GenericAttrConstraint[TensorType[AttributeCovT] | MemRefType[AttributeCovT]]
diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py
index 24b2340fd9..06d7e19bf8 100644
--- a/xdsl/dialects/csl/csl_stencil.py
+++ b/xdsl/dialects/csl/csl_stencil.py
@@ -6,7 +6,6 @@
from xdsl.dialects.builtin import (
AnyFloat,
AnyIntegerAttr,
- AnyMemRefTypeConstr,
AnyTensorTypeConstr,
Float16Type,
Float32Type,
@@ -132,7 +131,7 @@ class PrefetchOp(IRDLOperation):
name = "csl_stencil.prefetch"
input_stencil = operand_def(
- stencil.StencilTypeConstr | AnyMemRefTypeConstr | AnyTensorTypeConstr
+ stencil.StencilTypeConstr | MemRefType.constr() | AnyTensorTypeConstr
)
swaps = prop_def(builtin.ArrayAttr[ExchangeDeclarationAttr])
@@ -141,7 +140,7 @@ class PrefetchOp(IRDLOperation):
num_chunks = prop_def(AnyIntegerAttr)
- result = result_def(AnyMemRefTypeConstr | AnyTensorTypeConstr)
+ result = result_def(MemRefType.constr() | AnyTensorTypeConstr)
def __init__(
self,
@@ -227,13 +226,13 @@ class ApplyOp(IRDLOperation):
name = "csl_stencil.apply"
- field = operand_def(stencil.StencilTypeConstr | AnyMemRefTypeConstr)
+ field = operand_def(stencil.StencilTypeConstr | MemRefType.constr())
- accumulator = operand_def(AnyTensorTypeConstr | AnyMemRefTypeConstr)
+ accumulator = operand_def(AnyTensorTypeConstr | MemRefType.constr())
args_rchunk = var_operand_def(Attribute)
args_dexchng = var_operand_def(Attribute)
- dest = var_operand_def(stencil.FieldTypeConstr | AnyMemRefTypeConstr)
+ dest = var_operand_def(stencil.FieldTypeConstr | MemRefType.constr())
receive_chunk = region_def()
done_exchange = region_def()
@@ -364,7 +363,7 @@ def verify_(self) -> None:
# typecheck required (only) block arguments
assert isattr(
self.accumulator.type,
- AnyTensorTypeConstr | AnyMemRefTypeConstr,
+ AnyTensorTypeConstr | MemRefType.constr(),
)
chunk_region_req_types = [
type(self.accumulator.type)(
@@ -460,11 +459,11 @@ class AccessOp(IRDLOperation):
name = "csl_stencil.access"
op = operand_def(
- AnyMemRefTypeConstr | stencil.StencilTypeConstr | AnyTensorTypeConstr
+ MemRefType.constr() | stencil.StencilTypeConstr | AnyTensorTypeConstr
)
offset = prop_def(stencil.IndexAttr)
offset_mapping = opt_prop_def(stencil.IndexAttr)
- result = result_def(AnyTensorTypeConstr | AnyMemRefTypeConstr)
+ result = result_def(AnyTensorTypeConstr | MemRefType.constr())
traits = traits_def(HasAncestor(stencil.ApplyOp, ApplyOp), Pure())
@@ -582,7 +581,7 @@ def verify_(self) -> None:
f"{type(self)} access to own data requires type stencil.StencilType or memref.MemRefType but found {self.op.type}"
)
else:
- if not isattr(self.op.type, AnyTensorTypeConstr | AnyMemRefTypeConstr):
+ if not isattr(self.op.type, AnyTensorTypeConstr | MemRefType.constr()):
raise VerifyException(
f"{type(self)} access to neighbor data requires type memref.MemRefType or TensorType but found {self.op.type}"
)
diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py
index 0f8caaedc0..b8742b5905 100644
--- a/xdsl/dialects/memref_stream.py
+++ b/xdsl/dialects/memref_stream.py
@@ -17,13 +17,13 @@
from xdsl.dialects import memref
from xdsl.dialects.builtin import (
AffineMapAttr,
- AnyMemRefTypeConstr,
ArrayAttr,
ContainerType,
IndexType,
IntAttr,
IntegerAttr,
IntegerType,
+ MemRefType,
StringAttr,
)
from xdsl.dialects.utils import AbstractYieldOperation
@@ -463,7 +463,7 @@ class GenericOp(IRDLOperation):
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be read.
"""
- outputs = var_operand_def(AnyMemRefTypeConstr | WritableStreamType.constr())
+ outputs = var_operand_def(MemRefType.constr() | WritableStreamType.constr())
"""
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be written
diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py
index e9864d33d4..77f1a3b61d 100644
--- a/xdsl/dialects/stencil.py
+++ b/xdsl/dialects/stencil.py
@@ -9,11 +9,11 @@
from xdsl.dialects import builtin, memref
from xdsl.dialects.builtin import (
- AnyMemRefTypeConstr,
ArrayAttr,
IndexType,
IntAttr,
IntegerAttr,
+ MemRefType,
TensorType,
)
from xdsl.ir import (
@@ -910,7 +910,7 @@ class ExternalLoadOp(IRDLOperation):
name = "stencil.external_load"
field = operand_def(Attribute)
- result = result_def(base(FieldType[Attribute]) | AnyMemRefTypeConstr)
+ result = result_def(base(FieldType[Attribute]) | MemRefType.constr())
assembly_format = (
"$field attr-dict-with-keyword `:` type($field) `->` type($result)"
diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py
index 85a8f938f8..df8c739c04 100644
--- a/xdsl/transforms/convert_stencil_to_csl_stencil.py
+++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py
@@ -7,12 +7,12 @@
from xdsl.dialects import arith, builtin, memref, stencil, tensor, varith
from xdsl.dialects.builtin import (
AnyFloatAttr,
- AnyMemRefTypeConstr,
AnyTensorType,
DenseIntOrFPElementsAttr,
IndexType,
IntegerAttr,
IntegerType,
+ MemRefType,
ModuleOp,
TensorType,
)
@@ -177,7 +177,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
assert isattr(
op.input_stencil.type,
- AnyMemRefTypeConstr | stencil.StencilTypeConstr,
+ MemRefType.constr() | stencil.StencilTypeConstr,
)
assert isa(
t_type := op.input_stencil.type.get_element_type(), TensorType[Attribute]
diff --git a/xdsl/transforms/csl_stencil_to_csl_wrapper.py b/xdsl/transforms/csl_stencil_to_csl_wrapper.py
index 501cf9fdd3..065f7e01da 100644
--- a/xdsl/transforms/csl_stencil_to_csl_wrapper.py
+++ b/xdsl/transforms/csl_stencil_to_csl_wrapper.py
@@ -5,7 +5,6 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, func, llvm, memref, stencil
from xdsl.dialects.builtin import (
- AnyMemRefTypeConstr,
AnyTensorTypeConstr,
ArrayAttr,
DictionaryAttr,
@@ -113,7 +112,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
stencil.StencilTypeConstr,
) and isattr(
el_type := field_t.element_type,
- AnyTensorTypeConstr | AnyMemRefTypeConstr,
+ AnyTensorTypeConstr | MemRefType.constr(),
):
# unbufferized csl_stencil
z_dim = max(z_dim, el_type.get_shape()[-1])
@@ -124,7 +123,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
num_chunks = max(num_chunks, apply_op.num_chunks.value.data)
if isattr(
buf_t := apply_op.receive_chunk.block.args[0].type,
- AnyTensorTypeConstr | AnyMemRefTypeConstr,
+ AnyTensorTypeConstr | MemRefType.constr(),
):
chunk_size = max(chunk_size, buf_t.get_shape()[-1])
From 0bc26167cb5aaaf28b62d15fd990b0107693e721 Mon Sep 17 00:00:00 2001
From: Chris Vasiladiotis
Date: Tue, 4 Feb 2025 18:52:04 +0000
Subject: [PATCH 03/23] core: Disallow duplicate keys in attribute dictionaries
(#3830)
This PR:
- Disallows duplicate keys in properties and attribute dictionaries
- Tests of the above
Bumped into this by accident and the tests are similar to
https://github.com/llvm/llvm-project/blob/46b1543dc04970719caab0d4f9f65699fea6adbc/mlir/test/IR/invalid-builtin-attributes.mlir#L485
I don't venture often in the parser, so any feedback is welcome.
---
.../duplicate_attribute_keys.mlir | 14 +++++++++++++
xdsl/parser/attribute_parser.py | 12 +++++++++++
xdsl/parser/core.py | 20 ++++++++++++-------
3 files changed, 39 insertions(+), 7 deletions(-)
create mode 100644 tests/filecheck/parser-printer/duplicate_attribute_keys.mlir
diff --git a/tests/filecheck/parser-printer/duplicate_attribute_keys.mlir b/tests/filecheck/parser-printer/duplicate_attribute_keys.mlir
new file mode 100644
index 0000000000..68839ceb76
--- /dev/null
+++ b/tests/filecheck/parser-printer/duplicate_attribute_keys.mlir
@@ -0,0 +1,14 @@
+// RUN: xdsl-opt %s --parsing-diagnostics --split-input-file | filecheck %s
+
+// CHECK: Duplicate key 'key1' in dictionary attribute
+"test.op"() {a = {key1, key1}} : () -> ()
+
+// -----
+
+// CHECK: Duplicate key 'key1' in dictionary attribute
+"test.op"() {key1, key1} : () -> ()
+
+// -----
+
+// CHECK: Duplicate key 'key1' in properties dictionary
+"test.op"() <{key1, key1}> : () -> ()
diff --git a/xdsl/parser/attribute_parser.py b/xdsl/parser/attribute_parser.py
index f4cc20df73..ba6e02a78c 100644
--- a/xdsl/parser/attribute_parser.py
+++ b/xdsl/parser/attribute_parser.py
@@ -195,12 +195,24 @@ def _parse_attribute_entry(self) -> tuple[str, Attribute]:
return name, self.parse_attribute()
+ def _find_duplicated_key(self, attrs: list[tuple[str, Attribute]]) -> str | None:
+ seen_keys: set[str] = set()
+ for key, _ in attrs:
+ if key in seen_keys:
+ return key
+ seen_keys.add(key)
+ return None
+
def parse_optional_dictionary_attr_dict(self) -> dict[str, Attribute]:
attrs = self.parse_optional_comma_separated_list(
self.Delimiter.BRACES, self._parse_attribute_entry
)
if attrs is None:
return dict()
+
+ if (key := self._find_duplicated_key(attrs)) is not None:
+ self.raise_error(f"Duplicate key '{key}' in dictionary attribute")
+
return dict(attrs)
def _parse_dialect_type_or_attribute_body(
diff --git a/xdsl/parser/core.py b/xdsl/parser/core.py
index d49e855e57..c9b15e14db 100644
--- a/xdsl/parser/core.py
+++ b/xdsl/parser/core.py
@@ -646,6 +646,7 @@ def parse_optional_attr_dict_with_reserved_attr_names(
"""
begin_pos = self.lexer.pos
attr = self._parse_builtin_dict_attr()
+
for reserved_name in reserved_attr_names:
if reserved_name in attr.data:
self.raise_error(
@@ -792,13 +793,18 @@ def parse_optional_properties_dict(self) -> dict[str, Attribute]:
dictionary-attribute ::= `{` (attribute-entry (`,` attribute-entry)*)? `}`
properties ::= `<` dictionary-attribute `>`
"""
- if self.parse_optional_punctuation("<") is not None:
- entries = self.parse_comma_separated_list(
- self.Delimiter.BRACES, self._parse_attribute_entry
- )
- self.parse_punctuation(">")
- return dict(entries)
- return dict()
+ if self.parse_optional_punctuation("<") is None:
+ return dict()
+
+ entries = self.parse_comma_separated_list(
+ self.Delimiter.BRACES, self._parse_attribute_entry
+ )
+ self.parse_punctuation(">")
+
+ if (key := self._find_duplicated_key(entries)) is not None:
+ self.raise_error(f"Duplicate key '{key}' in properties dictionary")
+
+ return dict(entries)
def resolve_operands(
self,
From 324792fe505ca011738897bb49eee3cd795d3230 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Wed, 5 Feb 2025 08:33:36 +0000
Subject: [PATCH 04/23] pip prod(deps): bump marimo from 0.10.19 to 0.11.0
(#3834)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Bumps [marimo](https://github.com/marimo-team/marimo) from 0.10.19 to
0.11.0.
Release notes
Sourced from marimo's
releases.
0.11.0
Highlights ⭐
- SQL Engine Support. Connect to various databases
like postgresql, mysql, snowflake and more, using your preferred SQL
engine.
This release adds support for using for multiple SQL connection
libraries, such as SQLModel and SQLAlchemy. You can now define SQL
connections in your code like:
import sqlalchemy
import sqlmodel
import duckdb
Create an in-memory SQLite database with SQLAlchemy
sqlite_engine =
sqlachemy.create_engine("sqlite:///:memory:")
Create a Postgres database with SQLModel
postgres_engine =
sqlmodel.create_engine("postgresql://username:password@server:port/database")
Create a DuckDB connection
duckdb_conn = duckdb.connect("file.db")
And then select which connection to use in the SQL cell.
![image](https://github.com/user-attachments/assets/8ba7f094-aefd-4cfc-95c0-19da5ee3378e)
h/t @Light2Dark
- Markdown file-format improvements - Markdown
notebooks (i.e.
marimo edit notebook.md
) has an improved
syntax format: python {.marimo}
. You can also use SQL cells
in the markdown file-format, using sql {.marimo}
. To learn
more, run marimo tutorial markdown-format
h/t @dmadisetti
-
Markdown syntax - Added support for details,
admonitions, and emojis in markdown
-
Performance & Reliability - Lots of bug fixes
for better resource cleanup and memory management, as well as disabling
features not used in run-mode.
What's Changed
... (truncated)
Commits
3b84b1f
release: 0.11.0
2dc721c
fix: sql dropdown when initialized (#3687)
33dcefb
fix: checks whether latestEngineSelected is in list of engines (#3685)
79da93a
smoke test: docstring_to_markdown (#3686)
3426a56
chore(deps): update dependency vitest to v3.0.5 [security] (#3684)
2ee4419
Add openai-whisper to module name list (#3681)
16a1d6a
improv: update the sql dropdown to a better design and show disconnected
engi...
90e3987
[pre-commit.ci] pre-commit autoupdate (#3678)
c583b6d
fix: filtering datasets based on variables (#3676)
4c6c6d1
improvement: show engine's tables in the datasources panel (#3665)
- Additional commits viewable in compare
view
[![Dependabot compatibility
score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=marimo&package-manager=pip&previous-version=0.10.19&new-version=0.11.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores)
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
---------
Signed-off-by: dependabot[bot]
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: compor
---
pyproject.toml | 2 +-
uv.lock | 8 ++++----
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 54acf2bede..413bef5a79 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,7 +23,7 @@ dev = [
"nbval<0.12",
"filecheck==1.0.1",
"lit<19.0.0",
- "marimo==0.10.19",
+ "marimo==0.11.0",
"pre-commit==4.1.0",
"ruff==0.9.4",
"nbconvert>=7.7.2,<8.0.0",
diff --git a/uv.lock b/uv.lock
index e58096ffc6..e14ee9ebd7 100644
--- a/uv.lock
+++ b/uv.lock
@@ -907,7 +907,7 @@ wheels = [
[[package]]
name = "marimo"
-version = "0.10.19"
+version = "0.11.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
@@ -929,9 +929,9 @@ dependencies = [
{ name = "uvicorn" },
{ name = "websockets" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/9c/24/1010293c645f486a4f3175f89a95ea92e0a5005b368cc34738554828734e/marimo-0.10.19.tar.gz", hash = "sha256:eec812e874156852d825f50aa738cbaf2f55d8be8d36449febce443f991ccb75", size = 11870444 }
+sdist = { url = "https://files.pythonhosted.org/packages/a4/cd/066190b340063ba9002f8833ba683a1f49ef64ad0875fbea302099575b66/marimo-0.11.0.tar.gz", hash = "sha256:9c56a85202828535559e3b7b77db59f8df672166812d8bc81cb3b658df51a5f5", size = 10553579 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/ee/a1/6c6f7d17b2a7e11d0bbcca4dcfb392a9e8cea130b819f938ab027deba097/marimo-0.10.19-py3-none-any.whl", hash = "sha256:9e3f4ba8437359b7adcfa24dcf828eff99f6ed5e5a061b8f2943ed5b4509773a", size = 12175728 },
+ { url = "https://files.pythonhosted.org/packages/1f/7d/49ef32a53506762c932b6197f2df28450fe25e43a01fa7cfbd732d578f00/marimo-0.11.0-py3-none-any.whl", hash = "sha256:688652c0b69397ef40b412ecbc313a69f6ccadb1958b9ee66a1fcdfba98d0825", size = 10862632 },
]
[[package]]
@@ -2776,7 +2776,7 @@ requires-dist = [
{ name = "ipykernel", marker = "extra == 'dev'" },
{ name = "jax", marker = "extra == 'jax'", specifier = "==0.5.0" },
{ name = "lit", marker = "extra == 'dev'", specifier = "<19.0.0" },
- { name = "marimo", marker = "extra == 'dev'", specifier = "==0.10.19" },
+ { name = "marimo", marker = "extra == 'dev'", specifier = "==0.11.0" },
{ name = "mkdocs", marker = "extra == 'docs'", specifier = ">=1.6.1" },
{ name = "mkdocs-gen-files", marker = "extra == 'docs'", specifier = ">=0.5.0" },
{ name = "mkdocs-material", marker = "extra == 'docs'", specifier = ">=9.5.49" },
From 83ee4c0e8840f27626d7d0df5713645e082fbb4c Mon Sep 17 00:00:00 2001
From: Alex Rice
Date: Wed, 5 Feb 2025 09:48:01 +0000
Subject: [PATCH 05/23] dialects: (builtin/memref) rename Memref to MemRef
everywhere (#3833)
Makes the spelling consistent between `MemRefType` and
`UnrankedMemrefType`.
Fixes #3474
---
docs/Toy/toy/__main__.py | 4 +-
docs/Toy/toy/compiler.py | 4 +-
docs/Toy/toy/rewrites/lower_toy_affine.py | 6 +--
docs/marimo/__marimo__/linalg_snitch.ipynb | 4 +-
docs/marimo/linalg_snitch.py | 4 +-
tests/dialects/test_bufferization.py | 30 +++++++-------
tests/dialects/test_builtin.py | 6 +--
tests/dialects/test_memref.py | 4 +-
tests/dialects/test_mpi.py | 2 +-
tests/interpreters/test_affine_interpreter.py | 4 +-
tests/interpreters/test_memref_interpreter.py | 6 +--
.../test_memref_stream_interpreter.py | 14 +++----
tests/test_traits.py | 6 +--
xdsl/backend/csl/print_csl.py | 2 +-
.../riscv/lowering/convert_memref_to_riscv.py | 32 +++++++--------
xdsl/dialects/__init__.py | 4 +-
xdsl/dialects/affine.py | 2 +-
xdsl/dialects/bufferization.py | 32 +++++++--------
xdsl/dialects/builtin.py | 40 +++++++++----------
xdsl/dialects/gpu.py | 4 +-
xdsl/dialects/memref.py | 32 +++++++--------
xdsl/dialects/memref_stream.py | 2 +-
xdsl/dialects/mpi.py | 4 +-
xdsl/interpreters/__init__.py | 4 +-
xdsl/interpreters/memref.py | 4 +-
xdsl/interpreters/memref_stream.py | 2 +-
xdsl/interpreters/shaped_array.py | 2 +-
xdsl/parser/attribute_parser.py | 16 ++++----
xdsl/printer.py | 8 ++--
xdsl/transforms/__init__.py | 34 ++++++++--------
.../canonicalization_patterns/memref.py | 2 +-
.../convert_linalg_to_memref_stream.py | 2 +-
.../convert_memref_stream_to_loops.py | 2 +-
.../convert_memref_stream_to_snitch_stream.py | 2 +-
xdsl/transforms/convert_memref_to_ptr.py | 14 +++----
.../convert_ml_program_to_memref.py | 2 +-
xdsl/transforms/convert_ptr_to_riscv.py | 4 +-
xdsl/transforms/csl_stencil_bufferize.py | 6 +--
.../convert_stencil_to_ll_mlir.py | 24 +++++------
.../dmp/stencil_global_to_local.py | 14 +++----
.../hls_convert_stencil_to_ll_mlir.py | 12 +++---
xdsl/transforms/gpu_allocs.py | 2 +-
xdsl/transforms/loop_hoist_memref.py | 6 +--
xdsl/transforms/lower_mpi.py | 10 ++---
xdsl/transforms/memref_stream_fold_fill.py | 2 +-
.../memref_stream_generalize_fill.py | 2 +-
xdsl/transforms/memref_stream_infer_fill.py | 2 +-
xdsl/transforms/memref_stream_interleave.py | 2 +-
xdsl/transforms/memref_stream_legalize.py | 6 +--
.../memref_stream_tile_outer_loops.py | 2 +-
.../memref_stream_unnest_out_parameters.py | 2 +-
xdsl/transforms/memref_streamify.py | 2 +-
xdsl/transforms/memref_to_dsd.py | 8 ++--
.../transforms/test_lower_linalg_to_snitch.py | 24 +++++------
54 files changed, 236 insertions(+), 236 deletions(-)
diff --git a/docs/Toy/toy/__main__.py b/docs/Toy/toy/__main__.py
index 0626dedb44..6dc5125175 100644
--- a/docs/Toy/toy/__main__.py
+++ b/docs/Toy/toy/__main__.py
@@ -6,7 +6,7 @@
from xdsl.interpreters.arith import ArithFunctions
from xdsl.interpreters.builtin import BuiltinFunctions
from xdsl.interpreters.func import FuncFunctions
-from xdsl.interpreters.memref import MemrefFunctions
+from xdsl.interpreters.memref import MemRefFunctions
from xdsl.interpreters.printf import PrintfFunctions
from xdsl.interpreters.riscv_cf import RiscvCfFunctions
from xdsl.interpreters.riscv_debug import RiscvDebugFunctions
@@ -104,7 +104,7 @@ def main(path: Path, emit: str, ir: bool, print_generic: bool):
interpreter.register_implementations(AffineFunctions())
if emit in ("affine", "scf"):
interpreter.register_implementations(ArithFunctions())
- interpreter.register_implementations(MemrefFunctions())
+ interpreter.register_implementations(MemRefFunctions())
interpreter.register_implementations(PrintfFunctions())
interpreter.register_implementations(FuncFunctions())
if emit == "scf":
diff --git a/docs/Toy/toy/compiler.py b/docs/Toy/toy/compiler.py
index 4729dee117..7385a459a4 100644
--- a/docs/Toy/toy/compiler.py
+++ b/docs/Toy/toy/compiler.py
@@ -5,7 +5,7 @@
from xdsl.backend.riscv.lowering.convert_func_to_riscv_func import (
ConvertFuncToRiscvFuncPass,
)
-from xdsl.backend.riscv.lowering.convert_memref_to_riscv import ConvertMemrefToRiscvPass
+from xdsl.backend.riscv.lowering.convert_memref_to_riscv import ConvertMemRefToRiscvPass
from xdsl.backend.riscv.lowering.convert_print_format_to_riscv_debug import (
ConvertPrintFormatToRiscvDebugPass,
)
@@ -99,7 +99,7 @@ def transform(
return
ConvertFuncToRiscvFuncPass().apply(ctx, module_op)
- ConvertMemrefToRiscvPass().apply(ctx, module_op)
+ ConvertMemRefToRiscvPass().apply(ctx, module_op)
ConvertPrintFormatToRiscvDebugPass().apply(ctx, module_op)
ConvertArithToRiscvPass().apply(ctx, module_op)
ConvertScfToRiscvPass().apply(ctx, module_op)
diff --git a/docs/Toy/toy/rewrites/lower_toy_affine.py b/docs/Toy/toy/rewrites/lower_toy_affine.py
index 435ed59168..502fb95f6d 100644
--- a/docs/Toy/toy/rewrites/lower_toy_affine.py
+++ b/docs/Toy/toy/rewrites/lower_toy_affine.py
@@ -36,10 +36,10 @@
# region Helpers
-MemrefTypeF64: TypeAlias = memref.MemRefType[Float64Type]
+MemRefTypeF64: TypeAlias = memref.MemRefType[Float64Type]
-def convert_tensor_to_memref(type: toy.TensorTypeF64) -> MemrefTypeF64:
+def convert_tensor_to_memref(type: toy.TensorTypeF64) -> MemRefTypeF64:
"""
Convert the given RankedTensorType into the corresponding MemRefType.
"""
@@ -47,7 +47,7 @@ def convert_tensor_to_memref(type: toy.TensorTypeF64) -> MemrefTypeF64:
def insert_alloc_and_dealloc(
- type: MemrefTypeF64, op: Operation, rewriter: PatternRewriter
+ type: MemRefTypeF64, op: Operation, rewriter: PatternRewriter
) -> memref.AllocOp:
"""
Insert an allocation and deallocation for the given MemRefType.
diff --git a/docs/marimo/__marimo__/linalg_snitch.ipynb b/docs/marimo/__marimo__/linalg_snitch.ipynb
index b3490c91e6..7cf956880a 100644
--- a/docs/marimo/__marimo__/linalg_snitch.ipynb
+++ b/docs/marimo/__marimo__/linalg_snitch.ipynb
@@ -281,7 +281,7 @@
" [\n",
" convert_linalg_to_loops.ConvertLinalgToLoopsPass(),\n",
" convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass(),\n",
- " convert_memref_to_riscv.ConvertMemrefToRiscvPass(),\n",
+ " convert_memref_to_riscv.ConvertMemRefToRiscvPass(),\n",
" convert_arith_to_riscv.ConvertArithToRiscvPass(),\n",
" convert_scf_to_riscv_scf.ConvertScfToRiscvPass(),\n",
" reconcile_unrealized_casts.ReconcileUnrealizedCastsPass(),\n",
@@ -434,7 +434,7 @@
"\n",
"convert_linalg_to_snitch = PipelinePass(\n",
" [\n",
- " convert_linalg_to_memref_stream.ConvertLinalgToMemrefStreamPass(),\n",
+ " convert_linalg_to_memref_stream.ConvertLinalgToMemRefStreamPass(),\n",
" arith_add_fastmath.AddArithFastMathFlagsPass(),\n",
" *OPTIMISE_MEMREF_STREAM_PASSES,\n",
" *LOWER_MEMREF_STREAM_TO_SNITCH_STREAM_PASSES,\n",
diff --git a/docs/marimo/linalg_snitch.py b/docs/marimo/linalg_snitch.py
index 18d58e2f15..d02ac55cf8 100644
--- a/docs/marimo/linalg_snitch.py
+++ b/docs/marimo/linalg_snitch.py
@@ -287,7 +287,7 @@ def _(
[
convert_linalg_to_loops.ConvertLinalgToLoopsPass(),
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass(),
- convert_memref_to_riscv.ConvertMemrefToRiscvPass(),
+ convert_memref_to_riscv.ConvertMemRefToRiscvPass(),
convert_arith_to_riscv.ConvertArithToRiscvPass(),
convert_scf_to_riscv_scf.ConvertScfToRiscvPass(),
reconcile_unrealized_casts.ReconcileUnrealizedCastsPass(),
@@ -407,7 +407,7 @@ def _(
convert_linalg_to_snitch = PipelinePass(
[
- convert_linalg_to_memref_stream.ConvertLinalgToMemrefStreamPass(),
+ convert_linalg_to_memref_stream.ConvertLinalgToMemRefStreamPass(),
arith_add_fastmath.AddArithFastMathFlagsPass(),
*OPTIMISE_MEMREF_STREAM_PASSES,
*LOWER_MEMREF_STREAM_TO_SNITCH_STREAM_PASSES,
diff --git a/tests/dialects/test_bufferization.py b/tests/dialects/test_bufferization.py
index 1dd809c0d5..a8d5ac05ab 100644
--- a/tests/dialects/test_bufferization.py
+++ b/tests/dialects/test_bufferization.py
@@ -5,17 +5,17 @@
from xdsl.dialects.bufferization import (
AllocTensorOp,
CloneOp,
- TensorFromMemrefConstraint,
+ TensorFromMemRefConstraint,
ToTensorOp,
)
from xdsl.dialects.builtin import (
- AnyUnrankedMemrefTypeConstr,
+ AnyUnrankedMemRefTypeConstr,
IndexType,
IntegerType,
MemRefType,
TensorType,
UnitAttr,
- UnrankedMemrefType,
+ UnrankedMemRefType,
UnrankedTensorType,
f64,
)
@@ -33,36 +33,36 @@
def test_tensor_from_memref_inference():
- constr = TensorFromMemrefConstraint(MemRefType.constr())
+ constr = TensorFromMemRefConstraint(MemRefType.constr())
assert not constr.can_infer(set())
- constr2 = TensorFromMemrefConstraint(
+ constr2 = TensorFromMemRefConstraint(
EqAttrConstraint(MemRefType(f64, [10, 20, 30]))
)
assert constr2.can_infer(set())
assert constr2.infer(InferenceContext()) == TensorType(f64, [10, 20, 30])
- constr3 = TensorFromMemrefConstraint(
- EqAttrConstraint(UnrankedMemrefType.from_type(f64))
+ constr3 = TensorFromMemRefConstraint(
+ EqAttrConstraint(UnrankedMemRefType.from_type(f64))
)
assert constr3.can_infer(set())
assert constr3.infer(InferenceContext()) == UnrankedTensorType(f64)
@irdl_op_definition
-class TensorFromMemrefOp(IRDLOperation):
+class TensorFromMemRefOp(IRDLOperation):
name = "test.tensor_from_memref"
- T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
+ T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemRefTypeConstr)
in_tensor = operand_def(
- TensorFromMemrefConstraint(
+ TensorFromMemRefConstraint(
MemRefType.constr(element_type=EqAttrConstraint(IndexType()))
)
)
in_var_memref = operand_def(T)
- in_var_tensor = operand_def(TensorFromMemrefConstraint(T))
+ in_var_tensor = operand_def(TensorFromMemRefConstraint(T))
def test_tensor_from_memref_constraint():
@@ -72,16 +72,16 @@ def test_tensor_from_memref_constraint():
TensorType(IndexType(), [10, 20, 30]),
]
).res
- op1 = TensorFromMemrefOp(operands=(v_tensor, v_memref, v_tensor))
+ op1 = TensorFromMemRefOp(operands=(v_tensor, v_memref, v_tensor))
op1.verify()
[v_unranked_memref, v_unranked_tensor] = TestOp(
result_types=[
- UnrankedMemrefType.from_type(IndexType()),
+ UnrankedMemRefType.from_type(IndexType()),
UnrankedTensorType(IndexType()),
]
).res
- op2 = TensorFromMemrefOp(operands=(v_tensor, v_unranked_memref, v_unranked_tensor))
+ op2 = TensorFromMemRefOp(operands=(v_tensor, v_unranked_memref, v_unranked_tensor))
op2.verify()
@@ -131,7 +131,7 @@ def test_tensor_from_memref_constraint_failure(
]
).res
- op1 = TensorFromMemrefOp(operands=(v1, v2, v3))
+ op1 = TensorFromMemRefOp(operands=(v1, v2, v3))
with pytest.raises(VerifyException, match=error):
op1.verify()
diff --git a/tests/dialects/test_builtin.py b/tests/dialects/test_builtin.py
index 2636c9fab7..6f3be45191 100644
--- a/tests/dialects/test_builtin.py
+++ b/tests/dialects/test_builtin.py
@@ -31,7 +31,7 @@
Signedness,
StridedLayoutAttr,
SymbolRefAttr,
- TensorOrMemrefOf,
+ TensorOrMemRefOf,
TensorType,
UnrealizedConversionCastOp,
VectorBaseTypeAndRankConstraint,
@@ -658,14 +658,14 @@ def test_strides():
def test_tensor_or_memref_of_constraint_verify():
- constraint = TensorOrMemrefOf(i64)
+ constraint = TensorOrMemRefOf(i64)
constraint.verify(MemRefType(i64, [1]), ConstraintContext())
constraint.verify(TensorType(i64, [1]), ConstraintContext())
def test_tensor_or_memref_of_constraint_attribute_mismatch():
- constraint = TensorOrMemrefOf(i64)
+ constraint = TensorOrMemRefOf(i64)
with pytest.raises(
VerifyException, match=f"Expected tensor or memref type, got {i64}"
diff --git a/tests/dialects/test_memref.py b/tests/dialects/test_memref.py
index 26f1b29e4d..6a239dad79 100644
--- a/tests/dialects/test_memref.py
+++ b/tests/dialects/test_memref.py
@@ -12,7 +12,7 @@
MemRefType,
NoneAttr,
StridedLayoutAttr,
- UnrankedMemrefType,
+ UnrankedMemRefType,
i32,
i64,
)
@@ -327,7 +327,7 @@ def test_memref_cast():
i32_memref_type = MemRefType(i32, [10, 2])
memref_ssa_value = TestSSAValue(i32_memref_type)
- res_type = UnrankedMemrefType.from_type(i32)
+ res_type = UnrankedMemRefType.from_type(i32)
cast = CastOp.get(memref_ssa_value, res_type)
diff --git a/tests/dialects/test_mpi.py b/tests/dialects/test_mpi.py
index 5c6d832be1..0a30142029 100644
--- a/tests/dialects/test_mpi.py
+++ b/tests/dialects/test_mpi.py
@@ -9,7 +9,7 @@ def test_mpi_baseop():
"""
alloc0 = memref.AllocOp.get(f64, 32, [100, 14, 14])
dest = ConstantOp.from_int_and_width(1, i32)
- unwrap = mpi.UnwrapMemrefOp(alloc0)
+ unwrap = mpi.UnwrapMemRefOp(alloc0)
req_vec = mpi.AllocateTypeOp(mpi.RequestType, dest)
req_obj = mpi.VectorGetOp(req_vec, dest)
tag = ConstantOp.from_int_and_width(1, i32)
diff --git a/tests/interpreters/test_affine_interpreter.py b/tests/interpreters/test_affine_interpreter.py
index 24ee180d66..107b27a25c 100644
--- a/tests/interpreters/test_affine_interpreter.py
+++ b/tests/interpreters/test_affine_interpreter.py
@@ -7,7 +7,7 @@
from xdsl.interpreters.affine import AffineFunctions
from xdsl.interpreters.arith import ArithFunctions
from xdsl.interpreters.func import FuncFunctions
-from xdsl.interpreters.memref import MemrefFunctions
+from xdsl.interpreters.memref import MemRefFunctions
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir.affine import AffineMap
@@ -75,7 +75,7 @@ def test_functions():
interpreter = Interpreter(module_op)
interpreter.register_implementations(ArithFunctions())
- interpreter.register_implementations(MemrefFunctions())
+ interpreter.register_implementations(MemRefFunctions())
interpreter.register_implementations(AffineFunctions())
interpreter.register_implementations(FuncFunctions())
diff --git a/tests/interpreters/test_memref_interpreter.py b/tests/interpreters/test_memref_interpreter.py
index b688c32229..79e37b6ce2 100644
--- a/tests/interpreters/test_memref_interpreter.py
+++ b/tests/interpreters/test_memref_interpreter.py
@@ -10,13 +10,13 @@
)
from xdsl.interpreter import Interpreter
from xdsl.interpreters.arith import ArithFunctions
-from xdsl.interpreters.memref import MemrefFunctions
+from xdsl.interpreters.memref import MemRefFunctions
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr
interpreter = Interpreter(ModuleOp([]), index_bitwidth=32)
interpreter.register_implementations(ArithFunctions())
-interpreter.register_implementations(MemrefFunctions())
+interpreter.register_implementations(MemRefFunctions())
index = IndexType()
@@ -63,7 +63,7 @@ def test_memref_get_global():
fetch = memref.GetGlobalOp("my_global", memref_type)
interpreter = Interpreter(module, index_bitwidth=32)
- interpreter.register_implementations(MemrefFunctions())
+ interpreter.register_implementations(MemRefFunctions())
(result,) = interpreter.run_op(fetch, ())
assert result == ShapedArray(
diff --git a/tests/interpreters/test_memref_stream_interpreter.py b/tests/interpreters/test_memref_stream_interpreter.py
index 0c3340d442..e7a3dd4468 100644
--- a/tests/interpreters/test_memref_stream_interpreter.py
+++ b/tests/interpreters/test_memref_stream_interpreter.py
@@ -15,7 +15,7 @@
)
from xdsl.interpreter import Interpreter
from xdsl.interpreters.arith import ArithFunctions
-from xdsl.interpreters.memref_stream import MemrefStreamFunctions
+from xdsl.interpreters.memref_stream import MemRefStreamFunctions
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir import Block, Region
@@ -31,7 +31,7 @@ def index(value: int) -> IntegerAttr[IndexType]:
def test_memref_stream_generic():
interpreter = Interpreter(ModuleOp([]))
- interpreter.register_implementations(MemrefStreamFunctions())
+ interpreter.register_implementations(MemRefStreamFunctions())
interpreter.register_implementations(ArithFunctions())
op = memref_stream.GenericOp(
@@ -85,7 +85,7 @@ def test_memref_stream_generic():
def test_memref_stream_generic_scalar():
interpreter = Interpreter(ModuleOp([]))
- interpreter.register_implementations(MemrefStreamFunctions())
+ interpreter.register_implementations(MemRefStreamFunctions())
interpreter.register_implementations(ArithFunctions())
op = memref_stream.GenericOp(
@@ -139,7 +139,7 @@ def test_memref_stream_generic_scalar():
def test_memref_stream_generic_reduction():
interpreter = Interpreter(ModuleOp([]))
- interpreter.register_implementations(MemrefStreamFunctions())
+ interpreter.register_implementations(MemRefStreamFunctions())
interpreter.register_implementations(ArithFunctions())
op = memref_stream.GenericOp(
@@ -180,7 +180,7 @@ def test_memref_stream_generic_reduction():
def test_memref_stream_generic_imperfect_nesting():
interpreter = Interpreter(ModuleOp([]))
- interpreter.register_implementations(MemrefStreamFunctions())
+ interpreter.register_implementations(MemRefStreamFunctions())
interpreter.register_implementations(ArithFunctions())
f32 = Float32Type()
@@ -231,7 +231,7 @@ def test_memref_stream_generic_imperfect_nesting():
def test_memref_stream_generic_reduction_with_initial_value():
interpreter = Interpreter(ModuleOp([]))
- interpreter.register_implementations(MemrefStreamFunctions())
+ interpreter.register_implementations(MemRefStreamFunctions())
interpreter.register_implementations(ArithFunctions())
f32 = Float32Type()
@@ -282,7 +282,7 @@ def test_memref_stream_generic_reduction_with_initial_value():
def test_memref_stream_interleaved_reduction_with_initial_value():
interpreter = Interpreter(ModuleOp([]))
- interpreter.register_implementations(MemrefStreamFunctions())
+ interpreter.register_implementations(MemRefStreamFunctions())
interpreter.register_implementations(ArithFunctions())
f32 = Float32Type()
diff --git a/tests/test_traits.py b/tests/test_traits.py
index f8d06f55c8..7121a1470b 100644
--- a/tests/test_traits.py
+++ b/tests/test_traits.py
@@ -16,7 +16,7 @@
DYNAMIC_INDEX,
AnyIntegerAttr,
AnyTensorTypeConstr,
- AnyUnrankedMemrefTypeConstr,
+ AnyUnrankedMemRefTypeConstr,
AnyUnrankedTensorTypeConstr,
IntegerAttr,
IntegerType,
@@ -596,14 +596,14 @@ class SameOperandsAndResultTypeOp(IRDLOperation):
ops = var_operand_def(
MemRefType.constr()
- | AnyUnrankedMemrefTypeConstr
+ | AnyUnrankedMemRefTypeConstr
| AnyUnrankedTensorTypeConstr
| AnyTensorTypeConstr
)
res = var_result_def(
MemRefType.constr()
- | AnyUnrankedMemrefTypeConstr
+ | AnyUnrankedMemRefTypeConstr
| AnyUnrankedTensorTypeConstr
| AnyTensorTypeConstr
)
diff --git a/xdsl/backend/csl/print_csl.py b/xdsl/backend/csl/print_csl.py
index 82929ae758..ec1a395232 100644
--- a/xdsl/backend/csl/print_csl.py
+++ b/xdsl/backend/csl/print_csl.py
@@ -253,7 +253,7 @@ def _memref_global_init(self, init: Attribute, type: str) -> str:
case DenseIntOrFPElementsAttr():
data = init.get_attrs()
assert len(data) == 1, (
- f"Memref global initialiser has to have 1 value, got {len(data)}"
+ f"MemRef global initialiser has to have 1 value, got {len(data)}"
)
return f" = @constants({type}, {self.attribute_value_to_str(data[0])})"
case other:
diff --git a/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py b/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py
index d98aa08cf5..30e3d966c7 100644
--- a/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py
+++ b/xdsl/backend/riscv/lowering/convert_memref_to_riscv.py
@@ -38,7 +38,7 @@
from xdsl.utils.exceptions import DiagnosticException
-class ConvertMemrefAllocOp(RewritePattern):
+class ConvertMemRefAllocOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.AllocOp, rewriter: PatternRewriter) -> None:
assert isinstance(op_memref_type := op.memref.type, memref.MemRefType)
@@ -61,7 +61,7 @@ def match_and_rewrite(self, op: memref.AllocOp, rewriter: PatternRewriter) -> No
)
-class ConvertMemrefDeallocOp(RewritePattern):
+class ConvertMemRefDeallocOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: memref.DeallocOp, rewriter: PatternRewriter
@@ -174,7 +174,7 @@ def get_strided_pointer(
return ops, ptr.rd
-class ConvertMemrefStoreOp(RewritePattern):
+class ConvertMemRefStoreOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.StoreOp, rewriter: PatternRewriter):
assert isinstance(op_memref_type := op.memref.type, memref.MemRefType)
@@ -217,7 +217,7 @@ def match_and_rewrite(self, op: memref.StoreOp, rewriter: PatternRewriter):
rewriter.replace_matched_op(new_op)
-class ConvertMemrefLoadOp(RewritePattern):
+class ConvertMemRefLoadOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter):
assert isinstance(op_memref_type := op.memref.type, memref.MemRefType), (
@@ -263,7 +263,7 @@ def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter):
)
-class ConvertMemrefGlobalOp(RewritePattern):
+class ConvertMemRefGlobalOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.GlobalOp, rewriter: PatternRewriter):
initial_value = op.initial_value
@@ -311,7 +311,7 @@ def match_and_rewrite(self, op: memref.GlobalOp, rewriter: PatternRewriter):
rewriter.replace_matched_op(section)
-class ConvertMemrefGetGlobalOp(RewritePattern):
+class ConvertMemRefGetGlobalOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.GetGlobalOp, rewriter: PatternRewriter):
rewriter.replace_matched_op(
@@ -322,7 +322,7 @@ def match_and_rewrite(self, op: memref.GetGlobalOp, rewriter: PatternRewriter):
)
-class ConvertMemrefSubviewOp(RewritePattern):
+class ConvertMemRefSubviewOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter):
# Assumes that the operation is valid, meaning that the subview is indeed a
@@ -425,25 +425,25 @@ def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter):
)
-class ConvertMemrefToRiscvPass(ModulePass):
+class ConvertMemRefToRiscvPass(ModulePass):
name = "convert-memref-to-riscv"
def apply(self, ctx: MLContext, op: ModuleOp) -> None:
- contains_malloc = PatternRewriteWalker(ConvertMemrefAllocOp()).rewrite_module(
+ contains_malloc = PatternRewriteWalker(ConvertMemRefAllocOp()).rewrite_module(
op
)
contains_dealloc = PatternRewriteWalker(
- ConvertMemrefDeallocOp()
+ ConvertMemRefDeallocOp()
).rewrite_module(op)
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
- ConvertMemrefDeallocOp(),
- ConvertMemrefStoreOp(),
- ConvertMemrefLoadOp(),
- ConvertMemrefGlobalOp(),
- ConvertMemrefGetGlobalOp(),
- ConvertMemrefSubviewOp(),
+ ConvertMemRefDeallocOp(),
+ ConvertMemRefStoreOp(),
+ ConvertMemRefLoadOp(),
+ ConvertMemRefGlobalOp(),
+ ConvertMemRefGetGlobalOp(),
+ ConvertMemRefSubviewOp(),
]
)
).rewrite_module(op)
diff --git a/xdsl/dialects/__init__.py b/xdsl/dialects/__init__.py
index d96749e297..651f20e6b3 100644
--- a/xdsl/dialects/__init__.py
+++ b/xdsl/dialects/__init__.py
@@ -154,9 +154,9 @@ def get_memref():
return MemRef
def get_memref_stream():
- from xdsl.dialects.memref_stream import MemrefStream
+ from xdsl.dialects.memref_stream import MemRefStream
- return MemrefStream
+ return MemRefStream
def get_ml_program():
from xdsl.dialects.ml_program import MLProgram
diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py
index 48c5319f17..7608faf645 100644
--- a/xdsl/dialects/affine.py
+++ b/xdsl/dialects/affine.py
@@ -279,7 +279,7 @@ def __init__(
# for zero-dimensional memrefs.
if not isinstance(memref_type := memref.type, MemRefType):
raise ValueError(
- "affine.store memref operand must be of type MemrefType"
+ "affine.store memref operand must be of type MemRefType"
)
rank = memref_type.get_num_dims()
map = AffineMapAttr(AffineMap.identity(rank))
diff --git a/xdsl/dialects/bufferization.py b/xdsl/dialects/bufferization.py
index 5ff58c522a..303cb76e84 100644
--- a/xdsl/dialects/bufferization.py
+++ b/xdsl/dialects/bufferization.py
@@ -4,7 +4,7 @@
from xdsl.dialects.builtin import (
AnyTensorTypeConstr,
- AnyUnrankedMemrefTypeConstr,
+ AnyUnrankedMemRefTypeConstr,
AnyUnrankedTensorTypeConstr,
ContainerType,
IndexType,
@@ -12,7 +12,7 @@
ShapedType,
TensorType,
UnitAttr,
- UnrankedMemrefType,
+ UnrankedMemRefType,
UnrankedTensorType,
)
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
@@ -36,7 +36,7 @@
@dataclass(frozen=True)
-class TensorFromMemrefConstraint(
+class TensorFromMemRefConstraint(
GenericAttrConstraint[TensorType[Attribute] | UnrankedTensorType[Attribute]]
):
"""
@@ -46,7 +46,7 @@ class TensorFromMemrefConstraint(
"""
memref_constraint: GenericAttrConstraint[
- MemRefType[Attribute] | UnrankedMemrefType[Attribute]
+ MemRefType[Attribute] | UnrankedMemRefType[Attribute]
]
def can_infer(self, var_constraint_names: Set[str]) -> bool:
@@ -64,7 +64,7 @@ def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None
if isa(attr, TensorType[Attribute]):
memref_type = MemRefType(attr.element_type, attr.shape)
elif isa(attr, UnrankedTensorType[Attribute]):
- memref_type = UnrankedMemrefType.from_type(attr.element_type)
+ memref_type = UnrankedMemRefType.from_type(attr.element_type)
else:
raise VerifyException(
f"Expected tensor or unranked tensor type, got {attr}"
@@ -139,7 +139,7 @@ def __init__(
class CloneOp(IRDLOperation):
name = "bufferization.clone"
- T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
+ T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemRefTypeConstr)
input = operand_def(T)
output = result_def(T)
@@ -155,10 +155,10 @@ def __init__(self, input: SSAValue | Operation):
class ToTensorOp(IRDLOperation):
name = "bufferization.to_tensor"
- T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
+ T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemRefTypeConstr)
memref = operand_def(T)
- tensor = result_def(TensorFromMemrefConstraint(T))
+ tensor = result_def(TensorFromMemRefConstraint(T))
writable = opt_prop_def(UnitAttr)
restrict = opt_prop_def(UnitAttr)
@@ -192,11 +192,11 @@ def __init__(
@irdl_op_definition
-class ToMemrefOp(IRDLOperation):
+class ToMemRefOp(IRDLOperation):
name = "bufferization.to_memref"
- T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
- tensor = operand_def(TensorFromMemrefConstraint(T))
+ T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemRefTypeConstr)
+ tensor = operand_def(TensorFromMemRefConstraint(T))
memref = result_def(T)
read_only = opt_prop_def(UnitAttr)
@@ -208,10 +208,10 @@ class ToMemrefOp(IRDLOperation):
class MaterializeInDestinationOp(IRDLOperation):
name = "bufferization.materialize_in_destination"
- T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemrefTypeConstr)
- source = operand_def(TensorFromMemrefConstraint(T))
- dest = operand_def(T | TensorFromMemrefConstraint(T))
- result = opt_result_def(TensorFromMemrefConstraint(T))
+ T: ClassVar = VarConstraint("T", MemRefType.constr() | AnyUnrankedMemRefTypeConstr)
+ source = operand_def(TensorFromMemRefConstraint(T))
+ dest = operand_def(T | TensorFromMemRefConstraint(T))
+ result = opt_result_def(TensorFromMemRefConstraint(T))
restrict = opt_prop_def(UnitAttr)
writable = opt_prop_def(UnitAttr)
@@ -225,7 +225,7 @@ class MaterializeInDestinationOp(IRDLOperation):
AllocTensorOp,
CloneOp,
ToTensorOp,
- ToMemrefOp,
+ ToMemRefOp,
MaterializeInDestinationOp,
],
[],
diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py
index efefb47015..83048aa7ad 100644
--- a/xdsl/dialects/builtin.py
+++ b/xdsl/dialects/builtin.py
@@ -1406,7 +1406,7 @@ def from_strings(name: str, value: str, type: Attribute = NoneAttr()) -> OpaqueA
return OpaqueAttr([StringAttr(name), StringAttr(value), type])
-class MemrefLayoutAttr(Attribute, ABC):
+class MemRefLayoutAttr(Attribute, ABC):
"""
Interface for any attribute acceptable as a memref layout.
"""
@@ -1438,7 +1438,7 @@ def get_strides(self) -> Sequence[int | None] | None:
@irdl_attr_definition
-class StridedLayoutAttr(MemrefLayoutAttr, ParametrizedAttribute):
+class StridedLayoutAttr(MemRefLayoutAttr, ParametrizedAttribute):
"""
An attribute representing a strided layout of a shaped type.
See https://mlir.llvm.org/docs/Dialects/Builtin/#stridedlayoutattr
@@ -1526,7 +1526,7 @@ def get_affine_map(self) -> AffineMap:
@irdl_attr_definition
-class AffineMapAttr(MemrefLayoutAttr, Data[AffineMap]):
+class AffineMapAttr(MemRefLayoutAttr, Data[AffineMap]):
"""An Attribute containing an AffineMap object."""
name = "affine_map"
@@ -1821,10 +1821,10 @@ def print(self, printer: Printer) -> None:
_MemRefTypeElement = TypeVar(
"_MemRefTypeElement", bound=Attribute, covariant=True, default=Attribute
)
-_UnrankedMemrefTypeElems = TypeVar(
- "_UnrankedMemrefTypeElems", bound=Attribute, covariant=True
+_UnrankedMemRefTypeElems = TypeVar(
+ "_UnrankedMemRefTypeElems", bound=Attribute, covariant=True
)
-_UnrankedMemrefTypeElemsInit = TypeVar("_UnrankedMemrefTypeElemsInit", bound=Attribute)
+_UnrankedMemRefTypeElemsInit = TypeVar("_UnrankedMemRefTypeElemsInit", bound=Attribute)
@irdl_attr_definition
@@ -1844,14 +1844,14 @@ class MemRefType(
shape: ParameterDef[ArrayAttr[IntAttr]]
element_type: ParameterDef[_MemRefTypeElement]
- layout: ParameterDef[MemrefLayoutAttr | NoneAttr]
+ layout: ParameterDef[MemRefLayoutAttr | NoneAttr]
memory_space: ParameterDef[Attribute]
def __init__(
self,
element_type: _MemRefTypeElement,
shape: ArrayAttr[IntAttr] | Iterable[int | IntAttr],
- layout: MemrefLayoutAttr | NoneAttr = NoneAttr(),
+ layout: MemRefLayoutAttr | NoneAttr = NoneAttr(),
memory_space: Attribute = NoneAttr(),
):
s: ArrayAttr[IntAttr]
@@ -1976,7 +1976,7 @@ def constr(
@dataclass(frozen=True, init=False)
-class TensorOrMemrefOf(
+class TensorOrMemRefOf(
GenericAttrConstraint[TensorType[AttributeCovT] | MemRefType[AttributeCovT]]
):
"""A type constraint that can be nested once in a memref or a tensor."""
@@ -2022,30 +2022,30 @@ def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None
@irdl_attr_definition
-class UnrankedMemrefType(
- Generic[_UnrankedMemrefTypeElems],
+class UnrankedMemRefType(
+ Generic[_UnrankedMemRefTypeElems],
ParametrizedAttribute,
TypeAttribute,
- ContainerType[_UnrankedMemrefTypeElems],
+ ContainerType[_UnrankedMemRefTypeElems],
):
name = "unranked_memref"
- element_type: ParameterDef[_UnrankedMemrefTypeElems]
+ element_type: ParameterDef[_UnrankedMemRefTypeElems]
memory_space: ParameterDef[Attribute]
@staticmethod
def from_type(
- referenced_type: _UnrankedMemrefTypeElemsInit,
+ referenced_type: _UnrankedMemRefTypeElemsInit,
memory_space: Attribute = NoneAttr(),
- ) -> UnrankedMemrefType[_UnrankedMemrefTypeElemsInit]:
- return UnrankedMemrefType([referenced_type, memory_space])
+ ) -> UnrankedMemRefType[_UnrankedMemRefTypeElemsInit]:
+ return UnrankedMemRefType([referenced_type, memory_space])
- def get_element_type(self) -> _UnrankedMemrefTypeElems:
+ def get_element_type(self) -> _UnrankedMemRefTypeElems:
return self.element_type
-AnyUnrankedMemrefType: TypeAlias = UnrankedMemrefType[Attribute]
-AnyUnrankedMemrefTypeConstr = BaseAttr[AnyUnrankedMemrefType](UnrankedMemrefType)
+AnyUnrankedMemRefType: TypeAlias = UnrankedMemRefType[Attribute]
+AnyUnrankedMemRefTypeConstr = BaseAttr[AnyUnrankedMemRefType](UnrankedMemRefType)
RankedStructure: TypeAlias = (
VectorType[AttributeCovT] | TensorType[AttributeCovT] | MemRefType[AttributeCovT]
@@ -2369,6 +2369,6 @@ def print_without_type(self, printer: Printer):
AffineMapAttr,
AffineSetAttr,
MemRefType,
- UnrankedMemrefType,
+ UnrankedMemRefType,
],
)
diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py
index 9706c743be..de4ef9fc73 100644
--- a/xdsl/dialects/gpu.py
+++ b/xdsl/dialects/gpu.py
@@ -462,7 +462,7 @@ class HostRegisterOp(IRDLOperation):
name = "gpu.host_register"
- value = operand_def(memref.UnrankedMemrefType)
+ value = operand_def(memref.UnrankedMemRefType)
def __init__(self, memref: SSAValue | Operation):
super().__init__(operands=[SSAValue.get(memref)])
@@ -476,7 +476,7 @@ class HostUnregisterOp(IRDLOperation):
name = "gpu.host_unregister"
- value = operand_def(memref.UnrankedMemrefType)
+ value = operand_def(memref.UnrankedMemRefType)
def __init__(self, memref: SSAValue | Operation):
super().__init__(operands=[SSAValue.get(memref)])
diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py
index 25c3a9b5fb..3a8f87eea6 100644
--- a/xdsl/dialects/memref.py
+++ b/xdsl/dialects/memref.py
@@ -18,7 +18,7 @@
IntAttr,
IntegerAttr,
IntegerType,
- MemrefLayoutAttr,
+ MemRefLayoutAttr,
MemRefType,
NoneAttr,
SignlessIntegerConstraint,
@@ -26,7 +26,7 @@
StringAttr,
SymbolRefAttr,
UnitAttr,
- UnrankedMemrefType,
+ UnrankedMemRefType,
i32,
i64,
)
@@ -186,7 +186,7 @@ def get(
alignment: int | AnyIntegerAttr | None = None,
shape: Iterable[int | IntAttr] | None = None,
dynamic_sizes: Sequence[SSAValue | Operation] | None = None,
- layout: MemrefLayoutAttr | NoneAttr = NoneAttr(),
+ layout: MemRefLayoutAttr | NoneAttr = NoneAttr(),
memory_space: Attribute = NoneAttr(),
) -> Self:
if shape is None:
@@ -324,7 +324,7 @@ def get(
alignment: int | AnyIntegerAttr | None = None,
shape: Iterable[int | IntAttr] | None = None,
dynamic_sizes: Sequence[SSAValue | Operation] | None = None,
- layout: MemrefLayoutAttr | NoneAttr = NoneAttr(),
+ layout: MemRefLayoutAttr | NoneAttr = NoneAttr(),
memory_space: Attribute = NoneAttr(),
) -> AllocaOp:
if shape is None:
@@ -376,7 +376,7 @@ class AtomicRMWOp(IRDLOperation):
class DeallocOp(IRDLOperation):
name = "memref.dealloc"
memref = operand_def(
- base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute])
+ base(MemRefType[Attribute]) | base(UnrankedMemRefType[Attribute])
)
@staticmethod
@@ -465,7 +465,7 @@ class DimOp(IRDLOperation):
name = "memref.dim"
source = operand_def(
- base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute])
+ base(MemRefType[Attribute]) | base(UnrankedMemRefType[Attribute])
)
index = operand_def(IndexType)
@@ -650,14 +650,14 @@ def get(source: SSAValue | Operation):
)
-class MemrefHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
+class MemRefHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.memref import (
- MemrefSubviewOfSubviewFolding,
+ MemRefSubviewOfSubviewFolding,
)
- return (MemrefSubviewOfSubviewFolding(),)
+ return (MemRefSubviewOfSubviewFolding(),)
@irdl_op_definition
@@ -682,7 +682,7 @@ class SubviewOp(IRDLOperation):
irdl_options = [AttrSizedOperandSegments(as_property=True)]
traits = lazy_traits_def(
- lambda: (MemrefHasCanonicalizationPatternsTrait(), NoMemoryEffect())
+ lambda: (MemRefHasCanonicalizationPatternsTrait(), NoMemoryEffect())
)
def __init__(
@@ -901,16 +901,16 @@ class CastOp(IRDLOperation):
name = "memref.cast"
source = operand_def(
- base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute])
+ base(MemRefType[Attribute]) | base(UnrankedMemRefType[Attribute])
)
- dest = result_def(base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute]))
+ dest = result_def(base(MemRefType[Attribute]) | base(UnrankedMemRefType[Attribute]))
traits = traits_def(NoMemoryEffect())
@staticmethod
def get(
source: SSAValue | Operation,
- type: MemRefType[Attribute] | UnrankedMemrefType[Attribute],
+ type: MemRefType[Attribute] | UnrankedMemRefType[Attribute],
):
return CastOp.build(operands=[source], result_types=[type])
@@ -920,16 +920,16 @@ class MemorySpaceCastOp(IRDLOperation):
name = "memref.memory_space_cast"
source = operand_def(
- base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute])
+ base(MemRefType[Attribute]) | base(UnrankedMemRefType[Attribute])
)
- dest = result_def(base(MemRefType[Attribute]) | base(UnrankedMemrefType[Attribute]))
+ dest = result_def(base(MemRefType[Attribute]) | base(UnrankedMemRefType[Attribute]))
traits = traits_def(NoMemoryEffect())
def __init__(
self,
source: SSAValue | Operation,
- dest: MemRefType[Attribute] | UnrankedMemrefType[Attribute],
+ dest: MemRefType[Attribute] | UnrankedMemRefType[Attribute],
):
super().__init__(operands=[source], result_types=[dest])
diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py
index b8742b5905..bcfb143c18 100644
--- a/xdsl/dialects/memref_stream.py
+++ b/xdsl/dialects/memref_stream.py
@@ -953,7 +953,7 @@ def __init__(self, memref: SSAValue, value: SSAValue):
super().__init__(operands=(memref, value))
-MemrefStream = Dialect(
+MemRefStream = Dialect(
"memref_stream",
[
ReadOp,
diff --git a/xdsl/dialects/mpi.py b/xdsl/dialects/mpi.py
index d06299d5bc..6a87d7f6d2 100644
--- a/xdsl/dialects/mpi.py
+++ b/xdsl/dialects/mpi.py
@@ -651,7 +651,7 @@ class FinalizeOp(MPIBaseOp):
@irdl_op_definition
-class UnwrapMemrefOp(MPIBaseOp):
+class UnwrapMemRefOp(MPIBaseOp):
"""
This Op can be used as a helper to get memrefs into MPI calls.
@@ -848,7 +848,7 @@ def __init__(
InitOp,
FinalizeOp,
CommRankOp,
- UnwrapMemrefOp,
+ UnwrapMemRefOp,
GetDtypeOp,
AllocateTypeOp,
VectorGetOp,
diff --git a/xdsl/interpreters/__init__.py b/xdsl/interpreters/__init__.py
index 1c597b8ad6..666cf3927e 100644
--- a/xdsl/interpreters/__init__.py
+++ b/xdsl/interpreters/__init__.py
@@ -32,8 +32,8 @@ def register_implementations(interpreter: Interpreter, ctx: MLContext):
interpreter.register_implementations(cf.CfFunctions())
interpreter.register_implementations(func.FuncFunctions())
interpreter.register_implementations(linalg.LinalgFunctions())
- interpreter.register_implementations(memref_stream.MemrefStreamFunctions())
- interpreter.register_implementations(memref.MemrefFunctions())
+ interpreter.register_implementations(memref_stream.MemRefStreamFunctions())
+ interpreter.register_implementations(memref.MemRefFunctions())
interpreter.register_implementations(ml_program.MLProgramFunctions())
interpreter.register_implementations(pdl.PDLRewriteFunctions(ctx))
interpreter.register_implementations(printf.PrintfFunctions())
diff --git a/xdsl/interpreters/memref.py b/xdsl/interpreters/memref.py
index 278f7ad27c..745a356bc3 100644
--- a/xdsl/interpreters/memref.py
+++ b/xdsl/interpreters/memref.py
@@ -17,7 +17,7 @@
@register_impls
-class MemrefFunctions(InterpreterFunctions):
+class MemRefFunctions(InterpreterFunctions):
@impl(memref.AllocOp)
def run_alloc(
self, interpreter: Interpreter, op: memref.AllocOp, args: PythonValues
@@ -74,7 +74,7 @@ def run_get_global(
initial_value = mem.initial_value
if not isinstance(initial_value, builtin.DenseIntOrFPElementsAttr):
raise NotImplementedError(
- "Memrefs that are not dense int or float arrays are not implemented"
+ "MemRefs that are not dense int or float arrays are not implemented"
)
data = initial_value.get_values()
shape = initial_value.get_shape()
diff --git a/xdsl/interpreters/memref_stream.py b/xdsl/interpreters/memref_stream.py
index 62e16659fa..ebaec2559a 100644
--- a/xdsl/interpreters/memref_stream.py
+++ b/xdsl/interpreters/memref_stream.py
@@ -15,7 +15,7 @@
@register_impls
-class MemrefStreamFunctions(InterpreterFunctions):
+class MemRefStreamFunctions(InterpreterFunctions):
@impl(memref_stream.GenericOp)
def run_generic(
self,
diff --git a/xdsl/interpreters/shaped_array.py b/xdsl/interpreters/shaped_array.py
index 5bdf22629b..fe63cf9ee3 100644
--- a/xdsl/interpreters/shaped_array.py
+++ b/xdsl/interpreters/shaped_array.py
@@ -18,7 +18,7 @@
@dataclass
class ShapedArray(Generic[_T]):
"""
- A helper structure to represent instances of type MemrefType, TensorType, VectorType, etc.
+ A helper structure to represent instances of type MemRefType, TensorType, VectorType, etc.
in the interpreter.
"""
diff --git a/xdsl/parser/attribute_parser.py b/xdsl/parser/attribute_parser.py
index ba6e02a78c..499470ef86 100644
--- a/xdsl/parser/attribute_parser.py
+++ b/xdsl/parser/attribute_parser.py
@@ -41,7 +41,7 @@
IntegerAttr,
IntegerType,
LocationAttr,
- MemrefLayoutAttr,
+ MemRefLayoutAttr,
MemRefType,
NoneAttr,
NoneType,
@@ -53,7 +53,7 @@
SymbolRefAttr,
TensorType,
UnitAttr,
- UnrankedMemrefType,
+ UnrankedMemRefType,
UnrankedTensorType,
UnregisteredAttr,
VectorType,
@@ -527,15 +527,15 @@ def _parse_complex_attrs(self) -> ComplexType:
def _parse_memref_attrs(
self,
- ) -> MemRefType[Attribute] | UnrankedMemrefType[Attribute]:
+ ) -> MemRefType[Attribute] | UnrankedMemRefType[Attribute]:
shape, type = self.parse_shape()
# Unranked case
if shape is None:
if self.parse_optional_punctuation(",") is None:
- return UnrankedMemrefType.from_type(type)
+ return UnrankedMemRefType.from_type(type)
memory_space = self.parse_attribute()
- return UnrankedMemrefType.from_type(type, memory_space)
+ return UnrankedMemRefType.from_type(type, memory_space)
if self.parse_optional_punctuation(",") is None:
return MemRefType(type, shape)
@@ -546,12 +546,12 @@ def _parse_memref_attrs(
# layout is the second one
if self.parse_optional_punctuation(",") is not None:
memory_space = self.parse_attribute()
- if not isinstance(memory_or_layout, MemrefLayoutAttr):
+ if not isinstance(memory_or_layout, MemRefLayoutAttr):
self.raise_error("Expected a MemRef layout attribute")
return MemRefType(type, shape, memory_or_layout, memory_space)
- # If the argument is a MemrefLayoutAttr, use it as layout
- if isinstance(memory_or_layout, MemrefLayoutAttr):
+ # If the argument is a MemRefLayoutAttr, use it as layout
+ if isinstance(memory_or_layout, MemRefLayoutAttr):
return MemRefType(type, shape, layout=memory_or_layout)
# Otherwise, consider it as the memory space.
diff --git a/xdsl/printer.py b/xdsl/printer.py
index c81e2b0dd0..caa7609032 100644
--- a/xdsl/printer.py
+++ b/xdsl/printer.py
@@ -13,7 +13,7 @@
AffineSetAttr,
AnyFloat,
AnyFloatAttr,
- AnyUnrankedMemrefType,
+ AnyUnrankedMemRefType,
AnyUnrankedTensorType,
AnyVectorType,
ArrayAttr,
@@ -44,7 +44,7 @@
SymbolRefAttr,
TensorType,
UnitAttr,
- UnrankedMemrefType,
+ UnrankedMemRefType,
UnrankedTensorType,
UnregisteredAttr,
UnregisteredOp,
@@ -608,8 +608,8 @@ def print_int_or_question(value: IntAttr | NoneAttr) -> None:
self.print_string(">")
return
- if isinstance(attribute, UnrankedMemrefType):
- attribute = cast(AnyUnrankedMemrefType, attribute)
+ if isinstance(attribute, UnrankedMemRefType):
+ attribute = cast(AnyUnrankedMemRefType, attribute)
self.print_string("memref<*x")
self.print_attribute(attribute.element_type)
if not isinstance(attribute.memory_space, NoneAttr):
diff --git a/xdsl/transforms/__init__.py b/xdsl/transforms/__init__.py
index b039d5ad81..bc9f18f5b6 100644
--- a/xdsl/transforms/__init__.py
+++ b/xdsl/transforms/__init__.py
@@ -69,34 +69,34 @@ def get_convert_linalg_to_loops():
def get_convert_linalg_to_memref_stream():
from xdsl.transforms import convert_linalg_to_memref_stream
- return convert_linalg_to_memref_stream.ConvertLinalgToMemrefStreamPass
+ return convert_linalg_to_memref_stream.ConvertLinalgToMemRefStreamPass
def get_convert_memref_stream_to_loops():
from xdsl.transforms import convert_memref_stream_to_loops
- return convert_memref_stream_to_loops.ConvertMemrefStreamToLoopsPass
+ return convert_memref_stream_to_loops.ConvertMemRefStreamToLoopsPass
def get_convert_memref_stream_to_snitch_stream():
from xdsl.transforms import convert_memref_stream_to_snitch_stream
return (
- convert_memref_stream_to_snitch_stream.ConvertMemrefStreamToSnitchStreamPass
+ convert_memref_stream_to_snitch_stream.ConvertMemRefStreamToSnitchStreamPass
)
def get_convert_memref_to_ptr():
from xdsl.transforms import convert_memref_to_ptr
- return convert_memref_to_ptr.ConvertMemrefToPtr
+ return convert_memref_to_ptr.ConvertMemRefToPtr
def get_convert_memref_to_riscv():
from xdsl.backend.riscv.lowering import convert_memref_to_riscv
- return convert_memref_to_riscv.ConvertMemrefToRiscvPass
+ return convert_memref_to_riscv.ConvertMemRefToRiscvPass
def get_convert_ml_program_to_memref():
from xdsl.transforms import convert_ml_program_to_memref
- return convert_ml_program_to_memref.ConvertMlProgramToMemrefPass
+ return convert_ml_program_to_memref.ConvertMlProgramToMemRefPass
def get_convert_print_format_to_riscv_debug():
from xdsl.backend.riscv.lowering import convert_print_format_to_riscv_debug
@@ -286,7 +286,7 @@ def get_linalg_to_csl():
def get_loop_hoist_memref():
from xdsl.transforms import loop_hoist_memref
- return loop_hoist_memref.LoopHoistMemrefPass
+ return loop_hoist_memref.LoopHoistMemRefPass
def get_lower_affine():
from xdsl.transforms import lower_affine
@@ -331,52 +331,52 @@ def get_lower_snitch():
def get_memref_stream_fold_fill():
from xdsl.transforms import memref_stream_fold_fill
- return memref_stream_fold_fill.MemrefStreamFoldFillPass
+ return memref_stream_fold_fill.MemRefStreamFoldFillPass
def get_memref_stream_generalize_fill():
from xdsl.transforms import memref_stream_generalize_fill
- return memref_stream_generalize_fill.MemrefStreamGeneralizeFillPass
+ return memref_stream_generalize_fill.MemRefStreamGeneralizeFillPass
def get_memref_stream_infer_fill():
from xdsl.transforms import memref_stream_infer_fill
- return memref_stream_infer_fill.MemrefStreamInferFillPass
+ return memref_stream_infer_fill.MemRefStreamInferFillPass
def get_memref_stream_interleave():
from xdsl.transforms import memref_stream_interleave
- return memref_stream_interleave.MemrefStreamInterleavePass
+ return memref_stream_interleave.MemRefStreamInterleavePass
def get_memref_stream_legalize():
from xdsl.transforms import memref_stream_legalize
- return memref_stream_legalize.MemrefStreamLegalizePass
+ return memref_stream_legalize.MemRefStreamLegalizePass
def get_memref_stream_tile_outer_loops():
from xdsl.transforms import memref_stream_tile_outer_loops
- return memref_stream_tile_outer_loops.MemrefStreamTileOuterLoopsPass
+ return memref_stream_tile_outer_loops.MemRefStreamTileOuterLoopsPass
def get_memref_stream_unnest_out_parameters():
from xdsl.transforms import memref_stream_unnest_out_parameters
- return memref_stream_unnest_out_parameters.MemrefStreamUnnestOutParametersPass
+ return memref_stream_unnest_out_parameters.MemRefStreamUnnestOutParametersPass
def get_memref_streamify():
from xdsl.transforms import memref_streamify
- return memref_streamify.MemrefStreamifyPass
+ return memref_streamify.MemRefStreamifyPass
def get_memref_to_dsd():
from xdsl.transforms import memref_to_dsd
- return memref_to_dsd.MemrefToDsdPass
+ return memref_to_dsd.MemRefToDsdPass
def get_memref_to_gpu():
from xdsl.transforms import gpu_allocs
- return gpu_allocs.MemrefToGPUPass
+ return gpu_allocs.MemRefToGPUPass
def get_mlir_opt():
from xdsl.transforms import mlir_opt
diff --git a/xdsl/transforms/canonicalization_patterns/memref.py b/xdsl/transforms/canonicalization_patterns/memref.py
index e9efe8efa4..ef1c2e5265 100644
--- a/xdsl/transforms/canonicalization_patterns/memref.py
+++ b/xdsl/transforms/canonicalization_patterns/memref.py
@@ -10,7 +10,7 @@
from xdsl.utils.hints import isa
-class MemrefSubviewOfSubviewFolding(RewritePattern):
+class MemRefSubviewOfSubviewFolding(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter, /):
source_subview = op.source.owner
diff --git a/xdsl/transforms/convert_linalg_to_memref_stream.py b/xdsl/transforms/convert_linalg_to_memref_stream.py
index 83c02a744d..283e46ce9c 100644
--- a/xdsl/transforms/convert_linalg_to_memref_stream.py
+++ b/xdsl/transforms/convert_linalg_to_memref_stream.py
@@ -63,7 +63,7 @@ def match_and_rewrite(self, op: linalg.YieldOp, rewriter: PatternRewriter) -> No
rewriter.replace_matched_op(memref_stream.YieldOp(*op.operands))
-class ConvertLinalgToMemrefStreamPass(ModulePass):
+class ConvertLinalgToMemRefStreamPass(ModulePass):
name = "convert-linalg-to-memref-stream"
def apply(self, ctx: MLContext, op: ModuleOp) -> None:
diff --git a/xdsl/transforms/convert_memref_stream_to_loops.py b/xdsl/transforms/convert_memref_stream_to_loops.py
index ef8d50de2d..f97417fa2e 100644
--- a/xdsl/transforms/convert_memref_stream_to_loops.py
+++ b/xdsl/transforms/convert_memref_stream_to_loops.py
@@ -202,7 +202,7 @@ def insert_store(
)
-class ConvertMemrefStreamToLoopsPass(ModulePass):
+class ConvertMemRefStreamToLoopsPass(ModulePass):
"""
Converts a memref_stream generic to loop.
"""
diff --git a/xdsl/transforms/convert_memref_stream_to_snitch_stream.py b/xdsl/transforms/convert_memref_stream_to_snitch_stream.py
index a3277aed47..d3e04b7dd7 100644
--- a/xdsl/transforms/convert_memref_stream_to_snitch_stream.py
+++ b/xdsl/transforms/convert_memref_stream_to_snitch_stream.py
@@ -228,7 +228,7 @@ def strides_for_affine_map(
return result
-class ConvertMemrefStreamToSnitchStreamPass(ModulePass):
+class ConvertMemRefStreamToSnitchStreamPass(ModulePass):
"""
Converts memref_stream `read` and `write` operations to the snitch_stream equivalents.
diff --git a/xdsl/transforms/convert_memref_to_ptr.py b/xdsl/transforms/convert_memref_to_ptr.py
index f543ff5b48..ea995b14b5 100644
--- a/xdsl/transforms/convert_memref_to_ptr.py
+++ b/xdsl/transforms/convert_memref_to_ptr.py
@@ -155,7 +155,7 @@ def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /):
@dataclass
-class LowerMemrefFuncOpPattern(RewritePattern):
+class LowerMemRefFuncOpPattern(RewritePattern):
"""
Rewrites function arguments of MemRefType to PtrType.
"""
@@ -200,7 +200,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
@dataclass
-class LowerMemrefFuncReturnPattern(RewritePattern):
+class LowerMemRefFuncReturnPattern(RewritePattern):
"""
Rewrites all `memref` arguments to `func.return` into `ptr.PtrType`
"""
@@ -230,7 +230,7 @@ def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /):
@dataclass
-class LowerMemrefFuncCallPattern(RewritePattern):
+class LowerMemRefFuncCallPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /):
if not any(
@@ -327,7 +327,7 @@ def match_and_rewrite(
@dataclass(frozen=True)
-class ConvertMemrefToPtr(ModulePass):
+class ConvertMemRefToPtr(ModulePass):
name = "convert-memref-to-ptr"
lower_func: bool = False
@@ -341,9 +341,9 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
- LowerMemrefFuncOpPattern(),
- LowerMemrefFuncCallPattern(),
- LowerMemrefFuncReturnPattern(),
+ LowerMemRefFuncOpPattern(),
+ LowerMemRefFuncCallPattern(),
+ LowerMemRefFuncReturnPattern(),
ReconcileUnrealizedPtrCasts(),
]
)
diff --git a/xdsl/transforms/convert_ml_program_to_memref.py b/xdsl/transforms/convert_ml_program_to_memref.py
index 040a26d729..bd0fa5a812 100644
--- a/xdsl/transforms/convert_ml_program_to_memref.py
+++ b/xdsl/transforms/convert_ml_program_to_memref.py
@@ -58,7 +58,7 @@ def match_and_rewrite(
)
-class ConvertMlProgramToMemrefPass(ModulePass):
+class ConvertMlProgramToMemRefPass(ModulePass):
"""
Converts operations in the `ml_program` dialect to `memref`.
`ml_program` operations are at the `tensor` level of abstraction, so some of the
diff --git a/xdsl/transforms/convert_ptr_to_riscv.py b/xdsl/transforms/convert_ptr_to_riscv.py
index 6b59d794b5..b01e867e1b 100644
--- a/xdsl/transforms/convert_ptr_to_riscv.py
+++ b/xdsl/transforms/convert_ptr_to_riscv.py
@@ -110,7 +110,7 @@ def match_and_rewrite(self, op: ptr.LoadOp, rewriter: PatternRewriter, /):
@dataclass
-class ConvertMemrefToPtrOp(RewritePattern):
+class ConvertMemRefToPtrOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.ToPtrOp, rewriter: PatternRewriter, /):
rewriter.replace_matched_op(
@@ -131,7 +131,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
ConvertPtrAddOp(),
ConvertStoreOp(),
ConvertLoadOp(),
- ConvertMemrefToPtrOp(),
+ ConvertMemRefToPtrOp(),
]
),
).rewrite_module(op)
diff --git a/xdsl/transforms/csl_stencil_bufferize.py b/xdsl/transforms/csl_stencil_bufferize.py
index 01cb99eb54..81f44a3869 100644
--- a/xdsl/transforms/csl_stencil_bufferize.py
+++ b/xdsl/transforms/csl_stencil_bufferize.py
@@ -44,13 +44,13 @@ def tensor_to_memref_type(t: TensorType[Attribute]) -> memref.MemRefType[Attribu
return memref.MemRefType(t.get_element_type(), t.get_shape())
-def to_memref_op(op: SSAValue) -> bufferization.ToMemrefOp:
+def to_memref_op(op: SSAValue) -> bufferization.ToMemRefOp:
"""Creates a `bufferization.to_memref` operation."""
assert isa(op.type, AnyTensorType)
r_type = memref.MemRefType(
op.type.get_element_type(), op.type.get_shape()
) # todo set strided+offset here?
- return bufferization.ToMemrefOp(operands=[op], result_types=[r_type])
+ return bufferization.ToMemRefOp(operands=[op], result_types=[r_type])
def to_tensor_op(
@@ -412,7 +412,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
for arg, yld_arg in zip(op.dest, yld.arguments, strict=True):
if (
not isinstance(yld_arg, OpResult)
- or not isinstance(yld_arg.op, bufferization.ToMemrefOp)
+ or not isinstance(yld_arg.op, bufferization.ToMemRefOp)
or not isinstance(yld_arg.op.tensor, OpResult)
or not isinstance(
linalg_op := yld_arg.op.tensor.op,
diff --git a/xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py b/xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
index 5de1407881..245da8b4ac 100644
--- a/xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
+++ b/xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
@@ -69,7 +69,7 @@ def StencilToMemRefType(
@dataclass
-class CastOpToMemref(RewritePattern):
+class CastOpToMemRef(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: CastOp, rewriter: PatternRewriter, /):
assert isa(op.result.type, FieldType[Attribute])
@@ -133,7 +133,7 @@ def _find_result_store(result: SSAValue) -> tuple[StoreResultOp, ...]:
@dataclass
-class ReturnOpToMemref(RewritePattern):
+class ReturnOpToMemRef(RewritePattern):
return_target: dict[ApplyOp, list[SSAValue | None]]
@op_type_rewrite_pattern
@@ -228,7 +228,7 @@ def assert_subset(field: FieldType[Attribute], temp: TempType[Attribute]):
)
-class LoadOpToMemref(RewritePattern):
+class LoadOpToMemRef(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: LoadOp, rewriter: PatternRewriter, /):
for use in op.field.uses:
@@ -277,7 +277,7 @@ def prepare_apply_body(op: ApplyOp):
@dataclass
-class BufferOpToMemref(RewritePattern):
+class BufferOpToMemRef(RewritePattern):
return_targets: dict[ApplyOp, list[SSAValue | None]]
@op_type_rewrite_pattern
@@ -338,7 +338,7 @@ def field_subview(field: SSAValue):
)
-class AllocOpToMemref(RewritePattern):
+class AllocOpToMemRef(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: AllocOp, rewriter: PatternRewriter, /):
alloc = memref.AllocOp(
@@ -456,7 +456,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /):
@dataclass
-class AccessOpToMemref(RewritePattern):
+class AccessOpToMemRef(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /):
temp = op.temp.type
@@ -676,15 +676,15 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
[
ApplyOpFieldSubviews(),
ApplyOpToParallel(return_targets),
- BufferOpToMemref(return_targets),
+ BufferOpToMemRef(return_targets),
StencilStoreToSubview(return_targets),
- CastOpToMemref(),
- LoadOpToMemref(),
- AccessOpToMemref(),
- ReturnOpToMemref(return_targets),
+ CastOpToMemRef(),
+ LoadOpToMemRef(),
+ AccessOpToMemRef(),
+ ReturnOpToMemRef(return_targets),
TrivialExternalLoadOpCleanup(),
TrivialExternalStoreOpCleanup(),
- AllocOpToMemref(),
+ AllocOpToMemRef(),
]
),
apply_recursively=True,
diff --git a/xdsl/transforms/experimental/dmp/stencil_global_to_local.py b/xdsl/transforms/experimental/dmp/stencil_global_to_local.py
index 68eb3d76b3..4c14874419 100644
--- a/xdsl/transforms/experimental/dmp/stencil_global_to_local.py
+++ b/xdsl/transforms/experimental/dmp/stencil_global_to_local.py
@@ -285,7 +285,7 @@ def then() -> Iterable[Operation]:
# copy source area to outbound buffer
yield from generate_memcpy(source, ex.source_area(), alloc_outbound.memref)
# get ptr, count, dtype
- unwrap_out = mpi.UnwrapMemrefOp(alloc_outbound)
+ unwrap_out = mpi.UnwrapMemRefOp(alloc_outbound)
unwrap_out.ptr.name_hint = f"send_buff_ex{i}_ptr"
yield unwrap_out
@@ -305,7 +305,7 @@ def then() -> Iterable[Operation]:
)
# get ptr for receive buffer
- unwrap_in = mpi.UnwrapMemrefOp(alloc_inbound)
+ unwrap_in = mpi.UnwrapMemRefOp(alloc_inbound)
unwrap_in.ptr.name_hint = f"recv_buff_ex{i}_ptr"
yield unwrap_in
# Irecv call
@@ -439,7 +439,7 @@ def rewrite(
memref.AllocOp
| mpi.CommRankOp
| mpi.AllocateTypeOp
- | mpi.UnwrapMemrefOp
+ | mpi.UnwrapMemRefOp
| mpi.InitOp
),
rewriter: Rewriter,
@@ -450,7 +450,7 @@ def rewrite(
self.seen_ops.add(op)
# memref unwraps can always be moved to their allocation
- if isinstance(op, mpi.UnwrapMemrefOp) and isinstance(
+ if isinstance(op, mpi.UnwrapMemRefOp) and isinstance(
op.ref.owner, memref.AllocOp
):
op.detach()
@@ -499,7 +499,7 @@ def get_matcher(
memref.AllocOp
| mpi.CommRankOp
| mpi.AllocateTypeOp
- | mpi.UnwrapMemrefOp
+ | mpi.UnwrapMemRefOp
| mpi.InitOp
],
) -> Callable[[Operation], None]:
@@ -514,7 +514,7 @@ def match(op: Operation):
memref.AllocOp
| mpi.CommRankOp
| mpi.AllocateTypeOp
- | mpi.UnwrapMemrefOp
+ | mpi.UnwrapMemRefOp
| mpi.InitOp,
):
worklist.append(op)
@@ -533,7 +533,7 @@ def rewrite_module(self, op: builtin.ModuleOp):
memref.AllocOp
| mpi.CommRankOp
| mpi.AllocateTypeOp
- | mpi.UnwrapMemrefOp
+ | mpi.UnwrapMemRefOp
| mpi.InitOp
] = list()
matcher = self.get_matcher(worklist)
diff --git a/xdsl/transforms/experimental/hls_convert_stencil_to_ll_mlir.py b/xdsl/transforms/experimental/hls_convert_stencil_to_ll_mlir.py
index 41d5c8b1aa..e2ea78ad24 100644
--- a/xdsl/transforms/experimental/hls_convert_stencil_to_ll_mlir.py
+++ b/xdsl/transforms/experimental/hls_convert_stencil_to_ll_mlir.py
@@ -63,9 +63,9 @@
)
from xdsl.rewriter import InsertPoint
from xdsl.transforms.experimental.convert_stencil_to_ll_mlir import (
- AccessOpToMemref,
- CastOpToMemref,
- LoadOpToMemref,
+ AccessOpToMemRef,
+ CastOpToMemRef,
+ LoadOpToMemRef,
StencilToMemRefType,
TrivialExternalLoadOpCleanup,
TrivialExternalStoreOpCleanup,
@@ -1317,9 +1317,9 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
),
StencilAccessOpToReadBlockOp(),
StencilStoreToSubview(),
- CastOpToMemref(),
- LoadOpToMemref(),
- AccessOpToMemref(),
+ CastOpToMemRef(),
+ LoadOpToMemRef(),
+ AccessOpToMemRef(),
]
),
apply_recursively=False,
diff --git a/xdsl/transforms/gpu_allocs.py b/xdsl/transforms/gpu_allocs.py
index 4b3a8125de..ca87e188b5 100644
--- a/xdsl/transforms/gpu_allocs.py
+++ b/xdsl/transforms/gpu_allocs.py
@@ -28,7 +28,7 @@ def match_and_rewrite(self, op: memref.DeallocOp, rewriter: PatternRewriter, /):
rewriter.replace_matched_op(gpu.DeallocOp(op.memref))
-class MemrefToGPUPass(ModulePass):
+class MemRefToGPUPass(ModulePass):
name = "memref-to-gpu"
def apply(self, ctx: MLContext, op: ModuleOp) -> None:
diff --git a/xdsl/transforms/loop_hoist_memref.py b/xdsl/transforms/loop_hoist_memref.py
index 0d189fd2a7..91a5296070 100644
--- a/xdsl/transforms/loop_hoist_memref.py
+++ b/xdsl/transforms/loop_hoist_memref.py
@@ -69,7 +69,7 @@ def is_loop_dependent(val: SSAValue, loop: scf.ForOp):
@dataclass
-class LoopHoistMemref(RewritePattern):
+class LoopHoistMemRef(RewritePattern):
"""
Hoist pairs of memref.loads and memref.stores out of a loop.
@@ -190,14 +190,14 @@ def match_and_rewrite(
@dataclass(frozen=True)
-class LoopHoistMemrefPass(ModulePass):
+class LoopHoistMemRefPass(ModulePass):
name = "loop-hoist-memref"
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
- LoopHoistMemref(),
+ LoopHoistMemRef(),
]
),
walk_regions_first=True,
diff --git a/xdsl/transforms/lower_mpi.py b/xdsl/transforms/lower_mpi.py
index 3b8d880692..07a30bd08c 100644
--- a/xdsl/transforms/lower_mpi.py
+++ b/xdsl/transforms/lower_mpi.py
@@ -210,7 +210,7 @@ def _emit_memref_counts(
"""
assert isinstance(ssa_val_type := ssa_val.type, memref.MemRefType)
- # Note: we only allow MemRef, not UnrankedMemref!
+ # Note: we only allow MemRef, not UnrankedMemRef!
# TODO: handle -1 in sizes
if not all(dim >= 0 for dim in ssa_val_type.get_shape()):
raise RuntimeError("MPI lowering does not support unknown-size memrefs!")
@@ -619,13 +619,13 @@ def lower(self, op: mpi.RecvOp) -> tuple[list[Operation], list[SSAValue | None]]
], new_results
-class LowerMpiUnwrapMemrefOp(_MPIToLLVMRewriteBase):
+class LowerMpiUnwrapMemRefOp(_MPIToLLVMRewriteBase):
@op_type_rewrite_pattern
- def match_and_rewrite(self, op: mpi.UnwrapMemrefOp, rewriter: PatternRewriter, /):
+ def match_and_rewrite(self, op: mpi.UnwrapMemRefOp, rewriter: PatternRewriter, /):
rewriter.replace_matched_op(*self.lower(op))
def lower(
- self, op: mpi.UnwrapMemrefOp
+ self, op: mpi.UnwrapMemRefOp
) -> tuple[list[Operation], list[SSAValue | None]]:
count_ops, count_ssa_val = self._emit_memref_counts(op.ref)
extract_ptr_ops, ptr = self._memref_get_llvm_ptr(op.ref)
@@ -856,7 +856,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
LowerMpiReduce(lib_info),
LowerMpiAllreduce(lib_info),
LowerMpiBcast(lib_info),
- LowerMpiUnwrapMemrefOp(lib_info),
+ LowerMpiUnwrapMemRefOp(lib_info),
LowerMpiGetDtype(lib_info),
LowerMpiAllocateType(lib_info),
LowerNullRequestOp(lib_info),
diff --git a/xdsl/transforms/memref_stream_fold_fill.py b/xdsl/transforms/memref_stream_fold_fill.py
index 0036b470f5..5fa18f0318 100644
--- a/xdsl/transforms/memref_stream_fold_fill.py
+++ b/xdsl/transforms/memref_stream_fold_fill.py
@@ -55,7 +55,7 @@ def fold_fills_in_module(module_op: ModuleOp):
@dataclass(frozen=True)
-class MemrefStreamFoldFillPass(ModulePass):
+class MemRefStreamFoldFillPass(ModulePass):
"""
Folds `memref_stream.fill` operations that run immediately before a
`memref_stream.generic` operation into the init value.
diff --git a/xdsl/transforms/memref_stream_generalize_fill.py b/xdsl/transforms/memref_stream_generalize_fill.py
index e332e3bc5c..09e5989274 100644
--- a/xdsl/transforms/memref_stream_generalize_fill.py
+++ b/xdsl/transforms/memref_stream_generalize_fill.py
@@ -61,7 +61,7 @@ def match_and_rewrite(
@dataclass(frozen=True)
-class MemrefStreamGeneralizeFillPass(ModulePass):
+class MemRefStreamGeneralizeFillPass(ModulePass):
"""
Generalizes memref_stream.fill ops.
"""
diff --git a/xdsl/transforms/memref_stream_infer_fill.py b/xdsl/transforms/memref_stream_infer_fill.py
index 5d506ad5c3..dbc56fc6d5 100644
--- a/xdsl/transforms/memref_stream_infer_fill.py
+++ b/xdsl/transforms/memref_stream_infer_fill.py
@@ -70,7 +70,7 @@ def match_and_rewrite(
@dataclass(frozen=True)
-class MemrefStreamInferFillPass(ModulePass):
+class MemRefStreamInferFillPass(ModulePass):
"""
Detects memref_stream.generic operations that can be represented as
`memref_stream.fill` ops.
diff --git a/xdsl/transforms/memref_stream_interleave.py b/xdsl/transforms/memref_stream_interleave.py
index a413965d84..89a9121f1e 100644
--- a/xdsl/transforms/memref_stream_interleave.py
+++ b/xdsl/transforms/memref_stream_interleave.py
@@ -142,7 +142,7 @@ def match_and_rewrite(
@dataclass(frozen=True)
-class MemrefStreamInterleavePass(ModulePass):
+class MemRefStreamInterleavePass(ModulePass):
"""
Tiles the innermost parallel dimension of a `memref_stream.generic`.
If specified, the `pipeline-depth` parameter specifies the number of operations in the
diff --git a/xdsl/transforms/memref_stream_legalize.py b/xdsl/transforms/memref_stream_legalize.py
index bc1d86f2fc..5f0aadea55 100644
--- a/xdsl/transforms/memref_stream_legalize.py
+++ b/xdsl/transforms/memref_stream_legalize.py
@@ -116,7 +116,7 @@ def _legalize_block(
@dataclass(frozen=True)
-class MemrefStreamGenericLegalize(RewritePattern):
+class MemRefStreamGenericLegalize(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: memref_stream.GenericOp, rewriter: PatternRewriter
@@ -191,7 +191,7 @@ def match_and_rewrite(
@dataclass(frozen=True)
-class MemrefStreamLegalizePass(ModulePass):
+class MemRefStreamLegalizePass(ModulePass):
"""
Legalize memref_stream.generic payload and bounds for streaming.
"""
@@ -200,6 +200,6 @@ class MemrefStreamLegalizePass(ModulePass):
def apply(self, ctx: MLContext, op: ModuleOp) -> None:
PatternRewriteWalker(
- GreedyRewritePatternApplier([MemrefStreamGenericLegalize()]),
+ GreedyRewritePatternApplier([MemRefStreamGenericLegalize()]),
apply_recursively=False,
).rewrite_module(op)
diff --git a/xdsl/transforms/memref_stream_tile_outer_loops.py b/xdsl/transforms/memref_stream_tile_outer_loops.py
index 6f4eb5b5b9..70d9e6d92d 100644
--- a/xdsl/transforms/memref_stream_tile_outer_loops.py
+++ b/xdsl/transforms/memref_stream_tile_outer_loops.py
@@ -220,7 +220,7 @@ def match_and_rewrite(
@dataclass(frozen=True)
-class MemrefStreamTileOuterLoopsPass(ModulePass):
+class MemRefStreamTileOuterLoopsPass(ModulePass):
"""
Materializes loops around memref_stream.generic operations that have greater than
specified number of non-1 upper bounds.
diff --git a/xdsl/transforms/memref_stream_unnest_out_parameters.py b/xdsl/transforms/memref_stream_unnest_out_parameters.py
index 73bf7db221..7028772a40 100644
--- a/xdsl/transforms/memref_stream_unnest_out_parameters.py
+++ b/xdsl/transforms/memref_stream_unnest_out_parameters.py
@@ -54,7 +54,7 @@ def match_and_rewrite(
@dataclass(frozen=True)
-class MemrefStreamUnnestOutParametersPass(ModulePass):
+class MemRefStreamUnnestOutParametersPass(ModulePass):
"""
Converts the affine maps of memref_stream.generic out parameters from taking all the
indices to only taking "parallel" ones.
diff --git a/xdsl/transforms/memref_streamify.py b/xdsl/transforms/memref_streamify.py
index b00694d7a3..7fc76d7050 100644
--- a/xdsl/transforms/memref_streamify.py
+++ b/xdsl/transforms/memref_streamify.py
@@ -133,7 +133,7 @@ def match_and_rewrite(
@dataclass(frozen=True)
-class MemrefStreamifyPass(ModulePass):
+class MemRefStreamifyPass(ModulePass):
"""
Converts a memref generic on memrefs to a memref generic on streams, by moving it into
a streaming region.
diff --git a/xdsl/transforms/memref_to_dsd.py b/xdsl/transforms/memref_to_dsd.py
index a637f40033..ac2291ad00 100644
--- a/xdsl/transforms/memref_to_dsd.py
+++ b/xdsl/transforms/memref_to_dsd.py
@@ -100,9 +100,9 @@ def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter, /):
raise ValueError("Failed to resolve GetMemDsdOp called on dsd type")
-class FixMemrefLoadOnGetDsd(RewritePattern):
+class FixMemRefLoadOnGetDsd(RewritePattern):
"""
- Memref load ops should load from the underlying memref, not from the dsd.
+ MemRef load ops should load from the underlying memref, not from the dsd.
"""
@op_type_rewrite_pattern
@@ -401,7 +401,7 @@ def match_and_rewrite(self, op: csl.LoadVarOp, rewriter: PatternRewriter, /):
@dataclass(frozen=True)
-class MemrefToDsdPass(ModulePass):
+class MemRefToDsdPass(ModulePass):
"""
Lowers memref ops to CSL DSDs.
@@ -434,7 +434,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
LowerAllocOpPass(),
DsdOpUpdateType(),
RetainAddressOfOpPass(),
- FixMemrefLoadOnGetDsd(),
+ FixMemRefLoadOnGetDsd(),
FixGetDsdOnGetDsd(),
]
),
diff --git a/xdsl/transforms/test_lower_linalg_to_snitch.py b/xdsl/transforms/test_lower_linalg_to_snitch.py
index 09b413ea44..4c8657e78e 100644
--- a/xdsl/transforms/test_lower_linalg_to_snitch.py
+++ b/xdsl/transforms/test_lower_linalg_to_snitch.py
@@ -37,27 +37,27 @@
OPTIMISE_MEMREF_STREAM_PASSES: tuple[ModulePass, ...] = (
canonicalize.CanonicalizePass(),
- memref_stream_infer_fill.MemrefStreamInferFillPass(),
- memref_stream_unnest_out_parameters.MemrefStreamUnnestOutParametersPass(),
- memref_stream_fold_fill.MemrefStreamFoldFillPass(),
- memref_stream_generalize_fill.MemrefStreamGeneralizeFillPass(),
- memref_stream_interleave.MemrefStreamInterleavePass(),
- memref_stream_tile_outer_loops.MemrefStreamTileOuterLoopsPass(target_rank=4),
- memref_streamify.MemrefStreamifyPass(),
- convert_memref_stream_to_loops.ConvertMemrefStreamToLoopsPass(),
+ memref_stream_infer_fill.MemRefStreamInferFillPass(),
+ memref_stream_unnest_out_parameters.MemRefStreamUnnestOutParametersPass(),
+ memref_stream_fold_fill.MemRefStreamFoldFillPass(),
+ memref_stream_generalize_fill.MemRefStreamGeneralizeFillPass(),
+ memref_stream_interleave.MemRefStreamInterleavePass(),
+ memref_stream_tile_outer_loops.MemRefStreamTileOuterLoopsPass(target_rank=4),
+ memref_streamify.MemRefStreamifyPass(),
+ convert_memref_stream_to_loops.ConvertMemRefStreamToLoopsPass(),
canonicalize.CanonicalizePass(),
scf_for_loop_flatten.ScfForLoopFlattenPass(),
)
LOWER_MEMREF_STREAM_TO_SNITCH_STREAM_PASSES: tuple[ModulePass, ...] = (
canonicalize.CanonicalizePass(),
- convert_memref_to_riscv.ConvertMemrefToRiscvPass(),
+ convert_memref_to_riscv.ConvertMemRefToRiscvPass(),
lower_affine.LowerAffinePass(),
convert_scf_to_riscv_scf.ConvertScfToRiscvPass(),
convert_arith_to_riscv_snitch.ConvertArithToRiscvSnitchPass(),
convert_arith_to_riscv.ConvertArithToRiscvPass(),
convert_func_to_riscv_func.ConvertFuncToRiscvFuncPass(),
- convert_memref_stream_to_snitch_stream.ConvertMemrefStreamToSnitchStreamPass(),
+ convert_memref_stream_to_snitch_stream.ConvertMemRefStreamToSnitchStreamPass(),
reconcile_unrealized_casts.ReconcileUnrealizedCastsPass(),
)
@@ -78,8 +78,8 @@
TEST_LOWER_LINALG_TO_SNITCH_PASSES: tuple[ModulePass, ...] = (
canonicalize.CanonicalizePass(),
- convert_linalg_to_memref_stream.ConvertLinalgToMemrefStreamPass(),
- memref_stream_legalize.MemrefStreamLegalizePass(),
+ convert_linalg_to_memref_stream.ConvertLinalgToMemRefStreamPass(),
+ memref_stream_legalize.MemRefStreamLegalizePass(),
*OPTIMISE_MEMREF_STREAM_PASSES,
*LOWER_MEMREF_STREAM_TO_SNITCH_STREAM_PASSES,
*LOWER_SNITCH_STREAM_TO_ASM_PASSES,
From 60fd9981fd02c310d773faa77dc33ce53f509185 Mon Sep 17 00:00:00 2001
From: Joren Dumoulin
Date: Wed, 5 Feb 2025 11:27:03 +0100
Subject: [PATCH 06/23] dialects: (linalg) add hidden region to transpose op
(#3838)
This adds a hidden region to the linalg.transpose op to ensure correct
generic printing
Also changes permutation to a property instead of attribute.
This resolves the transpose op in #2959
This has now been checked manually, and will be put in ci with #3837
(but for that 3 other ops need to be fixed, PRs incoming...)
---
.../with-mlir/dialects/linalg/invalid_ops.mlir | 6 +++---
xdsl/dialects/linalg.py | 13 +++++++++++--
2 files changed, 14 insertions(+), 5 deletions(-)
diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir
index 23def88c2a..7b13db5546 100644
--- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir
+++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir
@@ -15,7 +15,7 @@ builtin.module {
%0, %1 = "test.op"() : () -> (tensor<16x64xf32>, tensor<64x16x1xf32>)
// CHECK: Operation does not verify: Input rank (2) does not match output rank (3)
- %res_transpose = "linalg.transpose"(%0, %1) {"permutation" = array} : (tensor<16x64xf32>, tensor<64x16x1xf32>) -> tensor<64x16x1xf32>
+ %res_transpose = linalg.transpose ins(%0 : tensor<16x64xf32>) outs(%1 : tensor<64x16x1xf32>) permutation = [1, 0]
}
@@ -25,7 +25,7 @@ builtin.module {
%0, %1 = "test.op"() : () -> (tensor<16x64xf32>, tensor<64x16xf32>)
// CHECK: Operation does not verify: Input rank (2) does not match size of permutation (3)
- %res_transpose = "linalg.transpose"(%0, %1) {"permutation" = array} : (tensor<16x64xf32>, tensor<64x16xf32>) -> tensor<64x16xf32>
+ %res_transpose = linalg.transpose ins(%0 : tensor<16x64xf32>) outs(%1 : tensor<64x16xf32>) permutation = [1, 2, 3]
}
@@ -35,7 +35,7 @@ builtin.module {
%0, %1 = "test.op"() : () -> (tensor<16x32x64xf32>, tensor<32x64x16xf32>)
// CHECK: Operation does not verify: dim(result, 1) = 64 doesn't match dim(input, permutation[1]) = 32
- %res_transpose = "linalg.transpose"(%0, %1) {"permutation" = array} : (tensor<16x32x64xf32>, tensor<32x64x16xf32>) -> tensor<32x64x16xf32>
+ %res_transpose = linalg.transpose ins(%0 : tensor<16x32x64xf32>) outs(%1 : tensor<32x64x16xf32>) permutation = [1, 1, 2]
}
diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py
index 9d3d4a6d87..f73c715a03 100644
--- a/xdsl/dialects/linalg.py
+++ b/xdsl/dialects/linalg.py
@@ -732,7 +732,9 @@ class TransposeOp(IRDLOperation):
init = operand_def(base(MemRefType) | base(AnyTensorType))
result = var_result_def(AnyTensorType)
- permutation = attr_def(DenseArrayBase)
+ hidden_region = region_def("single_block")
+
+ permutation = prop_def(DenseArrayBase)
def __init__(
self,
@@ -741,12 +743,19 @@ def __init__(
permutation: Attribute,
result: Attribute | None = None,
):
+ arg_types = NamedOpBase.body_arg_types((input, init))
+
+ @Builder.implicit_region(arg_types)
+ def hidden_region(args: tuple[BlockArgument, ...]) -> None:
+ YieldOp(args[0])
+
super().__init__(
- attributes={
+ properties={
"permutation": permutation,
},
operands=(input, init),
result_types=(result,),
+ regions=(hidden_region,),
)
def verify_(self) -> None:
From 9eceb8c98abb53228c66de5fd3af59051e87b56c Mon Sep 17 00:00:00 2001
From: Joren Dumoulin
Date: Wed, 5 Feb 2025 11:27:16 +0100
Subject: [PATCH 07/23] dialects: (linalg) let PoolingOpsBase inherit from
NamedOpBase (#3839)
To enable correct printing of the hidden regions in generic printing.
Resolves pooling in #2959 , will be tested in #3837
The constructor of these ops had to change to comply with the
NamedOpBase constructor ordering of arguments
---
tests/interpreters/test_linalg_interpreter.py | 12 ++--
xdsl/dialects/linalg.py | 61 +++++++++----------
2 files changed, 36 insertions(+), 37 deletions(-)
diff --git a/tests/interpreters/test_linalg_interpreter.py b/tests/interpreters/test_linalg_interpreter.py
index 542958d669..8e9407e6f0 100644
--- a/tests/interpreters/test_linalg_interpreter.py
+++ b/tests/interpreters/test_linalg_interpreter.py
@@ -275,14 +275,16 @@ def test_linalg_pooling_nchw_max():
interpreter = Interpreter(ModuleOp([]))
interpreter.register_implementations(LinalgFunctions())
op = linalg.PoolingNchwMaxOp(
- DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
- DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
(
TestSSAValue(TensorType(f32, [1, 1, 4, 4])),
TestSSAValue(TensorType(f32, [2, 2])),
),
(TestSSAValue(TensorType(f32, [1, 1, 3, 3])),),
(TensorType(f32, [1, 1, 3, 3]),),
+ {
+ "dilations": DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
+ "strides": DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
+ },
)
a = ShapedArray(TypedPtr.new_float32(list(range(1, 17))), [1, 1, 4, 4])
b = ShapedArray(
@@ -306,14 +308,16 @@ def test_linalg_pooling_nchw_max_strides_two():
interpreter = Interpreter(ModuleOp([]))
interpreter.register_implementations(LinalgFunctions())
op = linalg.PoolingNchwMaxOp(
- DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
- DenseIntOrFPElementsAttr.tensor_from_list([2], i64, [2]),
(
TestSSAValue(TensorType(f32, [1, 1, 4, 4])),
TestSSAValue(TensorType(f32, [2, 2])),
),
(TestSSAValue(TensorType(f32, [1, 1, 2, 2])),),
(TensorType(f32, [1, 1, 2, 2]),),
+ {
+ "dilations": DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
+ "strides": DenseIntOrFPElementsAttr.tensor_from_list([2], i64, [2]),
+ },
)
a = ShapedArray(
TypedPtr.new_float32([1, 1, 2, 4, 5, 6, 7, 8, 3, 2, 1, 0, 1, 2, 3, 4]),
diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py
index f73c715a03..f11d3c5072 100644
--- a/xdsl/dialects/linalg.py
+++ b/xdsl/dialects/linalg.py
@@ -958,45 +958,14 @@ def hidden_region(args: tuple[BlockArgument, ...]) -> None:
)
-class PoolingOpsBase(IRDLOperation, ABC):
+class PoolingOpsBase(NamedOpBase, ABC):
"""Base class for linalg pooling operations."""
- inputs = var_operand_def()
- outputs = var_operand_def(base(ShapedType))
-
- res = var_result_def(AnyTensorType)
-
- assembly_format = (
- "attr-dict `ins` `(` $inputs `:` type($inputs) `)` ` ` "
- "`outs` `(` $outputs `:` type($outputs) `)` `->` type($res)"
- )
+ PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True
strides = attr_def(DenseIntOrFPElementsAttr)
dilations = attr_def(DenseIntOrFPElementsAttr)
- irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]
-
- def __init__(
- self,
- dilations: Attribute,
- strides: Attribute,
- inputs: Sequence[SSAValue],
- outputs: Sequence[SSAValue] = (),
- res: Sequence[Attribute] | None = None,
- ):
- if res is None:
- result_types = tuple(output.type for output in outputs)
- else:
- result_types = res
- super().__init__(
- attributes={
- "dilations": dilations,
- "strides": strides,
- },
- operands=(inputs, outputs),
- result_types=result_types,
- )
-
@irdl_op_definition
class PoolingNchwMaxOp(PoolingOpsBase):
@@ -1008,6 +977,32 @@ class PoolingNchwMaxOp(PoolingOpsBase):
name = "linalg.pooling_nchw_max"
+ def __init__(
+ self,
+ inputs: Sequence[SSAValue],
+ outputs: Sequence[SSAValue] = (),
+ res: Sequence[Attribute] | None = None,
+ attributes: dict[str, Attribute] | None = None,
+ ):
+ arg_types = self.body_arg_types((*inputs, *outputs))
+
+ max_op = (
+ arith.MaximumfOp if isinstance(arg_types[-1], AnyFloat) else arith.MaxSIOp
+ )
+
+ @Builder.implicit_region(arg_types)
+ def hidden_region(args: tuple[BlockArgument, ...]) -> None:
+ result = max_op(args[0], args[1])
+ YieldOp(result)
+
+ super().__init__(
+ ins=inputs,
+ outs=outputs,
+ result_types=res,
+ attributes=attributes,
+ hidden_region=hidden_region,
+ )
+
class ConvOpsBase(IRDLOperation, ABC):
"""Base class for linalg convolution operations."""
From a216535b78599a722896c76867736ecee0af5113 Mon Sep 17 00:00:00 2001
From: Joren Dumoulin
Date: Wed, 5 Feb 2025 11:27:27 +0100
Subject: [PATCH 08/23] dialects: (linalg) add hidden region to BroadcastOp
(#3840)
Resolves pooling in https://github.com/xdslproject/xdsl/issues/2959 ,
will be tested in https://github.com/xdslproject/xdsl/pull/3837
---
.../with-mlir/dialects/linalg/invalid_ops.mlir | 6 +++---
xdsl/dialects/linalg.py | 9 +++++++++
2 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir
index 7b13db5546..ba4197c4b1 100644
--- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir
+++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir
@@ -44,7 +44,7 @@ builtin.module {
builtin.module {
%0, %1 = "test.op"() : () -> (tensor<16xf32>, tensor<32x64x16xf32>)
// CHECK: Operation does not verify: Input rank plus added dimensions (2) does not match output rank (3)
- %res_transpose = "linalg.broadcast"(%0, %1) {"dimensions" = array} : (tensor<16xf32>, tensor<32x64x16xf32>) -> tensor<32x64x16xf32>
+ %res_broadcast = linalg.broadcast ins(%0 : tensor<16xf32>) outs(%1 : tensor<32x64x16xf32>) dimensions = [1]
}
@@ -53,7 +53,7 @@ builtin.module {
builtin.module {
%0, %1 = "test.op"() : () -> (tensor<16xf32>, tensor<16x64xf32>)
// CHECK: Operation does not verify: Dimension 0 is out of range. Expected range: [0, 1], got: 9
- %res_transpose = "linalg.broadcast"(%0, %1) {"dimensions" = array} : (tensor<16xf32>, tensor<16x64xf32>) -> tensor<16x64xf32>
+ %res_broadcast = linalg.broadcast ins(%0 : tensor<16xf32>) outs(%1 : tensor<16x64xf32>) dimensions = [9]
}
@@ -62,6 +62,6 @@ builtin.module {
builtin.module {
%0, %1 = "test.op"() : () -> (tensor<3x4x5xf32>, tensor<4x5x6x2xf32>)
// CHECK: Operation does not verify: input dimension 0 should match output dimension 0. input: 3, output: 4
- %res_transpose = "linalg.broadcast"(%0, %1) {"dimensions" = array} : (tensor<3x4x5xf32>, tensor<4x5x6x2xf32>) -> tensor<4x5x6x2xf32>
+ %res_broadcast = linalg.broadcast ins(%0 : tensor<3x4x5xf32>) outs(%1 : tensor<4x5x6x2xf32>) dimensions = [1]
}
diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py
index f11d3c5072..3655d4d6cb 100644
--- a/xdsl/dialects/linalg.py
+++ b/xdsl/dialects/linalg.py
@@ -1069,6 +1069,8 @@ class BroadcastOp(IRDLOperation):
init = operand_def(base(MemRefType) | base(AnyTensorType))
result = var_result_def(AnyTensorType)
+ hidden_region = region_def("single_block")
+
dimensions = attr_def(DenseArrayBase)
def __init__(
@@ -1078,12 +1080,19 @@ def __init__(
dimensions: Attribute,
result: Attribute | None = None,
):
+ arg_types = NamedOpBase.body_arg_types((input, init))
+
+ @Builder.implicit_region(arg_types)
+ def hidden_region(args: tuple[BlockArgument, ...]) -> None:
+ YieldOp(args[0])
+
super().__init__(
attributes={
"dimensions": dimensions,
},
operands=(input, init),
result_types=(result,),
+ regions=(hidden_region,),
)
def verify_(self) -> None:
From 41b16cdf48611c85035648c203cb4a151373d393 Mon Sep 17 00:00:00 2001
From: Joren Dumoulin
Date: Wed, 5 Feb 2025 11:27:39 +0100
Subject: [PATCH 09/23] dialects: (linalg) let ConvOpsBase inherit from
NamedOpsBase (#3841)
This enables to correct printing of the hidden regions in generic format
Resolves conv ops in #2959, will be tested with #3837
---
.../with-mlir/dialects/linalg/ops.mlir | 12 ++++-
tests/interpreters/test_linalg_interpreter.py | 6 ++-
xdsl/dialects/linalg.py | 52 ++++++++++---------
3 files changed, 43 insertions(+), 27 deletions(-)
diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir
index 9d942b95d9..6b788a6e3e 100644
--- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir
+++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir
@@ -59,7 +59,7 @@ linalg.fill ins(%4 : f32) outs(%1 : memref<1x256xf32>)
%18, %19 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>)
%20 = "test.op"() : () -> (memref<64x4096xf32>)
-%zero = arith.constant 0: f32
+%zero = arith.constant 0.0 : f32
linalg.fill {id} ins(%zero : f32) outs(%20 : memref<64x4096xf32>)
linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%20 : memref<64x4096xf32>)
@@ -72,6 +72,14 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou
%quant_mat_mul = linalg.quantized_matmul ins(%21, %22, %23, %24 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%25 : tensor<64x4096xi32>) -> tensor<64x4096xi32>
+%26, %27, %28 = "test.op"(): () -> (tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>, tensor<1x1x3x3xi32>)
+
+%conv_2d_nchw_i = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%26, %27: tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>)
+ outs(%28: tensor<1x1x3x3xi32>) -> tensor<1x1x3x3xi32>
+
+
+
// CHECK-NEXT: #map = affine_map<(d0, d1) -> ()>
// CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-NEXT: module {
@@ -117,4 +125,6 @@ linalg.matmul {id} ins(%18, %19 : memref<64x9216xf32>, memref<9216x4096xf32>) ou
// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32
// CHECK-NEXT: %19 = "test.op"() : () -> tensor<64x4096xi32>
// CHECK-NEXT: %20 = linalg.quantized_matmul ins(%18#0, %18#1, %c0_i32, %c0_i32_1 : tensor<64x9216xi8>, tensor<9216x4096xi8>, i32, i32) outs(%19 : tensor<64x4096xi32>) -> tensor<64x4096xi32>
+// CHECK-NEXT: %21:3 = "test.op"() : () -> (tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>, tensor<1x1x3x3xi32>)
+// CHECK-NEXT: %22 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%21#0, %21#1 : tensor<1x1x5x5xi8>, tensor<1x1x3x3xi8>) outs(%21#2 : tensor<1x1x3x3xi32>) -> tensor<1x1x3x3xi32>
// CHECK-NEXT: }
diff --git a/tests/interpreters/test_linalg_interpreter.py b/tests/interpreters/test_linalg_interpreter.py
index 8e9407e6f0..5dbfa5068f 100644
--- a/tests/interpreters/test_linalg_interpreter.py
+++ b/tests/interpreters/test_linalg_interpreter.py
@@ -341,14 +341,16 @@ def test_linalg_conv_2d_nchw_fchw():
interpreter = Interpreter(ModuleOp([]))
interpreter.register_implementations(LinalgFunctions())
op = linalg.Conv2DNchwFchwOp(
- DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
- DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
(
TestSSAValue(TensorType(f32, [1, 1, 5, 5])),
TestSSAValue(TensorType(f32, [1, 1, 3, 3])),
),
(TestSSAValue(TensorType(f32, [1, 1, 3, 3])),),
(TensorType(f32, [1, 1, 3, 3]),),
+ {
+ "dilations": DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
+ "strides": DenseIntOrFPElementsAttr.tensor_from_list([1], i64, [2]),
+ },
)
a = ShapedArray(TypedPtr.new_float32(list(range(25))), [1, 1, 5, 5])
b = ShapedArray(
diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py
index 3655d4d6cb..0dfdc70d65 100644
--- a/xdsl/dialects/linalg.py
+++ b/xdsl/dialects/linalg.py
@@ -1004,43 +1004,47 @@ def hidden_region(args: tuple[BlockArgument, ...]) -> None:
)
-class ConvOpsBase(IRDLOperation, ABC):
+class ConvOpsBase(NamedOpBase, ABC):
"""Base class for linalg convolution operations."""
- inputs = var_operand_def()
- outputs = var_operand_def(base(ShapedType))
-
- res = var_result_def(AnyTensorType)
-
- assembly_format = (
- "attr-dict `ins` `(` $inputs `:` type($inputs) `)` ` ` "
- "`outs` `(` $outputs `:` type($outputs) `)` `->` type($res)"
- )
+ PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True
strides = attr_def(DenseIntOrFPElementsAttr)
dilations = attr_def(DenseIntOrFPElementsAttr)
- irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]
-
def __init__(
self,
- dilations: Attribute,
- strides: Attribute,
inputs: Sequence[SSAValue],
outputs: Sequence[SSAValue] = (),
res: Sequence[Attribute] | None = None,
+ attributes: dict[str, Attribute] | None = None,
):
- if res is None:
- result_types = tuple(output.type for output in outputs)
- else:
- result_types = res
+ arg_types = self.body_arg_types((*inputs, *outputs))
+ add, mul = (
+ (arith.AddfOp, arith.MulfOp)
+ if isinstance(arg_types[-1], AnyFloat)
+ else (arith.AddiOp, arith.MuliOp)
+ )
+
+ @Builder.implicit_region(arg_types)
+ def hidden_region(args: tuple[BlockArgument, ...]) -> None:
+ if arg_types[0] != arg_types[-1]:
+ assert isinstance(arg_types[-1], IntegerType)
+ a = arith.ExtSIOp(args[0], arg_types[-1])
+ b = arith.ExtSIOp(args[1], arg_types[-1])
+ else:
+ a = args[0]
+ b = args[1]
+ result = mul(a, b)
+ mac = add(result, args[2])
+ YieldOp(mac)
+
super().__init__(
- attributes={
- "dilations": dilations,
- "strides": strides,
- },
- operands=(inputs, outputs),
- result_types=result_types,
+ ins=inputs,
+ outs=outputs,
+ attributes=attributes,
+ result_types=res,
+ hidden_region=hidden_region,
)
From 75fbbe0afb8791b441403792d663f3895f8829e8 Mon Sep 17 00:00:00 2001
From: Joren Dumoulin
Date: Wed, 5 Feb 2025 11:49:27 +0100
Subject: [PATCH 10/23] dialects: (linalg) enable generic printing in mlir
conversion filecheck (#3837)
This PR adds generic printing to the mlir conversion filecheck
This will check whether the issues posed in #2959 are correctly
resolved, by checking if mlir correctly parses the generic output of
xdsl
There are still reminaing issues, solved in the following PRs:
(stacked on: ) #3838, #3839, #3840, #3841
---
.../filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir | 1 +
1 file changed, 1 insertion(+)
diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir
index 6b788a6e3e..74b801d4fe 100644
--- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir
+++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/ops.mlir
@@ -1,4 +1,5 @@
// RUN: xdsl-opt %s | xdsl-opt | mlir-opt --allow-unregistered-dialect | filecheck %s
+// RUN: xdsl-opt %s | xdsl-opt --print-op-generic | mlir-opt --allow-unregistered-dialect | filecheck %s
%0, %1 = "test.op"() : () -> (f32, memref<1x256xf32>)
From 617303c00f884475840564ee1ac2c6b73427a579 Mon Sep 17 00:00:00 2001
From: Alex Rice
Date: Wed, 5 Feb 2025 12:19:09 +0000
Subject: [PATCH 11/23] dialects: (builtin) remove AnyIntegerAttrConstr (#3842)
`IntegerAttr` and `IntegerAttr.constr()` are sufficient
---
tests/irdl/test_declarative_assembly_format.py | 5 ++---
xdsl/dialects/arith.py | 3 +--
xdsl/dialects/builtin.py | 1 -
xdsl/dialects/csl/csl.py | 3 +--
xdsl/transforms/lower_csl_wrapper.py | 2 +-
5 files changed, 5 insertions(+), 9 deletions(-)
diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py
index 42c8e9ca42..109ee39341 100644
--- a/tests/irdl/test_declarative_assembly_format.py
+++ b/tests/irdl/test_declarative_assembly_format.py
@@ -12,7 +12,6 @@
from xdsl.dialects import test
from xdsl.dialects.builtin import (
I32,
- AnyIntegerAttrConstr,
BoolAttr,
Float64Type,
FloatAttr,
@@ -2951,7 +2950,7 @@ class OptConstantOp(IRDLOperation):
name = "test.opt_constant"
T: ClassVar = VarConstraint("T", AnyAttr())
- value = opt_prop_def(TypedAttributeConstraint(AnyIntegerAttrConstr, T))
+ value = opt_prop_def(TypedAttributeConstraint(IntegerAttr.constr(), T))
res = opt_result_def(T)
@@ -2984,7 +2983,7 @@ class DefaultConstantOp(IRDLOperation):
T: ClassVar = VarConstraint("T", AnyAttr())
value = prop_def(
- TypedAttributeConstraint(AnyIntegerAttrConstr, T),
+ TypedAttributeConstraint(IntegerAttr.constr(), T),
default_value=BoolAttr.from_bool(True),
)
diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py
index 6850f33ffc..a5a7ffd0f6 100644
--- a/xdsl/dialects/arith.py
+++ b/xdsl/dialects/arith.py
@@ -8,7 +8,6 @@
AnyFloat,
AnyFloatConstr,
AnyIntegerAttr,
- AnyIntegerAttrConstr,
ContainerOf,
DenseIntOrFPElementsAttr,
Float16Type,
@@ -132,7 +131,7 @@ class ConstantOp(IRDLOperation):
result = result_def(_T)
value = prop_def(
TypedAttributeConstraint(
- AnyIntegerAttrConstr
+ IntegerAttr.constr()
| BaseAttr[FloatAttr[AnyFloat]](FloatAttr)
| BaseAttr(DenseIntOrFPElementsAttr),
_T,
diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py
index 83048aa7ad..c2d6622004 100644
--- a/xdsl/dialects/builtin.py
+++ b/xdsl/dialects/builtin.py
@@ -783,7 +783,6 @@ def unpack(
AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType]
-AnyIntegerAttrConstr: BaseAttr[AnyIntegerAttr] = BaseAttr(IntegerAttr)
BoolAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(1)]]
diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py
index 606e72174d..bc15deaec9 100644
--- a/xdsl/dialects/csl/csl.py
+++ b/xdsl/dialects/csl/csl.py
@@ -20,7 +20,6 @@
AnyFloatAttr,
AnyFloatAttrConstr,
AnyIntegerAttr,
- AnyIntegerAttrConstr,
ArrayAttr,
BoolAttr,
ContainerType,
@@ -420,7 +419,7 @@ def get_element_type(self) -> TypeAttribute:
QueueIdAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(3)]]
ParamAttr: TypeAlias = AnyFloatAttr | AnyIntegerAttr
-ParamAttrConstr = AnyFloatAttrConstr | AnyIntegerAttrConstr
+ParamAttrConstr = AnyFloatAttrConstr | IntegerAttr.constr()
@irdl_op_definition
diff --git a/xdsl/transforms/lower_csl_wrapper.py b/xdsl/transforms/lower_csl_wrapper.py
index 428fad1ab3..f58cefba1d 100644
--- a/xdsl/transforms/lower_csl_wrapper.py
+++ b/xdsl/transforms/lower_csl_wrapper.py
@@ -55,7 +55,7 @@ def _collect_params(
params = list[SSAValue]()
for param in op.params:
- if isattr(param.value, builtin.AnyIntegerAttrConstr):
+ if isattr(param.value, builtin.IntegerAttr):
value = arith.ConstantOp(param.value)
else:
value = None
From 51db69f8e12b70076d19bb304274806368fb360d Mon Sep 17 00:00:00 2001
From: Alex Rice
Date: Wed, 5 Feb 2025 12:31:54 +0000
Subject: [PATCH 12/23] dialects: (builtin) remove AnyIntegerAttr (#3843)
s/AnyIntegerAttr/IntegerAttr/g
---
tests/test_parser.py | 3 +-
tests/test_traits.py | 3 +-
xdsl/dialects/accfg.py | 9 +--
xdsl/dialects/affine.py | 5 +-
xdsl/dialects/arith.py | 45 ++++++------
xdsl/dialects/builtin.py | 11 ++-
xdsl/dialects/csl/csl.py | 3 +-
xdsl/dialects/csl/csl_stencil.py | 8 +-
xdsl/dialects/csl/csl_wrapper.py | 3 +-
xdsl/dialects/experimental/air.py | 4 +-
xdsl/dialects/experimental/fir.py | 7 +-
xdsl/dialects/experimental/hlfir.py | 3 +-
xdsl/dialects/llvm.py | 11 ++-
xdsl/dialects/memref.py | 9 +--
xdsl/dialects/mod_arith.py | 4 +-
xdsl/dialects/riscv.py | 33 ++++-----
xdsl/dialects/riscv_func.py | 3 +-
xdsl/dialects/seq.py | 5 +-
xdsl/dialects/transform.py | 35 +++++----
xdsl/dialects/x86/assembly.py | 9 +--
xdsl/dialects/x86/ops.py | 73 +++++++++----------
xdsl/interpreters/arith.py | 6 +-
xdsl/interpreters/builtin.py | 6 +-
xdsl/interpreters/riscv.py | 3 +-
xdsl/parser/attribute_parser.py | 3 +-
xdsl/tools/tblgen_to_py.py | 2 +-
.../canonicalization_patterns/cf.py | 11 ++-
.../canonicalization_patterns/csl.py | 6 +-
.../canonicalization_patterns/utils.py | 4 +-
xdsl/transforms/linalg_to_csl.py | 4 +-
30 files changed, 155 insertions(+), 176 deletions(-)
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 2abf201384..3e660f6960 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -7,7 +7,6 @@
from xdsl.context import MLContext
from xdsl.dialects.builtin import (
AnyFloatAttr,
- AnyIntegerAttr,
ArrayAttr,
Builtin,
DictionaryAttr,
@@ -823,7 +822,7 @@ def test_parse_number(
],
)
def test_parse_optional_builtin_int_or_float_attr(
- text: str, expected_value: AnyIntegerAttr | AnyFloatAttr | None
+ text: str, expected_value: IntegerAttr | AnyFloatAttr | None
):
parser = Parser(MLContext(), text)
if expected_value is None:
diff --git a/tests/test_traits.py b/tests/test_traits.py
index 7121a1470b..2c599fd579 100644
--- a/tests/test_traits.py
+++ b/tests/test_traits.py
@@ -14,7 +14,6 @@
from xdsl.dialects import test
from xdsl.dialects.builtin import (
DYNAMIC_INDEX,
- AnyIntegerAttr,
AnyTensorTypeConstr,
AnyUnrankedMemRefTypeConstr,
AnyUnrankedTensorTypeConstr,
@@ -306,7 +305,7 @@ class NoSymNameOp(IRDLOperation):
class SymNameWrongTypeOp(IRDLOperation):
name = "wrong_sym_name_type"
- sym_name = attr_def(AnyIntegerAttr)
+ sym_name = attr_def(IntegerAttr)
traits = traits_def(SymbolOpInterface())
op1 = SymNameWrongTypeOp(
diff --git a/xdsl/dialects/accfg.py b/xdsl/dialects/accfg.py
index 92921fa216..17f7fc5fe9 100644
--- a/xdsl/dialects/accfg.py
+++ b/xdsl/dialects/accfg.py
@@ -5,7 +5,6 @@
from typing import cast
from xdsl.dialects.builtin import (
- AnyIntegerAttr,
ArrayAttr,
DictionaryAttr,
IntegerAttr,
@@ -422,16 +421,16 @@ def verify_(self) -> None:
def field_names(self) -> tuple[str, ...]:
return tuple(self.fields.data.keys())
- def field_items(self) -> Iterable[tuple[str, AnyIntegerAttr]]:
+ def field_items(self) -> Iterable[tuple[str, IntegerAttr]]:
for name, val in self.fields.data.items():
- yield name, cast(AnyIntegerAttr, val)
+ yield name, cast(IntegerAttr, val)
def launch_field_names(self) -> tuple[str, ...]:
return tuple(self.launch_fields.data.keys())
- def launch_field_items(self) -> Iterable[tuple[str, AnyIntegerAttr]]:
+ def launch_field_items(self) -> Iterable[tuple[str, IntegerAttr]]:
for name, val in self.launch_fields.data.items():
- yield name, cast(AnyIntegerAttr, val)
+ yield name, cast(IntegerAttr, val)
@irdl_op_definition
diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py
index 7608faf645..f7c42d51c8 100644
--- a/xdsl/dialects/affine.py
+++ b/xdsl/dialects/affine.py
@@ -6,7 +6,6 @@
from xdsl.dialects.builtin import (
AffineMapAttr,
AffineSetAttr,
- AnyIntegerAttr,
ArrayAttr,
ContainerType,
DenseIntOrFPElementsAttr,
@@ -121,7 +120,7 @@ class ForOp(IRDLOperation):
lowerBoundMap = prop_def(AffineMapAttr)
upperBoundMap = prop_def(AffineMapAttr)
- step = prop_def(AnyIntegerAttr)
+ step = prop_def(IntegerAttr)
body = region_def()
@@ -168,7 +167,7 @@ def from_region(
lower_bound: int | AffineMapAttr,
upper_bound: int | AffineMapAttr,
region: Region,
- step: int | AnyIntegerAttr = 1,
+ step: int | IntegerAttr = 1,
) -> ForOp:
if isinstance(lower_bound, int):
lower_bound = AffineMapAttr(
diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py
index a5a7ffd0f6..cd343a5da9 100644
--- a/xdsl/dialects/arith.py
+++ b/xdsl/dialects/arith.py
@@ -7,7 +7,6 @@
from xdsl.dialects.builtin import (
AnyFloat,
AnyFloatConstr,
- AnyIntegerAttr,
ContainerOf,
DenseIntOrFPElementsAttr,
Float16Type,
@@ -145,7 +144,7 @@ class ConstantOp(IRDLOperation):
@overload
def __init__(
self,
- value: AnyIntegerAttr | FloatAttr[AnyFloat] | DenseIntOrFPElementsAttr,
+ value: IntegerAttr | FloatAttr[AnyFloat] | DenseIntOrFPElementsAttr,
value_type: None = None,
) -> None: ...
@@ -154,11 +153,11 @@ def __init__(self, value: Attribute, value_type: Attribute) -> None: ...
def __init__(
self,
- value: AnyIntegerAttr | FloatAttr[AnyFloat] | Attribute,
+ value: IntegerAttr | FloatAttr[AnyFloat] | Attribute,
value_type: Attribute | None = None,
):
if value_type is None:
- value = cast(AnyIntegerAttr | FloatAttr[AnyFloat], value)
+ value = cast(IntegerAttr | FloatAttr[AnyFloat], value)
value_type = value.type
super().__init__(
operands=[], result_types=[value_type], properties={"value": value}
@@ -207,7 +206,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return None
@staticmethod
- def is_right_zero(attr: AnyIntegerAttr) -> bool:
+ def is_right_zero(attr: IntegerAttr) -> bool:
"""
Returns True only when 'attr' is a right zero for the operation
https://en.wikipedia.org/wiki/Absorbing_element
@@ -218,7 +217,7 @@ def is_right_zero(attr: AnyIntegerAttr) -> bool:
return False
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
"""
Return True only when 'attr' is a right unit/identity for the operation
https://en.wikipedia.org/wiki/Identity_element
@@ -379,7 +378,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs + rhs
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0
@@ -462,11 +461,11 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs * rhs
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)
@staticmethod
- def is_right_zero(attr: AnyIntegerAttr) -> bool:
+ def is_right_zero(attr: IntegerAttr) -> bool:
return attr.value.data == 0
@@ -524,7 +523,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs - rhs
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0
@@ -555,7 +554,7 @@ class DivUIOp(SignlessIntegerBinaryOperation):
)
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)
@@ -574,7 +573,7 @@ class DivSIOp(SignlessIntegerBinaryOperation):
)
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)
@@ -591,7 +590,7 @@ class FloorDivSIOp(SignlessIntegerBinaryOperation):
)
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)
@@ -604,7 +603,7 @@ class CeilDivSIOp(SignlessIntegerBinaryOperation):
)
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)
@@ -618,7 +617,7 @@ class CeilDivUIOp(SignlessIntegerBinaryOperation):
)
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr == IntegerAttr(1, attr.type)
@@ -677,7 +676,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs & rhs
@staticmethod
- def is_right_zero(attr: AnyIntegerAttr) -> bool:
+ def is_right_zero(attr: IntegerAttr) -> bool:
return attr.value.data == 0
@@ -696,7 +695,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs | rhs
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0
@@ -715,7 +714,7 @@ def py_operation(lhs: int, rhs: int) -> int | None:
return lhs ^ rhs
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0
@@ -733,7 +732,7 @@ class ShLIOp(SignlessIntegerBinaryOperationWithOverflow):
)
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0
@@ -752,7 +751,7 @@ class ShRUIOp(SignlessIntegerBinaryOperation):
)
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0
@@ -772,7 +771,7 @@ class ShRSIOp(SignlessIntegerBinaryOperation):
)
@staticmethod
- def is_right_unit(attr: AnyIntegerAttr) -> bool:
+ def is_right_unit(attr: IntegerAttr) -> bool:
return attr.value.data == 0
@@ -853,7 +852,7 @@ class CmpiOp(ComparisonOperation):
"""
name = "arith.cmpi"
- predicate = prop_def(AnyIntegerAttr)
+ predicate = prop_def(IntegerAttr)
lhs = operand_def(signlessIntegerLike)
rhs = operand_def(signlessIntegerLike)
result = result_def(IntegerType(1))
@@ -945,7 +944,7 @@ class CmpfOp(ComparisonOperation):
"""
name = "arith.cmpf"
- predicate = prop_def(AnyIntegerAttr)
+ predicate = prop_def(IntegerAttr)
lhs = operand_def(floatingPointLike)
rhs = operand_def(floatingPointLike)
fastmath = prop_def(FastMathFlagsAttr, default_value=FastMathFlagsAttr("none"))
diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py
index c2d6622004..d856658cfc 100644
--- a/xdsl/dialects/builtin.py
+++ b/xdsl/dialects/builtin.py
@@ -782,7 +782,6 @@ def unpack(
return tuple(IntegerAttr(value, type) for value in type.unpack(buffer, num))
-AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType]
BoolAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(1)]]
@@ -1352,13 +1351,13 @@ def iter_values(self) -> Iterator[float] | Iterator[int]:
def get_values(self) -> tuple[int, ...] | tuple[float, ...]:
return self.elt_type.unpack(self.data.data, len(self))
- def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
+ def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.iter_unpack(self.elt_type, self.data.data)
else:
return FloatAttr.iter_unpack(self.elt_type, self.data.data)
- def get_attrs(self) -> tuple[AnyIntegerAttr, ...] | tuple[AnyFloatAttr, ...]:
+ def get_attrs(self) -> tuple[IntegerAttr, ...] | tuple[AnyFloatAttr, ...]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.unpack(self.elt_type, self.data.data, len(self))
else:
@@ -2180,7 +2179,7 @@ def from_list(
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
- data: Sequence[int | float] | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
+ data: Sequence[int | float] | Sequence[IntegerAttr] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr:
# zero rank type should only hold 1 value
if not type.get_shape() and len(data) != 1:
@@ -2249,7 +2248,7 @@ def get_values(self) -> Sequence[int] | Sequence[float]:
"""
return self.get_element_type().unpack(self.data.data, len(self))
- def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
+ def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
"""
Return an iterator over all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
@@ -2259,7 +2258,7 @@ def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
else:
return FloatAttr.iter_unpack(eltype, self.data.data)
- def get_attrs(self) -> Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr]:
+ def get_attrs(self) -> Sequence[IntegerAttr] | Sequence[AnyFloatAttr]:
"""
Return all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py
index bc15deaec9..5c48bc9325 100644
--- a/xdsl/dialects/csl/csl.py
+++ b/xdsl/dialects/csl/csl.py
@@ -19,7 +19,6 @@
AffineMapAttr,
AnyFloatAttr,
AnyFloatAttrConstr,
- AnyIntegerAttr,
ArrayAttr,
BoolAttr,
ContainerType,
@@ -418,7 +417,7 @@ def get_element_type(self) -> TypeAttribute:
QueueIdAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(3)]]
-ParamAttr: TypeAlias = AnyFloatAttr | AnyIntegerAttr
+ParamAttr: TypeAlias = AnyFloatAttr | IntegerAttr
ParamAttrConstr = AnyFloatAttrConstr | IntegerAttr.constr()
diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py
index 06d7e19bf8..0f484bd3ac 100644
--- a/xdsl/dialects/csl/csl_stencil.py
+++ b/xdsl/dialects/csl/csl_stencil.py
@@ -5,12 +5,12 @@
from xdsl.dialects import builtin, memref, stencil
from xdsl.dialects.builtin import (
AnyFloat,
- AnyIntegerAttr,
AnyTensorTypeConstr,
Float16Type,
Float32Type,
FloatAttr,
IndexType,
+ IntegerAttr,
MemRefType,
TensorType,
)
@@ -138,7 +138,7 @@ class PrefetchOp(IRDLOperation):
topo = prop_def(dmp.RankTopoAttr)
- num_chunks = prop_def(AnyIntegerAttr)
+ num_chunks = prop_def(IntegerAttr)
result = result_def(MemRefType.constr() | AnyTensorTypeConstr)
@@ -146,7 +146,7 @@ def __init__(
self,
input_stencil: SSAValue | Operation,
topo: dmp.RankTopoAttr,
- num_chunks: AnyIntegerAttr,
+ num_chunks: IntegerAttr,
swaps: Sequence[ExchangeDeclarationAttr],
result_type: memref.MemRefType[Attribute] | TensorType[Attribute] | None = None,
):
@@ -241,7 +241,7 @@ class ApplyOp(IRDLOperation):
topo = prop_def(dmp.RankTopoAttr)
- num_chunks = prop_def(AnyIntegerAttr)
+ num_chunks = prop_def(IntegerAttr)
bounds = opt_prop_def(stencil.StencilBoundsAttr)
diff --git a/xdsl/dialects/csl/csl_wrapper.py b/xdsl/dialects/csl/csl_wrapper.py
index 0e6d658cdf..414a1c96ed 100644
--- a/xdsl/dialects/csl/csl_wrapper.py
+++ b/xdsl/dialects/csl/csl_wrapper.py
@@ -6,7 +6,6 @@
from xdsl.dialects import builtin
from xdsl.dialects.builtin import (
- AnyIntegerAttr,
ArrayAttr,
IntegerAttr,
IntegerType,
@@ -73,7 +72,7 @@ def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
if parser.parse_optional_keyword("default"):
parser.parse_punctuation("=")
val = parser.parse_attribute()
- assert isa(val, AnyIntegerAttr)
+ assert isa(val, IntegerAttr)
assert isinstance(val.type, IntegerType)
type = val.type
else:
diff --git a/xdsl/dialects/experimental/air.py b/xdsl/dialects/experimental/air.py
index f87b2b1795..fe30103925 100644
--- a/xdsl/dialects/experimental/air.py
+++ b/xdsl/dialects/experimental/air.py
@@ -12,10 +12,10 @@
from xdsl.dialects.builtin import (
AnyArrayAttr,
- AnyIntegerAttr,
ArrayAttr,
IndexType,
IntAttr,
+ IntegerAttr,
MemRefType,
StringAttr,
SymbolRefAttr,
@@ -93,7 +93,7 @@ class ChannelOp(IRDLOperation):
size = prop_def(AnyArrayAttr)
def __init__(
- self, sym_name: SymbolRefAttr, size: ArrayAttr[AnyIntegerAttr]
+ self, sym_name: SymbolRefAttr, size: ArrayAttr[IntegerAttr]
): # TODO: add verify to check 64-bit integer array attribute
super().__init__(properties={"sym_name": sym_name, "size": size})
diff --git a/xdsl/dialects/experimental/fir.py b/xdsl/dialects/experimental/fir.py
index fc053a6b73..c1ce5d7287 100644
--- a/xdsl/dialects/experimental/fir.py
+++ b/xdsl/dialects/experimental/fir.py
@@ -18,7 +18,6 @@
from xdsl.dialects.arith import FastMathFlagsAttr
from xdsl.dialects.builtin import (
AnyFloat,
- AnyIntegerAttr,
ArrayAttr,
IndexType,
IntAttr,
@@ -203,7 +202,7 @@ class LLVMPointerType(ParametrizedAttribute, TypeAttribute):
name = "fir.llvm_ptr"
- type: ParameterDef[AnyIntegerAttr | AnyFloat]
+ type: ParameterDef[IntegerAttr | AnyFloat]
@irdl_attr_definition
@@ -227,7 +226,7 @@ class SequenceType(ParametrizedAttribute, TypeAttribute):
"""
name = "fir.array"
- shape: ParameterDef[ArrayAttr[AnyIntegerAttr | DeferredAttr | NoneType]]
+ shape: ParameterDef[ArrayAttr[IntegerAttr | DeferredAttr | NoneType]]
type: ParameterDef[IntegerType | AnyFloat | ReferenceType]
type2: ParameterDef[IntegerType | AnyFloat | ReferenceType | NoneType]
@@ -1259,7 +1258,7 @@ class DispatchOp(IRDLOperation):
"""
name = "fir.dispatch"
- pass_arg_pos = opt_prop_def(AnyIntegerAttr)
+ pass_arg_pos = opt_prop_def(IntegerAttr)
object = operand_def()
args = operand_def()
result_0 = result_def()
diff --git a/xdsl/dialects/experimental/hlfir.py b/xdsl/dialects/experimental/hlfir.py
index 45f2dc14b6..1eac938cff 100644
--- a/xdsl/dialects/experimental/hlfir.py
+++ b/xdsl/dialects/experimental/hlfir.py
@@ -15,7 +15,6 @@
from xdsl.dialects.arith import FastMathFlagsAttr
from xdsl.dialects.builtin import (
AnyFloat,
- AnyIntegerAttr,
ArrayAttr,
Attribute,
BoolAttr,
@@ -62,7 +61,7 @@ class ExprType(ParametrizedAttribute, TypeAttribute):
"""
name = "hlfir.expr"
- shape: ParameterDef[ArrayAttr[AnyIntegerAttr | DeferredAttr | NoneType]]
+ shape: ParameterDef[ArrayAttr[IntegerAttr | DeferredAttr | NoneType]]
elementType: ParameterDef[IntegerType | AnyFloat | ReferenceType]
def print_parameters(self, printer: Printer) -> None:
diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py
index 0799faddfe..9bab9224b8 100644
--- a/xdsl/dialects/llvm.py
+++ b/xdsl/dialects/llvm.py
@@ -9,7 +9,6 @@
from xdsl.dialects.builtin import (
I64,
AnyFloatConstr,
- AnyIntegerAttr,
ArrayAttr,
ContainerType,
DenseArrayBase,
@@ -1072,7 +1071,7 @@ class AllocaOp(IRDLOperation):
size = operand_def(IntegerType)
- alignment = opt_prop_def(AnyIntegerAttr)
+ alignment = opt_prop_def(IntegerAttr)
elem_type = opt_prop_def(Attribute)
res = result_def()
@@ -1352,9 +1351,9 @@ class GlobalOp(IRDLOperation):
thread_local_ = opt_prop_def(UnitAttr)
visibility_ = opt_prop_def(IntegerAttr[IntegerType])
value = opt_prop_def(Attribute)
- alignment = opt_prop_def(AnyIntegerAttr)
- addr_space = prop_def(AnyIntegerAttr)
- unnamed_addr = opt_prop_def(AnyIntegerAttr)
+ alignment = opt_prop_def(IntegerAttr)
+ addr_space = prop_def(IntegerAttr)
+ unnamed_addr = opt_prop_def(IntegerAttr)
section = opt_prop_def(StringAttr)
# This always needs an empty region as it is in the top level module definition
@@ -1586,7 +1585,7 @@ def parse(cls, parser: Parser):
def print(self, printer: Printer) -> None:
printer.print("(")
- if isattr(self.value, AnyIntegerAttr) and self.result.type == IntegerType(64):
+ if isattr(self.value, IntegerAttr) and self.result.type == IntegerType(64):
self.value.print_without_type(printer)
else:
printer.print(self.value)
diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py
index 3a8f87eea6..20b02b3cb5 100644
--- a/xdsl/dialects/memref.py
+++ b/xdsl/dialects/memref.py
@@ -9,7 +9,6 @@
from xdsl.dialects.builtin import (
I64,
AnyFloatConstr,
- AnyIntegerAttr,
ArrayAttr,
BoolAttr,
DenseArrayBase,
@@ -160,7 +159,7 @@ class AllocOp(IRDLOperation):
memref = result_def(MemRefType[Attribute])
# TODO how to constraint the IntegerAttr type?
- alignment = opt_prop_def(AnyIntegerAttr)
+ alignment = opt_prop_def(IntegerAttr)
irdl_options = [AttrSizedOperandSegments(as_property=True)]
@@ -183,7 +182,7 @@ def __init__(
def get(
cls,
return_type: Attribute,
- alignment: int | AnyIntegerAttr | None = None,
+ alignment: int | IntegerAttr | None = None,
shape: Iterable[int | IntAttr] | None = None,
dynamic_sizes: Sequence[SSAValue | Operation] | None = None,
layout: MemRefLayoutAttr | NoneAttr = NoneAttr(),
@@ -314,14 +313,14 @@ class AllocaOp(IRDLOperation):
memref = result_def(MemRefType[Attribute])
# TODO how to constraint the IntegerAttr type?
- alignment = opt_prop_def(AnyIntegerAttr)
+ alignment = opt_prop_def(IntegerAttr)
irdl_options = [AttrSizedOperandSegments(as_property=True)]
@staticmethod
def get(
return_type: Attribute,
- alignment: int | AnyIntegerAttr | None = None,
+ alignment: int | IntegerAttr | None = None,
shape: Iterable[int | IntAttr] | None = None,
dynamic_sizes: Sequence[SSAValue | Operation] | None = None,
layout: MemRefLayoutAttr | NoneAttr = NoneAttr(),
diff --git a/xdsl/dialects/mod_arith.py b/xdsl/dialects/mod_arith.py
index faffbe3b87..89d74bca85 100644
--- a/xdsl/dialects/mod_arith.py
+++ b/xdsl/dialects/mod_arith.py
@@ -7,7 +7,7 @@
from typing import ClassVar
from xdsl.dialects.arith import signlessIntegerLike
-from xdsl.dialects.builtin import AnyIntegerAttr
+from xdsl.dialects.builtin import IntegerAttr
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
IRDLOperation,
@@ -32,7 +32,7 @@ class BinaryOp(IRDLOperation, ABC):
lhs = operand_def(T)
rhs = operand_def(T)
output = result_def(T)
- modulus = prop_def(AnyIntegerAttr)
+ modulus = prop_def(IntegerAttr)
irdl_options = [ParsePropInAttrDict()]
diff --git a/xdsl/dialects/riscv.py b/xdsl/dialects/riscv.py
index 7f9efaf437..8130e234f6 100644
--- a/xdsl/dialects/riscv.py
+++ b/xdsl/dialects/riscv.py
@@ -14,7 +14,6 @@
)
from xdsl.backend.register_type import RegisterType
from xdsl.dialects.builtin import (
- AnyIntegerAttr,
IndexType,
IntegerAttr,
IntegerType,
@@ -443,7 +442,7 @@ def print_op_type(self, printer: Printer) -> None:
AssemblyInstructionArg: TypeAlias = (
- AnyIntegerAttr | LabelAttr | SSAValue | IntRegisterType | str | int
+ IntegerAttr | LabelAttr | SSAValue | IntRegisterType | str | int
)
@@ -501,7 +500,7 @@ def _append_comment(line: str, comment: StringAttr | None) -> str:
def _assembly_arg_str(arg: AssemblyInstructionArg) -> str:
- if isa(arg, AnyIntegerAttr):
+ if isa(arg, IntegerAttr):
return f"{arg.value.data}"
elif isinstance(arg, int):
return f"{arg}"
@@ -661,7 +660,7 @@ class RdImmIntegerOperation(RISCVCustomFormatOperation, RISCVInstruction, ABC):
def __init__(
self,
- immediate: int | AnyIntegerAttr | str | LabelAttr,
+ immediate: int | IntegerAttr | str | LabelAttr,
*,
rd: IntRegisterType | str | None = None,
comment: str | StringAttr | None = None,
@@ -1159,13 +1158,13 @@ class CsrReadWriteOperation(RISCVCustomFormatOperation, RISCVInstruction, ABC):
rd = result_def(IntRegisterType)
rs1 = operand_def(IntRegisterType)
- csr = attr_def(AnyIntegerAttr)
+ csr = attr_def(IntegerAttr)
writeonly = opt_attr_def(UnitAttr)
def __init__(
self,
rs1: Operation | SSAValue,
- csr: AnyIntegerAttr,
+ csr: IntegerAttr,
*,
writeonly: bool = False,
rd: IntRegisterType | str | None = None,
@@ -1237,13 +1236,13 @@ class CsrBitwiseOperation(RISCVCustomFormatOperation, RISCVInstruction, ABC):
rd = result_def(IntRegisterType)
rs1 = operand_def(IntRegisterType)
- csr = attr_def(AnyIntegerAttr)
+ csr = attr_def(IntegerAttr)
readonly = opt_attr_def(UnitAttr)
def __init__(
self,
rs1: Operation | SSAValue,
- csr: AnyIntegerAttr,
+ csr: IntegerAttr,
*,
readonly: bool = False,
rd: IntRegisterType | str | None = None,
@@ -1312,14 +1311,14 @@ class CsrReadWriteImmOperation(RISCVCustomFormatOperation, RISCVInstruction, ABC
"""
rd = result_def(IntRegisterType)
- csr = attr_def(AnyIntegerAttr)
- immediate = attr_def(AnyIntegerAttr)
+ csr = attr_def(IntegerAttr)
+ immediate = attr_def(IntegerAttr)
writeonly = opt_attr_def(UnitAttr)
def __init__(
self,
- csr: AnyIntegerAttr,
- immediate: AnyIntegerAttr,
+ csr: IntegerAttr,
+ immediate: IntegerAttr,
*,
writeonly: bool = False,
rd: IntRegisterType | str | None = None,
@@ -1394,13 +1393,13 @@ class CsrBitwiseImmOperation(RISCVCustomFormatOperation, RISCVInstruction, ABC):
"""
rd = result_def(IntRegisterType)
- csr = attr_def(AnyIntegerAttr)
- immediate = attr_def(AnyIntegerAttr)
+ csr = attr_def(IntegerAttr)
+ immediate = attr_def(IntegerAttr)
def __init__(
self,
- csr: AnyIntegerAttr,
- immediate: AnyIntegerAttr,
+ csr: IntegerAttr,
+ immediate: IntegerAttr,
*,
rd: IntRegisterType | str | None = None,
comment: str | StringAttr | None = None,
@@ -3999,7 +3998,7 @@ def parse_immediate_value(
)
-def print_immediate_value(printer: Printer, immediate: AnyIntegerAttr | LabelAttr):
+def print_immediate_value(printer: Printer, immediate: IntegerAttr | LabelAttr):
match immediate:
case IntegerAttr():
printer.print(immediate.value.data)
diff --git a/xdsl/dialects/riscv_func.py b/xdsl/dialects/riscv_func.py
index 854fa13e9d..a66d293c67 100644
--- a/xdsl/dialects/riscv_func.py
+++ b/xdsl/dialects/riscv_func.py
@@ -4,7 +4,6 @@
from xdsl.dialects import riscv
from xdsl.dialects.builtin import (
- AnyIntegerAttr,
FunctionType,
IntegerAttr,
IntegerType,
@@ -48,7 +47,7 @@ class SyscallOp(IRDLOperation):
def __init__(
self,
- num: int | AnyIntegerAttr,
+ num: int | IntegerAttr,
has_result: bool = False,
operands: list[SSAValue | Operation] = [],
):
diff --git a/xdsl/dialects/seq.py b/xdsl/dialects/seq.py
index 199c32c4ca..1fa49621e2 100644
--- a/xdsl/dialects/seq.py
+++ b/xdsl/dialects/seq.py
@@ -8,7 +8,6 @@
from typing import ClassVar
from xdsl.dialects.builtin import (
- AnyIntegerAttr,
IntegerAttr,
IntegerType,
TypeAttribute,
@@ -53,11 +52,11 @@ class ClockDividerOp(IRDLOperation):
name = "seq.clock_div"
- pow2 = attr_def(AnyIntegerAttr)
+ pow2 = attr_def(IntegerAttr)
clockIn = operand_def(ClockType)
clockOut = result_def(ClockType)
- def __init__(self, clockIn: SSAValue | Operation, pow2: int | AnyIntegerAttr):
+ def __init__(self, clockIn: SSAValue | Operation, pow2: int | IntegerAttr):
if isinstance(pow2, int):
pow2 = IntegerAttr(pow2, IntegerType(8))
super().__init__(
diff --git a/xdsl/dialects/transform.py b/xdsl/dialects/transform.py
index 7f22996303..c2b16ee5cd 100644
--- a/xdsl/dialects/transform.py
+++ b/xdsl/dialects/transform.py
@@ -5,7 +5,6 @@
from typing import Annotated, TypeAlias
from xdsl.dialects.builtin import (
- AnyIntegerAttr,
ArrayAttr,
DenseArrayBase,
DictionaryAttr,
@@ -157,7 +156,7 @@ class GetConsumersOfResultOp(IRDLOperation):
name = "transform.get_consumers_of_result"
- result_number = prop_def(AnyIntegerAttr)
+ result_number = prop_def(IntegerAttr)
target = operand_def(TransformOpHandleType)
consumers = result_def(TransformOpHandleType)
@@ -200,7 +199,7 @@ class GetParentOp(IRDLOperation):
allow_empty_results = opt_prop_def(UnitAttr)
op_name = opt_prop_def(StringAttr)
deduplicate = opt_prop_def(UnitAttr)
- nth_parent = prop_def(AnyIntegerAttr)
+ nth_parent = prop_def(IntegerAttr)
target = operand_def(TransformOpHandleType)
parent_result = result_def(TransformOpHandleType)
@@ -211,7 +210,7 @@ def __init__(
allow_empty_results: bool = False,
op_name: str | None = None,
deduplicate: bool = False,
- nth_parent: int | AnyIntegerAttr = 1,
+ nth_parent: int | IntegerAttr = 1,
):
if isinstance(nth_parent, int):
nth_parent = IntegerAttr(nth_parent, IntegerType(64))
@@ -236,13 +235,13 @@ class GetProducerOfOperandOp(IRDLOperation):
name = "transform.get_producer_of_operand"
- operand_number = prop_def(AnyIntegerAttr)
+ operand_number = prop_def(IntegerAttr)
target = operand_def(TransformOpHandleType)
producer = result_def(TransformOpHandleType)
def __init__(
self,
- operand_number: int | AnyIntegerAttr,
+ operand_number: int | IntegerAttr,
target: SSAValue,
):
if isinstance(operand_number, int):
@@ -326,7 +325,7 @@ class IncludeOp(IRDLOperation):
def __init__(
self,
target: str,
- failure_propagation_mode: FailurePropagationModeAttr | AnyIntegerAttr | int,
+ failure_propagation_mode: FailurePropagationModeAttr | IntegerAttr | int,
operands_input: Sequence[SSAValue],
):
if isinstance(failure_propagation_mode, int):
@@ -395,13 +394,13 @@ class MatchParamCmpIOp(IRDLOperation):
name = "transform.match.param.cmpi"
predicate = prop_def(
- AnyIntegerAttr
+ IntegerAttr
) # Valid values given in xdsl/xdsl/dialects/arith.py
param = operand_def(TransformParamHandleType)
reference = operand_def(TransformParamHandleType)
def __init__(
- self, predicate: int | AnyIntegerAttr, param: SSAValue, reference: SSAValue
+ self, predicate: int | IntegerAttr, param: SSAValue, reference: SSAValue
):
if isinstance(predicate, int):
predicate = IntegerAttr(predicate, IntegerType(64))
@@ -456,9 +455,9 @@ class SplitHandleOp(IRDLOperation):
name = "transform.split_handle"
- pass_through_empty_handle = prop_def(AnyIntegerAttr)
- fail_on_payload_too_small = prop_def(AnyIntegerAttr)
- overflow_result = opt_prop_def(AnyIntegerAttr)
+ pass_through_empty_handle = prop_def(IntegerAttr)
+ fail_on_payload_too_small = prop_def(IntegerAttr)
+ overflow_result = opt_prop_def(IntegerAttr)
handle = operand_def(TransformHandleType)
results_ = var_result_def(TransformHandleType)
@@ -466,9 +465,9 @@ def __init__(
self,
handle: SSAValue,
number_of_results: int,
- pass_through_empty_handle: int | AnyIntegerAttr | bool = False,
- fail_on_payload_too_small: int | AnyIntegerAttr | bool = False,
- overflow_result: int | AnyIntegerAttr | None = None,
+ pass_through_empty_handle: int | IntegerAttr | bool = False,
+ fail_on_payload_too_small: int | IntegerAttr | bool = False,
+ overflow_result: int | IntegerAttr | None = None,
):
if isinstance(pass_through_empty_handle, bool):
pass_through_empty_handle = IntegerAttr(
@@ -528,7 +527,7 @@ class SequenceOp(IRDLOperation):
def __init__(
self,
- failure_propagation_mode: FailurePropagationModeAttr | AnyIntegerAttr | int,
+ failure_propagation_mode: FailurePropagationModeAttr | IntegerAttr | int,
root: Sequence[SSAValue],
extra_bindings: Sequence[SSAValue],
body: Region,
@@ -766,7 +765,7 @@ class MatchOp(IRDLOperation):
name = "transform.structured.match"
ops = opt_prop_def(ArrayAttr[StringAttr])
- interface = opt_prop_def(AnyIntegerAttr)
+ interface = opt_prop_def(IntegerAttr)
op_attrs = opt_prop_def(DictionaryAttr)
filter_result_types = opt_prop_def(TypeAttribute)
filter_operand_types = opt_prop_def(TypeAttribute)
@@ -778,7 +777,7 @@ def __init__(
self,
target: SSAValue,
ops: Sequence[str] | ArrayAttr[StringAttr] | None = None,
- interface: int | AnyIntegerAttr | str | None = None,
+ interface: int | IntegerAttr | str | None = None,
op_attrs: dict[str, Attribute] | DictionaryAttr | None = None,
filter_result_types: TypeAttribute | None = None,
filter_operand_types: TypeAttribute | None = None,
diff --git a/xdsl/dialects/x86/assembly.py b/xdsl/dialects/x86/assembly.py
index 2d507dfd99..9f350ce15f 100644
--- a/xdsl/dialects/x86/assembly.py
+++ b/xdsl/dialects/x86/assembly.py
@@ -3,7 +3,6 @@
from typing import TypeAlias
from xdsl.dialects.builtin import (
- AnyIntegerAttr,
IndexType,
IntegerAttr,
IntegerType,
@@ -20,7 +19,7 @@
from .register import AVXRegisterType, GeneralRegisterType, RFLAGSRegisterType
AssemblyInstructionArg: TypeAlias = (
- AnyIntegerAttr | SSAValue | GeneralRegisterType | str | int | LabelAttr
+ IntegerAttr | SSAValue | GeneralRegisterType | str | int | LabelAttr
)
@@ -34,7 +33,7 @@ def append_comment(line: str, comment: StringAttr | None) -> str:
def assembly_arg_str(arg: AssemblyInstructionArg) -> str:
- if isa(arg, AnyIntegerAttr):
+ if isa(arg, IntegerAttr):
return f"{arg.value.data}"
elif isinstance(arg, int):
return f"{arg}"
@@ -97,7 +96,7 @@ def parse_optional_immediate_value(
return LabelAttr(immediate)
-def print_immediate_value(printer: Printer, immediate: AnyIntegerAttr | LabelAttr):
+def print_immediate_value(printer: Printer, immediate: IntegerAttr | LabelAttr):
match immediate:
case IntegerAttr():
printer.print(immediate.value.data)
@@ -106,7 +105,7 @@ def print_immediate_value(printer: Printer, immediate: AnyIntegerAttr | LabelAtt
def memory_access_str(
- register: AssemblyInstructionArg, offset: AnyIntegerAttr | None
+ register: AssemblyInstructionArg, offset: IntegerAttr | None
) -> str:
register_str = assembly_arg_str(register)
if offset is not None:
diff --git a/xdsl/dialects/x86/ops.py b/xdsl/dialects/x86/ops.py
index cf526b8d8a..6e6eb6140f 100644
--- a/xdsl/dialects/x86/ops.py
+++ b/xdsl/dialects/x86/ops.py
@@ -8,7 +8,6 @@
from typing_extensions import Self
from xdsl.dialects.builtin import (
- AnyIntegerAttr,
IntegerAttr,
IntegerType,
ModuleOp,
@@ -554,7 +553,7 @@ class R_RM_Operation(Generic[R1InvT, R2InvT], IRDLOperation, X86Instruction, ABC
r1 = operand_def(R1InvT)
r2 = operand_def(R2InvT)
- offset = opt_attr_def(AnyIntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
result = result_def(R1InvT)
@@ -562,7 +561,7 @@ def __init__(
self,
r1: Operation | SSAValue,
r2: Operation | SSAValue,
- offset: int | AnyIntegerAttr | None = None,
+ offset: int | IntegerAttr | None = None,
*,
comment: str | StringAttr | None = None,
result: R1InvT,
@@ -713,14 +712,14 @@ class R_RI_Operation(Generic[R1InvT], IRDLOperation, X86Instruction, ABC):
"""
r1 = operand_def(R1InvT)
- immediate = attr_def(AnyIntegerAttr)
+ immediate = attr_def(IntegerAttr)
result = result_def(R1InvT)
def __init__(
self,
r1: Operation | SSAValue,
- immediate: int | AnyIntegerAttr,
+ immediate: int | IntegerAttr,
*,
comment: str | StringAttr | None = None,
result: R1InvT,
@@ -845,13 +844,13 @@ class M_MR_Operation(Generic[R1InvT, R2InvT], IRDLOperation, X86Instruction, ABC
r1 = operand_def(R1InvT)
r2 = operand_def(R2InvT)
- offset = opt_attr_def(AnyIntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
def __init__(
self,
r1: Operation | SSAValue,
r2: Operation | SSAValue,
- offset: int | AnyIntegerAttr | None,
+ offset: int | IntegerAttr | None,
*,
comment: str | StringAttr | None = None,
):
@@ -964,14 +963,14 @@ class M_MI_Operation(Generic[R1InvT], IRDLOperation, X86Instruction, ABC):
"""
r1 = operand_def(R1InvT)
- immediate = attr_def(AnyIntegerAttr)
- offset = opt_attr_def(AnyIntegerAttr)
+ immediate = attr_def(IntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
def __init__(
self,
r1: Operation | SSAValue,
- offset: int | AnyIntegerAttr | None,
- immediate: int | AnyIntegerAttr,
+ offset: int | IntegerAttr | None,
+ immediate: int | IntegerAttr,
*,
comment: str | StringAttr | None = None,
):
@@ -1099,14 +1098,14 @@ class R_RRI_Operation(Generic[R1InvT, R2InvT], IRDLOperation, X86Instruction, AB
"""
r2 = operand_def(R2InvT)
- immediate = attr_def(AnyIntegerAttr)
+ immediate = attr_def(IntegerAttr)
r1 = result_def(R1InvT)
def __init__(
self,
r2: Operation | SSAValue,
- immediate: int | AnyIntegerAttr,
+ immediate: int | IntegerAttr,
*,
comment: str | StringAttr | None = None,
r1: R1InvT,
@@ -1160,16 +1159,16 @@ class R_RMI_Operation(Generic[R1InvT, R2InvT], IRDLOperation, X86Instruction, AB
"""
r2 = operand_def(R2InvT)
- immediate = attr_def(AnyIntegerAttr)
- offset = opt_attr_def(AnyIntegerAttr)
+ immediate = attr_def(IntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
r1 = result_def(R1InvT)
def __init__(
self,
r2: Operation | SSAValue,
- immediate: int | AnyIntegerAttr,
- offset: int | AnyIntegerAttr | None,
+ immediate: int | IntegerAttr,
+ offset: int | IntegerAttr | None,
*,
comment: str | StringAttr | None = None,
r1: R1InvT,
@@ -1243,7 +1242,7 @@ class M_PushOp(IRDLOperation, X86Instruction):
rsp_input = operand_def(GeneralRegisterType("rsp"))
source = operand_def(R1InvT)
- offset = opt_attr_def(AnyIntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
rsp_output = result_def(GeneralRegisterType("rsp"))
def __init__(
@@ -1252,7 +1251,7 @@ def __init__(
source: Operation | SSAValue,
*,
comment: str | StringAttr | None = None,
- offset: int | AnyIntegerAttr | None,
+ offset: int | IntegerAttr | None,
rsp_output: GeneralRegisterType,
):
if isinstance(comment, str):
@@ -1305,7 +1304,7 @@ class M_PopOp(IRDLOperation, X86Instruction):
destination = operand_def(
GeneralRegisterType
) # the destination is a pointer to the memory location and the register itself is not modified
- offset = opt_attr_def(AnyIntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
rsp_output = result_def(GeneralRegisterType("rsp"))
def __init__(
@@ -1314,7 +1313,7 @@ def __init__(
destination: Operation | SSAValue,
*,
comment: str | StringAttr | None = None,
- offset: int | AnyIntegerAttr | None = None,
+ offset: int | IntegerAttr | None = None,
rsp_output: GeneralRegisterType,
):
if isinstance(offset, int):
@@ -1358,12 +1357,12 @@ class M_M_Operation(Generic[R1InvT], IRDLOperation, X86Instruction, ABC):
"""
source = operand_def(R1InvT)
- offset = opt_attr_def(AnyIntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
def __init__(
self,
source: Operation | SSAValue,
- offset: int | AnyIntegerAttr | None,
+ offset: int | IntegerAttr | None,
*,
comment: str | StringAttr | None = None,
):
@@ -1462,7 +1461,7 @@ class M_IDivOp(IRDLOperation, X86Instruction):
r1 = operand_def(R1InvT)
rdx_input = operand_def(GeneralRegisterType("rdx"))
rax_input = operand_def(GeneralRegisterType("rax"))
- offset = opt_attr_def(AnyIntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
rdx_output = result_def(GeneralRegisterType("rdx"))
rax_output = result_def(GeneralRegisterType("rax"))
@@ -1472,7 +1471,7 @@ def __init__(
r1: Operation | SSAValue,
rdx_input: Operation | SSAValue,
rax_input: Operation | SSAValue,
- offset: int | AnyIntegerAttr | None = None,
+ offset: int | IntegerAttr | None = None,
*,
comment: str | StringAttr | None = None,
rdx_output: GeneralRegisterType,
@@ -1525,7 +1524,7 @@ class M_ImulOp(IRDLOperation, X86Instruction):
r1 = operand_def(GeneralRegisterType)
rax_input = operand_def(GeneralRegisterType("rax"))
- offset = opt_attr_def(AnyIntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
rdx_output = result_def(GeneralRegisterType("rdx"))
rax_output = result_def(GeneralRegisterType("rax"))
@@ -1534,7 +1533,7 @@ def __init__(
self,
r1: Operation | SSAValue,
rax_input: Operation | SSAValue,
- offset: int | AnyIntegerAttr | None = None,
+ offset: int | IntegerAttr | None = None,
*,
comment: str | StringAttr | None = None,
rdx_output: GeneralRegisterType,
@@ -1815,7 +1814,7 @@ class RM_CmpOp(IRDLOperation, X86Instruction):
r1 = operand_def(GeneralRegisterType)
r2 = operand_def(GeneralRegisterType)
- offset = opt_attr_def(AnyIntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
result = result_def(RFLAGSRegisterType)
@@ -1823,7 +1822,7 @@ def __init__(
self,
r1: Operation | SSAValue,
r2: Operation | SSAValue,
- offset: int | AnyIntegerAttr | None,
+ offset: int | IntegerAttr | None,
*,
comment: str | StringAttr | None = None,
result: RFLAGSRegisterType,
@@ -1873,14 +1872,14 @@ class RI_CmpOp(IRDLOperation, X86Instruction):
name = "x86.ri.cmp"
r1 = operand_def(GeneralRegisterType)
- immediate = attr_def(AnyIntegerAttr)
+ immediate = attr_def(IntegerAttr)
result = result_def(RFLAGSRegisterType)
def __init__(
self,
r1: Operation | SSAValue,
- immediate: int | AnyIntegerAttr,
+ immediate: int | IntegerAttr,
*,
comment: str | StringAttr | None = None,
result: RFLAGSRegisterType,
@@ -1926,7 +1925,7 @@ class MR_CmpOp(IRDLOperation, X86Instruction):
r1 = operand_def(GeneralRegisterType)
r2 = operand_def(GeneralRegisterType)
- offset = opt_attr_def(AnyIntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
result = result_def(RFLAGSRegisterType)
@@ -1934,7 +1933,7 @@ def __init__(
self,
r1: Operation | SSAValue,
r2: Operation | SSAValue,
- offset: int | AnyIntegerAttr | None,
+ offset: int | IntegerAttr | None,
*,
comment: str | StringAttr | None = None,
result: RFLAGSRegisterType,
@@ -1984,16 +1983,16 @@ class MI_CmpOp(IRDLOperation, X86Instruction):
name = "x86.mi.cmp"
r1 = operand_def(GeneralRegisterType)
- immediate = attr_def(AnyIntegerAttr)
- offset = opt_attr_def(AnyIntegerAttr)
+ immediate = attr_def(IntegerAttr)
+ offset = opt_attr_def(IntegerAttr)
result = result_def(RFLAGSRegisterType)
def __init__(
self,
r1: Operation | SSAValue,
- offset: int | AnyIntegerAttr | None,
- immediate: int | AnyIntegerAttr,
+ offset: int | IntegerAttr | None,
+ immediate: int | IntegerAttr,
*,
comment: str | StringAttr | None = None,
result: RFLAGSRegisterType,
diff --git a/xdsl/interpreters/arith.py b/xdsl/interpreters/arith.py
index eba3fbb3d8..ff158df829 100644
--- a/xdsl/interpreters/arith.py
+++ b/xdsl/interpreters/arith.py
@@ -2,7 +2,7 @@
from typing import cast
from xdsl.dialects import arith
-from xdsl.dialects.builtin import AnyFloatAttr, AnyIntegerAttr
+from xdsl.dialects.builtin import AnyFloatAttr, IntegerAttr
from xdsl.interpreter import (
Interpreter,
InterpreterFunctions,
@@ -23,10 +23,10 @@ def run_constant(
) -> PythonValues:
value = op.value
interpreter.interpreter_assert(
- isattr(op.value, base(AnyIntegerAttr) | base(AnyFloatAttr)),
+ isattr(op.value, base(IntegerAttr) | base(AnyFloatAttr)),
f"arith.constant not implemented for {type(op.value)}",
)
- value = cast(AnyIntegerAttr, op.value)
+ value = cast(IntegerAttr, op.value)
return (value.value.data,)
@impl(arith.SubiOp)
diff --git a/xdsl/interpreters/builtin.py b/xdsl/interpreters/builtin.py
index 1d94b8d05b..1cfa5062e8 100644
--- a/xdsl/interpreters/builtin.py
+++ b/xdsl/interpreters/builtin.py
@@ -3,9 +3,9 @@
from xdsl.dialects import builtin
from xdsl.dialects.builtin import (
AnyFloatAttr,
- AnyIntegerAttr,
Float32Type,
Float64Type,
+ IntegerAttr,
IntegerType,
PackableType,
UnrealizedConversionCastOp,
@@ -75,8 +75,8 @@ def float32_attr_value(
def integer_attr_value(
self, interpreter: Interpreter, attr: Attribute, attr_type: IntegerType
) -> float:
- interpreter.interpreter_assert(isa(attr, AnyIntegerAttr))
- attr = cast(AnyIntegerAttr, attr)
+ interpreter.interpreter_assert(isa(attr, IntegerAttr))
+ attr = cast(IntegerAttr, attr)
return attr.value.data
@impl_attr(builtin.MemRefType)
diff --git a/xdsl/interpreters/riscv.py b/xdsl/interpreters/riscv.py
index 2d98872437..fc7e5859c2 100644
--- a/xdsl/interpreters/riscv.py
+++ b/xdsl/interpreters/riscv.py
@@ -5,7 +5,6 @@
from xdsl.dialects import builtin, riscv
from xdsl.dialects.builtin import (
- AnyIntegerAttr,
IndexType,
IntegerAttr,
IntegerType,
@@ -214,7 +213,7 @@ def get_data_value(self, interpreter: Interpreter, key: str) -> Any:
return data[key]
def get_immediate_value(
- self, interpreter: Interpreter, imm: AnyIntegerAttr | riscv.LabelAttr
+ self, interpreter: Interpreter, imm: IntegerAttr | riscv.LabelAttr
) -> int | ptr.RawPtr:
match imm:
case IntegerAttr():
diff --git a/xdsl/parser/attribute_parser.py b/xdsl/parser/attribute_parser.py
index 499470ef86..1dea5b0661 100644
--- a/xdsl/parser/attribute_parser.py
+++ b/xdsl/parser/attribute_parser.py
@@ -18,7 +18,6 @@
AnyFloat,
AnyFloatAttr,
AnyFloatConstr,
- AnyIntegerAttr,
AnyTensorType,
AnyUnrankedTensorType,
AnyVectorType,
@@ -1203,7 +1202,7 @@ def parse_optional_location(self) -> LocationAttr | None:
def parse_optional_builtin_int_or_float_attr(
self,
- ) -> AnyIntegerAttr | AnyFloatAttr | None:
+ ) -> IntegerAttr | AnyFloatAttr | None:
bool = self.try_parse_builtin_boolean_attr()
if bool is not None:
return bool
diff --git a/xdsl/tools/tblgen_to_py.py b/xdsl/tools/tblgen_to_py.py
index 29dc42db7a..2b7ec07458 100644
--- a/xdsl/tools/tblgen_to_py.py
+++ b/xdsl/tools/tblgen_to_py.py
@@ -345,7 +345,7 @@ def _resolve_prop_constraint(self, rec: TblgenRecord | str) -> str:
""")
if (
- "AnyIntegerAttrBase" in rec.superclasses
+ "IntegerAttrBase" in rec.superclasses
or "SignlessIntegerAttrBase" in rec.superclasses
or "SignedIntegerAttrBase" in rec.superclasses
or "UnsignedIntegerAttrBase" in rec.superclasses
diff --git a/xdsl/transforms/canonicalization_patterns/cf.py b/xdsl/transforms/canonicalization_patterns/cf.py
index 0af86049d0..1532645a76 100644
--- a/xdsl/transforms/canonicalization_patterns/cf.py
+++ b/xdsl/transforms/canonicalization_patterns/cf.py
@@ -3,7 +3,6 @@
from xdsl.dialects import arith, cf
from xdsl.dialects.builtin import (
- AnyIntegerAttr,
BoolAttr,
DenseIntOrFPElementsAttr,
IntegerAttr,
@@ -289,7 +288,7 @@ def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
def drop_case_helper(
rewriter: PatternRewriter,
op: cf.SwitchOp,
- predicate: Callable[[AnyIntegerAttr, Block, Sequence[Operation | SSAValue]], bool],
+ predicate: Callable[[IntegerAttr, Block, Sequence[Operation | SSAValue]], bool],
):
case_values = op.case_values
if case_values is None:
@@ -306,11 +305,11 @@ def drop_case_helper(
op.case_operand,
strict=True,
):
- int_switch_case = cast(AnyIntegerAttr, switch_case)
+ int_switch_case = cast(IntegerAttr, switch_case)
if predicate(int_switch_case, block, operands):
requires_change = True
continue
- new_case_values.append(cast(AnyIntegerAttr, switch_case).value.data)
+ new_case_values.append(cast(IntegerAttr, switch_case).value.data)
new_case_blocks.append(block)
new_case_operands.append(operands)
@@ -346,7 +345,7 @@ class DropSwitchCasesThatMatchDefault(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
def predicate(
- switch_case: AnyIntegerAttr,
+ switch_case: IntegerAttr,
block: Block,
operands: Sequence[Operation | SSAValue],
) -> bool:
@@ -539,7 +538,7 @@ def match_and_rewrite(self, op: cf.SwitchOp, rewriter: PatternRewriter):
else:
def predicate(
- switch_case: AnyIntegerAttr,
+ switch_case: IntegerAttr,
block: Block,
operands: Sequence[Operation | SSAValue],
) -> bool:
diff --git a/xdsl/transforms/canonicalization_patterns/csl.py b/xdsl/transforms/canonicalization_patterns/csl.py
index e8f1771577..aa1a605a2c 100644
--- a/xdsl/transforms/canonicalization_patterns/csl.py
+++ b/xdsl/transforms/canonicalization_patterns/csl.py
@@ -1,7 +1,7 @@
from xdsl.dialects import arith
from xdsl.dialects.builtin import (
AffineMapAttr,
- AnyIntegerAttr,
+ IntegerAttr,
)
from xdsl.dialects.csl import csl
from xdsl.ir import OpResult
@@ -34,7 +34,7 @@ def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> N
if (
isinstance(offset_op.offset, OpResult)
and isinstance(cnst := offset_op.offset.op, arith.ConstantOp)
- and isa(attr_val := cnst.value, AnyIntegerAttr)
+ and isa(attr_val := cnst.value, IntegerAttr)
):
tensor_access = AffineMap.from_callable(
lambda x: (x + attr_val.value.data,)
@@ -101,7 +101,7 @@ def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> N
if (
isinstance(stride_op.stride, OpResult)
and isinstance(cnst := stride_op.stride.op, arith.ConstantOp)
- and isa(attr_val := cnst.value, AnyIntegerAttr)
+ and isa(attr_val := cnst.value, IntegerAttr)
):
tensor_access = AffineMap.from_callable(
lambda x: (x * attr_val.value.data,)
diff --git a/xdsl/transforms/canonicalization_patterns/utils.py b/xdsl/transforms/canonicalization_patterns/utils.py
index bab4fdb59a..929e4d3210 100644
--- a/xdsl/transforms/canonicalization_patterns/utils.py
+++ b/xdsl/transforms/canonicalization_patterns/utils.py
@@ -1,9 +1,9 @@
from xdsl.dialects import arith
-from xdsl.dialects.builtin import AnyIntegerAttr, IntegerAttr
+from xdsl.dialects.builtin import IntegerAttr
from xdsl.ir import SSAValue
-def const_evaluate_operand_attribute(operand: SSAValue) -> AnyIntegerAttr | None:
+def const_evaluate_operand_attribute(operand: SSAValue) -> IntegerAttr | None:
"""
Try to constant evaluate an SSA value, returning None on failure.
"""
diff --git a/xdsl/transforms/linalg_to_csl.py b/xdsl/transforms/linalg_to_csl.py
index f50c50617d..1816e619c9 100644
--- a/xdsl/transforms/linalg_to_csl.py
+++ b/xdsl/transforms/linalg_to_csl.py
@@ -4,10 +4,10 @@
from xdsl.dialects import arith, linalg
from xdsl.dialects.builtin import (
AnyFloatAttr,
- AnyIntegerAttr,
DenseIntOrFPElementsAttr,
Float16Type,
Float32Type,
+ IntegerAttr,
MemRefType,
ModuleOp,
)
@@ -39,7 +39,7 @@ def match_op_for_precision(
raise ValueError(f"Unsupported element type {prec}")
-def get_scalar_const(op: SSAValue) -> AnyFloatAttr | AnyIntegerAttr | None:
+def get_scalar_const(op: SSAValue) -> AnyFloatAttr | IntegerAttr | None:
"""Returns the value of a scalar arith.constant, or None if not a constant or not scalar)."""
if (
isinstance(op, OpResult)
From 4504a60541f4366ef50877a6cff82a053a2ba51b Mon Sep 17 00:00:00 2001
From: Alex Rice
Date: Wed, 5 Feb 2025 15:19:25 +0000
Subject: [PATCH 13/23] dialects: (builtin) remove AnyFloatAttr(Constr)?
(#3844)
Adds a default type to `FloatAttr` and removes `AnyFloatAttr` and
`AnyFloatAttrConstr`
---
tests/test_parser.py | 26 ++++++++---------
tests/test_printer.py | 3 +-
xdsl/dialects/builtin.py | 28 +++++++++----------
xdsl/dialects/csl/csl.py | 7 ++---
xdsl/interpreters/arith.py | 4 +--
xdsl/interpreters/builtin.py | 10 +++----
xdsl/parser/attribute_parser.py | 3 +-
xdsl/printer.py | 4 +--
.../canonicalization_patterns/arith.py | 12 ++++----
.../convert_stencil_to_csl_stencil.py | 4 +--
xdsl/transforms/linalg_to_csl.py | 4 +--
xdsl/transforms/lower_csl_stencil.py | 3 +-
12 files changed, 51 insertions(+), 57 deletions(-)
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 3e660f6960..04bc768952 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -6,10 +6,10 @@
from xdsl.context import MLContext
from xdsl.dialects.builtin import (
- AnyFloatAttr,
ArrayAttr,
Builtin,
DictionaryAttr,
+ FloatAttr,
IntAttr,
IntegerAttr,
IntegerType,
@@ -807,22 +807,22 @@ def test_parse_number(
("24: i32", IntegerAttr(24, 32)),
("0: index", IntegerAttr.from_index_int_value(0)),
("-64: i64", IntegerAttr(-64, 64)),
- ("-64.4: f64", AnyFloatAttr(-64.4, 64)),
- ("32.4: f32", AnyFloatAttr(32.4, 32)),
- ("0x7e00 : f16", AnyFloatAttr(float("nan"), 16)),
- ("0x7c00 : f16", AnyFloatAttr(float("inf"), 16)),
- ("0xfc00 : f16", AnyFloatAttr(float("-inf"), 16)),
- ("0x7fc00000 : f32", AnyFloatAttr(float("nan"), 32)),
- ("0x7f800000 : f32", AnyFloatAttr(float("inf"), 32)),
- ("0xff800000 : f32", AnyFloatAttr(float("-inf"), 32)),
- ("0x7ff8000000000000 : f64", AnyFloatAttr(float("nan"), 64)),
- ("0x7ff0000000000000 : f64", AnyFloatAttr(float("inf"), 64)),
- ("0xfff0000000000000 : f64", AnyFloatAttr(float("-inf"), 64)),
+ ("-64.4: f64", FloatAttr(-64.4, 64)),
+ ("32.4: f32", FloatAttr(32.4, 32)),
+ ("0x7e00 : f16", FloatAttr(float("nan"), 16)),
+ ("0x7c00 : f16", FloatAttr(float("inf"), 16)),
+ ("0xfc00 : f16", FloatAttr(float("-inf"), 16)),
+ ("0x7fc00000 : f32", FloatAttr(float("nan"), 32)),
+ ("0x7f800000 : f32", FloatAttr(float("inf"), 32)),
+ ("0xff800000 : f32", FloatAttr(float("-inf"), 32)),
+ ("0x7ff8000000000000 : f64", FloatAttr(float("nan"), 64)),
+ ("0x7ff0000000000000 : f64", FloatAttr(float("inf"), 64)),
+ ("0xfff0000000000000 : f64", FloatAttr(float("-inf"), 64)),
# ("3 : f64", None), # todo this fails in mlir-opt but not in xdsl
],
)
def test_parse_optional_builtin_int_or_float_attr(
- text: str, expected_value: IntegerAttr | AnyFloatAttr | None
+ text: str, expected_value: IntegerAttr | FloatAttr | None
):
parser = Parser(MLContext(), text)
if expected_value is None:
diff --git a/tests/test_printer.py b/tests/test_printer.py
index 9cf49c4c2e..895b3fa58c 100644
--- a/tests/test_printer.py
+++ b/tests/test_printer.py
@@ -11,7 +11,6 @@
from xdsl.dialects.arith import AddiOp, Arith, ConstantOp
from xdsl.dialects.builtin import (
AnyFloat,
- AnyFloatAttr,
Builtin,
FloatAttr,
FunctionType,
@@ -814,7 +813,7 @@ def _test_float_attr(value: float, type: AnyFloat):
def test_float_attr_specials():
printer = Printer()
- def _test_attr_print(expected: str, attr: AnyFloatAttr):
+ def _test_attr_print(expected: str, attr: FloatAttr):
io = StringIO()
printer.stream = io
printer.print_attribute(attr)
diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py
index d856658cfc..cf101e9c1e 100644
--- a/xdsl/dialects/builtin.py
+++ b/xdsl/dialects/builtin.py
@@ -905,7 +905,9 @@ def __hash__(self):
return hash(self.data)
-_FloatAttrType = TypeVar("_FloatAttrType", bound=AnyFloat, covariant=True)
+_FloatAttrType = TypeVar(
+ "_FloatAttrType", bound=AnyFloat, covariant=True, default=AnyFloat
+)
_FloatAttrTypeInvT = TypeVar("_FloatAttrTypeInvT", bound=AnyFloat)
@@ -980,10 +982,6 @@ def unpack(
return tuple(FloatAttr(value, type) for value in type.unpack(buffer, num))
-AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat]
-AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr)
-
-
@irdl_attr_definition
class ComplexType(ParametrizedAttribute, TypeAttribute):
name = "complex"
@@ -1351,13 +1349,13 @@ def iter_values(self) -> Iterator[float] | Iterator[int]:
def get_values(self) -> tuple[int, ...] | tuple[float, ...]:
return self.elt_type.unpack(self.data.data, len(self))
- def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
+ def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[FloatAttr]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.iter_unpack(self.elt_type, self.data.data)
else:
return FloatAttr.iter_unpack(self.elt_type, self.data.data)
- def get_attrs(self) -> tuple[IntegerAttr, ...] | tuple[AnyFloatAttr, ...]:
+ def get_attrs(self) -> tuple[IntegerAttr, ...] | tuple[FloatAttr, ...]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.unpack(self.elt_type, self.data.data, len(self))
else:
@@ -2134,10 +2132,10 @@ def create_dense_int(
@staticmethod
def create_dense_float(
type: RankedStructure[AnyFloat],
- data: Sequence[int | float] | Sequence[AnyFloatAttr],
+ data: Sequence[int | float] | Sequence[FloatAttr],
) -> DenseIntOrFPElementsAttr:
- if len(data) and isa(data[0], AnyFloatAttr):
- data = [el.value.data for el in cast(Sequence[AnyFloatAttr], data)]
+ if len(data) and isa(data[0], FloatAttr):
+ data = [el.value.data for el in cast(Sequence[FloatAttr], data)]
else:
data = cast(Sequence[float], data)
@@ -2168,7 +2166,7 @@ def from_list(
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
- data: Sequence[int | float] | Sequence[AnyFloatAttr],
+ data: Sequence[int | float] | Sequence[FloatAttr],
) -> DenseIntOrFPElementsAttr: ...
@staticmethod
@@ -2179,7 +2177,7 @@ def from_list(
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
- data: Sequence[int | float] | Sequence[IntegerAttr] | Sequence[AnyFloatAttr],
+ data: Sequence[int | float] | Sequence[IntegerAttr] | Sequence[FloatAttr],
) -> DenseIntOrFPElementsAttr:
# zero rank type should only hold 1 value
if not type.get_shape() and len(data) != 1:
@@ -2228,7 +2226,7 @@ def tensor_from_list(
| Sequence[float]
| Sequence[IntegerAttr[IndexType]]
| Sequence[IntegerAttr[IntegerType]]
- | Sequence[AnyFloatAttr]
+ | Sequence[FloatAttr]
),
data_type: IntegerType | IndexType | AnyFloat,
shape: Sequence[int],
@@ -2248,7 +2246,7 @@ def get_values(self) -> Sequence[int] | Sequence[float]:
"""
return self.get_element_type().unpack(self.data.data, len(self))
- def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
+ def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[FloatAttr]:
"""
Return an iterator over all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
@@ -2258,7 +2256,7 @@ def iter_attrs(self) -> Iterator[IntegerAttr] | Iterator[AnyFloatAttr]:
else:
return FloatAttr.iter_unpack(eltype, self.data.data)
- def get_attrs(self) -> Sequence[IntegerAttr] | Sequence[AnyFloatAttr]:
+ def get_attrs(self) -> Sequence[IntegerAttr] | Sequence[FloatAttr]:
"""
Return all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py
index 5c48bc9325..e47c5bf2ab 100644
--- a/xdsl/dialects/csl/csl.py
+++ b/xdsl/dialects/csl/csl.py
@@ -17,14 +17,13 @@
from xdsl.dialects import builtin
from xdsl.dialects.builtin import (
AffineMapAttr,
- AnyFloatAttr,
- AnyFloatAttrConstr,
ArrayAttr,
BoolAttr,
ContainerType,
DictionaryAttr,
Float16Type,
Float32Type,
+ FloatAttr,
FunctionType,
IntegerAttr,
IntegerType,
@@ -417,8 +416,8 @@ def get_element_type(self) -> TypeAttribute:
QueueIdAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(3)]]
-ParamAttr: TypeAlias = AnyFloatAttr | IntegerAttr
-ParamAttrConstr = AnyFloatAttrConstr | IntegerAttr.constr()
+ParamAttr: TypeAlias = FloatAttr | IntegerAttr
+ParamAttrConstr = FloatAttr.constr() | IntegerAttr.constr()
@irdl_op_definition
diff --git a/xdsl/interpreters/arith.py b/xdsl/interpreters/arith.py
index ff158df829..7cf6f4b2e6 100644
--- a/xdsl/interpreters/arith.py
+++ b/xdsl/interpreters/arith.py
@@ -2,7 +2,7 @@
from typing import cast
from xdsl.dialects import arith
-from xdsl.dialects.builtin import AnyFloatAttr, IntegerAttr
+from xdsl.dialects.builtin import FloatAttr, IntegerAttr
from xdsl.interpreter import (
Interpreter,
InterpreterFunctions,
@@ -23,7 +23,7 @@ def run_constant(
) -> PythonValues:
value = op.value
interpreter.interpreter_assert(
- isattr(op.value, base(IntegerAttr) | base(AnyFloatAttr)),
+ isattr(op.value, base(IntegerAttr) | base(FloatAttr)),
f"arith.constant not implemented for {type(op.value)}",
)
value = cast(IntegerAttr, op.value)
diff --git a/xdsl/interpreters/builtin.py b/xdsl/interpreters/builtin.py
index 1cfa5062e8..043714dccb 100644
--- a/xdsl/interpreters/builtin.py
+++ b/xdsl/interpreters/builtin.py
@@ -2,9 +2,9 @@
from xdsl.dialects import builtin
from xdsl.dialects.builtin import (
- AnyFloatAttr,
Float32Type,
Float64Type,
+ FloatAttr,
IntegerAttr,
IntegerType,
PackableType,
@@ -59,16 +59,16 @@ def run_cast(
def float64_attr_value(
self, interpreter: Interpreter, attr: Attribute, attr_type: Float64Type
) -> float:
- interpreter.interpreter_assert(isa(attr, AnyFloatAttr))
- attr = cast(AnyFloatAttr, attr)
+ interpreter.interpreter_assert(isa(attr, FloatAttr))
+ attr = cast(FloatAttr, attr)
return attr.value.data
@impl_attr(Float32Type)
def float32_attr_value(
self, interpreter: Interpreter, attr: Attribute, attr_type: Float32Type
) -> float:
- interpreter.interpreter_assert(isa(attr, AnyFloatAttr))
- attr = cast(AnyFloatAttr, attr)
+ interpreter.interpreter_assert(isa(attr, FloatAttr))
+ attr = cast(FloatAttr, attr)
return attr.value.data
@impl_attr(IntegerType)
diff --git a/xdsl/parser/attribute_parser.py b/xdsl/parser/attribute_parser.py
index 1dea5b0661..c8158672f4 100644
--- a/xdsl/parser/attribute_parser.py
+++ b/xdsl/parser/attribute_parser.py
@@ -16,7 +16,6 @@
AnyArrayAttr,
AnyDenseElement,
AnyFloat,
- AnyFloatAttr,
AnyFloatConstr,
AnyTensorType,
AnyUnrankedTensorType,
@@ -1202,7 +1201,7 @@ def parse_optional_location(self) -> LocationAttr | None:
def parse_optional_builtin_int_or_float_attr(
self,
- ) -> IntegerAttr | AnyFloatAttr | None:
+ ) -> IntegerAttr | FloatAttr | None:
bool = self.try_parse_builtin_boolean_attr()
if bool is not None:
return bool
diff --git a/xdsl/printer.py b/xdsl/printer.py
index caa7609032..61091ec13c 100644
--- a/xdsl/printer.py
+++ b/xdsl/printer.py
@@ -12,7 +12,6 @@
AffineMapAttr,
AffineSetAttr,
AnyFloat,
- AnyFloatAttr,
AnyUnrankedMemRefType,
AnyUnrankedTensorType,
AnyVectorType,
@@ -28,6 +27,7 @@
Float64Type,
Float80Type,
Float128Type,
+ FloatAttr,
FunctionType,
IndexType,
IntAttr,
@@ -333,7 +333,7 @@ def print_bytes_literal(self, bytestring: bytes):
self.print_string(chr(byte))
self.print_string('"')
- def print_float_attr(self, attribute: AnyFloatAttr):
+ def print_float_attr(self, attribute: FloatAttr):
self.print_float(attribute.value.data, attribute.type)
def print_float(self, value: float, type: AnyFloat):
diff --git a/xdsl/transforms/canonicalization_patterns/arith.py b/xdsl/transforms/canonicalization_patterns/arith.py
index caa569c529..4b8b3147c2 100644
--- a/xdsl/transforms/canonicalization_patterns/arith.py
+++ b/xdsl/transforms/canonicalization_patterns/arith.py
@@ -50,8 +50,8 @@ def match_and_rewrite(
def _fold_const_operation(
op_t: type[arith.FloatingPointLikeBinaryOperation],
- lhs: builtin.AnyFloatAttr,
- rhs: builtin.AnyFloatAttr,
+ lhs: builtin.FloatAttr,
+ rhs: builtin.FloatAttr,
) -> arith.ConstantOp | None:
match op_t:
case arith.AddfOp:
@@ -88,8 +88,8 @@ def match_and_rewrite(
if (
isinstance(op.lhs.owner, arith.ConstantOp)
and isinstance(op.rhs.owner, arith.ConstantOp)
- and isa(l := op.lhs.owner.value, builtin.AnyFloatAttr)
- and isa(r := op.rhs.owner.value, builtin.AnyFloatAttr)
+ and isa(l := op.lhs.owner.value, builtin.FloatAttr)
+ and isa(r := op.rhs.owner.value, builtin.FloatAttr)
and (cnst := _fold_const_operation(type(op), l, r))
):
rewriter.replace_matched_op(cnst)
@@ -126,8 +126,8 @@ def match_and_rewrite(
or u.fastmath is None
or arith.FastMathFlag.REASSOC not in op.fastmath.flags
or arith.FastMathFlag.REASSOC not in u.fastmath.flags
- or not isa(c1 := const1.value, builtin.AnyFloatAttr)
- or not isa(c2 := const2.value, builtin.AnyFloatAttr)
+ or not isa(c1 := const1.value, builtin.FloatAttr)
+ or not isa(c2 := const2.value, builtin.FloatAttr)
):
return
diff --git a/xdsl/transforms/convert_stencil_to_csl_stencil.py b/xdsl/transforms/convert_stencil_to_csl_stencil.py
index df8c739c04..1cda222cdb 100644
--- a/xdsl/transforms/convert_stencil_to_csl_stencil.py
+++ b/xdsl/transforms/convert_stencil_to_csl_stencil.py
@@ -6,9 +6,9 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, memref, stencil, tensor, varith
from xdsl.dialects.builtin import (
- AnyFloatAttr,
AnyTensorType,
DenseIntOrFPElementsAttr,
+ FloatAttr,
IndexType,
IntegerAttr,
IntegerType,
@@ -578,7 +578,7 @@ def match_and_rewrite(self, op: csl_stencil.AccessOp, rewriter: PatternRewriter,
return
val = dense.get_attrs()[0]
- assert isattr(val, AnyFloatAttr)
+ assert isattr(val, FloatAttr)
apply.add_coeff(op.offset, val)
rewriter.replace_op(mulf, [], new_results=[op.result])
diff --git a/xdsl/transforms/linalg_to_csl.py b/xdsl/transforms/linalg_to_csl.py
index 1816e619c9..83081f4ab2 100644
--- a/xdsl/transforms/linalg_to_csl.py
+++ b/xdsl/transforms/linalg_to_csl.py
@@ -3,10 +3,10 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, linalg
from xdsl.dialects.builtin import (
- AnyFloatAttr,
DenseIntOrFPElementsAttr,
Float16Type,
Float32Type,
+ FloatAttr,
IntegerAttr,
MemRefType,
ModuleOp,
@@ -39,7 +39,7 @@ def match_op_for_precision(
raise ValueError(f"Unsupported element type {prec}")
-def get_scalar_const(op: SSAValue) -> AnyFloatAttr | IntegerAttr | None:
+def get_scalar_const(op: SSAValue) -> FloatAttr | IntegerAttr | None:
"""Returns the value of a scalar arith.constant, or None if not a constant or not scalar)."""
if (
isinstance(op, OpResult)
diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py
index d2b4b5afd7..9384b9eb33 100644
--- a/xdsl/transforms/lower_csl_stencil.py
+++ b/xdsl/transforms/lower_csl_stencil.py
@@ -5,7 +5,6 @@
from xdsl.dialects import arith, func, memref, stencil
from xdsl.dialects.builtin import (
AffineMapAttr,
- AnyFloatAttr,
DenseIntOrFPElementsAttr,
Float16Type,
Float32Type,
@@ -265,7 +264,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
pattern = wrapper.get_param_value("pattern").value.data
neighbours = pattern - 1
empty = [FloatAttr(f, elem_t) for f in [0] + neighbours * [1]]
- cmap: dict[csl.Direction, list[AnyFloatAttr]] = {
+ cmap: dict[csl.Direction, list[FloatAttr]] = {
csl.Direction.NORTH: empty,
csl.Direction.SOUTH: empty.copy(),
csl.Direction.EAST: empty.copy(),
From 6e51e9be42609544767861d36fe2a7e231640cad Mon Sep 17 00:00:00 2001
From: Alex Rice
Date: Wed, 5 Feb 2025 16:26:57 +0000
Subject: [PATCH 14/23] installation: fix packages discovery (#3847)
Partial revert of 27bf82162f438557520caf3500f8467dae4a016b
---
pyproject.toml | 1 -
1 file changed, 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 413bef5a79..ad44c13015 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -57,7 +57,6 @@ xdsl-tblgen = "xdsl.tools.tblgen_to_py:main"
[tool.setuptools]
platforms = ["Linux", "Mac OS-X", "Unix"]
zip-safe = false
-packages = ["xdsl"]
[tool.setuptools.package-data]
xdsl = ["**/*.irdl", "py.typed", "interactive/*.tcss", "_version.py"]
From 9f50033b22952698f4f312f3e2fabf26573eb5b3 Mon Sep 17 00:00:00 2001
From: Joren Dumoulin
Date: Wed, 5 Feb 2025 17:38:21 +0100
Subject: [PATCH 15/23] dialects: (builtin) print hex str for
DenseIntOrFPElementsAttrs with > 100 elements (#3846)
mimicking mlir behaviour
This, along with #3845 should allow for very fast printing/parsing of
large dense attrs, allowing for low-overhead calling of intermediate
mlir-opt passes in our pipeline.
---
tests/filecheck/parser-printer/builtin_attrs.mlir | 2 +-
xdsl/dialects/builtin.py | 2 ++
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/tests/filecheck/parser-printer/builtin_attrs.mlir b/tests/filecheck/parser-printer/builtin_attrs.mlir
index f67312a114..838ee8f281 100644
--- a/tests/filecheck/parser-printer/builtin_attrs.mlir
+++ b/tests/filecheck/parser-printer/builtin_attrs.mlir
@@ -84,7 +84,7 @@
"func.func"() ({}) {function_type = () -> (), value = dense<"0xEEA7CC3DF47612BE2BA4173E8B75E8BDE0B915BDA3191CBE8388E0BDC826DB3DFE78273E6B037E3DEF140D3EF0B5803D4026693CD6B6E1BCE08B4DBDC3A9E63D943B163EE64E46BD808C253EB8F4893D30270CBE36696C3D045E1DBED06A703DA33EBBBD66D646BD36507BBD764D8FBD7010FA3DB6E1B53D9B83C8BDD33FA73D58AD293EB0A6123EAB2627BA40B4CB3C20E9B6BD805AB2BDE047BDBC809A743DE01ADD3D9B77D5BDCEE7043E00B8C1BDCBA80A3DBB03DA3D787C993D163968BC208510BDABFDB1BD8C07213EA34614BEAB06B73A0091413B8013B3BD768F193E7B6515BE7306833D363183BC36BC8B3CA016B7BD3E05D33DE67C28BDCABB0EBEDA2A013EA67DF6BD007EB5BA782A04BEAB69F73D16DD703D3B93A43D1BE45B3DEBAEE8BD8891F1BDF8B18F3D20EC923CE67101BE8382A8BDAB9EE7BA0006CA3AA3F224BE1B56A5BDC06B8A3DC3E6BE3D562310BB964B713C2CC11FBE4BC68F3DAEACD7BDFB093A3D00070F3EC3E4C93D5BCF0D3D1B01E13D9B7D7F3D537CD43D6BEDFBBC4BD9AEBD17BA023E569906BB86599CBD4E28073E1639F5BDF60909BE8B4727BEE4AD153EDF3C05BEB01913BEEB1A59BD03E8D4BD4BD3123D9EA381BD6058F03CD0EFF73D00747FBADBC5AEBD5054273E204DB4BD00CA683B1E28C93D3BCC2A3D9B0E683D4302923D9A3408BEABC89D3A565336BCC0A7F3BD76D1F93D68A3B93D44891C3E1685243E1B3FDBBD5E06A4BD2B4192BD2B19983C50C97B3D40A808BEC0994C3D4B3435BD0B88293D506749BDFC13063E2B7ADF3CF3B013BE"> : tensor<4x4x3x3xf32>, sym_name = "hex_f32_large_attr"} : () -> ()
- // CHECK: value = dense<[[[[0.0999296755, -0.143031895, 0.148087189], [-0.113505445, -0.0365542173, -0.152441546], [-0.109635375, 0.107007563, 0.163547486]], [[0.0620149784, 0.137775168, 0.0628470182], [0.0142303109, -0.0275530033, -0.0501822233], [0.112628482, 0.146711648, -0.0484150872]], [[0.161668777, 0.0673612952, -0.136868238], [0.0577175245, -0.153678954, 0.0586956143], [-0.0914280638, -0.04854431, -0.061355792]], [[-0.0699719638, 0.122101665, 0.0888094157], [-0.0979072675, 0.0816647038, 0.165700316], [0.143213987, -0.000637630641, 0.0248662233]]], [[[-0.0893118382, -0.0870866776, -0.0231055617], [0.0597176552, 0.107961416, -0.104232036], [0.129790515, -0.0945892334, 0.0338523798]], [[0.106452428, 0.0749444366, -0.0141737666], [-0.0352832079, -0.0869096145, 0.157255352], [-0.144800708, 0.00139637792, 0.00295358896]], [[-0.087439537, 0.149961323, -0.14589493], [0.0639771447, -0.0160146765, 0.0170575194], [-0.0893986225, 0.103037342, -0.0411347374]], [[-0.139388233, 0.126140028, -0.120356843], [-0.0013846755, -0.129068255, 0.120807014], [0.058804594, 0.0803589448, 0.0536843352]]], [[[-0.11361488, -0.11795336, 0.0701636672], [0.0179348588, -0.126411051, -0.0822801813], [-0.00176711881, 0.00154131651, -0.161081836]], [[-0.0807306394, 0.0675883293, 0.0932135805], [-0.00219937181, 0.0147274937, -0.15601033], [0.0702024326, -0.105309829, 0.0454196744]], [[0.13967514, 0.0985808596, 0.0346215777], [0.10986539, 0.0623756461, 0.103752755], [-0.0307528581, -0.0853753909, 1.276630e-01]], [[-0.00205381727, -0.0763426274, 0.131989688], [-0.119737789, -0.13382706, -0.163358852], [0.146171153, -0.130115017, -0.143652678]]], [[[-0.0530041866, -0.103958152, 0.0358460359], [-0.0633003563, 0.0293390155, 0.121062875], [-0.000974476337, -0.0853383169, 0.163407564]], [[-0.0880377293, 0.0035520792, 0.0982210487], [0.0416986756, 0.0566545539, 0.0712933764], [-0.133013159, 0.00120379531, -0.0111282673]], [[-0.118972301, 0.121981546, 0.0906437039], [0.152867377, 0.160663933, -0.107053958], [-0.0800902694, -0.0714133605, 0.0185666885]], [[0.0614712834, -0.133454323, 0.0499513149], [-0.0442393236, 0.0413895063, -0.0491707921], [0.130935609, 0.0272799339, -0.144229695]]]]> : tensor<4x4x3x3xf32>} : () -> ()
+ // CHECK: value = dense<"0xEEA7CC3DF47612BE2BA4173E8B75E8BDE0B915BDA3191CBE8388E0BDC826DB3DFE78273E6B037E3DEF140D3EF0B5803D4026693CD6B6E1BCE08B4DBDC3A9E63D943B163EE64E46BD808C253EB8F4893D30270CBE36696C3D045E1DBED06A703DA33EBBBD66D646BD36507BBD764D8FBD7010FA3DB6E1B53D9B83C8BDD33FA73D58AD293EB0A6123EAB2627BA40B4CB3C20E9B6BD805AB2BDE047BDBC809A743DE01ADD3D9B77D5BDCEE7043E00B8C1BDCBA80A3DBB03DA3D787C993D163968BC208510BDABFDB1BD8C07213EA34614BEAB06B73A0091413B8013B3BD768F193E7B6515BE7306833D363183BC36BC8B3CA016B7BD3E05D33DE67C28BDCABB0EBEDA2A013EA67DF6BD007EB5BA782A04BEAB69F73D16DD703D3B93A43D1BE45B3DEBAEE8BD8891F1BDF8B18F3D20EC923CE67101BE8382A8BDAB9EE7BA0006CA3AA3F224BE1B56A5BDC06B8A3DC3E6BE3D562310BB964B713C2CC11FBE4BC68F3DAEACD7BDFB093A3D00070F3EC3E4C93D5BCF0D3D1B01E13D9B7D7F3D537CD43D6BEDFBBC4BD9AEBD17BA023E569906BB86599CBD4E28073E1639F5BDF60909BE8B4727BEE4AD153EDF3C05BEB01913BEEB1A59BD03E8D4BD4BD3123D9EA381BD6058F03CD0EFF73D00747FBADBC5AEBD5054273E204DB4BD00CA683B1E28C93D3BCC2A3D9B0E683D4302923D9A3408BEABC89D3A565336BCC0A7F3BD76D1F93D68A3B93D44891C3E1685243E1B3FDBBD5E06A4BD2B4192BD2B19983C50C97B3D40A808BEC0994C3D4B3435BD0B88293D506749BDFC13063E2B7ADF3CF3B013BE"> : tensor<4x4x3x3xf32>}
"func.func"() ({}) {function_type = () -> (), value = "foo", sym_name = "string_attr"} : () -> ()
diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py
index cf101e9c1e..24e3b89d04 100644
--- a/xdsl/dialects/builtin.py
+++ b/xdsl/dialects/builtin.py
@@ -2315,6 +2315,8 @@ def print_without_type(self, printer: Printer):
pass
elif self.is_splat():
self._print_one_elem(data[0], printer)
+ elif len(self) > 100:
+ printer.print('"', "0x", self.data.data.hex().upper(), '"')
else:
self._print_dense_list(data, shape, printer)
printer.print_string(">")
From 34ab330e9f4ab7229b7ee9329bddec83040b4e95 Mon Sep 17 00:00:00 2001
From: Emma Urquhart <77412390+emmau678@users.noreply.github.com>
Date: Thu, 6 Feb 2025 10:05:34 +0000
Subject: [PATCH 16/23] core: Fix multiline error printing (#3849)
Fix the MultipleSpanParseError to print multiple lines in the error
output
---------
Co-authored-by: Alex Rice
---
.../parser-printer/graph_region.mlir | 20 +++++++++++++++++++
xdsl/utils/exceptions.py | 7 ++++---
2 files changed, 24 insertions(+), 3 deletions(-)
diff --git a/tests/filecheck/parser-printer/graph_region.mlir b/tests/filecheck/parser-printer/graph_region.mlir
index 59f102121b..222a1a7f5d 100644
--- a/tests/filecheck/parser-printer/graph_region.mlir
+++ b/tests/filecheck/parser-printer/graph_region.mlir
@@ -71,3 +71,23 @@ builtin.module {
}
// CHECK: SSA value %1 is referenced with an index larger than its size
+
+// -----
+
+// A block defined twice
+
+builtin.module {
+ ^blockA:
+ "test.op"() : () -> ()
+ ^blockA:
+ "test.op"() : () -> ()
+}
+
+// CHECK: /graph_region.mlir:72:4
+// CHECK-NEXT: ^blockA:
+// CHECK-NEXT: ^^^^^^^
+// CHECK-NEXT: re-declaration of block 'blockA'
+// CHECK-NEXT: originally declared here:
+// CHECK-NEXT: /graph_region.mlir:6:4
+// CHECK-NEXT: ^blockA:
+// CHECK-NEXT: ^^^^^^^
diff --git a/xdsl/utils/exceptions.py b/xdsl/utils/exceptions.py
index 413204fb7d..b927a54be8 100644
--- a/xdsl/utils/exceptions.py
+++ b/xdsl/utils/exceptions.py
@@ -87,9 +87,10 @@ class MultipleSpansParseError(ParseError):
ref_text: str | None
refs: list[tuple[Span, str | None]]
- def __repr__(self) -> str:
- res = super().__repr__() + "\n"
- res += self.ref_text or "With respect to:\n"
+ def __str__(self) -> str:
+ res = self.span.print_with_context(self.msg)
+ if self.ref_text is not None:
+ res += self.ref_text + "\n"
for span, msg in self.refs:
res += span.print_with_context(msg) + "\n"
return res
From 4088e1b7f2b1916a4cdf8f5e72597e0f970815ca Mon Sep 17 00:00:00 2001
From: emmau678
Date: Thu, 6 Feb 2025 10:22:41 +0000
Subject: [PATCH 17/23] update line numbers in filecheck
---
tests/filecheck/parser-printer/graph_region.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/filecheck/parser-printer/graph_region.mlir b/tests/filecheck/parser-printer/graph_region.mlir
index 222a1a7f5d..a36ea14877 100644
--- a/tests/filecheck/parser-printer/graph_region.mlir
+++ b/tests/filecheck/parser-printer/graph_region.mlir
@@ -83,7 +83,7 @@ builtin.module {
"test.op"() : () -> ()
}
-// CHECK: /graph_region.mlir:72:4
+// CHECK: /graph_region.mlir:82:4
// CHECK-NEXT: ^blockA:
// CHECK-NEXT: ^^^^^^^
// CHECK-NEXT: re-declaration of block 'blockA'
From 2291eb940d9f94182a51de9975307384774bbb74 Mon Sep 17 00:00:00 2001
From: emmau678
Date: Thu, 6 Feb 2025 10:35:00 +0000
Subject: [PATCH 18/23] move check to before the code
---
.../filecheck/parser-printer/graph_region.mlir | 18 ++++++++----------
1 file changed, 8 insertions(+), 10 deletions(-)
diff --git a/tests/filecheck/parser-printer/graph_region.mlir b/tests/filecheck/parser-printer/graph_region.mlir
index a36ea14877..5f7f5a1f06 100644
--- a/tests/filecheck/parser-printer/graph_region.mlir
+++ b/tests/filecheck/parser-printer/graph_region.mlir
@@ -74,7 +74,14 @@ builtin.module {
// -----
-// A block defined twice
+// CHECK: /graph_region.mlir:89:4
+// CHECK-NEXT: ^blockA:
+// CHECK-NEXT: ^^^^^^^
+// CHECK-NEXT: re-declaration of block 'blockA'
+// CHECK-NEXT: originally declared here:
+// CHECK-NEXT: /graph_region.mlir:13:4
+// CHECK-NEXT: ^blockA:
+// CHECK-NEXT: ^^^^^^^
builtin.module {
^blockA:
@@ -82,12 +89,3 @@ builtin.module {
^blockA:
"test.op"() : () -> ()
}
-
-// CHECK: /graph_region.mlir:82:4
-// CHECK-NEXT: ^blockA:
-// CHECK-NEXT: ^^^^^^^
-// CHECK-NEXT: re-declaration of block 'blockA'
-// CHECK-NEXT: originally declared here:
-// CHECK-NEXT: /graph_region.mlir:6:4
-// CHECK-NEXT: ^blockA:
-// CHECK-NEXT: ^^^^^^^
From b356946b4d9f91fe1d9dcfcce1da459ed5165af3 Mon Sep 17 00:00:00 2001
From: emmau678
Date: Thu, 6 Feb 2025 10:44:12 +0000
Subject: [PATCH 19/23] fix last change to move correct check above code
---
.../parser-printer/graph_region.mlir | 28 +++++++++----------
1 file changed, 13 insertions(+), 15 deletions(-)
diff --git a/tests/filecheck/parser-printer/graph_region.mlir b/tests/filecheck/parser-printer/graph_region.mlir
index 5f7f5a1f06..6c7676ca4f 100644
--- a/tests/filecheck/parser-printer/graph_region.mlir
+++ b/tests/filecheck/parser-printer/graph_region.mlir
@@ -43,24 +43,22 @@ builtin.module {
// -----
-// A graph region that refers to a value that is not defined in the module.
+// A graph region that refers to values that are not defined in the module.
+
+// CHECK: value %1 was used but not defined
builtin.module {
%0 = "test.termop"(%1) : (i32) -> i32
}
-// CHECK: value %1 was used but not defined
-
// -----
-// A graph region that refers to values that are not defined in the module.
+// CHECK: values %1, %2 were used but not defined
builtin.module {
%0 = "test.termop"(%1, %2) : (i32, i32) -> i32
}
-// CHECK: values %1, %2 were used but not defined
-
// -----
// A forward value used with a wrong index
@@ -74,18 +72,18 @@ builtin.module {
// -----
-// CHECK: /graph_region.mlir:89:4
-// CHECK-NEXT: ^blockA:
-// CHECK-NEXT: ^^^^^^^
-// CHECK-NEXT: re-declaration of block 'blockA'
-// CHECK-NEXT: originally declared here:
-// CHECK-NEXT: /graph_region.mlir:13:4
-// CHECK-NEXT: ^blockA:
-// CHECK-NEXT: ^^^^^^^
-
builtin.module {
^blockA:
"test.op"() : () -> ()
^blockA:
"test.op"() : () -> ()
}
+
+// CHECK: /graph_region.mlir:78:4
+// CHECK-NEXT: ^blockA:
+// CHECK-NEXT: ^^^^^^^
+// CHECK-NEXT: re-declaration of block 'blockA'
+// CHECK-NEXT: originally declared here:
+// CHECK-NEXT: /graph_region.mlir:4:4
+// CHECK-NEXT: ^blockA:
+// CHECK-NEXT: ^^^^^^^
From 7650810d95255f5f5f422a1ceff9e278b3009ab3 Mon Sep 17 00:00:00 2001
From: emmau678
Date: Tue, 11 Feb 2025 13:08:38 +0000
Subject: [PATCH 20/23] minor change to error printing format
---
tests/filecheck/parser-printer/graph_region.mlir | 4 ++--
xdsl/parser/core.py | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/tests/filecheck/parser-printer/graph_region.mlir b/tests/filecheck/parser-printer/graph_region.mlir
index 6c7676ca4f..40b78ffc82 100644
--- a/tests/filecheck/parser-printer/graph_region.mlir
+++ b/tests/filecheck/parser-printer/graph_region.mlir
@@ -45,7 +45,7 @@ builtin.module {
// A graph region that refers to values that are not defined in the module.
-// CHECK: value %1 was used but not defined
+// CHECK: value used but not defined: [%1]
builtin.module {
%0 = "test.termop"(%1) : (i32) -> i32
@@ -53,7 +53,7 @@ builtin.module {
// -----
-// CHECK: values %1, %2 were used but not defined
+// CHECK: values used but not defined: [%1, %2]
builtin.module {
%0 = "test.termop"(%1, %2) : (i32, i32) -> i32
diff --git a/xdsl/parser/core.py b/xdsl/parser/core.py
index c9b15e14db..2b2a25a59d 100644
--- a/xdsl/parser/core.py
+++ b/xdsl/parser/core.py
@@ -146,9 +146,9 @@ def parse_module(self, allow_implicit_module: bool = True) -> ModuleOp:
"%" + name for name in self.forward_ssa_references.keys()
)
if len(self.forward_ssa_references.keys()) > 1:
- self.raise_error(f"values {value_names} were used but not defined")
+ self.raise_error(f"values used but not defined: [{value_names}]")
else:
- self.raise_error(f"value {value_names} was used but not defined")
+ self.raise_error(f"value used but not defined: [{value_names}]")
return module_op
From 9cf7ddbc65f393bd778ce1181a540fd391bc49a6 Mon Sep 17 00:00:00 2001
From: Emma Urquhart <77412390+emmau678@users.noreply.github.com>
Date: Tue, 11 Feb 2025 13:50:47 +0000
Subject: [PATCH 21/23] Update xdsl/parser/core.py
Co-authored-by: Sasha Lopoukhine
---
xdsl/parser/core.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/xdsl/parser/core.py b/xdsl/parser/core.py
index 2b2a25a59d..b94b065e86 100644
--- a/xdsl/parser/core.py
+++ b/xdsl/parser/core.py
@@ -145,7 +145,7 @@ def parse_module(self, allow_implicit_module: bool = True) -> ModuleOp:
value_names = ", ".join(
"%" + name for name in self.forward_ssa_references.keys()
)
- if len(self.forward_ssa_references.keys()) > 1:
+ self.raise_error(f"value used but not defined: [{value_names}]")
self.raise_error(f"values used but not defined: [{value_names}]")
else:
self.raise_error(f"value used but not defined: [{value_names}]")
From efd244eec730ae123a06c7638257ab6af77ef7dc Mon Sep 17 00:00:00 2001
From: Emma Urquhart <77412390+emmau678@users.noreply.github.com>
Date: Tue, 11 Feb 2025 13:51:24 +0000
Subject: [PATCH 22/23] Update tests/filecheck/parser-printer/graph_region.mlir
Co-authored-by: Sasha Lopoukhine
---
tests/filecheck/parser-printer/graph_region.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/filecheck/parser-printer/graph_region.mlir b/tests/filecheck/parser-printer/graph_region.mlir
index 40b78ffc82..a0b1c95f8e 100644
--- a/tests/filecheck/parser-printer/graph_region.mlir
+++ b/tests/filecheck/parser-printer/graph_region.mlir
@@ -45,7 +45,7 @@ builtin.module {
// A graph region that refers to values that are not defined in the module.
-// CHECK: value used but not defined: [%1]
+// CHECK: values used but not defined: [%1]
builtin.module {
%0 = "test.termop"(%1) : (i32) -> i32
From dc78d820c57cafce7d001410a7417428a5800d52 Mon Sep 17 00:00:00 2001
From: emmau678
Date: Tue, 11 Feb 2025 13:59:53 +0000
Subject: [PATCH 23/23] updates from review comments - condense error handling
---
xdsl/parser/core.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/xdsl/parser/core.py b/xdsl/parser/core.py
index b94b065e86..9a1df7812b 100644
--- a/xdsl/parser/core.py
+++ b/xdsl/parser/core.py
@@ -145,10 +145,7 @@ def parse_module(self, allow_implicit_module: bool = True) -> ModuleOp:
value_names = ", ".join(
"%" + name for name in self.forward_ssa_references.keys()
)
- self.raise_error(f"value used but not defined: [{value_names}]")
- self.raise_error(f"values used but not defined: [{value_names}]")
- else:
- self.raise_error(f"value used but not defined: [{value_names}]")
+ self.raise_error(f"values used but not defined: [{value_names}]")
return module_op