Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible way to implement a LoopVectorization extension for conv2d & meanpool2d & activations #540

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6d32073
Add files via upload
jonas208 Sep 26, 2023
2a3cf3e
Delete ext/NNlibLoopVectorizationExt/conv_old.jl
jonas208 Sep 26, 2023
28a027a
Delete ext/NNlibLoopVectorizationExt/pooling_old.jl
jonas208 Sep 26, 2023
5339aaf
Add files via upload
jonas208 Sep 26, 2023
9e0dc6d
Add files via upload
jonas208 Sep 26, 2023
dd0f0ed
Add files via upload
jonas208 Sep 26, 2023
b341d1c
Add files via upload
jonas208 Sep 26, 2023
94f7964
Update runtests.jl
jonas208 Sep 26, 2023
6cc2e75
Add files via upload
jonas208 Sep 27, 2023
c5c79ee
Add files via upload
jonas208 Sep 27, 2023
132e35c
Add files via upload
jonas208 Sep 27, 2023
aa019e9
Add files via upload
jonas208 Sep 27, 2023
52e2a78
Add files via upload
jonas208 Sep 27, 2023
d63f8a5
Add files via upload
jonas208 Sep 28, 2023
ae86d13
Add files via upload
jonas208 Sep 28, 2023
776835d
Add files via upload
jonas208 Sep 28, 2023
13205da
Add files via upload
jonas208 Sep 28, 2023
5850341
Add files via upload
jonas208 Sep 28, 2023
00b28f2
Add files via upload
jonas208 Sep 28, 2023
af04cc6
Delete runtests.jl
jonas208 Sep 28, 2023
990a34c
Delete Project.toml
jonas208 Sep 28, 2023
db0ad66
Add files via upload
jonas208 Sep 28, 2023
a4e18e6
Add files via upload
jonas208 Sep 29, 2023
f584377
Add files via upload
jonas208 Sep 30, 2023
274db10
Add files via upload
jonas208 Sep 30, 2023
6c33d5c
Add files via upload
jonas208 Sep 30, 2023
7affd46
Add files via upload
jonas208 Oct 3, 2023
3130f8a
Add files via upload
jonas208 Oct 3, 2023
d87f909
Add files via upload
jonas208 Oct 7, 2023
82abca8
Add files via upload
jonas208 Oct 7, 2023
c5ec713
Delete bench_torch.py
jonas208 Oct 8, 2023
3f1c6dc
Add files via upload
jonas208 Oct 8, 2023
8dde5f9
Add files via upload
jonas208 Oct 8, 2023
5505157
Add files via upload
jonas208 Oct 8, 2023
0aa3a3f
Add files via upload
jonas208 Oct 8, 2023
35f2b77
Add files via upload
jonas208 Oct 8, 2023
07943d7
Add files via upload
jonas208 Oct 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.9.6"
version = "0.9.7"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -9,41 +9,54 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"

[extensions]
NNlibAMDGPUExt = "AMDGPU"
NNlibCUDAExt = "CUDA"
NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibLoopVectorizationExt = "LoopVectorization"

[compat]
AMDGPU = "0.5, 0.6"
Adapt = "3.2"
Atomix = "0.1"
ChainRulesCore = "1.13"
CUDA = "4, 5"
cuDNN = "1"
ChainRulesCore = "1.13"
EnzymeCore = "0.5, 0.6"
GPUArraysCore = "0.1"
KernelAbstractions = "0.9.2"
Requires = "1.0"
cuDNN = "1"
julia = "1.9"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -52,6 +65,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["AMDGPU", "CUDA", "ChainRulesTestUtils", "Documenter",
"FiniteDifferences", "ForwardDiff", "Logging", "ReverseDiff",
"StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN"]
test = ["AMDGPU", "BenchmarkTools", "CUDA", "ChainRulesTestUtils", "CpuId", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "LoopVectorization", "ReverseDiff", "StableRNGs", "Test", "UnicodePlots", "Zygote", "cuDNN",
"Enzyme", "EnzymeCore", "EnzymeTestUtils"]
32 changes: 32 additions & 0 deletions bench_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity
import time

model = models.efficientnet_v2_m()
model.eval()

b_size = 1
img = torch.rand(b_size, 3, 224, 224)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True, profile_memory=True) as prof:
with record_function("model_inference"):
pred = model(img)
"""
with record_function("model_backward"):
loss = torch.sum(pred - 0.5) # dummy loss
loss.backward()
"""

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=-1))
# print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=-1))

start1 = time.perf_counter()
pred = model(img)
start2 = time.perf_counter()
loss = torch.sum(pred - 0.5) # dummy loss
loss.backward()
end = time.perf_counter()
print(f"Time used inference: {start2 - start1} seconds")
print(f"Time used backward: {end - start2} seconds")
print(f"Time used inference and backward: {end - start1} seconds")
82 changes: 82 additions & 0 deletions benchmark.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
using NNlib, Flux, Metalhead
using BenchmarkTools, Statistics
using DataFrames, CSV

forward(model, input) = model(input)

dummy_loss(output) = sum(output .- 1)

function train_step(model, input)
∇model, ∇input = gradient(model, input) do m, x
dummy_loss(m(x))
end
return ∇model, ∇input
end

function benchmark(models, dtype, batch_sizes, channels, spatial_size)
model_names = sort(collect(keys(models))) # make sure the models are always in the same order
forward_times = zeros(length(model_names), length(batch_sizes))
train_step_times = zeros(length(model_names), length(batch_sizes))

for (i, model_name) in enumerate(model_names)
println("Benchmarking $model_name...")
for (j, batch_size) in enumerate(batch_sizes)

input = rand(dtype, spatial_size..., channels, batch_size)
model = models[model_name]

forward(model, input) # compilation
train_step(model, input) # compilation

# using @belapsed (minimum time)
#=
forward_times[i, j] = @belapsed forward($model, $input)
train_step_times[i, j] = @belapsed train_step($model, $input)
=#
# using median time
forward_times[i, j] = median(@benchmark forward($model, $input)).time / 10^9
train_step_times[i, j] = median(@benchmark train_step($model, $input)).time / 10^9

end
end

return forward_times, train_step_times
end

# models which should be benchmarked
models = Dict(
"ResNet18" => ResNet(18),
"WideResNet50" => WideResNet(50),
"DenseNet121" => DenseNet(121),
"EfficientNet" => EfficientNet(:b0),
"EfficientNetv2" => EfficientNetv2(:small),
"MobileNetv3" => MobileNetv3(:small),
# "GoogLeNet" => GoogLeNet(),
"ConvNeXt" => ConvNeXt(:tiny),
)

# the data type and batch sizes which should be benchmarked
dtype = Float32
batch_sizes = (1, 32)
# size information (e.g. ImageNet-like images)
channels = 3
spatial_size = (224, 224) # WH

forward_times1, train_step_times1 = benchmark(models, dtype, batch_sizes, channels, spatial_size)
using LoopVectorization # load LoopVectorization here to load the lv-extension
forward_times2, train_step_times2 = benchmark(models, dtype, batch_sizes, channels, spatial_size)

df = DataFrame()
df[!, "model_names"] = sort(collect(keys(models))) # make sure the models are always in the same order

for (i, batch_size) in enumerate(batch_sizes)
df[!, "acceleration inference, batch_size: $batch_size"] = forward_times1[:, i] ./ forward_times2[:, i]
df[!, "acceleration train, batch_size: $batch_size"] = train_step_times1[:, i] ./ train_step_times2[:, i]

df[!, "im2col, inference, batch_size: $batch_size"] = forward_times1[:, i]
df[!, "lv-ext, inference, batch_size: $batch_size"] = forward_times2[:, i]
df[!, "im2col, train, batch_size: $batch_size"] = train_step_times1[:, i]
df[!, "lv-ext, train, batch_size: $batch_size"] = train_step_times2[:, i]
end

CSV.write("benchmark_result_julia.csv", df)
8 changes: 8 additions & 0 deletions benchmark_result_julia.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model_names,"acceleration inference, batch_size: 1","acceleration train, batch_size: 1","im2col, inference, batch_size: 1","lv-ext, inference, batch_size: 1","im2col, train, batch_size: 1","lv-ext, train, batch_size: 1","acceleration inference, batch_size: 32","acceleration train, batch_size: 32","im2col, inference, batch_size: 32","lv-ext, inference, batch_size: 32","im2col, train, batch_size: 32","lv-ext, train, batch_size: 32"
ConvNeXt,3.0131868428607564,1.3994729097036838,0.4240265,0.1407236,1.43802405,1.0275469,1.1620754379865017,1.0632876498150545,4.6366846,3.9900031,14.7683656,13.8893418
DenseNet121,2.7011062104755816,1.575624841888005,0.1855009,0.0686759,0.7069096,0.4486535,1.2534693725910626,1.036169100624124,2.6923755,2.1479388,12.6862194,12.2433871
EfficientNet,6.669989485006747,2.4963127103581892,0.49731575,0.0745602,1.33507035,0.53481695,1.121879889261884,1.1298641673853496,2.5537233,2.2762894,8.1940817,7.2522715
EfficientNetv2,16.28186773870062,7.202334907846903,2.5620854,0.1573582,12.053267,1.67352215,1.4558721609174592,1.203183521905458,6.1329556,4.21256465,21.0444893,17.4906728
MobileNetv3,12.105678302652656,1.5684538123069776,0.1103481,0.0091154,0.31291775,0.19950715,1.2884028351358188,1.1391237206595466,0.43895395,0.3406962,2.0458146,1.7959547
ResNet18,1.321074637025202,1.0621579200481972,0.0558948,0.0423101,0.2110332,0.19868345,1.0855325609238786,0.8862720054297211,0.8071219,0.7435262,2.98925695,3.3728437
WideResNet50,0.6797203960326701,0.7846926795922912,0.1863516,0.2741592,0.6916193,0.88138875,0.8693605563082452,0.7977827693085691,3.68315245,4.23662245,13.5918181,17.0369913
8 changes: 8 additions & 0 deletions benchmark_result_julia_BLAS.set_num_threads(1).csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model_names,"acceleration inference, batch_size: 1","acceleration train, batch_size: 1","im2col, inference, batch_size: 1","lv-ext, inference, batch_size: 1","im2col, train, batch_size: 1","lv-ext, train, batch_size: 1","acceleration inference, batch_size: 32","acceleration train, batch_size: 32","im2col, inference, batch_size: 32","lv-ext, inference, batch_size: 32","im2col, train, batch_size: 32","lv-ext, train, batch_size: 32"
ConvNeXt,2.3560260677286187,1.177474414020855,0.4724393,0.2005238,1.5661168,1.3300644,1.1190087278994194,1.0605556431775283,6.1761583,5.519312,20.2704419,19.1130395
DenseNet121,3.122060107720349,1.6430011190729519,0.2130831,0.0682508,0.7872384,0.4791466,1.0415097744527333,0.9663276503256204,2.2208416,2.1323291,13.0980777,13.5544892
EfficientNet,7.620280694206955,2.614653451310472,0.55973705,0.0734536,1.3205672,0.50506395,1.0814317166163767,1.4227466786836107,2.3782218,2.1991419,7.8866889,5.543284
EfficientNetv2,19.58985619614861,6.567161064233096,3.0016087,0.1532226,12.1968172,1.8572435,1.393368850971046,1.4053620350129883,5.7712793,4.1419609,21.7360648,15.4665234
MobileNetv3,11.880463635205707,1.4356279307443949,0.1165919,0.00981375,0.31690295,0.2207417,1.2244903507960017,1.0962253667265058,0.4233406,0.345728,1.9161035,1.7479102
ResNet18,1.8553921481455138,1.1963831276674246,0.0782507,0.04217475,0.26524125,0.2217026,0.7610563817484197,0.8134254105937594,0.5683152,0.7467452,3.1617936,3.8870111
WideResNet50,1.244074585655623,1.0739855756563101,0.33919395,0.2726476,1.0936016,1.0182647,0.625895474130196,0.8091050354680169,2.59841495,4.1515158,16.8332114,20.8047295
8 changes: 8 additions & 0 deletions benchmark_result_pytorch.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
,model_names,"inference, batch_size: 1","train, batch_size: 1","inference, batch_size: 32","train, batch_size: 32"
0,ConvNeXt,0.0838077,0.3047408,1.6291157,4.6741033
1,DenseNet121,0.108188,0.2309508,1.5926049,4.6524887
2,EfficientNet,0.0661599,0.1544869,0.9558551,2.707274
3,EfficientNetv2,0.129968,0.2956851,1.830179,5.1837134
4,MobileNetv3,0.022812,0.0398323,0.2456348,0.5598582
5,ResNet18,0.0305176,0.074905,0.4218851,1.192504
6,WideResNet50,0.1575364,0.4871242,2.2386081,6.8065705
64 changes: 64 additions & 0 deletions benchmark_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import torchvision.models as visionmodels
import time
import pandas as pd

def dummy_loss(output):
return torch.sum(output - 1)

def train_step(model, input_to_model):
output = model(input_to_model)
loss = dummy_loss(output)
loss.backward()

def benchmark(models, batch_sizes, channels, spatial_size):
model_names = sorted(list(models.keys())) # make sure the models are always in the same order
forward_times = torch.zeros(len(model_names), len(batch_sizes))
train_step_times = torch.zeros(len(model_names), len(batch_sizes))

for i, model_name in enumerate(model_names):
print(f"Benchmarking {model_name}...")
for j, batch_size in enumerate(batch_sizes):

input_to_model = torch.rand(batch_size, channels, spatial_size[0], spatial_size[1])
model = models[model_name]

time_start = time.perf_counter()
model(input_to_model)
time_duration = time.perf_counter() - time_start
forward_times[i, j] = time_duration

time_start = time.perf_counter()
train_step(model, input_to_model)
time_duration = time.perf_counter() - time_start
train_step_times[i, j] = time_duration

return forward_times, train_step_times

models = {
"ResNet18" : visionmodels.resnet18(),
"WideResNet50" : visionmodels.wide_resnet50_2(),
"DenseNet121" : visionmodels.densenet121(),
"EfficientNet" : visionmodels.efficientnet_b0(),
"EfficientNetv2" : visionmodels.efficientnet_v2_s(),
"MobileNetv3" : visionmodels.mobilenet_v3_small(),
# "GoogLeNet" : visionmodels.googlenet(),
"ConvNeXt" : visionmodels.convnext_tiny(),
}

# the batch sizes which should be benchmarked
batch_sizes = (1, 32)
# size information (e.g. ImageNet-like images)
channels = 3
spatial_size = (224, 224) # HW

forward_times, train_step_times = benchmark(models, batch_sizes, channels, spatial_size)

df = pd.DataFrame()
df["model_names"] = sorted(list(models.keys())) # make sure the models are always in the same order

for (i, batch_size) in enumerate(batch_sizes):
df[f"inference, batch_size: {batch_size}"] = forward_times[:, i]
df[f"train, batch_size: {batch_size}"] = train_step_times[:, i]

df.to_csv("benchmark_result_pytorch.csv")
12 changes: 12 additions & 0 deletions ext/NNlibLoopVectorizationExt/NNlibLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module NNlibLoopVectorizationExt

using NNlib
using LoopVectorization
using Random, Statistics
using OffsetArrays, Static

include("conv.jl")
include("pooling.jl")
include("activations.jl")

end # module
26 changes: 26 additions & 0 deletions ext/NNlibLoopVectorizationExt/activations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_tanh(x) = tanh(x)
Base.broadcasted(::typeof(tanh), x::AbstractArray) = @turbo _tanh.(x)

_softsign(x) = x / (1 + abs(x))
Base.broadcasted(::typeof(NNlib.softsign), x::AbstractArray) = @turbo _softsign.(x)

_softplus(x) = log1p(exp(-abs(x)))
Base.broadcasted(::typeof(NNlib.softplus), x::AbstractArray) = (@turbo _softplus.(x)) .+ NNlib.relu.(x)

function _sigmoid(x)
t = exp(-abs(x))
ifelse(x ≥ 0, inv(1 + t), t / (1 + t))
end
Base.broadcasted(::typeof(NNlib.sigmoid), x::AbstractArray) = @turbo _sigmoid.(x)
Base.broadcasted(::typeof(NNlib.sigmoid_fast), x::AbstractArray) = @turbo _sigmoid.(x) # don't do the same for tanh_fast, it would be slower

function _hardsigmoid(x)
clamp((x + 3) / 6, 0, 1)
end
Base.broadcasted(::typeof(NNlib.hardsigmoid), x::AbstractArray) = @turbo _hardsigmoid.(x)

_logsigmoid(x) = -_softplus(-x)
Base.broadcasted(::typeof(NNlib.logsigmoid), x::AbstractArray) = @turbo _logsigmoid.(x)

_swish(x) = x * _sigmoid(x)
Base.broadcasted(::typeof(NNlib.swish), x::AbstractArray) = @turbo _swish.(x)
Loading