Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Array.asformat method and add back reshape function. #800

Merged
merged 14 commits into from
Nov 13, 2024
Merged
11 changes: 7 additions & 4 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ mkdocs-jupyter = "*"

[feature.tests.tasks]
test = "pytest --pyargs sparse -n auto"
test-mlir = { cmd = "pytest --pyargs sparse/mlir_backend -v" }
test-mlir = { cmd = "pytest --pyargs sparse.mlir_backend -v" }
test-finch = { cmd = "pytest --pyargs sparse/tests -n auto -v", depends-on = ["precompile"] }

[feature.tests.dependencies]
Expand Down Expand Up @@ -55,17 +55,20 @@ finch-tensor = ">=0.1.31"
SPARSE_BACKEND = "Finch"

[feature.finch.target.osx-arm64.activation.env]
SPARSE_BACKEND = "Finch"
PYTHONFAULTHANDLER = "${HOME}/faulthandler.log"

[feature.mlir.dependencies]
scipy = ">=0.19"
mlir-python-bindings = "19.*"

[feature.mlir.target.osx-arm64.pypi-dependencies]
finch-mlir = ">=0.0.2"
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved

[feature.mlir.activation.env]
SPARSE_BACKEND = "MLIR"

[environments]
tests = ["tests", "extras"]
docs = ["docs", "extras"]
mlir-dev = ["tests", "mlir"]
finch-dev = ["tests", "finch"]
mlir-dev = {features = ["tests", "mlir"], no-default-feature = true}
finch-dev = {features = ["tests", "finch"], no-default-feature = true}
3 changes: 2 additions & 1 deletion sparse/mlir_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
uint32,
uint64,
)
from ._ops import add
from ._ops import add, reshape

__all__ = [
"add",
Expand All @@ -36,6 +36,7 @@
"to_numpy",
"to_scipy",
"levels",
"reshape",
"from_constituent_arrays",
"int8",
"int16",
Expand Down
5 changes: 5 additions & 0 deletions sparse/mlir_backend/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,10 @@ def copy(self) -> "Array":
arrs = tuple(arr.copy() for arr in self.get_constituent_arrays())
return from_constituent_arrays(format=self.format, arrays=arrs, shape=self.shape)

def asformat(self, format: StorageFormat) -> "Array":
from ._ops import asformat

return asformat(self, format=format)

def get_constituent_arrays(self) -> tuple[np.ndarray, ...]:
return self._storage.get_constituent_arrays()
11 changes: 11 additions & 0 deletions sparse/mlir_backend/_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ctypes
import functools
import weakref
from collections.abc import Iterable

import mlir_finch.runtime as rt

Expand Down Expand Up @@ -52,3 +53,13 @@
ctypes.pythonapi.Py_DecRef(ptr)

weakref.finalize(owner, finalizer, ptr)


def as_shape(x) -> tuple[int]:
if not isinstance(x, Iterable):
x = (x,)

Check warning on line 60 in sparse/mlir_backend/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_common.py#L60

Added line #L60 was not covered by tests

if not all(isinstance(xi, int) for xi in x):
raise TypeError("Shape must be an `int` or tuple of `int`s.")

Check warning on line 63 in sparse/mlir_backend/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_common.py#L63

Added line #L63 was not covered by tests

return tuple(int(xi) for xi in x)
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 6 additions & 0 deletions sparse/mlir_backend/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
libc.free.argtypes = [ctypes.c_void_p]
libc.free.restype = None

SHARED_LIBS = []
if DEBUG:
SHARED_LIBS.append(MLIR_C_RUNNER_UTILS)

Check warning on line 33 in sparse/mlir_backend/_core.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_core.py#L33

Added line #L33 was not covered by tests

OPT_LEVEL = 0 if DEBUG else 2

# TODO: remove global state
ctx = Context()

Expand Down
16 changes: 8 additions & 8 deletions sparse/mlir_backend/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def np_dtype(self) -> np.dtype:
return np.dtype(getattr(np, f"uint{self.bit_width}"))


int8 = UnsignedIntegerDType(bit_width=8)
int16 = UnsignedIntegerDType(bit_width=16)
int32 = UnsignedIntegerDType(bit_width=32)
int64 = UnsignedIntegerDType(bit_width=64)
uint8 = UnsignedIntegerDType(bit_width=8)
uint16 = UnsignedIntegerDType(bit_width=16)
uint32 = UnsignedIntegerDType(bit_width=32)
uint64 = UnsignedIntegerDType(bit_width=64)


@dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
Expand All @@ -89,10 +89,10 @@ def np_dtype(self) -> np.dtype:
return np.dtype(getattr(np, f"int{self.bit_width}"))


uint8 = SignedIntegerDType(bit_width=8)
uint16 = SignedIntegerDType(bit_width=16)
uint32 = SignedIntegerDType(bit_width=32)
uint64 = SignedIntegerDType(bit_width=64)
int8 = SignedIntegerDType(bit_width=8)
int16 = SignedIntegerDType(bit_width=16)
int32 = SignedIntegerDType(bit_width=32)
int64 = SignedIntegerDType(bit_width=64)


intp: SignedIntegerDType = locals()[f"int{_PTR_WIDTH}"]
Expand Down
106 changes: 91 additions & 15 deletions sparse/mlir_backend/_ops.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import ctypes
import math

import mlir_finch.execution_engine
import mlir_finch.passmanager
from mlir_finch import ir
from mlir_finch.dialects import arith, complex, func, linalg, sparse_tensor, tensor

import numpy as np

from ._array import Array
from ._common import fn_cache
from ._core import CWD, DEBUG, SHARED_LIBS, ctx, pm
from ._common import as_shape, fn_cache
from ._core import CWD, DEBUG, OPT_LEVEL, SHARED_LIBS, ctx, pm
from ._dtypes import DType, IeeeComplexFloatingDType, IeeeRealFloatingDType, IntegerDType
from .levels import StorageFormat, _determine_format


@fn_cache
Expand All @@ -17,7 +21,6 @@
b_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
dtype: DType,
rank: int,
) -> ir.Module:
with ir.Location.unknown(ctx):
module = ir.Module.create()
Expand All @@ -31,7 +34,7 @@
raise RuntimeError(f"Can not add {dtype=}.")

dtype = dtype._get_mlir_type()
ordering = ir.AffineMap.get_permutation(range(rank))
max_rank = out_tensor_type.rank

with ir.InsertionPoint(module.body):

Expand All @@ -42,8 +45,13 @@
[out_tensor_type],
[a, b],
[out],
ir.ArrayAttr.get([ir.AffineMapAttr.get(p) for p in (ordering,) * 3]),
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * rank),
ir.ArrayAttr.get(
[
ir.AffineMapAttr.get(ir.AffineMap.get_minor_identity(max_rank, t.rank))
for t in (a_tensor_type, b_tensor_type, out_tensor_type)
]
),
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * max_rank),
)
block = generic_op.regions[0].blocks.append(dtype, dtype, dtype)
with ir.InsertionPoint(block):
Expand Down Expand Up @@ -72,7 +80,7 @@
if DEBUG:
(CWD / "add_module_opt.mlir").write_text(str(module))

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)


@fn_cache
Expand All @@ -97,7 +105,7 @@
if DEBUG:
(CWD / "reshape_module_opt.mlir").write_text(str(module))

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)


@fn_cache
Expand Down Expand Up @@ -125,26 +133,94 @@
if DEBUG:
(CWD / "broadcast_to_module_opt.mlir").write_text(str(module))

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=SHARED_LIBS)
return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)

Check warning on line 136 in sparse/mlir_backend/_ops.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_ops.py#L136

Added line #L136 was not covered by tests


@fn_cache
def get_convert_module(
in_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
):
with ir.Location.unknown(ctx):
module = ir.Module.create()

with ir.InsertionPoint(module.body):

@func.FuncOp.from_py_func(in_tensor_type)
def convert(in_tensor):
return sparse_tensor.convert(out_tensor_type, in_tensor)

def add(x1: Array, x2: Array) -> Array:
ret_storage_format = x1.format
convert.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "convert_module.mlir").write_text(str(module))

Check warning on line 155 in sparse/mlir_backend/_ops.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_ops.py#L155

Added line #L155 was not covered by tests
pm.run(module.operation)
if DEBUG:
(CWD / "convert_module.mlir").write_text(str(module))

Check warning on line 158 in sparse/mlir_backend/_ops.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_ops.py#L158

Added line #L158 was not covered by tests

return mlir_finch.execution_engine.ExecutionEngine(module, opt_level=OPT_LEVEL, shared_libs=SHARED_LIBS)


def add(x1: Array, x2: Array, /) -> Array:
# TODO: Determine output format via autoscheduler
ret_storage_format = _determine_format(x1.format, x2.format, dtype=x1.dtype, union=True)
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()
out_tensor_type = ret_storage_format._get_mlir_type(shape=x1.shape)
out_tensor_type = ret_storage_format._get_mlir_type(shape=np.broadcast_shapes(x1.shape, x2.shape))

# TODO: Decide what will be the output tensor_type
add_module = get_add_module(
x1._get_mlir_type(),
x2._get_mlir_type(),
out_tensor_type=out_tensor_type,
dtype=x1.dtype,
rank=x1.ndim,
)
add_module.invoke(
"add",
ctypes.pointer(ctypes.pointer(ret_storage)),
*x1._to_module_arg(),
*x2._to_module_arg(),
)
return Array(storage=ret_storage, shape=out_tensor_type.shape)
return Array(storage=ret_storage, shape=tuple(out_tensor_type.shape))


def asformat(x: Array, /, format: StorageFormat) -> Array:
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
if x.format == format:
return x

out_tensor_type = format._get_mlir_type(shape=x.shape)
ret_storage = format._get_ctypes_type(owns_memory=True)()

convert_module = get_convert_module(
x._get_mlir_type(),
out_tensor_type,
)

convert_module.invoke(
"convert",
ctypes.pointer(ctypes.pointer(ret_storage)),
*x._to_module_arg(),
)

return Array(storage=ret_storage, shape=x.shape)


def reshape(x: Array, /, shape: tuple[int, ...]) -> Array:
from ._conversions import _from_numpy

shape = as_shape(shape)
if math.prod(x.shape) != math.prod(shape):
raise ValueError(f"`math.prod(x.shape) != math.prod(shape)`, {x.shape=}, {shape=}")

Check warning on line 210 in sparse/mlir_backend/_ops.py

View check run for this annotation

Codecov / codecov/patch

sparse/mlir_backend/_ops.py#L210

Added line #L210 was not covered by tests

ret_storage_format = _determine_format(x.format, dtype=x.dtype, union=len(shape) > x.ndim, out_ndim=len(shape))
shape_array = _from_numpy(np.asarray(shape, dtype=np.uint64))
out_tensor_type = ret_storage_format._get_mlir_type(shape=shape)
ret_storage = ret_storage_format._get_ctypes_type(owns_memory=True)()

reshape_module = get_reshape_module(x._get_mlir_type(), shape_array._get_mlir_type(), out_tensor_type)

reshape_module.invoke(
"reshape",
ctypes.pointer(ctypes.pointer(ret_storage)),
*x._to_module_arg(),
*shape_array._to_module_arg(),
)

return Array(storage=ret_storage, shape=shape)
Loading
Loading