Skip to content

Commit

Permalink
Merge #1393
Browse files Browse the repository at this point in the history
1393: remove implicit conversions r=DhairyaLGandhi a=CarloLucibello

For some layers, we currently downcast the input type from Float64 to Float32 if the weights are Float32. 
I think we should follow julia's promotion rules, users should provide Float32 inputs if they want Float32 outputs. 
This also simplifies layers' definitions.

This change may have a performance impact on some people code that may go unnoticed, so while we want to promote good practice, I can understand if someone opposes this change. 

### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [ ] Documentation, if applicable
- [x] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
bors[bot] and CarloLucibello authored Jan 15, 2021
2 parents 33f99ef + 3b10434 commit 1ac78b5
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 43 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

## v0.12.0

* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405)
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405).
* Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394).
* Excise datasets in favour of other providers in the julia ecosystem.
* Added option to set `bias` to [false](https://github.com/FluxML/Flux.jl/pull/1379) to eliminating `bias` from being trained.
* Removed kwarg only constructors for [`convolutional layers`](https://github.com/FluxML/Flux.jl/pull/1379).
Expand Down
16 changes: 8 additions & 8 deletions docs/src/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@ not because the operations are faster, but because the memory usage is halved.
Which means allocations occur much faster.
And you use less memory.


## Preserve inputs' types

Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
they should also preserve the type of their inputs.

A very artificial example using an activation function like

```
my_tanh(x) = Float64(tanh(x))
```julia
my_tanh(x) = Float64(tanh(x))
```

will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would,
Expand All @@ -35,20 +34,21 @@ you will see a large slow-down.
This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals.
E.g. the following will have run into the same problem as above:

```
leaky_tanh(x) = 0.01*x + tanh(x)
```julia
leaky_tanh(x) = 0.01*x + tanh(x)
```

While one could change the activation function (e.g. to use `0.01f0*x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`:
```
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
```

```julia
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
```

## Evaluate batches as Matrices of features

While it can sometimes be tempting to process your observations (feature vectors) one at a time
e.g.

```julia
function loss_total(xs::AbstractVector{<:Vector}, ys::AbstractVector{<:Vector})
sum(zip(xs, ys)) do (x, y_target)
Expand Down
11 changes: 1 addition & 10 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ end

function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
# reshape to handle dims > 1 as batch dimensions
sz = size(x)
x = reshape(x, sz[1], :)
x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
x = σ.(W*x .+ b)
return reshape(x, :, sz[2:end]...)
end
Expand All @@ -134,14 +133,6 @@ function Base.show(io::IO, l::Dense)
print(io, ")")
end

# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
Diagonal(in::Integer)
Expand Down
21 changes: 0 additions & 21 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,6 @@ function Base.show(io::IO, l::Conv)
print(io, ")")
end

(a::Conv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1)
Expand Down Expand Up @@ -265,11 +260,6 @@ function Base.show(io::IO, l::ConvTranspose)
print(io, ")")
end

(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T}
calc_padding(Conv, pad, k .- stride .+ 1, dilation, stride)
Expand Down Expand Up @@ -363,11 +353,6 @@ function Base.show(io::IO, l::DepthwiseConv)
print(io, ")")
end

(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
CrossCor(filter, in => out, σ=identity; stride=1, pad=0, dilation=1)
Expand Down Expand Up @@ -449,12 +434,6 @@ function Base.show(io::IO, l::CrossCor)
print(io, ")")
end

(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

"""
AdaptiveMaxPool(out::NTuple)
Expand Down
9 changes: 6 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,13 @@ end

@testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x = rand(10)
x64 = rand(Float64, 10)
x32 = rand(Float32, 10)
@test eltype(m[1].W) == Float32
@test eltype(m(x)) == Float32
@test eltype(f64(m)(x)) == Float64
@test eltype(m(x32)) == Float32
@test eltype(m(x64)) == Float64
@test eltype(f64(m)(x32)) == Float64
@test eltype(f64(m)(x64)) == Float64
@test eltype(f64(m)[1].W) == Float64
@test eltype(f32(f64(m))[1].W) == Float32
end
Expand Down

0 comments on commit 1ac78b5

Please sign in to comment.