Skip to content

Commit

Permalink
Refactor statements copying in loop unrolling
Browse files Browse the repository at this point in the history
  • Loading branch information
sim642 committed Feb 21, 2024
1 parent bdebd1b commit 2fb6b61
Showing 1 changed file with 12 additions and 24 deletions.
36 changes: 12 additions & 24 deletions src/util/loopUnrolling.ml
Original file line number Diff line number Diff line change
Expand Up @@ -390,27 +390,13 @@ end
Also assigns fresh names to all labels and patches gotos for labels appearing in the current
fragment to their new name
*)
class copyandPatchLabelsVisitor(loopEnd,currentIterationEnd) = object
class copyandPatchLabelsVisitor(loopEnd, currentIterationEnd, gotos) = object
inherit nopCilVisitor

val mutable depth = 0
val mutable loopNestingDepth = 0

val gotos = StatementHashTable.create 20

method! vstmt s =
let after x =
depth <- depth-1;
if depth = 0 then
(* the labels can only be patched once the entire part of the AST we want has been transformed, and *)
(* we know all lables appear in the hash table *)
let patchLabelsVisitor = new patchLabelsGotosVisitor(StatementHashTable.find_opt gotos) in
let x = visitCilStmt patchLabelsVisitor x in
StatementHashTable.clear gotos;
x
else
x
in
let after x = x in
let rename_labels sn =
let new_labels = List.map (function Label(str,loc,b) -> Label (Cil.freshLabel str,loc,b) | x -> x) sn.labels in
(* this makes new physical copy*)
Expand All @@ -421,7 +407,6 @@ class copyandPatchLabelsVisitor(loopEnd,currentIterationEnd) = object
StatementHashTable.add gotos s new_s;
new_s
in
depth <- depth+1;
match s.skind with
| Continue loc ->
if loopNestingDepth = 0 then
Expand All @@ -440,6 +425,15 @@ class copyandPatchLabelsVisitor(loopEnd,currentIterationEnd) = object
| _ -> ChangeDoChildrenPost(rename_labels s, after)
end

let copy_and_patch_labels break_target current_continue_target stmts =
let gotos = StatementHashTable.create 20 in
let patcher = new copyandPatchLabelsVisitor (break_target, current_continue_target, gotos) in
let stmts' = List.map (visitCilStmt patcher) stmts in
(* the labels can only be patched once the entire part of the AST we want has been transformed, and *)
(* we know all lables appear in the hash table *)
let patchLabelsVisitor = new patchLabelsGotosVisitor(StatementHashTable.find_opt gotos) in
List.map (visitCilStmt patchLabelsVisitor) stmts'

class loopUnrollingVisitor(func, totalLoops) = object
(* Labels are simply handled by giving them a fresh name. Jumps coming from outside will still always go to the original label! *)
inherit nopCilVisitor
Expand All @@ -457,13 +451,7 @@ class loopUnrollingVisitor(func, totalLoops) = object
let copies = List.init factor (fun i ->
(* continues should go to the next unrolling *)
let current_continue_target = { (Cil.mkEmptyStmt ()) with labels = [Label (Cil.freshLabel ("loop_continue_" ^ (string_of_int i)),loc, false)]} in
let patcher = new copyandPatchLabelsVisitor (break_target, current_continue_target) in
let one_copy = visitCilStmt patcher (mkStmt (Block (mkBlock b.bstmts))) in
let one_copy_stmts = (* TODO: avoid this nonsense, directly visiting only b.bstmts breaks some continue labels for some reason *)
match one_copy.skind with
| Block b -> b.bstmts
| _ -> assert false
in
let one_copy_stmts = copy_and_patch_labels break_target current_continue_target b.bstmts in
one_copy_stmts @ [current_continue_target]
)
in
Expand Down

0 comments on commit 2fb6b61

Please sign in to comment.