diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index b9090160..fae7af93 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -159,6 +159,10 @@ class TaskRestartPattern(GufeTokenizable): max_retries: int def __init__(self, pattern: str, max_retries: int): + + if not isinstance(pattern, str) or pattern == "": + raise ValueError("`pattern` must be a non-empty string") + self.pattern = pattern if not isinstance(max_retries, int) or max_retries <= 0: @@ -189,6 +193,17 @@ def __eq__(self, other): class Traceback(GufeTokenizable): def __init__(self, tracebacks: List[str]): + value_error = ValueError( + "`tracebacks` must be a non-empty list of string values" + ) + if not isinstance(tracebacks, list) or tracebacks == []: + raise value_error + else: + # in the case where tracebacks is not an iterable, this will raise a TypeError + all_string_values = all([isinstance(value, str) for value in tracebacks]) + if not all_string_values or "" in tracebacks: + raise value_error + self.tracebacks = tracebacks def _gufe_tokenize(self): diff --git a/alchemiscale/tests/unit/test_storage_models.py b/alchemiscale/tests/unit/test_storage_models.py index 68c9b8c7..02fe188e 100644 --- a/alchemiscale/tests/unit/test_storage_models.py +++ b/alchemiscale/tests/unit/test_storage_models.py @@ -1,6 +1,11 @@ import pytest -from alchemiscale.storage.models import NetworkStateEnum, NetworkMark +from alchemiscale.storage.models import ( + NetworkStateEnum, + NetworkMark, + TaskRestartPattern, + Traceback, +) from alchemiscale import ScopedKey @@ -42,33 +47,113 @@ def test_suggested_states_message(self): class TestTaskRestartPattern(object): - @pytest.mark.xfail(raises=NotImplementedError) + pattern_value_error = "`pattern` must be a non-empty string" + max_retries_value_error = "`max_retries` must have a positive integer value." + def test_empty_pattern(self): - raise NotImplementedError + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern("", 3) + + def test_non_string_pattern(self): + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern(None, 3) + + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern([], 3) + + def test_non_positive_max_retries(self): - @pytest.mark.xfail(raises=NotImplementedError) - def test_negative_max_retries(self): - raise NotImplementedError + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern("Example pattern", 0) + + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern("Example pattern", -1) - @pytest.mark.xfail(raises=NotImplementedError) def test_non_int_max_retries(self): - raise NotImplementedError + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern("Example pattern", 4.0) - @pytest.mark.xfail(raises=NotImplementedError) def test_to_dict(self): - raise NotImplementedError + trp = TaskRestartPattern("Example pattern", 3) + dict_trp = trp.to_dict() + + assert len(dict_trp.keys()) == 5 + + assert dict_trp.pop("__qualname__") == "TaskRestartPattern" + assert dict_trp.pop("__module__") == "alchemiscale.storage.models" + + # light test of the version key + try: + dict_trp.pop(":version:") + except KeyError: + raise AssertionError("expected to find :version:") + + expected = {"pattern": "Example pattern", "max_retries": 3} + + assert expected == dict_trp - @pytest.mark.xfail(raises=NotImplementedError) def test_from_dict(self): - raise NotImplementedError + + original_pattern = "Example pattern" + original_max_retries = 3 + + trp_orig = TaskRestartPattern(original_pattern, original_max_retries) + trp_dict = trp_orig.to_dict() + trp_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(trp_dict) + + assert trp_reconstructed.pattern == original_pattern + assert trp_reconstructed.max_retries == original_max_retries class TestTraceback(object): - @pytest.mark.xfail(raises=NotImplementedError) + valid_entry = ["traceback1", "traceback2", "traceback3"] + tracebacks_value_error = "`tracebacks` must be a non-empty list of string values" + + def test_empty_string_element(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(self.valid_entry + [""]) + + def test_non_list_parameter(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(None) + + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(100) + + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback("not a list, but still an iterable that yields strings") + + def test_list_non_string_elements(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(self.valid_entry + [None]) + + def test_empty_list(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback([]) + def test_to_dict(self): - raise NotImplementedError + tb = Traceback(self.valid_entry) + tb_dict = tb.to_dict() + + assert len(tb_dict) == 4 + + assert tb_dict.pop("__qualname__") == "Traceback" + assert tb_dict.pop("__module__") == "alchemiscale.storage.models" + + # light test of the version key + try: + tb_dict.pop(":version:") + except KeyError: + raise AssertionError("expected to find :version:") + + expected = {"tracebacks": self.valid_entry} + + assert expected == tb_dict - @pytest.mark.xfail(raises=NotImplementedError) def test_from_dict(self): - raise NotImplementedError + tb_orig = Traceback(self.valid_entry) + tb_dict = tb_orig.to_dict() + tb_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(tb_dict) + + assert tb_reconstructed.tracebacks == self.valid_entry