Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor to make sigmatch use LayeredIdTable for bindings #24216

Merged
merged 11 commits into from
Oct 6, 2024
12 changes: 6 additions & 6 deletions compiler/concepts.nim
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
## for details. Note this is a first implementation and only the "Concept matching"
## section has been implemented.

import ast, astalgo, semdata, lookups, lineinfos, idents, msgs, renderer, types
import ast, semdata, lookups, lineinfos, idents, msgs, renderer, types, layeredtable

import std/intsets

Expand Down Expand Up @@ -309,7 +309,7 @@ proc conceptMatchNode(c: PContext; n: PNode; m: var MatchCon): bool =
# error was reported earlier.
result = false

proc conceptMatch*(c: PContext; concpt, arg: PType; bindings: var TypeMapping; invocation: PType): bool =
proc conceptMatch*(c: PContext; concpt, arg: PType; bindings: var LayeredIdTable; invocation: PType): bool =
## Entry point from sigmatch. 'concpt' is the concept we try to match (here still a PType but
## we extract its AST via 'concpt.n.lastSon'). 'arg' is the type that might fulfill the
## concept's requirements. If so, we return true and fill the 'bindings' with pairs of
Expand All @@ -328,16 +328,16 @@ proc conceptMatch*(c: PContext; concpt, arg: PType; bindings: var TypeMapping; i
dest = existingBinding(m, dest)
if dest == nil or dest.kind != tyGenericParam: break
if dest != nil:
bindings.idTablePut(a, dest)
bindings.put(a, dest)
when logBindings: echo "A bind ", a, " ", dest
else:
bindings.idTablePut(a, b)
bindings.put(a, b)
when logBindings: echo "B bind ", a, " ", b
# we have a match, so bind 'arg' itself to 'concpt':
bindings.idTablePut(concpt, arg)
bindings.put(concpt, arg)
# invocation != nil means we have a non-atomic concept:
if invocation != nil and arg.kind == tyGenericInst and invocation.kidsLen == arg.kidsLen-1:
# bind even more generic parameters
assert invocation.kind == tyGenericInvocation
for i in FirstGenericParamAt ..< invocation.kidsLen:
bindings.idTablePut(invocation[i], arg[i])
bindings.put(invocation[i], arg[i])
82 changes: 82 additions & 0 deletions compiler/layeredtable.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import std/tables
import ast

type
LayeredIdTableObj* {.acyclic.} = object
## stack of type binding contexts implemented as a linked list
topLayer*: TypeMapping
## the mappings on the current layer
nextLayer*: ref LayeredIdTableObj
## the parent type binding context, possibly `nil`
previousLen*: int
## total length of the bindings up to the parent layer,
## used to track if new bindings were added

const useRef = not defined(gcDestructors)
# implementation detail, only arc/orc doesn't cause issues when
# using LayeredIdTable as an object and not a ref

when useRef:
type LayeredIdTable* = ref LayeredIdTableObj
else:
type LayeredIdTable* = LayeredIdTableObj

proc initLayeredTypeMap*(pt: sink TypeMapping = initTypeMapping()): LayeredIdTable =
result = LayeredIdTable(topLayer: pt, nextLayer: nil)

proc shallowCopy*(pt: LayeredIdTable): LayeredIdTable {.inline.} =
## copies only the type bindings of the current layer, but not any parent layers,
## useful for write-only bindings
result = LayeredIdTable(topLayer: pt.topLayer, nextLayer: pt.nextLayer, previousLen: pt.previousLen)

proc currentLen*(pt: LayeredIdTable): int =
## the sum of the cached total binding count of the parents and
## the current binding count, just used to track if bindings were added
pt.previousLen + pt.topLayer.len

proc newTypeMapLayer*(pt: LayeredIdTable): LayeredIdTable =
result = LayeredIdTable(topLayer: initTable[ItemId, PType](), previousLen: pt.currentLen)
when useRef:
result.nextLayer = pt
else:
new(result.nextLayer)
result.nextLayer[] = pt

proc setToPreviousLayer*(pt: var LayeredIdTable) {.inline.} =
when useRef:
pt = pt.nextLayer
else:
when defined(gcDestructors):
pt = pt.nextLayer[]
else:
# workaround refc
let tmp = pt.nextLayer[]
pt = tmp

proc lookup(typeMap: ref LayeredIdTableObj, key: ItemId): PType =
result = nil
var tm = typeMap
while tm != nil:
result = getOrDefault(tm.topLayer, key)
if result != nil: return
tm = tm.nextLayer

template lookup*(typeMap: ref LayeredIdTableObj, key: PType): PType =
## recursively looks up binding of `key` in all parent layers
lookup(typeMap, key.itemId)

when not useRef:
proc lookup(typeMap: LayeredIdTableObj, key: ItemId): PType {.inline.} =
result = getOrDefault(typeMap.topLayer, key)
if result == nil and typeMap.nextLayer != nil:
result = lookup(typeMap.nextLayer, key)

template lookup*(typeMap: LayeredIdTableObj, key: PType): PType =
lookup(typeMap, key.itemId)

proc put(typeMap: var LayeredIdTable, key: ItemId, value: PType) {.inline.} =
typeMap.topLayer[key] = value

template put*(typeMap: var LayeredIdTable, key, value: PType) =
## binds `key` to `value` only in current layer
put(typeMap, key.itemId, value)
6 changes: 3 additions & 3 deletions compiler/sem.nim
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import
evaltempl, patterns, parampatterns, sempass2, linter, semmacrosanity,
lowerings, plugins/active, lineinfos, int128,
isolation_check, typeallowed, modulegraphs, enumtostr, concepts, astmsgs,
extccomp
extccomp, layeredtable

import vtables
import std/[strtabs, math, tables, intsets, strutils, packedsets]
Expand Down Expand Up @@ -473,15 +473,15 @@ proc semAfterMacroCall(c: PContext, call, macroResult: PNode,
# e.g. template foo(T: typedesc): seq[T]
# We will instantiate the return type here, because
# we now know the supplied arguments
var paramTypes = initTypeMapping()
var paramTypes = initLayeredTypeMap()
for param, value in genericParamsInMacroCall(s, call):
var givenType = value.typ
# the sym nodes used for the supplied generic arguments for
# templates and macros leave type nil so regular sem can handle it
# in this case, get the type directly from the sym
if givenType == nil and value.kind == nkSym and value.sym.typ != nil:
givenType = value.sym.typ
idTablePut(paramTypes, param.typ, givenType)
put(paramTypes, param.typ, givenType)

retType = generateTypeInstance(c, paramTypes,
macroResult.info, retType)
Expand Down
4 changes: 2 additions & 2 deletions compiler/semcall.nim
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ proc inheritBindings(c: PContext, x: var TCandidate, expectedType: PType) =
if t[i] == nil or u[i] == nil: return
stackPut(t[i], u[i])
of tyGenericParam:
let prebound = x.bindings.idTableGet(t)
let prebound = x.bindings.lookup(t)
if prebound != nil:
continue # Skip param, already bound

Expand All @@ -760,7 +760,7 @@ proc inheritBindings(c: PContext, x: var TCandidate, expectedType: PType) =
discard
# update bindings
for i in 0 ..< flatUnbound.len():
x.bindings.idTablePut(flatUnbound[i], flatBound[i])
x.bindings.put(flatUnbound[i], flatBound[i])

proc semResolvedCall(c: PContext, x: var TCandidate,
n: PNode, flags: TExprFlags;
Expand Down
10 changes: 5 additions & 5 deletions compiler/semdata.nim
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ when defined(nimPreviewSlimSystem):
import std/assertions

import
options, ast, astalgo, msgs, idents, renderer,
magicsys, vmdef, modulegraphs, lineinfos, pathutils
options, ast, msgs, idents, renderer,
magicsys, vmdef, modulegraphs, lineinfos, pathutils, layeredtable

import ic / ic

Expand Down Expand Up @@ -136,10 +136,10 @@ type
semOverloadedCall*: proc (c: PContext, n, nOrig: PNode,
filter: TSymKinds, flags: TExprFlags, expectedType: PType = nil): PNode {.nimcall.}
semTypeNode*: proc(c: PContext, n: PNode, prev: PType): PType {.nimcall.}
semInferredLambda*: proc(c: PContext, pt: Table[ItemId, PType], n: PNode): PNode
semGenerateInstance*: proc (c: PContext, fn: PSym, pt: Table[ItemId, PType],
semInferredLambda*: proc(c: PContext, pt: LayeredIdTable, n: PNode): PNode
semGenerateInstance*: proc (c: PContext, fn: PSym, pt: LayeredIdTable,
info: TLineInfo): PSym
instantiateOnlyProcType*: proc (c: PContext, pt: TypeMapping,
instantiateOnlyProcType*: proc (c: PContext, pt: LayeredIdTable,
prc: PSym, info: TLineInfo): PType
# used by sigmatch for explicit generic instantiations
includedFiles*: IntSet # used to detect recursive include files
Expand Down
4 changes: 2 additions & 2 deletions compiler/semexprs.nim
Original file line number Diff line number Diff line change
Expand Up @@ -2516,8 +2516,8 @@ proc instantiateCreateFlowVarCall(c: PContext; t: PType;
let sym = magicsys.getCompilerProc(c.graph, "nimCreateFlowVar")
if sym == nil:
localError(c.config, info, "system needs: nimCreateFlowVar")
var bindings = initTypeMapping()
bindings.idTablePut(sym.ast[genericParamsPos][0].typ, t)
var bindings = initLayeredTypeMap()
bindings.put(sym.ast[genericParamsPos][0].typ, t)
result = c.semGenerateInstance(c, sym, bindings, info)
# since it's an instantiation, we unmark it as a compilerproc. Otherwise
# codegen would fail:
Expand Down
12 changes: 6 additions & 6 deletions compiler/seminst.nim
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ proc pushProcCon*(c: PContext; owner: PSym) =
const
errCannotInstantiateX = "cannot instantiate: '$1'"

iterator instantiateGenericParamList(c: PContext, n: PNode, pt: TypeMapping): PSym =
iterator instantiateGenericParamList(c: PContext, n: PNode, pt: LayeredIdTable): PSym =
internalAssert c.config, n.kind == nkGenericParams
for a in n.items:
internalAssert c.config, a.kind == nkSym
Expand All @@ -43,7 +43,7 @@ iterator instantiateGenericParamList(c: PContext, n: PNode, pt: TypeMapping): PS
let symKind = if q.typ.kind == tyStatic: skConst else: skType
var s = newSym(symKind, q.name, c.idgen, getCurrOwner(c), q.info)
s.flags.incl {sfUsed, sfFromGeneric}
var t = idTableGet(pt, q.typ)
var t = lookup(pt, q.typ)
if t == nil:
if tfRetType in q.typ.flags:
# keep the generic type and allow the return type to be bound
Expand Down Expand Up @@ -220,7 +220,7 @@ proc referencesAnotherParam(n: PNode, p: PSym): bool =
if referencesAnotherParam(n[i], p): return true
return false

proc instantiateProcType(c: PContext, pt: TypeMapping,
proc instantiateProcType(c: PContext, pt: LayeredIdTable,
prc: PSym, info: TLineInfo) =
# XXX: Instantiates a generic proc signature, while at the same
# time adding the instantiated proc params into the current scope.
Expand All @@ -237,7 +237,7 @@ proc instantiateProcType(c: PContext, pt: TypeMapping,
# will need to use openScope, addDecl, etc.
#addDecl(c, prc)
pushInfoContext(c.config, info)
var typeMap = initLayeredTypeMap(pt)
var typeMap = shallowCopy(pt) # use previous bindings without writing to them
var cl = initTypeVars(c, typeMap, info, nil)
var result = instCopyType(cl, prc.typ)
let originalParams = result.n
Expand Down Expand Up @@ -324,7 +324,7 @@ proc instantiateProcType(c: PContext, pt: TypeMapping,
prc.typ = result
popInfoContext(c.config)

proc instantiateOnlyProcType(c: PContext, pt: TypeMapping, prc: PSym, info: TLineInfo): PType =
proc instantiateOnlyProcType(c: PContext, pt: LayeredIdTable, prc: PSym, info: TLineInfo): PType =
# instantiates only the type of a given proc symbol
# used by sigmatch for explicit generics
# wouldn't be needed if sigmatch could handle complex cases,
Expand Down Expand Up @@ -360,7 +360,7 @@ proc getLocalPassC(c: PContext, s: PSym): string =
for p in n:
extractPassc(p)

proc generateInstance(c: PContext, fn: PSym, pt: TypeMapping,
proc generateInstance(c: PContext, fn: PSym, pt: LayeredIdTable,
info: TLineInfo): PSym =
## Generates a new instance of a generic procedure.
## The `pt` parameter is a type-unsafe mapping table used to link generic
Expand Down
2 changes: 1 addition & 1 deletion compiler/semstmts.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1969,7 +1969,7 @@ proc semProcAnnotation(c: PContext, prc: PNode;

return result

proc semInferredLambda(c: PContext, pt: TypeMapping, n: PNode): PNode =
proc semInferredLambda(c: PContext, pt: LayeredIdTable, n: PNode): PNode =
## used for resolving 'auto' in lambdas based on their callsite
var n = n
let original = n[namePos].sym
Expand Down
43 changes: 12 additions & 31 deletions compiler/semtypinst.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import std / tables

import ast, astalgo, msgs, types, magicsys, semdata, renderer, options,
lineinfos, modulegraphs
lineinfos, modulegraphs, layeredtable

when defined(nimPreviewSlimSystem):
import std/assertions
Expand Down Expand Up @@ -65,10 +65,6 @@ proc cacheTypeInst(c: PContext; inst: PType) =
addToGenericCache(c, gt.sym, inst)

type
LayeredIdTable* {.acyclic.} = ref object
topLayer*: TypeMapping
nextLayer*: LayeredIdTable

TReplTypeVars* = object
c*: PContext
typeMap*: LayeredIdTable # map PType to PType
Expand All @@ -88,23 +84,8 @@ proc replaceTypeVarsTAux(cl: var TReplTypeVars, t: PType): PType
proc replaceTypeVarsS(cl: var TReplTypeVars, s: PSym, t: PType): PSym
proc replaceTypeVarsN*(cl: var TReplTypeVars, n: PNode; start=0; expectedType: PType = nil): PNode

proc initLayeredTypeMap*(pt: sink TypeMapping): LayeredIdTable =
result = LayeredIdTable()
result.topLayer = pt

proc newTypeMapLayer*(cl: var TReplTypeVars): LayeredIdTable =
result = LayeredIdTable(nextLayer: cl.typeMap, topLayer: initTable[ItemId, PType]())

proc lookup(typeMap: LayeredIdTable, key: PType): PType =
result = nil
var tm = typeMap
while tm != nil:
result = getOrDefault(tm.topLayer, key.itemId)
if result != nil: return
tm = tm.nextLayer

template put(typeMap: LayeredIdTable, key, value: PType) =
typeMap.topLayer[key.itemId] = value
result = newTypeMapLayer(cl.typeMap)

template checkMetaInvariants(cl: TReplTypeVars, t: PType) = # noop code
when false:
Expand Down Expand Up @@ -500,7 +481,7 @@ proc handleGenericInvocation(cl: var TReplTypeVars, t: PType): PType =
newbody.flags = newbody.flags + (t.flags + body.flags - tfInstClearedFlags)
result.flags = result.flags + newbody.flags - tfInstClearedFlags

cl.typeMap = cl.typeMap.nextLayer
setToPreviousLayer(cl.typeMap)

# This is actually wrong: tgeneric_closure fails with this line:
#newbody.callConv = body.callConv
Expand Down Expand Up @@ -791,19 +772,19 @@ proc initTypeVars*(p: PContext, typeMap: LayeredIdTable, info: TLineInfo;
localCache: initTypeMapping(), typeMap: typeMap,
info: info, c: p, owner: owner)

proc replaceTypesInBody*(p: PContext, pt: TypeMapping, n: PNode;
proc replaceTypesInBody*(p: PContext, pt: LayeredIdTable, n: PNode;
owner: PSym, allowMetaTypes = false,
fromStaticExpr = false, expectedType: PType = nil): PNode =
var typeMap = initLayeredTypeMap(pt)
var typeMap = shallowCopy(pt) # use previous bindings without writing to them
var cl = initTypeVars(p, typeMap, n.info, owner)
cl.allowMetaTypes = allowMetaTypes
pushInfoContext(p.config, n.info)
result = replaceTypeVarsN(cl, n, expectedType = expectedType)
popInfoContext(p.config)

proc prepareTypesInBody*(p: PContext, pt: TypeMapping, n: PNode;
proc prepareTypesInBody*(p: PContext, pt: LayeredIdTable, n: PNode;
owner: PSym = nil): PNode =
var typeMap = initLayeredTypeMap(pt)
var typeMap = shallowCopy(pt) # use previous bindings without writing to them
var cl = initTypeVars(p, typeMap, n.info, owner)
pushInfoContext(p.config, n.info)
result = prepareNode(cl, n)
Expand Down Expand Up @@ -836,13 +817,13 @@ proc recomputeFieldPositions*(t: PType; obj: PNode; currPosition: var int) =
inc currPosition
else: discard "cannot happen"

proc generateTypeInstance*(p: PContext, pt: TypeMapping, info: TLineInfo,
proc generateTypeInstance*(p: PContext, pt: LayeredIdTable, info: TLineInfo,
t: PType): PType =
# Given `t` like Foo[T]
# pt: Table with type mappings: T -> int
# Desired result: Foo[int]
# proc (x: T = 0); T -> int ----> proc (x: int = 0)
var typeMap = initLayeredTypeMap(pt)
var typeMap = shallowCopy(pt) # use previous bindings without writing to them
var cl = initTypeVars(p, typeMap, info, nil)
pushInfoContext(p.config, info)
result = replaceTypeVarsT(cl, t)
Expand All @@ -852,15 +833,15 @@ proc generateTypeInstance*(p: PContext, pt: TypeMapping, info: TLineInfo,
var position = 0
recomputeFieldPositions(objType, objType.n, position)

proc prepareMetatypeForSigmatch*(p: PContext, pt: TypeMapping, info: TLineInfo,
proc prepareMetatypeForSigmatch*(p: PContext, pt: LayeredIdTable, info: TLineInfo,
t: PType): PType =
var typeMap = initLayeredTypeMap(pt)
var typeMap = shallowCopy(pt) # use previous bindings without writing to them
var cl = initTypeVars(p, typeMap, info, nil)
cl.allowMetaTypes = true
pushInfoContext(p.config, info)
result = replaceTypeVarsT(cl, t)
popInfoContext(p.config)

template generateTypeInstance*(p: PContext, pt: TypeMapping, arg: PNode,
template generateTypeInstance*(p: PContext, pt: LayeredIdTable, arg: PNode,
t: PType): untyped =
generateTypeInstance(p, pt, arg.info, t)
Loading
Loading