Skip to content

Commit

Permalink
remove typewarn
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 8, 2021
1 parent 6d03787 commit 65c37c1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
16 changes: 8 additions & 8 deletions docs/src/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@ not because the operations are faster, but because the memory usage is halved.
Which means allocations occur much faster.
And you use less memory.


## Preserve inputs' types

Not only should your activation and loss functions be [type-stable](https://docs.julialang.org/en/v1/manual/performance-tips/#Write-%22type-stable%22-functions-1),
they should also preserve the type of their inputs.

A very artificial example using an activation function like

```
my_tanh(x) = Float64(tanh(x))
```julia
my_tanh(x) = Float64(tanh(x))
```

will result in performance on `Float32` input orders of magnitude slower than the normal `tanh` would,
Expand All @@ -35,20 +34,21 @@ you will see a large slow-down.
This can occur sneakily, because you can cause type-promotion by interacting with a numeric literals.
E.g. the following will have run into the same problem as above:

```
leaky_tanh(x) = 0.01*x + tanh(x)
```julia
leaky_tanh(x) = 0.01*x + tanh(x)
```

While one could change the activation function (e.g. to use `0.01f0*x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`:
```
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
```

```julia
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
```

## Evaluate batches as Matrices of features

While it can sometimes be tempting to process your observations (feature vectors) one at a time
e.g.

```julia
function loss_total(xs::AbstractVector{<:Vector}, ys::AbstractVector{<:Vector})
sum(zip(xs, ys)) do (x, y_target)
Expand Down
4 changes: 0 additions & 4 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ end
@functor Dense

function (a::Dense)(x::AbstractVecOrMat)
eltype(a.W) == eltype(x) || _dense_typewarn(a, x)
W, b, σ = a.W, a.b, a.σ
# reshape to handle dims > 1 as batch dimensions
sz = size(x)
Expand All @@ -129,9 +128,6 @@ function (a::Dense)(x::AbstractVecOrMat)
return reshape(x, :, sz[2:end]...)
end

_dense_typewarn(d, x) = @warn "Element types don't match for layer $d, this will be slow." typeof(d.W) typeof(x) maxlog=1
Zygote.@nograd _dense_typewarn

function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
l.σ == identity || print(io, ", ", l.σ)
Expand Down

0 comments on commit 65c37c1

Please sign in to comment.