Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (builtin) add PackableType and StructPackableType #3581

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import re
from collections.abc import Sequence

Expand Down Expand Up @@ -33,7 +34,10 @@
VectorRankConstraint,
VectorType,
f32,
f64,
i1,
i8,
i16,
i32,
i64,
)
Expand All @@ -51,6 +55,39 @@ def test_FloatType_bitwidths():
assert Float128Type().bitwidth == 128


def test_FloatType_formats():
with pytest.raises(NotImplementedError):
BFloat16Type().format
with pytest.raises(NotImplementedError):
Float16Type().format
assert Float32Type().format == "<f"
assert Float64Type().format == "<d"
with pytest.raises(NotImplementedError):
Float80Type().format
with pytest.raises(NotImplementedError):
Float128Type().format


def test_IntegerType_formats():
with pytest.raises(NotImplementedError):
IntegerType(2).format
assert IntegerType(1).format == "<b"
assert IntegerType(8).format == "<b"
assert IntegerType(16).format == "<h"
assert IntegerType(32).format == "<i"
assert IntegerType(64).format == "<q"


def test_FloatType_packing():
nums = (-128, -1, 0, 1, 127)
buffer = f32.pack(nums)
unpacked = f32.unpack(buffer, len(nums))
assert nums == unpacked

pi = f64.unpack(f64.pack((math.pi,)), 1)[0]
assert pi == math.pi


def test_IntegerType_size():
assert IntegerType(1).size == 1
assert IntegerType(2).size == 1
Expand Down Expand Up @@ -104,6 +141,68 @@ def test_IntegerAttr_normalize():
IntegerAttr(256, 8)


def test_IntegerType_packing():
# i1
nums_i1 = (0, 1, 0, 1)
buffer_i1 = i1.pack(nums_i1)
unpacked_i1 = i1.unpack(buffer_i1, len(nums_i1))
assert nums_i1 == unpacked_i1

# i8
nums_i8 = (-128, -1, 0, 1, 127)
buffer_i8 = i8.pack(nums_i8)
unpacked_i8 = i8.unpack(buffer_i8, len(nums_i8))
assert nums_i8 == unpacked_i8

# i16
nums_i16 = (-32768, -1, 0, 1, 32767)
buffer_i16 = i16.pack(nums_i16)
unpacked_i16 = i16.unpack(buffer_i16, len(nums_i16))
assert nums_i16 == unpacked_i16

# i32
nums_i32 = (-2147483648, -1, 0, 1, 2147483647)
buffer_i32 = i32.pack(nums_i32)
unpacked_i32 = i32.unpack(buffer_i32, len(nums_i32))
assert nums_i32 == unpacked_i32

# i64
nums_i64 = (-9223372036854775808, -1, 0, 1, 9223372036854775807)
buffer_i64 = i64.pack(nums_i64)
unpacked_i64 = i64.unpack(buffer_i64, len(nums_i64))
assert nums_i64 == unpacked_i64

# f32
nums_f32 = (-3.140000104904175, -1.0, 0.0, 1.0, 3.140000104904175)
buffer_f32 = f32.pack(nums_f32)
unpacked_f32 = f32.unpack(buffer_f32, len(nums_f32))
assert nums_f32 == unpacked_f32

# f64
nums_f64 = (-3.14159265359, -1.0, 0.0, 1.0, 3.14159265359)
buffer_f64 = f64.pack(nums_f64)
unpacked_f64 = f64.unpack(buffer_f64, len(nums_f64))
assert nums_f64 == unpacked_f64

# Test error cases
with pytest.raises(Exception, match="'b' format requires -128 <= number <= 127"):
# Values must be normalized before packing
i8.pack((255,))
with pytest.raises(
Exception, match="'h' format requires -32768 <= number <= 32767"
):
i16.pack((32768,))
with pytest.raises(
Exception, match="'i' format requires -2147483648 <= number <= 2147483647"
):
i32.pack((2147483648,))
with pytest.raises(
Exception,
match="'q' format requires -9223372036854775808 <= number <= 9223372036854775807",
):
i64.pack((9223372036854775808,))


def test_DenseIntOrFPElementsAttr_fp_type_conversion():
check1 = DenseIntOrFPElementsAttr.tensor_from_list([4, 5], f32, [])

Expand Down
130 changes: 120 additions & 10 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
import struct
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass
Expand Down Expand Up @@ -85,9 +86,12 @@
from xdsl.utils.isattr import isattr

if TYPE_CHECKING:
from _typeshed import ReadableBuffer, WriteableBuffer

from xdsl.parser import AttrParser, Parser
from xdsl.printer import Printer


DYNAMIC_INDEX = -1
"""
A constant value denoting a dynamic index in a shape.
Expand Down Expand Up @@ -347,10 +351,10 @@ def print_parameter(self, printer: Printer) -> None:

class FixedBitwidthType(TypeAttribute, ABC):
"""
A type attribute with a defined bitwidth
A type attribute whose runtime bitwidth is target-independent.
"""

name = "abstract.bitwidth_type"
name = "abstract.fixed_bitwidth_type"

@property
@abstractmethod
Expand All @@ -368,8 +372,76 @@ def size(self) -> int:
return (self.bitwidth + 7) >> 3


_PyT = TypeVar("_PyT")


class PackableType(Generic[_PyT], FixedBitwidthType, ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to combine the FixedBitwidthType and PackableType? Maybe we are mixing host bitdwidth with target bitwidth? The target system could for example have support for i4 types, but will be packed as i8 at compile time.

Maybe this is too niche, I'm just not to sure about this one, the two concepts just seem to have little in common

Or f80 or f128 for example, they are not packable, but do have a fixed bitwidth

Copy link
Collaborator

@jorendumoulin jorendumoulin Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Starting to change my mind about this, if something has a fixedbitwidthtype it should be packable, just not necessarily packable with that specified fixed bitwidth. I guess it thus makes sense to keep them together

"""
Abstract base class for xDSL types whose values can be encoded and decoded as bytes.
"""

@abstractmethod
def iter_unpack(self, buffer: ReadableBuffer, /) -> Iterator[_PyT]:
"""
Yields unpacked values one at a time, starting at the beginning of the buffer.
"""
raise NotImplementedError()

@abstractmethod
def unpack(self, buffer: ReadableBuffer, num: int, /) -> tuple[_PyT, ...]:
"""
Unpack `num` values from the beginning of the buffer.
"""
raise NotImplementedError()

@abstractmethod
def pack_into(self, buffer: WriteableBuffer, offset: int, value: _PyT) -> None:
"""
Pack a value at a given offset into a buffer.
"""
raise NotImplementedError()

@abstractmethod
def pack(self, values: Sequence[_PyT]) -> bytes:
"""
Create a new buffer containing the input `values`.
"""
raise NotImplementedError()


class StructPackableType(Generic[_PyT], PackableType[_PyT], ABC):
"""
Abstract base class for xDSL types that can be packed and unpacked using Python's
`struct` package, using a format string.
"""

@property
@abstractmethod
def format(self) -> str:
"""
Format to be used when decoding and encoding bytes.

https://docs.python.org/3/library/struct.html
"""
raise NotImplementedError()

def iter_unpack(self, buffer: ReadableBuffer, /) -> Iterator[_PyT]:
return (values[0] for values in struct.iter_unpack(self.format, buffer))

def unpack(self, buffer: ReadableBuffer, num: int, /) -> tuple[_PyT, ...]:
fmt = self.format[0] + str(num) + self.format[1:]
return struct.unpack(fmt, buffer)
jorendumoulin marked this conversation as resolved.
Show resolved Hide resolved

def pack_into(self, buffer: WriteableBuffer, offset: int, value: _PyT) -> None:
struct.pack_into(self.format, buffer, offset, value)

def pack(self, values: Sequence[_PyT]) -> bytes:
fmt = self.format[0] + str(len(values)) + self.format[1:]
return struct.pack(fmt, *values)


@irdl_attr_definition
class IntegerType(ParametrizedAttribute, FixedBitwidthType):
class IntegerType(ParametrizedAttribute, StructPackableType[int]):
name = "integer_type"
width: ParameterDef[IntAttr]
signedness: ParameterDef[SignednessAttr]
Expand Down Expand Up @@ -432,6 +504,20 @@ def print_value_without_type(self, value: int, printer: Printer):
else:
printer.print_string(f"{value}")

@property
def format(self) -> str:
match self.bitwidth:
case 1 | 8:
return "<b"
case 16:
return "<h"
case 32:
return "<i"
case 64:
return "<q"
case _:
raise NotImplementedError(f"Format not implemented for {self}")


i64 = IntegerType(64)
i32 = IntegerType(32)
Expand Down Expand Up @@ -583,66 +669,90 @@ def constr(
BoolAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(1)]]


class _FloatType(ABC):
class _FloatType(StructPackableType[float], ABC):
@property
@abstractmethod
def bitwidth(self) -> int:
raise NotImplementedError()


@irdl_attr_definition
class BFloat16Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
class BFloat16Type(ParametrizedAttribute, _FloatType):
name = "bf16"

@property
def bitwidth(self) -> int:
return 16

@property
def format(self) -> str:
raise NotImplementedError()


@irdl_attr_definition
class Float16Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
class Float16Type(ParametrizedAttribute, _FloatType):
name = "f16"

@property
def bitwidth(self) -> int:
return 16

@property
def format(self) -> str:
raise NotImplementedError()


@irdl_attr_definition
class Float32Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
class Float32Type(ParametrizedAttribute, _FloatType):
name = "f32"

@property
def bitwidth(self) -> int:
return 32

@property
def format(self) -> str:
return "<f"


@irdl_attr_definition
class Float64Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
class Float64Type(ParametrizedAttribute, _FloatType):
name = "f64"

@property
def bitwidth(self) -> int:
return 64

@property
def format(self) -> str:
return "<d"


@irdl_attr_definition
class Float80Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
class Float80Type(ParametrizedAttribute, _FloatType):
name = "f80"

@property
def bitwidth(self) -> int:
return 80

@property
def format(self) -> str:
raise NotImplementedError()


@irdl_attr_definition
class Float128Type(ParametrizedAttribute, FixedBitwidthType, _FloatType):
class Float128Type(ParametrizedAttribute, _FloatType):
name = "f128"

@property
def bitwidth(self) -> int:
return 128

@property
def format(self) -> str:
raise NotImplementedError()


AnyFloat: TypeAlias = (
BFloat16Type | Float16Type | Float32Type | Float64Type | Float80Type | Float128Type
Expand Down
Loading