Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: add a ScopedDict util, similar to LLVM's ScopedHashTable #3355

Merged
merged 3 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions docs/Toy/toy/frontend/ir_gen.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

from collections.abc import Iterable
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import NoReturn

from xdsl.builder import Builder
from xdsl.dialects.builtin import ModuleOp, TensorType, UnrankedTensorType, f64
from xdsl.ir import Block, Region, SSAValue
from xdsl.utils.scoped_dict import ScopedDict

from ..dialects.toy import (
AddOp,
Expand Down Expand Up @@ -43,24 +44,6 @@ class IRGenError(Exception):
pass


@dataclass
class ScopedSymbolTable:
"A mapping from variable names to SSAValues, append-only"

table: dict[str, SSAValue] = field(default_factory=dict)

def __contains__(self, __o: object) -> bool:
return __o in self.table

def __getitem__(self, __key: str) -> SSAValue:
return self.table[__key]

def __setitem__(self, __key: str, __value: SSAValue) -> None:
if __key in self:
raise AssertionError(f"Cannot add value for key {__key} in scope {self}")
self.table[__key] = __value


@dataclass(init=False)
class IRGen:
"""
Expand All @@ -80,7 +63,7 @@ class IRGen:
is stateful, in particular it keeps an "insertion point": this is where
the next operations will be introduced."""

symbol_table: ScopedSymbolTable | None = None
symbol_table: ScopedDict[str, SSAValue] | None = None
"""
The symbol table maps a variable name to a value in the current scope.
Entering a function creates a new scope, and the function arguments are
Expand Down Expand Up @@ -156,7 +139,7 @@ def ir_gen_function(self, function_ast: FunctionAST) -> FuncOp:
parent_builder = self.builder

# Create a scope in the symbol table to hold variable declarations.
self.symbol_table = ScopedSymbolTable()
self.symbol_table = ScopedDict[str, SSAValue]()

proto_args = function_ast.proto.args

Expand Down
34 changes: 34 additions & 0 deletions tests/utils/test_scoped_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from xdsl.utils.scoped_dict import ScopedDict


def test_simple():
table = ScopedDict[int, int]()

table[1] = 2

assert table[1] == 2

table[2] = 3

assert table[2] == 3

with pytest.raises(ValueError, match="Cannot overwrite value 3 for key 2"):
table[2] = 4

with pytest.raises(KeyError):
table[3]

inner = ScopedDict(table, name="inner")

inner[2] = 4

assert inner[2] == 4
assert table[2] == 3

inner[3] = 5

assert 3 not in table
assert 3 in inner
assert 4 not in inner
73 changes: 20 additions & 53 deletions xdsl/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import platform
from collections import Counter
from collections.abc import Callable, Generator, Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from typing import (
IO,
Expand All @@ -28,6 +28,7 @@
)
from xdsl.traits import CallableOpInterface, IsTerminator, SymbolOpInterface
from xdsl.utils.exceptions import InterpretationError
from xdsl.utils.scoped_dict import ScopedDict

_IMPL_OP_TYPE = "__impl_op_type"
_CAST_IMPL_TYPES = "__cast_impl_types"
Expand Down Expand Up @@ -503,53 +504,6 @@ def call(
return ext_func(ft, interpreter, op, args)


@dataclass
class InterpreterContext:
"""
Class holding the Python values associated with SSAValues during an
interpretation context. An environment is a stack of scopes, values are
assigned to the current scope, but can be fetched from a parent scope.
"""

name: str = field(default="unknown")
parent: InterpreterContext | None = None
env: dict[SSAValue, Any] = field(default_factory=dict)

def __getitem__(self, key: SSAValue) -> Any:
"""
Fetch key from environment. Attempts to first fetch from current scope,
then from parent scopes. Raises Interpretation error if not found.
"""
if key in self.env:
return self.env[key]
if self.parent is not None:
return self.parent[key]
raise InterpretationError(f"Could not find value for {key} in {self}")

def __setitem__(self, key: SSAValue, value: Any):
"""
Assign key to current scope. Raises InterpretationError if key already
assigned to.
"""
if key in self.env:
raise InterpretationError(
f"Attempting to register SSAValue {value} for name {key}"
f", but value with that name already exists in {self}"
)
self.env[key] = value

def stack(self) -> Generator[InterpreterContext, None, None]:
"""
Iterates through scopes starting with the root scope.
"""
if self.parent is not None:
yield from self.parent.stack()
yield self

def __format__(self, __format_spec: str) -> str:
return "/".join(c.name for c in self.stack())


def _get_system_bitwidth() -> Literal[32, 64] | None:
match platform.architecture()[0]:
case "64bit":
Expand Down Expand Up @@ -590,9 +544,14 @@ def did_interpret_op(self, op: Operation, results: PythonValues) -> None: ...
Number of bits in the binary representation of the index
"""
_impls: _InterpreterFunctionImpls = field(default_factory=_InterpreterFunctionImpls)
_ctx: InterpreterContext = field(
default_factory=lambda: InterpreterContext(name="root")
_ctx: ScopedDict[SSAValue, Any] = field(
default_factory=lambda: ScopedDict(name="root")
)
"""
Object holding the Python values associated with SSAValues during an
interpretation context. An environment is a stack of scopes, values are
assigned to the current scope, but can be fetched from a parent scope.
"""
file: IO[str] | None = field(default=None)
_symbol_table: dict[str, Operation] | None = None
_impl_data: dict[type[InterpreterFunctions], dict[str, Any]] = field(
Expand Down Expand Up @@ -630,11 +589,11 @@ def set_values(self, pairs: Iterable[tuple[SSAValue, Any]]):
for ssa_value, result_value in pairs:
self._ctx[ssa_value] = result_value

def push_scope(self, name: str = "unknown") -> None:
def push_scope(self, name: str | None = None) -> None:
"""
Create new scope in current environment, with optional custom `name`.
"""
self._ctx = InterpreterContext(name, self._ctx)
self._ctx = ScopedDict(name=name, parent=self._ctx)

def pop_scope(self) -> None:
"""
Expand Down Expand Up @@ -814,8 +773,16 @@ def interpreter_assert(self, condition: bool, message: str | None = None):
if not condition:
self.raise_error(message)

def scope_names(self):
ctx = self._ctx

while ctx is not None:
yield ctx.name or "unknown"
ctx = ctx.parent

def raise_error(self, message: str | None = None):
raise InterpretationError(f"AssertionError: ({self._ctx})({message})")
scope_description = "/".join(self.scope_names())
raise InterpretationError(f"AssertionError: ({scope_description})({message})")


@dataclass
Expand Down
62 changes: 62 additions & 0 deletions xdsl/utils/scoped_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

from typing import Generic, TypeVar

_Key = TypeVar("_Key")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be hashable or something? I'm not sure why pyright doesn't complain here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hashable isn't a real type, AFAIK, and the dict generic type has the same TypeVar definition

_Value = TypeVar("_Value")


class ScopedDict(Generic[_Key, _Value]):
"""
A tiered mapping from keys to values.
Once a value is set for a key, it cannot be overwritten.
A ScopedDict may have a parent dict, which is used as a fallback when a value for a
key is not found.
If a ScopedDict and its parent have values for the same key, the child value will be
returned.
This structure is useful for contexts where keys and values have a known scope, such
as during IR construction from an Abstract Syntax Tree.
ScopedDict instances may have a `name` property as a hint during debugging.
"""

_local_scope: dict[_Key, _Value]
parent: ScopedDict[_Key, _Value] | None
name: str | None

def __init__(
self,
parent: ScopedDict[_Key, _Value] | None = None,
*,
name: str | None = None,
) -> None:
self._local_scope = {}
self.parent = parent
self.name = name

def __getitem__(self, key: _Key) -> _Value:
"""
Fetch key from environment. Attempts to first fetch from current scope,
then from parent scopes. Raises KeyError error if not found.
"""
local = self._local_scope.get(key)
if local is not None:
return local
if self.parent is None:
raise KeyError(f"No value for key {key}")
return self.parent[key]

def __setitem__(self, key: _Key, value: _Value):
"""
Assign key to current scope. Raises InterpretationError if key already
assigned to.
"""
if key in self._local_scope:
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a KeyError?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered it, but I have a feeling that this is not the intended purpose of KeyError, which to me signals that something does not exist for a given key, whereas this error is raised when something does already exist. In either case, I'm not sure what the user can do to recover.

f"Cannot overwrite value {self._local_scope[key]} for key {key}"
)
self._local_scope[key] = value

def __contains__(self, key: _Key) -> bool:
return (
key in self._local_scope or self.parent is not None and key in self.parent
)
Loading