Skip to content

Commit

Permalink
Fixed linting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
RRavikiran66 committed Aug 19, 2024
1 parent 81b5abc commit 44fedab
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 45 deletions.
1 change: 1 addition & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def get_licm():
from xdsl.transforms import loop_invariant_code_motion

return loop_invariant_code_motion.ScfForLoopInavarintCodeMotionPass

return {
"arith-add-fastmath": get_arith_add_fastmath,
"loop-hoist-memref": get_loop_hoist_memref,
Expand Down
166 changes: 121 additions & 45 deletions xdsl/transforms/loop_invariant_code_motion.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@
from collections.abc import Iterator
from typing import Any, Sequence, cast
import re
import queue
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from collections.abc import Callable

from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, scf, memref
from xdsl.ir import SSAValue, Operation, Block, Region
from xdsl.traits import NoTerminator, OpTrait, OpTraitInvT
from xdsl.traits import (
IsTerminator,
IsolatedFromAbove,
IsTerminator,
MemoryEffectKind,
get_effects,
Pure,
RecursiveMemoryEffect
# is_side_effect_free,
# only_has_effect,
)
from xdsl.dialects import builtin, scf
from xdsl.ir import Operation, Region, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.traits import (
IsTerminator,
Pure,
RecursiveMemoryEffect,
# is_side_effect_free,
# only_has_effect,
)

# This pass hoists operation that are invariant to the loops.
#
Expand All @@ -33,51 +25,80 @@
#
# An operation is loop-invariant if it depends only of values defined outside of the loop. LICM moves these operations out of the loop body so that they are not computed more than once.
#
# for i in range(x, N, M): for i in range(x, N, M):
# for i in range(x, N, M): for i in range(x, N, M):
# for j in range(0, M, K): ----> c[i]= A[1] + b[1]
# c[i]=A[1]+b[1]


# Checks whether the given op can be hoisted by checking that
# - the op and none of its contained operations depend on values inside of the
# loop (by means of calling definedOutside).
# - the op has no side-effects.
def canBeHoisted(op: Operation, region_target: Region) -> bool | None:
# Do not move terminators.
print("enter can be hoisted?")
if op.has_trait(IsTerminator):
print("terminated")
return False

for ops in op.walk():
print("op12s: ", ops)

# Walk the nested operations and check that all used values are either
# defined outside of the loop or in a nested region, but not at the level of
# the loop body.
for child in op.walk():
for operand in child.operands:
for own in operand.owner.walk():
if not isinstance(own, scf.For):
if op.is_ancestor(operand.owner):
print("continue")
continue
print("operand owner: ", own)
print("ancester?: ", region_target.is_ancestor(own))
# if op.is_ancestor(operand.owner):
# print("continue")
# continue
if region_target.is_ancestor(own):
return False
print("owner and operand: ", op.is_ancestor(own))
print("region target: ", region_target.is_ancestor(own))
print(
"is defined outside?: ",
isDefinedOutsideOfRegoin(own, region_target),
)
return False
return True

# return any(
# op.is_ancestor(operand.owner) or condition(operand)
# for child in op.walk()
# for operand in child.operands
# )


def can_Be_Hoisted(op: Operation, region_target: Region) -> bool | None:

# def can_Be_Hoisted(op: Operation, region_target: Region) -> bool | None:
if op.has_trait(IsTerminator):
return False

return any(
op.is_ancestor(operand.owner) or isDefinedOutsideOfRegoin(op, region_target)
for child in op.walk()
for operand in child.operands
)

# if op.has_trait(IsTerminator):
# return False

# return any(
# op.is_ancestor(operand.owner) or isDefinedOutsideOfRegoin(op, region_target)
# for child in op.walk()
# for operand in child.operands
# )

def move_Out_of_Region(op: Operation, region: Region):
print("hoisted op: ", op.name)


def isDefinedOutsideOfRegoin(op: Operation, region: Region) -> bool | None:
return not op.is_ancestor(region)


def can_be_hoisted_with_value_check(
op: Operation, defined_outside: Callable[[SSAValue], bool]
) -> bool | None:
return canBeHoisted(op, defined_outside(op))

Check failure on line 99 in xdsl/transforms/loop_invariant_code_motion.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Argument of type "bool" cannot be assigned to parameter "region_target" of type "Region" in function "canBeHoisted"   "bool" is incompatible with "Region" (reportGeneralTypeIssues)

Check failure on line 99 in xdsl/transforms/loop_invariant_code_motion.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Argument of type "Operation" cannot be assigned to parameter of type "SSAValue"   "Operation" is incompatible with "SSAValue" (reportGeneralTypeIssues)

Check failure on line 99 in xdsl/transforms/loop_invariant_code_motion.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Argument of type "bool" cannot be assigned to parameter "region_target" of type "Region" in function "canBeHoisted"   "bool" is incompatible with "Region" (reportGeneralTypeIssues)

Check failure on line 99 in xdsl/transforms/loop_invariant_code_motion.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Argument of type "Operation" cannot be assigned to parameter of type "SSAValue"   "Operation" is incompatible with "SSAValue" (reportGeneralTypeIssues)


def isMemoryEffectFree(op: Operation) -> bool | None:
if not op.has_trait(Pure):
return False
Expand All @@ -86,12 +107,13 @@ def isMemoryEffectFree(op: Operation) -> bool | None:
return True
elif not op.has_trait(RecursiveMemoryEffect):
return False

for regions in op.regions:
for ops in regions.ops:
if not ops.has_trait(Pure):
return False


def isSpeculatable(op: Operation) -> bool | None:
if not op.has_trait(Pure):
return False
Expand All @@ -100,32 +122,85 @@ def isSpeculatable(op: Operation) -> bool | None:
return True
elif not op.has_trait(RecursiveMemoryEffect):
return False

for regions in op.regions:
for ops in regions.ops:
if not ops.has_trait(Pure):
return False
return False


# def moveLoopInvariantCode(
# regions : Sequence[Region],
# isDefinedOutsideRegion: Callable[[Region], bool], #function_ref<bool(Value, Region *)> isDefinedOutsideRegion
# shouldMoveOutofRegion: Callable[[Operation], bool], #function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion
# moveOutofRegion: Callable[[Operation], Region]) -> int | None: #function_ref<void(Operation *, Region *)> moveOutOfRegion

# numMoved = 0
# worklist : list[Operation] = []

# for region in regions: #iter thorugh the regions
# print("Original loops: ", region.parent_node)
# for op in region.block.ops:
# print("Operation: ", op)
# worklist.append(op)
# while not worklist:
# oper = worklist.pop()
# #Skip ops that have already been moved. Check if the op can be hoisted.
# if oper.parent_region() != region:
# continue
# print("Check op: ", oper)
# print("is memory affect free: ", isMemoryEffectFree(oper))
# print("can be hoisted?: ", canBeHoisted(oper, region))
# if not(isMemoryEffectFree(oper) and isSpeculatable(oper)) or not(canBeHoisted(oper, region)):
# continue
# print("Moving loop-invariant op: ", oper)
# move_Out_of_Region(oper, region)
# numMoved = numMoved + 1

# for user in oper.results[0].uses:
# if user.operation.parent_region is region:
# worklist.append(user.operation)

# return numMoved


class LoopsInvariantCodeMotion(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter) -> None:
# numMoved = 0
# worklist : list[Operation] = []
# for region in op.regions:
# print("Original loop:", region.parent_op())
# for block in region.ops:
# for lp in block.walk():
# for oper in lp.operands:
# print("lp: ",lp)
# print("region: ", region)
# print("operands: ", oper._name)
# print("isDefinedOutside: ", isDefinedOutsideOfRegoin(oper, region))

if any(isinstance(ha, scf.For) for ha in op.body.walk()):
return
numMoved = 0
worklist : list[Operation] = []
for region in op.regions: #iter thorugh the regions
worklist: list[Operation] = []
for region in op.regions: # iter thorugh the regions
print("Original loops: ", region.parent_node)
for ops in region.block.ops:
print("Operation: ", ops)
worklist.append(ops)
# print("worklist: ", not worklist)
while worklist:
# print("entered while loop")
oper = worklist.pop()
#Skip ops that have already been moved. Check if the op can be hoisted.
# Skip ops that have already been moved. Check if the op can be hoisted.
if oper.parent_region() != region:
continue
if (not(isMemoryEffectFree(oper) and isSpeculatable(oper)) or not(canBeHoisted(oper, region))):
print("Check op: ", oper)
print("is memory affect free: ", isMemoryEffectFree(oper))
print("can be hoisted?: ", canBeHoisted(oper, region))
if not (isMemoryEffectFree(oper) and isSpeculatable(oper)) or not (
canBeHoisted(oper, region)
):
continue
print("Moving loop-invariant op: ", oper)
move_Out_of_Region(oper, region)
Expand All @@ -135,11 +210,11 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter) -> None:
if user.operation.parent_region is region:
worklist.append(user.operation)

print(numMoved)
print(numMoved)
#####################dumb implementation###################
#works on only inner most scf.for loop for now but something that works
#dependent_stack holds all the operations that dependent and only this instructions should be in the loop
#hoistable instructions are temp_stack - dependent_stack
# works on only inner most scf.for loop for now but something that works
# dependent_stack holds all the operations that dependent and only this instructions should be in the loop
# hoistable instructions are temp_stack - dependent_stack
# temp_stack : list[Operation] = []
# dependent_stack : list[Operation] = []
# hoist_instruction : list[Operation] = []
Expand All @@ -153,7 +228,7 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter) -> None:
# loop_variable_match = re.search(r'scf\.for\s+%(\w+)\s*=', str(i))
# loop_variable = loop_variable_match.group(1) if loop_variable_match else None
# print("loop variable: ", loop_variable)

# for iter in op.body.block.walk():
# if str(iter).__contains__(loop_variable):
# temp_stack.append(iter)
Expand All @@ -174,6 +249,7 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter) -> None:
# for ins in dependent_stack:
# print(ins)


class ScfForLoopInavarintCodeMotionPass(ModulePass):
"""
Folds perfect loop nests if they can be represented with a single loop.
Expand Down

0 comments on commit 44fedab

Please sign in to comment.