Skip to content

Commit

Permalink
rename and moved cross multiplication to an external function
Browse files Browse the repository at this point in the history
  • Loading branch information
n0rbed committed Nov 1, 2024
1 parent e20c268 commit 63b17a0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ include("solver/polynomialization.jl")
include("solver/attract.jl")
include("solver/ia_main.jl")
include("solver/main.jl")
include("solver/new_feature.jl")
include("solver/ia_rules.jl")
export symbolic_solve

function symbolics_to_sympy end
Expand Down
62 changes: 42 additions & 20 deletions src/solver/new_feature.jl → src/solver/ia_rules.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,58 @@
function solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
vars = Symbolics.get_variables(eq)
vars = filter(v -> !isequal(v, s), vars)
vars = wrap.(vars)
function cross_multiply(eq)
og_oper = operation(unwrap(eq))
done = true
loop_add = false

term_tm = 1
done = false
if og_oper === (/)
done = false
args = arguments(unwrap(eq))
eq = wrap(args[1])
end

# do this until no / are present
while !done
args = arguments(unwrap(eq))
done = true
if og_oper === (+)
while !loop_add
args = arguments(unwrap(eq))
loop_add = true

for arg in args
!iscall(arg) && continue
oper = operation(arg)
if oper == (/)
done = false
term_tm *= wrap(arguments(arg)[2])
for arg in args
!iscall(arg) && continue
oper = operation(arg)
if oper == (/)
done = false
loop_add = false
term_tm *= wrap(arguments(arg)[2])
end
end
args = [arg*term_tm for arg in args]
eq = expand(Symbolics.term((+), unwrap.(args)...))
term_tm = 1
end
args = [arg*term_tm for arg in args]
eq = expand(Symbolics.term((+), unwrap.(args)...))
term_tm = 1
end

if done
return eq
else
return cross_multiply(eq)
end
end

function solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
@assert iscall(unwrap(eq))
vars = Symbolics.get_variables(eq)
vars = filter(v -> !isequal(v, s), vars)
vars = wrap.(vars)

term_tm = 1

eq = cross_multiply(eq)
coeffs, constant = polynomial_coeffs(eq, [s])
eqs = wrap.(collect(values(coeffs)))

symbolic_solve(eqs, vars)

solve_multivar(eqs, vars, dropmultiplicity=dropmultiplicity, warns=warns)
end


# an attempt at using ia_solve recursively.
function find_v(eqs, v, vars)
vars = filter(var -> !isequal(var, v), vars)
Expand Down

0 comments on commit 63b17a0

Please sign in to comment.