-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
misc: add a ScopedDict util, similar to LLVM's ScopedHashTable (#3355)
- Loading branch information
1 parent
4040e03
commit e6c7282
Showing
4 changed files
with
120 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import pytest | ||
|
||
from xdsl.utils.scoped_dict import ScopedDict | ||
|
||
|
||
def test_simple(): | ||
table = ScopedDict[int, int]() | ||
|
||
table[1] = 2 | ||
|
||
assert table[1] == 2 | ||
|
||
table[2] = 3 | ||
|
||
assert table[2] == 3 | ||
|
||
with pytest.raises(ValueError, match="Cannot overwrite value 3 for key 2"): | ||
table[2] = 4 | ||
|
||
with pytest.raises(KeyError): | ||
table[3] | ||
|
||
inner = ScopedDict(table, name="inner") | ||
|
||
inner[2] = 4 | ||
|
||
assert inner[2] == 4 | ||
assert table[2] == 3 | ||
|
||
inner[3] = 5 | ||
|
||
assert 3 not in table | ||
assert 3 in inner | ||
assert 4 not in inner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Generic, TypeVar | ||
|
||
_Key = TypeVar("_Key") | ||
_Value = TypeVar("_Value") | ||
|
||
|
||
class ScopedDict(Generic[_Key, _Value]): | ||
""" | ||
A tiered mapping from keys to values. | ||
Once a value is set for a key, it cannot be overwritten. | ||
A ScopedDict may have a parent dict, which is used as a fallback when a value for a | ||
key is not found. | ||
If a ScopedDict and its parent have values for the same key, the child value will be | ||
returned. | ||
This structure is useful for contexts where keys and values have a known scope, such | ||
as during IR construction from an Abstract Syntax Tree. | ||
ScopedDict instances may have a `name` property as a hint during debugging. | ||
""" | ||
|
||
_local_scope: dict[_Key, _Value] | ||
parent: ScopedDict[_Key, _Value] | None | ||
name: str | None | ||
|
||
def __init__( | ||
self, | ||
parent: ScopedDict[_Key, _Value] | None = None, | ||
*, | ||
name: str | None = None, | ||
) -> None: | ||
self._local_scope = {} | ||
self.parent = parent | ||
self.name = name | ||
|
||
def __getitem__(self, key: _Key) -> _Value: | ||
""" | ||
Fetch key from environment. Attempts to first fetch from current scope, | ||
then from parent scopes. Raises KeyError error if not found. | ||
""" | ||
local = self._local_scope.get(key) | ||
if local is not None: | ||
return local | ||
if self.parent is None: | ||
raise KeyError(f"No value for key {key}") | ||
return self.parent[key] | ||
|
||
def __setitem__(self, key: _Key, value: _Value): | ||
""" | ||
Assign key to current scope. Raises InterpretationError if key already | ||
assigned to. | ||
""" | ||
if key in self._local_scope: | ||
raise ValueError( | ||
f"Cannot overwrite value {self._local_scope[key]} for key {key}" | ||
) | ||
self._local_scope[key] = value | ||
|
||
def __contains__(self, key: _Key) -> bool: | ||
return ( | ||
key in self._local_scope or self.parent is not None and key in self.parent | ||
) |