From 46e928982d3f6ee349939fd031ba0133ae146aa8 Mon Sep 17 00:00:00 2001 From: pogudingleb Date: Sat, 21 Dec 2024 22:20:02 +0100 Subject: [PATCH] actual fix --- src/ODEexport.jl | 64 +++++++++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/src/ODEexport.jl b/src/ODEexport.jl index c7489847..1f91e4d9 100644 --- a/src/ODEexport.jl +++ b/src/ODEexport.jl @@ -1,3 +1,17 @@ +function __add_t(s) + if endswith(s, "(t)") + return s + end + return s * "(t)" +end + +function __remove_t(s) + if endswith(s, "(t)") + return s[1:(end - 3)] + end + return s +end + """ print_for_maple(ode, package) @@ -7,16 +21,17 @@ Prints the ODE in the format accepted by maple packages - DifferentialThomas if package=:DifferentialThomas """ function print_for_maple(ode::ODE, package = :SIAN) - varstr = - Dict(x => var_to_str(x) * "(t)" for x in vcat(ode.x_vars, ode.u_vars, ode.y_vars)) + varstr = Dict( + x => (__add_t ∘ var_to_str)(x) for x in vcat(ode.x_vars, ode.u_vars, ode.y_vars) + ) merge!(varstr, Dict(p => var_to_str(p) for p in ode.parameters)) R_print, vars_print = Nemo.polynomial_ring( base_ring(ode.poly_ring), [varstr[v] for v in gens(ode.poly_ring)], ) - x_names = join(map(var_to_str, ode.x_vars), ", ") - y_names = join(map(var_to_str, ode.y_vars), ", ") - u_names = join(map(var_to_str, ode.u_vars), ", ") + x_names = join(map(__remove_t ∘ var_to_str, ode.x_vars), ", ") + y_names = join(map(__remove_t ∘ var_to_str, ode.y_vars), ", ") + u_names = join(map(__remove_t ∘ var_to_str, ode.u_vars), ", ") u_string = length(u_names) > 0 ? ", [$u_names]" : "" ranking = "[$x_names], [$y_names] $u_string" @@ -39,7 +54,7 @@ function print_for_maple(ode::ODE, package = :SIAN) end # Form the equations - function _rhs_to_str(rhs) + function _rhs_to_str_t(rhs) num, den = unpack_fraction(rhs) res = string(evaluate(num, vars_print)) if den != 1 @@ -50,10 +65,10 @@ function print_for_maple(ode::ODE, package = :SIAN) eqs = [] for (x, f) in ode.x_equations - push!(eqs, ("diff(" * var_to_str(x) * "(t), t)", _rhs_to_str(f))) + push!(eqs, ("diff(" * (__add_t ∘ var_to_str)(x) * ", t)", _rhs_to_str_t(f))) end for (y, g) in ode.y_equations - push!(eqs, (var_to_str(y) * "(t)", _rhs_to_str(g))) + push!(eqs, ((__add_t ∘ var_to_str)(y), _rhs_to_str_t(g))) end if package == :SIAN result *= join(map(a -> a[1] * " = " * a[2], eqs), ",\n") * "\n];\n" @@ -78,13 +93,13 @@ end #------------------------------------------------------------------------------ -function _rhs_to_str(lhs) +function _rhs_to_str_not(lhs) num, den = unpack_fraction(lhs) rslt = string(num) if den != 1 rslt = "($rslt) / ($den)" end - return rslt + return replace(rslt, "(t)" => "") end """ @@ -99,7 +114,7 @@ function print_for_DAISY(ode::ODE) result = result * "B_:={" * - join(map(var_to_str, vcat(ode.u_vars, ode.y_vars, ode.x_vars)), ", ") * + join(map(__remove_t ∘ var_to_str, vcat(ode.u_vars, ode.y_vars, ode.x_vars)), ", ") * "}\$\n" result = result * "FOR EACH EL_ IN B_ DO DEPEND EL_,T\$\n\n" @@ -114,10 +129,10 @@ function print_for_DAISY(ode::ODE) eqs = [] for (x, f) in ode.x_equations - push!(eqs, "df($(var_to_str(x)), t) = " * _rhs_to_str(f)) + push!(eqs, "df($((__remove_t ∘ var_to_str)(x)), t) = " * _rhs_to_str_not(f)) end for (y, g) in ode.y_equations - push!(eqs, "$(var_to_str(y)) = " * _rhs_to_str(g)) + push!(eqs, "$((__remove_t ∘ var_to_str)(y)) = " * _rhs_to_str_not(g)) end result = result * "C_:={" * join(eqs, ",\n") * "}\$\n" @@ -138,28 +153,33 @@ Prints the ODE in the format accepted by GenSSI 2.0 (https://github.com/genssi-d function print_for_GenSSI(ode::ODE) result = "function model = SMTH()\n" - result *= "syms " * join(var_to_str.(ode.x_vars), " ") * "\n" + result *= "syms " * join((__remove_t ∘ var_to_str).(ode.x_vars), " ") * "\n" result *= "syms " * join(var_to_str.(ode.parameters), " ") * "\n" - result *= "syms " * join(map(v -> var_to_str(v) * "0", ode.x_vars), " ") * "\n" + result *= + "syms " * join(map(v -> (__remove_t ∘ var_to_str)(v) * "0", ode.x_vars), " ") * "\n" if length(ode.u_vars) > 0 - result *= "syms " * join(var_to_str.(ode.u_vars), " ") * "\n" + result *= "syms " * join((__remove_t ∘ var_to_str).(ode.u_vars), " ") * "\n" end result *= "model.sym.p = [" * join(var_to_str.(ode.parameters), "; ") * "; " - result *= join(map(v -> var_to_str(v) * "0", ode.x_vars), "; ") * "];\n" + result *= join(map(v -> (__remove_t ∘ var_to_str)(v) * "0", ode.x_vars), "; ") * "];\n" - result *= "model.sym.x = [" * join(var_to_str.(ode.x_vars), "; ") * "];\n" - result *= "model.sym.g = [" * join(var_to_str.(ode.u_vars), "; ") * "];\n" + result *= + "model.sym.x = [" * join((__remove_t ∘ var_to_str).(ode.x_vars), "; ") * "];\n" + result *= + "model.sym.g = [" * join((__remove_t ∘ var_to_str).(ode.u_vars), "; ") * "];\n" result *= - "model.sym.x0 = [" * join(map(v -> var_to_str(v) * "0", ode.x_vars), "; ") * "];\n" + "model.sym.x0 = [" * + join(map(v -> (__remove_t ∘ var_to_str)(v) * "0", ode.x_vars), "; ") * + "];\n" result *= "model.sym.xdot = [" eqs = [] for x in ode.x_vars f = ode.x_equations[x] - push!(eqs, _rhs_to_str(f)) + push!(eqs, _rhs_to_str_not(f)) end result *= join(eqs, "\n") * "];\n" @@ -167,7 +187,7 @@ function print_for_GenSSI(ode::ODE) eqs = [] for y in ode.y_vars g = ode.y_equations[y] - push!(eqs, _rhs_to_str(g)) + push!(eqs, _rhs_to_str_not(g)) end result *= join(eqs, "\n") * "];\n"