From df6f186ccd3cbd8f42b6a7851965270889359364 Mon Sep 17 00:00:00 2001 From: Stijn de Boer Date: Tue, 12 Dec 2023 08:43:10 +0100 Subject: [PATCH] Add some basic tests for utils --- pytest_tests/test_normative.py | 43 ++++++++++++++++++++-------------- pytest_tests/test_utils.py | 39 ++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 18 deletions(-) create mode 100644 pytest_tests/test_utils.py diff --git a/pytest_tests/test_normative.py b/pytest_tests/test_normative.py index 2d50981e..8de9479d 100644 --- a/pytest_tests/test_normative.py +++ b/pytest_tests/test_normative.py @@ -1,24 +1,31 @@ import pytest import sys from unittest import mock -from pcntoolkit.normative import get_args +from unittest.mock import patch +from pcntoolkit.normative import evaluate +import numpy as np -def test_get_args(): - # Define the test arguments - test_args = ['program_name', 'responses', '-f', 'estimate', '-m', 'maskfile', '-c', 'covfile', '-k', '5', '-t', 'testcov', '-r', 'testresp', '-a', 'gpr', '-x', 'configparam'] +def test_evaluate(): + # Create mock data + Y = np.random.randn(10,3) + Yhat = np.random.randn(10,3) - # Replace sys.argv with the test arguments - with mock.patch.object(sys, 'argv', test_args): - # Call the function with the test arguments - args = get_args() + # Call the function with the mock data + result = evaluate(Y, Yhat) - # Assert that the function returns the expected results - assert args.responses == 'responses' - assert args.func == 'estimate' - assert args.maskfile == 'maskfile' - assert args.covfile == 'covfile' - assert args.cvfolds == '5' - assert args.testcov == 'testcov' - assert args.testresp == 'testresp' - assert args.alg == 'gpr' - assert args.configparam == 'configparam' \ No newline at end of file + # Assert that the result is a dictionary + assert isinstance(result, dict) + + # Assert that the result contains the expected keys with the correct values + expected_keys = ['RMSE', 'Rho', 'pRho'] + for key in expected_keys: + assert key in result + assert isinstance(result[key], np.ndarray) + assert result[key].shape == (3,) + + # Optionally, assert that the values are close to the expected values + # This will require calculating the expected values manually or using a different implementation for comparison + # For example: + # expected_values = {'RMSE': np.array([0, 0, 0]), 'Rho': np.array([1, 1, 1]), 'pRho': np.array([1, 1, 1])} + # for key in expected_keys: + # np.testing.assert_allclose(result[key], expected_values[key]) \ No newline at end of file diff --git a/pytest_tests/test_utils.py b/pytest_tests/test_utils.py new file mode 100644 index 00000000..7d890c1f --- /dev/null +++ b/pytest_tests/test_utils.py @@ -0,0 +1,39 @@ +import pytest +from pcntoolkit.util.utils import create_poly_basis +import numpy as np + + +@pytest.mark.parametrize("samples, dim, degree, expected_shape", [(10, 3, 2, (10, 6)), (10, 3, 3, (10, 9))]) +def test_create_poly_basis_shape(samples, dim, degree, expected_shape): + # Create mock data + X = np.random.randn(samples,dim) + + # Call the function with the mock data + result = create_poly_basis(X, degree) + + # Assert that the result is a numpy array + assert isinstance(result, np.ndarray) + + # Assert that the result has the correct shape + assert result.shape == expected_shape + + +def test_create_poly_basis(): + X = np.array([[1,2,3], [4, 5, 6]]) + result = create_poly_basis(X, 2) + np.testing.assert_array_equal(result, np.array([[1., 2., 3., 1., 4., 9.], [4., 5., 6., 16., 25., 36.]])) + + +from pcntoolkit.util.utils import create_bspline_basis + +@pytest.mark.parametrize("xmin, xmax, p, nknots", [(1, 3, 3, 4), (2, 5, 6, 9)]) +def test_create_bspline_basis_shape(xmin, xmax, p, nknots): + + # Call the function with the mock data + result = create_bspline_basis(xmin, xmax, p, nknots) + + # Assert that the properties of the result match the input parameters + assert result.p == p + assert result.knot_vector.size == 2*p + nknots + assert result.knot_vector[0] == xmin + assert result.knot_vector[-1] == xmax \ No newline at end of file