-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve performance of CAReduce in Numba backend #1109
Conversation
This Op does not really fit the CAReduce API, as it requires an extra bit of information (number of elements in the axis) during the loop. A better solution will be a fused Elemwise+CAReduce
bfa16dd
to
2bc894a
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1109 +/- ##
==========================================
- Coverage 82.12% 82.10% -0.03%
==========================================
Files 183 183
Lines 48111 48030 -81
Branches 8667 8658 -9
==========================================
- Hits 39510 39433 -77
+ Misses 6435 6434 -1
+ Partials 2166 2163 -3
|
2bc894a
to
79e8109
Compare
Here is a direct comparison of C and numba backends for the non C-contiguous case: import numpy as np
import pytensor
c_contiguous = False
for transpose_in_graph in (True, False):
rng = np.random.default_rng(123)
N = 256
x_test = rng.uniform(size=(N, N, N))
transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1)
if not transpose_in_graph:
x_test = x_test.transpose(transpose_axis)
x = pytensor.shared(x_test, name="x", shape=x_test.shape, borrow=True)
if transpose_in_graph:
x = x.transpose(transpose_axis)
out = x.sum(axis=0)
c_fn = pytensor.function([], out, mode="FAST_COMPILE")
numba_fn = pytensor.function([], out, mode="NUMBA").vm.jit_fn
np.testing.assert_allclose(c_fn(), numba_fn()[0])
print(f"{transpose_in_graph=}")
%timeit c_fn()
%timeit numba_fn()
# transpose_in_graph=True
# 33.7 ms ± 2.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 188 ms ± 4.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# transpose_in_graph=False
# 33 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 103 ms ± 1.96 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) Airect numba implementation shows the same bad performance. import numpy as np
import numba
c_contiguous = False
rng = np.random.default_rng(123)
N = 256
x_test = rng.uniform(size=(N, N, N))
transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1)
x_test = x_test.transpose(transpose_axis)
out_dtype = np.float64
@numba.njit(fastmath=True, boundscheck=False)
def careduce_add(x):
x_shape = x.shape
res_shape = (x_shape[1], x_shape[2])
res = np.full((x_shape[1], x_shape[2]), np.asarray(0.0).item(), dtype=out_dtype)
for i0 in range(x_shape[0]):
for i1 in range(x_shape[1]):
for i2 in range(x_shape[2]):
res[i1, i2] += x[i0, i1, i2]
return res
np.testing.assert_allclose(careduce_add(x_test), np.sum(x_test, 0))
%timeit careduce_add(x_test)
# 136 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) |
79e8109
to
6268d99
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the walk through in the comparison @ricardoV94 , definitely interesting
Numba doing badly on the non-contiguous case is all due to loop ordering. LLVM doesn't reorder based on strides :( Anyway this PR improves overall, better old speeds where just due to chance when the reduced loop was the one with smallest strides |
Closes #935
Closes #931
The implementation for multiple axes no longer operates one axis at a time. Here are the benchmarks for the Sum test before and after this PR:
Note that we have a special dispatch for
Sum(axes=None)
introduced in #92, so the changes are not reflected in that benchmark. I temporarily disabled the special dispatch, to confirm that case is still improved:Because it is still a bit slower, and this is the most common reduction, I decided to leave the special case.
Numba doesn't seem to optimize non-contiguous arrays very well. The C backend implementation with explicit loop reordering written in #971 does not show such a penalty.
Finally we also see an improvement in the slowest case of the pre-existing numba-logsumexp benchmark: