From 90789c069cf072ec31e7e45a2b44871ac2de8ca5 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 7 Nov 2023 23:10:21 -0500 Subject: [PATCH 01/13] WIP --- Project.toml | 2 ++ src/ModelingToolkit.jl | 1 + src/systems/diffeqs/odesystem.jl | 3 ++- src/systems/unit_check.jl | 42 ++++++++++++++++++++++++++++++++ src/systems/validation.jl | 32 +++++++++++++++--------- test/runtests.jl | 9 ++++++- test/units.jl | 2 +- 7 files changed, 76 insertions(+), 15 deletions(-) create mode 100644 src/systems/unit_check.jl diff --git a/Project.toml b/Project.toml index 69d50d7ba8..5382757b5e 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" +DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -71,6 +72,7 @@ DiffRules = "0.1, 1.0" Distributions = "0.23, 0.24, 0.25" DocStringExtensions = "0.7, 0.8, 0.9" DomainSets = "0.6" +DynamicQuantities = "0.8" ForwardDiff = "0.10.3" FunctionWrappersWrappers = "0.1" Graphs = "1.5.2" diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 99bf0b015b..a82110b45d 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -152,6 +152,7 @@ include("systems/pde/pdesystem.jl") include("systems/sparsematrixclil.jl") include("systems/discrete_system/discrete_system.jl") +include("systems/unit_check.jl") include("systems/validation.jl") include("systems/dependency_graphs.jl") include("clock.jl") diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 77975c91d9..508cb8cfe5 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -163,7 +163,8 @@ struct ODESystem <: AbstractODESystem check_equations(equations(cevents), iv) end if checks == true || (checks & CheckUnits) > 0 - all_dimensionless([dvs; ps; iv]) || check_units(deqs) + u = __get_unit_type(dvs, ps, iv) + check_units(u, deqs) end new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching, diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl new file mode 100644 index 0000000000..902501f9ea --- /dev/null +++ b/src/systems/unit_check.jl @@ -0,0 +1,42 @@ +import DynamicQuantities +const DQ = DynamicQuantities + +struct ValidationError <: Exception + message::String +end + +check_units(::Nothing, _...) = true + +__get_literal_unit(x) = getmetadata(x, VariableUnit, nothing) +function __get_unit_type(vs′...) + vs = Iterators.flatten(vs′) + for v in vs + u = __get_literal_unit(v) + if u isa DQ.AbstractQuantity + return Val(:Unitful) + else + return Val(:DynamicQuantities) + end + end + return nothing +end + +function check_units(::Val{:DynamicQuantities}, eqs...) + validate(eqs...) || + throw(ValidationError("Some equations had invalid units. See warnings for details.")) +end + +function screen_units(result) + if result isa DQ.AbstractQuantity + d = DQ.dimension(result) + if d isa DQ.Dimensions + return result + elseif d isa DQ.SymbolicDimensions + throw(ValidationError("$result uses SymbolicDimensions, please use `u\"m\"` to instantiate SI unit only.")) + else + throw(ValidationError("$result doesn't use SI unit, please use `u\"m\"` to instantiate SI unit only.")) + end + end +end + + diff --git a/src/systems/validation.jl b/src/systems/validation.jl index d3ce0ea9a4..ce15b68a1f 100644 --- a/src/systems/validation.jl +++ b/src/systems/validation.jl @@ -1,10 +1,12 @@ +module UnitfulUnitCheck + +using .ModelingToolkit, Symbolics, SciMLBase +using .ModelingToolkit: ValidationError +const MT = ModelingToolkit + Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y Base.:/(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y -struct ValidationError <: Exception - message::String -end - """ Throw exception on invalid unit types, otherwise return argument. """ @@ -60,7 +62,11 @@ get_literal_unit(x) = screen_unit(getmetadata(x, VariableUnit, unitless)) function get_unit(op, args) # Fallback result = op(1 .* get_unit.(args)...) try - unit(result) + if result isa DQ.AbstractQuantity + oneunit(result) + else + unit(result) + end catch throw(ValidationError("Unable to get unit for operation $op with arguments $args.")) end @@ -211,15 +217,15 @@ function _validate(conn::Connection; info::String = "") valid end -function validate(jump::Union{ModelingToolkit.VariableRateJump, - ModelingToolkit.ConstantRateJump}, t::Symbolic; +function validate(jump::Union{MT.VariableRateJump, + MT.ConstantRateJump}, t::Symbolic; info::String = "") newinfo = replace(info, "eq." => "jump") _validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units validate(jump.affect!, info = newinfo) end -function validate(jump::ModelingToolkit.MassActionJump, t::Symbolic; info::String = "") +function validate(jump::MT.MassActionJump, t::Symbolic; info::String = "") left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols net_symbols = [x[1] for x in jump.net_stoch] all_symbols = vcat(left_symbols, net_symbols) @@ -235,18 +241,18 @@ function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Sy all([validate(jumps.x[idx], t, info = labels[idx]) for idx in 1:3]) end -function validate(eq::ModelingToolkit.Equation; info::String = "") +function validate(eq::MT.Equation; info::String = "") if typeof(eq.lhs) == Connection _validate(eq.rhs; info) else _validate([eq.lhs, eq.rhs], ["left", "right"]; info) end end -function validate(eq::ModelingToolkit.Equation, +function validate(eq::MT.Equation, term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "") _validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info) end -function validate(eq::ModelingToolkit.Equation, terms::Vector; info::String = "") +function validate(eq::MT.Equation, terms::Vector; info::String = "") _validate(vcat([eq.lhs, eq.rhs], terms), vcat(["left", "right"], "noise #" .* string.(1:length(terms))); info) end @@ -273,8 +279,10 @@ validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== n """ Throws error if units of equations are invalid. """ -function check_units(eqs...) +function MT.check_units(::Val{:Unitful}, eqs...) validate(eqs...) || throw(ValidationError("Some equations had invalid units. See warnings for details.")) end all_dimensionless(states) = all(x -> safe_get_unit(x, "") in (unitless, nothing), states) + +end # module diff --git a/test/runtests.jl b/test/runtests.jl index 5cafa89f68..39338bba00 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,7 +13,14 @@ using SafeTestsets, Test @safetestset "Clock Test" include("clock.jl") @safetestset "DiscreteSystem Test" include("discretesystem.jl") @safetestset "ODESystem Test" include("odesystem.jl") -@safetestset "Unitful Quantities Test" include("units.jl") +@safetestset "Dynamic Quantities Test" begin + using DynamicQuantities + include("units.jl") +end +@safetestset "Unitful Quantities Test" begin + using Unitful + include("units.jl") +end @safetestset "LabelledArrays Test" include("labelledarrays.jl") @safetestset "Mass Matrix Test" include("mass_matrix.jl") @safetestset "SteadyStateSystem Test" include("steadystatesystems.jl") diff --git a/test/units.jl b/test/units.jl index 9abe428cd2..2442d71831 100644 --- a/test/units.jl +++ b/test/units.jl @@ -1,4 +1,4 @@ -using ModelingToolkit, Unitful, OrdinaryDiffEq, JumpProcesses, IfElse +using ModelingToolkit, OrdinaryDiffEq, JumpProcesses, IfElse using Test MT = ModelingToolkit @parameters τ [unit = u"ms"] γ From b1d736dca019566d311ed4b0087c2be960a104f8 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 9 Nov 2023 13:39:02 -0500 Subject: [PATCH 02/13] Move Unitful checks to a separate module --- src/systems/diffeqs/sdesystem.jl | 3 +- src/systems/jumps/jumpsystem.jl | 3 +- src/systems/nonlinear/nonlinearsystem.jl | 3 +- .../optimization/constraints_system.jl | 3 +- .../optimization/optimizationsystem.jl | 3 +- src/systems/pde/pdesystem.jl | 3 +- src/systems/unit_check.jl | 28 +++++--- src/systems/validation.jl | 13 ++-- test/runtests.jl | 8 +-- test/units.jl | 65 ++++++++++--------- 10 files changed, 73 insertions(+), 59 deletions(-) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index e584857fd8..5f3f9fb00f 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -134,7 +134,8 @@ struct SDESystem <: AbstractODESystem check_equations(equations(cevents), iv) end if checks == true || (checks & CheckUnits) > 0 - all_dimensionless([dvs; ps; iv]) || check_units(deqs, neqs) + u = __get_unit_type(dvs, ps, iv) + check_units(u, deqs, neqs) end new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index df50834903..3633c9c7b3 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -112,7 +112,8 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem check_parameters(ps, iv) end if checks == true || (checks & CheckUnits) > 0 - all_dimensionless([states; ps; iv]) || check_units(ap, iv) + u = __get_unit_type(states, ps, iv) + check_units(u, ap, iv) end new{U}(tag, ap, iv, states, ps, var_to_name, observed, name, systems, defaults, connector_type, devents, metadata, gui_metadata, complete) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 0eca0ce7d3..8ed1e198f8 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -88,7 +88,8 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem tearing_state = nothing, substitutions = nothing, complete = false, parent = nothing; checks::Union{Bool, Int} = true) if checks == true || (checks & CheckUnits) > 0 - all_dimensionless([states; ps]) || check_units(eqs) + u = __get_unit_type(states, ps) + check_units(u, eqs) end new(tag, eqs, states, ps, var_to_name, observed, jac, name, systems, defaults, connector_type, metadata, gui_metadata, tearing_state, substitutions, complete, diff --git a/src/systems/optimization/constraints_system.jl b/src/systems/optimization/constraints_system.jl index 46def83701..4e4a1ed6f1 100644 --- a/src/systems/optimization/constraints_system.jl +++ b/src/systems/optimization/constraints_system.jl @@ -77,7 +77,8 @@ struct ConstraintsSystem <: AbstractTimeIndependentSystem tearing_state = nothing, substitutions = nothing; checks::Union{Bool, Int} = true) if checks == true || (checks & CheckUnits) > 0 - all_dimensionless([states; ps]) || check_units(constraints) + u = __get_unit_type(states, ps) + check_units(u, constraints) end new(tag, constraints, states, ps, var_to_name, observed, jac, name, systems, defaults, diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index e1ab36c9ae..fe6768daed 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -68,7 +68,8 @@ struct OptimizationSystem <: AbstractOptimizationSystem if checks == true || (checks & CheckUnits) > 0 unwrap(op) isa Symbolic && check_units(op) check_units(observed) - all_dimensionless([states; ps]) || check_units(constraints) + u = __get_unit_type(states, ps) + check_units(u, constraints) end new(tag, op, states, ps, var_to_name, observed, constraints, name, systems, defaults, metadata, gui_metadata, complete, diff --git a/src/systems/pde/pdesystem.jl b/src/systems/pde/pdesystem.jl index f0d8ec3869..50e01e907c 100644 --- a/src/systems/pde/pdesystem.jl +++ b/src/systems/pde/pdesystem.jl @@ -98,7 +98,8 @@ struct PDESystem <: ModelingToolkit.AbstractMultivariateSystem checks::Union{Bool, Int} = true, name) if checks == true || (checks & CheckUnits) > 0 - all_dimensionless([dvs; ivs; ps]) || check_units(eqs) + u = __get_unit_type(dvs, ivs, ps) + check_units(u, deqs) end eqs = eqs isa Vector ? eqs : [eqs] diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index 902501f9ea..e13f382f84 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -1,4 +1,4 @@ -import DynamicQuantities +import DynamicQuantities, Unitful const DQ = DynamicQuantities struct ValidationError <: Exception @@ -8,14 +8,26 @@ end check_units(::Nothing, _...) = true __get_literal_unit(x) = getmetadata(x, VariableUnit, nothing) +function __get_scalar_unit_type(v) + u = __get_literal_unit(v) + if u isa DQ.AbstractQuantity + return Val(:DynamicQuantities) + elseif u isa Unitful.Unitlike + return Val(:Unitful) + end + return nothing +end function __get_unit_type(vs′...) - vs = Iterators.flatten(vs′) - for v in vs - u = __get_literal_unit(v) - if u isa DQ.AbstractQuantity - return Val(:Unitful) + for vs in vs′ + if vs isa AbstractVector + for v in vs + u = __get_scalar_unit_type(v) + u === nothing || return u + end else - return Val(:DynamicQuantities) + v = vs + u = __get_scalar_unit_type(v) + u === nothing || return u end end return nothing @@ -38,5 +50,3 @@ function screen_units(result) end end end - - diff --git a/src/systems/validation.jl b/src/systems/validation.jl index ce15b68a1f..0acab06281 100644 --- a/src/systems/validation.jl +++ b/src/systems/validation.jl @@ -1,7 +1,9 @@ module UnitfulUnitCheck -using .ModelingToolkit, Symbolics, SciMLBase -using .ModelingToolkit: ValidationError +using ..ModelingToolkit, Symbolics, SciMLBase, Unitful, IfElse, RecursiveArrayTools +using ..ModelingToolkit: ValidationError, + ModelingToolkit, Connection, instream, JumpType, VariableUnit, get_systems +using Symbolics: Symbolic, value, issym, isadd, ismul, ispow const MT = ModelingToolkit Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y @@ -62,11 +64,7 @@ get_literal_unit(x) = screen_unit(getmetadata(x, VariableUnit, unitless)) function get_unit(op, args) # Fallback result = op(1 .* get_unit.(args)...) try - if result isa DQ.AbstractQuantity - oneunit(result) - else - unit(result) - end + unit(result) catch throw(ValidationError("Unable to get unit for operation $op with arguments $args.")) end @@ -283,6 +281,5 @@ function MT.check_units(::Val{:Unitful}, eqs...) validate(eqs...) || throw(ValidationError("Some equations had invalid units. See warnings for details.")) end -all_dimensionless(states) = all(x -> safe_get_unit(x, "") in (unitless, nothing), states) end # module diff --git a/test/runtests.jl b/test/runtests.jl index 39338bba00..d6562a09d3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,10 +13,10 @@ using SafeTestsets, Test @safetestset "Clock Test" include("clock.jl") @safetestset "DiscreteSystem Test" include("discretesystem.jl") @safetestset "ODESystem Test" include("odesystem.jl") -@safetestset "Dynamic Quantities Test" begin - using DynamicQuantities - include("units.jl") -end +#@safetestset "Dynamic Quantities Test" begin +# using DynamicQuantities +# include("units.jl") +#end @safetestset "Unitful Quantities Test" begin using Unitful include("units.jl") diff --git a/test/units.jl b/test/units.jl index 2442d71831..9135a5b51e 100644 --- a/test/units.jl +++ b/test/units.jl @@ -1,51 +1,52 @@ using ModelingToolkit, OrdinaryDiffEq, JumpProcesses, IfElse using Test MT = ModelingToolkit +UMT = ModelingToolkit.UnitfulUnitCheck @parameters τ [unit = u"ms"] γ @variables t [unit = u"ms"] E(t) [unit = u"kJ"] P(t) [unit = u"MW"] D = Differential(t) #This is how equivalent works: -@test MT.equivalent(u"MW", u"kJ/ms") -@test !MT.equivalent(u"m", u"cm") -@test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) +@test UMT.equivalent(u"MW", u"kJ/ms") +@test !UMT.equivalent(u"m", u"cm") +@test UMT.equivalent(UMT.get_unit(P^γ), UMT.get_unit((E / τ)^γ)) # Basic access -@test MT.get_unit(t) == u"ms" -@test MT.get_unit(E) == u"kJ" -@test MT.get_unit(τ) == u"ms" -@test MT.get_unit(γ) == MT.unitless -@test MT.get_unit(0.5) == MT.unitless -@test MT.get_unit(MT.SciMLBase.NullParameters()) == MT.unitless +@test UMT.get_unit(t) == u"ms" +@test UMT.get_unit(E) == u"kJ" +@test UMT.get_unit(τ) == u"ms" +@test UMT.get_unit(γ) == UMT.unitless +@test UMT.get_unit(0.5) == UMT.unitless +@test UMT.get_unit(UMT.SciMLBase.NullParameters()) == UMT.unitless # Prohibited unit types @parameters β [unit = u"°"] α [unit = u"°C"] γ [unit = 1u"s"] -@test_throws MT.ValidationError MT.get_unit(β) -@test_throws MT.ValidationError MT.get_unit(α) -@test_throws MT.ValidationError MT.get_unit(γ) +@test_throws UMT.ValidationError UMT.get_unit(β) +@test_throws UMT.ValidationError UMT.get_unit(α) +@test_throws UMT.ValidationError UMT.get_unit(γ) # Non-trivial equivalence & operators -@test MT.get_unit(τ^-1) == u"ms^-1" -@test MT.equivalent(MT.get_unit(D(E)), u"MW") -@test MT.equivalent(MT.get_unit(E / τ), u"MW") -@test MT.get_unit(2 * P) == u"MW" -@test MT.get_unit(t / τ) == MT.unitless -@test MT.equivalent(MT.get_unit(P - E / τ), u"MW") -@test MT.equivalent(MT.get_unit(D(D(E))), u"MW/ms") -@test MT.get_unit(IfElse.ifelse(t > t, P, E / τ)) == u"MW" -@test MT.get_unit(1.0^(t / τ)) == MT.unitless -@test MT.get_unit(exp(t / τ)) == MT.unitless -@test MT.get_unit(sin(t / τ)) == MT.unitless -@test MT.get_unit(sin(1u"rad")) == MT.unitless -@test MT.get_unit(t^2) == u"ms^2" +@test UMT.get_unit(τ^-1) == u"ms^-1" +@test UMT.equivalent(UMT.get_unit(D(E)), u"MW") +@test UMT.equivalent(UMT.get_unit(E / τ), u"MW") +@test UMT.get_unit(2 * P) == u"MW" +@test UMT.get_unit(t / τ) == UMT.unitless +@test UMT.equivalent(UMT.get_unit(P - E / τ), u"MW") +@test UMT.equivalent(UMT.get_unit(D(D(E))), u"MW/ms") +@test UMT.get_unit(IfElse.ifelse(t > t, P, E / τ)) == u"MW" +@test UMT.get_unit(1.0^(t / τ)) == UMT.unitless +@test UMT.get_unit(exp(t / τ)) == UMT.unitless +@test UMT.get_unit(sin(t / τ)) == UMT.unitless +@test UMT.get_unit(sin(1u"rad")) == UMT.unitless +@test UMT.get_unit(t^2) == u"ms^2" eqs = [D(E) ~ P - E / τ 0 ~ P] -@test MT.validate(eqs) +@test UMT.validate(eqs) @named sys = ODESystem(eqs) -@test !MT.validate(D(D(E)) ~ P) -@test !MT.validate(0 ~ P + E * τ) +@test !UMT.validate(D(D(E)) ~ P) +@test !UMT.validate(0 ~ P + E * τ) # Disabling unit validation/checks selectively @test_throws MT.ArgumentError ODESystem(eqs, t, [E, P, t], [τ], name = :sys) @@ -86,9 +87,9 @@ end good_eqs = [connect(p1, p2)] bad_eqs = [connect(p1, p2, op)] bad_length_eqs = [connect(op, lp)] -@test MT.validate(good_eqs) -@test !MT.validate(bad_eqs) -@test !MT.validate(bad_length_eqs) +@test UMT.validate(good_eqs) +@test !UMT.validate(bad_eqs) +@test !UMT.validate(bad_length_eqs) @named sys = ODESystem(good_eqs, t, [], []) @test_throws MT.ValidationError ODESystem(bad_eqs, t, [], []; name = :sys) @@ -136,7 +137,7 @@ noiseeqs = [0.1u"MW" 0.1u"MW" # Invalid noise matrix noiseeqs = [0.1u"MW" 0.1u"MW" 0.1u"MW" 0.1u"s"] -@test !MT.validate(eqs, noiseeqs) +@test !UMT.validate(eqs, noiseeqs) # Non-trivial simplifications @variables t [unit = u"s"] V(t) [unit = u"m"^3] L(t) [unit = u"m"] From ec3ac52e48b918102770cb3c9e297eaf8f0ed955 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 9 Nov 2023 14:00:50 -0500 Subject: [PATCH 03/13] Fix typo --- src/systems/discrete_system/discrete_system.jl | 3 ++- src/systems/pde/pdesystem.jl | 2 +- src/systems/unit_check.jl | 11 ++++++++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index cb6118dfd2..45efcf506b 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -103,7 +103,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem check_parameters(ps, iv) end if checks == true || (checks & CheckUnits) > 0 - all_dimensionless([dvs; ps; iv; ctrls]) || check_units(discreteEqs) + u = __get_unit_type(dvs, ps, iv, ctrls) + check_units(u, discreteEqs) end new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, name, systems, diff --git a/src/systems/pde/pdesystem.jl b/src/systems/pde/pdesystem.jl index 50e01e907c..c3bea8cbec 100644 --- a/src/systems/pde/pdesystem.jl +++ b/src/systems/pde/pdesystem.jl @@ -99,7 +99,7 @@ struct PDESystem <: ModelingToolkit.AbstractMultivariateSystem name) if checks == true || (checks & CheckUnits) > 0 u = __get_unit_type(dvs, ivs, ps) - check_units(u, deqs) + check_units(u, eqs) end eqs = eqs isa Vector ? eqs : [eqs] diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index e13f382f84..8160f84352 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -7,7 +7,16 @@ end check_units(::Nothing, _...) = true -__get_literal_unit(x) = getmetadata(x, VariableUnit, nothing) +function __get_literal_unit(x) + if x isa Pair + x = x[1] + end + if !(x isa Union{Num, Symbolic}) + return nothing + end + v = value(x) + getmetadata(v, VariableUnit, nothing) +end function __get_scalar_unit_type(v) u = __get_literal_unit(v) if u isa DQ.AbstractQuantity From c0cf44a8e5ea1143ebbd9e353005fe43d5f6d8e4 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 10 Nov 2023 14:42:54 -0500 Subject: [PATCH 04/13] Add DynamicQuantities support --- src/systems/unit_check.jl | 256 +++++++++++++++++++++++++++++++++++++- src/systems/validation.jl | 7 +- test/dq_units.jl | 177 ++++++++++++++++++++++++++ test/runtests.jl | 10 +- test/units.jl | 2 +- 5 files changed, 432 insertions(+), 20 deletions(-) create mode 100644 test/dq_units.jl diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index 8160f84352..ac73bb24c9 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -1,6 +1,10 @@ import DynamicQuantities, Unitful const DQ = DynamicQuantities +#For dispatching get_unit +const Conditional = Union{typeof(ifelse), typeof(IfElse.ifelse)} +const Comparison = Union{typeof.([==, !=, ≠, <, <=, ≤, >, >=, ≥])...} + struct ValidationError <: Exception message::String end @@ -42,20 +46,260 @@ function __get_unit_type(vs′...) return nothing end -function check_units(::Val{:DynamicQuantities}, eqs...) - validate(eqs...) || - throw(ValidationError("Some equations had invalid units. See warnings for details.")) -end - -function screen_units(result) +function screen_unit(result) if result isa DQ.AbstractQuantity d = DQ.dimension(result) if d isa DQ.Dimensions + if result != oneunit(result) + throw(ValidationError("$result uses non SI unit. Please use SI unit only.")) + end return result elseif d isa DQ.SymbolicDimensions throw(ValidationError("$result uses SymbolicDimensions, please use `u\"m\"` to instantiate SI unit only.")) else throw(ValidationError("$result doesn't use SI unit, please use `u\"m\"` to instantiate SI unit only.")) end + else + throw(ValidationError("$result doesn't have any unit.")) + end +end + +const unitless = DQ.Quantity(1.0) +get_literal_unit(x) = screen_unit(something(__get_literal_unit(x), unitless)) + +""" +Find the unit of a symbolic item. +""" +get_unit(x::Real) = unitless +get_unit(x::DQ.AbstractQuantity) = screen_unit(oneunit(x)) +get_unit(x::AbstractArray) = map(get_unit, x) +get_unit(x::Num) = get_unit(unwrap(x)) +get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) +get_unit(op::Difference, args) = get_unit(args[1]) / get_unit(op.t) +get_unit(op::typeof(getindex), args) = get_unit(args[1]) +get_unit(x::SciMLBase.NullParameters) = unitless +get_unit(op::typeof(instream), args) = get_unit(args[1]) + +function get_unit(op, args) # Fallback + result = op(get_unit.(args)...) + try + oneunit(result) + catch + throw(ValidationError("Unable to get unit for operation $op with arguments $args.")) + end +end + +function get_unit(op::Integral, args) + unit = 1 + if op.domain.variables isa Vector + for u in op.domain.variables + unit *= get_unit(u) + end + else + unit *= get_unit(op.domain.variables) + end + return oneunit(get_unit(args[1]) * unit) +end + +equivalent(x, y) = isequal(x, y) +function get_unit(op::Conditional, args) + terms = get_unit.(args) + terms[1] == unitless || + throw(ValidationError(", in $op, [$(terms[1])] is not dimensionless.")) + equivalent(terms[2], terms[3]) || + throw(ValidationError(", in $op, units [$(terms[2])] and [$(terms[3])] do not match.")) + return terms[2] +end + +function get_unit(op::typeof(Symbolics._mapreduce), args) + if args[2] == + + get_unit(args[3]) + else + throw(ValidationError("Unsupported array operation $op")) + end +end + +function get_unit(op::Comparison, args) + terms = get_unit.(args) + equivalent(terms[1], terms[2]) || + throw(ValidationError(", in comparison $op, units [$(terms[1])] and [$(terms[2])] do not match.")) + return unitless +end + +function get_unit(x::Symbolic) + if (u = __get_literal_unit(x)) !== nothing + screen_unit(u) + elseif issym(x) + get_literal_unit(x) + elseif isadd(x) + terms = get_unit.(arguments(x)) + firstunit = terms[1] + for other in terms[2:end] + termlist = join(map(repr, terms), ", ") + equivalent(other, firstunit) || + throw(ValidationError(", in sum $x, units [$termlist] do not match.")) + end + return firstunit + elseif ispow(x) + pargs = arguments(x) + base, expon = get_unit.(pargs) + @assert oneunit(expon) == unitless + if base == unitless + unitless + else + pargs[2] isa Number ? base^pargs[2] : (1 * base)^pargs[2] + end + elseif istree(x) + op = operation(x) + if issym(op) || (istree(op) && istree(operation(op))) # Dependent variables, not function calls + return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i] + elseif istree(op) && !istree(operation(op)) + gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t) + return screen_unit(getmetadata(gp, VariableUnit, unitless)) + end # Actual function calls: + args = arguments(x) + return get_unit(op, args) + else # This function should only be reached by Terms, for which `istree` is true + throw(ArgumentError("Unsupported value $x.")) + end +end + +""" +Get unit of term, returning nothing & showing warning instead of throwing errors. +""" +function safe_get_unit(term, info) + side = nothing + try + side = get_unit(term) + catch err + if err isa DQ.DimensionError + @warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.") + elseif err isa ValidationError + @warn(info*err.message) + elseif err isa MethodError + @warn("$info: no method matching $(err.f) for arguments $(typeof.(err.args)).") + else + rethrow() + end end + side +end + +function _validate(terms::Vector, labels::Vector{String}; info::String = "") + valid = true + first_unit = nothing + first_label = nothing + for (term, label) in zip(terms, labels) + equnit = safe_get_unit(term, info * label) + if equnit === nothing + valid = false + elseif !isequal(term, 0) + if first_unit === nothing + first_unit = equnit + first_label = label + elseif !equivalent(first_unit, equnit) + valid = false + @warn("$info: units [$(first_unit)] for $(first_label) and [$(equnit)] for $(label) do not match.") + end + end + end + valid +end + +function _validate(conn::Connection; info::String = "") + valid = true + syss = get_systems(conn) + sys = first(syss) + st = states(sys) + for i in 2:length(syss) + s = syss[i] + sst = states(s) + if length(st) != length(sst) + valid = false + @warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) have $(length(st)) and $(length(sst)) states, cannor connect.") + continue + end + for (i, x) in enumerate(st) + j = findfirst(isequal(x), sst) + if j == nothing + valid = false + @warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) do not have the same states.") + else + aunit = safe_get_unit(x, info * string(nameof(sys)) * "#$i") + bunit = safe_get_unit(sst[j], info * string(nameof(s)) * "#$j") + if !equivalent(aunit, bunit) + valid = false + @warn("$info: connected system states $x and $(sst[j]) have mismatched units.") + end + end + end + end + valid +end + +function validate(jump::Union{VariableRateJump, + ConstantRateJump}, t::Symbolic; + info::String = "") + newinfo = replace(info, "eq." => "jump") + _validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units + validate(jump.affect!, info = newinfo) +end + +function validate(jump::MassActionJump, t::Symbolic; info::String = "") + left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols + net_symbols = [x[1] for x in jump.net_stoch] + all_symbols = vcat(left_symbols, net_symbols) + allgood = _validate(all_symbols, string.(all_symbols); info) + n = sum(x -> x[2], jump.reactant_stoch, init = 0) + base_unitful = all_symbols[1] #all same, get first + allgood && _validate([jump.scaled_rates, 1 / (t * base_unitful^n)], + ["scaled_rates", "1/(t*reactants^$n))"]; info) +end + +function validate(jumps::ArrayPartition{<:Union{Any, Vector{<:JumpType}}}, t::Symbolic) + labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"] + all([validate(jumps.x[idx], t, info = labels[idx]) for idx in 1:3]) +end + +function validate(eq::Equation; info::String = "") + if typeof(eq.lhs) == Connection + _validate(eq.rhs; info) + else + _validate([eq.lhs, eq.rhs], ["left", "right"]; info) + end +end +function validate(eq::Equation, + term::Union{Symbolic, DQ.AbstractQuantity, Num}; info::String = "") + _validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info) +end +function validate(eq::Equation, terms::Vector; info::String = "") + _validate(vcat([eq.lhs, eq.rhs], terms), + vcat(["left", "right"], "noise #" .* string.(1:length(terms))); info) +end + +""" +Returns true iff units of equations are valid. +""" +function validate(eqs::Vector; info::String = "") + all([validate(eqs[idx], info = info * " in eq. #$idx") for idx in 1:length(eqs)]) +end +function validate(eqs::Vector, noise::Vector; info::String = "") + all([validate(eqs[idx], noise[idx], info = info * " in eq. #$idx") + for idx in 1:length(eqs)]) +end +function validate(eqs::Vector, noise::Matrix; info::String = "") + all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx") + for idx in 1:length(eqs)]) +end +function validate(eqs::Vector, term::Symbolic; info::String = "") + all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)]) +end +validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== nothing + +""" +Throws error if units of equations are invalid. +""" +function check_units(::Val{:DynamicQuantities}, eqs...) + validate(eqs...) || + throw(ValidationError("Some equations had invalid units. See warnings for details.")) end diff --git a/src/systems/validation.jl b/src/systems/validation.jl index 0acab06281..90b0745cd2 100644 --- a/src/systems/validation.jl +++ b/src/systems/validation.jl @@ -2,7 +2,8 @@ module UnitfulUnitCheck using ..ModelingToolkit, Symbolics, SciMLBase, Unitful, IfElse, RecursiveArrayTools using ..ModelingToolkit: ValidationError, - ModelingToolkit, Connection, instream, JumpType, VariableUnit, get_systems + ModelingToolkit, Connection, instream, JumpType, VariableUnit, get_systems, + Conditional, Comparison using Symbolics: Symbolic, value, issym, isadd, ismul, ispow const MT = ModelingToolkit @@ -39,10 +40,6 @@ MT = ModelingToolkit equivalent(x, y) = isequal(1 * x, 1 * y) const unitless = Unitful.unit(1) -#For dispatching get_unit -const Conditional = Union{typeof(ifelse), typeof(IfElse.ifelse)} -const Comparison = Union{typeof.([==, !=, ≠, <, <=, ≤, >, >=, ≥])...} - """ Find the unit of a symbolic item. """ diff --git a/test/dq_units.jl b/test/dq_units.jl new file mode 100644 index 0000000000..d0dc2b0bb4 --- /dev/null +++ b/test/dq_units.jl @@ -0,0 +1,177 @@ +using ModelingToolkit, OrdinaryDiffEq, JumpProcesses, IfElse, DynamicQuantities +using Test +MT = ModelingToolkit +@parameters τ [unit = u"s"] γ +@variables t [unit = u"s"] E(t) [unit = u"J"] P(t) [unit = u"W"] +D = Differential(t) + +# Basic access +@test MT.get_unit(t) == u"s" +@test MT.get_unit(E) == u"J" +@test MT.get_unit(τ) == u"s" +@test MT.get_unit(γ) == MT.unitless +@test MT.get_unit(0.5) == MT.unitless +@test MT.get_unit(MT.SciMLBase.NullParameters()) == MT.unitless + +# Prohibited unit types +@parameters γ [unit = 1u"ms"] +@test_throws MT.ValidationError MT.get_unit(γ) + +eqs = [D(E) ~ P - E / τ + 0 ~ P] +@test MT.validate(eqs) +@named sys = ODESystem(eqs) + +@test !MT.validate(D(D(E)) ~ P) +@test !MT.validate(0 ~ P + E * τ) + +# Disabling unit validation/checks selectively +@test_throws MT.ArgumentError ODESystem(eqs, t, [E, P, t], [τ], name = :sys) +ODESystem(eqs, t, [E, P, t], [τ], name = :sys, checks = MT.CheckUnits) +eqs = [D(E) ~ P - E / τ + 0 ~ P + E * τ] +@test_throws MT.ValidationError ODESystem(eqs, name = :sys, checks = MT.CheckAll) +@test_throws MT.ValidationError ODESystem(eqs, name = :sys, checks = true) +ODESystem(eqs, name = :sys, checks = MT.CheckNone) +ODESystem(eqs, name = :sys, checks = false) +@test_throws MT.ValidationError ODESystem(eqs, name = :sys, + checks = MT.CheckComponents | MT.CheckUnits) +@named sys = ODESystem(eqs, checks = MT.CheckComponents) +@test_throws MT.ValidationError ODESystem(eqs, t, [E, P, t], [τ], name = :sys, + checks = MT.CheckUnits) + +# connection validation +@connector function Pin(; name) + sts = @variables(v(t)=1.0, [unit = u"V"], + i(t)=1.0, [unit = u"A", connect = Flow]) + ODESystem(Equation[], t, sts, []; name = name) +end +@connector function OtherPin(; name) + sts = @variables(v(t)=1.0, [unit = u"mV"], + i(t)=1.0, [unit = u"mA", connect = Flow]) + ODESystem(Equation[], t, sts, []; name = name) +end +@connector function LongPin(; name) + sts = @variables(v(t)=1.0, [unit = u"V"], + i(t)=1.0, [unit = u"A", connect = Flow], + x(t)=1.0) + ODESystem(Equation[], t, sts, []; name = name) +end +@named p1 = Pin() +@named p2 = Pin() +@named op = OtherPin() +@named lp = LongPin() +good_eqs = [connect(p1, p2)] +bad_eqs = [connect(p1, p2, op)] +bad_length_eqs = [connect(op, lp)] +@test MT.validate(good_eqs) +@test !MT.validate(bad_eqs) +@test !MT.validate(bad_length_eqs) +@named sys = ODESystem(good_eqs, t, [], []) +@test_throws MT.ValidationError ODESystem(bad_eqs, t, [], []; name = :sys) + +# Array variables +@variables t [unit = u"s"] x(t)[1:3] [unit = u"m"] +@parameters v[1:3]=[1, 2, 3] [unit = u"m/s"] +D = Differential(t) +eqs = D.(x) .~ v +ODESystem(eqs, name = :sys) + +# Difference equation +@parameters t [unit = u"s"] a [unit = u"s"^-1] +@variables x(t) [unit = u"kg"] +δ = Differential(t) +D = Difference(t; dt = 0.1u"s") +eqs = [ + δ(x) ~ a * x, +] +de = ODESystem(eqs, t, [x], [a], name = :sys) + +# Nonlinear system +@parameters a [unit = u"kg"^-1] +@variables x [unit = u"kg"] +eqs = [ + 0 ~ a * x, +] +@named nls = NonlinearSystem(eqs, [x], [a]) + +# SDE test w/ noise vector +@parameters τ [unit = u"s"] Q [unit = u"W"] +@variables t [unit = u"s"] E(t) [unit = u"J"] P(t) [unit = u"W"] +D = Differential(t) +eqs = [D(E) ~ P - E / τ + P ~ Q] + +noiseeqs = [0.1u"W", + 0.1u"W"] +@named sys = SDESystem(eqs, noiseeqs, t, [P, E], [τ, Q]) + +# With noise matrix +noiseeqs = [0.1u"W" 0.1u"W" + 0.1u"W" 0.1u"W"] +@named sys = SDESystem(eqs, noiseeqs, t, [P, E], [τ, Q]) + +# Invalid noise matrix +noiseeqs = [0.1u"W" 0.1u"W" + 0.1u"W" 0.1u"s"] +@test !MT.validate(eqs, noiseeqs) + +# Non-trivial simplifications +@variables t [unit = u"s"] V(t) [unit = u"m"^3] L(t) [unit = u"m"] +@parameters v [unit = u"m/s"] r [unit = u"m"^3 / u"s"] +D = Differential(t) +eqs = [D(L) ~ v, + V ~ L^3] +@named sys = ODESystem(eqs) +sys_simple = structural_simplify(sys) + +eqs = [D(V) ~ r, + V ~ L^3] +@named sys = ODESystem(eqs) +sys_simple = structural_simplify(sys) + +@variables V [unit = u"m"^3] L [unit = u"m"] +@parameters v [unit = u"m/s"] r [unit = u"m"^3 / u"s"] t [unit = u"s"] +eqs = [V ~ r * t, + V ~ L^3] +@named sys = NonlinearSystem(eqs, [V, L], [t, r]) +sys_simple = structural_simplify(sys) + +eqs = [L ~ v * t, + V ~ L^3] +@named sys = NonlinearSystem(eqs, [V, L], [t, r]) +sys_simple = structural_simplify(sys) + +#Jump System +@parameters β [unit = u"(mol^2*s)^-1"] γ [unit = u"(mol*s)^-1"] t [unit = u"s"] jumpmol [ + unit = u"mol", +] +@variables S(t) [unit = u"mol"] I(t) [unit = u"mol"] R(t) [unit = u"mol"] +rate₁ = β * S * I +affect₁ = [S ~ S - 1 * jumpmol, I ~ I + 1 * jumpmol] +rate₂ = γ * I +affect₂ = [I ~ I - 1 * jumpmol, R ~ R + 1 * jumpmol] +j₁ = ConstantRateJump(rate₁, affect₁) +j₂ = VariableRateJump(rate₂, affect₂) +js = JumpSystem([j₁, j₂], t, [S, I, R], [β, γ], name = :sys) + +affect_wrong = [S ~ S - jumpmol, I ~ I + 1] +j_wrong = ConstantRateJump(rate₁, affect_wrong) +@test_throws MT.ValidationError JumpSystem([j_wrong, j₂], t, [S, I, R], [β, γ], name = :sys) + +rate_wrong = γ^2 * I +j_wrong = ConstantRateJump(rate_wrong, affect₂) +@test_throws MT.ValidationError JumpSystem([j₁, j_wrong], t, [S, I, R], [β, γ], name = :sys) + +# mass action jump tests for SIR model +maj1 = MassActionJump(2 * β / 2, [S => 1, I => 1], [S => -1, I => 1]) +maj2 = MassActionJump(γ, [I => 1], [I => -1, R => 1]) +@named js3 = JumpSystem([maj1, maj2], t, [S, I, R], [β, γ]) + +#Test unusual jump system +@parameters β γ t +@variables S(t) I(t) R(t) + +maj1 = MassActionJump(2.0, [0 => 1], [S => 1]) +maj2 = MassActionJump(γ, [S => 1], [S => -1]) +@named js4 = JumpSystem([maj1, maj2], t, [S], [β, γ]) diff --git a/test/runtests.jl b/test/runtests.jl index d6562a09d3..23806cae62 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,14 +13,8 @@ using SafeTestsets, Test @safetestset "Clock Test" include("clock.jl") @safetestset "DiscreteSystem Test" include("discretesystem.jl") @safetestset "ODESystem Test" include("odesystem.jl") -#@safetestset "Dynamic Quantities Test" begin -# using DynamicQuantities -# include("units.jl") -#end -@safetestset "Unitful Quantities Test" begin - using Unitful - include("units.jl") -end +@safetestset "Dynamic Quantities Test" include("dq_units.jl") +@safetestset "Unitful Quantities Test" include("units.jl") @safetestset "LabelledArrays Test" include("labelledarrays.jl") @safetestset "Mass Matrix Test" include("mass_matrix.jl") @safetestset "SteadyStateSystem Test" include("steadystatesystems.jl") diff --git a/test/units.jl b/test/units.jl index 9135a5b51e..c518af8459 100644 --- a/test/units.jl +++ b/test/units.jl @@ -1,4 +1,4 @@ -using ModelingToolkit, OrdinaryDiffEq, JumpProcesses, IfElse +using ModelingToolkit, OrdinaryDiffEq, JumpProcesses, IfElse, Unitful using Test MT = ModelingToolkit UMT = ModelingToolkit.UnitfulUnitCheck From cea83b379f4f29135d9773890cb547ea3fa05a77 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 10 Nov 2023 15:02:30 -0500 Subject: [PATCH 05/13] Migrate model parsing test to DynamicQuantities and better tests --- test/model_parsing.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 43eb6dd0fb..d55a90d0a8 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -4,7 +4,7 @@ using ModelingToolkit: get_gui_metadata, VariableDescription, RegularConnector using URIs: URI using Distributions -using Unitful +using DynamicQuantities, OrdinaryDiffEq ENV["MTK_ICONS_DIR"] = "$(@__DIR__)/icons" @@ -143,6 +143,10 @@ C_val = 20 R_val = 20 res__R = 100 @mtkbuild rc = RC(; C_val, R_val, resistor.R = res__R) +prob = ODEProblem(rc, [], (0, 1e9)) +sol = solve(prob, Rodas5P()) +defs = ModelingToolkit.defaults(rc) +@test sol[rc.capacitor.v, end] ≈ defs[rc.constant.k] resistor = getproperty(rc, :resistor; namespace = false) @test getname(rc.resistor) === getname(resistor) @test getname(rc.resistor.R) === getname(resistor.R) From f793c0bd5886e5a1b9957a2ac089b61ff01cddde Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 10 Nov 2023 16:16:56 -0500 Subject: [PATCH 06/13] Fix `check_units` for optimization systems --- src/systems/optimization/optimizationsystem.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index fe6768daed..c02bf33809 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -66,9 +66,9 @@ struct OptimizationSystem <: AbstractOptimizationSystem gui_metadata = nothing, complete = false, parent = nothing; checks::Union{Bool, Int} = true) if checks == true || (checks & CheckUnits) > 0 - unwrap(op) isa Symbolic && check_units(op) - check_units(observed) u = __get_unit_type(states, ps) + unwrap(op) isa Symbolic && check_units(u, op) + check_units(u, observed) check_units(u, constraints) end new(tag, op, states, ps, var_to_name, observed, From ffd8a5f42744274e6229e4b22ea4dbd4f3cbd0c1 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Fri, 10 Nov 2023 19:08:38 -0500 Subject: [PATCH 07/13] Fix constants check --- test/constants.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/constants.jl b/test/constants.jl index 6e60434201..ee9c9fa618 100644 --- a/test/constants.jl +++ b/test/constants.jl @@ -1,6 +1,7 @@ using ModelingToolkit, OrdinaryDiffEq, Unitful using Test MT = ModelingToolkit +UMT = ModelingToolkit.UnitfulUnitCheck @constants a = 1 @test_throws MT.ArgumentError @constants b @@ -25,7 +26,7 @@ simp = structural_simplify(sys) #Constant with units @constants β=1 [unit = u"m/s"] -MT.get_unit(β) +UMT.get_unit(β) @test MT.isconstant(β) @variables t [unit = u"s"] x(t) [unit = u"m"] D = Differential(t) From 1a8b7235423a40180ad3a88a525b09e1213a4406 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Tue, 2 Jan 2024 15:08:43 -0500 Subject: [PATCH 08/13] Ignore iv in `__get_unit_type` --- src/systems/diffeqs/odesystem.jl | 2 +- src/systems/diffeqs/sdesystem.jl | 2 +- src/systems/discrete_system/discrete_system.jl | 2 +- src/systems/jumps/jumpsystem.jl | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 508cb8cfe5..c453625ac3 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -163,7 +163,7 @@ struct ODESystem <: AbstractODESystem check_equations(equations(cevents), iv) end if checks == true || (checks & CheckUnits) > 0 - u = __get_unit_type(dvs, ps, iv) + u = __get_unit_type(dvs, ps) check_units(u, deqs) end new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 5f3f9fb00f..abff555682 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -134,7 +134,7 @@ struct SDESystem <: AbstractODESystem check_equations(equations(cevents), iv) end if checks == true || (checks & CheckUnits) > 0 - u = __get_unit_type(dvs, ps, iv) + u = __get_unit_type(dvs, ps) check_units(u, deqs, neqs) end new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 45efcf506b..b0c57135cd 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -103,7 +103,7 @@ struct DiscreteSystem <: AbstractTimeDependentSystem check_parameters(ps, iv) end if checks == true || (checks & CheckUnits) > 0 - u = __get_unit_type(dvs, ps, iv, ctrls) + u = __get_unit_type(dvs, ps, ctrls) check_units(u, discreteEqs) end new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, name, diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 3633c9c7b3..bd2b22574d 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -112,7 +112,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem check_parameters(ps, iv) end if checks == true || (checks & CheckUnits) > 0 - u = __get_unit_type(states, ps, iv) + u = __get_unit_type(states, ps) check_units(u, ap, iv) end new{U}(tag, ap, iv, states, ps, var_to_name, observed, name, systems, defaults, From 8f09054f1d2231293039184b2e7ca66a62ded281 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 3 Jan 2024 08:18:52 -0500 Subject: [PATCH 09/13] Update odesystem.jl --- test/odesystem.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/odesystem.jl b/test/odesystem.jl index c4a09b6ecd..9bcbbfc130 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -542,8 +542,8 @@ difference_cb = ModelingToolkit.PeriodicCallback(periodic_difference_affect!, 0. sol2 = solve(prob2, Tsit5(); callback = difference_cb, tstops = collect(prob.tspan[1]:0.1:prob.tspan[2])[2:end], verbose = false) -@test sol(0:0.01:1)[x] ≈ sol2(0:0.01:1)[1, :] -@test sol(0:0.01:1)[y] ≈ sol2(0:0.01:1)[2, :] +@test_broken sol(0:0.01:1)[x] ≈ sol2(0:0.01:1)[1, :] +@test_broken sol(0:0.01:1)[y] ≈ sol2(0:0.01:1)[2, :] using ModelingToolkit From 3e474bd52b43589532656ef3f183bd11a19a0137 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 3 Jan 2024 12:28:00 -0500 Subject: [PATCH 10/13] Eagerly screen units --- Project.toml | 2 +- src/systems/unit_check.jl | 3 ++- test/dq_units.jl | 7 +------ 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 88aceb0055..7aa31fb600 100644 --- a/Project.toml +++ b/Project.toml @@ -73,7 +73,7 @@ Distributed = "1" Distributions = "0.23, 0.24, 0.25" DocStringExtensions = "0.7, 0.8, 0.9" DomainSets = "0.6" -DynamicQuantities = "0.8" +DynamicQuantities = "0.8, 0.9, 0.10" ForwardDiff = "0.10.3" FunctionWrappersWrappers = "0.1" Graphs = "1.5.2" diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index ac73bb24c9..02567fa6a9 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -19,7 +19,8 @@ function __get_literal_unit(x) return nothing end v = value(x) - getmetadata(v, VariableUnit, nothing) + u = getmetadata(v, VariableUnit, nothing) + u === nothing ? nothing : screen_unit(u) end function __get_scalar_unit_type(v) u = __get_literal_unit(v) diff --git a/test/dq_units.jl b/test/dq_units.jl index d0dc2b0bb4..406292a344 100644 --- a/test/dq_units.jl +++ b/test/dq_units.jl @@ -59,16 +59,11 @@ end end @named p1 = Pin() @named p2 = Pin() -@named op = OtherPin() +@test_throws MT.ValidationError @named op = OtherPin() @named lp = LongPin() good_eqs = [connect(p1, p2)] -bad_eqs = [connect(p1, p2, op)] -bad_length_eqs = [connect(op, lp)] @test MT.validate(good_eqs) -@test !MT.validate(bad_eqs) -@test !MT.validate(bad_length_eqs) @named sys = ODESystem(good_eqs, t, [], []) -@test_throws MT.ValidationError ODESystem(bad_eqs, t, [], []; name = :sys) # Array variables @variables t [unit = u"s"] x(t)[1:3] [unit = u"m"] From fc87dd65ab3bbb48c604bf72fff98af228567bde Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 3 Jan 2024 14:20:57 -0500 Subject: [PATCH 11/13] Revert "Ignore iv in `__get_unit_type`" This reverts commit 1a8b7235423a40180ad3a88a525b09e1213a4406. --- src/systems/diffeqs/odesystem.jl | 2 +- src/systems/diffeqs/sdesystem.jl | 2 +- src/systems/discrete_system/discrete_system.jl | 2 +- src/systems/jumps/jumpsystem.jl | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index c453625ac3..508cb8cfe5 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -163,7 +163,7 @@ struct ODESystem <: AbstractODESystem check_equations(equations(cevents), iv) end if checks == true || (checks & CheckUnits) > 0 - u = __get_unit_type(dvs, ps) + u = __get_unit_type(dvs, ps, iv) check_units(u, deqs) end new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index abff555682..5f3f9fb00f 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -134,7 +134,7 @@ struct SDESystem <: AbstractODESystem check_equations(equations(cevents), iv) end if checks == true || (checks & CheckUnits) > 0 - u = __get_unit_type(dvs, ps) + u = __get_unit_type(dvs, ps, iv) check_units(u, deqs, neqs) end new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index e28076eb97..dedb7bf6b3 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -103,7 +103,7 @@ struct DiscreteSystem <: AbstractTimeDependentSystem check_parameters(ps, iv) end if checks == true || (checks & CheckUnits) > 0 - u = __get_unit_type(dvs, ps, ctrls) + u = __get_unit_type(dvs, ps, iv, ctrls) check_units(u, discreteEqs) end new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, name, diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 7b7061fd61..6fe2b503e1 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -112,7 +112,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem check_parameters(ps, iv) end if checks == true || (checks & CheckUnits) > 0 - u = __get_unit_type(states, ps) + u = __get_unit_type(states, ps, iv) check_units(u, ap, iv) end new{U}(tag, ap, iv, states, ps, var_to_name, observed, name, systems, defaults, From 089649e692f03a3754e2e8faccf71d9c51f6b70e Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 3 Jan 2024 14:21:53 -0500 Subject: [PATCH 12/13] Only eagerly check DQ --- src/systems/unit_check.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index 02567fa6a9..9b474fcc1d 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -20,7 +20,7 @@ function __get_literal_unit(x) end v = value(x) u = getmetadata(v, VariableUnit, nothing) - u === nothing ? nothing : screen_unit(u) + u isa DQ.AbstractQuantity ? screen_unit(u) : u end function __get_scalar_unit_type(v) u = __get_literal_unit(v) From e763b63711e411af5013c203635b0a5a6c56d284 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 3 Jan 2024 22:55:09 -0500 Subject: [PATCH 13/13] Fix model_parsing tests --- test/model_parsing.jl | 5 ++++- test/precompile_test/ModelParsingPrecompile.jl | 3 +-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/model_parsing.jl b/test/model_parsing.jl index b8a9d0b036..f94d7483fd 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -9,7 +9,7 @@ ENV["MTK_ICONS_DIR"] = "$(@__DIR__)/icons" # Mock module used to test if the `@mtkmodel` macro works with fully-qualified names as well. module MyMockModule -using ..ModelingToolkit, ..Unitful +using ModelingToolkit, DynamicQuantities export Pin @connector Pin begin @@ -328,6 +328,8 @@ end # Ensure that modules consisting MTKModels with component arrays and icons of # `Expr` type and `unit` metadata can be precompiled. +module PrecompilationTest +using Unitful, Test, ModelingToolkit @testset "Precompile packages with MTKModels" begin push!(LOAD_PATH, joinpath(@__DIR__, "precompile_test")) @@ -340,6 +342,7 @@ end pop!(LOAD_PATH) end +end @testset "Conditional statements inside the blocks" begin @mtkmodel C begin end diff --git a/test/precompile_test/ModelParsingPrecompile.jl b/test/precompile_test/ModelParsingPrecompile.jl index 8430831b55..ed67dd8a0c 100644 --- a/test/precompile_test/ModelParsingPrecompile.jl +++ b/test/precompile_test/ModelParsingPrecompile.jl @@ -1,7 +1,6 @@ module ModelParsingPrecompile -using ModelingToolkit -using Unitful +using ModelingToolkit, Unitful @mtkmodel ModelWithComponentArray begin @parameters begin