From 422691344e99d4d63ca213922da90e42343e2332 Mon Sep 17 00:00:00 2001 From: Lilith Hafner Date: Sat, 23 Sep 2023 16:19:40 -0500 Subject: [PATCH] Fix numargs for ComposedFunction --- src/utils.jl | 2 ++ test/function_building_error_messages.jl | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index d013c9ac9d..7f5f4b8246 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -27,6 +27,8 @@ function numargs(f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ (length(T),) end +numargs(f::ComposedFunction) = numargs(f.inner) + """ $(SIGNATURES) diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index 269985a397..924c342e51 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -7,10 +7,11 @@ function test_num_args() numpar = SciMLBase.numargs(f) # Should be [1,2] g = (x, y) -> x^2 numpar2 = SciMLBase.numargs(g) # [2] + numpar3 = SciMLBase.numargs(sqrt ∘ g) # [2] @show numpar, minimum(numpar) == 1, maximum(numpar) == 2 minimum(numpar) == 1 && maximum(numpar) == 2 && maximum(numpar2) == 2 && - minimum(numpar2) == 2 + only(numpar3) == 2 end @test test_num_args() @@ -614,7 +615,7 @@ bvjp(u, v, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, vjp = bvjp) bvjp(du, u, v, p, t) = [1.0] BVPFunction(bfiip, bciip, vjp = bvjp) - + @test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, vjp = bvjp) # IntegralFunction