diff --git a/mammoth/tests/test_models.py b/mammoth/tests/test_models.py index 089ca796..aea9355f 100644 --- a/mammoth/tests/test_models.py +++ b/mammoth/tests/test_models.py @@ -14,7 +14,7 @@ mammoth.opts._add_train_general_opts(parser) # -data option is required, but not used in this test, so dummy. -opts = parser.parse_known_args(['-tasks', 'dummy', '-node_rank', '0', '-model_dim', '500'])[0] +opts = parser.parse_known_args(['-tasks', 'dummy', '-node_rank', '0', '-model_dim', '500'], strict=False)[0] class TestModel(unittest.TestCase): diff --git a/mammoth/utils/parse.py b/mammoth/utils/parse.py index 54056cf9..79a000f7 100644 --- a/mammoth/utils/parse.py +++ b/mammoth/utils/parse.py @@ -269,6 +269,12 @@ def defaults(cls, *args): defaults = dummy_parser.parse_known_args([])[0] return defaults + def parse_known_args(self, *args, strict=True, **kwargs): + opts, unknown = super().parse_known_args(*args, **kwargs) + if strict and unknown: + raise ValueError(f'unknown arguments provided:\n{unknown}') + return opts, unknown + @classmethod def update_model_opts(cls, model_opts): cls._validate_adapters(model_opts)