From 1baa5c19dd8229f206067bc88ee4d870928d9363 Mon Sep 17 00:00:00 2001 From: Morten Ledum Date: Mon, 10 Jan 2022 16:08:18 +0100 Subject: [PATCH] Fix crash caused by missing thermostat specification Properly tests provided start_temperature and target_temperature values ensuring they are numbers before checking validity. Fails in a sane way if this is not true. --- hymd/input_parser.py | 48 ++++++++++++++++++++++++++++----- test/test_input_parser.py | 57 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 6 deletions(-) diff --git a/hymd/input_parser.py b/hymd/input_parser.py index 24e9abc6..3ea7dc80 100644 --- a/hymd/input_parser.py +++ b/hymd/input_parser.py @@ -841,13 +841,49 @@ def check_tau(config, comm=MPI.COMM_WORLD): def check_start_and_target_temperature(config, comm=MPI.COMM_WORLD): + """Validate provided starting and target thermostat temperatures + + Assesses the provided temperature target and ensures it is a non-negative + floating point number or :code:`False`. Ensures the starting temperature is + a non-negative floating point number or :code:`False`. + + If the value for either is :code:`None`, the returned configuration object + has the values defaulted to :code:`False` in each case. + + Parameters + ---------- + config : Config + Configuration object. + + Returns + ------- + validated_config : Config + Configuration object with validated :code:`target_temperature` and + :code:`start_temperature`. + """ for t in ("start_temperature", "target_temperature"): - if getattr(config, t) < 0: - warn_str = "t set to negative value, defaulting 0" - setattr(config, t, 0.0) - Logger.rank0.log(logging.WARNING, warn_str) - if comm.Get_rank() == 0: - warnings.warn(warn_str) + if getattr(config, t) is not None: + try: + if getattr(config, t) < 0: + warn_str = ( + f"{t} set to negative value ({getattr(config, t)}), " + f"defaulting to False" + ) + setattr(config, t, False) + Logger.rank0.log(logging.WARNING, warn_str) + if comm.Get_rank() == 0: + warnings.warn(warn_str) + except TypeError as e: + err_str = ( + f"Could not interpret {t} = {repr(getattr(config, t))} as " + f"a number." + ) + raise TypeError(err_str) from e + + if config.start_temperature is None: + config.start_temperature = False + if config.target_temperature is None: + config.target_temperature = False return config diff --git a/test/test_input_parser.py b/test/test_input_parser.py index edd2d078..3656c9a8 100644 --- a/test/test_input_parser.py +++ b/test/test_input_parser.py @@ -18,6 +18,7 @@ check_integrator, check_thermostat_coupling_groups, check_cancel_com_momentum, + check_start_and_target_temperature, ) @@ -481,3 +482,59 @@ def test_input_parser_check_cancel_com_momentum(config_toml, caplog): config_ = check_cancel_com_momentum(config) assert config_.cancel_com_momentum is False caplog.clear() + + +def test_input_parser_check_start_and_target_temperature(config_toml, caplog): + caplog.set_level(logging.INFO) + _, config_toml_str = config_toml + config = parse_config_toml(config_toml_str) + for t in ( + "hello", + MPI.COMM_WORLD, + config_toml, + [1], + ): + config.start_temperature = t + with pytest.raises(TypeError) as recorded_error: + _ = check_start_and_target_temperature(config) + log = caplog.text + assert all([(s in log) for s in ("not interpret", "a number")]) + message = str(recorded_error.value) + assert all([(s in message) for s in ("not interpret", "a number")]) + caplog.clear() + + config.target_temperature = t + with pytest.raises(TypeError) as recorded_error: + _ = check_start_and_target_temperature(config) + log = caplog.text + assert all([(s in log) for s in ("not interpret", "a number")]) + message = str(recorded_error.value) + assert all([(s in message) for s in ("not interpret", "a number")]) + caplog.clear() + + config = parse_config_toml(config_toml_str) + + config.start_temperature = None + assert check_start_and_target_temperature(config).start_temperature is False # noqa: E501 + config.target_temperature = None + assert check_start_and_target_temperature(config).target_temperature is False # noqa: E501 + + with pytest.warns(Warning) as recorded_warning: + config.start_temperature = -6.2985252885781357 + config = check_start_and_target_temperature(config) + assert config.start_temperature is False + message = recorded_warning[0].message.args[0] + log = caplog.text + assert all([(s in message) for s in ("to negative", "defaulting to")]) + assert all([(s in log) for s in ("to negative", "defaulting to")]) + caplog.clear() + + with pytest.warns(Warning) as recorded_warning: + config.target_temperature = -0.000025892857873 + config = check_start_and_target_temperature(config) + assert config.target_temperature is False + message = recorded_warning[0].message.args[0] + log = caplog.text + assert all([(s in message) for s in ("to negative", "defaulting to")]) + assert all([(s in log) for s in ("to negative", "defaulting to")]) + caplog.clear()