Skip to content

Commit

Permalink
Add some basic tests for utils
Browse files Browse the repository at this point in the history
  • Loading branch information
AuguB committed Dec 12, 2023
1 parent 5234733 commit df6f186
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 18 deletions.
43 changes: 25 additions & 18 deletions pytest_tests/test_normative.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
import pytest
import sys
from unittest import mock
from pcntoolkit.normative import get_args
from unittest.mock import patch
from pcntoolkit.normative import evaluate
import numpy as np

def test_get_args():
# Define the test arguments
test_args = ['program_name', 'responses', '-f', 'estimate', '-m', 'maskfile', '-c', 'covfile', '-k', '5', '-t', 'testcov', '-r', 'testresp', '-a', 'gpr', '-x', 'configparam']
def test_evaluate():
# Create mock data
Y = np.random.randn(10,3)
Yhat = np.random.randn(10,3)

# Replace sys.argv with the test arguments
with mock.patch.object(sys, 'argv', test_args):
# Call the function with the test arguments
args = get_args()
# Call the function with the mock data
result = evaluate(Y, Yhat)

# Assert that the function returns the expected results
assert args.responses == 'responses'
assert args.func == 'estimate'
assert args.maskfile == 'maskfile'
assert args.covfile == 'covfile'
assert args.cvfolds == '5'
assert args.testcov == 'testcov'
assert args.testresp == 'testresp'
assert args.alg == 'gpr'
assert args.configparam == 'configparam'
# Assert that the result is a dictionary
assert isinstance(result, dict)

# Assert that the result contains the expected keys with the correct values
expected_keys = ['RMSE', 'Rho', 'pRho']
for key in expected_keys:
assert key in result
assert isinstance(result[key], np.ndarray)
assert result[key].shape == (3,)

# Optionally, assert that the values are close to the expected values
# This will require calculating the expected values manually or using a different implementation for comparison
# For example:
# expected_values = {'RMSE': np.array([0, 0, 0]), 'Rho': np.array([1, 1, 1]), 'pRho': np.array([1, 1, 1])}
# for key in expected_keys:
# np.testing.assert_allclose(result[key], expected_values[key])
39 changes: 39 additions & 0 deletions pytest_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from pcntoolkit.util.utils import create_poly_basis
import numpy as np


@pytest.mark.parametrize("samples, dim, degree, expected_shape", [(10, 3, 2, (10, 6)), (10, 3, 3, (10, 9))])
def test_create_poly_basis_shape(samples, dim, degree, expected_shape):
# Create mock data
X = np.random.randn(samples,dim)

# Call the function with the mock data
result = create_poly_basis(X, degree)

# Assert that the result is a numpy array
assert isinstance(result, np.ndarray)

# Assert that the result has the correct shape
assert result.shape == expected_shape


def test_create_poly_basis():
X = np.array([[1,2,3], [4, 5, 6]])
result = create_poly_basis(X, 2)
np.testing.assert_array_equal(result, np.array([[1., 2., 3., 1., 4., 9.], [4., 5., 6., 16., 25., 36.]]))


from pcntoolkit.util.utils import create_bspline_basis

@pytest.mark.parametrize("xmin, xmax, p, nknots", [(1, 3, 3, 4), (2, 5, 6, 9)])
def test_create_bspline_basis_shape(xmin, xmax, p, nknots):

# Call the function with the mock data
result = create_bspline_basis(xmin, xmax, p, nknots)

# Assert that the properties of the result match the input parameters
assert result.p == p
assert result.knot_vector.size == 2*p + nknots
assert result.knot_vector[0] == xmin
assert result.knot_vector[-1] == xmax

0 comments on commit df6f186

Please sign in to comment.