Skip to content

Commit

Permalink
fix controlled dict init
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Nov 27, 2024
1 parent 9e45e02 commit 74c4fd2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 32 deletions.
37 changes: 17 additions & 20 deletions src/monty/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ class ControlledDict(collections.UserDict, ABC):
allow_del: ClassVar[bool] = True
allow_update: ClassVar[bool] = True

def __init__(self, *args, **kwargs):
"""Temporarily allow all add/update during initialization."""
original_allow_add = self.allow_add
original_allow_update = self.allow_update
try:
self.allow_add = True
self.allow_update = True
super().__init__(*args, **kwargs)
finally:
self.allow_add = original_allow_add
self.allow_update = original_allow_update

# TODO: extract checkers

# Overriding add/update operations
Expand Down Expand Up @@ -142,30 +154,15 @@ def clear(self):
super().clear()


class frozendict(dict):
class frozendict(ControlledDict):
"""
A dictionary that does not permit changes. The naming
violates PEP 8 to be consistent with the built-in "frozenset" naming.
violates PEP 8 to be consistent with the built-in `frozenset` naming.
"""

def __init__(self, *args, **kwargs) -> None:
"""
Args:
args: Passthrough arguments for standard dict.
kwargs: Passthrough keyword arguments for standard dict.
"""
dict.__init__(self, *args, **kwargs)

def __setitem__(self, key: Any, val: Any) -> None:
raise TypeError(f"{type(self).__name__} does not support item assignment")

def update(self, *args, **kwargs) -> None:
"""
Args:
args: Passthrough arguments for standard dict.
kwargs: Passthrough keyword arguments for standard dict.
"""
raise TypeError(f"Cannot update a {self.__class__.__name__}")
allow_add: ClassVar[bool] = False
allow_del: ClassVar[bool] = False
allow_update: ClassVar[bool] = False


class Namespace(dict): # TODO: this name is a bit confusing, deprecate it?
Expand Down
36 changes: 24 additions & 12 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from collections import UserDict

import pytest

from monty.collections import (
Expand Down Expand Up @@ -108,32 +110,42 @@ def test_del_disabled(self):
with pytest.raises(TypeError, match="delete is disabled"):
dct.clear()

def test_frozen_like(self):
"""Make sure setter is allow at init time."""
ControlledDict.allow_add = False
ControlledDict.allow_update = False

dct = ControlledDict({"hello": "world"})
assert isinstance(dct, UserDict)
assert dct["hello"] == "world"

assert not dct.allow_add
assert not dct.allow_update


def test_frozendict():
dct = frozendict({"hello": "world"})
assert isinstance(dct, dict)
assert isinstance(dct, UserDict)
assert dct["hello"] == "world"

assert not dct.allow_add
assert not dct.allow_update
assert not dct.allow_del

# Test setter
with pytest.raises(TypeError, match="Cannot overwrite existing key"):
dct["key"] == "val"
with pytest.raises(TypeError, match="allow_add is set to False"):
dct["key"] = "val"

# Test update
with pytest.raises(TypeError, match="Cannot overwrite existing key"):
with pytest.raises(TypeError, match="allow_add is set to False"):
dct.update(key="val")

# Test inplace-or (|=)
with pytest.raises(TypeError, match="Cannot overwrite existing key"):
dct |= {"key": "val"}

# TODO: from this point we need a different error message

# Test pop
with pytest.raises(TypeError, match="Cannot overwrite existing key"):
with pytest.raises(TypeError, match="delete is disabled"):
dct.pop("key")

# Test delete
with pytest.raises(TypeError, match="Cannot overwrite existing key"):
with pytest.raises(TypeError, match="delete is disabled"):
del dct["key"]


Expand Down

0 comments on commit 74c4fd2

Please sign in to comment.