diff --git a/Project.toml b/Project.toml index fd29a0e..b739289 100644 --- a/Project.toml +++ b/Project.toml @@ -4,10 +4,12 @@ authors = ["Mike J Innes "] version = "0.4.1" [deps] +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] +ConstructionBase = "1.4" Documenter = "0.27" julia = "1.6" diff --git a/src/Functors.jl b/src/Functors.jl index ca04dfd..c1d3e0a 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -1,5 +1,6 @@ module Functors +using ConstructionBase: constructorof export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect include("functor.jl") diff --git a/src/functor.jl b/src/functor.jl index d21ff8e..b29a39f 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -18,7 +18,7 @@ function functor(T, x) if isempty(names) return (), _ -> x end - S = T.name.wrapper # remove parameters from parametric types + S = constructorof(T) # remove parameters from parametric types and support anonymous functions vals = ntuple(i -> getfield(x, names[i]), length(names)) return NamedTuple{names}(vals), y -> S(y...) end diff --git a/test/basics.jl b/test/basics.jl index 1fca9c7..bb94f1b 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -148,6 +148,17 @@ end @test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1)) end +@testset "anonymous functions" begin + model = let W = rand(2,2), b = ones(2) + x -> tanh.(W*x .+ b) + end + newmodel = fmap(zero, model) + @test newmodel isa Function + @test newmodel([1,2]) == [0,0] + @test newmodel.W == [0 0; 0 0] + @test newmodel.b == [0, 0] +end + ### ### Extras ###