-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
64 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,31 @@ | ||
import pytest | ||
import sys | ||
from unittest import mock | ||
from pcntoolkit.normative import get_args | ||
from unittest.mock import patch | ||
from pcntoolkit.normative import evaluate | ||
import numpy as np | ||
|
||
def test_get_args(): | ||
# Define the test arguments | ||
test_args = ['program_name', 'responses', '-f', 'estimate', '-m', 'maskfile', '-c', 'covfile', '-k', '5', '-t', 'testcov', '-r', 'testresp', '-a', 'gpr', '-x', 'configparam'] | ||
def test_evaluate(): | ||
# Create mock data | ||
Y = np.random.randn(10,3) | ||
Yhat = np.random.randn(10,3) | ||
|
||
# Replace sys.argv with the test arguments | ||
with mock.patch.object(sys, 'argv', test_args): | ||
# Call the function with the test arguments | ||
args = get_args() | ||
# Call the function with the mock data | ||
result = evaluate(Y, Yhat) | ||
|
||
# Assert that the function returns the expected results | ||
assert args.responses == 'responses' | ||
assert args.func == 'estimate' | ||
assert args.maskfile == 'maskfile' | ||
assert args.covfile == 'covfile' | ||
assert args.cvfolds == '5' | ||
assert args.testcov == 'testcov' | ||
assert args.testresp == 'testresp' | ||
assert args.alg == 'gpr' | ||
assert args.configparam == 'configparam' | ||
# Assert that the result is a dictionary | ||
assert isinstance(result, dict) | ||
|
||
# Assert that the result contains the expected keys with the correct values | ||
expected_keys = ['RMSE', 'Rho', 'pRho'] | ||
for key in expected_keys: | ||
assert key in result | ||
assert isinstance(result[key], np.ndarray) | ||
assert result[key].shape == (3,) | ||
|
||
# Optionally, assert that the values are close to the expected values | ||
# This will require calculating the expected values manually or using a different implementation for comparison | ||
# For example: | ||
# expected_values = {'RMSE': np.array([0, 0, 0]), 'Rho': np.array([1, 1, 1]), 'pRho': np.array([1, 1, 1])} | ||
# for key in expected_keys: | ||
# np.testing.assert_allclose(result[key], expected_values[key]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
from pcntoolkit.util.utils import create_poly_basis | ||
import numpy as np | ||
|
||
|
||
@pytest.mark.parametrize("samples, dim, degree, expected_shape", [(10, 3, 2, (10, 6)), (10, 3, 3, (10, 9))]) | ||
def test_create_poly_basis_shape(samples, dim, degree, expected_shape): | ||
# Create mock data | ||
X = np.random.randn(samples,dim) | ||
|
||
# Call the function with the mock data | ||
result = create_poly_basis(X, degree) | ||
|
||
# Assert that the result is a numpy array | ||
assert isinstance(result, np.ndarray) | ||
|
||
# Assert that the result has the correct shape | ||
assert result.shape == expected_shape | ||
|
||
|
||
def test_create_poly_basis(): | ||
X = np.array([[1,2,3], [4, 5, 6]]) | ||
result = create_poly_basis(X, 2) | ||
np.testing.assert_array_equal(result, np.array([[1., 2., 3., 1., 4., 9.], [4., 5., 6., 16., 25., 36.]])) | ||
|
||
|
||
from pcntoolkit.util.utils import create_bspline_basis | ||
|
||
@pytest.mark.parametrize("xmin, xmax, p, nknots", [(1, 3, 3, 4), (2, 5, 6, 9)]) | ||
def test_create_bspline_basis_shape(xmin, xmax, p, nknots): | ||
|
||
# Call the function with the mock data | ||
result = create_bspline_basis(xmin, xmax, p, nknots) | ||
|
||
# Assert that the properties of the result match the input parameters | ||
assert result.p == p | ||
assert result.knot_vector.size == 2*p + nknots | ||
assert result.knot_vector[0] == xmin | ||
assert result.knot_vector[-1] == xmax |