Skip to content

Commit

Permalink
dialects: (builtin) replace DenseArrayBase's data array with bytes (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh authored Dec 6, 2024
1 parent c450996 commit fe9f649
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 41 deletions.
28 changes: 10 additions & 18 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AnyTensorType,
ArrayAttr,
BFloat16Type,
BytesAttr,
ComplexType,
DenseArrayBase,
DenseIntOrFPElementsAttr,
Expand All @@ -18,7 +19,6 @@
Float80Type,
Float128Type,
FloatAttr,
FloatData,
IntAttr,
IntegerAttr,
IntegerType,
Expand Down Expand Up @@ -238,22 +238,6 @@ def test_DenseIntOrFPElementsAttr_from_list():
assert attr.type == AnyTensorType(f32, [])


def test_DenseArrayBase_verifier_failure():
# Check that a malformed attribute raises a verify error

with pytest.raises(VerifyException) as err:
DenseArrayBase([f32, ArrayAttr([IntAttr(0)])])
assert err.value.args[0] == (
"dense array of float element type " "should only contain floats"
)

with pytest.raises(VerifyException) as err:
DenseArrayBase([i32, ArrayAttr([FloatData(0.0)])])
assert err.value.args[0] == (
"dense array of integer element type " "should only contain integers"
)


@pytest.mark.parametrize(
"ref,expected",
[
Expand Down Expand Up @@ -439,7 +423,7 @@ def test_complex_init():

def test_dense_as_tuple():
floats = DenseArrayBase.from_list(f32, [3.14159, 2.71828])
assert floats.get_values() == (3.14159, 2.71828)
assert floats.get_values() == (3.141590118408203, 2.718280076980591)

ints = DenseArrayBase.from_list(i32, [1, 1, 2, 3, 5, 8])
assert ints.get_values() == (1, 1, 2, 3, 5, 8)
Expand All @@ -455,6 +439,14 @@ def test_create_dense_int():
DenseArrayBase.create_dense_int(i8, (99999999, 255, 256))


def test_create_dense_wrong_size():
with pytest.raises(
VerifyException,
match=re.escape("Data length of array (1) not divisible by element size 2"),
):
DenseArrayBase((i16, BytesAttr(b"F")))


def test_strides():
assert ShapedType.strides_for_shape(()) == ()
assert ShapedType.strides_for_shape((), factor=2) == ()
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/dialects/builtin/attrs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
%x6 = "arith.constant"() {"value" = 0 : i64, "test" = false} : () -> i64
// CHECK: "test" = array<i32: 2, 3, 4>
%x7 = "arith.constant"() {"value" = 0 : i64, "test" = array<i32: 2, 3, 4>} : () -> i64
// CHECK: "test" = array<f32: 2.1, 3.2, 4.3>
// CHECK: "test" = array<f32: 2.0999999046325684, 3.200000047683716, 4.300000190734863>
%x8 = "arith.constant"() {"value" = 0 : i64, "test" = array<f32: 2.1, 3.2, 4.3>} : () -> i64
// CHECK: "test" = #builtin.signedness<signless>
%x9 = "arith.constant"() {"value" = 0 : i64, "test" = #builtin.signedness<signless>} : () -> i64
Expand Down
2 changes: 1 addition & 1 deletion tests/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def test_densearray_attr():
"""Test that a DenseArrayAttr can be parsed and then printed."""

prog = """
"func.func"() <{"sym_name" = "test", "function_type" = i64, "sym_visibility" = "private", "unit_attr"}> {"bool_attrs" = array<i1: false, true>, "int_attr" = array<i32: 19, 23, 55>, "float_attr" = array<f32: 0.34>} : () -> ()
"func.func"() <{"sym_name" = "test", "function_type" = i64, "sym_visibility" = "private", "unit_attr"}> {"bool_attrs" = array<i1: false, true>, "int_attr" = array<i32: 19, 23, 55>, "float_attr" = array<f32: 0.3400000035762787>} : () -> ()
"""

ctx = MLContext()
Expand Down
44 changes: 23 additions & 21 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,22 +1107,16 @@ class DenseArrayBase(ParametrizedAttribute):
name = "array"

elt_type: ParameterDef[IntegerType | AnyFloat]
data: ParameterDef[ArrayAttr[IntAttr] | ArrayAttr[FloatData]]
data: ParameterDef[BytesAttr]

def verify(self):
if isinstance(self.elt_type, IntegerType):
for d in self.data.data:
if isinstance(d, FloatData):
raise VerifyException(
"dense array of integer element type should only contain "
"integers"
)
else:
for d in self.data.data:
if isinstance(d, IntAttr):
raise VerifyException(
"dense array of float element type should only contain floats"
)
data_len = len(self.data.data)
elt_size = self.elt_type.size
if data_len % elt_size:
raise VerifyException(
f"Data length of {self.name} ({data_len}) not divisible by element "
f"size {elt_size}"
)

@staticmethod
def create_dense_int(
Expand All @@ -1147,18 +1141,26 @@ def create_dense_int(

values = cast(tuple[IntAttr, ...], normalized_values)

return DenseArrayBase([data_type, ArrayAttr(values)])
fmt = data_type.format[0] + str(len(data)) + data_type.format[1:]

bytes_data = struct.pack(fmt, *(attr.data for attr in values))

return DenseArrayBase([data_type, BytesAttr(bytes_data)])

@staticmethod
def create_dense_float(
data_type: AnyFloat, data: Sequence[int | float] | Sequence[FloatData]
) -> DenseArrayBase:
if len(data) and isinstance(data[0], int | float):
attr_list = [FloatData(float(d)) for d in cast(Sequence[int | float], data)]
vals = data
else:
attr_list = cast(Sequence[FloatData], data)
vals = tuple(attr.data for attr in cast(Sequence[FloatData], data))

fmt = data_type.format[0] + str(len(data)) + data_type.format[1:]

bytes_data = struct.pack(fmt, *vals)

return DenseArrayBase([data_type, ArrayAttr(attr_list)])
return DenseArrayBase([data_type, BytesAttr(bytes_data)])

@overload
@staticmethod
Expand Down Expand Up @@ -1192,13 +1194,13 @@ def from_list(
raise TypeError(f"Unsupported element type {data_type}")

def iter_values(self) -> Iterator[float] | Iterator[int]:
return (attr.data for attr in self.data.data)
return self.elt_type.iter_unpack(self.data.data)

def get_values(self) -> tuple[int, ...] | tuple[float, ...]:
return tuple(self.iter_values())
return self.elt_type.unpack(self.data.data, len(self))

def __len__(self) -> int:
return len(self.data.data)
return len(self.data.data) // self.elt_type.size


@irdl_attr_definition
Expand Down

0 comments on commit fe9f649

Please sign in to comment.