Skip to content

Commit

Permalink
add top-level .JuliaFormatter.toml, format sovler
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha Demin committed Aug 17, 2024
1 parent 5500727 commit b5e6a79
Show file tree
Hide file tree
Showing 11 changed files with 214 additions and 208 deletions.
2 changes: 2 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style = "sciml"
format_markdown = true
93 changes: 53 additions & 40 deletions src/solver/attract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function detect_addlogs(lhs, var)
args = arguments(u_lhs)
oper = operation(u_lhs)
!isequal(oper, (+)) && return false

found = [false, false]
c = 1
for arg in args
Expand Down Expand Up @@ -86,7 +86,7 @@ function detect_exponential(lhs, var)
end
end

return all(found)
return all(found)
end

"""
Expand All @@ -109,13 +109,18 @@ Main.RootFinding.slog(-1 + x^2)
function attract_logs(lhs, var)
contains_var(arg) = n_occurrences(arg, var) > 0

r_addlogs = Vector{Any}()
push!(r_addlogs, @acrule log(~x::(contains_var)) + log(~y::(contains_var)) => slog(~x * ~y))
push!(r_addlogs, @acrule ~z*log(~x::(contains_var)) + log(~y::(contains_var)) => slog((~x)^(~z) * ~y))
push!(r_addlogs, @acrule ~z*log(~x::(contains_var)) + ~h*log(~y::(contains_var)) => slog((~x)^(~z) * (~y)^(~h)))

r_addlogs = Vector{Any}()
push!(r_addlogs,
@acrule log(~x::(contains_var)) + log(~y::(contains_var)) => slog(~x * ~y))
push!(r_addlogs,
@acrule ~z * log(~x::(contains_var)) + log(~y::(contains_var)) => slog((~x)^(~z) *
~y))
push!(r_addlogs,
@acrule ~z * log(~x::(contains_var)) + ~h * log(~y::(contains_var)) => slog((~x)^(~z) *
(~y)^(~h)))

lhs = expand(simplify(lhs, rewriter=SymbolicUtils.Postwalk(SymbolicUtils.Chain(r_addlogs))))
lhs = expand(simplify(
lhs, rewriter = SymbolicUtils.Postwalk(SymbolicUtils.Chain(r_addlogs))))

return lhs
end
Expand Down Expand Up @@ -143,11 +148,23 @@ function attract_exponential(lhs, var)
contains_var(arg) = n_occurrences(arg, var) > 0

r_addexpon = Vector{Any}()
push!(r_addexpon, @acrule (~b)^(~f::(contains_var)) + (~d)^(~g::(contains_var)) => ~f*term(slog, ~b) - ~g*term(slog, ~d) + term(log, term(complex, -1)))
push!(r_addexpon, @acrule (~a)*(~b)^(~f::(contains_var)) + (~d)^(~g::(contains_var)) => ~f*term(slog, ~b) - ~g*term(slog, ~d) + term(slog, -~a))
push!(r_addexpon, @acrule (~a)*(~b)^(~f::(contains_var)) + (~c)*(~d)^(~g::(contains_var)) => ~f*term(slog, ~b) - ~g*term(slog, ~d) + term(slog, -(~a)//(~c)))

lhs = expand(simplify(lhs, rewriter=SymbolicUtils.Postwalk(SymbolicUtils.Chain(r_addexpon))))
push!(r_addexpon,
@acrule (~b)^(~f::(contains_var)) + (~d)^(~g::(contains_var)) => ~f *
term(slog, ~b) -
~g *
term(slog, ~d) +
term(
log, term(complex, -1)))
push!(r_addexpon,
@acrule (~a) * (~b)^(~f::(contains_var)) + (~d)^(~g::(contains_var)) => ~f * term(
slog, ~b) - ~g * term(slog, ~d) + term(slog, -~a))
push!(r_addexpon,
@acrule (~a) * (~b)^(~f::(contains_var)) + (~c) * (~d)^(~g::(contains_var)) => ~f *
term(
slog, ~b) - ~g * term(slog, ~d) + term(slog, -(~a) // (~c)))

lhs = expand(simplify(
lhs, rewriter = SymbolicUtils.Postwalk(SymbolicUtils.Chain(r_addexpon))))

return expand(lhs)
end
Expand Down Expand Up @@ -179,33 +196,29 @@ function attract_trig(lhs, var)
contains_var(arg) = n_occurrences(arg, var) > 0

# r_doubleangle1 = @acrule 2*sin(~x::(contains_var))*cos(~x::(contains_var)) => sin(2*~x)
r_trig = [
@acrule(sin(~x::(contains_var))^2 + cos(~x::(contains_var))^2 => one(~x))
@acrule(sin(~x::(contains_var))^2 + -1 => -1*cos(~x)^2)
@acrule(cos(~x::(contains_var))^2 + -1 => -1*sin(~x)^2)

@acrule(cos(~x::(contains_var))^2 + -1*sin(~x::(contains_var))^2 => cos(2 * ~x))
@acrule(sin(~x::(contains_var))^2 + -1*cos(~x::(contains_var))^2 => -cos(2 * ~x))
@acrule(cos(~x::(contains_var)) * sin(~x::(contains_var)) => sin(2 * ~x)/2)

@acrule(tan(~x::(contains_var))^2 + -1*sec(~x::(contains_var))^2 => one(~x))
@acrule(-1*tan(~x::(contains_var))^2 + sec(~x::(contains_var))^2 => one(~x))
@acrule(tan(~x::(contains_var))^2 + 1 => sec(~x)^2)
@acrule(sec(~x::(contains_var))^2 + -1 => tan(~x)^2)

@acrule(cot(~x::(contains_var))^2 + -1*csc(~x)^2 => one(~x))
@acrule(cot(~x::(contains_var))^2 + 1 => csc(~x)^2)
@acrule(csc(~x::(contains_var))^2 + -1 => cot(~x)^2)

@acrule(cosh(~x::(contains_var))^2 + -1*sinh(~x)^2 => one(~x))
@acrule(cosh(~x::(contains_var))^2 + -1 => sinh(~x)^2)
@acrule(sinh(~x::(contains_var))^2 + 1 => cosh(~x)^2)

@acrule(cosh(~x::(contains_var))^2 + sinh(~x::(contains_var))^2 => cosh(2 * ~x))
@acrule(cosh(~x::(contains_var)) * sinh(~x::(contains_var)) => sinh(2 * ~x)/2)
]

lhs = expand(simplify(lhs, rewriter=SymbolicUtils.Postwalk(SymbolicUtils.Chain(r_trig))))
r_trig = [@acrule(sin(~x::(contains_var))^2 + cos(~x::(contains_var))^2=>one(~x))
@acrule(sin(~x::(contains_var))^2 + -1=>-1 * cos(~x)^2)
@acrule(cos(~x::(contains_var))^2 + -1=>-1 * sin(~x)^2)
@acrule(cos(~x::(contains_var))^2 + -1 * sin(~x::(contains_var))^2=>cos(2 *
~x))
@acrule(sin(~x::(contains_var))^2 + -1 * cos(~x::(contains_var))^2=>-cos(2 *
~x))
@acrule(cos(~x::(contains_var)) * sin(~x::(contains_var))=>sin(2 * ~x) / 2)
@acrule(tan(~x::(contains_var))^2 + -1 * sec(~x::(contains_var))^2=>one(~x))
@acrule(-1 * tan(~x::(contains_var))^2 + sec(~x::(contains_var))^2=>one(~x))
@acrule(tan(~x::(contains_var))^2 + 1=>sec(~x)^2)
@acrule(sec(~x::(contains_var))^2 + -1=>tan(~x)^2)
@acrule(cot(~x::(contains_var))^2 + -1 * csc(~x)^2=>one(~x))
@acrule(cot(~x::(contains_var))^2 + 1=>csc(~x)^2)
@acrule(csc(~x::(contains_var))^2 + -1=>cot(~x)^2)
@acrule(cosh(~x::(contains_var))^2 + -1 * sinh(~x)^2=>one(~x))
@acrule(cosh(~x::(contains_var))^2 + -1=>sinh(~x)^2)
@acrule(sinh(~x::(contains_var))^2 + 1=>cosh(~x)^2)
@acrule(cosh(~x::(contains_var))^2 + sinh(~x::(contains_var))^2=>cosh(2 * ~x))
@acrule(cosh(~x::(contains_var)) * sinh(~x::(contains_var))=>sinh(2 * ~x) / 2)]

lhs = expand(simplify(
lhs, rewriter = SymbolicUtils.Postwalk(SymbolicUtils.Chain(r_trig))))

return lhs
end
50 changes: 27 additions & 23 deletions src/solver/ia_helpers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function n_occurrences(expr, var)
n = 0
!iscall(unwrap(expr)) && any(isequal(var, x) for x in get_variables(expr)) && return 1
!iscall(unwrap(expr)) && return 0
!iscall(unwrap(expr)) && return 0

args = arguments(unwrap(expr))

Expand Down Expand Up @@ -30,12 +30,11 @@ function n_func_occ(expr, var)
expr = unwrap(expr)
!iscall(expr) && return n_occurrences(expr, var)
args, cur_oper = arguments(expr), operation(expr)
counted_ops = [sqrt, cbrt, sin, log, log2, log10, cos, tan, asin, acos, atan, exp, ssqrt, scbrt, slog]
counted_ops = [sqrt, cbrt, sin, log, log2, log10, cos, tan,
asin, acos, atan, exp, ssqrt, scbrt, slog]
n = 0


if cur_oper === (*) || cur_oper === (+)

outside = false
for arg in args
n_occurrences(arg, var) == 0 && continue
Expand All @@ -52,35 +51,42 @@ function n_func_occ(expr, var)

args_arg = arguments(arg)
oper_arg = operation(arg)
is_var_outside(arg) = check_poly_inunivar(arg, var) && !outside && n_occurrences(arg, var) != 0
case_1_pow = oper_arg === (^) && n_occurrences(args_arg[2], var) == 0 && n_occurrences(args_arg[1], var) != 0 && check_poly_inunivar(args_arg[1], var) && n_occurrences(arg, var) != 0 && !(args_arg[2] isa Number)
case_2_pow = oper_arg === (^) && n_occurrences(args_arg[2], var) != 0 && n_occurrences(args_arg[1], var) == 0
case_3_pow = oper_arg === (^) && n_occurrences(args_arg[2], var) == 0 && n_occurrences(args_arg[1], var) != 0 && !check_poly_inunivar(args_arg[1], var)

function is_var_outside(arg)
check_poly_inunivar(arg, var) && !outside && n_occurrences(arg, var) != 0
end
case_1_pow = oper_arg === (^) && n_occurrences(args_arg[2], var) == 0 &&
n_occurrences(args_arg[1], var) != 0 &&
check_poly_inunivar(args_arg[1], var) &&
n_occurrences(arg, var) != 0 && !(args_arg[2] isa Number)
case_2_pow = oper_arg === (^) && n_occurrences(args_arg[2], var) != 0 &&
n_occurrences(args_arg[1], var) == 0
case_3_pow = oper_arg === (^) && n_occurrences(args_arg[2], var) == 0 &&
n_occurrences(args_arg[1], var) != 0 &&
!check_poly_inunivar(args_arg[1], var)

# any transcedental operation and the case: (weird_transcedental_f(x))^(something)
if any(isequal(oper, op) for op in counted_ops) || case_3_pow
n += n_func_occ(args_arg[1], var)
# the case (some constant)^(f(x))

# the case (some constant)^(f(x))
elseif case_2_pow
n += n_func_occ(args_arg[2], var)
# var is outside 'x'+1
n += n_func_occ(args_arg[2], var)

# var is outside 'x'+1
elseif is_var_outside(arg)
n += 1
outside = true

# case (f(x))^(weird stuff)
# case (f(x))^(weird stuff)
elseif case_1_pow
n += 1

# n(2 / x) = 1; n(x/x^2) = 2?
# n(2 / x) = 1; n(x/x^2) = 2?
elseif oper_arg === (/)
n += n_func_occ(args_arg[1], var)
n += n_func_occ(args_arg[2], var)

# multiplication cases
# multiplication cases
elseif oper_arg === (*)
args_arg = arguments(arg)

Expand All @@ -89,7 +95,7 @@ function n_func_occ(expr, var)
if is_var_outside(sub_arg)
n += 1
outside = true
# log(x)*y
# log(x)*y
elseif !check_poly_inunivar(sub_arg, var)
n += n_func_occ(sub_arg, var)
end
Expand All @@ -102,11 +108,10 @@ function n_func_occ(expr, var)
n += n_func_occ(arg, var)
end
end

return n
end


function arg_contains_log(arg, var)
oper = operation(arg)
isequal(oper, log) && return true
Expand All @@ -118,15 +123,14 @@ function arg_contains_log(arg, var)
return false
end


function find_logandexpon(arg, var, oper, poly_index)
args_arg = arguments(arg)

oper_term, constant_term = 0, 0
oper_term, constant_term = 0, 0

for a in args_arg
if n_occurrences(a, var) != 0 && iscall(a) && operation(a) == (oper) &&
check_poly_inunivar(arguments(a)[poly_index], var)
check_poly_inunivar(arguments(a)[poly_index], var)
oper_term = a
elseif n_occurrences(a, var) == 0
constant_term = a
Expand Down
Loading

0 comments on commit b5e6a79

Please sign in to comment.