Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Array.asformat method and add back reshape function. #800

Merged
merged 14 commits into from
Nov 13, 2024
Merged
11 changes: 7 additions & 4 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ mkdocs-jupyter = "*"

[feature.tests.tasks]
test = "pytest --pyargs sparse -n auto"
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -v" }
test-mlir = { cmd = "pytest --pyargs sparse.mlir_backend -v" }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -v", depends-on = ["precompile"] }

[feature.tests.dependencies]
Expand Down Expand Up @@ -55,17 +55,20 @@ finch-tensor = ">=0.1.31"
SPARSE_BACKEND = "Finch"

[feature.finch.target.osx-arm64.activation.env]
SPARSE_BACKEND = "Finch"
PYTHONFAULTHANDLER = "${HOME}/faulthandler.log"

[feature.mlir.dependencies]
scipy = ">=0.19"
mlir-python-bindings = "19.*"

[feature.mlir.target.osx-arm64.pypi-dependencies]
finch-mlir = ">=0.0.2"
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved

[feature.mlir.activation.env]
SPARSE_BACKEND = "MLIR"

[environments]
tests = ["tests", "extras"]
docs = ["docs", "extras"]
mlir-dev = ["tests", "mlir"]
finch-dev = ["tests", "finch"]
mlir-dev = {features = ["tests", "mlir"], no-default-feature = true}
finch-dev = {features = ["tests", "finch"], no-default-feature = true}
3 changes: 2 additions & 1 deletion sparse/mlir_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
uint32,
uint64,
)
from ._ops import add
from ._ops import add, reshape

__all__ = [
"add",
Expand All @@ -36,6 +36,7 @@
"to_numpy",
"to_scipy",
"levels",
"reshape",
"from_constituent_arrays",
"int8",
"int16",
Expand Down
5 changes: 5 additions & 0 deletions sparse/mlir_backend/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,10 @@ def copy(self) -> "Array":
arrs = tuple(arr.copy() for arr in self.get_constituent_arrays())
return from_constituent_arrays(format=self.format, arrays=arrs, shape=self.shape)

def asformat(self, format: StorageFormat) -> "Array":
from ._ops import asformat

return asformat(self, format=format)

def get_constituent_arrays(self) -> tuple[np.ndarray, ...]:
return self._storage.get_constituent_arrays()
11 changes: 11 additions & 0 deletions sparse/mlir_backend/_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ctypes
import functools
import weakref
from collections.abc import Iterable

import mlir_finch.runtime as rt

Expand Down Expand Up @@ -52,3 +53,13 @@
ctypes.pythonapi.Py_DecRef(ptr)

weakref.finalize(owner, finalizer, ptr)


def as_shape(x) -> tuple[int]:
if not isinstance(x, Iterable):
x = (x,)

Check warning on line 60 in sparse/mlir_backend/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_common.py#L60

Added line #L60 was not covered by tests

if not all(isinstance(xi, int) for xi in x):
raise TypeError("Shape must be an `int` or tuple of `int`s.")

Check warning on line 63 in sparse/mlir_backend/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_common.py#L63

Added line #L63 was not covered by tests

return tuple(int(xi) for xi in x)
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 6 additions & 0 deletions sparse/mlir_backend/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
libc.free.argtypes = [ctypes.c_void_p]
libc.free.restype = None

SHARED_LIBS = []
if DEBUG:
SHARED_LIBS.append(MLIR_C_RUNNER_UTILS)

Check warning on line 33 in sparse/mlir_backend/_core.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_core.py#L33

Added line #L33 was not covered by tests

OPT_LEVEL = 0 if DEBUG else 2

# TODO: remove global state
ctx = Context()

Expand Down
16 changes: 8 additions & 8 deletions sparse/mlir_backend/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def np_dtype(self) -> np.dtype:
return np.dtype(getattr(np, f"uint{self.bit_width}"))


int8 = UnsignedIntegerDType(bit_width=8)
int16 = UnsignedIntegerDType(bit_width=16)
int32 = UnsignedIntegerDType(bit_width=32)
int64 = UnsignedIntegerDType(bit_width=64)
uint8 = UnsignedIntegerDType(bit_width=8)
uint16 = UnsignedIntegerDType(bit_width=16)
uint32 = UnsignedIntegerDType(bit_width=32)
uint64 = UnsignedIntegerDType(bit_width=64)


@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
Expand All @@ -89,10 +89,10 @@ def np_dtype(self) -> np.dtype:
return np.dtype(getattr(np, f"int{self.bit_width}"))


uint8 = SignedIntegerDType(bit_width=8)
uint16 = SignedIntegerDType(bit_width=16)
uint32 = SignedIntegerDType(bit_width=32)
uint64 = SignedIntegerDType(bit_width=64)
int8 = SignedIntegerDType(bit_width=8)
int16 = SignedIntegerDType(bit_width=16)
int32 = SignedIntegerDType(bit_width=32)
int64 = SignedIntegerDType(bit_width=64)


intp: SignedIntegerDType = locals()[f"int{_PTR_WIDTH}"]
Expand Down
99 changes: 84 additions & 15 deletions sparse/mlir_backend/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from mlir_finch import ir
from mlir_finch.dialects import arith, complex, func, linalg, sparse_tensor, tensor

import numpy as np

from ._array import Array
from ._common import fn_cache
from ._core import CWD, DEBUG, SHARED_LIBS, ctx, pm
from ._common import as_shape, fn_cache
from ._core import CWD, DEBUG, OPT_LEVEL, SHARED_LIBS, ctx, pm
from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType
from .levels import StorageFormat, _determine_format


@fn_cache
Expand All @@ -17,7 +20,6 @@
b_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
dtype: DType,
rank: int,
) -> ir.Module:
with ir.Location.unknown(ctx):
module = ir.Module.create()
Expand All @@ -31,7 +33,7 @@
raise RuntimeError(f"Can not add {dtype=}.")

dtype = dtype._get_mlir_type()
ordering = ir.AffineMap.get_permutation(range(rank))
max_rank = out_tensor_type.rank

with ir.InsertionPoint(module.body):

Expand All @@ -42,8 +44,13 @@
[out_tensor_type],
[a, b],
[out],
ir.ArrayAttr.get([ir.AffineMapAttr.get(p) for p in (ordering,) * 3]),
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * rank),
ir.ArrayAttr.get(
[
ir.AffineMapAttr.get(ir.AffineMap.get_minor_identity(max_rank, t.rank))
for t in (a_tensor_type, b_tensor_type, out_tensor_type)
]
),
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * max_rank),
)
block = generic_op.regions[0].blocks.append(dtype, dtype, dtype)
with ir.InsertionPoint(block):
Expand Down Expand Up @@ -72,7 +79,7 @@
if DEBUG:
(CWD / "add_module_opt.mlir").write_text(str(module))

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)


@fn_cache
Expand All @@ -97,7 +104,7 @@
if DEBUG:
(CWD / "reshape_module_opt.mlir").write_text(str(module))

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)


@fn_cache
Expand Down Expand Up @@ -125,26 +132,88 @@
if DEBUG:
(CWD / "broadcast_to_module_opt.mlir").write_text(str(module))

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)

Check warning on line 135 in sparse/mlir_backend/_ops.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_ops.py#L135

Added line #L135 was not covered by tests


@fn_cache
def get_convert_module(
in_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
):
with ir.Location.unknown(ctx):
module = ir.Module.create()

with ir.InsertionPoint(module.body):

@func.FuncOp.from_py_func(in_tensor_type)
def convert(in_tensor):
return sparse_tensor.convert(out_tensor_type, in_tensor)

convert.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "broadcast_to_module.mlir").write_text(str(module))

Check warning on line 154 in sparse/mlir_backend/_ops.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_ops.py#L154

Added line #L154 was not covered by tests
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
pm.run(module.operation)
if DEBUG:
(CWD / "broadcast_to_module_opt.mlir").write_text(str(module))

Check warning on line 157 in sparse/mlir_backend/_ops.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_ops.py#L157

Added line #L157 was not covered by tests
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)


def add(x1: Array, x2: Array) -> Array:
ret_storage_format = x1.format
def add(x1: Array, x2: Array, /) -> Array:
# TODO: Determine output format via autoscheduler
ret_storage_format = _determine_format(x1.format, x2.format, dtype=x1.dtype, union=True)
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
out_tensor_type = ret_storage_format._get_mlir_type(shape=x1.shape)
out_tensor_type = ret_storage_format._get_mlir_type(shape=np.broadcast_shapes(x1.shape, x2.shape))

# TODO: Decide what will be the output tensor_type
add_module = get_add_module(
x1._get_mlir_type(),
x2._get_mlir_type(),
out_tensor_type=out_tensor_type,
dtype=x1.dtype,
rank=x1.ndim,
)
add_module.invoke(
"add",
ctypes.pointer(ctypes.pointer(ret_storage)),
*x1._to_module_arg(),
*x2._to_module_arg(),
)
return Array(storage=ret_storage, shape=out_tensor_type.shape)
return Array(storage=ret_storage, shape=tuple(out_tensor_type.shape))


def asformat(x: Array, /, format: StorageFormat) -> Array:
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
out_tensor_type = format._get_mlir_type(shape=x.shape)
ret_storage = format._get_ctypes_type(owns_memory=True)()

convert_module = get_convert_module(
x._get_mlir_type(),
out_tensor_type,
)

convert_module.invoke(
"convert",
ctypes.pointer(ctypes.pointer(ret_storage)),
*x._to_module_arg(),
)

return Array(storage=ret_storage, shape=x.shape)


def reshape(x: Array, /, shape: tuple[int, ...]) -> Array:
from ._conversions import _from_numpy

shape = as_shape(shape)
ret_storage_format = _determine_format(x.format, dtype=x.dtype, union=len(shape) > x.ndim, out_ndim=len(shape))
shape_array = _from_numpy(np.asarray(shape, dtype=np.uint64))
out_tensor_type = ret_storage_format._get_mlir_type(shape=shape)
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()

reshape_module = get_reshape_module(x._get_mlir_type(), shape_array._get_mlir_type(), out_tensor_type)

reshape_module.invoke(
"reshape",
ctypes.pointer(ctypes.pointer(ret_storage)),
*x._to_module_arg(),
*shape_array._to_module_arg(),
)

return Array(storage=ret_storage, shape=shape)
76 changes: 76 additions & 0 deletions sparse/mlir_backend/levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,79 @@
crd_width=crd_width,
dtype=dtype,
)


def _is_sparse_level(lvl: Level | LevelFormat, /) -> bool:
assert isinstance(lvl, Level | LevelFormat)
if isinstance(lvl, Level):
lvl = lvl.format
return LevelFormat.Dense != lvl
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved


def _count_sparse_levels(format: StorageFormat) -> int:
return sum(_is_sparse_level(lvl) for lvl in format.levels)


def _count_dense_levels(format: StorageFormat) -> int:
return sum(not _is_sparse_level(lvl) for lvl in format.levels)


def _determine_format(*formats: StorageFormat, dtype: DType, union: bool, out_ndim: int | None = None) -> StorageFormat:
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
"""Determines the output format from a group of input formats.

1. Counts the sparse levels for `union=True`, and dense ones for `union=False`.
2. Gets the max number of counted levels for each format.
3. Constructs a format with the same number of counted levels.
Sparse levels are replaced with `LevelFormat.Compressed`.
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
StorageFormat
Output storage format.
"""
if len(formats) == 0:
if out_ndim is None:
out_ndim = 0
return get_storage_format(

Check warning on line 245 in sparse/mlir_backend/levels.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/levels.py#L243-L245

Added lines #L243 - L245 were not covered by tests
levels=(Level(LevelFormat.Dense),) * out_ndim,
order="C",
pos_width=64,
crd_width=64,
dtype=dtype,
)

if out_ndim is None:
out_ndim = max(fmt.rank for fmt in formats)

pos_width = 0
crd_width = 0
counter = _count_sparse_levels if not union else _count_dense_levels
n_counted = None
order = ()
for fmt in formats:
n_counted = counter(fmt) if n_counted is None else max(n_counted, counter(fmt))
pos_width = max(pos_width, fmt.pos_width)
crd_width = max(crd_width, fmt.crd_width)
if order != "C":
if fmt.order[: len(order)] == order:
order = fmt.order
elif order[: len(fmt.order)] != fmt.order:
order = "C"

if not isinstance(order, str):
order = order + tuple(range(len(order), out_ndim))
order = order[:out_ndim]

if out_ndim < n_counted:
n_counted = out_ndim

n_sparse = n_counted if not union else out_ndim - n_counted

levels = (Level(LevelFormat.Dense),) * (out_ndim - n_sparse) + (Level(LevelFormat.Compressed),) * n_sparse
return get_storage_format(
levels=levels,
order=order,
pos_width=pos_width,
crd_width=crd_width,
dtype=dtype,
)
Loading
Loading