Skip to content

Commit

Permalink
Added validation and unit tests for storgage models
Browse files Browse the repository at this point in the history
* TaskReturnPattern: Confirm that the input pattern is a string type and that it is not empty.
* Traceback: Confirm that the input is a list of strings and that none of them are empty.
  • Loading branch information
ianmkenney committed Jul 18, 2024
1 parent b7f63d4 commit 7e82f54
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 16 deletions.
15 changes: 15 additions & 0 deletions alchemiscale/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ class TaskRestartPattern(GufeTokenizable):
max_retries: int

def __init__(self, pattern: str, max_retries: int):

if not isinstance(pattern, str) or pattern == "":
raise ValueError("`pattern` must be a non-empty string")

self.pattern = pattern

if not isinstance(max_retries, int) or max_retries <= 0:
Expand Down Expand Up @@ -189,6 +193,17 @@ def __eq__(self, other):
class Traceback(GufeTokenizable):

def __init__(self, tracebacks: List[str]):
value_error = ValueError(
"`tracebacks` must be a non-empty list of string values"
)
if not isinstance(tracebacks, list) or tracebacks == []:
raise value_error
else:
# in the case where tracebacks is not an iterable, this will raise a TypeError
all_string_values = all([isinstance(value, str) for value in tracebacks])
if not all_string_values or "" in tracebacks:
raise value_error

self.tracebacks = tracebacks

def _gufe_tokenize(self):
Expand Down
117 changes: 101 additions & 16 deletions alchemiscale/tests/unit/test_storage_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest

from alchemiscale.storage.models import NetworkStateEnum, NetworkMark
from alchemiscale.storage.models import (
NetworkStateEnum,
NetworkMark,
TaskRestartPattern,
Traceback,
)
from alchemiscale import ScopedKey


Expand Down Expand Up @@ -42,33 +47,113 @@ def test_suggested_states_message(self):

class TestTaskRestartPattern(object):

@pytest.mark.xfail(raises=NotImplementedError)
pattern_value_error = "`pattern` must be a non-empty string"
max_retries_value_error = "`max_retries` must have a positive integer value."

def test_empty_pattern(self):
raise NotImplementedError
with pytest.raises(ValueError, match=self.pattern_value_error):
_ = TaskRestartPattern("", 3)

def test_non_string_pattern(self):
with pytest.raises(ValueError, match=self.pattern_value_error):
_ = TaskRestartPattern(None, 3)

with pytest.raises(ValueError, match=self.pattern_value_error):
_ = TaskRestartPattern([], 3)

def test_non_positive_max_retries(self):

@pytest.mark.xfail(raises=NotImplementedError)
def test_negative_max_retries(self):
raise NotImplementedError
with pytest.raises(ValueError, match=self.max_retries_value_error):
TaskRestartPattern("Example pattern", 0)

with pytest.raises(ValueError, match=self.max_retries_value_error):
TaskRestartPattern("Example pattern", -1)

@pytest.mark.xfail(raises=NotImplementedError)
def test_non_int_max_retries(self):
raise NotImplementedError
with pytest.raises(ValueError, match=self.max_retries_value_error):
TaskRestartPattern("Example pattern", 4.0)

@pytest.mark.xfail(raises=NotImplementedError)
def test_to_dict(self):
raise NotImplementedError
trp = TaskRestartPattern("Example pattern", 3)
dict_trp = trp.to_dict()

assert len(dict_trp.keys()) == 5

assert dict_trp.pop("__qualname__") == "TaskRestartPattern"
assert dict_trp.pop("__module__") == "alchemiscale.storage.models"

# light test of the version key
try:
dict_trp.pop(":version:")
except KeyError:
raise AssertionError("expected to find :version:")

expected = {"pattern": "Example pattern", "max_retries": 3}

assert expected == dict_trp

@pytest.mark.xfail(raises=NotImplementedError)
def test_from_dict(self):
raise NotImplementedError

original_pattern = "Example pattern"
original_max_retries = 3

trp_orig = TaskRestartPattern(original_pattern, original_max_retries)
trp_dict = trp_orig.to_dict()
trp_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(trp_dict)

assert trp_reconstructed.pattern == original_pattern
assert trp_reconstructed.max_retries == original_max_retries


class TestTraceback(object):

@pytest.mark.xfail(raises=NotImplementedError)
valid_entry = ["traceback1", "traceback2", "traceback3"]
tracebacks_value_error = "`tracebacks` must be a non-empty list of string values"

def test_empty_string_element(self):
with pytest.raises(ValueError, match=self.tracebacks_value_error):
Traceback(self.valid_entry + [""])

def test_non_list_parameter(self):
with pytest.raises(ValueError, match=self.tracebacks_value_error):
Traceback(None)

with pytest.raises(ValueError, match=self.tracebacks_value_error):
Traceback(100)

with pytest.raises(ValueError, match=self.tracebacks_value_error):
Traceback("not a list, but still an iterable that yields strings")

def test_list_non_string_elements(self):
with pytest.raises(ValueError, match=self.tracebacks_value_error):
Traceback(self.valid_entry + [None])

def test_empty_list(self):
with pytest.raises(ValueError, match=self.tracebacks_value_error):
Traceback([])

def test_to_dict(self):
raise NotImplementedError
tb = Traceback(self.valid_entry)
tb_dict = tb.to_dict()

assert len(tb_dict) == 4

assert tb_dict.pop("__qualname__") == "Traceback"
assert tb_dict.pop("__module__") == "alchemiscale.storage.models"

# light test of the version key
try:
tb_dict.pop(":version:")
except KeyError:
raise AssertionError("expected to find :version:")

expected = {"tracebacks": self.valid_entry}

assert expected == tb_dict

@pytest.mark.xfail(raises=NotImplementedError)
def test_from_dict(self):
raise NotImplementedError
tb_orig = Traceback(self.valid_entry)
tb_dict = tb_orig.to_dict()
tb_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(tb_dict)

assert tb_reconstructed.tracebacks == self.valid_entry

0 comments on commit 7e82f54

Please sign in to comment.