Skip to content

Commit

Permalink
upgrade warnings (#1926)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Apr 4, 2022
1 parent d317492 commit 6405ab3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# v0.12 deprecations

function ones(dims...)
Base.depwarn("Flux.ones(size...) is deprecated, please use Flux.ones32(size...) or Base.ones(Float32, size...)", :ones)
Base.depwarn("Flux.ones(size...) is deprecated, please use Flux.ones32(size...) or Base.ones(Float32, size...)", :ones, force=true)
Base.ones(Float32, dims...)
end
ones(T::Type, dims...) = Base.ones(T, dims...)

function zeros(dims...)
Base.depwarn("Flux.zeros(size...) is deprecated, please use Flux.zeros32(size...) or Base.zeros(Float32, size...)", :zeros)
Base.depwarn("Flux.zeros(size...) is deprecated, please use Flux.zeros32(size...) or Base.zeros(Float32, size...)", :zeros, force=true)
Base.zeros(Float32, dims...)
end
zeros(T::Type, dims...) = Base.zeros(T, dims...)
Expand Down
7 changes: 3 additions & 4 deletions src/losses/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ end
ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # should help Diffractor's broadcasting
ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true)

# This can be made an error in Flux v0.13, for now just a warning
function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
for d in 1:max(ndims(ŷ), ndims(y))
if size(ŷ,d) != size(y,d)
@warn "Size mismatch in loss function! In future this will be an error. In Flux <= 0.12 broadcasting accepts this, but may not give sensible results" summary(ŷ) summary(y) maxlog=3 _id=hash(size(y))
end
size(ŷ,d) == size(y,d) || throw(DimensionMismatch(
"loss function expects size(ŷ) = $(size(ŷ)) to match size(y) = $(size(y))"
))
end
end
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
Expand Down

0 comments on commit 6405ab3

Please sign in to comment.