From aaa94fdb050e0b5333725d7d968b58b080b67df2 Mon Sep 17 00:00:00 2001 From: David Little Date: Thu, 26 Aug 2021 10:10:14 -0400 Subject: [PATCH] Only extract numeric arrays (#5) Uses `is_numeric_array` to filter out any non-numeric arrays. --- Project.toml | 2 +- src/functors.jl | 10 +++++++--- test/runtests.jl | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 4cb6429..a22eec0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LegolasFlux" uuid = "eb5f792d-d1b1-4535-bae3-d5649ec7daa4" authors = ["Beacon Biosignals, Inc."] -version = "0.1.1" +version = "0.1.2" [deps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" diff --git a/src/functors.jl b/src/functors.jl index 175d8e3..9b98504 100644 --- a/src/functors.jl +++ b/src/functors.jl @@ -15,13 +15,17 @@ end fetch_weights(m) -> Vector{Array} Returns the weights of a model by using `Functors.children` to recurse -through the model, keeping any arrays found. The `@functor` macro defines +through the model, keeping any numeric arrays found. The `@functor` macro defines `Functors.children` automatically so that should be sufficient to support custom types. -Note that this function does not copy the results, so that e.g. mutating `fetch_weights(m)[1]` modifies the model. +Note that this function does not copy the results, so that e.g. mutating +`fetch_weights(m)[1]` modifies the model. """ -fetch_weights(m) = filter(x -> x isa Array, fcollect2(m)) +fetch_weights(m) = filter(is_numeric_array, fcollect2(m)) +is_numeric_array(x) = false +is_numeric_array(x::Array{<:Number}) = true +is_numeric_array(x::Array) = all(x -> x isa Number || is_numeric_array(x), x) """ load_weights!(m, xs) diff --git a/test/runtests.jl b/test/runtests.jl index 82ae7c3..e78cc85 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,6 +37,24 @@ end rm("my_model.model.arrow") end +struct MyArrayModel + dense_array::Array +end +Flux.@functor MyArrayModel + +@testset "Non-numeric arrays ignored" begin + m = MyArrayModel([Dense(1, 10), Dense(10, 10), Dense(10, 1)]) + weights = fetch_weights(m) + @test length(weights) == 6 + + model_row = ModelRow(; weights=collect(weights)) + write_model_row("my_model.model.arrow", model_row) + + new_model_row = read_model_row("my_model.model.arrow") + new_weights = collect(new_model_row.weights) + @test new_weights == weights +end + @testset "Errors" begin my_model = make_my_model() w = test_weights()