Skip to content

Commit

Permalink
Merge pull request #375 from SciML/export_fix
Browse files Browse the repository at this point in the history
Fixing function for exporting into other formats
  • Loading branch information
pogudingleb authored Dec 21, 2024
2 parents 4655510 + 46e9289 commit 3519c17
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 22 deletions.
64 changes: 42 additions & 22 deletions src/ODEexport.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
function __add_t(s)
if endswith(s, "(t)")
return s
end
return s * "(t)"
end

function __remove_t(s)
if endswith(s, "(t)")
return s[1:(end - 3)]
end
return s
end

"""
print_for_maple(ode, package)
Expand All @@ -7,16 +21,17 @@ Prints the ODE in the format accepted by maple packages
- DifferentialThomas if package=:DifferentialThomas
"""
function print_for_maple(ode::ODE, package = :SIAN)
varstr =
Dict(x => var_to_str(x) * "(t)" for x in vcat(ode.x_vars, ode.u_vars, ode.y_vars))
varstr = Dict(
x => (__add_t var_to_str)(x) for x in vcat(ode.x_vars, ode.u_vars, ode.y_vars)
)
merge!(varstr, Dict(p => var_to_str(p) for p in ode.parameters))
R_print, vars_print = Nemo.polynomial_ring(
base_ring(ode.poly_ring),
[varstr[v] for v in gens(ode.poly_ring)],
)
x_names = join(map(var_to_str, ode.x_vars), ", ")
y_names = join(map(var_to_str, ode.y_vars), ", ")
u_names = join(map(var_to_str, ode.u_vars), ", ")
x_names = join(map(__remove_t var_to_str, ode.x_vars), ", ")
y_names = join(map(__remove_t var_to_str, ode.y_vars), ", ")
u_names = join(map(__remove_t var_to_str, ode.u_vars), ", ")
u_string = length(u_names) > 0 ? ", [$u_names]" : ""
ranking = "[$x_names], [$y_names] $u_string"

Expand All @@ -39,7 +54,7 @@ function print_for_maple(ode::ODE, package = :SIAN)
end

# Form the equations
function _rhs_to_str(rhs)
function _rhs_to_str_t(rhs)
num, den = unpack_fraction(rhs)
res = string(evaluate(num, vars_print))
if den != 1
Expand All @@ -50,10 +65,10 @@ function print_for_maple(ode::ODE, package = :SIAN)

eqs = []
for (x, f) in ode.x_equations
push!(eqs, ("diff(" * var_to_str(x) * "(t), t)", _rhs_to_str(f)))
push!(eqs, ("diff(" * (__add_t var_to_str)(x) * ", t)", _rhs_to_str_t(f)))
end
for (y, g) in ode.y_equations
push!(eqs, (var_to_str(y) * "(t)", _rhs_to_str(g)))
push!(eqs, ((__add_t var_to_str)(y), _rhs_to_str_t(g)))
end
if package == :SIAN
result *= join(map(a -> a[1] * " = " * a[2], eqs), ",\n") * "\n];\n"
Expand All @@ -78,13 +93,13 @@ end

#------------------------------------------------------------------------------

function _rhs_to_str(lhs)
function _rhs_to_str_not(lhs)
num, den = unpack_fraction(lhs)
rslt = string(num)
if den != 1
rslt = "($rslt) / ($den)"
end
return rslt
return replace(rslt, "(t)" => "")
end

"""
Expand All @@ -99,7 +114,7 @@ function print_for_DAISY(ode::ODE)
result =
result *
"B_:={" *
join(map(var_to_str, vcat(ode.u_vars, ode.y_vars, ode.x_vars)), ", ") *
join(map(__remove_t var_to_str, vcat(ode.u_vars, ode.y_vars, ode.x_vars)), ", ") *
"}\$\n"

result = result * "FOR EACH EL_ IN B_ DO DEPEND EL_,T\$\n\n"
Expand All @@ -114,10 +129,10 @@ function print_for_DAISY(ode::ODE)
eqs = []

for (x, f) in ode.x_equations
push!(eqs, "df($(var_to_str(x)), t) = " * _rhs_to_str(f))
push!(eqs, "df($((__remove_t var_to_str)(x)), t) = " * _rhs_to_str_not(f))
end
for (y, g) in ode.y_equations
push!(eqs, "$(var_to_str(y)) = " * _rhs_to_str(g))
push!(eqs, "$((__remove_t var_to_str)(y)) = " * _rhs_to_str_not(g))
end

result = result * "C_:={" * join(eqs, ",\n") * "}\$\n"
Expand All @@ -138,36 +153,41 @@ Prints the ODE in the format accepted by GenSSI 2.0 (https://github.com/genssi-d
function print_for_GenSSI(ode::ODE)
result = "function model = SMTH()\n"

result *= "syms " * join(var_to_str.(ode.x_vars), " ") * "\n"
result *= "syms " * join((__remove_t var_to_str).(ode.x_vars), " ") * "\n"
result *= "syms " * join(var_to_str.(ode.parameters), " ") * "\n"
result *= "syms " * join(map(v -> var_to_str(v) * "0", ode.x_vars), " ") * "\n"
result *=
"syms " * join(map(v -> (__remove_t var_to_str)(v) * "0", ode.x_vars), " ") * "\n"
if length(ode.u_vars) > 0
result *= "syms " * join(var_to_str.(ode.u_vars), " ") * "\n"
result *= "syms " * join((__remove_t var_to_str).(ode.u_vars), " ") * "\n"
end

result *= "model.sym.p = [" * join(var_to_str.(ode.parameters), "; ") * "; "
result *= join(map(v -> var_to_str(v) * "0", ode.x_vars), "; ") * "];\n"
result *= join(map(v -> (__remove_t var_to_str)(v) * "0", ode.x_vars), "; ") * "];\n"

result *= "model.sym.x = [" * join(var_to_str.(ode.x_vars), "; ") * "];\n"
result *= "model.sym.g = [" * join(var_to_str.(ode.u_vars), "; ") * "];\n"
result *=
"model.sym.x = [" * join((__remove_t var_to_str).(ode.x_vars), "; ") * "];\n"
result *=
"model.sym.g = [" * join((__remove_t var_to_str).(ode.u_vars), "; ") * "];\n"

result *=
"model.sym.x0 = [" * join(map(v -> var_to_str(v) * "0", ode.x_vars), "; ") * "];\n"
"model.sym.x0 = [" *
join(map(v -> (__remove_t var_to_str)(v) * "0", ode.x_vars), "; ") *
"];\n"

result *= "model.sym.xdot = ["

eqs = []
for x in ode.x_vars
f = ode.x_equations[x]
push!(eqs, _rhs_to_str(f))
push!(eqs, _rhs_to_str_not(f))
end
result *= join(eqs, "\n") * "];\n"

result *= "model.sym.y = ["
eqs = []
for y in ode.y_vars
g = ode.y_equations[y]
push!(eqs, _rhs_to_str(g))
push!(eqs, _rhs_to_str_not(g))
end
result *= join(eqs, "\n") * "];\n"

Expand Down
14 changes: 14 additions & 0 deletions test/exports.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@testset "Exporting to other formats" begin
ode = @ODEmodel(x'(t) = a * x(t) + b * u(t)^2, y(t) = 1 // x(t))
@test print_for_maple(ode) ==
"read \"../IdentifiabilityODE.mpl\";\n\nsys := [\ndiff(x(t), t) = x(t)*a + u(t)^2*b,\ny(t) = (1) / (x(t))\n];\nCodeTools[CPUTime](IdentifiabilityODE(sys, GetParameters(sys)));"
@test print_for_maple(ode, :DifferentialAlgebra) ==
"with(DifferentialAlgebra):\nring_diff := DifferentialRing(blocks = [[x], [y] , [u]], derivations = [t]):\nsys := [\ndiff(x(t), t) - (x(t)*a + u(t)^2*b),\ny(t) - ((1) / (x(t)))\n];\nres := CodeTools[CPUTime](RosenfeldGroebner(sys, ring_diff, singsol=none));"
@test print_for_maple(ode, :DifferentialThomas) ==
"with(DifferentialThomas):\nwith(Tools):\nRanking([t], [[x], [y] , [u]]):\nsys := [\ndiff(x(t), t) - (x(t)*a + u(t)^2*b),\ny(t) - ((1) / (x(t)))\n];\nres := CodeTools[CPUTime](ThomasDecomposition(sys));"
@test print_for_DAISY(ode) ==
"B_:={u, y, x}\$\nFOR EACH EL_ IN B_ DO DEPEND EL_,T\$\n\nB1_:={a, b}\$\n %NUMBER OF STATES\nNX_:=1\$\n %NUMBER OF INPUTS\nNU_:=1\$\n %NUMBER OF OUTPUTS\nNY_:=1\$\n\nC_:={df(x, t) = x*a + u^2*b,\ny = (1) / (x)}\$\nFLAG_:=1\$\nSHOWTIME\$\nDAISY()\$\nSHOWTIME\$\nEND\$\n"
@test print_for_GenSSI(ode) ==
"function model = SMTH()\nsyms x\nsyms a b\nsyms x0\nsyms u\nmodel.sym.p = [a; b; x0];\nmodel.sym.x = [x];\nmodel.sym.g = [u];\nmodel.sym.x0 = [x0];\nmodel.sym.xdot = [x*a + u^2*b];\nmodel.sym.y = [(1) / (x)];\nend"
@test print_for_COMBOS(ode) == "dx1/dt = x1*a + u1^2*b;\ny1 = (1) / (x1)"
end

0 comments on commit 3519c17

Please sign in to comment.