From 20e4799466e41074cbb9743a6c9bbd63683ae9a3 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 18 Apr 2024 17:23:12 -0400 Subject: [PATCH] update constructor and format repo --- src/problems/bvp_problems.jl | 14 ++-- src/scimlfunctions.jl | 19 +++--- test/function_building_error_messages.jl | 84 ++++++++++++++++-------- 3 files changed, 76 insertions(+), 41 deletions(-) diff --git a/src/problems/bvp_problems.jl b/src/problems/bvp_problems.jl index 8db6e45a0..cecd6c088 100644 --- a/src/problems/bvp_problems.jl +++ b/src/problems/bvp_problems.jl @@ -320,7 +320,8 @@ struct SecondOrderBVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: problem_type::PT kwargs::K - @add_kwonly function SecondOrderBVProblem{iip}(f::DynamicalBVPFunction{iip, TP}, u0, tspan, + @add_kwonly function SecondOrderBVProblem{iip}( + f::DynamicalBVPFunction{iip, TP}, u0, tspan, p = NullParameters(); problem_type = nothing, nlls = nothing, kwargs...) where {iip, TP} _u0 = prepare_initial_state(u0) @@ -331,16 +332,19 @@ struct SecondOrderBVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <: typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs) end - function SecondOrderBVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip} + function SecondOrderBVProblem{iip}( + f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip} SecondOrderBVProblem(DynamicalBVPFunction{iip}(f, bc), u0, tspan, p; kwargs...) end end function SecondOrderBVProblem(f, bc, u0, tspan, p = NullParameters(); kwargs...) iip = isinplace(f, 5) - return SecondOrderBVProblem{iip}(DynamicalBVPFunction{iip}(f, bc), u0, tspan, p; kwargs...) + return SecondOrderBVProblem{iip}( + DynamicalBVPFunction{iip}(f, bc), u0, tspan, p; kwargs...) end -function SecondOrderBVProblem(f::DynamicalBVPFunction, u0, tspan, p = NullParameters(); kwargs...) +function SecondOrderBVProblem( + f::DynamicalBVPFunction, u0, tspan, p = NullParameters(); kwargs...) return SecondOrderBVProblem{isinplace(f)}(f, u0, tspan, p; kwargs...) -end \ No newline at end of file +end diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 32b0e11c4..becff1bf9 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2142,7 +2142,8 @@ For more details on this argument, see the ODEFunction documentation. The fields of the DynamicalBVPFunction type directly match the names of the inputs. """ -struct DynamicalBVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, +struct DynamicalBVPFunction{ + iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP, JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV, SYS} <: AbstractBVPFunction{iip, twopoint} f::F @@ -3763,7 +3764,8 @@ OptimizationFunction(args...; kwargs...) = OptimizationFunction{true}(args...; k function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); grad = nothing, hess = nothing, hv = nothing, - cons = nothing, cons_j = nothing, cons_h = nothing, + cons = nothing, cons_j = nothing, cons_jvp = nothing, + cons_vjp = nothing, cons_h = nothing, hess_prototype = nothing, cons_jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing, @@ -3785,7 +3787,8 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); sys = sys_or_symbolcache(sys, syms, paramsyms) OptimizationFunction{iip, typeof(adtype), typeof(f), typeof(grad), typeof(hess), typeof(hv), - typeof(cons), typeof(cons_j), typeof(cons_h), + typeof(cons), typeof(cons_j), typeof(cons_jvp), + typeof(cons_vjp), typeof(cons_h), typeof(hess_prototype), typeof(cons_jac_prototype), typeof(cons_hess_prototype), typeof(observed), @@ -3794,7 +3797,8 @@ function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD(); typeof(cons_jac_colorvec), typeof(cons_hess_colorvec), typeof(lag_hess_colorvec) }(f, adtype, grad, hess, - hv, cons, cons_j, cons_h, + hv, cons, cons_j, cons_jvp, + cons_vjp, cons_h, hess_prototype, cons_jac_prototype, cons_hess_prototype, observed, expr, cons_expr, sys, lag_h, lag_hess_prototype, hess_colorvec, cons_jac_colorvec, @@ -3992,7 +3996,6 @@ function DynamicalBVPFunction{iip, specialize, twopoint}(f, bc; colorvec = __has_colorvec(f) ? f.colorvec : nothing, bccolorvec = __has_colorvec(bc) ? bc.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, twopoint} - if mass_matrix === I && f isa Tuple mass_matrix = ((I for i in 1:length(f))...,) end @@ -4100,7 +4103,7 @@ function DynamicalBVPFunction{iip, specialize, twopoint}(f, bc; _f = prepare_function(f) sys = something(sys, SymbolCache(syms, paramsyms, indepsym)) - + if specialize === NoSpecialize DynamicalBVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, @@ -4132,11 +4135,11 @@ function DynamicalBVPFunction{iip}(f, bc; twopoint::Union{Val, Bool} = Val(false end DynamicalBVPFunction{iip}(f::DynamicalBVPFunction, bc; kwargs...) where {iip} = f function DynamicalBVPFunction(f, bc; twopoint::Union{Val, Bool} = Val(false), kwargs...) - DynamicalBVPFunction{isinplace(f, 5), FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...) + DynamicalBVPFunction{isinplace(f, 5), FullSpecialize, _unwrap_val(twopoint)}( + f, bc; kwargs...) end DynamicalBVPFunction(f::DynamicalBVPFunction; kwargs...) = f - function IntegralFunction{iip, specialize}(f, integrand_prototype) where {iip, specialize} _f = prepare_function(f) IntegralFunction{iip, specialize, typeof(_f), typeof(integrand_prototype)}(_f, diff --git a/test/function_building_error_messages.jl b/test/function_building_error_messages.jl index c9773b93e..3437bf106 100644 --- a/test/function_building_error_messages.jl +++ b/test/function_building_error_messages.jl @@ -706,24 +706,35 @@ DynamicalBVPFunction(dbfiip, dbciip, jac = dbjac, bcjac = dbcjac) DynamicalBVPFunction(dbfoop, dbcoop, jac = dbjac, bcjac = dbcjac) dbWfact(du, u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact = dbWfact) dbWfact(du, u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact = dbWfact) dbWfact(du, u, p, gamma, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfiip, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact = dbWfact) dbWfact(ddu, du, u, p, gamma, t) = [1.0] DynamicalBVPFunction(dbfiip, dbciip, Wfact = dbWfact) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, Wfact = dbWfact) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact = dbWfact) dbWfact_t(du, u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact_t = dbWfact_t) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact_t = dbWfact_t) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, Wfact_t = dbWfact_t) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact_t = dbWfact_t) dbWfact_t(du, u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, Wfact_t = dbWfact_t) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, Wfact_t = dbWfact_t) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, Wfact_t = dbWfact_t) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, Wfact_t = dbWfact_t) dbWfact_t(du, u, p, gamma, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, @@ -738,18 +749,25 @@ DynamicalBVPFunction(dbfiip, dbciip, Wfact_t = dbWfact_t) Wfact_t = dbWfact_t) dbtgrad(du, u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, tgrad = dbtgrad) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, tgrad = dbtgrad) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, tgrad = dbtgrad) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, tgrad = dbtgrad) dbtgrad(du, u, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, tgrad = dbtgrad) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, tgrad = dbtgrad) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfiip, dbciip, tgrad = dbtgrad) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, tgrad = dbtgrad) dbtgrad(ddu, du, u, p, t) = [1.0] DynamicalBVPFunction(dbfiip, dbciip, tgrad = dbtgrad) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, tgrad = dbtgrad) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, tgrad = dbtgrad) dbparamjac(du, u, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, paramjac = dbparamjac) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, paramjac = dbparamjac) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, paramjac = dbparamjac) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, paramjac = dbparamjac) dbparamjac(du, u, p, t) = [1.0] @test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, @@ -764,25 +782,35 @@ DynamicalBVPFunction(dbfiip, dbciip, paramjac = dbparamjac) paramjac = dbparamjac) dbjvp(du, u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, jvp = dbjvp) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, jvp = dbjvp) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, jvp = dbjvp) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, jvp = dbjvp) dbjvp(du, u, v, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, jvp = dbjvp) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, jvp = dbjvp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfiip, dbciip, jvp = dbjvp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, jvp = dbjvp) dbjvp(ddu, du, u, v, p, t) = [1.0] DynamicalBVPFunction(dbfiip, dbciip, jvp = dbjvp) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, jvp = dbjvp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, jvp = dbjvp) dbvjp(du, u, p, t) = [1.0] -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfiip, dbciip, vjp = dbvjp) -@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction(dbfoop, dbciip, vjp = dbvjp) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfiip, dbciip, vjp = dbvjp) +@test_throws SciMLBase.TooFewArgumentsError DynamicalBVPFunction( + dbfoop, dbciip, vjp = dbvjp) dbvjp(du, u, v, p, t) = [1.0] -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfiip, dbciip, vjp = dbvjp) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, vjp = dbvjp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfiip, dbciip, vjp = dbvjp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, vjp = dbvjp) dbvjp(ddu, du, u, v, p, t) = [1.0] DynamicalBVPFunction(dbfiip, dbciip, vjp = dbvjp) -@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction(dbfoop, dbciip, vjp = dbvjp) +@test_throws SciMLBase.NonconformingFunctionsError DynamicalBVPFunction( + dbfoop, dbciip, vjp = dbvjp) # IntegralFunction