Skip to content

Commit

Permalink
Add small fixes to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesYang007 committed Nov 18, 2024
1 parent 3622f51 commit 6a9a141
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
18 changes: 12 additions & 6 deletions tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,18 +425,20 @@ def test_naive_cconcatenate(n, ps, dtype, n_threads=2, seed=0):
)
run_naive(X, cX, dtype)

atol = 1e-4 if dtype == np.float32 else 1e-14

# test mean
w = np.random.uniform(0, 1, cX.shape[0]).astype(dtype)
mean = np.empty(cX.shape[1], dtype=dtype)
cX.mean(w, mean)
expected = X.T @ w
assert np.allclose(mean, expected)
assert np.allclose(mean, expected, atol=atol)

# test var
var = np.empty(cX.shape[1], dtype=dtype)
cX.var(mean, w, var)
expected = np.sum((X - mean) ** 2 * w[:, None], axis=0)
assert np.allclose(var, expected)
assert np.allclose(var, expected, atol=atol)


@pytest.mark.filterwarnings("ignore: Detected matrix to be C-contiguous.")
Expand Down Expand Up @@ -482,18 +484,20 @@ def test_naive_convex_relu(n, d, m, gated, storage, dtype, n_threads=2, seed=0):
cX = mod.convex_relu(Z, mask, gated=gated, n_threads=n_threads)
run_naive(X, cX, dtype)

atol = 1e-4 if dtype == np.float32 else 1e-14

# test mean
w = np.random.uniform(0, 1, cX.shape[0]).astype(dtype)
mean = np.empty(cX.shape[1], dtype=dtype)
cX.mean(w, mean)
expected = X.T @ w
assert np.allclose(mean, expected)
assert np.allclose(mean, expected, atol=atol)

# test var
var = np.empty(cX.shape[1], dtype=dtype)
cX.var(mean, w, var)
expected = np.sum((X - mean) ** 2 * w[:, None], axis=0)
assert np.allclose(var, expected)
assert np.allclose(var, expected, atol=atol)


@pytest.mark.filterwarnings("ignore: Detected matrix to be C-contiguous.")
Expand All @@ -511,18 +515,20 @@ def test_naive_dense(n, p, dtype, order, seed=0):
cX = mod.dense(X, method="naive", n_threads=15)
run_naive(X, cX, dtype)

atol = 1e-4 if dtype == np.float32 else 1e-14

# test mean
w = np.random.uniform(0, 1, cX.shape[0]).astype(dtype)
mean = np.empty(cX.shape[1], dtype=dtype)
cX.mean(w, mean)
expected = X.T @ w
assert np.allclose(mean, expected)
assert np.allclose(mean, expected, atol=atol)

# test var
var = np.empty(cX.shape[1], dtype=dtype)
cX.var(mean, w, var)
expected = np.sum((X - mean) ** 2 * w[:, None], axis=0)
assert np.allclose(var, expected)
assert np.allclose(var, expected, atol=atol)


@pytest.mark.filterwarnings("ignore: Detected matrix to be C-contiguous.")
Expand Down
6 changes: 4 additions & 2 deletions tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,8 @@ def test_css_cov_greedy(p, k, loss, seed):
state = ad.solver.css_cov(S, k, method="greedy", loss=loss)
actual = state.subset
expected = CSSCov(S, loss).greedy(k)
assert np.all(actual == expected)
# sort first since numerical instability may result in different ordering
assert np.all(np.sort(actual) == np.sort(expected))


@pytest.mark.parametrize("p, k", [
Expand All @@ -1235,4 +1236,5 @@ def test_css_cov_swapping(p, k, loss, seed):
state = ad.solver.css_cov(S, k, method="swapping", loss=loss)
actual = state.subset
expected = CSSCov(S, loss).swapping(actual)
assert np.all(actual == expected)
# sort first since numerical instability may result in different ordering
assert np.all(np.sort(actual) == np.sort(expected))

0 comments on commit 6a9a141

Please sign in to comment.