Skip to content

Commit

Permalink
LICM scf.for
Browse files Browse the repository at this point in the history
  • Loading branch information
RRavikiran66 committed Aug 16, 2024
1 parent 326e8f6 commit 81b5abc
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ func.func public @ddot(%arg0: memref<8xf64>, %arg1: memref<8xf64>, %arg2: memref
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
scf.for %arg3 = %c0 to %c8 step %c1 {
%0 = memref.load %arg0[%arg3] : memref<8xf64>
%1 = memref.load %arg1[%arg3] : memref<8xf64>
%2 = memref.load %arg2[] : memref<f64>
%3 = arith.mulf %0, %1 : f64
%4 = arith.addf %2, %3 : f64
memref.store %4, %arg2[] : memref<f64>
%a0 = memref.load %arg0[%arg3] : memref<8xf64>
%a1 = memref.load %arg1[%arg3] : memref<8xf64>
%a2 = memref.load %arg2[] : memref<f64>
%a3 = arith.mulf %a0, %a1 : f64
%a4 = arith.addf %a2, %a3 : f64
memref.store %a4, %arg2[] : memref<f64>
}
return %arg2 : memref<f64>
}
Expand Down Expand Up @@ -129,4 +129,17 @@ func.func @invariant_loop_dialect() {
// CHECK-NEXT: arith.addf

return
}
}

func.func @speculate_tensor_dim_unknown_rank_known_dim(
// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_known_dim
%t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
%c0 = arith.constant 0 : index
// CHECK: scf.for
// CHECK-NEXT: tensor.dim
scf.for %i = %lb to %ub step %step {
%val = tensor.dim %t, %c0 : tensor<*xf32>
}

return
}
208 changes: 122 additions & 86 deletions xdsl/transforms/loop_invariant_code_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, scf, memref
from xdsl.ir import SSAValue, Operation, Block, Region
from xdsl.traits import IsTerminator, NoTerminator, OpTrait, OpTraitInvT
from xdsl.traits import NoTerminator, OpTrait, OpTraitInvT
from xdsl.traits import (
IsTerminator,
IsolatedFromAbove,
IsTerminator,
MemoryEffectKind,
get_effects,
Pure,
RecursiveMemoryEffect
# is_side_effect_free,
# only_has_effect,
)
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
Expand All @@ -30,113 +41,138 @@
# - the op and none of its contained operations depend on values inside of the
# loop (by means of calling definedOutside).
# - the op has no side-effects.
def canBeHoisted(op: Operation, condition: Callable[[SSAValue], bool]) -> bool | None:
def canBeHoisted(op: Operation, region_target: Region) -> bool | None:
# Do not move terminators.
if op.has_trait(IsTerminator):
return False

# Walk the nested operations and check that all used values are either
# defined outside of the loop or in a nested region, but not at the level of
# the loop body.
for child in op.walk():
for operand in child.operands:
for own in operand.owner.walk():
if not isinstance(own, scf.For):
if op.is_ancestor(operand.owner):
print("continue")
continue
if region_target.is_ancestor(own):
return False
return True


# def can_Be_Hoisted(op: Operation, region_target: Region) -> bool | None:

# if op.has_trait(IsTerminator):
# return False

# return any(
# op.is_ancestor(operand.owner) or isDefinedOutsideOfRegoin(op, region_target)
# for child in op.walk()
# for operand in child.operands
# )

return any(
op.is_ancestor(operand.owner) or condition(operand)
for child in op.walk()
for operand in child.operands
)

def move_Out_of_Region(op: Operation, region: Region):
print("hoisted op: ", op.name)

def isDefinedOutsideOfRegoin(value: SSAValue, region: Region) -> bool | None:
return not region.is_ancestor(value.owner)
def isDefinedOutsideOfRegoin(op: Operation, region: Region) -> bool | None:
return not op.is_ancestor(region)

def can_be_hoisted_with_value_check(op: Operation, defined_outside: Callable[[SSAValue], bool])-> bool | None:
return canBeHoisted(op, defined_outside(op))
def isMemoryEffectFree(op: Operation) -> bool | None:
if not op.has_trait(Pure):
return False
# Have a close look if the op might have side effects.
if not op.has_trait(RecursiveMemoryEffect):
return True
elif not op.has_trait(RecursiveMemoryEffect):
return False

for regions in op.regions:
for ops in regions.ops:
if not ops.has_trait(Pure):
return False

def moveLoopInvariantCode(
regions : Sequence[Region],
isDefinedOutsideRegion: Callable[[Region], bool], #function_ref<bool(Value, Region *)> isDefinedOutsideRegion
shouldMoveOutofRegion: Callable[[Operation], bool], #function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion
moveOutofRegion: Callable[[Operation], Region]) -> int | None: #function_ref<void(Operation *, Region *)> moveOutOfRegion
def isSpeculatable(op: Operation) -> bool | None:
if not op.has_trait(Pure):
return False
# Have a close look if the op might have side effects.
if not op.has_trait(RecursiveMemoryEffect):
return True
elif not op.has_trait(RecursiveMemoryEffect):
return False

numMoved = 0
worklist : list[Operation] = []

for region in regions: #iter thorugh the regions
print("Original loop: ", region.parent_node)
for op in region.block.ops:
print("Operation: ", op)
worklist.append(op)

definedOutside = isDefinedOutsideRegion(region)

while not worklist:
oper = worklist.pop()
#Skip ops that have already been moved. Check if the op can be hoisted.
if oper.parent_region() != region:
continue
print("Check op: ", oper)

if not(shouldMoveOutofRegion(oper, region) or not(canBeHoisted(oper, definedOutside))):
continue
print("Moving loop-invariant op: ", oper)
moveOutofRegion(oper, region)
numMoved = numMoved + 1

for user in oper.results[0].uses:
if user.operation.parent_region is region:
worklist.append(user.operation)

return numMoved
for regions in op.regions:
for ops in regions.ops:
if not ops.has_trait(Pure):
return False


class LoopsInvariantCodeMotion(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter) -> None:
for region in op.regions:
for block in region.blocks:
for lp in block.walk():
for oper in lp.operands:
print("lp: ",lp)
print("region: ", region)
print("operands: ", oper._name)
print("isDefinedOutside: ", isDefinedOutsideOfRegoin(oper, region))

if any(isinstance(ha, scf.For) for ha in op.body.walk()):
return
numMoved = 0
worklist : list[Operation] = []
for region in op.regions: #iter thorugh the regions
for ops in region.block.ops:
worklist.append(ops)
# print("worklist: ", not worklist)
while worklist:
# print("entered while loop")
oper = worklist.pop()
#Skip ops that have already been moved. Check if the op can be hoisted.
if oper.parent_region() != region:
continue
if (not(isMemoryEffectFree(oper) and isSpeculatable(oper)) or not(canBeHoisted(oper, region))):
continue
print("Moving loop-invariant op: ", oper)
move_Out_of_Region(oper, region)
numMoved = numMoved + 1
if not isinstance(oper, scf.Yield):
for user in oper.results[0].uses:
if user.operation.parent_region is region:
worklist.append(user.operation)

print(numMoved)
#####################dumb implementation###################
#works on only inner most scf.for loop for now but something that works
#dependent_stack holds all the operations that dependent and only this instructions should be in the loop
#hoistable instructions are temp_stack - dependent_stack
temp_stack : list[Operation] = []
dependent_stack : list[Operation] = []
hoist_instruction : list[Operation] = []

#walking to the inner most loop
if any(isinstance(ha, scf.For) for ha in op.body.walk()):
return
print(op.parent_block())
for i in op.parent_block().ops:
if isinstance(i, scf.For):
loop_variable_match = re.search(r'scf\.for\s+%(\w+)\s*=', str(i))
loop_variable = loop_variable_match.group(1) if loop_variable_match else None
print("loop variable: ", loop_variable)
# temp_stack : list[Operation] = []
# dependent_stack : list[Operation] = []
# hoist_instruction : list[Operation] = []

# #walking to the inner most loop
# if any(isinstance(ha, scf.For) for ha in op.body.walk()):
# return
# print(op.parent_block())
# for i in op.parent_block().ops:
# if isinstance(i, scf.For):
# loop_variable_match = re.search(r'scf\.for\s+%(\w+)\s*=', str(i))
# loop_variable = loop_variable_match.group(1) if loop_variable_match else None
# print("loop variable: ", loop_variable)

for iter in op.body.block.walk():
if str(iter).__contains__(loop_variable):
temp_stack.append(iter)
dependent_stack.append(iter)

while temp_stack:
item = temp_stack.pop()
if any(l is item for l in op.body.walk()):
print("item: ", item)
if isinstance(item, scf.Yield) or isinstance(item, scf.For) or isinstance(item, memref.Store):
continue
for user in item.results[0].uses:
print("Dependent Instructions are used: ", user.operation)
temp_stack.append(user.operation)
if not dependent_stack.__contains__(user.operation):
dependent_stack.append(user.operation)

for ins in dependent_stack:
print(ins)
# for iter in op.body.block.walk():
# if str(iter).__contains__(loop_variable):
# temp_stack.append(iter)
# dependent_stack.append(iter)

# while temp_stack:
# item = temp_stack.pop()
# if any(l is item for l in op.body.walk()):
# print("item: ", item)
# if isinstance(item, scf.Yield) or isinstance(item, scf.For) or isinstance(item, memref.Store):
# continue
# for user in item.results[0].uses:
# print("Dependent Instructions are used: ", user.operation)
# temp_stack.append(user.operation)
# if not dependent_stack.__contains__(user.operation):
# dependent_stack.append(user.operation)

# for ins in dependent_stack:
# print(ins)

class ScfForLoopInavarintCodeMotionPass(ModulePass):
"""
Expand Down

0 comments on commit 81b5abc

Please sign in to comment.