Skip to content

Commit

Permalink
remove inactive case fields from VM object constructor nodes
Browse files Browse the repository at this point in the history
fixes #17571
  • Loading branch information
metagn committed Nov 15, 2024
1 parent 371f50f commit 1fc3c45
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 19 deletions.
2 changes: 1 addition & 1 deletion compiler/sem.nim
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ proc fixupTypeAfterEval(c: PContext, evaluated, eOrig: PNode): PNode =
if hasCycle(result):
result = localErrorNode(c, eOrig, "the resulting AST is cyclic and cannot be processed further")
else:
semmacrosanity.annotateType(result, expectedType, c.config)
result = semmacrosanity.annotateType(result, expectedType, c.config)
else:
result = semExprWithType(c, evaluated)
#result = fitNode(c, e.typ, result) inlined with special case:
Expand Down
82 changes: 66 additions & 16 deletions compiler/semmacrosanity.nim
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,28 @@
## Implements type sanity checking for ASTs resulting from macros. Lots of
## room for improvement here.

import ast, msgs, types, options
import ast, msgs, types, options, trees, nimsets

proc ithField(n: PNode, field: var int): PSym =
type
FieldTracker = object
index: int
remaining: int
constr: PNode
delete: bool # to delete fields from inactive case branches
FieldInfo = ref object
sym: PSym
delete: bool

proc caseBranchMatchesExpr(branch, matched: PNode): bool =
# copied from sem
result = false
for i in 0 ..< branch.len-1:
if branch[i].kind == nkRange:
if overlap(branch[i], matched): return true
elif exprStructuralEquivalent(branch[i], matched):
return true

proc ithField(n: PNode, field: var FieldTracker): FieldInfo =
result = nil
case n.kind
of nkRecList:
Expand All @@ -23,18 +42,42 @@ proc ithField(n: PNode, field: var int): PSym =
if n[0].kind != nkSym: return
result = ithField(n[0], field)
if result != nil: return
# value of the discriminator field, from (index - remaining - 1 + 1):
# - 1 because the `ithField` call above decreased it by 1,
# + 1 because the constructor node has an initial type child
let val = field.constr[field.index - field.remaining][1]
var branchFound = false
for i in 1..<n.len:
let previousDelete = field.delete
case n[i].kind
of nkOfBranch, nkElse:
of nkOfBranch:
if branchFound or previousDelete or
not caseBranchMatchesExpr(n[i], val):
# if this is not the active case branch,
# mark all fields inside as deleted
field.delete = true
else:
branchFound = true
result = ithField(lastSon(n[i]), field)
if result != nil: return
field.delete = previousDelete
of nkElse:
if branchFound:
# if this is not the active case branch,
# mark all fields inside as deleted
field.delete = true
result = ithField(lastSon(n[i]), field)
if result != nil: return
field.delete = previousDelete
else: discard
of nkSym:
if field == 0: result = n.sym
else: dec(field)
if field.remaining == 0:
result = FieldInfo(sym: n.sym, delete: field.delete)
else:
dec(field.remaining)
else: discard

proc ithField(t: PType, field: var int): PSym =
proc ithField(t: PType, field: var FieldTracker): FieldInfo =
var base = t.baseClass
while base != nil:
let b = skipTypes(base, skipPtrs)
Expand All @@ -43,7 +86,8 @@ proc ithField(t: PType, field: var int): PSym =
base = b.baseClass
result = ithField(t.n, field)

proc annotateType*(n: PNode, t: PType; conf: ConfigRef) =
proc annotateType*(n: PNode, t: PType; conf: ConfigRef): PNode =
result = n
let x = t.skipTypes(abstractInst+{tyRange})
# Note: x can be unequal to t and we need to be careful to use 't'
# to not to skip tyGenericInst
Expand All @@ -52,20 +96,24 @@ proc annotateType*(n: PNode, t: PType; conf: ConfigRef) =
let x = t.skipTypes(abstractPtrs)
n.typ() = t
n[0].typ() = t
result = copyNode(n)
result.add(n[0])
for i in 1..<n.len:
var j = i-1
let field = x.ithField(j)
var tracker = FieldTracker(index: i-1, remaining: i-1, constr: n, delete: false)
let field = x.ithField(tracker)
if field.isNil:
globalError conf, n.info, "invalid field at index " & $i
else:
elif not field.delete:
# only add fields from active case branches
internalAssert(conf, n[i].kind == nkExprColonExpr)
annotateType(n[i][1], field.typ, conf)
n[i][1] = annotateType(n[i][1], field.sym.typ, conf)
result.add(n[i])
of nkPar, nkTupleConstr:
if x.kind == tyTuple:
n.typ() = t
for i in 0..<n.len:
if i >= x.kidsLen: globalError conf, n.info, "invalid field at index " & $i
else: annotateType(n[i], x[i], conf)
else: n[i] = annotateType(n[i], x[i], conf)
elif x.kind == tyProc and x.callConv == ccClosure:
n.typ() = t
elif x.kind == tyOpenArray: # `opcSlice` transforms slices into tuples
Expand All @@ -79,11 +127,11 @@ proc annotateType*(n: PNode, t: PType; conf: ConfigRef) =
of nkStrKinds:
for i in left..right:
bracketExpr.add newIntNode(nkCharLit, BiggestInt n[0].strVal[i])
annotateType(bracketExpr[^1], x.elementType, conf)
bracketExpr[^1] = annotateType(bracketExpr[^1], x.elementType, conf)
of nkBracket:
for i in left..right:
bracketExpr.add n[0][i]
annotateType(bracketExpr[^1], x.elementType, conf)
bracketExpr[^1] = annotateType(bracketExpr[^1], x.elementType, conf)
else:
globalError(conf, n.info, "Incorrectly generated tuple constr")
n[] = bracketExpr[]
Expand All @@ -94,13 +142,15 @@ proc annotateType*(n: PNode, t: PType; conf: ConfigRef) =
of nkBracket:
if x.kind in {tyArray, tySequence, tyOpenArray}:
n.typ() = t
for m in n: annotateType(m, x.elemType, conf)
for i in 0 ..< n.len:
n[i] = annotateType(n[i], x.elemType, conf)
else:
globalError(conf, n.info, "[] must have some form of array type")
of nkCurly:
if x.kind in {tySet}:
n.typ() = t
for m in n: annotateType(m, x.elemType, conf)
for i in 0 ..< n.len:
n[i] = annotateType(n[i], x.elemType, conf)
else:
globalError(conf, n.info, "{} must have the set type")
of nkFloatLit..nkFloat128Lit:
Expand Down
4 changes: 2 additions & 2 deletions compiler/vm.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1458,10 +1458,10 @@ proc rawExecute(c: PCtx, start: int, tos: PStackFrame): TFullReg =
var macroCall = newNodeI(nkCall, c.debug[pc])
macroCall.add(newSymNode(prc))
for i in 1..rc-1:
let node = regs[rb+i].regToNode
var node = regs[rb+i].regToNode
node.info = c.debug[pc]
if prc.typ[i].kind notin {tyTyped, tyUntyped}:
node.annotateType(prc.typ[i], c.config)
node = node.annotateType(prc.typ[i], c.config)

macroCall.add(node)
var a = evalTemplate(macroCall, prc, genSymOwner, c.config, c.cache, c.templInstCounter, c.idgen)
Expand Down
19 changes: 19 additions & 0 deletions tests/vm/tcaseobj.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# issue #17571

import std/[macros, objectdollar]

type
MyEnum = enum
F, S, T
Foo = object
case o: MyEnum
of F:
f: string
of S:
s: string
of T:
t: string

let val = static(Foo(o: F, f: "foo")).f
doAssert val == "foo"
doAssert $static(Foo(o: F, f: "foo")) == $Foo(o: F, f: "foo")

0 comments on commit 1fc3c45

Please sign in to comment.