Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests are too slow for libraries with eager-mode dispatch overhead #197

Open
jakevdp opened this issue Nov 9, 2023 · 13 comments
Open

Tests are too slow for libraries with eager-mode dispatch overhead #197

jakevdp opened this issue Nov 9, 2023 · 13 comments

Comments

@jakevdp
Copy link
Contributor

jakevdp commented Nov 9, 2023

For example this job: https://github.com/google/jax/actions/runs/6804780902/job/18502996983, which is associated with jax-ml/jax#16099 took 4.5 hours to run about 1000 tests.

========== 567 failed, 439 passed, 116 skipped in 16419.47s (4:33:39) ==========

This is slow enough to be virtually unusable.

Is there anything I can do to speed up the tests during development of the JAX array API?

Edit: as mentioned below, this is due to the frequent use of patterns within the tests that every value in each array via indexing, and the fact that in eager execution, each of these indexing operations have a small dispatch overhead that leads to slow tests when arrays are large.

@asmeurer
Copy link
Member

asmeurer commented Nov 9, 2023

They shouldn't be that slow. The NumPy tests finish in a fraction of that time https://github.com/data-apis/array-api-compat/actions/runs/6565885145/job/17835396340. It might be worth investigating what is going on.

Although even there, the NumPy tests took an hour to run, which seems like a lot. We should investigate if something slowed down recently.

Generally, though, you can speed things up by lowering the hypothesis --max-examples. By default it is 100, but something like 50 should make the tests run in roughly half the time.

@asmeurer
Copy link
Member

asmeurer commented Nov 9, 2023

I just ran ARRAY_API_TESTS_MODULE=numpy.array_api pytest array_api_tests/ locally and it took 12 minutes. So that's the about time the test should run in (there were a few failures which speed up the runtime a bit, but in general it shouldn't take more than 20 minutes).

@asmeurer
Copy link
Member

asmeurer commented Nov 9, 2023

Can you run something like pytest --durations=10 to print the 10 slowest tests? That would help to pin down what is going on.

@jakevdp
Copy link
Contributor Author

jakevdp commented Nov 9, 2023

I think the slowness comes from test patterns that look like this:

for i in range(n_rows):
for j in range(_n_cols):
f_indexed_out = f"out[{i}, {j}]={out[i, j]}"
if j - i == kw.get("k", 0):
assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}"
else:
assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}"

In JAX, each operation outside the context of JIT compilation has a small amount of overhead related to dispatch and device placement for the output, so running $\mathcal{O}[N^2]$ indexing operations in a loop will accumulate that overhead and be very slow.

@asmeurer
Copy link
Member

asmeurer commented Nov 9, 2023

There's a variable that limits the max array size

. The default is 10000 but we should make it configurable. Lowering it to 1000 or so for JAX would probably fix this.

@jakevdp
Copy link
Contributor Author

jakevdp commented Nov 9, 2023

I pasted the slowest 20 tests below.

For the most part I think it comes down to the slow repeated indexing I mentioned above: every one of these that I've checked indexes each element of the output to check it against a reference implementation.

54.85s call     array_api_tests/test_creation_functions.py::test_linspace
35.72s call     array_api_tests/test_creation_functions.py::test_eye
25.87s call     array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)]
21.24s call     array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)]
20.68s call     array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)]
20.64s call     array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_and[bitwise_and(x1, x2)]
19.93s call     array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x, s)]
19.88s call     array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[not_equal(x1, x2)]
19.72s call     array_api_tests/test_operators_and_elementwise_functions.py::test_less_equal[less_equal(x1, x2)]
19.59s call     array_api_tests/test_operators_and_elementwise_functions.py::test_less[less(x1, x2)]
18.77s call     array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_or[bitwise_or(x1, x2)]
18.14s call     array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)]
17.24s call     array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[bitwise_left_shift(x1, x2)]
16.71s call     array_api_tests/test_operators_and_elementwise_functions.py::test_divide[divide(x1, x2)]
14.99s call     array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[multiply(x1, x2)]
13.70s call     array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_xor[bitwise_xor(x1, x2)]
13.69s call     array_api_tests/test_operators_and_elementwise_functions.py::test_logical_or
13.46s call     array_api_tests/test_operators_and_elementwise_functions.py::test_atan2
13.00s call     array_api_tests/test_operators_and_elementwise_functions.py::test_add[add(x1, x2)]

@jakevdp
Copy link
Contributor Author

jakevdp commented Nov 10, 2023

Would you be open to changing logic that looks like this:

for i in range(n_rows):
for j in range(_n_cols):
f_indexed_out = f"out[{i}, {j}]={out[i, j]}"
if j - i == kw.get("k", 0):
assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}"
else:
assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}"

to something more like this?

k = kw.get("k", 0)
expected = [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)]
assert xp.all(xp.asarray(expected) == out)

It does depend on asarray and all working properly, but the previous approach depends on indexing working properly. You lose the granularity of the error, but that could be addressed using where when generating the error message.

Rewriting tests this way would make the test suite usable for JAX and other libraries that have non-negligible dispatch overhead in eager mode.

I would be happy to prepare a PR if you think this is the right direction for the testing suite.

@asmeurer
Copy link
Member

That would work for the creation functions like eye, but how would you change the elementwise function tests? For example, test_bitwise_left_shift (one of your slowest tests):

def test_bitwise_left_shift(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
right = data.draw(ctx.right_strat, label=ctx.right_sym)
if ctx.right_is_scalar:
assume(right >= 0)
else:
assume(not xp.any(ah.isnegative(right)))
res = ctx.func(left, right)
binary_param_assert_dtype(ctx, left, right, res)
binary_param_assert_shape(ctx, left, right, res)
nbits = res.dtype
binary_param_assert_against_refimpl(
ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0
)

The loop is done in this helper

for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):

You have to loop through the input arrays to generate the exact outputs. Maybe there's some way to extract the original input array list from hypothesis (assuming it didn't go through any further transformations after being converted to an array). @honno, thoughts?

Anyway, I think changing it to be like that for the creation functions is fine.

@asmeurer
Copy link
Member

Anyway, I think the best solution for you is to make MAX_ARRAY_SIZE configurable. I expect 1000 (or even 100) would be fine for most test cases, but would drop the runtime of those tests down correspondingly.

@jakevdp
Copy link
Contributor Author

jakevdp commented Nov 10, 2023

Another option would be to use dlpack to export the full array to numpy or some other standard, where eager-mode indexing is not problematic.

@jakevdp jakevdp changed the title Tests are too slow... Tests are too slow for libraries with eager-mode dispatch overhead Nov 10, 2023
@rgommers
Copy link
Member

I did a bit of profiling and testing to figure out what's going on. The result of:

$ py-spy record -o profile.svg -- pytest array_api_tests/

is this:

profile

That takes about 15 minutes, and ~75% of the time is spent in test_special_cases.py. Special cases are really not of interest to start with when developing an array API implementation, so the next step bash to add:

--ignore=array_api_tests/test_special_cases.py

That brings it down to ~4 minutes. With NumPy 1.26.0, there are still 4 failures when running on numpy.array_api. These 4 failures can be reproduced also with

--max-examples=1

and that runs in about 3 seconds. Adding back the special cases tests makes it run in 18 seconds with --max-examples=1.

Right now we are in the position that we need this 3 second run to pass on libraries. That is far more important than anything else; there are, as of now, zero implementations that actually pass 100%. The array_api_compat layer is missing array methods like .mT and .to_device for numpy and shows casting rule issues (as expected until numpy 2.0); plain numpy.array_api comes closest but still needs a couple of fixes.

It looks to me like we need to focus on that. And in addition, make some of the extra-costly checks for JAX vectorized, as we're discussing in gh-200 right now.

@jakevdp
Copy link
Contributor Author

jakevdp commented Nov 14, 2023

Thanks - with --max-examples=5 and a number of skips related to mutation and other issues, the jax.experimental.array_api PR now passes the test suite! jax-ml/jax#16099

I did find that it errors with this week's hypothesis 6.88.4 release; perhaps I should file a bug for that separately.

@rgommers
Copy link
Member

Awesome, that's great to see!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants