Skip to content

Commit

Permalink
test: mock test cases for xlmr model.
Browse files Browse the repository at this point in the history
  • Loading branch information
ltbringer committed Aug 29, 2021
1 parent 34eb65f commit f3fc4aa
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/plugin/text/classification/test_xlmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,27 @@ def test_xlmr_plugin_no_module_error():


def test_xlmr_plugin_when_no_labelencoder_saved():
save_module_name = const.XLMR_MODULE
save_model_name = const.XLMR_MULTI_CLASS_MODEL
const.XLMR_MODULE = "tests.plugin.text.classification.test_xlmr"
const.XLMR_MULTI_CLASS_MODEL = "MockClassifier"

xlmr_clf = XLMRMultiClass(
model_dir=".",
access=lambda w: w.input[const.CLASSIFICATION_INPUT],
mutate=write_intent_to_workflow,
)
assert isinstance(xlmr_clf, XLMRMultiClass)
assert xlmr_clf.model is None
const.XLMR_MODULE = save_module_name
const.XLMR_MULTI_CLASS_MODEL = save_model_name


def test_xlmr_plugin_when_labelencoder_EOFError(capsys):
save_module_name = const.XLMR_MODULE
save_model_name = const.XLMR_MULTI_CLASS_MODEL
const.XLMR_MODULE = "tests.plugin.text.classification.test_xlmr"
const.XLMR_MULTI_CLASS_MODEL = "MockClassifier"
_, file_path = tempfile.mkstemp(suffix=".pkl")
save_label_encoder_file = const.LABELENCODER_FILE
directory, file_name = os.path.split(file_path)
Expand All @@ -94,6 +105,8 @@ def test_xlmr_plugin_when_labelencoder_EOFError(capsys):
assert xlmr_plugin.model is None
os.remove(file_path)
const.LABELENCODER_FILE = save_label_encoder_file
const.XLMR_MODULE = save_module_name
const.XLMR_MULTI_CLASS_MODEL = save_model_name


def test_xlmr_init_mock():
Expand Down Expand Up @@ -155,6 +168,11 @@ def test_train_xlmr_mock():


def test_invalid_operations():
save_module_name = const.XLMR_MODULE
save_model_name = const.XLMR_MULTI_CLASS_MODEL
const.XLMR_MODULE = "tests.plugin.text.classification.test_xlmr"
const.XLMR_MULTI_CLASS_MODEL = "MockClassifier"

directory = "/tmp"
file_path = os.path.join(directory, const.LABELENCODER_FILE)
if os.path.exists(file_path):
Expand Down Expand Up @@ -197,6 +215,8 @@ def test_invalid_operations():

if os.path.exists(file_path):
os.remove(file_path)
const.XLMR_MODULE = save_module_name
const.XLMR_MULTI_CLASS_MODEL = save_model_name


@pytest.mark.parametrize("payload", load_tests("cases", __file__))
Expand Down

0 comments on commit f3fc4aa

Please sign in to comment.