Skip to content

Commit

Permalink
Only extract numeric arrays (#5)
Browse files Browse the repository at this point in the history
Uses `is_numeric_array` to filter out any non-numeric arrays.
  • Loading branch information
haberdashPI authored Aug 26, 2021
1 parent cd965c8 commit aaa94fd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LegolasFlux"
uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4"
authors = ["Beacon Biosignals, Inc."]
version = "0.1.1"
version = "0.1.2"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
Expand Down
10 changes: 7 additions & 3 deletions src/functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ end
fetch_weights(m) -> Vector{Array}
Returns the weights of a model by using `Functors.children` to recurse
through the model, keeping any arrays found. The `@functor` macro defines
through the model, keeping any numeric arrays found. The `@functor` macro defines
`Functors.children` automatically so that should be sufficient to support
custom types.
Note that this function does not copy the results, so that e.g. mutating `fetch_weights(m)[1]` modifies the model.
Note that this function does not copy the results, so that e.g. mutating
`fetch_weights(m)[1]` modifies the model.
"""
fetch_weights(m) = filter(x -> x isa Array, fcollect2(m))
fetch_weights(m) = filter(is_numeric_array, fcollect2(m))
is_numeric_array(x) = false
is_numeric_array(x::Array{<:Number}) = true
is_numeric_array(x::Array) = all(x -> x isa Number || is_numeric_array(x), x)

"""
load_weights!(m, xs)
Expand Down
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ end
rm("my_model.model.arrow")
end

struct MyArrayModel
dense_array::Array
end
Flux.@functor MyArrayModel

@testset "Non-numeric arrays ignored" begin
m = MyArrayModel([Dense(1, 10), Dense(10, 10), Dense(10, 1)])
weights = fetch_weights(m)
@test length(weights) == 6

model_row = ModelRow(; weights=collect(weights))
write_model_row("my_model.model.arrow", model_row)

new_model_row = read_model_row("my_model.model.arrow")
new_weights = collect(new_model_row.weights)
@test new_weights == weights
end

@testset "Errors" begin
my_model = make_my_model()
w = test_weights()
Expand Down

2 comments on commit aaa94fd

@ericphanson
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/43588

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.2 -m "<description of version>" aaa94fdb050e0b5333725d7d968b58b080b67df2
git push origin v0.1.2

Please sign in to comment.