diff --git a/Manifest.toml b/Manifest.toml index 5809670..43974d8 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,13 +2,13 @@ julia_version = "1.8.1" manifest_format = "2.0" -project_hash = "acd57e765b4f1dab1bfcb51cba7cf1a8a0f08029" +project_hash = "171ac7d7a6173a404267c4f8b757ee8c7604f245" [[deps.AbstractAlgebra]] deps = ["GroupsCore", "InteractiveUtils", "LinearAlgebra", "MacroTools", "Markdown", "Random", "RandomExtensions", "SparseArrays", "Test"] -git-tree-sha1 = "29e65c331f97db9189ef00a4c7aed8127c2fd2d4" +git-tree-sha1 = "a69dbe3b376ace7d9eebe2db43216e8b52ba6da9" uuid = "c3fe647b-3220-5bb0-a1ea-a7954cac585d" -version = "0.27.10" +version = "0.29.2" [[deps.AbstractTrees]] git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" @@ -21,11 +21,6 @@ git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "3.6.1" -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" @@ -42,34 +37,12 @@ git-tree-sha1 = "38911c7737e123b28182d89027f4216cfc8a9da7" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" version = "7.4.3" -[[deps.ArrayInterfaceCore]] -deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d" -uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.29" - [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" -[[deps.AutoHashEquals]] -git-tree-sha1 = "45bb6705d93be619b81451bb2006b7ee5d4e4453" -uuid = "15f4f7f2-30c1-5605-9d31-71845cf9641f" -version = "0.2.0" - -[[deps.BangBang]] -deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] -git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.37" - [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - [[deps.Bijections]] git-tree-sha1 = "fe4f8c5ee7f76f2198d5c2a06d3961c249cce7bd" uuid = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" @@ -155,11 +128,6 @@ git-tree-sha1 = "02d2316b7ffceff992f3096ae48c7829a8aa0638" uuid = "b152e2b5-7a66-4b01-a709-34e65c35f657" version = "0.1.3" -[[deps.CompositionsBase]] -git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.1" - [[deps.ConstructionBase]] deps = ["LinearAlgebra"] git-tree-sha1 = "89a9db8d28102b094992472d333674bd1a83ce2a" @@ -197,11 +165,6 @@ version = "1.0.0" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - [[deps.DensityInterface]] deps = ["InverseFunctions", "Test"] git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" @@ -256,9 +219,9 @@ version = "0.9.3" [[deps.DomainSets]] deps = ["CompositeTypes", "IntervalSets", "LinearAlgebra", "Random", "StaticArrays", "Statistics"] -git-tree-sha1 = "988e2db482abeb69efc76ae8b6eba2e93805ee70" +git-tree-sha1 = "698124109da77b6914f64edd696be8dccf90229e" uuid = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" -version = "0.5.15" +version = "0.6.6" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -354,9 +317,9 @@ version = "1.8.0" [[deps.Groebner]] deps = ["AbstractAlgebra", "Combinatorics", "Logging", "MultivariatePolynomials", "Primes", "Random"] -git-tree-sha1 = "47f0f03eddecd7ad59c42b1dd46d5f42916aff63" +git-tree-sha1 = "827f29c95676735719f8d6acbf0a3aaf73b3c9e5" uuid = "0b43b601-686d-58a3-8a1c-6623616c7cd4" -version = "0.2.11" +version = "0.3.2" [[deps.GroupsCore]] deps = ["Markdown", "Random"] @@ -380,11 +343,6 @@ git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428" uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" version = "0.1.3" -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - [[deps.IntegerMathUtils]] git-tree-sha1 = "f366daebdfb079fd1fe4e3d560f99a0c892e15bc" uuid = "18e54dd8-cb9d-406c-a71d-865a43cbb235" @@ -396,9 +354,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[deps.IntervalSets]] deps = ["Dates", "Random", "Statistics"] -git-tree-sha1 = "3f91cd3f56ea48d4d2a75c2a65455c5fc74fa347" +git-tree-sha1 = "16c0cc91853084cb5f58a78bd209513900206ce6" uuid = "8197267c-284f-5f27-9208-e0e47529a953" -version = "0.7.3" +version = "0.7.4" [[deps.InverseFunctions]] deps = ["Test"] @@ -535,18 +493,6 @@ deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" version = "2.28.0+0" -[[deps.Metatheory]] -deps = ["AutoHashEquals", "DataStructures", "Dates", "DocStringExtensions", "Parameters", "Reexport", "TermInterface", "ThreadsX", "TimerOutputs"] -git-tree-sha1 = "0f39bc7f71abdff12ead4fc4a7d998fb2f3c171f" -uuid = "e9d8d322-4543-424a-9be4-0cc815abe26c" -version = "1.3.5" - -[[deps.MicroCollections]] -deps = ["BangBang", "InitialValues", "Setfield"] -git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.1.4" - [[deps.Missings]] deps = ["DataAPI"] git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" @@ -557,10 +503,10 @@ version = "1.1.0" uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.ModelingToolkit]] -deps = ["AbstractTrees", "ArrayInterfaceCore", "Combinatorics", "Compat", "ConstructionBase", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "DiffRules", "Distributed", "Distributions", "DocStringExtensions", "DomainSets", "ForwardDiff", "FunctionWrappersWrappers", "Graphs", "IfElse", "InteractiveUtils", "JuliaFormatter", "JumpProcesses", "LabelledArrays", "Latexify", "Libdl", "LinearAlgebra", "MacroTools", "NaNMath", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLBase", "Serialization", "Setfield", "SimpleNonlinearSolve", "SparseArrays", "SpecialFunctions", "StaticArrays", "SymbolicIndexingInterface", "SymbolicUtils", "Symbolics", "UnPack", "Unitful"] -git-tree-sha1 = "aea7045bc1aec761725c70cbad064b21169128ea" +deps = ["AbstractTrees", "ArrayInterface", "Combinatorics", "Compat", "ConstructionBase", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "DiffRules", "Distributed", "Distributions", "DocStringExtensions", "DomainSets", "ForwardDiff", "FunctionWrappersWrappers", "Graphs", "IfElse", "InteractiveUtils", "JuliaFormatter", "JumpProcesses", "LabelledArrays", "Latexify", "Libdl", "LinearAlgebra", "MacroTools", "NaNMath", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLBase", "Serialization", "Setfield", "SimpleNonlinearSolve", "SparseArrays", "SpecialFunctions", "StaticArrays", "SymbolicIndexingInterface", "SymbolicUtils", "Symbolics", "UnPack", "Unitful"] +git-tree-sha1 = "de2daac4b0ca05c2cedfb4535dcee453e3f5fabd" uuid = "961ee093-0014-501f-94e3-6117800e7a78" -version = "8.46.1" +version = "8.52.0" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" @@ -732,12 +678,6 @@ git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" uuid = "189a3867-3050-52da-a836-e630ba90ab69" version = "1.2.2" -[[deps.Referenceables]] -deps = ["Adapt"] -git-tree-sha1 = "e681d3bfa49cd46c3c161505caddf20f0e62aaa9" -uuid = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" -version = "0.1.2" - [[deps.Requires]] deps = ["UUIDs"] git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" @@ -833,12 +773,6 @@ git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "2.2.0" -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - [[deps.Static]] deps = ["IfElse"] git-tree-sha1 = "08be5ee09a7632c32695d954a602df96a877bf0d" @@ -901,16 +835,16 @@ uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" version = "0.2.2" [[deps.SymbolicUtils]] -deps = ["AbstractTrees", "Bijections", "ChainRulesCore", "Combinatorics", "ConstructionBase", "DataStructures", "DocStringExtensions", "DynamicPolynomials", "IfElse", "LabelledArrays", "LinearAlgebra", "Metatheory", "MultivariatePolynomials", "NaNMath", "Setfield", "SparseArrays", "SpecialFunctions", "StaticArrays", "TermInterface", "TimerOutputs"] -git-tree-sha1 = "027b43d312f6d52187bb16c2d4f0588ddb8c4bb2" +deps = ["AbstractTrees", "Bijections", "ChainRulesCore", "Combinatorics", "ConstructionBase", "DataStructures", "DocStringExtensions", "DynamicPolynomials", "IfElse", "LabelledArrays", "LinearAlgebra", "MultivariatePolynomials", "NaNMath", "Setfield", "SparseArrays", "SpecialFunctions", "StaticArrays", "TimerOutputs", "Unityper"] +git-tree-sha1 = "bfbd444c209b41c7b2fef36b6e146a66da0be9f1" uuid = "d1185830-fcd6-423d-90d6-eec64667417b" -version = "0.19.11" +version = "1.0.4" [[deps.Symbolics]] -deps = ["ArrayInterfaceCore", "ConstructionBase", "DataStructures", "DiffRules", "Distributions", "DocStringExtensions", "DomainSets", "Groebner", "IfElse", "LaTeXStrings", "LambertW", "Latexify", "Libdl", "LinearAlgebra", "MacroTools", "Markdown", "Metatheory", "NaNMath", "RecipesBase", "Reexport", "Requires", "RuntimeGeneratedFunctions", "SciMLBase", "Setfield", "SparseArrays", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "TermInterface", "TreeViews"] -git-tree-sha1 = "111fbf43883d95989577133aeeb889f2040d0aea" +deps = ["ArrayInterface", "ConstructionBase", "DataStructures", "DiffRules", "Distributions", "DocStringExtensions", "DomainSets", "Groebner", "IfElse", "LaTeXStrings", "LambertW", "Latexify", "Libdl", "LinearAlgebra", "MacroTools", "Markdown", "NaNMath", "RecipesBase", "Reexport", "Requires", "RuntimeGeneratedFunctions", "SciMLBase", "Setfield", "SparseArrays", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "TreeViews"] +git-tree-sha1 = "7ecd651e3829d2957478516e92f693f12d5b4781" uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7" -version = "4.14.0" +version = "5.2.0" [[deps.TOML]] deps = ["Dates"] @@ -934,11 +868,6 @@ deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" -[[deps.TermInterface]] -git-tree-sha1 = "7aa601f12708243987b88d1b453541a75e3d8c7a" -uuid = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" -version = "0.2.3" - [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -949,12 +878,6 @@ git-tree-sha1 = "c97f60dd4f2331e1a495527f80d242501d2f9865" uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" version = "0.5.1" -[[deps.ThreadsX]] -deps = ["ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "Setfield", "SplittablesBase", "Transducers"] -git-tree-sha1 = "34e6bcf36b9ed5d56489600cf9f3c16843fa2aa2" -uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" -version = "0.1.11" - [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b" @@ -966,12 +889,6 @@ git-tree-sha1 = "90538bf898832b6ebd900fa40f223e695970e3a5" uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" version = "0.5.25" -[[deps.Transducers]] -deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "c42fa452a60f022e9e087823b47e5a5f8adc53d5" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.75" - [[deps.TreeViews]] deps = ["Test"] git-tree-sha1 = "8d0d7a3fe2f30d6a7f833a5f19f7c7a5b396eae6" @@ -1012,6 +929,12 @@ git-tree-sha1 = "bb37ed24f338bc59b83e3fc9f32dd388e5396c53" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" version = "1.12.4" +[[deps.Unityper]] +deps = ["ConstructionBase"] +git-tree-sha1 = "d5f4ec8c22db63bd3ccb239f640e895cfde145aa" +uuid = "a7c27f48-0311-42f6-a7f8-2c11e75eb415" +version = "0.1.2" + [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" diff --git a/Project.toml b/Project.toml index f4f6fb3..a321f59 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SourceCodeMcCormick" uuid = "a7283dc5-4ecf-47fb-a95b-1412723fc960" authors = ["Robert Gottlieb "] -version = "0.1.3" +version = "0.2.0" [deps] DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -13,7 +13,7 @@ SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" DocStringExtensions = "0.8 - 0.9" ModelingToolkit = "8" IfElse = "0.1.0 - 0.1.1" -SymbolicUtils = "0.19.7" +SymbolicUtils = "1" julia = "1.7" [extras] diff --git a/src/SourceCodeMcCormick.jl b/src/SourceCodeMcCormick.jl index 26df9a9..ea20de2 100644 --- a/src/SourceCodeMcCormick.jl +++ b/src/SourceCodeMcCormick.jl @@ -6,6 +6,8 @@ using SymbolicUtils.Code using IfElse using DocStringExtensions +import SymbolicUtils: BasicSymbolic, exprtype, SYM, TERM, ADD, MUL, POW, DIV + """ AbstractTransform @@ -45,6 +47,6 @@ include(joinpath(@__DIR__, "transform", "transform.jl")) export McCormickIntervalTransform, IntervalTransform export apply_transform, all_evaluators, convex_evaluator, extract_terms, - genvar, genparam, get_name, factor!, binarize!, pull_vars, shrink_eqs + genvar, genparam, get_name, factor, binarize!, pull_vars, shrink_eqs end \ No newline at end of file diff --git a/src/interval/interval.jl b/src/interval/interval.jl index fd2231e..7087de6 100644 --- a/src/interval/interval.jl +++ b/src/interval/interval.jl @@ -5,57 +5,29 @@ Rules for constructing interval bounding expressions # Structure used to indicate an overload with intervals is preferable struct IntervalTransform <: AbstractTransform end -function var_names(::IntervalTransform, s::Term{Real, Base.ImmutableDict{DataType, Any}}) #The variables - arg_list = Symbol[] - if haskey(s.metadata, ModelingToolkit.MTKParameterCtx) - sL = genparam(Symbol(string(get_name(s))*"_lo")) - sU = genparam(Symbol(string(get_name(s))*"_hi")) - else - for i in s.arguments - push!(arg_list, get_name(i)) - end - sL = genvar(Symbol(string(get_name(s))*"_lo"), arg_list) - sU = genvar(Symbol(string(get_name(s))*"_hi"), arg_list) - end - return Symbolics.value(sL), Symbolics.value(sU) -end -function var_names(::IntervalTransform, s::Real) - return s, s -end -function var_names(::IntervalTransform, s::Term{Real, Nothing}) #Any terms like "Differential", or "x[1]" (NOT x[1](t)) - if typeof(s.arguments[1])<:Term #then it has typical args like "x", "y", ... - args = Symbol[] - for i in s.arguments[1].arguments - push!(args, get_name(i)) - end - var = get_name(s.arguments[1]) - var_lo = genvar(Symbol(string(var)*"_lo"), args) - var_hi = genvar(Symbol(string(var)*"_hi"), args) - elseif typeof(s.arguments[1])<:Sym #Then it has no typical args, i.e., x[1] has args Any[x, 1] - if length(s.arguments)==1 - var_lo = genparam(Symbol(string(s.arguments[1].name)*"_lo")) - var_hi = genparam(Symbol(string(s.arguments[1].name)*"_hi")) +var_names(::IntervalTransform, a::Real) = a, a +function var_names(::IntervalTransform, a::BasicSymbolic) + if exprtype(a)==SYM + aL = genvar(Symbol(string(get_name(a))*"_lo")) + aU = genvar(Symbol(string(get_name(a))*"_hi")) + return aL.val, aU.val + elseif exprtype(a)==TERM + if varterm(a) + arg_list = Symbol[] + for i in a.arguments + push!(arg_list, get_name(i)) + end + aL = genvar(Symbol(string(get_name(a))*"_lo"), arg_list) + aU = genvar(Symbol(string(get_name(a))*"_hi"), arg_list) + return aL.val, aU.val else - var_lo = genparam(Symbol(string(s.arguments[1].name)*"_"*string(s.arguments[2])*"_lo")) - var_hi = genparam(Symbol(string(s.arguments[1].name)*"_"*string(s.arguments[2])*"_hi")) + aL = genvar(Symbol(string(get_name(a))*"_lo")) + aU = genvar(Symbol(string(get_name(a))*"_hi")) + return aL.val, aU.val end else - println("Term: $s") - for arg in s.arguments - @show arg - @show typeof(arg) - end - error("Type of argument invalid") + error("Reached `var_names` with an unexpected type [ADD/MUL/DIV/POW]. Check expression factorization to make sure it is being binarized correctly.") end - - sL = s.f(var_lo) - sU = s.f(var_hi) - return Symbolics.value(sL), Symbolics.value(sU) -end -function var_names(::IntervalTransform, s::Sym) #The parameters - sL = genparam(Symbol(string(get_name(s))*"_lo")) - sU = genparam(Symbol(string(get_name(s))*"_hi")) - return Symbolics.value(sL), Symbolics.value(sU) end function translate_initial_conditions(::IntervalTransform, prob::ODESystem, new_eqs::Vector{Equation}) @@ -81,37 +53,28 @@ function translate_initial_conditions(::IntervalTransform, prob::ODESystem, new_ end -# Helper functions for navigating SymbolicUtils structures -get_name(x::Sym{SymbolicUtils.FnType{Tuple{Any}, Real}, Nothing}) = x.name - """ get_name -Take a Symbolic-type object such as `x[1,1]` and return a symbol like `:x_1_1`. +Take a `BasicSymbolic` object such as `x[1,1]` and return a symbol like `:x_1_1`. """ -function get_name(s::Term{SymbolicUtils.FnType{Tuple, Real}, Nothing}) - d = s.arguments - new_var = string(d[1]) - for i in 2:length(d) - new_var = new_var*"_"*string(d[i]) - end - return Symbol(new_var) -end - -function get_name(s::Term) - if haskey(s.metadata, ModelingToolkit.MTKParameterCtx) - d = s.arguments - new_param = string(d[1]) - for i in 2:length(d) - new_param = new_param*"_"*string(d[i]) +function get_name(a::BasicSymbolic) + if exprtype(a)==SYM + return a.name + elseif exprtype(a)==TERM + if varterm(a) + return a.f.name + elseif (a.f==getindex) + args = a.arguments + new_var = string(args[1]) + for i in 2:lastindex(args) + new_var = new_var * "_" * string(args[i]) + end + return Symbol(new_var) + else + error("Problem generating variable name. This may happen if the variable is non-standard. Please post an issue if you get this error.") end - return Symbol(new_param) - else - return get_name(s.f) end end -function get_name(s::Sym) - return s.name -end include(joinpath(@__DIR__, "rules.jl")) \ No newline at end of file diff --git a/src/relaxation/relaxation.jl b/src/relaxation/relaxation.jl index 65f9e09..ebde253 100644 --- a/src/relaxation/relaxation.jl +++ b/src/relaxation/relaxation.jl @@ -1,54 +1,39 @@ struct McCormickTransform <: AbstractTransform end struct McCormickIntervalTransform <: AbstractTransform end -function var_names(::McCormickTransform, s::Term{Real, Base.ImmutableDict{DataType, Any}}) - arg_list = Symbol[] - if haskey(s.metadata, ModelingToolkit.MTKParameterCtx) - scv = genparam(Symbol(string(get_name(s))*"_cv")) - scc = genparam(Symbol(string(get_name(s))*"_cc")) - else - for i in s.arguments - push!(arg_list, get_name(i)) - end - scv = genvar(Symbol(string(get_name(s))*"_cv"), arg_list) - scc = genvar(Symbol(string(get_name(s))*"_cc"), arg_list) - end - return Symbolics.value(scv), Symbolics.value(scc) -end -function var_names(::McCormickTransform, s::Real) - return s, s -end -function var_names(::McCormickTransform, s::Term{Real, Nothing}) #Any terms like "Differential" or x[1] - if typeof(s.arguments[1])<:Term #then it has args - args = Symbol[] - for i in s.arguments[1].arguments - push!(args, get_name(i)) - end - var = get_name(s.arguments[1]) - var_cv = genvar(Symbol(string(var)*"_cv"), args) - var_cc = genvar(Symbol(string(var)*"_cc"), args) - elseif typeof(s.arguments[1])<:Sym #Then it has no args - if length(s.arguments)==1 - var_cv = genparam(Symbol(string(s.arguments[1].name)*"_cv")) - var_cc = genparam(Symbol(string(s.arguments[1].name)*"_cc")) + +var_names(::McCormickTransform, a::Real) = a, a +function var_names(::McCormickTransform, a::BasicSymbolic) + if exprtype(a)==SYM + acv = genvar(Symbol(string(get_name(a))*"_cv")) + acc = genvar(Symbol(string(get_name(a))*"_cc")) + return acv.val, acc.val + elseif exprtype(a)==TERM + if varterm(a) + arg_list = Symbol[] + for i in a.arguments + push!(arg_list, get_name(i)) + end + acv = genvar(Symbol(string(get_name(a))*"_cv"), arg_list) + acc = genvar(Symbol(string(get_name(a))*"_cc"), arg_list) + return acv.val, acc.val else - var_cv = genparam(Symbol(string(s.arguments[1].name)*"_"*string(s.arguments[2])*"_cv")) - var_cc = genparam(Symbol(string(s.arguments[1].name)*"_"*string(s.arguments[2])*"_cc")) + acv = genvar(Symbol(string(get_name(a))*"_cv")) + acc = genvar(Symbol(string(get_name(a))*"_cc")) + return acv.val, acc.val end else - error("Type of argument invalid") + error("Reached `var_names` with an unexpected type [ADD/MUL/DIV/POW]. Check expression factorization to make sure it is being binarized correctly.") end - - scv = s.f(var_cv) - scc = s.f(var_cc) - return Symbolics.value(scv), Symbolics.value(scc) end -function var_names(::McCormickTransform, s::Sym) #The parameters - scv = genparam(Symbol(string(get_name(s))*"_cv")) - scc = genparam(Symbol(string(get_name(s))*"_cc")) - return Symbolics.value(scv), Symbolics.value(scc) + +function var_names(::McCormickIntervalTransform, a::Any) + aL, aU = var_names(IntervalTransform(), a) + acv, acc = var_names(McCormickTransform(), a) + return aL, aU, acv, acc end + function translate_initial_conditions(::McCormickTransform, prob::ODESystem, new_eqs::Vector{Equation}) vars, params = extract_terms(new_eqs) var_defaults = Dict{Any, Any}() @@ -71,14 +56,6 @@ function translate_initial_conditions(::McCormickTransform, prob::ODESystem, new return var_defaults, param_defaults end - -function var_names(::McCormickIntervalTransform, s::Any) - sL, sU = var_names(IntervalTransform(), s) - scv, scc = var_names(McCormickTransform(), s) - return sL, sU, scv, scc -end - - function translate_initial_conditions(::McCormickIntervalTransform, prob::ODESystem, new_eqs::Vector{Equation}) vars, params = extract_terms(new_eqs) var_defaults = Dict{Any, Any}() diff --git a/src/transform/binarize.jl b/src/transform/binarize.jl index 1404387..7b6fbb7 100644 --- a/src/transform/binarize.jl +++ b/src/transform/binarize.jl @@ -1,39 +1,37 @@ #= Rules for transforming narity operations into arity 1 operations =# -function binarize!(ex::SymbolicUtils.Add) - (arity(ex) < 3) && return ex - # Op is already + - skipfirst = iszero(ex.coeff) - newdict = Dict{Any, Number}() - for (key, val) in ex.dict - if skipfirst - skipfirst = false - continue +function binarize!(ex::BasicSymbolic) + exprtype(ex) in (SYM, TERM, DIV, POW) && return nothing + if exprtype(ex)==ADD + skipfirst = iszero(ex.coeff) + newdict = Dict{Any, Number}() + for (key, val) in ex.dict + if skipfirst + skipfirst = false + continue + end + newdict[key] = val + delete!(ex.dict, key) end - newdict[key] = val - delete!(ex.dict, key) - end - a = SymbolicUtils.Add(Real, 0, newdict) - binarize!(a) - ex.dict[a] = 1 - return nothing -end -function binarize!(ex::SymbolicUtils.Mul) - (arity(ex) < 3) && return ex - # Op is already * - skipfirst = isone(ex.coeff) - newdict = Dict{Any, Number}() - for (key, val) in ex.dict - if skipfirst - skipfirst = false - continue + a = SymbolicUtils.Add(Real, 0, newdict) + binarize!(a) + ex.dict[a] = 1 + return nothing + elseif exprtype(ex)==MUL + skipfirst = isone(ex.coeff) + newdict = Dict{Any, Number}() + for (key, val) in ex.dict + if skipfirst + skipfirst = false + continue + end + newdict[key] = val + delete!(ex.dict, key) end - newdict[key] = val - delete!(ex.dict, key) + a = SymbolicUtils.Mul(Real, 1, newdict) + binarize!(a) + ex.dict[a] = 1 + return nothing end - a = SymbolicUtils.Mul(Real, 1, newdict) - binarize!(a) - ex.dict[a] = 1 - return nothing end diff --git a/src/transform/factor.jl b/src/transform/factor.jl index 24d74c1..b82312e 100644 --- a/src/transform/factor.jl +++ b/src/transform/factor.jl @@ -1,66 +1,56 @@ base_term(a::Any) = false -base_term(a::Term{Real, Base.ImmutableDict{DataType,Any}}) = true -base_term(a::Term{Real, Nothing}) = (a.f==getindex) -base_term(a::Sym) = true base_term(a::Real) = true - - -function isfactor(ex::SymbolicUtils.Add) - (~iszero(ex.coeff)) && (length(ex.dict)>1) && return false - (iszero(ex.coeff)) && (length(ex.dict)>2) && return false - for (key, val) in ex.dict - ~(isone(val)) && return false - ~(base_term(key)) && return false - end - return true -end -function isfactor(ex::SymbolicUtils.Mul) - (~isone(ex.coeff)) && (length(ex.dict)>1) && return false - (isone(ex.coeff)) && (length(ex.dict)>2) && return false - for (key, val) in ex.dict - ~(isone(val)) && return false - ~(base_term(key)) && return false - end - return true -end -function isfactor(ex::SymbolicUtils.Div) - ~(base_term(ex.num)) && return false - ~(base_term(ex.den)) && return false - return true +function base_term(a::BasicSymbolic) + exprtype(a)==SYM && return true + exprtype(a)==TERM && return varterm(a) || (a.f==getindex) + return false end -function isfactor(ex::SymbolicUtils.Pow) - ~(base_term(ex.base)) && return false - ~(base_term(ex.exp)) && return false - return true -end -function isfactor(ex::Term{Real,Nothing}) - for i in ex.arguments - ~(base_term(i)) && return false + +function isfactor(a::BasicSymbolic) + if exprtype(a)==SYM + return true + elseif exprtype(a)==TERM + varterm(a) || (a.f==getindex) && return true + for i in a.arguments + ~(base_term(i)) && return false + end + return true + elseif exprtype(a)==ADD + (~iszero(a.coeff)) && (length(a.dict)>1) && return false + (iszero(a.coeff)) && (length(a.dict)>2) && return false + for (key, val) in a.dict + ~(isone(val)) && return false + ~(base_term(key)) && return false + end + return true + elseif exprtype(a)==MUL + (~isone(a.coeff)) && (length(a.dict)>1) && return false + (isone(a.coeff)) && (length(a.dict)>2) && return false + for (key, val) in a.dict + ~(isone(val)) && return false + ~(base_term(key)) && return false + end + return true + elseif exprtype(a)==DIV + ~(base_term(a.num)) && return false + ~(base_term(a.den)) && return false + return true + elseif exprtype(a)==POW + ~(base_term(a.base)) && return false + ~(base_term(a.exp)) && return false + return true end - return true end -factor!(ex::Num) = factor!(ex.val) -factor!(ex::Num, eqs::Vector{Equation}) = factor!(ex.val, eqs=eqs) -function factor!(ex::Sym{Real, Base.ImmutableDict{DataType, Any}}; eqs = Equation[]) - index = findall(x -> isequal(x.rhs,ex), eqs) - if isempty(index) - newsym = gensym(:aux) - newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) - newvar = genvar(newsym) - new = Equation(Symbolics.value(newvar), ex) - push!(eqs, new) - else - p = collect(1:length(eqs)) - deleteat!(p, index[1]) - push!(p, index[1]) - eqs[:] = eqs[p] - end - return eqs +function factor!(a...) + @warn """Use of "!" is deprecated as of v0.2.0. Please call `factor()` instead.""" + return factor(a...) end +factor(ex::Num) = factor(ex.val) +factor(ex::Num, eqs::Vector{Equation}) = factor(ex.val, eqs=eqs) -function factor!(ex::SymbolicUtils.Add; eqs = Equation[]) +function factor(ex::BasicSymbolic; eqs = Equation[]) binarize!(ex) if isfactor(ex) index = findall(x -> isequal(x.rhs,ex), eqs) @@ -78,185 +68,102 @@ function factor!(ex::SymbolicUtils.Add; eqs = Equation[]) end return eqs end - new_terms = Dict{Any, Number}() - for (key, val) in ex.dict - if base_term(key) && isone(val) - new_terms[key] = val - elseif (base_term(key)) - index = findall(x -> isequal(x.rhs,val*key), eqs) - if isempty(index) - newsym = gensym(:aux) - newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) - newvar = genvar(newsym) - new = Equation(Symbolics.value(newvar), val*key) - push!(eqs, new) - new_terms[Symbolics.value(newvar)] = 1 + if exprtype(ex)==ADD + new_terms = Dict{Any, Number}() + for (key, val) in ex.dict + if base_term(key) && isone(val) + new_terms[key] = val + elseif (base_term(key)) + index = findall(x -> isequal(x.rhs,val*key), eqs) + if isempty(index) + newsym = gensym(:aux) + newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) + newvar = genvar(newsym) + new = Equation(Symbolics.value(newvar), val*key) + push!(eqs, new) + new_terms[Symbolics.value(newvar)] = 1 + else + new_terms[eqs[index[1]].lhs] = 1 + end else - new_terms[eqs[index[1]].lhs] = 1 + factor(val*key, eqs=eqs) + new_terms[eqs[end].lhs] = 1 end - else - factor!(val*key, eqs=eqs) - new_terms[eqs[end].lhs] = 1 - end - end - new_add = SymbolicUtils.Add(Real, ex.coeff, new_terms) - factor!(new_add, eqs=eqs) - return eqs -end -function factor!(ex::SymbolicUtils.Mul; eqs = Equation[]) - binarize!(ex) - if isfactor(ex) - index = findall(x -> isequal(x.rhs,ex), eqs) - if isempty(index) - newsym = gensym(:aux) - newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) - newvar = genvar(newsym) - new = Equation(Symbolics.value(newvar), ex) - push!(eqs, new) - else - p = collect(1:length(eqs)) - deleteat!(p, index[1]) - push!(p, index[1]) - eqs[:] = eqs[p] end + new_add = SymbolicUtils.Add(Real, ex.coeff, new_terms) + factor(new_add, eqs=eqs) return eqs - end - new_terms = Dict{Any, Number}() - for (key, val) in ex.dict - if base_term(key) && isone(val) - new_terms[key] = val - elseif base_term(key) - index = findall(x -> isequal(x.rhs,key^val), eqs) - if isempty(index) - newsym = gensym(:aux) - newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) - newvar = genvar(newsym) - new = Equation(Symbolics.value(newvar), key^val) - push!(eqs, new) - new_terms[Symbolics.value(newvar)] = 1 + elseif exprtype(ex)==MUL + new_terms = Dict{Any, Number}() + for (key, val) in ex.dict + if base_term(key) && isone(val) + new_terms[key] = val + elseif base_term(key) + index = findall(x -> isequal(x.rhs,key^val), eqs) + if isempty(index) + newsym = gensym(:aux) + newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) + newvar = genvar(newsym) + new = Equation(Symbolics.value(newvar), key^val) + push!(eqs, new) + new_terms[Symbolics.value(newvar)] = 1 + else + new_terms[eqs[index[1]].lhs] = 1 + end else - new_terms[eqs[index[1]].lhs] = 1 + factor(key^val, eqs=eqs) + new_terms[eqs[end].lhs] = 1 end + end + new_mul = SymbolicUtils.Mul(Real, ex.coeff, new_terms) + factor(new_mul, eqs=eqs) + return eqs + elseif exprtype(ex)==DIV + if base_term(ex.num) + new_num = ex.num else - factor!(key^val, eqs=eqs) - new_terms[eqs[end].lhs] = 1 + factor(ex.num, eqs=eqs) + new_num = eqs[end].lhs end - end - new_mul = SymbolicUtils.Mul(Real, ex.coeff, new_terms) - factor!(new_mul, eqs=eqs) - return eqs -end -function factor!(ex::SymbolicUtils.Pow; eqs = Equation[]) - if isfactor(ex) - index = findall(x -> isequal(x.rhs,ex), eqs) - if isempty(index) - newsym = gensym(:aux) - newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) - newvar = genvar(newsym) - new = Equation(Symbolics.value(newvar), ex) - push!(eqs, new) + if base_term(ex.den) + new_den = ex.den else - p = collect(1:length(eqs)) - deleteat!(p, index[1]) - push!(p, index[1]) - eqs[:] = eqs[p] + factor(ex.den, eqs=eqs) + new_den = eqs[end].lhs end + new_div = SymbolicUtils.Div(new_num, new_den) + factor(new_div, eqs=eqs) return eqs - end - if base_term(ex.base) - new_base = ex.base - else - factor!(ex.base, eqs=eqs) - new_base = eqs[end].lhs - end - if base_term(ex.exp) - new_exp = ex.exp - else - factor!(ex.exp, eqs=eqs) - new_exp = eqs[end].lhs - end - new_pow = SymbolicUtils.Pow(new_base, new_exp) - factor!(new_pow, eqs=eqs) - return eqs -end -function factor!(ex::SymbolicUtils.Div; eqs = Equation[]) - if isfactor(ex) - index = findall(x -> isequal(x.rhs,ex), eqs) - if isempty(index) - newsym = gensym(:aux) - newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) - newvar = genvar(newsym) - new = Equation(Symbolics.value(newvar), ex) - push!(eqs, new) + elseif exprtype(ex)==POW + if base_term(ex.base) + new_base = ex.base else - p = collect(1:length(eqs)) - deleteat!(p, index[1]) - push!(p, index[1]) - eqs[:] = eqs[p] + factor(ex.base, eqs=eqs) + new_base = eqs[end].lhs end - return eqs - end - if base_term(ex.num) - new_num = ex.num - else - factor!(ex.num, eqs=eqs) - new_num = eqs[end].lhs - end - if base_term(ex.den) - new_den = ex.den - else - factor!(ex.den, eqs=eqs) - new_den = eqs[end].lhs - end - new_div = SymbolicUtils.Div(new_num, new_den) - factor!(new_div, eqs=eqs) - return eqs -end -function factor!(ex::Term{Real, Nothing}; eqs = Equation[]) - if isfactor(ex) - index = findall(x -> isequal(x.rhs,ex), eqs) - if isempty(index) - newsym = gensym(:aux) - newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) - newvar = genvar(newsym) - new = Equation(Symbolics.value(newvar), ex) - push!(eqs, new) + if base_term(ex.exp) + new_exp = ex.exp else - p = collect(1:length(eqs)) - deleteat!(p, index[1]) - push!(p, index[1]) - eqs[:] = eqs[p] + factor(ex.exp, eqs=eqs) + new_exp = eqs[end].lhs end + new_pow = SymbolicUtils.Pow(new_base, new_exp) + factor(new_pow, eqs=eqs) return eqs - end - new_args = [] - for arg in ex.arguments - if base_term(arg) - push!(new_args, arg) - else - factor!(arg, eqs=eqs) - push!(new_args, eqs[end].lhs) + elseif exprtype(ex)==TERM + new_args = [] + for arg in ex.arguments + if base_term(arg) + push!(new_args, arg) + else + factor(arg, eqs=eqs) + push!(new_args, eqs[end].lhs) + end end - end - new_func = Term(ex.f, new_args) - factor!(new_func, eqs=eqs) - return eqs -end -function factor!(ex::Term{Real, Base.ImmutableDict{DataType, Any}}; eqs = Equation[]) - index = findall(x -> isequal(x.rhs,ex), eqs) - if isempty(index) - newsym = gensym(:aux) - newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) - newvar = genvar(newsym) - new = Equation(Symbolics.value(newvar), ex) - push!(eqs, new) - else - p = collect(1:length(eqs)) - deleteat!(p, index[1]) - push!(p, index[1]) - eqs[:] = eqs[p] + new_func = Term(ex.f, new_args) + factor(new_func, eqs=eqs) + return eqs end return eqs end - diff --git a/src/transform/transform.jl b/src/transform/transform.jl index c6224e7..bc3f7eb 100644 --- a/src/transform/transform.jl +++ b/src/transform/transform.jl @@ -3,13 +3,57 @@ include(joinpath(@__DIR__, "binarize.jl")) include(joinpath(@__DIR__, "factor.jl")) +function apply_transform(transform::T, prob::ODESystem) where T<:AbstractTransform + + # Factorize all model equations to generate a new set of equations + + genparam(get_name(prob.iv.val)) + + equations = Equation[] + for eqn in prob.eqs + current = length(equations) + factor(eqn.rhs, eqs=equations) + if length(equations) > current + push!(equations, Equation(eqn.lhs, equations[end].rhs)) + deleteat!(equations, length(equations)-1) + else + index = findall(x -> isequal(x.rhs, eqn.rhs), equations) + push!(equations, Equation(eqn.lhs, equations[index[1]].lhs)) + end + end + + # Apply transform rules to the factored equations to make the final equation set + new_equations = Equation[] + for a in equations + zn = var_names(transform, zstr(a)) + xn = var_names(transform, xstr(a)) + if isone(arity(a)) + targs = (transform, op(a), zn..., xn...) + else + targs = (transform, op(a), zn..., xn..., var_names(transform, ystr(a))...) + end + new = transform_rule(targs...) + for i in new + push!(new_equations, i) + end + end + + # Copy model start points to the newly transformed variables + var_defaults, param_defaults = translate_initial_conditions(transform, prob, new_equations) + + # Use the transformed equations and new start points to generate a new ODE system + @named new_sys = ODESystem(new_equations, defaults=merge(var_defaults, param_defaults)) + + return new_sys +end + function apply_transform(transform::T, eqn_vector::Vector{Equation}) where T<:AbstractTransform # Factorize all model equations to generate a new set of equations equations = Equation[] for eqn in eqn_vector current = length(equations) - factor!(eqn.rhs, eqs=equations) + factor(eqn.rhs, eqs=equations) if length(equations) > current push!(equations, Equation(eqn.lhs, equations[end].rhs)) deleteat!(equations, length(equations)-1) @@ -42,7 +86,7 @@ function apply_transform(transform::T, eqn::Equation) where T<:AbstractTransform # Factorize the equations to generate a new set of equations equations = Equation[] - factor!(eqn.rhs, eqs=equations) + factor(eqn.rhs, eqs=equations) if length(equations) > 0 push!(equations, Equation(eqn.lhs, equations[end].rhs)) deleteat!(equations, length(equations)-1) @@ -76,7 +120,7 @@ function apply_transform(transform::T, num::Num) where T<:AbstractTransform @variables result eqn = result ~ num equations = Equation[] - factor!(eqn.rhs, eqs=equations) + factor(eqn.rhs, eqs=equations) if length(equations) > 0 push!(equations, Equation(eqn.lhs, equations[end].rhs)) deleteat!(equations, length(equations)-1) diff --git a/src/transform/utilities.jl b/src/transform/utilities.jl index fa27e2e..2b958e5 100644 --- a/src/transform/utilities.jl +++ b/src/transform/utilities.jl @@ -1,74 +1,79 @@ +# Initial feed functions arity(a::Equation) = arity(a.rhs) -arity(a::Term{Real, Base.ImmutableDict{DataType,Any}}) = 1 -arity(a::Term{Real, Nothing}) = 1 -arity(a::Sym{Real, Base.ImmutableDict{DataType,Any}}) = 1 -arity(a::SymbolicUtils.Add) = length(a.dict) + (~iszero(a.coeff)) -arity(a::SymbolicUtils.Mul) = length(a.dict) + (~isone(a.coeff)) -arity(a::SymbolicUtils.Pow) = 2 -arity(a::SymbolicUtils.Div) = 2 - +arity(a::Num) = arity(a.val) op(a::Equation) = op(a.rhs) -op(::SymbolicUtils.Add) = + -op(::SymbolicUtils.Mul) = * -op(::SymbolicUtils.Pow) = ^ -op(::SymbolicUtils.Div) = / -op(::Term{Real, Base.ImmutableDict{DataType,Any}}) = nothing -op(a::Term{Real, Nothing}) = a.f -op(a::Sym{Real, Base.ImmutableDict{DataType,Any}}) = getindex - +op(a::Num) = op(a.val) + +# Helpful classification checker to differentiate between terms like +# "exp(x)" where x is a variable, and terms like "y(t)" where y and t +# are both variables +varterm(a::BasicSymbolic) = typeof(a.f)<:BasicSymbolic ? true : false + +# Informational functions +function arity(a::BasicSymbolic) + exprtype(a)==SYM && return 1 + exprtype(a)==TERM && return 1 + exprtype(a)==ADD && return length(a.dict) + (~iszero(a.coeff)) + exprtype(a)==MUL && return length(a.dict) + (~isone(a.coeff)) + exprtype(a)==POW && return 2 + exprtype(a)==DIV && return 2 +end +function op(a::BasicSymbolic) + exprtype(a)==SYM && return nothing + exprtype(a)==TERM && return varterm(a) ? nothing : a.f + exprtype(a)==ADD && return + + exprtype(a)==MUL && return * + exprtype(a)==POW && return ^ + exprtype(a)==DIV && return / +end + +# Component extraction functions xstr(a::Equation) = sub_1(a.rhs) ystr(a::Equation) = sub_2(a.rhs) zstr(a::Equation) = a.lhs -sub_1(a::Term{Real, Base.ImmutableDict{DataType,Any}}) = a -sub_1(a::Sym{Real, Base.ImmutableDict{DataType,Any}}) = a -function sub_1(a::SymbolicUtils.Add) - sorted_dict = sort(collect(a.dict), by=x->string(x[1])) - return sorted_dict[1].first -end -function sub_2(a::SymbolicUtils.Add) - ~(iszero(a.coeff)) && return a.coeff - sorted_dict = sort(collect(a.dict), by=x->string(x[1])) - return sorted_dict[2].first -end - -function sub_1(a::SymbolicUtils.Mul) - sorted_dict = sort(collect(a.dict), by=x->string(x[1])) - return sorted_dict[1].first -end -function sub_2(a::SymbolicUtils.Mul) - ~(isone(a.coeff)) && return a.coeff - sorted_dict = sort(collect(a.dict), by=x->string(x[1])) - return sorted_dict[2].first -end -function sub_1(a::SymbolicUtils.Div) - return a.num -end -function sub_2(a::SymbolicUtils.Div) - return a.den -end -function sub_1(a::SymbolicUtils.Pow) - return a.base -end -function sub_2(a::SymbolicUtils.Pow) - return a.exp -end - -function sub_1(a::Term{Real, Nothing}) - if a.f==getindex +function sub_1(a::BasicSymbolic) + if exprtype(a)==SYM return a - else + elseif exprtype(a)==TERM + varterm(a) || a.f==getindex && return a return a.arguments[1] + elseif exprtype(a)==ADD + sorted_dict = sort(collect(a.dict), by=x->string(x[1])) + return sorted_dict[1].first + elseif exprtype(a)==MUL + sorted_dict = sort(collect(a.dict), by=x->string(x[1])) + return sorted_dict[1].first + elseif exprtype(a)==DIV + return a.num + elseif exprtype(a)==POW + return a.base + end +end +function sub_2(a::BasicSymbolic) + if exprtype(a)==SYM + return nothing + elseif exprtype(a)==TERM + varterm(a) || a.f==getindex && return nothing + return a.arguments[2] + elseif exprtype(a)==ADD + ~(iszero(a.coeff)) && return a.coeff + sorted_dict = sort(collect(a.dict), by=x->string(x[1])) + return sorted_dict[2].first + elseif exprtype(a)==MUL + ~(isone(a.coeff)) && return a.coeff + sorted_dict = sort(collect(a.dict), by=x->string(x[1])) + return sorted_dict[2].first + elseif exprtype(a)==DIV + return a.den + elseif exprtype(a)==POW + return a.exp end end -sub_2(a::Term{Real, Nothing}) = a.arguments[2] -# Uses Symbolics functions to generate a variable as a function of the dependent variables of choice (default: t) -function genvar(a::Symbol) - @isdefined(t) ? genvar(a, :t) : genparam(a) -end +# Uses Symbolics functions to generate a variable as a function of the dependent variables of choice function genvar(a::Symbol, b::Symbol) vars = Symbol[] ex = Expr(:block) @@ -89,6 +94,9 @@ function genvar(a::Symbol, b::Vector{Symbol}) push!(ex.args, rhs) eval(ex)[1] end + +# If no variables are given, instead create a parameter +genvar(a::Symbol) = genparam(a) function genparam(a::Symbol) params = Symbol[] ex = Expr(:block) @@ -101,6 +109,7 @@ function genparam(a::Symbol) end +# A function to extract terms from a set of equations, for use in dynamic systems function extract_terms(eqs::Vector{Equation}) allstates = SymbolicUtils.OrderedSet() ps = SymbolicUtils.OrderedSet() @@ -204,14 +213,6 @@ function get_cvcc_start_dict(sys::ODESystem, term::Num, start_point::Float64) end -# Side note: this is how you can get a and b to show up NOT as a(t) and b(t) -# t = genparam(:t) -# a = genvar(:a) -# b = genvar(:b) -# st = SymbolicUtils.Code.NameState(Dict(a => :a, b => :b)) -# toexpr(a+b, st) - - """ pull_vars(::Num) pull_vars(::Vector{Num}) @@ -268,49 +269,32 @@ function pull_vars(eqns::Vector{Equation}) return vars end -function _pull_vars(term::SymbolicUtils.Add, vars::Vector{Num}, strings::Vector{String}) - args = arguments(term) - for arg in args - if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}}) - if ~(string(arg) in strings) - push!(strings, string(arg)) - push!(vars, arg) - end - elseif (typeof(arg) <: Int) || (typeof(arg) <: AbstractFloat) - nothing - else - vars, strings = _pull_vars(arg, vars, strings) +function _pull_vars(term::BasicSymbolic, vars::Vector{Num}, strings::Vector{String}) + if exprtype(term)==SYM + if ~(string(term) in strings) + push!(strings, string(term)) + push!(vars, term) + return vars, strings end + return vars, strings end - return vars, strings -end - -function _pull_vars(term::SymbolicUtils.Mul, vars::Vector{Num}, strings::Vector{String}) - args = arguments(term) - for arg in args - if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}}) - if ~(string(arg) in strings) - push!(strings, string(arg)) - push!(vars, arg) - end - elseif (typeof(arg) <: Int) || (typeof(arg) <: AbstractFloat) - nothing - else - vars, strings = _pull_vars(arg, vars, strings) + if exprtype(term)==TERM && varterm(term) + if ~(string(term.f) in strings) + push!(strings, string(term.f)) + push!(vars, term) + return vars, strings end + return vars, strings end - return vars, strings -end - -function _pull_vars(term::SymbolicUtils.Div, vars::Vector{Num}, strings::Vector{String}) args = arguments(term) for arg in args - if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}}) + ~(typeof(arg)<:BasicSymbolic) ? continue : nothing + if exprtype(arg)==SYM if ~(string(arg) in strings) push!(strings, string(arg)) push!(vars, arg) end - elseif (typeof(arg) <: Int) || (typeof(arg) <: AbstractFloat) + elseif typeof(arg) <: Real nothing else vars, strings = _pull_vars(arg, vars, strings) @@ -319,63 +303,6 @@ function _pull_vars(term::SymbolicUtils.Div, vars::Vector{Num}, strings::Vector{ return vars, strings end -function _pull_vars(term::SymbolicUtils.Pow, vars::Vector{Num}, strings::Vector{String}) - args = arguments(term) - for arg in args - if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}}) - if ~(string(arg) in strings) - push!(strings, string(arg)) - push!(vars, arg) - end - elseif (typeof(arg) <: Int) || (typeof(arg) <: AbstractFloat) - nothing - else - vars, strings = _pull_vars(arg, vars, strings) - end - end - return vars, strings -end - -function _pull_vars(term::SymbolicUtils.Term{Real, Nothing}, vars::Vector{Num}, strings::Vector{String}) - args = arguments(term) - for arg in args - if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}}) - if ~(string(arg) in strings) - push!(strings, string(arg)) - push!(vars, arg) - end - elseif (typeof(arg) <: Int) || (typeof(arg) <: AbstractFloat) - nothing - else - vars, strings = _pull_vars(arg, vars, strings) - end - end - return vars, strings -end - -function _pull_vars(term::SymbolicUtils.Term{Bool, Nothing}, vars::Vector{Num}, strings::Vector{String}) - args = arguments(term) - for arg in args - if (typeof(arg) <: Sym{Real, Base.ImmutableDict{DataType, Any}}) - if ~(string(arg) in strings) - push!(strings, string(arg)) - push!(vars, arg) - end - elseif (typeof(arg) <: Int) || (typeof(arg) <: AbstractFloat) - nothing - else - vars, strings = _pull_vars(arg, vars, strings) - end - end - return vars, strings -end - -function _pull_vars(term::SymbolicUtils.Term{Float64, Nothing}, vars::Vector{Num}, strings::Vector{String}) - return vars, strings -end -function _pull_vars(term::Rational{Int64}, vars::Vector{Num}, strings::Vector{String}) - return vars, strings -end """ shrink_eqs(::Vector{Equation}) @@ -475,7 +402,7 @@ function convex_evaluator(term::Num; force::Bool=false) # huge time savings by separating out the expression using the knowledge # that the sum of convex relaxations is equal to the convex relaxation # of the sum (i.e., a_cv + b_cv = (a+b)_cv, and same for lo/hi/cc) - if typeof(term.val) <: SymbolicUtils.Add + if exprtype(term.val) == ADD # Start with any real-valued operands [if present] cv_eqn = term.val.coeff @@ -532,7 +459,7 @@ end function convex_evaluator(equation::Equation; force::Bool=false) # Same as when the input is `Num`, but we have to deal with the input # already being an equation (whose LHS is irrelevant) - if typeof(equation.rhs.val) <: SymbolicUtils.Add + if exprtype(equation.rhs.val) == ADD cv_eqn = equation.rhs.val.coeff for (key,val) in equation.rhs.val.dict new_equation = 0 ~ (val*key) @@ -568,7 +495,7 @@ four functions (representing lower bound, upper bound, convex relaxation, and concave relaxation evaluation functions) [lo, hi, cv, cc] and the order vector. """ function all_evaluators(term::Num; force::Bool=false) - if typeof(term.val) <: SymbolicUtils.Add + if exprtype(term.val) == ADD lo_eqn = term.val.coeff hi_eqn = term.val.coeff cv_eqn = term.val.coeff @@ -606,7 +533,7 @@ function all_evaluators(term::Num; force::Bool=false) return lo_evaluator, hi_evaluator, cv_evaluator, cc_evaluator, ordered_vars end function all_evaluators(equation::Equation; force::Bool=false) - if typeof(equation.rhs) <: SymbolicUtils.Add + if exprtype(equation.rhs) == ADD lo_eqn = equation.rhs.coeff hi_eqn = equation.rhs.coeff cv_eqn = equation.rhs.coeff