Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of CAReduce in Numba backend #1109

Merged
merged 2 commits into from
Nov 29, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 29, 2024

Closes #935
Closes #931

The implementation for multiple axes no longer operates one axis at a time. Here are the benchmarks for the Sum test before and after this PR:

NUMBA Before:
---------------------------------------------------------------------------------------------------------- benchmark: 14 tests -----------------------------------------------------------------------------------------------------------
Name (time in ms)                                                  Min                   Max                  Mean             StdDev                Median                IQR            Outliers       OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_careduce_benchmark[c_contiguous=True-axis=None]            6.0967 (1.0)          7.5594 (1.0)          6.2753 (1.0)       0.2068 (1.13)         6.2230 (1.0)       0.0879 (1.0)         14;17  159.3552 (1.0)         158           1
test_careduce_benchmark[c_contiguous=True-axis=(1, 2)]         14.3584 (2.36)        15.3677 (2.03)        14.5429 (2.32)      0.1825 (1.0)         14.5005 (2.33)      0.1169 (1.33)          5;4   68.7621 (0.43)         70           1
test_careduce_benchmark[c_contiguous=True-axis=(0, 2)]         14.6134 (2.40)        24.6524 (3.26)        15.2444 (2.43)      1.4982 (8.21)        14.8519 (2.39)      0.2748 (3.13)          4;9   65.5977 (0.41)         69           1
test_careduce_benchmark[c_contiguous=True-axis=2]              23.1851 (3.80)        39.2687 (5.19)        35.4508 (5.65)      5.1560 (28.26)       38.1168 (6.13)      3.7166 (42.30)         4;4   28.2081 (0.18)         28           1
test_careduce_benchmark[c_contiguous=True-axis=1]              41.4532 (6.80)        42.5562 (5.63)        41.9859 (6.69)      0.2191 (1.20)        42.0163 (6.75)      0.2281 (2.60)          6;2   23.8175 (0.15)         24           1
test_careduce_benchmark[c_contiguous=True-axis=(0, 1)]         41.5490 (6.81)        45.7928 (6.06)        42.8554 (6.83)      1.1483 (6.29)        42.5019 (6.83)      0.8048 (9.16)          4;3   23.3343 (0.15)         24           1
test_careduce_benchmark[c_contiguous=False-axis=None]         165.3848 (27.13)      174.0071 (23.02)      168.2743 (26.82)     3.3571 (18.40)      167.2065 (26.87)     4.1945 (47.74)         1;0    5.9427 (0.04)          6           1
test_careduce_benchmark[c_contiguous=False-axis=2]            174.8747 (28.68)      190.2124 (25.16)      179.4774 (28.60)     5.5442 (30.39)      178.5084 (28.69)     3.6153 (41.15)         1;1    5.5717 (0.03)          6           1
test_careduce_benchmark[c_contiguous=False-axis=(1, 2)]       174.9006 (28.69)      177.3417 (23.46)      176.2445 (28.09)     0.8405 (4.61)       176.1939 (28.31)     0.9043 (10.29)         2;0    5.6739 (0.04)          6           1
test_careduce_benchmark[c_contiguous=False-axis=1]            197.9328 (32.47)      203.0997 (26.87)      200.9122 (32.02)     2.5227 (13.83)      202.5147 (32.54)     4.4556 (50.71)         1;0    4.9773 (0.03)          5           1
test_careduce_benchmark[c_contiguous=False-axis=(0, 2)]       199.6480 (32.75)      206.9815 (27.38)      203.5135 (32.43)     3.2742 (17.94)      203.4289 (32.69)     6.0479 (68.83)         2;0    4.9137 (0.03)          5           1
test_careduce_benchmark[c_contiguous=False-axis=(0, 1)]       204.7145 (33.58)      209.9537 (27.77)      206.4983 (32.91)     2.1112 (11.57)      206.2483 (33.14)     2.6657 (30.34)         1;0    4.8427 (0.03)          5           1
test_careduce_benchmark[c_contiguous=False-axis=0]            888.5353 (145.74)   1,002.9280 (132.67)     939.2021 (149.67)   47.3530 (259.52)     924.2738 (148.53)   77.3944 (880.83)        2;0    1.0647 (0.01)          5           1
test_careduce_benchmark[c_contiguous=True-axis=0]           1,004.1833 (164.71)   1,171.4621 (154.97)   1,106.3967 (176.31)   65.4015 (358.44)   1,121.6751 (180.25)   88.7262 (>1000.0)       1;0    0.9038 (0.01)          5           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


NUMBA After:
------------------------------------------------------------------------------------------------------ benchmark: 14 tests -------------------------------------------------------------------------------------------------------
Name (time in ms)                                                Min                 Max                Mean             StdDev              Median                IQR            Outliers       OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_careduce_benchmark[c_contiguous=True-axis=None]          5.3386 (1.0)        7.1268 (1.0)        6.1742 (1.0)       0.3182 (3.79)       6.1859 (1.0)       0.1687 (1.25)        52;42  161.9631 (1.0)         182           1
test_careduce_benchmark[c_contiguous=True-axis=1]             8.8967 (1.67)      13.4488 (1.89)      10.0940 (1.63)      1.1528 (13.73)      9.6278 (1.56)      1.5788 (11.67)        17;4   99.0684 (0.61)         94           1
test_careduce_benchmark[c_contiguous=True-axis=(0, 1)]       10.3976 (1.95)      13.2198 (1.85)      10.6963 (1.73)      0.3954 (4.71)      10.5585 (1.71)      0.2725 (2.02)          8;5   93.4903 (0.58)         94           1
test_careduce_benchmark[c_contiguous=True-axis=(0, 2)]       13.8707 (2.60)      14.7292 (2.07)      14.2018 (2.30)      0.2990 (3.56)      14.0646 (2.27)      0.5905 (4.37)         28;0   70.4137 (0.43)         71           1
test_careduce_benchmark[c_contiguous=True-axis=(1, 2)]       14.5889 (2.73)      14.9401 (2.10)      14.7559 (2.39)      0.0840 (1.0)       14.7547 (2.39)      0.1352 (1.0)          27;0   67.7696 (0.42)         68           1
test_careduce_benchmark[c_contiguous=True-axis=0]            14.7935 (2.77)      17.8587 (2.51)      15.7053 (2.54)      0.9524 (11.34)     15.2319 (2.46)      1.5442 (11.42)        13;0   63.6727 (0.39)         65           1
test_careduce_benchmark[c_contiguous=True-axis=2]            22.5702 (4.23)      38.0029 (5.33)      30.4078 (4.92)      7.3567 (87.62)     36.7073 (5.93)     14.7076 (108.76)       21;0   32.8863 (0.20)         44           1
test_careduce_benchmark[c_contiguous=False-axis=None]       167.2684 (31.33)    172.2568 (24.17)    168.6540 (27.32)     1.8952 (22.57)    167.9704 (27.15)     1.6657 (12.32)         1;1    5.9293 (0.04)          6           1
test_careduce_benchmark[c_contiguous=False-axis=1]          173.8041 (32.56)    176.8286 (24.81)    175.0461 (28.35)     1.1514 (13.71)    174.8555 (28.27)     1.5803 (11.69)         2;0    5.7128 (0.04)          6           1
test_careduce_benchmark[c_contiguous=False-axis=0]          175.8450 (32.94)    178.8685 (25.10)    177.4204 (28.74)     1.2897 (15.36)    177.7929 (28.74)     2.4946 (18.45)         3;0    5.6363 (0.03)          6           1
test_careduce_benchmark[c_contiguous=False-axis=(1, 2)]     180.2156 (33.76)    181.9736 (25.53)    181.2290 (29.35)     0.6001 (7.15)     181.2596 (29.30)     0.5724 (4.23)          2;0    5.5179 (0.03)          6           1
test_careduce_benchmark[c_contiguous=False-axis=(0, 2)]     182.3413 (34.16)    185.7313 (26.06)    184.1320 (29.82)     1.1762 (14.01)    184.2766 (29.79)     1.2764 (9.44)          2;0    5.4309 (0.03)          6           1
test_careduce_benchmark[c_contiguous=False-axis=2]          194.9572 (36.52)    283.0956 (39.72)    231.4646 (37.49)    34.2189 (407.53)   224.2623 (36.25)    55.8371 (412.91)        2;0    4.3203 (0.03)          6           1
test_careduce_benchmark[c_contiguous=False-axis=(0, 1)]     202.0370 (37.84)    261.3171 (36.67)    222.9340 (36.11)    26.6018 (316.82)   207.2764 (33.51)    42.3533 (313.20)        1;0    4.4856 (0.03)          5           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Note that we have a special dispatch for Sum(axes=None) introduced in #92, so the changes are not reflected in that benchmark. I temporarily disabled the special dispatch, to confirm that case is still improved:

NUMBA Before (default CAReduce impl):
----------------------------------------------------------------------------------------------------- benchmark: 2 tests ----------------------------------------------------------------------------------------------------
Name (time in ms)                                              Min                 Max                Mean            StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_careduce_benchmark[c_contiguous=True-axis=None]       14.3855 (1.0)       14.8842 (1.0)       14.5902 (1.0)      0.0985 (1.0)       14.5816 (1.0)      0.1100 (1.0)          18;4  68.5389 (1.0)          70           1
test_careduce_benchmark[c_contiguous=False-axis=None]     203.1434 (14.12)    208.9211 (14.04)    205.1373 (14.06)    2.2386 (22.74)    204.2839 (14.01)    2.2561 (20.51)         1;0   4.8748 (0.07)          5           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

NUMBA After  (default CAReduce impl):
----------------------------------------------------------------------------------------------------- benchmark: 2 tests -----------------------------------------------------------------------------------------------------
Name (time in ms)                                              Min                 Max                Mean            StdDev              Median               IQR            Outliers       OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_careduce_benchmark[c_contiguous=True-axis=None]        7.6042 (1.0)        9.2345 (1.0)        8.0017 (1.0)      0.3589 (1.0)        7.8468 (1.0)      0.5667 (1.0)          30;1  124.9732 (1.0)         129           1
test_careduce_benchmark[c_contiguous=False-axis=None]     173.5030 (22.82)    179.8559 (19.48)    176.4238 (22.05)    2.4490 (6.82)     175.8669 (22.41)    3.9248 (6.93)          2;0    5.6682 (0.05)          6           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Because it is still a bit slower, and this is the most common reduction, I decided to leave the special case.

Numba doesn't seem to optimize non-contiguous arrays very well. The C backend implementation with explicit loop reordering written in #971 does not show such a penalty.

C-implementation
------------------------------------------------------------------------------------------------------ benchmark: 14 tests ------------------------------------------------------------------------------------------------------
Name (time in ms)                                               Min                 Max                Mean             StdDev              Median                IQR            Outliers       OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_careduce_benchmark[c_contiguous=True-axis=1]            6.7267 (1.0)       12.1475 (1.29)       8.0947 (1.07)      0.7214 (2.82)       8.0757 (1.09)      0.5199 (4.52)        33;10  123.5382 (0.94)        123           1
test_careduce_benchmark[c_contiguous=True-axis=0]            7.1908 (1.07)       9.4370 (1.0)        7.5843 (1.0)       0.4038 (1.58)       7.4319 (1.0)       0.3982 (3.46)         21;8  131.8516 (1.0)         148           1
test_careduce_benchmark[c_contiguous=True-axis=(0, 1)]       7.2123 (1.07)       9.7355 (1.03)       8.1594 (1.08)      0.4370 (1.71)       8.1734 (1.10)      0.1904 (1.65)        31;33  122.5583 (0.93)        128           1
test_careduce_benchmark[c_contiguous=False-axis=(1, 2)]      8.0725 (1.20)      11.4813 (1.22)       8.5335 (1.13)      0.5003 (1.96)       8.3601 (1.12)      0.3283 (2.85)          8;8  117.1858 (0.89)         82           1
test_careduce_benchmark[c_contiguous=True-axis=None]        13.8293 (2.06)      18.6684 (1.98)      14.3862 (1.90)      0.8901 (3.48)      14.0311 (1.89)      0.4894 (4.25)        10;11   69.5112 (0.53)         72           1
test_careduce_benchmark[c_contiguous=False-axis=None]       13.8392 (2.06)      15.6759 (1.66)      14.0091 (1.85)      0.2654 (1.04)      13.9427 (1.88)      0.1151 (1.0)           5;5   71.3823 (0.54)         72           1
test_careduce_benchmark[c_contiguous=True-axis=2]           45.2172 (6.72)      58.6688 (6.22)      46.8904 (6.18)      3.5423 (13.84)     45.6780 (6.15)      0.5301 (4.60)          2;3   21.3263 (0.16)         22           1
test_careduce_benchmark[c_contiguous=False-axis=(0, 1)]     45.2255 (6.72)      49.0917 (5.20)      46.6270 (6.15)      1.1743 (4.59)      46.1825 (6.21)      1.7203 (14.94)         8;0   21.4468 (0.16)         22           1
test_careduce_benchmark[c_contiguous=True-axis=(0, 2)]      45.2671 (6.73)      46.1873 (4.89)      45.6296 (6.02)      0.2559 (1.0)       45.5888 (6.13)      0.3778 (3.28)          6;0   21.9156 (0.17)         22           1
test_careduce_benchmark[c_contiguous=False-axis=0]          45.7718 (6.80)      49.9460 (5.29)      46.9616 (6.19)      0.9911 (3.87)      46.7856 (6.30)      1.2120 (10.53)         8;1   21.2940 (0.16)         22           1
test_careduce_benchmark[c_contiguous=True-axis=(1, 2)]      48.4265 (7.20)      54.0043 (5.72)      48.9075 (6.45)      1.1805 (4.61)      48.6437 (6.55)      0.1659 (1.44)          1;2   20.4468 (0.16)         21           1
test_careduce_benchmark[c_contiguous=False-axis=(0, 2)]     48.9519 (7.28)      53.3287 (5.65)      49.9922 (6.59)      1.1338 (4.43)      49.5702 (6.67)      1.1638 (10.11)         2;2   20.0031 (0.15)         20           1
test_careduce_benchmark[c_contiguous=False-axis=1]          87.3253 (12.98)    117.7347 (12.48)     97.1947 (12.82)    10.8305 (42.32)     93.1650 (12.54)    10.9362 (95.00)         2;1   10.2886 (0.08)         10           1
test_careduce_benchmark[c_contiguous=False-axis=2]          94.2052 (14.00)    123.4957 (13.09)    109.8449 (14.48)     9.9904 (39.04)    109.6667 (14.76)    13.5596 (117.79)        3;0    9.1037 (0.07)          9           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Finally we also see an improvement in the slowest case of the pre-existing numba-logsumexp benchmark:

NUMBA Before:
--------------------------------------------------------------------------------------------------------------- benchmark: 6 tests --------------------------------------------------------------------------------------------------------------
Name (time in us)                                Min                       Max                      Mean                  StdDev                    Median                     IQR            Outliers          OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_logsumexp_benchmark[1-size0]            13.3750 (1.0)             49.2820 (1.11)            14.1633 (1.0)            1.1505 (1.0)             13.9560 (1.0)            0.2400 (1.0)     1242;1978  70,605.0384 (1.0)       30394           1
test_logsumexp_benchmark[0-size0]            13.9260 (1.04)            44.4430 (1.0)             16.3895 (1.16)           2.4661 (2.14)            15.9300 (1.14)           0.6010 (2.50)     873;3072  61,014.4931 (0.86)      10928           1
test_logsumexp_benchmark[1-size1]         9,546.4760 (713.76)      11,210.3610 (252.24)       9,762.2832 (689.27)       212.6699 (184.84)       9,756.4640 (699.09)       193.6830 (807.01)        8;2     102.4351 (0.00)         78           1
test_logsumexp_benchmark[0-size1]        10,306.7690 (770.60)      13,220.6750 (297.47)      10,757.3244 (759.52)       553.2049 (480.82)      10,552.3635 (756.12)       389.2280 (>1000.0)     11;10      92.9599 (0.00)         74           1
test_logsumexp_benchmark[1-size2]     1,368,416.7610 (>1000.0)  1,399,692.1620 (>1000.0)  1,382,021.1000 (>1000.0)   13,889.7094 (>1000.0)  1,376,665.6560 (>1000.0)   24,415.4015 (>1000.0)       1;0       0.7236 (0.00)          5           1
test_logsumexp_benchmark[0-size2]     2,741,798.7080 (>1000.0)  3,161,025.8590 (>1000.0)  2,877,190.6602 (>1000.0)  184,517.6885 (>1000.0)  2,760,426.2470 (>1000.0)  266,215.4525 (>1000.0)       1;0       0.3476 (0.00)          5           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

NUMBA After:
-------------------------------------------------------------------------------------------------------------- benchmark: 6 tests -------------------------------------------------------------------------------------------------------------
Name (time in us)                                Min                       Max                      Mean                 StdDev                    Median                    IQR            Outliers          OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_logsumexp_benchmark[0-size0]            14.1960 (1.0)            230.8520 (4.02)            15.7622 (1.0)           3.4654 (1.60)            14.9680 (1.0)           0.4543 (1.16)     784;2537  63,443.0537 (1.0)       12349           1
test_logsumexp_benchmark[1-size0]            14.2470 (1.00)            57.4970 (1.0)             16.3850 (1.04)          2.1622 (1.0)             16.0000 (1.07)          0.3910 (1.0)     1971;7148  61,031.3361 (0.96)      26322           1
test_logsumexp_benchmark[0-size1]         8,908.0190 (627.50)       9,353.5040 (162.68)       9,067.2673 (575.26)      107.2531 (49.60)        9,036.6210 (603.73)      140.0220 (358.11)       26;1     110.2868 (0.00)         84           1
test_logsumexp_benchmark[1-size1]         9,565.0100 (673.78)      12,657.2000 (220.14)      10,572.0747 (670.72)    1,046.2932 (483.90)      10,191.3075 (680.87)    1,755.4770 (>1000.0)      13;0      94.5888 (0.00)         64           1
test_logsumexp_benchmark[0-size2]     1,284,855.3190 (>1000.0)  1,396,819.2750 (>1000.0)  1,319,934.6108 (>1000.0)  45,751.4029 (>1000.0)  1,313,243.5640 (>1000.0)  53,306.7305 (>1000.0)       1;0       0.7576 (0.00)          5           1
test_logsumexp_benchmark[1-size2]     1,354,385.0090 (>1000.0)  1,384,451.8370 (>1000.0)  1,367,609.9862 (>1000.0)  11,224.5104 (>1000.0)  1,364,489.5490 (>1000.0)  13,840.7880 (>1000.0)       2;0       0.7312 (0.00)          5           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

This Op does not really fit the CAReduce API, as it requires an extra bit of information (number of elements in the axis) during the loop. A better solution will be a fused Elemwise+CAReduce
@ricardoV94 ricardoV94 force-pushed the better_numba_careduce branch from bfa16dd to 2bc894a Compare November 29, 2024 14:38
Copy link

codecov bot commented Nov 29, 2024

Codecov Report

Attention: Patch coverage is 93.33333% with 3 lines in your changes missing coverage. Please review.

Project coverage is 82.10%. Comparing base (0824dba) to head (6268d99).
Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/elemwise.py 93.18% 2 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1109      +/-   ##
==========================================
- Coverage   82.12%   82.10%   -0.03%     
==========================================
  Files         183      183              
  Lines       48111    48030      -81     
  Branches     8667     8658       -9     
==========================================
- Hits        39510    39433      -77     
+ Misses       6435     6434       -1     
+ Partials     2166     2163       -3     
Files with missing lines Coverage Δ
pytensor/scalar/basic.py 80.50% <ø> (-0.19%) ⬇️
pytensor/tensor/math.py 91.85% <100.00%> (+0.53%) ⬆️
pytensor/link/numba/dispatch/elemwise.py 92.05% <93.18%> (-0.02%) ⬇️

... and 1 file with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the better_numba_careduce branch from 2bc894a to 79e8109 Compare November 29, 2024 16:14
@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 29, 2024

Here is a direct comparison of C and numba backends for the non C-contiguous case:

import numpy as np
import pytensor

c_contiguous = False
for transpose_in_graph in (True, False):
    rng = np.random.default_rng(123)
    N = 256
    x_test = rng.uniform(size=(N, N, N))
    transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1)
    
    if not transpose_in_graph:
        x_test = x_test.transpose(transpose_axis)
    
    x = pytensor.shared(x_test, name="x", shape=x_test.shape, borrow=True)
    
    if transpose_in_graph:
        x = x.transpose(transpose_axis)
        
    out = x.sum(axis=0)
    c_fn = pytensor.function([], out, mode="FAST_COMPILE")
    numba_fn = pytensor.function([], out, mode="NUMBA").vm.jit_fn
    np.testing.assert_allclose(c_fn(), numba_fn()[0])
    print(f"{transpose_in_graph=}")
    %timeit c_fn()
    %timeit numba_fn()
          
# transpose_in_graph=True
# 33.7 ms ± 2.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 188 ms ± 4.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# transpose_in_graph=False
# 33 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 103 ms ± 1.96 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Airect numba implementation shows the same bad performance.

import numpy as np
import numba

c_contiguous = False
rng = np.random.default_rng(123)
N = 256
x_test = rng.uniform(size=(N, N, N))
transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1)
x_test = x_test.transpose(transpose_axis)

out_dtype = np.float64

@numba.njit(fastmath=True, boundscheck=False)
def careduce_add(x):
    x_shape = x.shape
    res_shape = (x_shape[1], x_shape[2])
    res = np.full((x_shape[1], x_shape[2]), np.asarray(0.0).item(), dtype=out_dtype)
    for i0 in range(x_shape[0]):
        for i1 in range(x_shape[1]):
            for i2 in range(x_shape[2]):
                res[i1, i2] += x[i0, i1, i2]
    return res

np.testing.assert_allclose(careduce_add(x_test), np.sum(x_test, 0))
%timeit careduce_add(x_test)  
# 136 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

@ricardoV94 ricardoV94 force-pushed the better_numba_careduce branch from 79e8109 to 6268d99 Compare November 29, 2024 17:01
Copy link

@AlexAndorra AlexAndorra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the walk through in the comparison @ricardoV94 , definitely interesting

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 29, 2024

Numba doing badly on the non-contiguous case is all due to loop ordering. LLVM doesn't reorder based on strides :(

Anyway this PR improves overall, better old speeds where just due to chance when the reduced loop was the one with smallest strides

@ricardoV94 ricardoV94 merged commit ef97287 into pymc-devs:main Nov 29, 2024
61 of 62 checks passed
@ricardoV94 ricardoV94 deleted the better_numba_careduce branch November 30, 2024 11:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reductions along leading axes can be incredibly slow in C and Numba backends
2 participants