Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further refine the unit tests of ensemble_md #49

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ include parameters for data analysis here.
the parameters :code:`gro` and :code:`top`, only one MDP file can be specified for the parameter :code:`mdp`. If you wish to use
different parameters for different replicas, please use the parameter :code:`mdp_args`.
- :code:`modify_coords`: (Optional, Default: :code:`None`)
The name of the Python module (without including the :code:`.py` extension) for modifying the output coordinates of the swapping replicas
The file path to the Python module for modifying the output coordinates of the swapping replicas
before the coordinate exchange, which is generally required in REXEE simulations for multiple serial mutations.
For the CLI :code:`run_REXEE` to work, here is the predefined contract for the module/function based on the assumptions :code:`run_REXEE` makes.
Modules/functions not obeying the contract are unlikely to work.
Expand Down
21 changes: 15 additions & 6 deletions ensemble_md/replica_exchange_EE.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,14 @@ def set_params(self, analysis):
raise ParameterError(f"The parameter '{i}' should be a dictionary.")

if self.add_swappables is not None:
if not isinstance(self.add_swappables, list):
raise ParameterError("The parameter 'add_swappables' should be a nested list.")
for sublist in self.add_swappables:
if not isinstance(sublist, list):
raise ParameterError("The parameter 'add_swappables' should be a nested list.")
for item in sublist:
if not isinstance(item, int) or item < 0:
raise ParameterError("Each number specified in 'add_swappables' should be a non-negative integer.") # noqa: E501
if [len(i) for i in self.add_swappables] != [2] * len(self.add_swappables):
raise ParameterError("Each sublist in 'add_swappables' should contain two integers.")

if self.mdp_args is not None:
# Note that mdp_args is a dictionary including MDP parameters DIFFERING across replicas.
Expand Down Expand Up @@ -441,9 +441,17 @@ def set_params(self, analysis):

# 7-12. External module for coordinate modification
if self.modify_coords is not None:
sys.path.append(os.getcwd())
module = importlib.import_module(self.modify_coords)
self.modify_coords_fn = getattr(module, self.modify_coords)
module_file = os.path.basename(self.modify_coords)
module_dir = os.path.dirname(self.modify_coords)
if module_dir not in sys.path:
sys.path.append(module_dir) # so that the module can be imported
module_name = os.path.splitext(module_file)[0]
module = importlib.import_module(module_name)
if not hasattr(module, module_name):
err_msg = f'The module for coordinate manipulation (specified through the parameter) must have a function with the same name as the module, i.e., {module_name}.' # noqa: E501
raise ParameterError(err_msg)
else:
self.modify_coords_fn = getattr(module, module_name)
else:
self.modify_coords_fn = None

Expand Down Expand Up @@ -509,6 +517,7 @@ def print_params(self, params_analysis=False):
print(f"Additionally defined swappable states: {self.add_swappables}")
print(f"Additional grompp arguments: {self.grompp_args}")
print(f"Additional runtime arguments: {self.runtime_args}")
print(f"External modules for coordinate manipulation: {self.modify_coords}")
# print(f"Number of attempted swaps in one exchange interval: {self.n_ex}")
if self.mdp_args is not None and len(self.mdp_args.keys()) > 1:
print("MDP parameters differing across replicas:")
Expand Down Expand Up @@ -935,7 +944,6 @@ def get_swapping_pattern(self, dhdl_files, states):
print('No swap is proposed because there is no swappable pair at all.')
break
else:
self.n_swap_attempts += 1
if self.proposal == 'exhaustive':
n_ex_exhaustive += 1

Expand All @@ -946,6 +954,7 @@ def get_swapping_pattern(self, dhdl_files, states):
print('No swap is proposed because there is no swappable pair at all.')
break # no need to re-identify swappable pairs and draw new samples
else:
self.n_swap_attempts += 1
if self.verbose is True and self.proposal != 'exhaustive':
print(f'A swap ({i + 1}/{n_ex}) is proposed between the configurations of Simulation {swap[0]} (state {states[swap[0]]}) and Simulation {swap[1]} (state {states[swap[1]]}) ...') # noqa: E501

Expand Down
43 changes: 40 additions & 3 deletions ensemble_md/tests/test_analyze_free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
@patch('ensemble_md.analysis.analyze_free_energy.extract_dHdl')
@patch('ensemble_md.analysis.analyze_free_energy.detect_equilibration')
@patch('ensemble_md.analysis.analyze_free_energy.subsample_correlated_data')
def test_preprocess_data(mock_corr, mock_equil, mock_extract_dHdl, mock_extract_u_nk, mock_subsampling, mock_alchemlyb, capfd): # noqa: E501
def test_preprocess_data(mock_corr, mock_equil, mock_extract_dhdl, mock_extract_u_nk, mock_subsampling, mock_alchemlyb, capfd): # noqa: E501
mock_data, mock_data_series = MagicMock(), MagicMock()
mock_alchemlyb.concat.return_value = mock_data
mock_subsampling.u_nk2series.return_value = mock_data_series
mock_subsampling._prepare_input.return_value = (mock_data, mock_data_series)
mock_equil.return_value = (10, 5, 50) # t, g, Neff_max
mock_equil.return_value = (10, 5, 18) # t, g, Neff_max
mock_data_series.__len__.return_value = 100 # For one of the print statements

# Set slicing to return different mock objects based on input
Expand All @@ -53,6 +53,7 @@ def generic_list_slicing(key):
mock_data_series.__getitem__.side_effect = slicing_side_effect # so that we can use mock_data_series[t:]
mock_data_series_equil = mock_data_series[10:] # Mock the equilibrated data series, given t=10

# Case 1: data_type = u_nk
files = [[f'ensemble_md/tests/data/dhdl/simulation_example/sim_{i}/iteration_{j}/dhdl.xvg' for j in range(3)] for i in range(4)] # noqa: E501
results = analyze_free_energy.preprocess_data(files, 300, 'u_nk')

Expand All @@ -74,13 +75,48 @@ def generic_list_slicing(key):
assert ' Adopted spacing: 1' in out
assert ' 10.0% of the u_nk data was in the equilibrium region and therfore discarded.' in out # noqa: E501
assert ' Statistical inefficiency of u_nk: 5.0' in out
assert ' Number of effective samples: 50' in out
assert ' Number of effective samples: 18' in out
assert mock_corr.call_args_list[i] == call(mock_data_series_equil, g=5)

assert len(results[0]) == 4
assert results[1] == [10, 10, 10, 10]
assert results[2] == [5, 5, 5, 5]

# Case 2: data_type = dHdl
mock_alchemlyb.concat.reset_mock()
mock_subsampling._prepare_input.reset_mock()
mock_subsampling.slicing.reset_mock()
mock_equil.reset_mock()

mock_subsampling.dhdl2series.return_value = mock_data_series
mock_subsampling._prepare_input.return_value = (mock_data, mock_data_series)
mock_data_series.__len__.return_value = 200
mock_data_series.values.__len__.return_value = 200

results = analyze_free_energy.preprocess_data(files, 300, 'dhdl', t=10, g=5)
out, err = capfd.readouterr()

for i in range(4):
for j in range(3):
assert mock_extract_dhdl.call_args_list[i * 3 + j] == call(files[i][j], T=300)
assert mock_subsampling._prepare_input.call_args_list[i] == call(mock_data, mock_data_series, drop_duplicates=True, sort=True) # noqa: E501
assert mock_subsampling.slicing.call_args_list[2 * i] == call(mock_data, step=1)
assert mock_subsampling.slicing.call_args_list[2 * i + 1] == call(mock_data_series, step=1)
assert 'Subsampling and decorrelating the concatenated dhdl data ...' in out
assert ' Adopted spacing: 1' in out
assert ' 5.0% of the dhdl data was in the equilibrium region and therfore discarded.' in out # noqa: E501
assert ' Statistical inefficiency of dhdl: 5.0' in out
assert ' Number of effective samples: 38' in out
assert mock_corr.call_args_list[i] == call(mock_data_series_equil, g=5)

assert len(results[0]) == 4
assert results[1] == []
assert results[2] == []

# Case 3: Invalid data_type
with pytest.raises(ValueError, match="Invalid data_type. Expected 'u_nk' or 'dhdl'."):
analyze_free_energy.preprocess_data(files, 300, 'xyz')


@pytest.mark.parametrize("method, expected_estimator", [
("MBAR", "MBAR estimator"),
Expand Down Expand Up @@ -152,6 +188,7 @@ def test_combine_df_adjacent():
state_ranges = [[0, 1, 2], [1, 2, 3]]

# Test 1: df_err_adjacent is None (in which case err_type is ignored)
# Note that this test would lead to two harmless RuntimWarnings due to calculations like np.std([1], ddof=1), which return NaN # noqa: E501
results = analyze_free_energy._combine_df_adjacent(df_adjacent, state_ranges, None, "propagate")
assert results[0] == [1, 3.5, 6]
assert math.isnan(results[1][0])
Expand Down
16 changes: 16 additions & 0 deletions ensemble_md/tests/test_analyze_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def test_calc_transmtx():
assert B3 is None
assert C3 is None

# Case 4: Invalid simulation type
with pytest.raises(ValueError, match='Invalid simulation type test.'):
analyze_matrix.calc_transmtx(os.path.join(input_path, 'log/EXE.log'), simulation_type='test')


def test_calc_equil_prob(capfd):
# Case 1: Right stochastic
Expand All @@ -87,6 +91,18 @@ def test_calc_equil_prob(capfd):
assert 'The input transition matrix is neither right nor left stochastic' in out


def test_calc_t_relax():
# Case 1: spectral_gap_err is specified
results = analyze_matrix.calc_t_relax(0.5, 0.1, 0.1)
assert results[0] == 0.2
assert results[1] == 0.1 * 0.1 / 0.5 ** 2

# Case 2: spectral_gap_err is not specified
results = analyze_matrix.calc_t_relax(0.5, 0.1)
assert results[0] == 0.2
assert results[1] is None


def test_calc_spectral_gap(capfd):
# Case 1 (sanity check): doublly stochastic
mtx = np.array([[0.5, 0.5], [0.5, 0.5]])
Expand Down
60 changes: 51 additions & 9 deletions ensemble_md/tests/test_analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,38 @@ def test_stitch_time_series_for_sim():
assert os.path.exists('state_trajs_for_sim.npy')
os.remove('state_trajs_for_sim.npy')

# Test 2: Test for discontinuous time series
# Test 2: The case where dhdl is False
# Here we again use dhdl.xvg files but use dhdl=False with col_idx=1, which corresponds to the state index
trajs = analyze_traj.stitch_time_series_for_sim(files, dhdl=False, col_idx=1)

trajs[0] == [
0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1,
1, 1, 1, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 1, 0, 1, 1,
1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 2, 1, 1, 0, 0, 1, 0, 1, 0, 1
]

trajs[1] == [
1, 1, 2, 3, 3, 3, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 1, 1, 1, 1,
2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 1, 1, 1, 1, 2, 3, 3, 3, 2, 2,
1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 2, 1, 1, 0, 0, 1, 0, 1, 0, 1
]

trajs[2] == [
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2, 3, 3,
3, 3, 3, 3, 3, 3, 3, 2, 3, 2, 3, 3, 3, 2, 2, 3, 4, 3, 3, 2,
3, 3, 2, 2, 2, 3, 4, 3, 4, 4, 5, 5, 5, 5, 4, 3, 4, 3, 3, 4, 4
]

trajs[3] == [
3, 3, 3, 3, 3, 3, 3, 5, 4, 4, 5, 4, 4, 5, 4, 5, 5, 5, 4, 5,
4, 4, 5, 4, 5, 5, 4, 5, 5, 5, 4, 5, 5, 4, 5, 4, 5, 4, 5, 5,
6, 6, 6, 5, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 5, 6, 6, 6, 7, 6, 7
]

assert os.path.exists('state_trajs_for_sim.npy')
os.remove('state_trajs_for_sim.npy')

# Test 3: Test for discontinuous time series
# Here, for sim_2, we exclude the last 5 lines for the dhdl.xvg file in iteration_1 to create a gap
save_and_exclude(f'{folder}/sim_2/iteration_1/dhdl.xvg', 5)
os.rename(f'{folder}/sim_2/iteration_1/dhdl.xvg', f'{folder}/sim_2/iteration_1/dhdl_temp.xvg')
Expand Down Expand Up @@ -260,6 +291,7 @@ def test_traj2transmtx():
np.testing.assert_array_equal(analyze_traj.traj2transmtx(traj, N, normalize=False), array)

# Case 2: normalize=True
# This test would lead to a harmless RuntimeWarnings due to 0/0 in the last row.
array = np.array([
[0, 0.5, 0, 0.5],
[0.5, 0, 0.5, 0],
Expand Down Expand Up @@ -624,16 +656,26 @@ def test_plot_transit_time(mock_plt):
assert mock_plt.ylabel.call_args_list[0] == call('Average transit time from states 0 to k (step)')
assert mock_plt.ylabel.call_args_list[1] == call('Average transit time from states k to 0 (step)')
assert mock_plt.ylabel.call_args_list[2] == call('Average round-trip time (step)')
assert [mock_plt.savefig.call_args_list[i][0][0] for i in range(3)] == [
'./t_0k.png',
'./t_k0.png',
'./t_roundtrip.png',
]

# Case 2: dt = 0.2 ps, fig_prefix = 'test', here we just test the return values
mock_plt.reset_mock()
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, dt=0.2)
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, dt=0.2, fig_prefix='test')
t_1_, t_2_, t_3_ = [[1.0, 1.4], [0.8, 0.8]], [[0.8, 0.6], [1.2]], [[1.8, 2.0], [2.0]]
for i in range(2):
np.testing.assert_array_almost_equal(t_1[i], t_1_[i])
np.testing.assert_array_almost_equal(t_2[i], t_2_[i])
np.testing.assert_array_almost_equal(t_3[i], t_3_[i])
assert u == 'ps'
assert [mock_plt.savefig.call_args_list[i][0][0] for i in range(3)] == [
'./test_t_0k.png',
'./test_t_k0.png',
'./test_t_roundtrip.png',
]

# Case 3: dt = 200 ps, long trajs
mock_plt.reset_mock()
Expand Down Expand Up @@ -661,7 +703,7 @@ def test_plot_transit_time(mock_plt):
# Case 5: More than 100 round trips so that a histogram is plotted
mock_plt.reset_mock()
trajs = np.array([[0, 1, 2, 3, 2] * 20000, [0, 1, 3, 2, 1] * 20000])
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N)
t_1, t_2, t_3, u = analyze_traj.plot_transit_time(trajs, N, fig_prefix='test')

assert t_1 == [[3] * 20000, [2] * 20000]
assert t_2 == [[2] * 19999, [3] * 19999]
Expand Down Expand Up @@ -715,12 +757,12 @@ def test_plot_transit_time(mock_plt):
]

assert [mock_plt.savefig.call_args_list[i][0][0] for i in range(6)] == [
'./t_0k.png',
'./hist_t_0k.png',
'./t_k0.png',
'./hist_t_k0.png',
'./t_roundtrip.png',
'./hist_t_roundtrip.png'
'./test_t_0k.png',
'./test_hist_t_0k.png',
'./test_t_k0.png',
'./test_hist_t_k0.png',
'./test_t_roundtrip.png',
'./test_hist_t_roundtrip.png'
]


Expand Down
14 changes: 14 additions & 0 deletions ensemble_md/tests/test_gmx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,20 @@ def test_write(self):
mdp = gmx_parser.MDP("ensemble_md/tests/data/expanded.mdp")
mdp.write('test_1.mdp', skipempty=False)
mdp.write('test_2.mdp', skipempty=True)

assert os.path.isfile('test_1.mdp')
assert os.path.isfile('test_2.mdp')

mdp = gmx_parser.MDP('test_1.mdp')
mdp.write(skipempty=True) # This should overwrite the file

# Check if the files are the same
with open('test_1.mdp', 'r') as f:
lines_1 = f.readlines()
with open('test_2.mdp', 'r') as f:
lines_2 = f.readlines()
assert lines_1 == lines_2

os.remove('test_1.mdp')
os.remove('test_2.mdp')

Expand Down
Loading
Loading