From 2fb6b618ca557122c250178791c2056ae7f178fb Mon Sep 17 00:00:00 2001 From: Simmo Saan Date: Wed, 21 Feb 2024 11:47:54 +0200 Subject: [PATCH] Refactor statements copying in loop unrolling --- src/util/loopUnrolling.ml | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/src/util/loopUnrolling.ml b/src/util/loopUnrolling.ml index abdff8fa57..d5efd0d937 100644 --- a/src/util/loopUnrolling.ml +++ b/src/util/loopUnrolling.ml @@ -390,27 +390,13 @@ end Also assigns fresh names to all labels and patches gotos for labels appearing in the current fragment to their new name *) -class copyandPatchLabelsVisitor(loopEnd,currentIterationEnd) = object +class copyandPatchLabelsVisitor(loopEnd, currentIterationEnd, gotos) = object inherit nopCilVisitor - val mutable depth = 0 val mutable loopNestingDepth = 0 - val gotos = StatementHashTable.create 20 - method! vstmt s = - let after x = - depth <- depth-1; - if depth = 0 then - (* the labels can only be patched once the entire part of the AST we want has been transformed, and *) - (* we know all lables appear in the hash table *) - let patchLabelsVisitor = new patchLabelsGotosVisitor(StatementHashTable.find_opt gotos) in - let x = visitCilStmt patchLabelsVisitor x in - StatementHashTable.clear gotos; - x - else - x - in + let after x = x in let rename_labels sn = let new_labels = List.map (function Label(str,loc,b) -> Label (Cil.freshLabel str,loc,b) | x -> x) sn.labels in (* this makes new physical copy*) @@ -421,7 +407,6 @@ class copyandPatchLabelsVisitor(loopEnd,currentIterationEnd) = object StatementHashTable.add gotos s new_s; new_s in - depth <- depth+1; match s.skind with | Continue loc -> if loopNestingDepth = 0 then @@ -440,6 +425,15 @@ class copyandPatchLabelsVisitor(loopEnd,currentIterationEnd) = object | _ -> ChangeDoChildrenPost(rename_labels s, after) end +let copy_and_patch_labels break_target current_continue_target stmts = + let gotos = StatementHashTable.create 20 in + let patcher = new copyandPatchLabelsVisitor (break_target, current_continue_target, gotos) in + let stmts' = List.map (visitCilStmt patcher) stmts in + (* the labels can only be patched once the entire part of the AST we want has been transformed, and *) + (* we know all lables appear in the hash table *) + let patchLabelsVisitor = new patchLabelsGotosVisitor(StatementHashTable.find_opt gotos) in + List.map (visitCilStmt patchLabelsVisitor) stmts' + class loopUnrollingVisitor(func, totalLoops) = object (* Labels are simply handled by giving them a fresh name. Jumps coming from outside will still always go to the original label! *) inherit nopCilVisitor @@ -457,13 +451,7 @@ class loopUnrollingVisitor(func, totalLoops) = object let copies = List.init factor (fun i -> (* continues should go to the next unrolling *) let current_continue_target = { (Cil.mkEmptyStmt ()) with labels = [Label (Cil.freshLabel ("loop_continue_" ^ (string_of_int i)),loc, false)]} in - let patcher = new copyandPatchLabelsVisitor (break_target, current_continue_target) in - let one_copy = visitCilStmt patcher (mkStmt (Block (mkBlock b.bstmts))) in - let one_copy_stmts = (* TODO: avoid this nonsense, directly visiting only b.bstmts breaks some continue labels for some reason *) - match one_copy.skind with - | Block b -> b.bstmts - | _ -> assert false - in + let one_copy_stmts = copy_and_patch_labels break_target current_continue_target b.bstmts in one_copy_stmts @ [current_continue_target] ) in