From 6bf24e991290385615ecf7da607ff430bdbd3dbf Mon Sep 17 00:00:00 2001 From: Yuriy Glukhov Date: Thu, 19 Sep 2024 11:58:04 +0200 Subject: [PATCH 1/2] Revert "disable closure iterator changes in #23787 unless `-d:nimOptIters` is enabled (#24108)" This reverts commit 22d2cf217597468ace8ba540d6990b1f6d8a816a. --- compiler/closureiters.nim | 128 ++++-------------- compiler/lambdalifting.nim | 36 +---- compiler/transf.nim | 16 +-- tests/destructor/tuse_ownedref_after_move.nim | 2 +- tests/iter/tyieldintry.nim | 4 +- 5 files changed, 36 insertions(+), 150 deletions(-) diff --git a/compiler/closureiters.nim b/compiler/closureiters.nim index 8bdd04ca78e19..9ee394111ba26 100644 --- a/compiler/closureiters.nim +++ b/compiler/closureiters.nim @@ -161,10 +161,6 @@ type # is their finally. For finally it is parent finally. Otherwise -1 idgen: IdGenerator varStates: Table[ItemId, int] # Used to detect if local variable belongs to multiple states - stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting - # remove if -d:nimOptIters is default, treating it as always nil - nimOptItersEnabled: bool # tracks if -d:nimOptIters is enabled - # should be default when issues are fixed, see #24094 const nkSkip = {nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt, @@ -174,11 +170,8 @@ const localRequiresLifting = -2 proc newStateAccess(ctx: var Ctx): PNode = - if ctx.stateVarSym.isNil: - result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), + result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), getStateField(ctx.g, ctx.fn), ctx.fn.info) - else: - result = newSymNode(ctx.stateVarSym) proc newStateAssgn(ctx: var Ctx, toValue: PNode): PNode = # Creates state assignment: @@ -196,22 +189,12 @@ proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym = result.flags.incl sfNoInit assert(not typ.isNil, "Env var needs a type") - if not ctx.stateVarSym.isNil: - # We haven't gone through labmda lifting yet, so just create a local var, - # it will be lifted later - if ctx.tempVars.isNil: - ctx.tempVars = newNodeI(nkVarSection, ctx.fn.info) - addVar(ctx.tempVars, newSymNode(result)) - else: - let envParam = getEnvParam(ctx.fn) - # let obj = envParam.typ.lastSon - result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen) + let envParam = getEnvParam(ctx.fn) + # let obj = envParam.typ.lastSon + result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen) proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode = - if ctx.stateVarSym.isNil: - result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info) - else: - result = newSymNode(s) + result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info) proc newTempVarAccess(ctx: Ctx, s: PSym): PNode = result = newSymNode(s, ctx.fn.info) @@ -263,20 +246,12 @@ proc newTempVarDef(ctx: Ctx, s: PSym, initialValue: PNode): PNode = v = ctx.g.emptyNode newTree(nkVarSection, newTree(nkIdentDefs, newSymNode(s), ctx.g.emptyNode, v)) -proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode - proc newTempVar(ctx: var Ctx, typ: PType, parent: PNode, initialValue: PNode = nil): PSym = - if ctx.nimOptItersEnabled: - result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info) - else: - result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ) + result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info) inc ctx.tempVarId result.typ = typ assert(not typ.isNil, "Temp var needs a type") - if ctx.nimOptItersEnabled: - parent.add(ctx.newTempVarDef(result, initialValue)) - elif initialValue != nil: - parent.add(ctx.newEnvVarAsgn(result, initialValue)) + parent.add(ctx.newTempVarDef(result, initialValue)) proc hasYields(n: PNode): bool = # TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt. @@ -455,24 +430,13 @@ proc newTempVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = result = newTree(nkFastAsgn, ctx.newTempVarAccess(s), v) result.info = v.info -proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode = - # unused with -d:nimOptIters - if isEmptyType(v.typ): - result = v - else: - result = newTree(nkFastAsgn, ctx.newEnvVarAccess(s), v) - result.info = v.info - proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) = - var input = input if input.kind == nkStmtListExpr: let (st, res) = exprToStmtList(input) output.add(st) - input = res - if ctx.nimOptItersEnabled: - output.add(ctx.newTempVarAsgn(sym, input)) + output.add(ctx.newTempVarAsgn(sym, res)) else: - output.add(ctx.newEnvVarAsgn(sym, input)) + output.add(ctx.newTempVarAsgn(sym, input)) proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode = result = newNodeI(nkStmtList, exprBody.info) @@ -601,11 +565,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkIf): " & $branch.kind) - if isExpr: - if ctx.nimOptItersEnabled: - result.add(ctx.newTempVarAccess(tmp)) - else: - result.add(ctx.newEnvVarAccess(tmp)) + if isExpr: result.add(ctx.newTempVarAccess(tmp)) of nkTryStmt, nkHiddenTryStmt: var ns = false @@ -635,10 +595,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkTryStmt): " & $branch.kind) result.add(n) - if ctx.nimOptItersEnabled: - result.add(ctx.newTempVarAccess(tmp)) - else: - result.add(ctx.newEnvVarAccess(tmp)) + result.add(ctx.newTempVarAccess(tmp)) of nkCaseStmt: var ns = false @@ -670,10 +627,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = else: internalError(ctx.g.config, "lowerStmtListExpr(nkCaseStmt): " & $branch.kind) result.add(n) - if ctx.nimOptItersEnabled: - result.add(ctx.newTempVarAccess(tmp)) - else: - result.add(ctx.newEnvVarAccess(tmp)) + result.add(ctx.newTempVarAccess(tmp)) elif n[0].kind == nkStmtListExpr: result = newNodeI(nkStmtList, n.info) let (st, ex) = exprToStmtList(n[0]) @@ -706,11 +660,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let tmp = ctx.newTempVar(cond.typ, result, cond) # result.add(ctx.newTempVarAsgn(tmp, cond)) - var check: PNode - if ctx.nimOptItersEnabled: - check = ctx.newTempVarAccess(tmp) - else: - check = ctx.newEnvVarAccess(tmp) + var check = ctx.newTempVarAccess(tmp) if n[0].sym.magic == mOr: check = ctx.g.newNotCall(check) @@ -720,18 +670,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode = let (st, ex) = exprToStmtList(cond) ifBody.add(st) cond = ex - if ctx.nimOptItersEnabled: - ifBody.add(ctx.newTempVarAsgn(tmp, cond)) - else: - ifBody.add(ctx.newEnvVarAsgn(tmp, cond)) + ifBody.add(ctx.newTempVarAsgn(tmp, cond)) let ifBranch = newTree(nkElifBranch, check, ifBody) let ifNode = newTree(nkIfStmt, ifBranch) result.add(ifNode) - if ctx.nimOptItersEnabled: - result.add(ctx.newTempVarAccess(tmp)) - else: - result.add(ctx.newEnvVarAccess(tmp)) + result.add(ctx.newTempVarAccess(tmp)) else: for i in 0..; requires a copy because it's not the last read of ':envAlt.b1'; routine: main" + errormsg: "'=copy' is not available for type ; requires a copy because it's not the last read of ':envAlt.b0'; routine: main" line: 48 """ diff --git a/tests/iter/tyieldintry.nim b/tests/iter/tyieldintry.nim index e51ab7f0d59a1..04409795b0763 100644 --- a/tests/iter/tyieldintry.nim +++ b/tests/iter/tyieldintry.nim @@ -1,5 +1,5 @@ discard """ - matrix: "; --experimental:strictdefs; -d:nimOptIters" + matrix: "; --experimental:strictdefs" targets: "c cpp" """ @@ -505,7 +505,7 @@ block: # void iterator discard var a = it -if defined(nimOptIters): # Locals present in only 1 state should be on the stack +block: # Locals present in only 1 state should be on the stack proc checkOnStack(a: pointer, shouldBeOnStack: bool) = # Quick and dirty way to check if a points to stack var dummy = 0 From af1813150c0e457df5733b5fd2a036b15ccdcb5a Mon Sep 17 00:00:00 2001 From: Yuriy Glukhov Date: Wed, 16 Oct 2024 13:10:43 +0200 Subject: [PATCH 2/2] Fixes #3824, fixes #19154, refs #24094. --- compiler/lambdalifting.nim | 40 +++++++--- tests/iter/tnestedclosures.nim | 139 +++++++++++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 10 deletions(-) create mode 100644 tests/iter/tnestedclosures.nim diff --git a/compiler/lambdalifting.nim b/compiler/lambdalifting.nim index 136af59ef35d1..3d3a50089b65c 100644 --- a/compiler/lambdalifting.nim +++ b/compiler/lambdalifting.nim @@ -175,6 +175,7 @@ proc addHiddenParam(routine: PSym, param: PSym) = #echo "produced environment: ", param.id, " for ", routine.id proc getEnvParam*(routine: PSym): PSym = + if routine.ast.isNil: return nil let params = routine.ast[paramsPos] let hidden = lastSon(params) if hidden.kind == nkSym and hidden.sym.kind == skParam and hidden.sym.name.s == paramName: @@ -419,6 +420,12 @@ proc addClosureParam(c: var DetectionPass; fn: PSym; info: TLineInfo) = localError(c.graph.config, fn.info, "internal error: inconsistent environment type") #echo "adding closure to ", fn.name.s +proc iterEnvHasUpField(g: ModuleGraph, iter: PSym): bool = + let cp = getEnvParam(iter) + doAssert(cp != nil, "Env param not present in iter") + let upField = lookupInRecord(cp.typ.skipTypes({tyOwned, tyRef, tyPtr}).n, getIdent(g.cache, upName)) + upField != nil + proc detectCapturedVars(n: PNode; owner: PSym; c: var DetectionPass) = case n.kind of nkSym: @@ -437,7 +444,7 @@ proc detectCapturedVars(n: PNode; owner: PSym; c: var DetectionPass) = let body = transformBody(c.graph, c.idgen, s, {useCache}) detectCapturedVars(body, s, c) let ow = s.skipGenericOwner - let innerClosure = innerProc and s.typ.callConv == ccClosure and not s.isIterator + let innerClosure = innerProc and s.typ.callConv == ccClosure and (not s.isIterator or iterEnvHasUpField(c.graph, s)) let interested = interestingVar(s) if ow == owner: if owner.isIterator: @@ -642,16 +649,27 @@ proc finishClosureCreation(owner: PSym; d: var DetectionPass; c: LiftingPass; res.add newAsgnStmt(unowned, nilLit, info) createTypeBoundOpsLL(d.graph, unowned.typ, info, d.idgen, owner) -proc closureCreationForIter(iter: PNode; +proc getUpForIter(g: ModuleGraph; owner, iterOwner: PSym, expectedUpTyp: PType): PNode = + var p = getHiddenParam(g, owner) + var res = p.newSymNode + while res.typ.skipTypes({tyOwned, tyRef, tyPtr}) != expectedUpTyp: + let upField = lookupInRecord(p.typ.skipTypes({tyOwned, tyRef, tyPtr}).n, getIdent(g.cache, upName)) + if upField == nil: + return nil + p = upField + res = rawIndirectAccess(res, upField, p.info) + res + +proc closureCreationForIter(owner: PSym, iter: PNode; d: var DetectionPass; c: var LiftingPass): PNode = result = newNodeIT(nkStmtListExpr, iter.info, iter.sym.typ) - let owner = iter.sym.skipGenericOwner - var v = newSym(skVar, getIdent(d.graph.cache, envName), d.idgen, owner, iter.info) + let iterOwner = iter.sym.skipGenericOwner + var v = newSym(skVar, getIdent(d.graph.cache, envName), d.idgen, iterOwner, iter.info) incl(v.flags, sfShadowed) v.typ = asOwnedRef(d, getHiddenParam(d.graph, iter.sym).typ) var vnode: PNode - if owner.isIterator: - let it = getHiddenParam(d.graph, owner) + if iterOwner.isIterator: + let it = getHiddenParam(d.graph, iterOwner) addUniqueField(it.typ.skipTypes({tyOwned, tyRef, tyPtr}), v, d.graph.cache, d.idgen) vnode = indirectAccess(newSymNode(it), v, v.info) else: @@ -660,12 +678,14 @@ proc closureCreationForIter(iter: PNode; addVar(vs, vnode) result.add(vs) result.add genCreateEnv(vnode) - createTypeBoundOpsLL(d.graph, vnode.typ, iter.info, d.idgen, owner) + createTypeBoundOpsLL(d.graph, vnode.typ, iter.info, d.idgen, iterOwner) let upField = lookupInRecord(v.typ.skipTypes({tyOwned, tyRef, tyPtr}).n, getIdent(d.graph.cache, upName)) if upField != nil: - let u = setupEnvVar(owner, d, c, iter.info) - if u.typ.skipTypes({tyOwned, tyRef, tyPtr}) == upField.typ.skipTypes({tyOwned, tyRef, tyPtr}): + let expectedUpTyp = upField.typ.skipTypes({tyOwned, tyRef, tyPtr}) + let u = if iterOwner == owner: setupEnvVar(iterOwner, d, c, iter.info) + else: getUpForIter(d.graph, owner, iterOwner, expectedUpTyp) + if u != nil and u.typ.skipTypes({tyOwned, tyRef, tyPtr}) == expectedUpTyp: result.add(newAsgnStmt(rawIndirectAccess(vnode, upField, iter.info), u, iter.info)) else: @@ -699,7 +719,7 @@ proc symToClosure(n: PNode; owner: PSym; d: var DetectionPass; let available = getHiddenParam(d.graph, owner) result = makeClosure(d.graph, d.idgen, s, available.newSymNode, n.info) elif s.isIterator: - result = closureCreationForIter(n, d, c) + result = closureCreationForIter(owner, n, d, c) elif s.skipGenericOwner == owner: # direct dependency, so use the outer's env variable: result = makeClosure(d.graph, d.idgen, s, setupEnvVar(owner, d, c, n.info), n.info) diff --git a/tests/iter/tnestedclosures.nim b/tests/iter/tnestedclosures.nim new file mode 100644 index 0000000000000..273c4aa303d01 --- /dev/null +++ b/tests/iter/tnestedclosures.nim @@ -0,0 +1,139 @@ +discard """ + targets: "c" + output: ''' +Test 1: +12 +Test 2: +23 +23 +Test 3: +34 +34 +Test 4: +45 +45 +50 +50 +Test 5: +45 +123 +47 +50 +Test 6: + +Test 7: +0 +1 +2 +''' +""" + +block: #24094 + echo "Test 1:" + proc foo() = + let x = 12 + iterator bar2(): int {.closure.} = + yield x + proc bar() = + let z = bar2 + for y in z(): # just doing bar2() gives param not in env: x + echo y + bar() + + foo() + +block: #24094 + echo "Test 2:" + iterator foo(): int {.closure.} = + let x = 23 + iterator bar2(): int {.closure.} = + yield x + proc bar() = + let z = bar2 + for y in z(): + echo y + bar() + yield x + + for x in foo(): echo x + +block: #24094 + echo "Test 3:" + iterator foo(): int {.closure.} = + let x = 34 + proc bar() = + echo x + iterator bar2(): int {.closure.} = + bar() + yield x + for y in bar2(): + yield y + + for x in foo(): echo x + +block: + echo "Test 4:" + proc foo() = + var x = 45 + iterator bar2(): int {.closure.} = + yield x + yield x + 3 + + let b1 = bar2 + let b2 = bar2 + echo b1() + echo b2() + x = 47 + echo b1() + echo b2() + foo() + +block: + echo "Test 5:" + proc foo() = + var x = 45 + iterator bar2(): int {.closure.} = + yield x + yield x + 3 + + proc bar() = + var y = 123 + iterator bar3(): int {.closure.} = + yield x + yield y + let b3 = bar3 + for z in b3(): + echo z + x = 47 + let b2 = bar2 + for z in b2(): + echo z + bar() + foo() + +block: #19154 + echo "Test 6:" + proc test(s: string): proc(): iterator(): string = + iterator it(): string = yield s + proc f(): iterator(): string = it + return f + + let it = test("hi")() + for s in it(): + echo "<", s, ">" + +block: #3824 + echo "Test 7:" + proc main = + iterator factory(): int {.closure.} = + iterator bar(): int {.closure.} = + yield 0 + yield 1 + yield 2 + + for x in bar(): yield x + + for x in factory(): + echo x + + main()