From 9cebf82f2a8118136639d18d0f16e5dcfb473f46 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 9 Aug 2022 19:02:21 +0100 Subject: [PATCH] Support more non-perturbable types --- Project.toml | 2 +- src/to_vec.jl | 3 ++- test/to_vec.jl | 38 +++++++++++++++++++++++++++++--------- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 68561cc..d4b65a0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.24" +version = "0.12.25" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/to_vec.jl b/src/to_vec.jl index 6dd1fbc..42608c2 100644 --- a/src/to_vec.jl +++ b/src/to_vec.jl @@ -20,6 +20,7 @@ end # Base case -- if x is already a Vector{<:Real} there's no conversion necessary. to_vec(x::Vector{<:Real}) = (x, identity) +to_vec(x::Vector{<:Bool}) = invoke(to_vec, Tuple{DenseVector}, x) # not for bool arrays # get around the constructors and make the type directly # Note this is moderately evil accessing julia's internals @@ -260,7 +261,7 @@ function to_vec(d::Dict) end # non-perturbable types -for T in (:DataType, :CartesianIndex, :AbstractZero) +for T in (:DataType, :CartesianIndex, :AbstractZero, :Bool, :Nothing, :AbstractString, :Symbol) T_from_vec = Symbol(T, :_from_vec) @eval function FiniteDifferences.to_vec(x::$T) function $T_from_vec(x_vec::Vector) diff --git a/test/to_vec.jl b/test/to_vec.jl index 4f01ff1..4149573 100644 --- a/test/to_vec.jl +++ b/test/to_vec.jl @@ -200,16 +200,32 @@ end end end - @testset "DataType" begin - test_to_vec(Float64; check_inferred=false) # isa DataType - test_to_vec(Vector; check_inferred=false) # isa UnionAll - end + @testset "Nondifferentiable types" begin + @testset "DataType" begin + test_to_vec(Float64; check_inferred=false) # isa DataType + test_to_vec(Vector; check_inferred=false) # isa UnionAll + end + + @testset "CartesianIndex" begin + test_to_vec(CartesianIndex(1)) + test_to_vec(CartesianIndex(1, 2)) + @test to_vec(CartesianIndex(1))[1] == [] + @test to_vec(CartesianIndex(1, 3))[1] == [] + end + + @testset "Bool" begin + test_to_vec(true) + @test to_vec(true)[1] == [] - @testset "CartesianIndex" begin - test_to_vec(CartesianIndex(1)) - test_to_vec(CartesianIndex(1, 2)) - @test to_vec(CartesianIndex(1))[1] == [] - @test to_vec(CartesianIndex(1, 3))[1] == [] + test_to_vec([true, false]) + @test to_vec([true, false])[1] == [] + end + + @testset "misc Base types" begin + test_to_vec(nothing) + test_to_vec("a") + test_to_vec(:b) + end end @testset "ChainRulesCore Differentials" begin @@ -260,8 +276,12 @@ end end @testset "fallback" begin + test_to_vec(ThreeFields(nothing, 1.5, false), check_inferred=false) + @test to_vec(ThreeFields(nothing, 1.5, false))[1] == [1.5] # drops the two nonpertubable fields + nested = Nested(ThreeFields(1.0, 2.0, "Three"), Singleton()) test_to_vec(nested; check_inferred=false) # map + end @testset "WrapperArray" begin