-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tests are too slow for libraries with eager-mode dispatch overhead #197
Comments
They shouldn't be that slow. The NumPy tests finish in a fraction of that time https://github.com/data-apis/array-api-compat/actions/runs/6565885145/job/17835396340. It might be worth investigating what is going on. Although even there, the NumPy tests took an hour to run, which seems like a lot. We should investigate if something slowed down recently. Generally, though, you can speed things up by lowering the hypothesis |
I just ran |
Can you run something like |
I think the slowness comes from test patterns that look like this: array-api-tests/array_api_tests/test_creation_functions.py Lines 358 to 364 in f82c7bc
In JAX, each operation outside the context of JIT compilation has a small amount of overhead related to dispatch and device placement for the output, so running |
There's a variable that limits the max array size
|
I pasted the slowest 20 tests below. For the most part I think it comes down to the slow repeated indexing I mentioned above: every one of these that I've checked indexes each element of the output to check it against a reference implementation.
|
Would you be open to changing logic that looks like this: array-api-tests/array_api_tests/test_creation_functions.py Lines 358 to 364 in f82c7bc
to something more like this? k = kw.get("k", 0)
expected = [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)]
assert xp.all(xp.asarray(expected) == out) It does depend on Rewriting tests this way would make the test suite usable for JAX and other libraries that have non-negligible dispatch overhead in eager mode. I would be happy to prepare a PR if you think this is the right direction for the testing suite. |
That would work for the creation functions like array-api-tests/array_api_tests/test_operators_and_elementwise_functions.py Lines 822 to 837 in f82c7bc
The loop is done in this helper
You have to loop through the input arrays to generate the exact outputs. Maybe there's some way to extract the original input array list from hypothesis (assuming it didn't go through any further transformations after being converted to an array). @honno, thoughts? Anyway, I think changing it to be like that for the creation functions is fine. |
Anyway, I think the best solution for you is to make |
Another option would be to use dlpack to export the full array to |
I did a bit of profiling and testing to figure out what's going on. The result of: $ py-spy record -o profile.svg -- pytest array_api_tests/ is this: That takes about 15 minutes, and ~75% of the time is spent in --ignore=array_api_tests/test_special_cases.py That brings it down to ~4 minutes. With NumPy 1.26.0, there are still 4 failures when running on --max-examples=1 and that runs in about 3 seconds. Adding back the special cases tests makes it run in 18 seconds with Right now we are in the position that we need this 3 second run to pass on libraries. That is far more important than anything else; there are, as of now, zero implementations that actually pass 100%. The It looks to me like we need to focus on that. And in addition, make some of the extra-costly checks for JAX vectorized, as we're discussing in gh-200 right now. |
Thanks - with I did find that it errors with this week's hypothesis 6.88.4 release; perhaps I should file a bug for that separately. |
Awesome, that's great to see! |
For example this job: https://github.com/google/jax/actions/runs/6804780902/job/18502996983, which is associated with jax-ml/jax#16099 took 4.5 hours to run about 1000 tests.
This is slow enough to be virtually unusable.
Is there anything I can do to speed up the tests during development of the JAX array API?Edit: as mentioned below, this is due to the frequent use of patterns within the tests that every value in each array via indexing, and the fact that in eager execution, each of these indexing operations have a small dispatch overhead that leads to slow tests when arrays are large.
The text was updated successfully, but these errors were encountered: