diff --git a/tests/test_common.py b/tests/test_common.py index 088e6442..fb990e7e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -224,3 +224,14 @@ def test_remove_numeric_parentheses(): assert common.remove_numeric_parentheses('C2H5(547)H') == 'C2H5(547)H' assert common.remove_numeric_parentheses('HNO(T)(21)') == 'HNO(T)' + +def test_numpy_to_list(): + """Test the numpy_to_list() function""" + import numpy as np + assert common.numpy_to_list(np.array([1, 2, 3])) == [1, 2, 3] + assert common.numpy_to_list(np.array([1.0, 2.0, 3.0])) == [1.0, 2.0, 3.0] + assert common.numpy_to_list(np.array([1.0, 2.0, 3.0], dtype=np.float32)) == [1.0, 2.0, 3.0] + assert common.numpy_to_list(np.array([1.0, 2.0, 3.0], dtype=np.float64)) == [1.0, 2.0, 3.0] + assert common.numpy_to_list(np.array([1, 2, 3], dtype=np.int32)) == [1, 2, 3] + assert common.numpy_to_list(np.array([1, 2, 3], dtype=np.int64)) == [1, 2, 3] + assert common.numpy_to_list(np.array([1, 2, 3], dtype=np.bool_)) == [True, True, True]