From 176a9a2d6ba879013b725312362e49fa1cb360ee Mon Sep 17 00:00:00 2001 From: Vindaar Date: Sun, 16 Jun 2024 18:41:31 +0200 Subject: [PATCH 1/4] add `fuseLoops` macro to fuse loops in Nim Fusing the loops in this context means the equivalent of OpenMP's `collapse` statement, i.e. merging multiple nested loops into a single loop. One can use the `nofuse` argument to indicate a nested loop should _not_ be fused. --- scinim/fuse_loops.nim | 189 ++++++++++++++++++++++++++++++++++++++++++ tests/tFuseLoops.nim | 54 ++++++++++++ 2 files changed, 243 insertions(+) create mode 100644 scinim/fuse_loops.nim create mode 100644 tests/tFuseLoops.nim diff --git a/scinim/fuse_loops.nim b/scinim/fuse_loops.nim new file mode 100644 index 0000000..8e44b4c --- /dev/null +++ b/scinim/fuse_loops.nim @@ -0,0 +1,189 @@ +import std / [macros, options, algorithm] + +type + ForLoop = object + n: NimNode # the actual node + body: NimNode # body of the loop *WITHOUT* any inner loops! + idx: NimNode # the loop index + start: NimNode # start of the loop + stop: NimNode # stop of the loop + +template nofuse*(arg: untyped): untyped = + ## Just a dummy template, which can be easily used to disable fusing of + ## a nested loop + arg + +proc extractBody(n: NimNode): NimNode = + ## Returns the input tree without any possible nested for loops. Nested + ## loops are replaced by `nnkEmpty` nodes to be filled again later in `bodies`. + case n.kind + of nnkForStmt: + if n[1].kind == nnkInfix and n[1][0].strVal == "..<": + result = newEmptyNode() ## Flattened nested loop body will be inserted here + else: + result = n + else: + if n.len > 0: + result = newTree(n.kind) + for ch in n: + let bd = extractBody(ch) + if bd != nil: + result.add bd + else: + result = n + +proc toForLoop(n: NimNode): Option[ForLoop] = + ## Returns a `some(ForLoop)` if the given node is a fuse-able for loop + doAssert n.kind == nnkForStmt + if n[1].kind != nnkInfix: return + if n[1][0].strVal != "..<": + error("Unexpected iterator: " & $n[1].repr & + ". It must be of the form `0 ..< X`.") + if not (n[1][1].kind == nnkIntLit and n[1][1].intVal == 0): + error("Starting iteration index must be 0!") + result = some(ForLoop(n: n, + body: extractBody(n[2]), + idx: n[0], + start: n[1][1], + stop: n[1][2])) + +template addIf(s, opt): untyped = + if opt.isSome: + s.add opt.unsafeGet + +proc extractLoops(n: NimNode): seq[ForLoop] = + ## Extracts (fuse-able) loops from the given Nim node and errors if more than + ## one for loop found at the same level. + case n.kind + of nnkForStmt: + result.addIf toForLoop(n) + result.add extractLoops(n[2]) # go over body + else: + var foundLoops = 0 # counter for number of loops at current body + for ch in n: + let loops = extractLoops(ch) + if loops.len > 0: + result.add loops + inc foundLoops + if foundLoops > 1: + error("Found more than one loop (" & $foundLoops & ") at the level of node: " & + n.repr & ". Please wrap " & "these loops as `nofuse`, i.e. `nofuse(0 ..< X)`") + +proc genFusedLoop(idx: NimNode, stop: NimNode, ompStr = ""): NimNode = + ## Generate either regular or OpenMP for loop + let loopIter = if ompStr.len == 0: + nnkInfix.newTree(ident"..<", + newLit 0, + stop) + else: + nnkCall.newTree(ident"||", + newLit 0, + stop, + newLit ompStr) + result = nnkForStmt.newTree( + idx, + loopIter + ) + +proc calcStop(loops: seq[ForLoop]): NimNode = + ## Returns `N * T * U * ...` expression where the indices are + ## the stop indices of the loops to be fused. + case loops.len + of 0: doAssert false, "Must not happen" + of 1: result = loops[0].stop + else: + var ml = loops.reversed # want last elements first + let x = ml.pop + result = nnkInfix.newTree(ident"*", x.stop, + calcStop(ml.reversed)) + +proc modOrDiv(prefix, suffix: NimNode, isDiv: bool): NimNode = + if isDiv: + result = quote do: + `prefix` div `suffix` + else: + result = quote do: + `prefix` mod `suffix` + +proc asLet(v, val: NimNode): NimNode = + result = quote do: + let `v` = `val` + +proc genPrelude(idx: NimNode, loops: seq[ForLoop]): NimNode = + ## The basic algorithm for generating the correct index for fused loops is + ## + ## Notation: + ## `i` = Loop index of single remaining outer loop + ## `N_i` = Stopping index (-1) of the inner loop `i` + ## `n` = Total number of nested loops + ## + ## Whichever is easiest to read for you: + ## + ## `let i0 = i div (N_0 * N_1 ... N_n)` + ## `let i1 = (i mod (N_0 * N_1 ... N_n)) div (N_1 * N_2 * ... N_n)` + ## `let i2 = ((i mod (N_0 * N_1 ... N_n)) mod (N_1 * N_2 * ... N_n)) div (N_2 * ... * N_n)` + ## ... + ## + ## ... or + ## + ## `let i0 = i div Π_i=0^n N_i` + ## `let i1 = (i mod Π_i=0^n N_i) div Π_i=1^n N_i` + ## `let i2 = ((i mod Π_i=0^n N_i) mod Π_i=1^n N_i) div Π_i=2^n N_i` + ## + ## ...or + ## + ## `let i0 = Idx div [Product of remaining N-1 loops]` + ## `let i1 = (Idx mod [Product of remaining loops]) div [Product of remaining N-2 loops]` + ## `let i2 = (Idx mod [Product of remaining loops]) mod [Product of remaining N-2 loops]` + result = newStmtList() + var prefix = idx + var ml = loops.reversed + var lIdx = ml.pop # drop first element + var suffix = ml.calcStop() + while ml.len > 0: + result.add asLet(lIdx.idx, modOrDiv(prefix, suffix, isDiv = true)) + lIdx = ml.pop # get next loop index & adjust remaining loops + # now adjust prefix and suffix + prefix = modOrDiv(prefix, suffix, isDiv = false) + if ml.len > 0: # adjust suffix + suffix = ml.calcStop() + else: # simply add last 'prefix' + result.add asLet(lIdx.idx, prefix) + +proc bodies(loops: seq[ForLoop]): NimNode = + ## Concatenates all loop bodies, by placing the next loop into the + ## `nnkEmpty` node of the current node + var ml = loops.reversed + #echo ml.repr + var cur = ml.pop + result = cur.body + for i in 0 ..< result.len: + let ch = result[i] + if ch.kind == nnkEmpty: + # insert next loop + result[i] = bodies(ml.reversed) # revert order again + break # there can only be a single `nnkEmpty` (multiple loops not allowed, + # yields CT error) + +proc fuseLoopImpl(ompStr: string, body: NimNode): NimNode = + # 1. extract all loops from the body + let loops = extractLoops(body) + # 2. generate identifier for the final loop + let idx = genSym(nskForVar, "idx") + # 3. generate the fused outer loop + result = genFusedLoop(idx, calcStop(loops), ompStr) + # 4. generate final loop body by... + var loopBody = newStmtList() + # 4a. generate prelude of loop variables of original loops + loopBody.add genPrelude(idx, loops) # gen code to produce the old loop variables + # 4b. insert old loop bodies into respective positions + loopBody.add bodies(loops) + result.add loopBody + when defined(DebugFuseLoop): + echo result.repr + +macro fuseLoops*(body: untyped): untyped = + result = fuseLoopImpl("", body) + +macro fuseLoops*(ompStr: untyped{lit}, body: untyped): untyped = + result = fuseLoopImpl(ompStr.strVal, body) diff --git a/tests/tFuseLoops.nim b/tests/tFuseLoops.nim new file mode 100644 index 0000000..474da11 --- /dev/null +++ b/tests/tFuseLoops.nim @@ -0,0 +1,54 @@ +import ../scinim/fuse_loops +import std / unittest + +suite "fuseLoops": + test "Compiles test for different `fuseLoops` setups": + const N = 5 + const T = 10 + const X = 3 + + ## XXX: These should probably become proper tests. :) + + fuseLoops: + for i in 0 ..< N: + let x = i * 2 + for j in 0 ..< T: + let z = x * j + echo i, j, x, z + echo x + + fuseLoops: + for i in 0 ..< N: + let x = i * 2 + for j in 0 ..< T: + let z = x * j + echo i, j, x, z + for k in nofuse(0 ..< T): + echo k + echo x + + fuseLoops("parallel for"): + for i in 0 ..< N: + let x = i * 2 + for j in 0 ..< T: + let z = x * j + for k in 0 ..< X: + echo i, j, k, x, z + echo x + + ## The following raises a CT error + when compiles(( + fuseLoops: + for i in 0 ..< N: + let x = i * 2 + var zsum = 0 + for j in 0 ..< T: + let z = x * j + zsum += z + echo i, x, z + echo x + for j in 0 ..< 2 * T: + zsum += j + echo zsum + )): + doAssert false From 35152e0b2462bcf7b91790f0f5cd0a51c1f02809 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Sun, 16 Jun 2024 18:51:44 +0200 Subject: [PATCH 2/4] export `fuseLoops` by default in `scinim` --- scinim.nim | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scinim.nim b/scinim.nim index 4111ad4..4dbe9cd 100644 --- a/scinim.nim +++ b/scinim.nim @@ -3,3 +3,6 @@ export signals import ./scinim/numpyarrays export numpyarrays + +import ./scinim/fuse_loops +export fuse_loops From fd52267d863fd5027ff7831c8d94cf88e41d24b1 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Sun, 16 Jun 2024 19:02:25 +0200 Subject: [PATCH 3/4] add minor docstring and add note about OpenMP compilation --- scinim/fuse_loops.nim | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/scinim/fuse_loops.nim b/scinim/fuse_loops.nim index 8e44b4c..3355880 100644 --- a/scinim/fuse_loops.nim +++ b/scinim/fuse_loops.nim @@ -183,7 +183,24 @@ proc fuseLoopImpl(ompStr: string, body: NimNode): NimNode = echo result.repr macro fuseLoops*(body: untyped): untyped = + ## Fuses all loops inside the body of the macro, unless they are annotated with + ## `nofuse`. result = fuseLoopImpl("", body) macro fuseLoops*(ompStr: untyped{lit}, body: untyped): untyped = + ## Fuses all loops inside the body of the macro, unless they are annotated with + ## `nofuse`. + ## + ## This version supports handing a string to be passed to OpenMP, i.e. + ## `fuseLoops("parallelFor"): body` + ## + ## Note: To utilize OpenMP, you may have to compile with + ## `--passC:"-fopenmp" --passL:"-lgomp"` + ## (at least for GCC. For Clang the commands differ slightly I believe). + ## + ## There is also a chance you either have to compile with + ## `--exceptions:quirky` + ## or using the C++ backend, due to the C backend producing `goto` statements + ## inside the loops, which lead to C compiler errors when combined with + ## OpenMP. result = fuseLoopImpl(ompStr.strVal, body) From dbb79a948a26529dad050b2dc4747cb508a8799a Mon Sep 17 00:00:00 2001 From: Vindaar Date: Sun, 16 Jun 2024 19:08:01 +0200 Subject: [PATCH 4/4] update arraymancer dependency to 0.7.32 for SomeSets fix --- scinim.nimble | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scinim.nimble b/scinim.nimble index de0fc18..f443502 100644 --- a/scinim.nimble +++ b/scinim.nimble @@ -10,7 +10,7 @@ backend = "cpp" # Dependencies requires "nim >= 1.6.0" requires "threading" -requires "arraymancer >= 0.7.31" +requires "arraymancer >= 0.7.32" requires "polynumeric >= 0.2.0" requires "nimpy >= 0.2.0"