Skip to content

Commit

Permalink
misc: add a ScopedDict util, similar to LLVM's ScopedHashTable (#3355)
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh authored Oct 30, 2024
1 parent 4040e03 commit e6c7282
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 74 deletions.
25 changes: 4 additions & 21 deletions docs/Toy/toy/frontend/ir_gen.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

from collections.abc import Iterable
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import NoReturn

from xdsl.builder import Builder
from xdsl.dialects.builtin import ModuleOp, TensorType, UnrankedTensorType, f64
from xdsl.ir import Block, Region, SSAValue
from xdsl.utils.scoped_dict import ScopedDict

from ..dialects.toy import (
AddOp,
Expand Down Expand Up @@ -43,24 +44,6 @@ class IRGenError(Exception):
pass


@dataclass
class ScopedSymbolTable:
"A mapping from variable names to SSAValues, append-only"

table: dict[str, SSAValue] = field(default_factory=dict)

def __contains__(self, __o: object) -> bool:
return __o in self.table

def __getitem__(self, __key: str) -> SSAValue:
return self.table[__key]

def __setitem__(self, __key: str, __value: SSAValue) -> None:
if __key in self:
raise AssertionError(f"Cannot add value for key {__key} in scope {self}")
self.table[__key] = __value


@dataclass(init=False)
class IRGen:
"""
Expand All @@ -80,7 +63,7 @@ class IRGen:
is stateful, in particular it keeps an "insertion point": this is where
the next operations will be introduced."""

symbol_table: ScopedSymbolTable | None = None
symbol_table: ScopedDict[str, SSAValue] | None = None
"""
The symbol table maps a variable name to a value in the current scope.
Entering a function creates a new scope, and the function arguments are
Expand Down Expand Up @@ -156,7 +139,7 @@ def ir_gen_function(self, function_ast: FunctionAST) -> FuncOp:
parent_builder = self.builder

# Create a scope in the symbol table to hold variable declarations.
self.symbol_table = ScopedSymbolTable()
self.symbol_table = ScopedDict[str, SSAValue]()

proto_args = function_ast.proto.args

Expand Down
34 changes: 34 additions & 0 deletions tests/utils/test_scoped_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from xdsl.utils.scoped_dict import ScopedDict


def test_simple():
table = ScopedDict[int, int]()

table[1] = 2

assert table[1] == 2

table[2] = 3

assert table[2] == 3

with pytest.raises(ValueError, match="Cannot overwrite value 3 for key 2"):
table[2] = 4

with pytest.raises(KeyError):
table[3]

inner = ScopedDict(table, name="inner")

inner[2] = 4

assert inner[2] == 4
assert table[2] == 3

inner[3] = 5

assert 3 not in table
assert 3 in inner
assert 4 not in inner
73 changes: 20 additions & 53 deletions xdsl/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import platform
from collections import Counter
from collections.abc import Callable, Generator, Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from typing import (
IO,
Expand All @@ -28,6 +28,7 @@
)
from xdsl.traits import CallableOpInterface, IsTerminator, SymbolOpInterface
from xdsl.utils.exceptions import InterpretationError
from xdsl.utils.scoped_dict import ScopedDict

_IMPL_OP_TYPE = "__impl_op_type"
_CAST_IMPL_TYPES = "__cast_impl_types"
Expand Down Expand Up @@ -503,53 +504,6 @@ def call(
return ext_func(ft, interpreter, op, args)


@dataclass
class InterpreterContext:
"""
Class holding the Python values associated with SSAValues during an
interpretation context. An environment is a stack of scopes, values are
assigned to the current scope, but can be fetched from a parent scope.
"""

name: str = field(default="unknown")
parent: InterpreterContext | None = None
env: dict[SSAValue, Any] = field(default_factory=dict)

def __getitem__(self, key: SSAValue) -> Any:
"""
Fetch key from environment. Attempts to first fetch from current scope,
then from parent scopes. Raises Interpretation error if not found.
"""
if key in self.env:
return self.env[key]
if self.parent is not None:
return self.parent[key]
raise InterpretationError(f"Could not find value for {key} in {self}")

def __setitem__(self, key: SSAValue, value: Any):
"""
Assign key to current scope. Raises InterpretationError if key already
assigned to.
"""
if key in self.env:
raise InterpretationError(
f"Attempting to register SSAValue {value} for name {key}"
f", but value with that name already exists in {self}"
)
self.env[key] = value

def stack(self) -> Generator[InterpreterContext, None, None]:
"""
Iterates through scopes starting with the root scope.
"""
if self.parent is not None:
yield from self.parent.stack()
yield self

def __format__(self, __format_spec: str) -> str:
return "/".join(c.name for c in self.stack())


def _get_system_bitwidth() -> Literal[32, 64] | None:
match platform.architecture()[0]:
case "64bit":
Expand Down Expand Up @@ -590,9 +544,14 @@ def did_interpret_op(self, op: Operation, results: PythonValues) -> None: ...
Number of bits in the binary representation of the index
"""
_impls: _InterpreterFunctionImpls = field(default_factory=_InterpreterFunctionImpls)
_ctx: InterpreterContext = field(
default_factory=lambda: InterpreterContext(name="root")
_ctx: ScopedDict[SSAValue, Any] = field(
default_factory=lambda: ScopedDict(name="root")
)
"""
Object holding the Python values associated with SSAValues during an
interpretation context. An environment is a stack of scopes, values are
assigned to the current scope, but can be fetched from a parent scope.
"""
file: IO[str] | None = field(default=None)
_symbol_table: dict[str, Operation] | None = None
_impl_data: dict[type[InterpreterFunctions], dict[str, Any]] = field(
Expand Down Expand Up @@ -630,11 +589,11 @@ def set_values(self, pairs: Iterable[tuple[SSAValue, Any]]):
for ssa_value, result_value in pairs:
self._ctx[ssa_value] = result_value

def push_scope(self, name: str = "unknown") -> None:
def push_scope(self, name: str | None = None) -> None:
"""
Create new scope in current environment, with optional custom `name`.
"""
self._ctx = InterpreterContext(name, self._ctx)
self._ctx = ScopedDict(name=name, parent=self._ctx)

def pop_scope(self) -> None:
"""
Expand Down Expand Up @@ -814,8 +773,16 @@ def interpreter_assert(self, condition: bool, message: str | None = None):
if not condition:
self.raise_error(message)

def scope_names(self):
ctx = self._ctx

while ctx is not None:
yield ctx.name or "unknown"
ctx = ctx.parent

def raise_error(self, message: str | None = None):
raise InterpretationError(f"AssertionError: ({self._ctx})({message})")
scope_description = "/".join(self.scope_names())
raise InterpretationError(f"AssertionError: ({scope_description})({message})")


@dataclass
Expand Down
62 changes: 62 additions & 0 deletions xdsl/utils/scoped_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

from typing import Generic, TypeVar

_Key = TypeVar("_Key")
_Value = TypeVar("_Value")


class ScopedDict(Generic[_Key, _Value]):
"""
A tiered mapping from keys to values.
Once a value is set for a key, it cannot be overwritten.
A ScopedDict may have a parent dict, which is used as a fallback when a value for a
key is not found.
If a ScopedDict and its parent have values for the same key, the child value will be
returned.
This structure is useful for contexts where keys and values have a known scope, such
as during IR construction from an Abstract Syntax Tree.
ScopedDict instances may have a `name` property as a hint during debugging.
"""

_local_scope: dict[_Key, _Value]
parent: ScopedDict[_Key, _Value] | None
name: str | None

def __init__(
self,
parent: ScopedDict[_Key, _Value] | None = None,
*,
name: str | None = None,
) -> None:
self._local_scope = {}
self.parent = parent
self.name = name

def __getitem__(self, key: _Key) -> _Value:
"""
Fetch key from environment. Attempts to first fetch from current scope,
then from parent scopes. Raises KeyError error if not found.
"""
local = self._local_scope.get(key)
if local is not None:
return local
if self.parent is None:
raise KeyError(f"No value for key {key}")
return self.parent[key]

def __setitem__(self, key: _Key, value: _Value):
"""
Assign key to current scope. Raises InterpretationError if key already
assigned to.
"""
if key in self._local_scope:
raise ValueError(
f"Cannot overwrite value {self._local_scope[key]} for key {key}"
)
self._local_scope[key] = value

def __contains__(self, key: _Key) -> bool:
return (
key in self._local_scope or self.parent is not None and key in self.parent
)

0 comments on commit e6c7282

Please sign in to comment.