Skip to content

Commit

Permalink
Use fused conv
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 23, 2024
1 parent e40a291 commit 446463b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 23 deletions.
29 changes: 11 additions & 18 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,12 @@ function parameterlength(c::Conv{N, use_bias}) where {N, use_bias}
(use_bias ? c.out_chs : 0)
end

@inline function (c::Conv{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups)
return apply_activation(c.activation, _conv(x, ps.weight, cdims)), st
end

@inline function (c::Conv{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups)
return apply_bias_activation(c.activation, _conv(x, ps.weight, cdims), ps.bias), st
@inline function (c::Conv)(x::AbstractArray, ps, st::NamedTuple)
cdims = DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)
return (
fused_conv_bias_activation(
c.activation, ps.weight, x, _getproperty(ps, Val(:bias)), cdims),
st)
end

function Base.show(io::IO, l::Conv)
Expand Down Expand Up @@ -617,16 +613,13 @@ function parameterlength(c::CrossCor{N, use_bias}) where {N, use_bias}
return prod(c.kernel_size) * c.in_chs * c.out_chs + (use_bias ? c.out_chs : 0)
end

@inline function (c::CrossCor{N, false})(x::AbstractArray, ps, st::NamedTuple) where {N}
@inline function (c::CrossCor)(x::AbstractArray, ps, st::NamedTuple)
cdims = DenseConvDims(
DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation); F=true)
return apply_activation(c.activation, _conv(x, ps.weight, cdims)), st
end

@inline function (c::CrossCor{N, true})(x::AbstractArray, ps, st::NamedTuple) where {N}
cdims = DenseConvDims(
DenseConvDims(x, ps.weight; c.stride, padding=c.pad, c.dilation); F=true)
return apply_bias_activation(c.activation, _conv(x, ps.weight, cdims), ps.bias), st
return (
fused_conv_bias_activation(
c.activation, ps.weight, x, _getproperty(ps, Val(:bias)), cdims),
st)
end

function Base.show(io::IO, l::CrossCor)
Expand Down
5 changes: 0 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,6 @@ end

# Backend Integration
## Convolution
@inline _conv(x, weight, cdims) = conv(x, weight, cdims)
@inline function _conv(x::SubArray{T, N, <:AbstractArray}, weight, cdims) where {T, N}
return _conv(copy(x), weight, cdims)
end

@inline _conv_transpose(x, weight, cdims) = ∇conv_data(x, weight, cdims)
@inline function _conv_transpose(
x::SubArray{T, N, <:GPUArraysCore.AnyGPUArray}, weight, cdims) where {T, N}
Expand Down

0 comments on commit 446463b

Please sign in to comment.