Skip to content

Commit

Permalink
Fixing implementation of _big_finale by adding two more methods for
Browse files Browse the repository at this point in the history
`_childarray_sum`.
  • Loading branch information
codetalker7 committed Sep 21, 2023
1 parent 4d57436 commit 918da32
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ for T in [
end
end

function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
function _layer_show(io::IO, layer, indent::Int=1, name=nothing)
_str = isnothing(name) ? "" : "$name = "
str = _str * sprint(show, layer, context=io)
print(io, " "^indent, str, indent==0 ? "" : ",")
Expand All @@ -94,7 +94,7 @@ function _big_finale(io::IO, m)
pars = underscorise(sum(length, ps; init=0))
bytes = Base.format_bytes(Base.summarysize(m))
unique_params = IdSet()
noncnt = _childarray_sum(x -> unique_param!(x, unique_params), m) - length(ps)
noncnt = _childarray_sum(_ -> 1, m, unique_params) - length(ps)
if noncnt > 0
nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps; init=0))
printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black)
Expand All @@ -111,21 +111,23 @@ end
_childarray_sum(f, x::AbstractArray{<:Number}) = f(x)
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x),
init=0)
_childarray_sum(f, x::AbstractArray{<:Number}, idset::Base.IdSet) = f(x)
function _childarray_sum(f, x, idset::Base.IdSet)
isleaf(x) && return 0

# utility functions

underscorise(n::Integer) =
join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_')

function unique_param!(x::AbstractArray{<:Number}, idset::Base.IdSet)
if x in idset
0
return 0
else
push!(idset, x)
1
return sum(y -> _childarray_sum(f, y, idset), Functors.children(x), init = 0)
end
end

# utility functions

underscorise(n::Integer) =
join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_')

function _nan_show(io::IO, x)
if !isempty(x) && _all(iszero, x)
printstyled(io, " (all zero)", color=:cyan)
Expand Down

0 comments on commit 918da32

Please sign in to comment.