Skip to content

Commit

Permalink
transforms: (stencil-tensorize-z-dimension) use DenseIntOrFPElementsA…
Browse files Browse the repository at this point in the history
…ttr constructor

make indextype dense

further wip

wip

wip: make indextype dense

further wip

wip

typo

normalize integers

s

fix print_csl

fixes

fix pyright issues

s

resolve reviewer comments

update to other pr

revert stencil pass change
  • Loading branch information
jorendumoulin committed Jan 7, 2025
1 parent a0d424d commit f9d4a14
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions xdsl/transforms/experimental/stencil_tensorize_z_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
ArrayAttr,
ContainerType,
DenseIntOrFPElementsAttr,
FloatAttr,
IndexType,
IntAttr,
IntegerType,
ModuleOp,
ShapedType,
TensorType,
Expand Down Expand Up @@ -193,7 +196,7 @@ def match_and_rewrite(
@staticmethod
def _rewrite_scalar_operand(
scalar_op: SSAValue,
dest_typ: TensorType[Attribute],
dest_typ: TensorType[IndexType | IntegerType | AnyFloat],
op: FloatingPointLikeBinaryOperation,
rewriter: PatternRewriter,
) -> SSAValue:
Expand All @@ -203,8 +206,10 @@ def _rewrite_scalar_operand(
If it is not a constant, create an empty tensor and `linalg.fill` it with the scalar value.
"""
if isinstance(scalar_op, OpResult) and isinstance(scalar_op.op, ConstantOp):
assert isinstance(float_attr := scalar_op.op.value, FloatAttr)
scalar_value = float_attr.value.data
tens_const = ConstantOp(
DenseIntOrFPElementsAttr([dest_typ, ArrayAttr([scalar_op.op.value])])
DenseIntOrFPElementsAttr.from_list(dest_typ, [scalar_value])
)
rewriter.insert_op(tens_const, InsertPoint.before(scalar_op.op))
return tens_const.result
Expand Down Expand Up @@ -265,7 +270,9 @@ def is_tensorized(
return len(typ.get_shape()) == 2 and isinstance(typ.get_element_type(), TensorType)


def is_tensor(typ: Attribute) -> TypeGuard[TensorType[Attribute]]:
def is_tensor(
typ: Attribute,
) -> TypeGuard[TensorType[IndexType | IntegerType | AnyFloat]]:
return isinstance(typ, TensorType)


Expand Down

0 comments on commit f9d4a14

Please sign in to comment.