Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #3824, fixes #19154, and hopefully #24094. Re-applies #23787. #24316

Merged
merged 3 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 25 additions & 103 deletions compiler/closureiters.nim
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,6 @@ type
# is their finally. For finally it is parent finally. Otherwise -1
idgen: IdGenerator
varStates: Table[ItemId, int] # Used to detect if local variable belongs to multiple states
stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting
# remove if -d:nimOptIters is default, treating it as always nil
nimOptItersEnabled: bool # tracks if -d:nimOptIters is enabled
# should be default when issues are fixed, see #24094

const
nkSkip = {nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt,
Expand All @@ -174,11 +170,8 @@ const
localRequiresLifting = -2

proc newStateAccess(ctx: var Ctx): PNode =
if ctx.stateVarSym.isNil:
result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)),
result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)),
getStateField(ctx.g, ctx.fn), ctx.fn.info)
else:
result = newSymNode(ctx.stateVarSym)

proc newStateAssgn(ctx: var Ctx, toValue: PNode): PNode =
# Creates state assignment:
Expand All @@ -196,22 +189,12 @@ proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym =
result.flags.incl sfNoInit
assert(not typ.isNil, "Env var needs a type")

if not ctx.stateVarSym.isNil:
# We haven't gone through labmda lifting yet, so just create a local var,
# it will be lifted later
if ctx.tempVars.isNil:
ctx.tempVars = newNodeI(nkVarSection, ctx.fn.info)
addVar(ctx.tempVars, newSymNode(result))
else:
let envParam = getEnvParam(ctx.fn)
# let obj = envParam.typ.lastSon
result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen)
let envParam = getEnvParam(ctx.fn)
# let obj = envParam.typ.lastSon
result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen)

proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode =
if ctx.stateVarSym.isNil:
result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info)
else:
result = newSymNode(s)
result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info)

proc newTempVarAccess(ctx: Ctx, s: PSym): PNode =
result = newSymNode(s, ctx.fn.info)
Expand Down Expand Up @@ -263,20 +246,12 @@ proc newTempVarDef(ctx: Ctx, s: PSym, initialValue: PNode): PNode =
v = ctx.g.emptyNode
newTree(nkVarSection, newTree(nkIdentDefs, newSymNode(s), ctx.g.emptyNode, v))

proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode

proc newTempVar(ctx: var Ctx, typ: PType, parent: PNode, initialValue: PNode = nil): PSym =
if ctx.nimOptItersEnabled:
result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info)
else:
result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ)
result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info)
inc ctx.tempVarId
result.typ = typ
assert(not typ.isNil, "Temp var needs a type")
if ctx.nimOptItersEnabled:
parent.add(ctx.newTempVarDef(result, initialValue))
elif initialValue != nil:
parent.add(ctx.newEnvVarAsgn(result, initialValue))
parent.add(ctx.newTempVarDef(result, initialValue))

proc hasYields(n: PNode): bool =
# TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt.
Expand Down Expand Up @@ -455,24 +430,13 @@ proc newTempVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
result = newTree(nkFastAsgn, ctx.newTempVarAccess(s), v)
result.info = v.info

proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
# unused with -d:nimOptIters
if isEmptyType(v.typ):
result = v
else:
result = newTree(nkFastAsgn, ctx.newEnvVarAccess(s), v)
result.info = v.info

proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) =
var input = input
if input.kind == nkStmtListExpr:
let (st, res) = exprToStmtList(input)
output.add(st)
input = res
if ctx.nimOptItersEnabled:
output.add(ctx.newTempVarAsgn(sym, input))
output.add(ctx.newTempVarAsgn(sym, res))
else:
output.add(ctx.newEnvVarAsgn(sym, input))
output.add(ctx.newTempVarAsgn(sym, input))

proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode =
result = newNodeI(nkStmtList, exprBody.info)
Expand Down Expand Up @@ -601,11 +565,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
else:
internalError(ctx.g.config, "lowerStmtListExpr(nkIf): " & $branch.kind)

if isExpr:
if ctx.nimOptItersEnabled:
result.add(ctx.newTempVarAccess(tmp))
else:
result.add(ctx.newEnvVarAccess(tmp))
if isExpr: result.add(ctx.newTempVarAccess(tmp))

of nkTryStmt, nkHiddenTryStmt:
var ns = false
Expand Down Expand Up @@ -635,10 +595,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
else:
internalError(ctx.g.config, "lowerStmtListExpr(nkTryStmt): " & $branch.kind)
result.add(n)
if ctx.nimOptItersEnabled:
result.add(ctx.newTempVarAccess(tmp))
else:
result.add(ctx.newEnvVarAccess(tmp))
result.add(ctx.newTempVarAccess(tmp))

of nkCaseStmt:
var ns = false
Expand Down Expand Up @@ -670,10 +627,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
else:
internalError(ctx.g.config, "lowerStmtListExpr(nkCaseStmt): " & $branch.kind)
result.add(n)
if ctx.nimOptItersEnabled:
result.add(ctx.newTempVarAccess(tmp))
else:
result.add(ctx.newEnvVarAccess(tmp))
result.add(ctx.newTempVarAccess(tmp))
elif n[0].kind == nkStmtListExpr:
result = newNodeI(nkStmtList, n.info)
let (st, ex) = exprToStmtList(n[0])
Expand Down Expand Up @@ -706,11 +660,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
let tmp = ctx.newTempVar(cond.typ, result, cond)
# result.add(ctx.newTempVarAsgn(tmp, cond))

var check: PNode
if ctx.nimOptItersEnabled:
check = ctx.newTempVarAccess(tmp)
else:
check = ctx.newEnvVarAccess(tmp)
var check = ctx.newTempVarAccess(tmp)
if n[0].sym.magic == mOr:
check = ctx.g.newNotCall(check)

Expand All @@ -720,18 +670,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
let (st, ex) = exprToStmtList(cond)
ifBody.add(st)
cond = ex
if ctx.nimOptItersEnabled:
ifBody.add(ctx.newTempVarAsgn(tmp, cond))
else:
ifBody.add(ctx.newEnvVarAsgn(tmp, cond))
ifBody.add(ctx.newTempVarAsgn(tmp, cond))

let ifBranch = newTree(nkElifBranch, check, ifBody)
let ifNode = newTree(nkIfStmt, ifBranch)
result.add(ifNode)
if ctx.nimOptItersEnabled:
result.add(ctx.newTempVarAccess(tmp))
else:
result.add(ctx.newEnvVarAccess(tmp))
result.add(ctx.newTempVarAccess(tmp))
else:
for i in 0..<n.len:
if n[i].kind == nkStmtListExpr:
Expand All @@ -742,10 +686,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking
let tmp = ctx.newTempVar(n[i].typ, result, n[i])
# result.add(ctx.newTempVarAsgn(tmp, n[i]))
if ctx.nimOptItersEnabled:
n[i] = ctx.newTempVarAccess(tmp)
else:
n[i] = ctx.newEnvVarAccess(tmp)
n[i] = ctx.newTempVarAccess(tmp)

result.add(n)

Expand Down Expand Up @@ -1343,13 +1284,6 @@ proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode =
result.info = n.info

let localVars = newNodeI(nkStmtList, n.info)
if not ctx.stateVarSym.isNil:
let varSect = newNodeI(nkVarSection, n.info)
addVar(varSect, newSymNode(ctx.stateVarSym))
localVars.add(varSect)

if not ctx.tempVars.isNil:
localVars.add(ctx.tempVars)

let blockStmt = newNodeI(nkBlockStmt, n.info)
blockStmt.add(newSymNode(ctx.stateLoopLabel))
Expand Down Expand Up @@ -1552,21 +1486,11 @@ proc liftLocals(c: var Ctx, n: PNode): PNode =
n[i] = liftLocals(c, n[i])

proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: PNode): PNode =
var ctx = Ctx(g: g, fn: fn, idgen: idgen,
# should be default when issues are fixed, see #24094:
nimOptItersEnabled: isDefined(g.config, "nimOptIters"))

if getEnvParam(fn).isNil:
if ctx.nimOptItersEnabled:
# The transformation should always happen after at least partial lambdalifting
# is performed, so that the closure iter environment is always created upfront.
doAssert(false, "Env param not created before iter transformation")
else:
# Lambda lifting was not done yet. Use temporary :state sym, which will
# be handled specially by lambda lifting. Local temp vars (if needed)
# should follow the same logic.
ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), idgen, fn, fn.info)
ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen)
var ctx = Ctx(g: g, fn: fn, idgen: idgen)

# The transformation should always happen after at least partial lambdalifting
# is performed, so that the closure iter environment is always created upfront.
doAssert(getEnvParam(fn) != nil, "Env param not created before iter transformation")

ctx.stateLoopLabel = newSym(skLabel, getIdent(ctx.g.cache, ":stateLoop"), idgen, fn, fn.info)
var pc = PreprocessContext(finallys: @[], config: g.config, idgen: idgen)
Expand All @@ -1592,10 +1516,9 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
let caseDispatcher = newTreeI(nkCaseStmt, n.info,
ctx.newStateAccess())

if ctx.nimOptItersEnabled:
# Lamdalifting will not touch our locals, it is our responsibility to lift those that
# need it.
detectCapturedVars(ctx)
# Lamdalifting will not touch our locals, it is our responsibility to lift those that
# need it.
detectCapturedVars(ctx)

for s in ctx.states:
let body = ctx.transformStateAssignments(s.body)
Expand All @@ -1604,8 +1527,7 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
caseDispatcher.add newTreeI(nkElse, n.info, newTreeI(nkReturnStmt, n.info, g.emptyNode))

result = wrapIntoStateLoop(ctx, caseDispatcher)
if ctx.nimOptItersEnabled:
result = liftLocals(ctx, result)
result = liftLocals(ctx, result)

when false:
echo "TRANSFORM TO STATES: "
Expand Down
Loading