Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use KernelAbstractions.jl for upsample kernels #486

Merged
merged 11 commits into from
Apr 10, 2023
Merged

Use KernelAbstractions.jl for upsample kernels #486

merged 11 commits into from
Apr 10, 2023

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Apr 5, 2023

Ref #479.

  • Migrate upsample kernels to KernelAbstractions.jl (implementation migrated from NNlibCUDA).
    Successfully tested on CPU, ROCBackend, CUDABackend.
    From the user perspective, there are no API changes (except align_corners option).
    Backend is selected automatically using KernelAbstractions.get_backend(x).

  • Introduce NNlib test suite, which accepts backend as an argument and runs tests for respective device.
    For now, it contains only upsample tests, but as we move other kernels to KernelAbstraction, it makes sense to move the rest of the tests to it as well.
    This way, we don't need to rewrite tests for each backend.

    When backend is

    • CPU, gradtest checks Zygote gradient against FiniteDifferences (as was before).
    • not CPU, gradtest checks agains CPU.
      This is because AMDGPU.jl fails with FiniteDifferences.jl.
      Still, fow now this is fine, because test suite for CPU runs first.
AssertionError: ROCArray only supports bits types
  Stacktrace:
    [1] ROCVector{Tuple{Vector{Float32}, FiniteDifferences.var"#Real_from_vec#20"}}(buf::AMDGPU.Runtime.Mem.Buffer, dims::Tuple{Int64}; offset::Int64)
    ...
  • Add align_corners option (copied from NNlibCUDA.jl, but now supports CPU as well).
  • Remove upsample from ext/NNlibCUDA.

Side note: this PR is based on #483 since it introduces separate Project.toml for tests.

  • Add benchmarks comparing performance before & after.

PR Checklist

  • Tests are added
  • Documentation, if applicable

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we don't have upsampling (or anything other than activation functions) in the benchmarks, but do you have a rough idea of how the KA implementations stack up against the existing ones?

src/upsample.jl Show resolved Hide resolved
pxl-th added 3 commits April 6, 2023 11:41
- Add `align_corners` option.
- Add unified test suite which accepts backend as an argument
    and runs tests for it.
@pxl-th
Copy link
Member Author

pxl-th commented Apr 6, 2023

Benchmarks

CPU:

  • Before:
julia> using BenchmarkTools, NNlib, Zygote

julia> x = rand(Float32, 128, 128, 128, 16);

julia> @benchmark NNlib.upsample_bilinear(x, (2, 2))
BenchmarkTools.Trial: 76 samples with 1 evaluation.
 Range (min  max):  47.665 ms  74.961 ms  ┊ GC (min  max):  0.65%  27.56%
 Time  (median):     67.239 ms              ┊ GC (median):    26.38%
 Time  (mean ± σ):   66.386 ms ±  4.412 ms  ┊ GC (mean ± σ):  27.19% ±  5.61%

                                      █▅        ▇▆             
  ▃▁▃▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅██▆▄▁▅▁▁▁▄██▄▄▃▄▁▃▁▃▃▃▃ ▁
  47.7 ms         Histogram: frequency by time        73.9 ms <

 Memory estimate: 512.01 MiB, allocs estimate: 101.

julia> @benchmark Zygote.gradient(x -> sum(NNlib.upsample_bilinear(x, (2, 2))), x)
BenchmarkTools.Trial: 28 samples with 1 evaluation.
 Range (min  max):  154.072 ms  191.380 ms  ┊ GC (min  max):  1.02%  15.65%
 Time  (median):     180.883 ms               ┊ GC (median):    16.55%
 Time  (mean ± σ):   181.300 ms ±   6.365 ms  ┊ GC (mean ± σ):  15.91% ±  2.94%

                                            ██▁                  
  ▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▄███▇▇▁▄▁▁▁▁▄▇▁▁▁▁▁▇ ▁
  154 ms           Histogram: frequency by time          191 ms <

 Memory estimate: 768.03 MiB, allocs estimate: 422.
  • This PR:
julia> using BenchmarkTools, NNlib, Zygote

julia> x = rand(Float32, 128, 128, 128, 16);

julia> @benchmark NNlib.upsample_bilinear(x, (2, 2))
BenchmarkTools.Trial: 76 samples with 1 evaluation.
 Range (min  max):  53.949 ms  72.410 ms  ┊ GC (min  max):  0.66%  19.75%
 Time  (median):     66.807 ms              ┊ GC (median):    19.92%
 Time  (mean ± σ):   66.478 ms ±  3.042 ms  ┊ GC (mean ± σ):  19.06% ±  4.15%

                                   ▄▄▆▂▆      ▂█▆▂             
  ▄▁▁▁▁▁▁▁▄▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▄▆█████▄▄▆█▆▆█████▄▁▄▁▁▆▁▁▄▄ ▁
  53.9 ms         Histogram: frequency by time        72.3 ms <

 Memory estimate: 512.01 MiB, allocs estimate: 166.

julia> @benchmark Zygote.gradient(x -> sum(NNlib.upsample_bilinear(x, (2, 2))), x)
BenchmarkTools.Trial: 27 samples with 1 evaluation.
 Range (min  max):  168.529 ms  198.798 ms  ┊ GC (min  max):  0.80%  12.97%
 Time  (median):     191.926 ms               ┊ GC (median):    13.53%
 Time  (mean ± σ):   191.509 ms ±   4.891 ms  ┊ GC (mean ± σ):  13.07% ±  2.46%

                                                ██               
  ▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁███▄▆▆▁▆▁▁▁▁▁▁▁▄ ▁
  169 ms           Histogram: frequency by time          199 ms <

 Memory estimate: 640.04 MiB, allocs estimate: 585.

CUDA:

  • Before:
julia> using CUDA, Zygote, BenchmarkTools, NNlib, NNlibCUDA

julia> x = CUDA.rand(Float32, 128, 128, 128, 16);

julia> @benchmark CUDA.@sync NNlib.upsample_bilinear(x, (2, 2))
BenchmarkTools.Trial: 951 samples with 1 evaluation.
 Range (min  max):  4.072 ms  17.352 ms  ┊ GC (min  max): 0.00%  9.22%
 Time  (median):     4.924 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.193 ms ±  1.168 ms  ┊ GC (mean ± σ):  4.06% ± 7.58%

    ▁▂▃▆▆▅▆█▇▃                                                
  ▆▄██████████▆▇▁▁▅▄▁▄▄▁▁▁▁▁▁▁▄▁▁▁▄▅▁▄▅▅▆▅▇▄▆█▇█▅▄▆▅▆▇▆▆▅▅▆▆ █
  4.07 ms      Histogram: log(frequency) by time     9.38 ms <

 Memory estimate: 3.67 KiB, allocs estimate: 63.

julia> @benchmark CUDA.@sync Zygote.gradient(x -> sum(NNlib.upsample_bilinear(x, (2, 2))), x)
BenchmarkTools.Trial: 189 samples with 1 evaluation.
 Range (min  max):  23.853 ms  31.900 ms  ┊ GC (min  max): 0.00%  11.43%
 Time  (median):     25.381 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   26.413 ms ±  2.252 ms  ┊ GC (mean ± σ):  2.73% ±  4.30%

          ▃█▇▅▃▁                                               
  ▃▃▃▅▃▇█▇██████▅▄▄▃▃▁▁▃▃▁▁▁▁▃▄▃▃▁▁▃▃▁▁▃▁▁▃▁▃▃▃▄▃▃▃▃▃▃▄▃▃▅▃▄▃ ▃
  23.9 ms         Histogram: frequency by time        31.7 ms <

 Memory estimate: 18.97 KiB, allocs estimate: 405.
  • This PR:
julia> using CUDA, Zygote, BenchmarkTools, NNlib

julia> x = CUDA.rand(Float32, 128, 128, 128, 16);

julia> @benchmark CUDA.@sync NNlib.upsample_bilinear(x, (2, 2))
BenchmarkTools.Trial: 1080 samples with 1 evaluation.
 Range (min  max):  4.039 ms   17.729 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     4.426 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   4.611 ms ± 728.586 μs  ┊ GC (mean ± σ):  1.39% ± 3.29%

  ▂▄▆█▆▆▅▆▅▅▄▅▄▅▂▃▂▂▁                                         ▁
  ███████████████████▇▅▄▄▄▆▄▆▄▆▁▆▆▇▇▆▆▆▅▆▄▄▆▆▄▆▄▄▆▆▅▁▆▆▁▅▁▆▆▇ █
  4.04 ms      Histogram: log(frequency) by time      7.05 ms <

 Memory estimate: 6.52 KiB, allocs estimate: 128.

julia> @benchmark CUDA.@sync Zygote.gradient(x -> sum(NNlib.upsample_bilinear(x, (2, 2))), x)
BenchmarkTools.Trial: 199 samples with 1 evaluation.
 Range (min  max):  22.150 ms  30.296 ms  ┊ GC (min  max): 0.00%  5.63%
 Time  (median):     24.243 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   25.103 ms ±  1.843 ms  ┊ GC (mean ± σ):  2.07% ± 3.29%

              ▇▄▄█▂█                                           
  ▃▁▁▁▁▁▁▁▁▁▃████████▆▅▅▃▃▃▃▁▁▁▁▁▁▁▁▃▁▁▁▁▃▁▄▃▅▃▅▅▄▃▃▄▅▄▅▃▃▃▃▄ ▃
  22.2 ms         Histogram: frequency by time        29.3 ms <

 Memory estimate: 25.81 KiB, allocs estimate: 560.

@pxl-th
Copy link
Member Author

pxl-th commented Apr 6, 2023

I had to specialize kernels for CPU & GPU backends as they benefit from different indexing order.
Otherwise we get pretty big regressions (applying GPU kernel for CPU is 4x slower).

Also in the future, we may avoid passing backend as an argument to kernel for CPU/GPU specialization and use AbstractGPUArray for array types.
This requires, however, that XXXDeviceArray subtype AbstractGPUArray instead of AbstractArray (or DenseArray) as it is now.

@pxl-th pxl-th requested a review from ToucheSir April 6, 2023 13:39
@CarloLucibello
Copy link
Member

Instead of wrapping tests within functions, can we define at the global level a device(x) (based on the available backend) and cpu(x) function and maintain the current style?

@pxl-th
Copy link
Member Author

pxl-th commented Apr 10, 2023

Instead of wrapping tests within functions, can we define at the global level a device(x) (based on the available backend) and cpu(x) function and maintain the current style?

For device(x) to make it global we would also have to switch some global DEVICE variable, to be able to call nnlib_testsuite(backend) multiple times during the tests.

DEVICE = CPU
nnlib_testsuite(CPU) # device(x) = adapt(CPU, x)
...
DEVICE = ROCBackend
nnlib_testsuite(ROCBackend) # device(x) = adapt(ROCBackend, x)

If this is preferred, I can make it this way.
But that would make device(x) type unstable, no?

Alternatively, we can leave tests as functions and define device(x) there.

For now I've added cpu(x) as global function and device(x) is defined in upsample_testsuite(x) function.

@ToucheSir
Copy link
Member

ToucheSir commented Apr 10, 2023

Sorry, I forgot to ask how perf compares on CPU with varying numbers of threads. We should probably have a separate PR adding upsampling to the benchmarks suite. Otherwise this LGTM

@pxl-th
Copy link
Member Author

pxl-th commented Apr 10, 2023

I've optimized CPU kernels a bit more, they are now even faster.

--threads=1:

  • Before:
julia> @benchmark NNlib.upsample_bilinear(x, (2, 2))
BenchmarkTools.Trial: 14 samples with 1 evaluation.
 Range (min  max):  364.796 ms  380.752 ms  ┊ GC (min  max): 0.06%  2.46%
 Time  (median):     374.844 ms               ┊ GC (median):    2.49%
 Time  (mean ± σ):   374.699 ms ±   3.776 ms  ┊ GC (mean ± σ):  2.14% ± 0.89%

                                      █ ▄                        
  ▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁█▁▁▁▁▁▆▆▆▁▁▁▁▆▁▁▁▁▁▁▁▁▆ ▁
  365 ms           Histogram: frequency by time          381 ms <

 Memory estimate: 512.00 MiB, allocs estimate: 10.

julia> @benchmark Zygote.gradient(x -> sum(NNlib.upsample_bilinear(x, (2, 2))), x)
BenchmarkTools.Trial: 7 samples with 1 evaluation.
 Range (min  max):  780.481 ms  801.564 ms  ┊ GC (min  max): 0.13%  2.20%
 Time  (median):     799.364 ms               ┊ GC (median):    2.09%
 Time  (mean ± σ):   796.938 ms ±   7.326 ms  ┊ GC (mean ± σ):  1.81% ± 0.75%

  ▁                                                   ▁▁█  ▁  ▁  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁███▁▁█▁▁█ ▁
  780 ms           Histogram: frequency by time          802 ms <

 Memory estimate: 768.01 MiB, allocs estimate: 238.
  • This PR:
julia> @benchmark NNlib.upsample_bilinear(x, (2, 2))
BenchmarkTools.Trial: 15 samples with 1 evaluation.
 Range (min  max):  342.028 ms  352.197 ms  ┊ GC (min  max): 0.07%  2.85%
 Time  (median):     351.279 ms               ┊ GC (median):    2.74%
 Time  (mean ± σ):   350.428 ms ±   2.765 ms  ┊ GC (mean ± σ):  2.40% ± 0.95%

                                                        █▄▁ ▄    
  ▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁███▁█▁▆ ▁
  342 ms           Histogram: frequency by time          352 ms <

 Memory estimate: 512.00 MiB, allocs estimate: 13.

julia> @benchmark Zygote.gradient(x -> sum(NNlib.upsample_bilinear(x, (2, 2))), x)
BenchmarkTools.Trial: 7 samples with 1 evaluation.
 Range (min  max):  752.479 ms  768.879 ms  ┊ GC (min  max): 0.15%  2.22%
 Time  (median):     766.059 ms               ┊ GC (median):    2.17%
 Time  (mean ± σ):   764.923 ms ±   5.610 ms  ┊ GC (mean ± σ):  1.86% ± 0.76%

  ▁                                                ▁█    ▁ ▁  ▁  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██▁▁▁▁█▁█▁▁█ ▁
  752 ms           Histogram: frequency by time          769 ms <

 Memory estimate: 640.01 MiB, allocs estimate: 275.

--threads=16:

  • Before:
julia> using BenchmarkTools, NNlib, Zygote

julia> x = rand(Float32, 128, 128, 128, 16);

julia> @benchmark NNlib.upsample_bilinear(x, (2, 2))
BenchmarkTools.Trial: 76 samples with 1 evaluation.
 Range (min  max):  47.665 ms  74.961 ms  ┊ GC (min  max):  0.65%  27.56%
 Time  (median):     67.239 ms              ┊ GC (median):    26.38%
 Time  (mean ± σ):   66.386 ms ±  4.412 ms  ┊ GC (mean ± σ):  27.19% ±  5.61%

                                      █▅        ▇▆             
  ▃▁▃▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅██▆▄▁▅▁▁▁▄██▄▄▃▄▁▃▁▃▃▃▃ ▁
  47.7 ms         Histogram: frequency by time        73.9 ms <

 Memory estimate: 512.01 MiB, allocs estimate: 101.

julia> @benchmark Zygote.gradient(x -> sum(NNlib.upsample_bilinear(x, (2, 2))), x)
BenchmarkTools.Trial: 28 samples with 1 evaluation.
 Range (min  max):  154.072 ms  191.380 ms  ┊ GC (min  max):  1.02%  15.65%
 Time  (median):     180.883 ms               ┊ GC (median):    16.55%
 Time  (mean ± σ):   181.300 ms ±   6.365 ms  ┊ GC (mean ± σ):  15.91% ±  2.94%

                                            ██▁                  
  ▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▄███▇▇▁▄▁▁▁▁▄▇▁▁▁▁▁▇ ▁
  154 ms           Histogram: frequency by time          191 ms <

 Memory estimate: 768.03 MiB, allocs estimate: 422.
  • This PR:
julia> @benchmark NNlib.upsample_bilinear(x, (2, 2))
BenchmarkTools.Trial: 87 samples with 1 evaluation.
 Range (min  max):  47.044 ms  68.864 ms  ┊ GC (min  max):  0.80%  16.62%
 Time  (median):     57.800 ms              ┊ GC (median):    19.52%
 Time  (mean ± σ):   57.725 ms ±  2.507 ms  ┊ GC (mean ± σ):  18.97% ±  3.56%

                                          ▁ ▇█▂▄▁▁ ▅           
  ▃▃▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▃▅▆▆█▅████████▆▆▃▁▁▅▅▁▃ ▁
  47 ms           Histogram: frequency by time        61.2 ms <

 Memory estimate: 512.01 MiB, allocs estimate: 167.

julia> @benchmark Zygote.gradient(x -> sum(NNlib.upsample_bilinear(x, (2, 2))), x)
BenchmarkTools.Trial: 30 samples with 1 evaluation.
 Range (min  max):  149.817 ms  181.332 ms  ┊ GC (min  max):  0.93%  10.90%
 Time  (median):     166.550 ms               ┊ GC (median):    11.34%
 Time  (mean ± σ):   166.666 ms ±   4.378 ms  ┊ GC (mean ± σ):  11.04% ±  1.96%

                               ▃  █▁                             
  ▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄█▁▇██▇▄▄▁▄▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▁
  150 ms           Histogram: frequency by time          181 ms <

 Memory estimate: 640.04 MiB, allocs estimate: 584.

@ToucheSir ToucheSir closed this Apr 10, 2023
@ToucheSir ToucheSir reopened this Apr 10, 2023
Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Integration test failure looks like a hiccup because it's gone now. Anything left to do before merging?

@pxl-th
Copy link
Member Author

pxl-th commented Apr 10, 2023

No, should be good to go.

@ToucheSir ToucheSir merged commit ee909e6 into FluxML:master Apr 10, 2023
@CarloLucibello
Copy link
Member

Why CI jobs have doubled?

@pxl-th
Copy link
Member Author

pxl-th commented Apr 11, 2023

I think it's because PR was closed and re-opened to re-trigger the CI

@pxl-th pxl-th deleted the ka-kernels branch April 11, 2023 10:38
@maxfreu
Copy link
Contributor

maxfreu commented May 15, 2023

Hi @pxl-th, thanks for making these changes! I had something like this in my mind since I ported the pytorch code, but never made it - good to see someone else did it :)

@pxl-th
Copy link
Member Author

pxl-th commented May 16, 2023

@maxfreu, no problem! :) I plan to tackle other as well at some point

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants